#pragma once // SPDX-License-Identifier: MIT // Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. #include namespace aiter { namespace torch_itfs { std::vector mha_bwd(const at::Tensor& dout, // [b, sq, hq, d] const at::Tensor& q, // [b, sq, hq, d] const at::Tensor& k, // [b, sk, hk, d] const at::Tensor& v, // [b, sk, hk, d] const at::Tensor& out, // [b, sq, hq, d] const at::Tensor& lse, // [b, hq, sq] float p_dropout, float softmax_scale, bool is_causal, int window_size_left, int window_size_right, bool deterministic, std::optional dq, // [b, sq, hq, d] std::optional dk, // [b, sk, hk, d] std::optional dv, // [b, sk, hk, d] std::optional dbias_, // [sq, sk] std::optional bias_, // [sq, sk] std::optional alibi_slopes, // [hq] or [b, hq] std::optional rng_state, std::optional gen); } // namespace torch_itfs } // namespace aiter