o
    0 iZ                     @   sz  d Z ddlZddlmZ ddlmZ ddlmZ ddlmZ ddlmZ ddl	Z	zddl
ZdZW n ey;   d	ZY nw ee	je	jfZee	jfZed
dddZeddddZdd Zdd ZeddddZdd Zdd Zdd Zejd d!d"d#dd$Zejd%d&d'd(dd$Z	d@d)d*Z ed+d,d-d.Z!G d/d0 d0e"Z#d1d2 Z$d3d4 Z%d5d6 Z&d7d8 Z'dAd:d;Z(d<d= Z)d>d? Z*dS )Bz*Indexing mixin for sparse matrix classes.
    N)_core)
isspmatrix)spmatrix)device)runtimeTFzT d, S ind, int32 minorzraw T answerz+if (ind == minor) atomicAdd(&answer[0], d);Z#cupyx_scipy_sparse_compress_getitemz"T real, T imag, S ind, int32 minorz$raw T answer_real, raw T answer_imagzo
    if (ind == minor) {
    atomicAdd(&answer_real[0], real);
    atomicAdd(&answer_imag[0], imag);
    }
    Z+cupyx_scipy_sparse_compress_getitem_complexc           
      C   sT   |||d  }t |d t |d }}|| }||| }| || }	|	||fS )aE  Return a submatrix of the input sparse matrix by slicing major axis.

    Args:
        Ax (cupy.ndarray): data array from input sparse matrix
        Aj (cupy.ndarray): indices array from input sparse matrix
        Ap (cupy.ndarray): indptr array from input sparse matrix
        start (int): starting index of major axis
        stop (int): ending index of major axis

    Returns:
        Bx (cupy.ndarray): data array of output sparse matrix
        Bj (cupy.ndarray): indices array of output sparse matrix
        Bp (cupy.ndarray): indptr array of output sparse matrix

       r   )int)
AxAjApstartstopZstart_offsetZstop_offsetBpBjBx r   e/home/app/PaddleOCR-VL-test/.venv_paddleocr/lib/python3.10/site-packages/cupyx/scipy/sparse/_index.py_get_csr_submatrix_major_axis+   s   
r   c           
      C   sn   ||k||k @ }t j|jd |jd}d|d< ||dd< t j||d || }|| | }| | }	|	||fS )aE  Return a submatrix of the input sparse matrix by slicing minor axis.

    Args:
        Ax (cupy.ndarray): data array from input sparse matrix
        Aj (cupy.ndarray): indices array from input sparse matrix
        Ap (cupy.ndarray): indptr array from input sparse matrix
        start (int): starting index of minor axis
        stop (int): ending index of minor axis

    Returns:
        Bx (cupy.ndarray): data array of output sparse matrix
        Bj (cupy.ndarray): indices array of output sparse matrix
        Bp (cupy.ndarray): indptr array of output sparse matrix

    r   dtyper   Nout)cupyemptysizer   cumsum)
r
   r   r   r   r   maskZmask_sumr   r   r   r   r   r   _get_csr_submatrix_minor_axisD   s   
r   zNint32 out_rows, raw I rows, raw int32 Ap, raw int32 Aj, raw T Ax, raw int32 Bpzint32 Bj, T BxaD  
    const I row = rows[out_rows];

    // Look up starting offset
    const I starting_output_offset = Bp[out_rows];
    const I output_offset = i - starting_output_offset;
    const I starting_input_offset = Ap[row];

    Bj = Aj[starting_input_offset + output_offset];
    Bx = Ax[starting_input_offset + output_offset];
Z$cupyx_scipy_sparse_csr_row_index_kerc           
      C   sx   t |}t j|jd |jd}d|d< t j|| |dd d t|d }t||}t||||| |\}}	|	||fS )a  Populate indices and data arrays from the given row index
    Args:
        Ax (cupy.ndarray): data array from input sparse matrix
        Aj (cupy.ndarray): indices array from input sparse matrix
        Ap (cupy.ndarray): indptr array from input sparse matrix
        rows (cupy.ndarray): index array of rows to populate
    Returns:
        Bx (cupy.ndarray): data array of output sparse matrix
        Bj (cupy.ndarray): indices array of output sparse matrix
        Bp (cupy.ndarray): indptr array for output sparse matrix
    r   r   r   Nr   r   )	r   diffr   r   r   r   r	   _csr_indptr_to_coo_rows_csr_row_index_ker)
