#pragma once #ifdef USE_VULKAN_API #include #include #include namespace at { namespace native { namespace vulkan { namespace ops { class LayernormPackedContext final : virtual public VulkanPackedContext, public torch::jit::CustomClassHolder { private: c10::impl::GenericList unpacked_; public: LayernormPackedContext( const std::optional& weight, const std::optional& bias, double eps); /* * Assigns a name to each index in the unpacked list. */ struct ListArgs final { static constexpr uint32_t kWeight = 0u; static constexpr uint32_t kBias = 1u; static constexpr uint32_t kEps = 2u; static constexpr uint32_t kNumArgs = 3u; }; static LayernormPackedContext pack(const c10::impl::GenericList); const c10::impl::GenericList unpack() const override { TORCH_CHECK(!unpacked_.empty(), "unpacked_ does not have any elements!"); return unpacked_; } }; c10::intrusive_ptr create_layernorm_context( std::optional&& weight, std::optional&& bias, double eps); Tensor run_layernorm_context( const Tensor& input, IntArrayRef normalized_shape, const c10::intrusive_ptr& context); } // namespace ops } // namespace vulkan } // namespace native } // namespace at #endif /* USE_VULKAN_API */