o
    )i                     @  sv  d Z ddlmZ ddlmZ ddlmZmZmZ ddl	Z	ddl
mZmZmZ ddlmZmZmZ ddlmZ ddlmZ dd	lmZmZmZ dd
lmZmZ ddlmZ ddlm Z m!Z! ddl"m#Z# ddl$m%Z% ddl&m'Z'm(Z(m)Z)m*Z*m+Z+m,Z,m-Z- ddl.m/Z/ dZ0ee1Z2G dd deZ3eG dd dZ4G dd de(e4 Z5G dd deZ6										d9d:d7d8Z7dS );z Attention layer with FlashInfer.    )annotations)	dataclass)ClassVarOptionalUnionN)"BatchDecodeWithPagedKVCacheWrapper#BatchPrefillWithPagedKVCacheWrapper!MultiLevelCascadeAttentionWrapper)_get_range_bufget_seq_lens!trtllm_batch_decode_with_kv_cache)"trtllm_batch_context_with_kv_cache)AttentionBackendAttentionImplAttentionType)CUDAGraphMode
VllmConfig)init_logger)cdivis_pin_memory_available)use_trtllm_attention)use_cascade_attention)AttentionCGSupportAttentionMetadataBuilderCommonAttentionMetadataget_kv_cache_layoutget_per_layer_parametersinfer_global_hyperparameterssplit_decodes_and_prefills)AttentionSpeci   c                   @  s   e Zd ZU dZded< ed)ddZed*d	d
Zed+ddZe	d,ddZ
e	d-ddZe	d.ddZe	d/ddZe	d0d d!Ze	d1d"d#Ze	d2d&d'Zd(S )3FlashInferBackendTboolaccept_output_bufferreturnlist[torch.dtype]c                 C  s   t jt jgS N)torchfloat16Zbfloat16cls r*   q/home/app/PaddleOCR-VL-test/.venv_paddleocr/lib/python3.10/site-packages/vllm/v1/attention/backends/flashinfer.pyget_supported_dtypes-   s   z&FlashInferBackend.get_supported_dtypes	list[int]c                 C  s   g dS )N)@         r*   r(   r*   r*   r+   get_supported_head_sizes1   s   z*FlashInferBackend.get_supported_head_sizes	head_sizeintNonec                 C  s<   |   }||vr| jd}td| d| d| dd S )NBackendz
Head size z is not supported by z. Supported head sizes are: zg. Set VLLM_ATTENTION_BACKEND=FLEX_ATTENTION to use FlexAttention backend which supports all head sizes.)r1   __name__removesuffix
ValueError)r)   r2   Zsupported_head_sizes	attn_typer*   r*   r+   validate_head_size6   s   z$FlashInferBackend.validate_head_sizestrc                   C  s   dS )NZFLASHINFER_VLLM_V1r*   r*   r*   r*   r+   get_nameA      zFlashInferBackend.get_nametype[FlashInferImpl]c                   C     t S r%   )FlashInferImplr*   r*   r*   r+   get_impl_clsE   r=   zFlashInferBackend.get_impl_clstype[FlashInferMetadata]c                   C  r?   r%   )FlashInferMetadatar*   r*   r*   r+   get_metadata_clsI   r=   z"FlashInferBackend.get_metadata_clstype[FlashInferMetadataBuilder]c                   C  r?   r%   )FlashInferMetadataBuilderr*   r*   r*   r+   get_builder_clsM   r=   z!FlashInferBackend.get_builder_cls
num_blocks
block_sizenum_kv_headstuple[int, ...]c                 C  s   | d|||fS N   r*   )rH   rI   rJ   r2   r*   r*   r+   get_kv_cache_shapeQ   s   z$FlashInferBackend.get_kv_cache_shapec                  C  s6   t  } | dkrd}|S | dkrd}|S td|  d)NZNHD)r      rM         HND)r   rO   rP   rM   rQ   zUnknown cache layout format .)r   r8   )Zcache_layoutstride_orderr*   r*   r+   get_kv_cache_stride_orderZ   s   z+FlashInferBackend.get_kv_cache_stride_orderkv_cache_dtypetorch.dtypec                 C  s*   | dv rt jS | dkrt jS td|  )N)fp8Zfp8_e4m3Zfp8_e5m2zUnrecognized FP8 dtype: )r&   Zfloat8_e4m3fnZfloat8_e5m2r8   )rV   r*   r*   r+   get_fp8_dtype_for_flashinferg   s
   z.FlashInferBackend.get_fp8_dtype_for_flashinferN)r#   r$   )r#   r-   )r2   r3   r#   r4   )r#   r;   )r#   r>   )r#   rB   )r#   rE   )