r
   r   r   rowsZrow_nnzr   nnzout_rowsr   r   r   r   r   _csr_row_indexq   s   


r%   c                 C   sb   ddl m} tj| tjd}t }tj	r| dkrt
d|||jj| |jd |jj|j |S )Nr   )cusparser   z@hipSPARSE currently cannot handle sparse matrices with null ptrsr   )Zcupy_backends.cuda.libsr&   r   r   numpyZint32r   Zget_cusparse_handler   Zis_hip
ValueErrorZxcsr2coodataZptrr   ZCUSPARSE_INDEX_BASE_ZERO)r#   r   r&   r$   handler   r   r   r       s   r    c           
      C   s   t j| |d} t j||d}t || g}t ||}| | }|| }|| }t j|jdd}	t||||	|jd d ||	 ||	 ||	 fS )z;Find the unique indices for each row and keep only the lastr   boolr   r   )r   asarraystackZlexsortZastypeZonesr   _unique_mask_kern)
ijxZ	idx_dtypeZstackedorderZindptr_insertsZindices_insertsZdata_insertsr   r   r   r   _select_last_indices   s   
r4   zqraw I insert_indices, raw T insert_values, raw I insertion_indptr,
        raw I Ap, raw I Aj, raw T Ax, raw I Bpzraw I Bj, raw T Bxa  

        const I input_row_start = Ap[i];
        const I input_row_end = Ap[i+1];
        const I input_count = input_row_end - input_row_start;

        const I insert_row_start = insertion_indptr[i];
        const I insert_row_end = insertion_indptr[i+1];
        const I insert_count = insert_row_end - insert_row_start;

        I input_offset = 0;
        I insert_offset = 0;

        I output_n = Bp[i];

        I cur_existing_index = -1;
        T cur_existing_value = -1;

        I cur_insert_index = -1;
        T cur_insert_value = -1;

        if(input_offset < input_count) {
            cur_existing_index = Aj[input_row_start+input_offset];
            cur_existing_value = Ax[input_row_start+input_offset];
        }

        if(insert_offset < insert_count) {
            cur_insert_index = insert_indices[insert_row_start+insert_offset];
            cur_insert_value = insert_values[insert_row_start+insert_offset];
        }


        for(I jj = 0; jj < input_count + insert_count; jj++) {

            // if we have both available, use the lowest one.
            if(input_offset < input_count &&
               insert_offset < insert_count) {

                if(cur_existing_index < cur_insert_index) {
                    Bj[output_n] = cur_existing_index;
                    Bx[output_n] = cur_existing_value;

                    ++input_offset;

                    if(input_offset < input_count) {
                        cur_existing_index = Aj[input_row_start+input_offset];
                        cur_existing_value = Ax[input_row_start+input_offset];
                    }


                } else {
                    Bj[output_n] = cur_insert_index;
                    Bx[output_n] = cur_insert_value;

                    ++insert_offset;
                    if(insert_offset < insert_count) {
                        cur_insert_index =
                            insert_indices[insert_row_start+insert_offset];
                        cur_insert_value =
                            insert_values[insert_row_start+insert_offset];
                    }
                }

            } else if(input_offset < input_count) {
                Bj[output_n] = cur_existing_index;
                Bx[output_n] = cur_existing_value;

                ++input_offset;
                if(input_offset < input_count) {
                    cur_existing_index = Aj[input_row_start+input_offset];
                    cur_existing_value = Ax[input_row_start+input_offset];
                }

            } else {
                    Bj[output_n] = cur_insert_index;
                    Bx[output_n] = cur_insert_value;

                    ++insert_offset;
                    if(insert_offset < insert_count) {
                        cur_insert_index =
                            insert_indices[insert_row_start+insert_offset];
                        cur_insert_value =
                            insert_values[insert_row_start+insert_offset];
                    }
            }

            output_n++;
        }
    Z1cupyx_scipy_sparse_csr_copy_existing_indices_kern)Z	no_returnz#raw I rows, raw I cols, raw I orderzraw bool maskaT  
    I cur_row = rows[i];
    I next_row = rows[i+1];

    I cur_col = cols[i];
    I next_col = cols[i+1];

    I cur_order = order[i];
    I next_order = order[i+1];

    if(cur_row == next_row && cur_col == next_col) {
        if(cur_order < next_order)
            mask[i] = false;
        else
            mask[i+1] = false;
    }
    Z#cupyx_scipy_sparse_unique_mask_kernc                 C   sD   ||dk   | 7  < ||dk   |7  < t | ||||||||jd	S )a  Populate data array for a set of rows and columns
    Args
        n_row : total number of rows in input array
        n_col : total number of columns in input array
        Ap : indptr array for input sparse matrix
        Aj : indices array for input sparse matrix
        Ax : data array for input sparse matrix
        Bi : array of rows to extract from input sparse matrix
        Bj : array of columns to extract from input sparse matrix
    Returns
        Bx : data array for output sparse matrix
    r   r,   )_csr_sample_values_kernr   )Zn_rowZn_colr   r   r
   ZBir   Znot_found_valr   r   r   _csr_sample_values'  s   r6   zWI n_row, I n_col, raw I Ap, raw I Aj, raw T Ax,
    raw I Bi, raw I Bj, I not_found_valzraw T Bxat  
    const I j = Bi[i]; // sample row
    const I k = Bj[i]; // sample column
    const I row_start = Ap[j];
    const I row_end   = Ap[j+1];
    T x = 0;
    bool val_found = false;
    for(I jj = row_start; jj < row_end; jj++) {
        if (Aj[jj] == k) {
            x += Ax[jj];
            val_found = true;
        }
    }
    Bx[i] = val_found ? x : not_found_val;
