o
    )i                     @   s4  d Z ddlZddlmZ ddlmZ ddlmZ ddlm	Z	 ddl
mZ ddlmZmZmZmZmZmZmZmZmZ ddlZdd	lmZ dd
lmZ ddlmZmZmZmZm Z m!Z! ddl"m#Z#m$Z$m%Z%m&Z& ddl'm(Z( ddl)m*Z* ddl+m,Z,m-Z-m.Z. ddl/m0Z0 ddl1m2Z2 ddl3m4Z4 ddl5m6Z6m7Z7m8Z8m9Z9 e4rddl:m;Z; ndZ;z
ddl<m=Z= dZ>W n e?y   dZ>zddl@m=Z= W n e?y   dZ=Y nw Y nw erddlAmBZB e2C ZDG dd deZEedddZFG dd de eeF ZGe	G d d deZHG d!d" d"eeF eeF ZIG d#d$ d$e!eF eeF ZJdS )%a  
# MLA Common Components

This file implements common components for MLA implementations.

First we define:

Sq      as Q sequence length
Skv     as KV sequence length

MLA has two possible ways of computing, a data-movement friendly approach and a
compute friendly approach, we generally want to use the compute friendly
approach for "prefill" (i.e. the ratio Sq / Skv is "small", is near 1)
and the data-movement friendly approach for "decode" (i.e. the ratio
Sq / Skv is "large").

NOTE what we deem small and large is currently determined by if its labelled
prefill or decode by the scheduler, but this is something we should probably
tune.