rH   r3   rI   r3   rJ   r3   r2   r3   r#   rK   )r#   rK   )rV   r;   r#   rW   )r6   
__module____qualname__r"   __annotations__classmethodr,   r1   r:   staticmethodr<   rA   rD   rG   rN   rU   rY   r*   r*   r*   r+   r    )   s,   
 
r    c                   @  s:  e Zd ZU ded< ded< ded< ded< ded< ded< ded	< ded
< ded< ded< ded< ded< ded< ded< ded< ded< ded< ded< ded< ded< ded< ded< ded< dZded< dZded< dZded < dZded!< dZd"ed#< dZ	d$ed%< dZ
d&ed'< dZded(< dZded)< d*d+ ZdS ),rC   r3   num_actual_tokenstorch.Tensorqo_indptr_cpupaged_kv_indptr_cpupaged_kv_indicespaged_kv_last_page_len_cpunum_qo_headsrJ   head_dim	page_sizerW   kv_data_typeq_data_typeslot_mapping	max_q_lenmax_seq_lenseq_lensblock_table_tensorr!   prefill_use_trtllmdecode_use_trtllmnum_decodesnum_decode_tokensnum_prefillsnum_prefill_tokensuse_cascadeNOptional[torch.Tensor]shared_qo_indptr_cpushared_kv_page_indptr_cpushared_kv_page_indices_cpushared_kv_last_page_len_cpuz-Optional[BatchPrefillWithPagedKVCacheWrapper]prefill_wrapperz,Optional[BatchDecodeWithPagedKVCacheWrapper]decode_wrapperz+Optional[MultiLevelCascadeAttentionWrapper]cascade_wrapperqo_indptr_gpupaged_kv_indptr_gpuc                 C  s   | j d urt| j  d S d S r%   )rf   r    r:   selfr*   r*   r+   __post_init__   s   
z FlashInferMetadata.__post_init__)r6   rZ   r[   r\   rw   rx   ry   rz   r{   r|   r}   r~   r   r   r*   r*   r*   r+   rC   q   sD   
 
rC   c                   @  s   e Zd ZU ejZded< dZded< d-ddZdd Z	dd Z
	d.d/ddZdd Zd0dd Z	d.d1d&d'Zd2d(d)Zd3d*d+Zd,S )4rF   zClassVar[AttentionCGSupport]cudagraph_supportrO   zClassVar[int]reorder_batch_thresholdkv_cache_specr   layer_names	list[str]vllm_configr   devicetorch.devicec           	      C  sJ  || _ || _|j| _|| _d | _d | _d | _|j| _t|j	j
| jj}|jj}|| }| jj tjk| _| jrCi | _t|| jj| _d | _tt||t| _tj|d tj| j d| _tj|tj| j d| _ tj|tj| j d| _!t" }tj|d tjd|d| _#tj|tjd|d| _$tj|tjd|d| _%tj&|tj| j d| _'d S )NrO   dtyper   cpu)r   r   
pin_memory)(r   r   cache_configr   _workspace_buffer_prefill_wrapper_decode_wrapperZcompilation_configr   model_configZmax_model_lenrI   Zscheduler_configZmax_num_seqsZcudagraph_modeZdecode_moder   ZFULLenable_cuda_graph_decode_wrappers_cudagraphminZmax_capture_size_decode_cudagraph_max_bs_cascade_wrapperr   r   r@   global_hyperparametersr&   zerosint32paged_kv_indptrrc   paged_kv_last_page_lenr   rb   Zpaged_kv_indices_cpurd   Zarangeblock_table_arange)	r   r   r   r   r   Zmax_num_pages_per_reqZmax_num_reqsZmax_num_pagesr   r*   r*   r+   __init__   sv   


z"FlashInferMetadataBuilder.__init__c                 C  s&   | j d u rtjttj| jd| _ | j S )Nr   )r   r&   r    FLASHINFER_WORKSPACE_BUFFER_SIZEZuint8r   r   r*   r*   r+   _get_workspace_buffer   s   
z/FlashInferMetadataBuilder._get_workspace_bufferc                 C  s"   | j d u rt|  t | _ | j S r%   )r   r   r   r   r   r*   r*   r+   _get_prefill_wrapper  s
   

z.FlashInferMetadataBuilder._get_prefill_wrapperF
batch_sizer3   use_cudagraphr!   c           
   	   C  s   |r
