#pragma once #ifdef USE_VULKAN_API #include #include #include #include #include namespace at { namespace native { namespace vulkan { namespace ops { template void stage_pack_weights( api::Context* const context, vTensor& v_weight, const Tensor& weight, const int64_t src_kb_sz, const int64_t src_kh_sz, const int64_t src_kw_sz, const int64_t dst_kh_sz, const int64_t dst_kw_sz) { const int64_t src_matrix_sz = src_kw_sz * src_kh_sz; const int64_t dst_plane_sz = dst_kw_sz * dst_kh_sz; const int64_t dst_matrix_sz = dst_plane_sz * 4; const T* const src_weight_ptr = weight.const_data_ptr(); api::StorageBuffer staging(context, api::kFloat, v_weight.gpu_numel()); { api::MemoryMap mapping(staging.buffer(), api::MemoryAccessType::WRITE); T* dst_weight_ptr = mapping.template data(); memset(dst_weight_ptr, 0, v_weight.nbytes()); for (const auto src_b : c10::irange(src_kb_sz)) { for (const auto src_h : c10::irange(src_kh_sz)) { for (const auto src_w : c10::irange(src_kw_sz)) { int64_t dst_plane = 2 * (src_h % 2) + (src_w % 2); int64_t dst_index = (src_h / 2) * dst_kw_sz + (src_w / 2); memcpy( dst_weight_ptr + src_b * dst_matrix_sz + dst_plane * dst_plane_sz + dst_index, src_weight_ptr + src_b * src_matrix_sz + src_h * src_kw_sz + src_w, sizeof(T)); } } } } utils::pack_staging_to_vtensor(staging.buffer(), v_weight); } class LinearPackedContext final : virtual public VulkanPackedContext, public torch::jit::CustomClassHolder { private: c10::impl::GenericList unpacked_; public: LinearPackedContext( const Tensor& weight, const std::optional& bias, const bool use_batch = false); /* * Assigns a name to each index in the unpacked list. */ struct Unpacked final { static constexpr uint32_t Weight = 0u; static constexpr uint32_t Bias = 1u; static constexpr uint32_t NumArgs = 2u; }; /* * Assigns a name to each index in the packed list. */ struct Packed final { static constexpr uint32_t Weight = 0u; static constexpr uint32_t Bias = 1u; static constexpr uint32_t WeightSizes = 2u; static constexpr uint32_t BiasDefined = 3u; static constexpr uint32_t NumArgs = 4u; }; static LinearPackedContext pack(c10::impl::GenericList); const c10::impl::GenericList unpack() const override { TORCH_CHECK(!unpacked_.empty(), "unpacked_ does not have any elements!"); return unpacked_; } }; c10::intrusive_ptr create_linear_context( Tensor&& weight, std::optional&& bias); Tensor run_linear_context( const Tensor& input, const c10::intrusive_ptr& context); Tensor run_qlinear_context( const Tensor& input, double output_scale, int64_t output_zero_point, const c10::intrusive_ptr& context); } // namespace ops } // namespace vulkan } // namespace native } // namespace at #endif /* USE_VULKAN_API */