Main reference: DeepseekV2 paper, and FlashInfer Implementation
(https://arxiv.org/abs/2405.04434 and https://github.com/flashinfer-ai/flashinfer/pull/551).

Deepseek's MLA attention works the following way:
* Use a single latent vector to represent the per-token entry of the KV cache.
* For decode (i.e. the memory friendly approach) the attention "simulates" a
multi-head attention, while the compute is similar to multi-query attention.

Below is example of both paths assuming batchsize = 1

## More Extent Definitions:

C           Context length, `Skv - Sq`
H           hidden size
N           number of attention heads
Lq          latent dimension for Q              1536 in DSV3
Lkv         latent dimension for K/V            512 in DSV3
P           nope dimension, no rope.            128 in DSV3
R           rope dimension, goes through rope.  64 in DSV3
V           V head dim.                         128 in DSV3

## Vector/Matrix Definitions

h_t         hidden states (input to attention)  shape [Sq, H]
q_c         latent/compressed Q                 shape [Sq, Lq]
q_nope      uncompressed Q (no-rope)            shape [Sq, N, P]
q_pe        uncompressed Q (rope)               shape [Sq, N, R]
kv_c        latent/compressed KV                shape [Skv, Lkv]
k_pe        decoupled k position embeddings     shape [Skv, R]
new_kv_c    new kv_c from current iter          shape [Sq, Lkv]
new_k_pe    new k_pe from current iter          shape [Sq, R]
cache_kv_c  cached k_c from previous iters      shape [C, Lkv]
cache_k_pe  cached k_pe from previous iters     shape [C, R]
W_DQ        project h_t to q_c                  shape [H, Lq]
W_UQ        project q_c to q_nope               shape [Lq, N * P]
W_QR        project q_c to q_pe                 shape [Lq, N * R]
W_DKV       project h_t to kv_c                 shape [H, Lkv]
W_UK        project kv_c to k_nope              shape [Lkv, N, P]
W_KR        project h_t to k_pe                 shape [H, R]
W_UV        project kv_c to v                   shape [Lkv, N, V]
W_O         project v to h_t                    shape [N * V, H]


## Compute Friendly Approach (i.e. "_forward_prefill"):

q_c      = h_t @ W_DQ
q_nope   = (q_c @ W_UQ).view(Sq, N, P)
q_pe     = RoPE(q_c @ W_QR).view(Sq, N, R)
new_kv_c = h_t @ W_DKV
new_k_pe = RoPE(h_t @ W_KR)
kv_c     = torch.cat([new_kv_c, cache_kv_c], dim=0)
k_pe     = torch.cat([new_k_pe, cache_k_pe], dim=0)
k_nope   = (kv_c @ W_UK.view(Lkv, N * P)).view(Skv, N, P)
v        = (kv_c @ W_UV.view(Lkv, N * V)).view(Skv, N, V)

// MHA with QK headdim = P + R
//           V headdim = V
//      spda_o shape [Sq, N, V]
spda_o = scaled_dot_product_attention(
    torch.cat([q_nope, q_pe], dim=-1),
    torch.cat([k_nope, k_pe.unsqueeze(1).expand(-1, N, -1)], dim=-1),
    v
) 
return spda_o @ W_O

NOTE: in the actual code, 
    `kv_b_proj` is [W_UK; W_UV] concatenated per head
    `q_b_proj` is [W_UQ; W_QR] concatenated per head
    `out_proj` is W_O


## Data-Movement Friendly Approach (i.e. "_forward_decode"):

Runtime
q_c      = h_t @ W_DQ
q_nope   = (q_c @ W_UQ).view(-1, N, P)
ql_nope  = einsum("snh,lnh->snl", q, W_UK)
q_pe     = RoPE(q_c @ W_QR).view(Sq, N, R)
new_kv_c = h_t @ W_DKV
new_k_pe = RoPE(h_t @ W_KR)
kv_c     = torch.cat([new_kv_c, cache_kv_c], dim=0)
k_pe     = torch.cat([new_k_pe, cache_k_pe], dim=0)

// MQA with QK headdim = Lkv + R
//           V headdim = Lkv
//      spda_o shape [Sq, N, Lkv]
// NOTE: this is less compute-friendly since Lkv > P
//       but is more data-movement friendly since its MQA vs MHA
spda_o = scaled_dot_product_attention(
    torch.cat([ql_nope, q_pe], dim=-1),
    torch.cat([kv_c, k_pe], dim=-1),
    kv_c
)

o = einsum("snl,lnv->snv", spda_o.reshape(-1, N, Lkv), W_UV)
return o.view(-1, N * V) @ self.num_heads @ W_O


## Chunked Prefill

For chunked prefill we want to use the compute friendly algorithm. We are 
assuming sufficiently large Sq / Skv ratio, in the future may want to switch to 
the data-movement friendly approach if the chunk (i.e. `Sq`) is small.

However, the compute-friendly approach can potentially run out of memory if Skv
is large due to: `k_nope = (kv_c @ W_UK).view(Skv, N, P)`

To mitigate this, we chunk the computation of attention with respect to the 
current context (i.e. `cache_kv_c` and `cache_k_pe`) so that we can used a 
fixed workspace size.

The chunked prefill approach is as follows:

MCC        Max chunk of context to process per iter, computed dynamically, 
           used to bound the memory usage

q_c        = h_t @ W_DQ
q_nope     = (q_c @ W_UQ).view(Sq, N, P)
q_pe       = RoPE(q_c @ W_QR).view(Sq, N, R)
new_kv_c   = h_t @ W_DKV
new_k_pe   = RoPE(h_t @ W_KR)
new_k_nope = (new_kv_c @ W_UK.view(Lkv, N * P)).view(Sq, N, P)
new_v      = (new_kv_c @ W_UV.view(Lkv, N * V)).view(Sq, N, V)

// MHA between queries and new KV
//     with QK headdim = P + R
//           V headdim = V
//    curr_o   shape [Sq, N, V]
//    curr_lse shape [N, Sq], this is just order FA returns
curr_o, curr_lse = scaled_dot_product_attention(
    torch.cat([q_nope, q_pe], dim=-1),
    torch.cat([new_k_nope, new_k_pe.unsqueeze(1).expand(-1, N, -1)], dim=-1),
    new_v,
    casual=True,
    return_softmax_lse=True
) 

// Compute attention with the already existing context
for chunk_idx in range(cdiv(C, MCC)):
    chunk_start  = chunk_idx * MCC
    chunk_end    = min(chunk_start + MCC, C)
    Sc           = chunk_end - chunk_start
    cache_kv_c_chunk   = cache_kv_c[chunk_start:chunk_end]
    cache_k_pe_chunk   = cache_k_pe[chunk_start:chunk_end]
    cache_k_nope_chunk = (cache_kv_c_chunk @ W_UK).view(-1, N, P)
    cache_v_chunk      = (cache_kv_c_chunk @ W_UV).view(-1, N, V)

    chunk_o, chunk_lse = scaled_dot_product_attention(
        torch.cat([q_nope, q_pe], dim=-1),
        torch.cat([cache_k_nope_chunk,
                   cache_k_pe_chunk.unsqueeze(1).expand(-1, N, -1)],
                   dim=-1),
        cache_v_chunk,
        casual=False,
        return_softmax_lse=True
    )

    curr_o, curr_lse = merge_attn_states(
        suffix_output=curr_o,
        suffix_lse=curr_lse,
        prefix_output=chunk_o,
        prefix_lse=chunk_lse,
    )

return curr_o @ W_O
    N)abstractmethod)defaultdict)contextmanager)	dataclass)
