#pragma once // @generated by torchgen/gen.py from Function.h #include #include #include #include #include #include #include #include #include #include #include #include #include #include namespace at { // aten::_flash_attention_backward(Tensor grad_out, Tensor query, Tensor key, Tensor value, Tensor out, Tensor logsumexp, Tensor cum_seq_q, Tensor cum_seq_k, SymInt max_q, SymInt max_k, float dropout_p, bool is_causal, Tensor rng_state, Tensor unused, *, float? scale=None, SymInt? window_size_left=None, SymInt? window_size_right=None) -> (Tensor, Tensor, Tensor) inline ::std::tuple _flash_attention_backward(const at::Tensor & grad_out, const at::Tensor & query, const at::Tensor & key, const at::Tensor & value, const at::Tensor & out, const at::Tensor & logsumexp, const at::Tensor & cum_seq_q, const at::Tensor & cum_seq_k, int64_t max_q, int64_t max_k, double dropout_p, bool is_causal, const at::Tensor & rng_state, const at::Tensor & unused, ::std::optional scale=::std::nullopt, ::std::optional window_size_left=::std::nullopt, ::std::optional window_size_right=::std::nullopt) { return at::_ops::_flash_attention_backward::call(grad_out, query, key, value, out, logsumexp, cum_seq_q, cum_seq_k, max_q, max_k, dropout_p, is_causal, rng_state, unused, scale, window_size_left.has_value() ? ::std::make_optional(c10::SymInt(*window_size_left)) : ::std::nullopt, window_size_right.has_value() ? ::std::make_optional(c10::SymInt(*window_size_right)) : ::std::nullopt); } namespace symint { template >> ::std::tuple _flash_attention_backward(const at::Tensor & grad_out, const at::Tensor & query, const at::Tensor & key, const at::Tensor & value, const at::Tensor & out, const at::Tensor & logsumexp, const at::Tensor & cum_seq_q, const at::Tensor & cum_seq_k, int64_t max_q, int64_t max_k, double dropout_p, bool is_causal, const at::Tensor & rng_state, const at::Tensor & unused, ::std::optional scale=::std::nullopt, ::std::optional window_size_left=::std::nullopt, ::std::optional window_size_right=::std::nullopt) { return at::_ops::_flash_attention_backward::call(grad_out, query, key, value, out, logsumexp, cum_seq_q, cum_seq_k, max_q, max_k, dropout_p, is_causal, rng_state, unused, scale, window_size_left.has_value() ? ::std::make_optional(c10::SymInt(*window_size_left)) : ::std::nullopt, window_size_right.has_value() ? ::std::make_optional(c10::SymInt(*window_size_right)) : ::std::nullopt); } } // aten::_flash_attention_backward(Tensor grad_out, Tensor query, Tensor key, Tensor value, Tensor out, Tensor logsumexp, Tensor cum_seq_q, Tensor cum_seq_k, SymInt max_q, SymInt max_k, float dropout_p, bool is_causal, Tensor rng_state, Tensor unused, *, float? scale=None, SymInt? window_size_left=None, SymInt? window_size_right=None) -> (Tensor, Tensor, Tensor) inline ::std::tuple _flash_attention_backward_symint(const at::Tensor & grad_out, const at::Tensor & query, const at::Tensor & key, const at::Tensor & value, const at::Tensor & out, const at::Tensor & logsumexp, const at::Tensor & cum_seq_q, const at::Tensor & cum_seq_k, c10::SymInt max_q, c10::SymInt max_k, double dropout_p, bool is_causal, const at::Tensor & rng_state, const at::Tensor & unused, ::std::optional scale=::std::nullopt, ::std::optional window_size_left=::std::nullopt, ::std::optional window_size_right=::std::nullopt) { return at::_ops::_flash_attention_backward::call(grad_out, query, key, value, out, logsumexp, cum_seq_q, cum_seq_k, max_q, max_k, dropout_p, is_causal, rng_state, unused, scale, window_size_left, window_size_right); } namespace symint { template >> ::std::tuple _flash_attention_backward(const at::Tensor & grad_out, const at::Tensor & query, const at::Tensor & key, const at::Tensor & value, const at::Tensor & out, const at::Tensor & logsumexp, const at::Tensor & cum_seq_q, const at::Tensor & cum_seq_k, c10::SymInt max_q, c10::SymInt max_k, double dropout_p, bool is_causal, const at::Tensor & rng_state, const at::Tensor & unused, ::std::optional scale=::std::nullopt, ::std::optional window_size_left=::std::nullopt, ::std::optional window_size_right=::std::nullopt) { return at::_ops::_flash_attention_backward::call(grad_out, query, key, value, out, logsumexp, cum_seq_q, cum_seq_k, max_q, max_k, dropout_p, is_causal, rng_state, unused, scale, window_size_left, window_size_right); } } }