#pragma once #include "flash.h" namespace FLASH_NAMESPACE { struct Flash_fwd_params_sparse : public Flash_fwd_params { // For sparse attention const int* block_count; const int* block_offset; const int* column_count; const int* column_index; int NUM_ROWS; int NNZ_S; int NNZ_V; }; //////////////////////////////////////////////////////////////////////////////////////////////////// template void run_mha_fwd_sparse_(Flash_fwd_params_sparse ¶ms, cudaStream_t stream); } // namespace FLASH_NAMESPACE