accumulate)	TYPE_CHECKINGAnyDictGenericListOptionalTupleTypeTypeVar)_custom_ops)envs)AttentionBackendAttentionLayerAttentionMetadataAttentionMetadataBuilderAttentionStateMLAAttentionImpl)PAD_SLOT_IDcompute_slot_mappingcompute_slot_mapping_start_idxis_block_tables_empty)merge_attn_states)get_flash_attn_version)ColumnParallelLinear
LinearBaseUnquantizedLinearMethod)MultiModalPlaceholderMap)current_platform)
HAS_TRITON)async_tensor_h2dcdivmake_tensor_with_pad
round_down)triton_attention)flash_attn_varlen_funcTF)ModelInputForGPUBuilderc                   @   s   e Zd ZedefddZeded fddZeded fdd	Zeded
 fddZ	ede
de
de
de
dee
df f
ddZedejdejdejddfddZedeej dejddfddZedee
 fddZdS ) MLACommonBackendreturnc                   C   s   dS )NZ
TRITON_MLA r-   r-   r-   n/home/app/PaddleOCR-VL-test/.venv_paddleocr/lib/python3.10/site-packages/vllm/attention/backends/mla/common.pyget_name      zMLACommonBackend.get_namer   c                   C      t S N)MLACommonMetadatar-   r-   r-   r.   get_metadata_cls   r0   z!MLACommonBackend.get_metadata_clsMLACommonMetadataBuilderc                   C   r1   r2   )r5   r-   r-   r-   r.   get_builder_cls   r0   z MLACommonBackend.get_builder_clsMLACommonStatec                   C   r1   r2   )r7   r-   r-   r-   r.   get_state_cls   r0   zMLACommonBackend.get_state_cls
num_blocks
block_sizenum_kv_heads	head_size.c                 C   s
   | ||fS r2   r-   )r9   r:   r;   r<   r-   r-   r.   get_kv_cache_shape  s   
z#MLACommonBackend.get_kv_cache_shapesrc_kv_cachedst_kv_cache
src_to_dstNc                 C   s   t | || d S r2   )opsswap_blocks)r>   r?   r@   r-   r-   r.   rB     s   zMLACommonBackend.swap_blocks	kv_cachessrc_to_distsc                 C   s   t | | d S r2   )rA   Zcopy_blocks_mla)rC   rD   r-   r-   r.   copy_blocks  s   zMLACommonBackend.copy_blocksc                   C   s   dgS )Ni@  r-   r-   r-   r-   r.   get_supported_head_sizes  s   z)MLACommonBackend.get_supported_head_sizes)__name__
__module____qualname__staticmethodstrr/   r   r4   r6   r8   intr   r=   torchTensorrB   r   rE   rF   r-   r-   r-   r.   r+      sR    
r+   Tr3   )boundc                   @   sz   e Zd Zdd ZedefddZdefddZ		dded
ede	fddZ
		dd
efddZ		dd
efddZdd ZdS )r7   c                 C   s   || _ d| _|j}|j| _|j}|j| _|j| _| js| jr>ttd| jj	 d|j
 |j d| _| j|j