Z)cupyx_scipy_sparse_csr_sample_values_kernc                   @   s   e Zd ZdZdd Zdd Zdd Zdd	 Zd
d Zdd Z	dd Z
dd Zdd Zdd Zdd Zdd Zdd Zdd Zdd Zd d! Zd"d# Zd$d% Zd&d' Zd(d) Zd*S )+
IndexMixinzS
    This class provides common dispatching and validation logic for indexing.
    c                 C   s  t rtjtjdk rtd| |\}}t|t	r@t|t	r&| 
||S t|tr1| ||S |jdkr<| ||S tdt|trxt|t	rP| ||S t|tri|td krc||krc|  S | ||S |jdkrt| ||S td|jdkrt|t	r| ||S t|tr| ||S n5t|t	r| ||S t|trtd|jd dkr|jdks|jd dkr| |d d df | S t||\}}|j|jkrtd|jdkr| jt|j| jdS | ||S )Nz1.4.0z,Sparse __getitem__() requires Scipy >= 1.4.0r   zindex results in >2 dimensionsr   'number of row and column indices differr   ) scipy_availabler'   libZNumpyVersionscipy__version__NotImplementedError_parse_indices
isinstance_int_scalar_types_get_intXintslice_get_intXslicendim_get_intXarray
IndexError_get_sliceXintcopy_get_sliceXslice_get_sliceXarray_get_arrayXint_get_arrayXsliceshape_get_columnXarrayZravelr   broadcast_arraysr   	__class__Z
atleast_2dr   _get_arrayXarray)selfkeyrowcolr   r   r   __getitem__Z  sZ   













&
zIndexMixin.__getitem__c           
      C   s  |  |\}}t|tr.t|tr.tj|| jd}|jdkr"td| |||j	d  d S t|t
