#pragma once #include #include #include #if AT_MKLDNN_ENABLED() namespace at::native::mkldnn { const static std::map fusion_attr_map = { {"none", ideep::attr_t()}, {"relu", ideep::attr_t::fuse_relu()}, }; using SerializationTypeConvPrePack = std::tuple< Tensor, std::optional, std::vector, std::vector, std::vector, int64_t, std::vector, std::string>; class ConvOpContext : public torch::jit::CustomClassHolder { protected: Tensor orig_weight_; std::optional orig_bias_; std::vector stride_; std::vector padding_; std::vector dilation_; int64_t groups_; std::vector input_size_; std::string attr_; public: SerializationTypeConvPrePack unpack() { return std::make_tuple( orig_weight_, orig_bias_, stride_, padding_, dilation_, groups_, input_size_, attr_); } virtual Tensor run(const Tensor& input) = 0; virtual void run(const Tensor& input, void* output) = 0; }; class MkldnnConvOpContext final : public ConvOpContext { private: ContextConv op_context_; public: MkldnnConvOpContext( Tensor&& weight, std::optional&& bias, std::vector&& padding, std::vector&& stride, std::vector&& dilation, uint64_t groups, std::vector&& input_size, ContextConv&& op_context) : op_context_(std::move(op_context)) { orig_weight_ = std::move(weight); orig_bias_ = std::move(bias); padding_ = std::move(padding); stride_ = std::move(stride); dilation_ = std::move(dilation); groups_ = groups; input_size_ = std::move(input_size); } Tensor run(const Tensor& input) override; void run(const Tensor& input, void* output) override; static c10::intrusive_ptr create_context( Tensor&& weight, std::optional&& bias, std::vector&& padding, std::vector&& stride, std::vector&& dilation, int64_t groups, std::vector&& input_size, const ideep::attr_t& attr); }; } // namespace at #endif // AT_MKLDNN_ENABLED()