|j ks<J d S d S )NF      i   )runner_is_graph_capturingscheduler_configmodel_configcache_configchunked_prefill_enabledenable_prefix_cachingminmaxZmax_model_lenZmax_num_seqsr:   context_chunk_workspace_size)selfrS   rU   rW   r-   r-   r.   __init__&  s,   
zMLACommonState.__init__max_batch_sizec                 c   s    d| _ tj|fttj| jjd| _tj|tj	| jjd| _
t| jjj| jjd| _tj|ftj| jjd| _d V  d| _ | `| `
| `| `d S )NTdtypedevice)rb   F)rT   rM   fullr   longrS   rb   _graph_slot_mappingZonesint32_graph_seq_lens
from_numpygraph_block_tablesto_graph_block_tableszerosZ
_positions)r]   r_   r-   r-   r.   graph_captureD  s4   zMLACommonState.graph_capture
batch_sizec                 C   s   | j sJ | | jS r2   )rT   	__class__rS   )r]   rn   r-   r-   r.   graph_clone^  s   
zMLACommonState.graph_cloneFis_encoder_decoder_modelr,   c                 C   s   | j sJ | jjjdi dd ddddddddd	|d
| jd | dd d| jd | ddddddd| jjdd dd dd d| jd | d| jj	 }|r^t
d|S )N"multi_modal_placeholder_index_mapsenable_kv_scales_calculationFuse_cuda_graphTnum_prefillsr   num_prefill_tokensnum_decode_tokensslot_mappingseq_lensseq_lens_tensormax_query_len   max_decode_query_lenmax_prefill_seq_lenmax_decode_seq_lenquery_start_locseq_start_loccontext_lens_tensorblock_tableshead_dim3MLACommonState does not support encoder/decoder yetr-   )rT   rS   attn_backendmake_metadatare   rg   Zmax_seq_len_to_capturerk   rV   get_head_sizeNotImplementedError)r]   rn   rq   attn_metadatar-   r-   r.   $graph_capture_get_metadata_for_batchb  sV   
	
