#pragma once #include #include namespace at::native::metal { using SerializationTypeConv2dPrePack = std::tuple< Tensor, std::optional, std::vector, std::vector, std::vector, int64_t, std::optional, std::optional>; class Conv2dOpContext : public torch::jit::CustomClassHolder { public: SerializationTypeConv2dPrePack pack() { return std::make_tuple( weight_, bias_, stride_, padding_, dilation_, groups_, output_min_, output_max_); } Conv2dOpContext() = delete; Conv2dOpContext( at::Tensor&& weight, std::optional&& bias, std::vector stride, std::vector padding, std::vector dilation, int64_t groups, std::optional output_min, std::optional output_max) : weight_(std::move(weight)), bias_(std::move(bias)), stride_(std::move(stride)), padding_(std::move(padding)), dilation_(std::move(dilation)), groups_(groups), output_min_(std::move(output_min)), output_max_(std::move(output_max)) {} ~Conv2dOpContext() override { if (releaseCallback_) { releaseCallback_(conv2dOp_); } } void release_resources() override { if (releaseCallback_) { releaseCallback_(conv2dOp_); } } const Tensor& get_weight() const { return weight_; } const std::optional& get_bias() const { return bias_; } const std::vector& get_stride() const { return stride_; } const std::vector& get_padding() const { return padding_; } const std::vector& get_dilation() const { return dilation_; } int64_t get_groups() const { return groups_; } const std::optional& get_output_min() const { return output_min_; } const std::optional& get_output_max() const { return output_max_; } void set_conv2dOpPtr(void* ptr) { conv2dOp_ = ptr; } void* get_conv2dOpPtr() const { return conv2dOp_; } void set_releaseCallback(const std::function& func) { releaseCallback_ = func; } std::function& get_releaseCallback() { return releaseCallback_; } private: Tensor weight_; std::optional bias_; std::vector stride_; std::vector padding_; std::vector dilation_; int64_t groups_; std::optional output_min_; std::optional output_max_; std::function releaseCallback_ = nullptr; void* conv2dOp_ = nullptr; // reserved to hold MPSCNNConv2dOp objects }; using SerializationTypeLinearPrePack = std::tuple< Tensor, std::optional, std::optional, std::optional>; class LinearOpContext : public torch::jit::CustomClassHolder { public: SerializationTypeLinearPrePack pack() { return std::make_tuple(weight_, bias_, output_min_, output_max_); } LinearOpContext() = delete; LinearOpContext( at::Tensor&& weight, std::optional&& bias, std::optional output_min, std::optional output_max) : weight_(std::move(weight)), bias_(std::move(bias)), output_min_(std::move(output_min)), output_max_(std::move(output_max)) {} ~LinearOpContext() override { if (releaseCallback_) { releaseCallback_(opaqueOpPtr_); } } void release_resources() override { if (releaseCallback_) { releaseCallback_(opaqueOpPtr_); } } const Tensor& get_weight() const { return weight_; } const std::optional& get_bias() const { return bias_; } const std::optional& get_output_min() const { return output_min_; } const std::optional& get_output_max() const { return output_max_; } void set_opaqueOpPtr(void* ptr) { opaqueOpPtr_ = ptr; } void* get_opaqueOpPtr() const { return opaqueOpPtr_; } void set_releaseCallback(const std::function& func) { releaseCallback_ = func; } std::function& get_releaseCallback() { return releaseCallback_; } private: Tensor weight_; std::optional bias_; std::optional output_min_; std::optional output_max_; void* opaqueOpPtr_ = nullptr; // reserved to hold MPSCNNFullyConnected objects std::function releaseCallback_ = nullptr; }; } // namespace at::native::metal