#include #include #include #include "pybind11/pybind11.h" #include "pybind11/cast.h" #include "pybind11/stl.h" #include "cudnn_frontend.h" namespace py = pybind11; using namespace pybind11::literals; namespace cudnn_frontend::python_bindings { // This class is only meant direct pythonic API calls to c++ Graph class. class PyGraph { public: template std::shared_ptr pointwise_ternary(std::shared_ptr& a, std::shared_ptr& b, std::shared_ptr& c, cudnn_frontend::DataType_t const& compute_data_type, std::string const& name); template std::shared_ptr pointwise_binary(std::shared_ptr& a, std::shared_ptr& b, cudnn_frontend::DataType_t const& compute_data_type, std::string const& name); template std::shared_ptr pointwise_unary(std::shared_ptr& a, cudnn_frontend::DataType_t const& compute_data_type, std::string const& name); // This Graph class is the sole structure which implicitly makes PyGraph own all tensors, nodes, and cudnn // descriptors. cudnn_frontend::graph::Graph graph; cudnnHandle_t handle; bool is_handle_owner = false; PyGraph(std::string const&, cudnn_frontend::DataType_t io_data_type, cudnn_frontend::DataType_t intermediate_data_type, cudnn_frontend::DataType_t compute_data_type, std::optional handle_, py::object sm_count, py::object sm_version, std::shared_ptr kernel_cache) { graph.set_compute_data_type(compute_data_type) .set_intermediate_data_type(intermediate_data_type) .set_io_data_type(io_data_type); if (handle_.has_value()) { handle = static_cast((void*)(handle_.value())); } else { detail::create_handle(&handle); is_handle_owner = true; } if (sm_count.is(py::none()) == false) { graph.set_sm_count(sm_count.cast()); } if (sm_version.is(py::none()) == false) { graph.set_sm_version(sm_version.cast()); } if (kernel_cache) { graph.set_kernel_cache(kernel_cache); graph.set_dynamic_shape_enabled(true); } } ~PyGraph() { if (is_handle_owner) { detail::destroy_handle(handle); } } std::shared_ptr tensor(std::vector const& dim, std::vector const& stride, cudnn_frontend::DataType_t const& data_type, bool const& is_virtual, bool const& is_pass_by_value, std::shared_ptr const& ragged_offset, std::string const& name); std::shared_ptr tensor_like(std::shared_ptr const& pyobj, std::string const&); std::shared_ptr tensor_like(py::object const& pyobj); std::vector> batchnorm(std::shared_ptr& x, std::shared_ptr& scale, std::shared_ptr& bias, std::shared_ptr& in_running_mean, std::shared_ptr& in_running_var, std::shared_ptr& epsilon, std::shared_ptr& momentum, std::vector>& peer_stats, cudnn_frontend::DataType_t const& compute_data_type, std::string const& name); std::vector> layernorm(cudnn_frontend::NormFwdPhase_t const forward_phase, std::shared_ptr& x, std::shared_ptr& scale, std::shared_ptr& bias, std::shared_ptr& epsilon, cudnn_frontend::DataType_t const& compute_data_type, std::string const& name); std::shared_ptr batchnorm_inference(std::shared_ptr& x, std::shared_ptr& mean, std::shared_ptr& inv_variance, std::shared_ptr& scale, std::shared_ptr& bias, cudnn_frontend::DataType_t const& compute_data_type, std::string const& name); std::vector> layernorm_backward(std::shared_ptr const& dy, std::shared_ptr const& x, std::shared_ptr const& scale, std::shared_ptr const& mean, std::shared_ptr const& inv_variance, cudnn_frontend::DataType_t const& compute_data_type, std::string const& name); std::vector> batchnorm_backward(std::shared_ptr const& dy, std::shared_ptr const& x, std::shared_ptr const& scale, std::shared_ptr const& mean, std::shared_ptr const& inv_variance, std::vector>& peer_stats, cudnn_frontend::DataType_t const& compute_data_type, std::string const& name); std::shared_ptr slice(std::shared_ptr& input, std::vector const& slices, cudnn_frontend::DataType_t const& compute_data_type, std::string const& name); std::shared_ptr conv_fprop(std::shared_ptr& image, std::shared_ptr& weight, std::vector const& pre_padding, std::vector const& post_padding, std::vector const& stride, std::vector const& dilation, cudnn_frontend::ConvolutionMode_t const& conv_mode, cudnn_frontend::DataType_t const& compute_data_type, std::string const& name); std::shared_ptr conv_dgrad(std::shared_ptr& loss, std::shared_ptr& filter, std::vector const& pre_padding, std::vector const& post_padding, std::vector const& stride, std::vector const& dilation, cudnn_frontend::ConvolutionMode_t const& conv_mode, cudnn_frontend::DataType_t const& compute_data_type, std::string const& name); std::shared_ptr conv_wgrad(std::shared_ptr& image, std::shared_ptr& loss, std::vector const& pre_padding, std::vector const& post_padding, std::vector const& stride, std::vector const& dilation, cudnn_frontend::ConvolutionMode_t const& conv_mode, cudnn_frontend::DataType_t const& compute_data_type, std::string const& name); std::shared_ptr matmul(std::shared_ptr& A, std::shared_ptr& B, cudnn_frontend::DataType_t const& compute_data_type, double const padding, std::string const& name); std::shared_ptr relu(std::shared_ptr& input, std::optional const& negative_slope, std::optional const& lower_clip, std::optional const& upper_clip, cudnn_frontend::DataType_t const& compute_data_type, std::string const& name); std::shared_ptr gen_index(std::shared_ptr& input, int64_t const axis, cudnn_frontend::DataType_t const& compute_data_type, std::string const& name); std::shared_ptr relu_backward(std::shared_ptr& loss, std::shared_ptr& input, std::optional const& negative_slope, std::optional const& lower_clip, std::optional const& upper_clip, cudnn_frontend::DataType_t const& compute_data_type, std::string const& name); std::shared_ptr leaky_relu_backward(std::shared_ptr& loss, std::shared_ptr& input, float const negative_slope, cudnn_frontend::DataType_t const& compute_data_type, std::string const& name); std::shared_ptr leaky_relu(std::shared_ptr& input, float const negative_slope, cudnn_frontend::DataType_t const& compute_data_type, std::string const& name); std::array, 2UL> genstats(std::shared_ptr& input, cudnn_frontend::DataType_t const& compute_data_type, std::string const& name); std::shared_ptr reduction(std::shared_ptr& input, cudnn_frontend::ReductionMode_t const mode, cudnn_frontend::DataType_t const& compute_data_type, std::string const& name); std::shared_ptr reshape(std::shared_ptr& input, std::string const& name); std::vector> rmsnorm(cudnn_frontend::NormFwdPhase_t const forward_phase, std::shared_ptr& x, std::shared_ptr& scale, std::shared_ptr& bias, std::shared_ptr& epsilon, cudnn_frontend::DataType_t const& compute_data_type, std::string const& name); std::vector> rmsnorm_backward(std::shared_ptr const& dy, std::shared_ptr const& x, std::shared_ptr const& scale, std::shared_ptr const& inv_variance, bool const has_dbias, cudnn_frontend::DataType_t const& compute_data_type, std::string const& name); std::vector> instancenorm(cudnn_frontend::NormFwdPhase_t const forward_phase, std::shared_ptr& x, std::shared_ptr& scale, std::shared_ptr& bias, std::shared_ptr& epsilon, cudnn_frontend::DataType_t const& compute_data_type, std::string const& name); std::vector> instancenorm_backward(std::shared_ptr const& dy, std::shared_ptr const& x, std::shared_ptr const& scale, std::shared_ptr const& mean, std::shared_ptr const& inv_variance, cudnn_frontend::DataType_t const& compute_data_type, std::string const& name); // return [o, stats] std::array, 2> sdpa(std::shared_ptr& q, std::shared_ptr& k, std::shared_ptr& v, bool const is_inference, py::object const& attn_scale, std::shared_ptr& bias, bool const use_alibi_mask, bool const use_padding_mask, std::shared_ptr& seq_len_q, std::shared_ptr& seq_len_kv, bool const use_causal_mask, bool const use_causal_mask_bottom_right, py::object const& sliding_window_length, cudnn_frontend::DiagonalAlignment_t const& diagonal_alignment, py::object const& left_bound, py::object const& right_bound, py::object const& dropout, std::shared_ptr& rng_dump, std::shared_ptr& paged_attention_k_table, std::shared_ptr& paged_attention_v_table, py::object const& paged_attention_max_seq_len_kv, cudnn_frontend::DataType_t const& compute_data_type, std::string const& name); // return [dQ, dK, dV] std::array, 3> sdpa_backward(std::shared_ptr& q, std::shared_ptr& k, std::shared_ptr& v, std::shared_ptr& o, std::shared_ptr& dO, std::shared_ptr& stats, py::object const& attn_scale, std::shared_ptr& bias, std::shared_ptr& dBias, bool const use_alibi_mask, bool const use_padding_mask, std::shared_ptr& seq_len_q, std::shared_ptr& seq_len_kv, py::object const& max_total_seq_len_q, py::object const& max_total_seq_len_kv, bool const use_causal_mask, bool const use_causal_mask_bottom_right, py::object const& sliding_window_length, cudnn_frontend::DiagonalAlignment_t const& diagonal_alignment, py::object const& left_bound, py::object const& right_bound, py::object const& dropout, std::shared_ptr& rng_dump, bool const use_deterministic_algorithm, cudnn_frontend::DataType_t const& compute_data_type, std::string const& name); // return [o, stats, amax_s, amax_o] std::array, 4> sdpa_fp8(std::shared_ptr& q, std::shared_ptr& k, std::shared_ptr& v, std::shared_ptr& descale_q, std::shared_ptr& descale_k, std::shared_ptr& descale_v, std::shared_ptr& descale_s, std::shared_ptr& scale_s, std::shared_ptr& scale_o, bool const is_inference, py::object const& attn_scale, bool const use_padding_mask, std::shared_ptr& seq_len_q, std::shared_ptr& seq_len_kv, bool const use_causal_mask, bool const use_causal_mask_bottom_right, py::object const& dropout, cudnn_frontend::DataType_t const& compute_data_type, std::string const& name); // return [dQ, dK, dV, amax_dQ, amax_dK, amax_dV, amax_dP] std::array, 7> sdpa_fp8_backward(std::shared_ptr& q, std::shared_ptr& k, std::shared_ptr& v, std::shared_ptr& o, std::shared_ptr& dO, std::shared_ptr& stats, std::shared_ptr& descale_q, std::shared_ptr& descale_k, std::shared_ptr& descale_v, std::shared_ptr& descale_o, std::shared_ptr& descale_dO, std::shared_ptr& descale_s, std::shared_ptr& descale_dP, std::shared_ptr& scale_s, std::shared_ptr& scale_dQ, std::shared_ptr& scale_dK, std::shared_ptr& scale_dV, std::shared_ptr& scale_dP, py::object const& attn_scale, bool const use_padding_mask, std::shared_ptr& seq_len_q, std::shared_ptr& seq_len_kv, bool const use_causal_mask, bool const use_causal_mask_bottom_right, py::object const& dropout, cudnn_frontend::DataType_t const& compute_data_type, std::string const& name); void validate(); size_t key(); void build_operation_graph(); void create_execution_plans(std::vector const&); void create_execution_plan(int64_t const engine_id, std::unordered_map const& knobs); int64_t get_engine_count(); std::vector get_knobs_for_engine(int64_t const engine_id); void build_plans(BuildPlanPolicy_t const); void build_plan_at_index(int64_t const index); void check_support(); void build(std::vector const&); int64_t get_workspace_size(); void populate_cuda_graph(std::intptr_t handle, std::unordered_map var_pack, std::intptr_t workspace, std::intptr_t cuda_graph); void update_cuda_graph(std::intptr_t handle, std::unordered_map var_pack, std::intptr_t workspace, std::intptr_t cuda_graph); void execute(std::unordered_map var_pack, int64_t workspace, std::optional); void execute_plan_at_index(std::unordered_map var_pack, int64_t workspace, int64_t index, std::optional); std::vector get_behavior_notes(); std::vector get_behavior_notes_for_plan_at_index(int64_t const index); void select_numeric_notes(std::vector const& notes) { graph.select_numeric_notes(notes); return; } void select_behavior_notes(std::vector const& notes) { graph.select_behavior_notes(notes); return; } void deselect_engines(std::vector const& engine_names) { graph.deselect_engines(engine_names); return; } void deselect_numeric_notes(std::vector const& notes) { graph.deselect_numeric_notes(notes); return; } void deselect_behavior_notes(std::vector const& notes) { graph.deselect_behavior_notes(notes); return; } void deselect_workspace_greater_than(int64_t const workspace) { graph.deselect_workspace_greater_than(workspace); return; } std::vector serialize() const; void deserialize(py::object const& pyobj); int64_t get_execution_plan_count() const { return graph.get_execution_plan_count(); } int64_t get_workspace_size_plan_at_index(int64_t index); std::shared_ptr query_tensor_attributes_of_uid(int64_t const uid) const; std::string get_plan_name_at_index(int64_t index); }; } // namespace cudnn_frontend::python_bindings