z3MLACommonState.graph_capture_get_metadata_for_batchc                 C   s&   |j |jj|jjd}|rtd|S )N)rx   rz   r   r   )rx   decode_metadatarz   r   r   )r]   r   rq   input_buffersr-   r-   r.   get_graph_input_buffers  s   z&MLACommonState.get_graph_input_buffersc                 C   s<   |d j |jjdd |d j |jjdd |rtdd S )Nrz   T)non_blockingr   z3TritonMLAState does not support encoder/decoder yet)Zcopy_r   rz   r   r   )r]   r   r   rq   r-   r-   r.   prepare_graph_input_buffers  s   z*MLACommonState.prepare_graph_input_buffersc                 C   s\   | j s| jr,t| ds%|jd usJ tj| j| j f| jj	|jj
d| _| j|j_d S d S )Ncontext_chunk_workspacer`   )rX   rY   hasattrinput_tokensrM   emptyr\   rV   r   ra   rb   r   r   )r]   Zmodel_inputr-   r-   r.   begin_forward  s   

zMLACommonState.begin_forwardN)F)rG   rH   rI   r^   r   rL   rm   rp   boolrO   r   r   r   r   r-   r-   r-   r.   r7   $  s,    
"

r7   c                   @   sj  e Zd ZU dZeed< eee  ed< ee	j
 ed< eed< eed< ee	j
 ed< ee	j
 ed< d	Zee ed
< d	Zee ed< d	Zee	j
 ed< d	Zee	j
 ed< d	Zee ed< d	Zee ed< eed< d	Zee ed< dZeed< d	Zee	j
 ed< d	Zee	j
 ed< d	Zeee  ed< d	Zeee  ed< d	Zee	j
 ed< dd Zedd Zedd Zd	S )r3   a  Metadata for MLACommon. 
    
    NOTE: Please read the comment at the top of the file before trying to 
    understand this class

    NOTE: Any python object stored here is not updated when it is
    cuda-graph replayed. If you have values that need to be changed
    dynamically, it should be stored in tensor. The tensor has to be
    updated from `CUDAGraphRunner.forward` API.
    rt   ry   rz   r~   r   r   r   Nr{   r}   r   r   _cached_prefill_metadata_cached_decode_metadatarv   r   Fis_profile_runcontext_chunk_cu_seq_lenscontext_chunk_startscontext_chunk_seq_totcontext_chunk_max_seq_lensr   c                 C   s@   t  }| jd ur| j|vrtd| dd| j dd S d S )NzOnly z are supported for head_dim,z
 received .)r+   rF   r   
ValueError)r]   Zsupported_head_sizesr-   r-   r.   __post_init__  s   
zMLACommonMetadata.__post_init__c                 C   s  | j dkrd S | jd ur| jS | jd usJ | jd usJ | jd u r$d n	| jd | j d  }| jd u r5d n| jd | j }| jd u rDd n| jd | j  }| jd u rSd n| jd | j  }| jd u rbd n	| jd | j d  }| jd u rsd n| jd | j  }| j	d u rd n| j	d | j  }| j
di ddd| j d| jddd|d	d d
dd|d|d| jd| jddddd|d|d|d|d| jd| jd| jd| jd| jd| j| _| jS )Nr   r|   rt   Fru   rv   rw   rx   rr   rs   ry   rz   r{   r~   r}   r   r   r   r   r   r   r   r   r   r   r   r-   )ru   r   ry   rz   r   rx   rv   r   r   r   ro   r{   r~   r   r   r   r   r   r   )r]   r   rx   ry   rz   r   r   r   r-   r-   r.   prefill_metadata  s   


	
z"MLACommonMetadata.prefill_metadatac                 C   s  | j dkrd S | jd ur| jS | jd usJ | jd u rd n| j| jd  }| jd u r,d n| j| jd  }| jd u r;d n| j| jd  }| jdi d| jddddd| j d|dd dd	d
d d|d| j	d| j
ddd| jd| jd ur| j| jd  | j| j  nd d| jd ur| j| jd  nd dd d|d| jd| j| _| jS dd d|d| jd| j| _| jS )Nr   rt   ru   rv   rw   rx   rr   rs   Fry   rz   r}   r{   r~   r   r   r   r   r   r   r   r-   )rw   r   rz   rx   rv   ru   r   ro   rt   r}   r{   r   r   r   r   r   )r]   rx   rz   r   r-   r-   r.   r   H  s   


	



z!MLACommonMetadata.decode_metadata)rG   rH   rI   __doc__r   __annotations__r   r   rL   rM   rN   r{   r}   r   r   r   r   r   r   r   r   r   r   r   r   r   propertyr   r   r-   r-   r-   r.   r3     s8   
 
:c                   @   s   e Zd ZU dZg Zeee  ed< dddZdd Z	d	d
de
de
fddZdedeee  dejfddZdee dee dedefddZdS )r5   j
    NOTE: Please read the comment at the top of the file before trying to 
    understand this class
    BLOCK_TABLE_EXTENDERinput_builderr*   c                 C   sf   || _ |j| _|j| _|j| _| jjj| _| jjj| _| js!| jr1| j jj}|j	| _	| jj| _
d S d S r2   )r   rS   sliding_windowr:   rU   rX   rW   rY   
attn_stater\   	page_size)r]   r   r   r-   r-   r.   r^     s   
z!MLACommonMetadataBuilder.__init__c                 C   sD   g | _ g | _g | _g | _g | _tt| _d| _d| _	d| _
d| _d S )Nr   F)rx   prefill_seq_lenscontext_lensr   curr_seq_lensr   r!   Zmultimodal_placeholder_mapsru   rv   rw   Zhas_prefix_cache_hit)r]   r-   r-   r.   prepare  s   
z MLACommonMetadataBuilder.prepare
inter_dataz,ModelInputForGPUBuilder.InterDataForSeqGrouprX   prefix_cache_hitc              
   C   s*  |j }|j}t|jdd |jD |j|j|j|j|j	D ]v\}}}}	}
}}| j
| |rB|  jd7  _|  j|7  _| j
| n|  j|
7  _| j
|	 g }|rX|| }n|s\|sr|durr|dkri|| }n	|| | d }| j
| t|}t||
|| j}t|| j||||| j|j qdS )zAdd a sequence group to the metadata. Specifically update/append
        1. context length.
        2. block table.
        3. slot mapping.
        c                 S   s   g | ]}t |qS r-   )len).0tr-   r-   r.   
<listcomp>  s    z;MLACommonMetadataBuilder._add_seq_group.<locals>.<listcomp>r|   Nr   )	is_promptr   zipZseq_idsr   Zorig_seq_lensry   
query_lensr   Zcurr_sliding_window_blocksappendru   rv   r   rw   r   r   r   r   r   rx   r:   )r]   r   rX   r   r   r   Zseq_idZ	token_lenZseq_lenZcurr_seq_lenZ	query_lenZcontext_lenZcurr_sliding_window_blockblock_tabler   Z	start_idxr-   r-   r.   _add_seq_group  sL   


z'MLACommonMetadataBuilder._add_seq_groupnum_seqsr   r,   c           	      C   s   | j jj\}}||ksJ | j jd | }t|D ]#\}}|r<t|}||kr0|||d |f< q|d | ||d |f< qt|j| j jddS )NT)rb   r   )	rS   ri   shape	enumerater   rM   rh   rj   rb   )	r]   r   r   r_   Z
max_blocksri   ir   r9   r-   r-   r.   _get_graph_runner_block_tables  s    


z7MLACommonMetadataBuilder._get_graph_runner_block_tablesry   r   cuda_graph_pad_sizern   c           #      C   s  t dd | jjD }| jjD ]}| || jj| q| jj}|dk}t|}	|| jd }
t	|
dkr9t|
}nd}t| j
dd}t| jdd}| j}tt|dd}tt|dd}t	|}|r| jtg|  | j| jj|  || j }| || j}n
t| jdtj|d	}|	dksJ d
||dusJ t| jtj|| jj}t|tj|| jj}t| jtj|| jj}t|tj|| jj}t|tj|| jj}d}d}d}d}| js| j r| jdkr|dur|d| j  dkr|d| j dk! " }| j#| }t$|| j%}|dksJ t&| |}tj'||tjd(d)d| j| }t*|d| j (d|| }|| j+dd} | j,dd-tj}!tj.|tj|d(d}"tj/|"|!gdd}| jddj01 }| j!dd1 }t|| j#ksJ | jj2j3d'i d|d| jd|d| jd|ddddd|d|d|	d|d|d|d|d|d|d |d!| jj45 d"| jj6d#|d$|d%|d&|S )(a  Build attention metadata with on-device tensors.

        Args:
            seq_lens: The maybe padded sequence lengths of the input sequences.
            query_lens: The query lengths of the input sequences.
            cuda_graph_pad_size: The padding size for cuda graph.
                                 -1 if cuda graph is not used.
            batch_size: The maybe padded batch size.
        c                 S   s   g | ]}|j qS r-   )r   )r   r   r-   r-   r.   r     s    z2MLACommonMetadataBuilder.build.<locals>.<listcomp>Nr   r|   )default)initial)padra   rb   zquery_lens: {}rb   ra   )rZ   dimr`   rt   ru   rx   rv   rw   rr   rs   Fry   rz   r{   r}   r~   r   r   r   r   r   r   r   r   r   r   r   r-   )7anyr   Zinter_data_listr   rX   rS   rb   r[   ru   r   r   r   rw   listr   rx   extendr   r   ro   r   rv   r   r&   rM   rL   formatr$   r   Z
pin_memoryrd   rf   rY   sumitemr\   r'   r   r%   Zarange	unsqueezeexpandrZ   clampZcumsumrj   rl   catvaluestolistr   r   rV   r   Zin_profile_run)#r]   ry   r   r   rn   r   r   rb   Zuse_captured_graphr{   Zdecode_query_lensr}   r~   r   rw   r   r   r   r   r   rz   Zslot_mapping_tensorZquery_start_loc_tensorZseq_start_loc_tensorr   r   r   r   Znum_prefills_with_contextZmax_context_chunkZ
num_chunksZ
chunk_endsZchunk_seq_lensZ_context_chunk_cu_seq_lenszeror-   r-   r.   build  s  






	
zMLACommonMetadataBuilder.buildN)r   r*   )rG   rH   rI   r   r   r   rL   r   r^   r   r   r   r   rM   rN   r   r   r-   r-   r-   r.   r5   z  s0   
 

