#pragma once #include #include #include #include #include #include namespace torch::jit::mobile::nnc { // Specify the requirements on an input tensor. // TODO: support input tensor with dynamic shape (PR #54982) struct TORCH_API InputSpec { InputSpec() = default; // Deserialize the spec from an IValue. explicit InputSpec(const c10::IValue& value); // Serialize the spec into an IValue. [[nodiscard]] c10::IValue serialize() const; // Check whether the input tensor adheres to the spec. [[nodiscard]] bool validate(const at::Tensor& input) const; std::vector sizes_; c10::ScalarType dtype_{c10::ScalarType::Undefined}; }; // Specify the sizes/dtype/... of output tensor to preallocate the output. // TODO: support the case where kernel allocates output tensors dynamically. struct TORCH_API OutputSpec { OutputSpec() = default; // Deserialize the spec from an IValue. explicit OutputSpec(const c10::IValue& value); // Serialize the spec into an IValue. [[nodiscard]] c10::IValue serialize() const; // Allocate an output tensor in accordance with the spec. [[nodiscard]] at::Tensor allocate() const; std::vector sizes_; c10::ScalarType dtype_{c10::ScalarType::Undefined}; std::optional qscale_; std::optional qzero_; }; // Hold the temporary buffers / states needed during the execution. struct TORCH_API ExecutionState { ExecutionState() = default; ExecutionState(const ExecutionState&) = delete; ExecutionState(ExecutionState&&) = default; ExecutionState& operator=(const ExecutionState&) = delete; ExecutionState& operator=(ExecutionState&&) = default; // Preallocated buffers needed by the NNC kernel. std::vector preallocations_; // The NNC kernel expects the following arguments layout: // input tensor 1 // ... // input tensor INPUT_NUM // output tensor 1 // ... // output tensor OUTPUT_NUM // parameter tensor 1 // ... // parameter tensor PARAM_NUM // temporary buffer 1 // ... // temporary buffer BUFFER_NUM std::vector arguments_; }; // Specify how to allocate temporary buffers at initialization. struct TORCH_API MemoryPlan { MemoryPlan() = default; explicit MemoryPlan(const c10::IValue& value); [[nodiscard]] c10::IValue serialize() const; void allocate(ExecutionState* state) const; std::vector buffer_sizes_; }; // Location of a symbolic shape among dimensions of the inputs struct TORCH_API SymbolicShapePosition { SymbolicShapePosition() = default; SymbolicShapePosition(int64_t input_idx, int64_t dim_idx) : input_idx_(input_idx), dim_idx_(dim_idx) {} int64_t input_idx_; int64_t dim_idx_; }; // Represents a compiled NNC function which has a 1-1 correspondence with a // `Method` (e.g. `forward`). It's similar as torch::jit::mobile::Function. class TORCH_API Function { public: explicit Function() = default; // Deserialize from an IValue that is generated by the 'serialize()' method. explicit Function(const c10::IValue& value); // Serialize into an IValue. c10::IValue serialize() const; // Execute the compiled NNC function. c10::impl::GenericList run(const c10::impl::GenericList& inputs) const; // The name of the function as specified in the model code. c10::QualifiedName name() const { return name_; } void set_name(const c10::QualifiedName& name) { name_ = name; } // The unique id of the generated NNC kernel corresponding to the function. const std::string& nnc_kernel_id() const { return nnc_kernel_id_; } void set_nnc_kernel_id(const std::string& name) { nnc_kernel_id_ = name; } // The parameters (e.g. weights / bias tensors) to be passed to the generated // NNC kernel. const c10::impl::GenericList& parameters() const { return parameters_; } void set_parameters(const c10::impl::GenericList& parameters) { parameters_ = parameters; } const std::vector& input_specs() const { return input_specs_; } void set_input_specs(const std::vector& input_specs) { input_specs_ = input_specs; } const std::vector& output_specs() const { return output_specs_; } void set_output_specs(const std::vector& output_specs) { output_specs_ = output_specs; } const MemoryPlan& memory_plan() const { return memory_plan_; } void set_memory_plan(const MemoryPlan& memory_plan) { memory_plan_ = memory_plan; } const std::vector& sym_shape_positions() const { return sym_shape_positions_; } void set_sym_shape_positions( const std::vector& sym_shape_pos) { sym_shape_positions_ = sym_shape_pos; } private: void init_execution_state() const; c10::QualifiedName name_; std::string nnc_kernel_id_; c10::impl::GenericList parameters_{at::AnyType::get()}; std::vector input_specs_; std::vector output_specs_; std::vector sym_shape_positions_; MemoryPlan memory_plan_; mutable std::unique_ptr execution_state_; }; // CompilationUnit consists of a set of compiled NNC functions. It has a 1-1 // correspondence with a `Module`. // It's similar as torch::jit::mobile::CompilationUnit. class TORCH_API CompilationUnit { public: CompilationUnit() = default; CompilationUnit(const CompilationUnit&) = delete; CompilationUnit(CompilationUnit&&) = default; CompilationUnit& operator=(const CompilationUnit&) = delete; CompilationUnit& operator=(CompilationUnit&&) = default; // Deserialize from an IValue that is generated by the 'serialize()' method. explicit CompilationUnit(const c10::IValue& value); // Serialize all registered functions into an IValue. The IValue will be save // into the compiled TorchScript model file ahead-of-time on the host, and // will be deserialized at runtime on the target device. [[nodiscard]] c10::IValue serialize() const; // Execute a registered function. [[nodiscard]] c10::impl::GenericList run( const c10::QualifiedName& function_name, const c10::impl::GenericList& inputs) const; // Register a function to the compilation unit. void register_function(std::unique_ptr fn); private: [[nodiscard]] Function* find_function(const c10::QualifiedName& qn) const; std::unordered_map> functions_; }; } // namespace torch::jit::mobile::nnc