rEtj|| jd  d d d f }nt|}t|t
rntj|| jd  d d d f }|jdkrm|d d d f }nt|}t||\}}|j|jkrtdt|r|jdkr|d  }|d  }|jd dko|jd dk}|jd dko|jd dk}|s|jd |jd kr|s|jd |jd kstd|jdkrd S |jdd}|  | ||| d S tj|| jd}t||\}}	|jdkrd S ||j}| ||| d S )	Nr   r   z&Trying to assign a sequence to an itemr   r8   zshape mismatch in assignmentT)rH   )r>   r?   r@   r   r-   r   r   r(   _set_intXintZflatrB   ZarangeindicesrM   Z
atleast_1drD   rO   rF   r   ZtocooZsum_duplicates_set_arrayXarray_sparseZreshape_set_arrayXarray)
rR   rS   r2   rT   rU   r0   r1   Zbroadcast_rowZbroadcast_col_r   r   r   __setitem__  sV   


$

"




zIndexMixin.__setitem__c                 C   s.   t |tjtjfr|jdkr|jdkrdS dS )Nr   r   TF)r?   r   ndarrayr'   rD   r   )rR   indexr   r   r   
_is_scalar  s   zIndexMixin._is_scalarc                 C   s   | j \}}t|\}}| |r| }| |r| }t|tr)t||d}nt|ts4| ||}t|trCt||d}||fS t|tsN| ||}||fS )NrT   column)	rM   _unpack_indexr_   itemr?   r@   _normalize_indexrB   
_asindices)rR   rS   MNrT   rU   r   r   r   r>     s    






zIndexMixin._parse_indicesc              
   C   sN   zt j|| jjd}W n tttfy   tdw |jdvr#td|| S )a  Convert `idx` to a valid index for an axis with a given length.
        Subclasses that need special validation can override this method.

        idx is assumed to be at least a 1-dimensional array-like, but can
        have no more than 2 dimensions.
        r   zinvalid index)r      zIndex dimension must be <= 2)	r   r-   rX   r   r(   	TypeErrorMemoryErrorrF   rD   )rR   idxlengthr2   r   r   r   rd     s   
zIndexMixin._asindicesc                 C   s&   | j \}}t||d}| |tdS )zReturn a copy of row i of the matrix, as a (1 x n) row vector.

        Args:
            i (integer): Row

        Returns:
            cupyx.scipy.sparse.spmatrix: Sparse matrix with single row
        r^   N)rM   rc   rC   rB   rR   r0   re   rf   r   r   r   getrow     
	zIndexMixin.getrowc                 C   s&   | j \}}t||d}| td|S )zReturn a copy of column i of the matrix, as a (m x 1) column vector.

        Args:
            i (integer): Column

        Returns:
            cupyx.scipy.sparse.spmatrix: Sparse matrix with single column
        r^   N)rM   rc   rG   rB   rl   r   r   r   getcol  rn   zIndexMixin.getcolc                 C      t  Nr=   rR   rT   rU   r   r   r   rA        zIndexMixin._get_intXintc                 C   rp   rq   rr   rs   r   r   r   rE     rt   zIndexMixin._get_intXarrayc                 C   rp   rq   rr   rs   r   r   r   rC     rt   zIndexMixin._get_intXslicec                 C   rp   rq   rr   rs   r   r   r   rG     rt   zIndexMixin._get_sliceXintc                 C   rp   rq   rr   rs   r   r   r   rI     rt   zIndexMixin._get_sliceXslicec                 C   rp   rq   rr   rs   r   r   r   rJ     rt   zIndexMixin._get_sliceXarrayc                 C   rp   rq   rr   rs   r   r   r   rK     rt   zIndexMixin._get_arrayXintc                 C   rp   rq   rr   rs   r   r   r   rL      rt   zIndexMixin._get_arrayXslicec                 C   rp   rq   rr   rs   r   r   r   rN   #  rt   zIndexMixin._get_columnXarrayc                 C   rp   rq   rr   rs   r   r   r   rQ   &  rt   zIndexMixin._get_arrayXarrayc                 C   rp   rq   rr   rR   rT   rU   r2   r   r   r   rW   )  rt   zIndexMixin._set_intXintc                 C   rp   rq   rr   ru   r   r   r   rZ   ,  rt   zIndexMixin._set_arrayXarrayc                 C   s6   t j| | jd}t ||\}}| ||| d S )Nr   )r   r-   Ztoarrayr   rO   rZ   )rR   rT   rU   r2   r[   r   r   r   rY   /  s   z"IndexMixin._set_arrayXarray_sparseN)__name__