5

r5   c                $   @   sT  e Zd ZdZdededededeee  dee ded	ee d
edee dee dededededede	ddf$ddZ
dd Zdd ZdejfddZdejdejd efd!d"Zdejd#ejd$ejdejd edejfd%d&Zed'ejd(ejdejd edejf
d)d*Z		d2d+edejd,ejd$ejd-ejd ed.eej d/eej dejfd0d1ZdS )3MLACommonImplr   	num_headsr<   scaler;   alibi_slopesr   kv_cache_dtypelogits_soft_cap	attn_typekv_sharing_target_layer_nameq_lora_rankkv_lora_rankqk_nope_head_dimqk_rope_head_dimqk_head_dim
v_head_dim	kv_b_projr,   Nc                 C   s   |
d urt d|| _|| _t|| _|| _|| _|| _|| _|| _	|| _
|| _|| _|| _t| _t| _t | _| jd urFtjt| jd| _| jd u pX| jdkoWt d dk | _d S )NzKV sharing not supported in V0.)Z
fa_version   r   	   )r   r   r<   floatr   r;   r   r   r   r   r   r   r   r   r(   triton_fa_funcr)   r   Zvllm_flash_attn_version	functoolspartialr"   Zget_device_capability_pad_v)r]   r   r<   r   r;   r   r   r   r   r   r   r   r   r   r   r   r   r   r-   r-   r.   r^     s4   