| j |d }n| j}|d u rb| jj| jj}| jj| jj}tj	p+|| dk}|rB| j
d |d  }| j}| jd | }	nd }d }d }	t|  t ||||	|d}|r_|| j |< |S || _|S )NrQ   rO   )Zuse_cuda_graphZpaged_kv_indptr_bufferZpaged_kv_indices_bufferZpaged_kv_last_page_len_bufferuse_tensor_cores)r   getr   r   r   get_num_attention_headsparallel_configZget_num_kv_headsenvsZ"VLLM_FLASHINFER_FORCE_TENSOR_CORESr   rc   r   r   r   r   )
r   r   r   r|   re   rJ   r   r   rc   r   r*   r*   r+   _get_decode_wrapper  sN   


z-FlashInferMetadataBuilder._get_decode_wrapperc                 C  s$   | j d u rtd|  t | _ | j S rL   )r   r	   r   r   r   r*   r*   r+   _get_cascade_wrapper7  s
   
z.FlashInferMetadataBuilder._get_cascade_wrapperattn_metadatarC   c           
      C  st  |j r:|  |_|jj|j|jg|j|jg|j|j	g|j
|jg|j|j|j|jd| jj| jj| jj|j|jd d S |j}|j}|dkr|}|  |_|j|d  jd |d ks[J |j|d  jd |d kskJ |j|d  jd |ksyJ |j|d  |j|  }|j|d  }|js|jj|||j	|j|d  |j|j|j|jd| jj| jj| jj|j|jd n|| j|_|| j|_|dkr6|dk}| j o|o|| j!k}|r| j"#|}	| jd| d|	  $|jd  | j||	 $d n|}	| %|	||_&|j's8t(|j&| jd |	d  |j	| jd |	 |j|j|j|jd| jj| jj| jj|j|jd d S d S d S )NT)Zcausalsm_scalewindow_leftlogits_soft_capri   rh   r   rO   NONE)pos_encoding_moder   r   r   ri   rh   ))ru   r   r}   planrw   ra   rx   rb   ry   rc   rz   rd   re   rJ   rf   rg   r   r   r   r   ri   rh   rs   rq   r   r{   shapero   tor   r~   r   r   r   r   Zpad_for_cudagraphZfill_r   r|   rp   fast_plan_decode)
r   r   rs   rq   Zprefill_startra   rb   Zpure_decoder   Znum_input_tokensr*   r*   r+   _plan=  s   

 





zFlashInferMetadataBuilder._plancommon_prefix_lencommon_attn_metadatar   
fast_buildr#   c           &      C  s  |j }|j}t|\}}}}	| jj}
|j}|j }|j}|j}|j	}||
 d |
 }|dk}|rt||
 dks9J ||
 }t
jd|gt
jdd}t
jd|gt
jdd}|dd |f }t
j|
gt
jdd}|d d |d f }||8 }nd }d }d }d }| }|j| jdd}| jd | d|dk }t
|}| jd | }t
j|d d d |f ||d t
j|dt
j| jdd|  d ||
 }t
j|dkt
|
|| jd | d | jj}|d	rt|}n| jj}| jj | jj!}| jj"} | jj#}!| j$j%}"|d	 ot&|	|||| |!|"}#t&||||| |!|"}$t'd%i d
|d|j(d| jd d|  d|d| jd | d|d| d|!d|
d|d| jjjd|j)d|d|d|d|d|#d|$d|d|d|d|	d |d!|d"|d#|d$|}%| *|% |%S )&NrO   r   r   r   Tnon_blocking)out)dimr   r   rX   r_   ra   rb   rc   rd   re   rJ   rf   rg   rh   ri   rj   rk   rl   rm   rn   ro   rp   rq   rr   rs   rt   ru   rw   rx   ry   rz   r*   )+num_reqsr_   r   r   rI   max_query_lenseq_lens_cpumaxrm   rn   r&   tensorr   r   r   r   Z	unsqueezesumrc   Zmasked_selectZcumsumrb   whererd   r   cache_dtype
startswithr    rY   r   r   r   r   r   rJ   r2   r   	has_sinksr   rC   Zquery_start_loc_cpurj   r   )&r   r   r   r   r   r_   rq   rs   rr   rt   rg   rk   rl   rm   r   rn   Zblock_table_bounds_cpuru   Znum_common_kv_blocksrw   rx   ry   rz   Zmax_num_blocksZblock_table_boundsmaskZnum_actual_pagesrc   rd   r   rV   re   rJ   rf   r   ro   rp   r   r*   r*   r+   build  s  








	


