#pragma once #include #include #include #include #include #include #include #include namespace torch::jit::fuser::cuda { // query codegen output arch and target TORCH_CUDA_CU_API void codegenOutputQuery( const cudaDeviceProp* const prop, int& major, int& minor, bool& compile_to_sass); // A class holding metadata for an actual CUDA function. // Note: CUDA functions are per device. struct TORCH_CUDA_CU_API FusedKernelCUDA : public ::torch::jit::fuser::FusedKernel { FusedKernelCUDA( at::DeviceIndex device, std::string name, std::string code, std::vector input_desc, std::vector output_desc, std::vector chunk_desc, std::vector concat_desc, bool has_random); ~FusedKernelCUDA() override; void launch_raw(const uint32_t numel, std::vector& arguments) const override; at::Backend backend() const override { return at::Backend::CUDA; } private: static constexpr auto kBlockSize = 128; // Note: per device to store device properties and compute launch heuristics // Acquiring these values at launch time would be too slow at::DeviceIndex device_; int maxBlocks_{}; cudaDeviceProp* prop_{}; std::vector ptx_; CUmodule module_{}; CUfunction function_{}; }; } // namespace torch::jit::fuser::cuda