zMLACommonImpl.__init__c           
      K   s   |}| j rtjjj|d|jd |jd  gdd}tr:tjr:|s:| 	|||d |d |d |d |d |d |d }nt
rK| jd|||||d	|}n| jd|||||d
|}d }	t|trd|^}}	|rr|	d uslJ ||	d fS |S )Nr   r   )valuecu_seqlens_qcu_seqlens_kmax_seqlen_qmax_seqlen_kcausal)qkvreturn_softmax_lsesoftmax_scale)r   r   r   Zreturn_attn_probsr   r-   )r   rM   nnZ
functionalr   r   is_hipr   ZVLLM_USE_TRITON_FLASH_ATTNr   
is_vllm_far)   
isinstancetuple)
r]   r   r   r   r   r   kwargsZmaybe_padded_vZattn_outrestr-   r-   r.    _flash_attn_varlen_diff_headdims  s^   


z.MLACommonImpl._flash_attn_varlen_diff_headdimsc                 C   sD   | d| j| jdd}t|| j}|ddd| j| j S )Nr   r   r|   )	viewr   r   	transposerM   bmmW_UVZreshaper   )r]   xr-   r-   r.   
_v_up_proj  s   zMLACommonImpl._v_up_proj	act_dtypec                    s   dd dt f fdd}|| jj}|j| j| j| j| j  fks;J d|jd| jd| jd	| jd
| j
|| j| j| j| j }|j	| j| jgdd\}}|
dd| _|ddd| _d S )Nc                 S   s<   d}|D ]}t | |rt| |  S qtd|  d| d)N)weightZqweightZweight_packedzLayer 'z&' has no recognized weight attribute: r   )r   getattrAttributeError)layerZWEIGHT_NAMESattrr-   r-   r.   get_layer_weight  s   
zEMLACommonImpl.process_weights_after_loading.<locals>.get_layer_weightr	  c                    sD   t | jtstj| j | jd}| jj| |d d}~|jS | j	S )Nr`   )Zbias)
r   Zquant_methodr    rM   eyeZinput_size_per_partitionrb   applyrO   r  )r	  r  Zdequant_weightsr  r  r-   r.   get_and_maybe_dequant_weights  s   zRMLACommonImpl.process_weights_after_loading.<locals>.get_and_maybe_dequant_weightszkv_b_proj_weight.shape=z, self.kv_lora_rank=z, self.num_heads=z, self.qk_nope_head_dim=z, self.v_head_dim=r   r   r   r|      )r   r   rO   r   r   r   r   r   r   splitr   r  ZpermuteW_UK_T)r]   r  r  Zkv_b_proj_weightZW_UKr  r-   r  r.   process_weights_after_loading  s6   	


z+MLACommonImpl.process_weights_after_loadingr   kv_c_and_k_pe_cacher   c                 C   s  |j }|d us	J |jd usJ |jd usJ |jd usJ |jd us%J |jd us,J d }t|j}|jd us:J |j}t|D ]}|j| }	t	j
|||j|j| |j|j| d |d |	 dd | jf }
|d |	 d| jd f d}| |
d d| j| j| j }|j| j| jgdd\}}tj||g |jd d dR fdd}| j||||j|j| |j|j| | jddd	
\}}|d u r|}|}qAt|}t|}t||||||d
 |}|}qA||fS )N)Z	src_cachedstr   Zcu_seq_lensrn   Z
seq_starts.r|   r   r   r   FT
r   r   r   r   r   r   r   r   r   r   )output
output_lseprefix_output
prefix_lsesuffix_output
suffix_lse)r   r   r   r   r   r   r   r   rangerA   Zgather_cacher   ru   r   r   r   r   r   r   r   r  rM   r   r   r   r   r   r{   r   
empty_liker   )r]   r   r  r   r   r  ZitersZ	workspacer   tokskv_c_normedk_pekv_nopek_noper   r   Zattn_outputZattn_softmax_lser  Z
output_tmpZoutput_lse_tmpr-   r-   r.   _compute_prefill_context8  s   


	

$

z&MLACommonImpl._compute_prefill_contextr   r!  c                 C   s  |j }|d us	J |jd uo|j dk}| |d d| j| j| j }|j| j| jgdd\}	}
t	j
|	|g |	jd d dR fdd}| j|||
|j|j|j|j| jd|d
}|rz|\}}| |||\}}t	|}t|||||d | jr|dd |
jd f }|jdd	S )
Nr   r   r   Tr  )r  r  r  r  r  .)Z	start_dim)r   r   r[   r   r   r   r   r   r  rM   r   r   r   r   r   r~   r   r$  r  r   r   flatten)r]   r   r   r!  r  r   r   Zhas_contextr"  r#  r   r   r  r  r  Zcontext_outputZcontext_lser-   r-   r.   _forward_prefill  sP   	

,
	zMLACommonImpl._forward_prefillql_nopeq_pec                 C   s   t r2   )r   )r]   r(  r)  r  r   r-   r-   r.   _forward_decode  s   zMLACommonImpl._forward_decoder	  
k_c_normedkv_cacher  output_scalec	                 C   s  |d urt d|d urt d|jr.|jd ur.tj|jjd | j| j| j f|j	|j
d}	|jd u}
|jd u}|j}|d| j| j}||d  }|d | }|d | }|d | }| dkrutj||d||j | j|jd tj|j|j | j| j |j	|j
d}|r| ||||||d |< |
r|j| j| jgdd\}}|dd}t|| j}|dd}| ||||||d < |S )	Nz+output is not yet supported for MLAImplBasez>fused output quantization is not yet supported for MLAImplBaser   r   r   r|   )r   r   r   ) r   r   r   rM   r   r   r   r   r   rb   ra   r   r   rv   r   r   ZnumelrA   Zconcat_and_cache_mlaZsqueezerx   r&  r   Z_k_scalerw   r'  r  r   r   r  r  r*  )r]   r	  r   r+  r!  r,  r   r  r-  _Z
has_decodeZhas_prefillrv   decode_qZ	prefill_qZprefill_k_peZprefill_k_c_normedZdecode_q_nopeZdecode_q_peZdecode_ql_noper-   r-   r.   forward  sr   



	

zMLACommonImpl.forward)NN)rG   rH   rI   r   rL   r   r   r   rK   r   r^   r   r  rM   ra   r  rN   r3   r$  r'  r   rO   r*  r   r0  r-   r-   r-   r.   r     s    
	

8;3
Q
8	
r   )Kr   r   abcr   collectionsr   
contextlibr   dataclassesr   	itertoolsr   typingr   r   r	   r
   r   r   r   r   r   rM   Zvllmr   rA   r   Z vllm.attention.backends.abstractr   r   r   r   r   r   Zvllm.attention.backends.utilsr   r   r   r   Z$vllm.attention.ops.merge_attn_statesr   Zvllm.attention.utils.fa_utilsr   Z!vllm.model_executor.layers.linearr   r   r    Zvllm.multimodalr!   Zvllm.platformsr"   Zvllm.triton_utilsr#   Z
vllm.utilsr$   r%   r&   r'   Z)vllm.attention.ops.triton_flash_attentionr(   Zvllm.vllm_flash_attnr)   r   ImportErrorZ
flash_attnZvllm.worker.model_runnerr*   Zis_rocmr   r+   rO   r7   r3   r5   r   r-   r-   r-   r.   <module>   sd    <, /  J  