zFlashInferMetadataBuilder.buildc                 C  s*   |}|j |jksJ dd|_| d|S )z
        This method builds the metadata for full cudagraph capture.
        Currently, only decode is supported for full cudagraphs with FlashInfer.
        zrFlashInfer only supports decode-only full CUDAGraph capture. Make sure all cudagraph capture sizes <= max_num_seq.rO   r   )r   r_   r   r   )r   r   mr*   r*   r+   build_for_cudagraph_capture>  s   z5FlashInferMetadataBuilder.build_for_cudagraph_capturec                 O  s$   | j j| jjjkrdS t|i |S )NF)r   r   r   r   r   )r   argskwargsr*   r*   r+   r   N  s   z/FlashInferMetadataBuilder.use_cascade_attentionN)r   r   r   r   r   r   r   r   )F)r   r3   r   r!   )r   rC   )r   r3   r   r   r   r!   r#   rC   )r   r   )r#   r!   )r6   rZ   r[   r   ZUNIFORM_SINGLE_TOKEN_DECODEr   r\   r   r   r   r   r   r   r   r   r   r   r*   r*   r*   r+   rF      s"   
 

>,
  
rF   c                   @  s2   e Zd Zdejddfd&ddZ		d'd(d$d%ZdS ))r@   N	num_headsr3   r2   scalefloatrJ   alibi_slopesOptional[list[float]]sliding_windowOptional[int]rV   r;   r   Optional[float]r9   r   kv_sharing_target_layer_namesinksrv   r#   r4   c                 C  s   || _ || _t|| _|| _|d urtj|tjd}|| _|d u r%d| _	n|d df| _	|| _
|| _|
| _| j | j | _|	tjkrEtdd | _|d urf|jd |kratd| d|jd  d|| _d S d S )	Nr   )r   r   rO   r   zaEncoder self-attention and encoder/decoder cross-attention are not implemented for FlashInferImplzWSinks must have the same number of heads as the number of heads in the layer. Expected z
, but got rS   )r   r2   r   r   rJ   r&   r   Zfloat32r   r   rV   r   r   Znum_queries_per_kvr   DECODERNotImplementedErrorr   r   r8   )r   r   r2   r   rJ   r   r   rV   r   r9   r   r   r*   r*   r+   r   X  s8   