__module____qualname____doc__rV   r\   r_   r>   rd   rm   ro   rA   rE   rC   rG   rI   rJ   rK   rL   rN   rQ   rW   rZ   rY   r   r   r   r   r7   U  s,    43r7   c                 C   s   t r	t| tjjS dS )NF)r9   r?   r;   sparser   )r^   r   r   r   _try_is_scipy_spmatrix6  s   r{   c                 C   s&  t | ttjtjfst| r| jdkr| jjdkr| 	 S t
| } t | trEt| dkr1| \}}n:t| dkrA| d td}}n*tdt| }|du rU| td}}n|jdk rbt|tdfS |jdkrk|	 S t|sst|rwtdt|}t|}|durt|}|durt|}||fS )a   Parse index. Always return a tuple of the form (row, col).
    Valid type for row/col is integer, slice, or array of integers.

    Returns:
          resulting row & col indices : single integer, slice, or
          array of integers. If row & column indices are supplied
          explicitly, they are used as the major/minor indices.
          If only one index is supplied, the minor index is
          assumed to be all (e.g., [maj, :]).
    rg   br   r   Nzinvalid number of indiceszoIndexing with sparse matrices is not supported except boolean indexing where matrix and index are equal shapes.)r?   r   r   r]   r'   r{   rD   r   kindZnonzero_eliminate_ellipsistuplelenrB   rF   _compatible_boolean_index_boolean_index_to_arrayr   )r^   rT   rU   rj   Zbool_rowZbool_colr   r   r   ra   <  s@   




ra   c                 C   s   | t u rtdtdfS t| ts| S t| D ]\}}|t u r#|} nq| S t| dkr4tdtdfS t| dkr\|dkrT| d t u rLtdtdfS td| d fS | d tdfS g }| |d d D ]}|t urq|| qf|t| }tdd| }| d| tdf|  t| S )z6Process indices with Ellipsis. Returns modified index.Nr   rg   r   )EllipsisrB   r?   r   	enumerater   appendmax)r^   r1   vZfirst_ellipsistailndZnslicer   r   r   r~   r  s4   

"r~   c                 C   s6   | | k s	| |krt d|| | dk r| |7 } | S )Nz{} ({}) out of ranger   )rF   format)r2   dimnamer   r   r   rc     s
   rc   rg   c                 C   sZ   |dk rdS zt | dkr| d nd}W n
 ty   Y dS w t|tr&dS t||d S )zQReturns True if first element of the incompatible
    array type is boolean.
    r   Nr   T)r   rh   r?   _bool_scalar_types_first_element_bool)rj   Zmax_dimfirstr   r   r   r     s   
r   c                 C   s8   t | dr| jjdkr| S dS t| rtj| ddS dS )ztReturns a boolean index array that can be converted to
    integer array. Returns None if no such array exists.
    rD   r|   r+   r   N)hasattrr   r}   r   r   r-   rj   r   r   r   r     s   
r   c                 C   s0   | j dkr	tdtj| | jd} t| d S )Nr   zinvalid index shaper   r   )rD   rF   r   arrayr   wherer   r   r   r   r     s   
r   )r   )rg   )+ry   r   r   Zcupyx.scipy.sparse._baser   r   Z	cupy.cudar   r   r'   Zscipy.sparser;   r9   ImportErrorr	   integerint_r@   r+   Zbool_r   ZElementwiseKernelZ_compress_getitem_kernZ_compress_getitem_complex_kernr   r   r!   r%   r    r4   Z_insert_many_populate_arraysr/   r6   r5   objectr7   r{   ra   r~   rc   r   r   r   r   r   r   r   <module>   s    

X_
 b6$
