#pragma once #include #include #ifdef USE_FBGEMM #include #include #include namespace ao::sparse { struct TORCH_API PackedLinearWeight : public LinearPackedParamsBase { PackedLinearWeight(std::unique_ptr> w, std::optional bias, std::vector col_offsets, std::vector w_scale, std::vector w_zp, c10::QScheme q_scheme, const int64_t out_features_block_size /* block sparsity size across output_features */, const int64_t in_features_block_size /* block sparsity size across input_features */) : LinearPackedParamsBase( out_features_block_size, in_features_block_size), w(std::move(w)), bias_(std::move(bias)), col_offsets(std::move(col_offsets)), w_scale(std::move(w_scale)), w_zp(std::move(w_zp)), q_scheme(q_scheme) {} std::unique_ptr> w; std::optional bias_; std::vector col_offsets; std::vector w_scale; std::vector w_zp; c10::QScheme q_scheme; at::Tensor apply( const at::Tensor& input, double output_scale, int64_t output_zero_point) override; at::Tensor apply_relu( const at::Tensor& input, double output_scale, int64_t output_zero_point) override; at::Tensor apply_dynamic(const at::Tensor& input) override { TORCH_INTERNAL_ASSERT( false, "Sparse quantized dynamic linear with fused relu is not yet " "supported on qnnpack backend."); return at::Tensor(); } at::Tensor apply_dynamic_relu(const at::Tensor& input) override { TORCH_INTERNAL_ASSERT( false, "Sparse quantized dynamic linear with fused relu is not yet " "supported on qnnpack backend."); return at::Tensor(); } LinearPackedSerializationType unpack() override; BCSRSerializationType serialize() override; static c10::intrusive_ptr deserialize( const BCSRSerializationType& serialized); std::optional bias() override { return bias_; } static c10::intrusive_ptr prepack( const at::Tensor& weight, const std::optional& bias, const int64_t out_features_block_size, const int64_t in_features_block_size); private: template at::Tensor apply_impl( const at::Tensor& input, double output_scale, int64_t output_zero_point); }; } // namespace ao::sparse #endif // USE_FBGEMM namespace ao::sparse { int register_linear_params(); } // namespace ao::sparse