zFlashInferImpl.__init__layertorch.nn.Modulequeryr`   keyvaluekv_cacher   rC   outputoutput_scalec	                 C  s  |dusJ d|durt d|du r|S |j}	| jdu rMtjj|||dddf |dddf |j| j|j	|j
 | jdrMt| j}
||
}| jdurW| jd nd}|d|	 }|}|d|	 }|jr}|jdusqJ ||j|| |S |j}|j}t }|j| }|dkr;|j}||d }|jd |ksJ |dusJ |js|jsJ |j|ksJ |j| jpdksJ |j| j ksJ |j|||j!|j"||d d	 n`|# }|j$}|j%|d }|j&|d }t' d
ksJ |( sJ |( sJ |( sJ |( sJ |( sJ t)||||||j*|j+|j!| j  |j"|j,|j-|j.|| j/||d d |dkr|j0}|d| }|jd |ksSJ |dusZJ |j1s|j|ksfJ |j| jpmdksrJ |j| j ks{J |j|||j!|j"|d| d	 |S |# }|j$}|j%d| }|j&d| }t' d
ksJ |( sJ |( sJ |( sJ |( sJ |( sJ t2||||||j+|j!| j  |j"|| j/|d| d |S )a   Forward pass with FlashInfer.

        Args:
            query: shape = [num_tokens, num_heads, head_size]
            key: shape = [num_tokens, num_kv_heads, head_size]
            value: shape = [num_tokens, num_kv_heads, head_size]
            kv_cache: shape -
            # NHD: [num_blocks, 2, block_size, num_kv_heads, head_size]
            # HND: [num_blocks, 2,  num_kv_heads, block_size, head_size]


            attn_metadata: Metadata for attention.
        Returns:
            shape = [num_tokens, num_heads * head_size]
        NzOutput tensor must be provided.zAfused output quantization is not yet supported for FlashInferImplr   rO   rX   r           )Zk_scaleZv_scaler   rR   )r   r   workspace_bufferblock_tablesrm   rk   Z
max_kv_len
bmm1_scale
bmm2_scaler   Zcum_seq_lens_qZcum_seq_lens_kvr   r   r   )r   r   r   r   rm   rl   r   r   r   r   r   )3r   r_   r   r&   opsZ_C_cache_opsZreshape_and_cache_flashrj   rV   Z_k_scaleZ_v_scaler   r    rY   viewr   ru   r}   copy_runrr   rt   rU   Zpermuter{   r   ro   Z_causal_window_left_logits_soft_capr   	_sm_scaler   Z_k_scale_floatZ_v_scale_float
contiguous_float_workspace_bufferrn   rm   r   Zis_contiguousr   rk   rl   rs   r~   r   r   r|   rp   r   )r   r   r   r   r   r   r   r   r   r_   Ztorch_dtyper   Zoutput_paddedrr   rt   rT   Zkv_cache_permuter{   Zprefill_queryr   Zblock_tables_prefillZseq_lens_prefillr|   Zdecode_queryZblock_tables_decodeZseq_lens_decoder*   r*   r+   forward  s  







	




$

zFlashInferImpl.forward)r   r3   r2   r3   r   r   rJ   r3   r   r   r   r   rV   r;   r   r   r9   r   r   r   r   rv   r#   r4   )NN)r   r   r   r`   r   r`   r   r`   r   r`   r   rC   r   rv   r   rv   r#   r`   )r6   rZ   r[   r   r   r   r   r*   r*   r*   r+   r@   V  s    7r@   r   r   r'   T
indptr_cpur`   indiceslast_page_len_cpure   r3   rJ   rf   rg   r   r;   r   r   r   ri   !Optional[Union[str, torch.dtype]]rh   	data_typer   
rope_scale
rope_thetar   r!   r#   r4   c                 C  sT  | j r	t| ddr#| |||||||||	|
||||||| d| _dS | j s*J dt|}|
du r4d}
|durE|du r>|}|du rD|}n|du rKd}|du rQ|}t|tr[tt|n|}t|trgtt|n|}| jrst	|d d	}|| j
krtd
|| j
t|t| jkrtd| jj|dd | jj|dd |}|}| jrt|||}z| j| j| j| j||||||||| j ||d| _W nO ty } ztd| |d}~ww z%| j| j| j| j|||||| j |	|
||tjd|dtjd|d| _W n ty } ztd| |d}~ww || _|	| _|
| _|| _|| _|| _dS )ag  
    A faster version of BatchDecodeWithPagedKVCacheWrapper::plan used for
    cudagraph capture/replay, while the no cudagraph version turns back
    to the original plan.
    using original plan after passing host-side buffers:
    - only host-to-device copy of indptr and last_page_len buffers
    Modifications for cudagraph:
    - only host-to-device copy of indptr and last_page_len buffers.
    - avoid device-to-device copy of indices buffer.

    Part of the code get inspiration from the original plan from FlashInfer repo
    and the implementation of fast_decode_plan for FlashInfer in SGlang repo.
    vllm_first_callTFNzShould be cudagraph only herer   r'   rO   r   zThe batch size should be fixed in cudagraph mode, the runtime batch size {} mismatches the batch size set during initialization {}zHThe size of indices should be less than or equal to the allocated bufferr   zError in tensor core plan: r   r   zError in standard plan: ) Zis_cuda_graph_enabledgetattrr   r   len
isinstancer;   r&   r   r
   Z_fixed_batch_sizer8   formatZ_paged_kv_indices_bufZ_paged_kv_indptr_bufr   Z_paged_kv_last_page_len_bufr   Z_cached_moduler   Z_int_workspace_bufferZ _pin_memory_int_workspace_bufferZ
_plan_info	ExceptionRuntimeErroremptyZ_pos_encoding_moder   r   r   Z_rope_scaleZ_rope_theta)r   r   r   r   re   rJ   rf   rg   r   r   r   ri   rh   r   r   r   r   r   r   Zqo_indptr_hostZindptr_hostZlast_page_len_hostZkv_lens_arr_hoster*   r*   r+   r   N  s   $





r   )
r   r   Nr'   NNNNNT)$r   r`   r   r`   r   r`   re   r3   rJ   r3   rf   r3   rg   r3   r   r;   r   r3   r   r   ri   r   rh   r   r   r   r   r   r   r   r   r   r   r!   r#   r4   )8__doc__
__future__r   dataclassesr   typingr   r   r   r&   Z
flashinferr   r   r	   Zflashinfer.decoder
   r   r   Zflashinfer.prefillr   Z	vllm.envsr   Z vllm.attention.backends.abstractr   r   r   Zvllm.configr   r   Zvllm.loggerr   Z
vllm.utilsr   r   Zvllm.utils.flashinferr   Z%vllm.v1.attention.backends.flash_attnr   Z vllm.v1.attention.backends.utilsr   r   r   r   r   r   r   Zvllm.v1.kv_cache_interfacer   r   r6   loggerr    rC   rF   r@   r   r*   r*   r*   r+   <module>   sP   $HG      