o
    pi(                     @  s  d dl mZ d dlmZmZmZ d dlZd dlmZ ddlm	Z	 ddl
mZ er:d dlmZ d d	lmZ d d
lmZ d dlmZ d dlmZ g Zeh dddd	 dVdWddZG dd deZG dd deZ	dXdYd$d%Zed&d'hd(d)d	*			dZd[d.d/ZG d0d1 d1ejZd\d7d8Zd]d9d:Zd]d;d<Z	dXdYd=d%Zed&d'hd>d?ddd@d^dBdCZed&d'hdDdEddd@d^dFdGZG dHdI dIeZ ed&d'hdJdKd		d_dd@d`dPdQZ!ed&d'hdRdSd		d_dd@d`dTdUZ"dS )a    )annotations)TYPE_CHECKINGAny
NamedTupleN)_C_ops   )Variable)in_dynamic_mode)Sequence)Tensor)Size2)nn)ForbidKeywordsDecorator>   axisZnum_or_sectionsnamexzpaddle.compat.splitzpaddle.splitZillegal_keys	func_nameZcorrect_nametensorr   split_size_or_sectionsint | Sequence[int]dimintreturntuple[Tensor, ...]c                   s   fdd}ddd}t  ttfr:t D ]#\}}d}t |tr)t|d}n|}|dk r9td	| d
| qt rt |trG|d}|t	| j
 dksTJ d|dk r_|t	| j
 n|}t  ttfrtj rt D ]\}}	t |	tr |   |< qrnt  tstdt  dt  tr dksJ d| || j
| t  trtt|  |S tt|  |S tt|  |S t |tjjrtdt |trt	| j
| dksJ d|dk rt	| j
| n|}| j
}
t  tttfstdt  trA dksJ d| || j
| t  tr8tj r/tj  tt|  |S tt|  |S t |tr[|
| dkr[t	 |
| ks[J dtj rhtj  tt|  |S )a	  
    (PyTorch Compatible API) Split the input tensor into multiple sub-Tensors.

    Args:
        tensor (Tensor): A N-D Tensor. The data type is bool, bfloat16, float16, float32, float64, uint8, int8, int32 or int64.
        split_size_or_sections (int|list|tuple):
            If split_size_or_sections is an integer type, then tensor will be split into equally sized chunks (if possible).
            Last chunk will be smaller if the tensor size along the given dimension dim is not divisible by split_size.
            If split_size_or_sections is a list, then tensor will be split into len(split_size_or_sections) chunks with sizes
            in dim according to split_size_or_sections. Negative inputs are not allowed. For example: for a dim with 9 channels,
            [2, 3, -1] will not be interpreted as [2, 3, 4], but will be rejected and an exception will be thrown.
        dim (int|Tensor, optional): The dim along which to split, it can be a integer or a ``0-D Tensor``
            with shape [] and data type  ``int32`` or ``int64``.
            If :math::`dim < 0`, the dim to split along is :math:`rank(x) + dim`. Default is 0.
    Returns:
        tuple(Tensor), The tuple of segmented Tensors.

    Note:
        This is a pytorch compatible API that follows the function signature and behavior of torch.split.
        To use the original split of paddle, please consider `paddle.split`

    Examples:

        .. code-block:: python

            >>> import paddle

            >>> # x is a Tensor of shape [3, 8, 5]
            >>> x = paddle.rand([3, 8, 5])

            >>> out0, out1, out2 = paddle.compat.split(x, split_size_or_sections=3, dim=1)
            >>> print(out0.shape)
            [3, 3, 5]
            >>> print(out1.shape)
            [3, 3, 5]
            >>> print(out2.shape)
            [3, 2, 5]

            >>> out0, out1, out2 = paddle.compat.split(x, split_size_or_sections=[1, 2, 5], dim=1)
            >>> print(out0.shape)
            [3, 1, 5]
            >>> print(out1.shape)
            [3, 2, 5]
            >>> print(out2.shape)
            [3, 5, 5]

            >>> # dim is negative, the real dim is (rank(x) + dim)=1
            >>> out0, out1, out2 = paddle.compat.split(x, split_size_or_sections=3, dim=-2)
            >>> print(out0.shape)
            [3, 3, 5]
            >>> print(out1.shape)
            [3, 3, 5]
            >>> print(out2.shape)
            [3, 2, 5]
    c                   s@   |  }|  }|dkr|S  fddt |D }|| |S )Nr   c                   s   g | ]} qS  r   ).0_r   r   [/home/app/PaddleOCR-VL/.venv_paddleocr/lib/python3.10/site-packages/paddle/tensor/compat.py
<listcomp>o   s    z/split.<locals>.GetSplitSize.<locals>.<listcomp>)rangeappend)Z
split_sizeZshape_on_dimZremaining_numZnum_complete_sectionsectionsr   r   r   GetSplitSizei   s   

zsplit.<locals>.GetSplitSizer   r   r   c                 S  sF   t | }t|tr|| k s||krtd| d| d| | | S )Nz:(InvalidArgument) The dim is expected to be in range of [-, z), but got )len
isinstancer   
ValueError)shaper   Zshape_ranger   r   r   GetShapeOnDimInRangeu   s   
z#split.<locals>.GetShapeOnDimInRanger   zWpaddle.compat.split expects split_sizes have only non-negative entries, but got size = z on dim z(rank(x) + dim) must >= 0zjThe type of 'split_size_or_sections' in split must be int, list or tuple in imperative mode, but received .z.split_size_or_sections must be greater than 0.zv'dim' is not allowed to be a pir.Value in a static graph: 
pir.Value can not be used for indexing python lists/tuples.z\The type of 'split_size_or_sections' in split must be int, list or tuple in imperative mode.zClen(split_size_or_sections) must not be more than input.shape[dim].N)r   r   r   r   )r'   listtuple	enumerater   r   itemr(   r	   r&   r)   paddleutilsZ_contain_var	TypeErrortyper   splitZsplit_with_numpirValueZget_int_tensor_list)r   r   r   r$   r*   iZsection_sizeZ	shape_valindexr/   Zinput_shaper   r   r   r4   )   s   @
	








r4   c                   @     e Zd ZU ded< ded< dS )SortRetTyper   valuesindicesN__name__
__module____qualname____annotations__r   r   r   r   r:         
 r:   c                   @  r9   )MinMaxRetTyper   r;   r<   Nr=   r   r   r   r   rC      rB   rC   Fout-Tensor | tuple[Tensor, Tensor] | list[Tensor]expect_multipleboolc                 C     | d u rd S t  std|rKt| ttfrt| dkr&tdt|  dt| d tj	r6t| d tj	sItdt| d  dt| d  d	d S t| tj	s[td
t|  dd S NzlUsing `out` static graph CINN backend is currently not supported. Directly return the tensor tuple instead.
r   z-Expected a list or tuple of two tensors, got z	 instead.r      z-Expected Tensor type in the tuple/list, got (r%   z
) instead.zExpected a Tensor, got 
r	   RuntimeErrorr'   r-   r,   r&   r2   r3   r0   r   rD   rF   r   r   r   _check_out_status   ,    rN   r   r   zpaddle.compat.sortzpaddle.sortinput
descendingstablec                 C  sT   t |dd t| |||\}}|dur$t||d  t||d  t||dS )a5
  

    Sorts the input along the given dimension, and returns the sorted output and indices tensor. The default sort algorithm is ascending, if you want the sort algorithm to be descending, you must set the :attr:`descending` as True.

    Args:
        input (Tensor): An input N-D Tensor with type float32, float64, int16,
            int32, int64, uint8, float16, bfloat16
        dim (int, optional): Dimension to compute indices along. The effective range
            is [-R, R), where R is Rank(x). when dim<0, it works the same way
            as dim+R. Default is -1.
        descending (bool, optional) : Descending is a flag, if set to true,
            algorithm will sort by descending order, else sort by
            ascending order. Default is false.
        stable (bool, optional): Whether to use stable sorting algorithm or not.
            When using stable sorting algorithm, the order of equivalent elements
            will be preserved. Default is False.
        out (tuple, optional) : the output tuple/list of (Tensor, Tensor) that
            can be optionally given to be used as output buffers

    Returns:
        SortRetType, a named tuple which contains `values` and `indices`, can be accessed through either indexing
        (e.g. `result[0]` for values and `result[1]` for indices), or by `result.values` & `result.indices`

    Examples:

    .. code-block:: python

            >>> import paddle

            >>> x = paddle.to_tensor([[5,8,9,5],
            ...                       [0,0,1,7],
            ...                       [6,9,2,4]],
            ...                      dtype='float32')
            >>> out1 = paddle.compat.sort(input=x, dim=-1)
            >>> out2 = paddle.compat.sort(x, 1, descending=True)
            >>> out1
            SortRetType(values=Tensor(shape=[3, 4], dtype=float32, place=Place(cpu), stop_gradient=True,
                   [[5., 5., 8., 9.],
                    [0., 0., 1., 7.],
                    [2., 4., 6., 9.]]), indices=Tensor(shape=[3, 4], dtype=int64, place=Place(cpu), stop_gradient=True,
                   [[0, 3, 1, 2],
                    [0, 1, 2, 3],
                    [2, 3, 0, 1]]))
            >>> out2
            SortRetType(values=Tensor(shape=[3, 4], dtype=float32, place=Place(cpu), stop_gradient=True,
                   [[9., 8., 5., 5.],
                    [7., 1., 0., 0.],
                    [9., 6., 4., 2.]]), indices=Tensor(shape=[3, 4], dtype=int64, place=Place(cpu), stop_gradient=True,
                   [[2, 1, 0, 3],
                    [3, 2, 0, 1],
                    [1, 0, 3, 2]]))
    T)rF   Nr   rJ   r;   r<   )rN   r   Zargsortr0   assignr:   )rQ   r   rR   rS   rD   Zoutputsr<   r   r   r   sort  s   @rV   c                      sh   e Zd ZU dZded< ded< ded< ded< eh ddd	d
			dd fddZdddZ  ZS )Unfolda  
    A compatible version of paddle.nn.Unfold:

    The keyword arguments are in non-plural forms, example: `kernel_size` instead of `kernel_sizes`. `padding` restricts the size of the input to be 1(int) or 2, Size4 is not allowed.

    All the input parameters allow `Tensor` or `pir.Value` as inputs, and will be converted to lists. Other aspects are the same. To use a more input-flexible version of Unfold, please refer to `paddle.nn.Unfold`.

    Args:
        kernel_size(int|list|tuple|Tensor): The size of convolution kernel, should be [k_h, k_w]
            or an integer k treated as [k, k].
        stride(int|list|tuple|Tensor, optional): The strides, should be [stride_h, stride_w]
            or an integer stride treated as [sride, stride]. For default, strides will be [1, 1].
        padding(int|list|tuple|Tensor, optional): The paddings of each dimension, should be
            a single integer or [padding_h, padding_w]. If [padding_h, padding_w] was given, it will expanded to
            [padding_h, padding_w, padding_h, padding_w]. If an integer padding was given,
            [padding, padding, padding, padding] will be used. By default, paddings will be 0.
        dilation(int|list|tuple|Tensor, optional): The dilations of convolution kernel, should be
            [dilation_h, dilation_w], or an integer dilation treated as [dilation, dilation].
            For default, it will be [1, 1].

    Examples:
        .. code-block:: python

            >>> import paddle
            >>> x = paddle.randn((100, 3, 224, 224))
            >>> unfold = paddle.compat.Unfold(kernel_size=[3, 3])
            >>> result = unfold(x)
            >>> print(result.shape)
            [100, 27, 49284]
    r   kernel_sizes	dilationspaddingsstrides>   rX   rZ   rY   r[   zpaddle.compat.Unfoldzpaddle.nn.Unfoldr   rJ   r   kernel_sizedilationpaddingstrider   Nonec                   s   t  |||| d S N)super__init__)selfr\   r]   r^   r_   	__class__r   r   rc   o  s   zUnfold.__init__rQ   r   c                 C  s@   ddd}t jj||| j|| j|| jdd|| j| jdS )NFc                 S  sv   | }t  rt| tjjtjfr|  }nt| ttt	fs t
d|r9t|ttfr9t|dkr9tdt| d|S )Nz^paddle.compat.Unfold does not allow paddle.Tensor or pir.Value as inputs in static graph mode.r   zOThe `padding` field of paddle.compat.Unfold can only have size 1 or 2, now len=z2. 
Did you mean to use paddle.nn.Unfold() instead?)r	   r'   r0   r5   r6   r   tolistr,   r-   r   r2   r&   r(   )r   
size_checkresr   r   r   to_list_if_necessary~  s   
z,Unfold.forward.<locals>.to_list_if_necessaryT)rh   )rX   r[   rZ   rY   r   F)r   Z
functionalZunfoldrX   r[   rZ   rY   r   )rd   rQ   rj   r   r   r   forward}  s   
zUnfold.forward)rJ   r   rJ   )
r\   r   r]   r   r^   r   r_   r   r   r`   )rQ   r   r   r   )	r>   r?   r@   __doc__rA   r   rc   rl   __classcell__r   r   re   r   rW   J  s    
 	rW   r   strargsr   kwargsc                   sV  d fdd	fdd}d }d}t  }|t  }|dkr% |dkrX|dkr2 \}}n|dkr? d	 }|d
}n|d}|d
}|d u sTt|ttjjfrW n/|dkr|rc d	 }ndv rld }ndv rd }t|ttjjfs |d u r |d urt|ttjjfst|turd dt| d||fS )N c                   s\   dd  D }| dd  D  d|}d d|  d| d d	 d
}t|S )Nc                 S  s   g | ]}t |jqS r   r3   r>   )r   vr   r   r   r      s    zO_min_max_param_checker.<locals>.invalid_arguments_exception.<locals>.<listcomp>c                 S  s$   g | ]\}}| d t |j qS )=rs   )r   krt   r   r   r   r      s   $ r%   z%Invalid arguments for `paddle.compat.z`:
zGot: (paddle.Tensor input, z;), but expect one of:
 - (input: paddle.Tensor) for reduce_zL on all dims.
 - (input: paddle.Tensor, other: paddle.Tensor) -> see paddle.zOimum
 - (input: paddle.Tensor, int dim (cannot be None), bool keepdim = False)
)extenditemsjoinr2   )Zerror_prefixZ	type_strs	signature	error_msg)rp   r   rq   r   r   invalid_arguments_exception  s   
z;_min_max_param_checker.<locals>.invalid_arguments_exceptionc                   s*   d }z|  }W |S  t y     d w ra   )KeyError)keyri   )r|   rq   r   r   try_get_keys  s   
z,_min_max_param_checker.<locals>.try_get_keysFr   rJ   r   keepdimr   otherzBThe second input must be int or Tensor or implicit None in compat.z, but received z.
)rr   )r&   r'   r   r0   r5   r6   r3   r   )r   rp   rq   r   dim_or_otherr   num_argsZtotal_arg_numr   )rp   r   r|   rq   r   _min_max_param_checker  sP   




r   c                 C  sL   | j }|tjks|tjks|tjks|tjkr$| js"td| ddS dS )z@Prevent integral input tensor type to have `stop_gradient=False`zTensors with integral type: 'z' should stop gradient.N)dtyper0   Zint32int64Zuint8int16stop_gradientr2   rQ   Zin_dtyper   r   r   _min_max_tensor_allow_grad  s   




r   c                 C  s8   | j }|tjks|tjks|tjkrtd| ddS )zPpaddle.min/argmin(max/argmax), paddle.take_along_axis reject the following typesz*Non-CUDA GPU placed Tensor does not have 'zZ' op registered.
Paddle support following DataTypes: int32, int64, float64, float32, uint8N)r   r0   Zfloat16Zbfloat16r   r2   r   r   r   r   _min_max_allow_cpu_composite  s   



r   c                 C  rH   rI   rK   rM   r   r   r   rN     rO   zpaddle.compat.minz
paddle.min)rD   Tensor | MinMaxRetTypec                O    t | tjjtjfstdt| j dt|  t	dg|R i |\}}d}|du r7t
|d t| }not |trt
|d | jrt rx| j sxt|  tj| |dd}tj| ||d}|rit||d	}n=t|j|d|j|dd	}n.t| ||d\}	}
d|
_t|	|
d	}nt| tjg tj| jd
d	}nt
|d t| |}|durt |trt|j|d  t|j|d  |S t|| |S )a  

    Computes the minimum of tensor elements. There are mainly 3 cases (functionalities):

    1. paddle.compat.min(input: Tensor): reduce min over all dims, return a single value Tensor
    2. paddle.compat.min(input: Tensor, dim: int (cannot be None), keepdim=False): reduce min over the given dim,
        returns a named tuple MinMaxRetType(values: Tensor, indices: Tensor)
    3. paddle.compat.min(input: Tensor, other: Tensor): see `paddle.minimum`

    Special warning: the gradient behavior is NOT well-documented by PyTorch, the actual behavior should be:

    1. Case 1: the same as `min`
    2. Case 2: NOT evenly distributing the gradient for equal minimum elements! PyTorch actually only propagates to the elements with indices,
        for example: Tensor([1, 1, 1]) -> min(..., dim=0) -> values=Tensor(0, ...), indices=Tensor(0), the gradient for input tensor won't be
        Tensor([1/3, 1/3, 1/3]) as stated in their documentation, but will be Tensor([1, 0, 0]). This API implements a similar backward kernel.
    3. Case 3: the same as `minimum`

    Args:
        input (Tensor): A tensor, the data type is bfloat16, float16, float32, float64, int32, int64 on GPU.
            uint8, int32, int64, float32, float64 are allowed on CPU.
        dim (int, optional): The dim along which the minimum is computed.
            If this is not specified: see case 1, note that: `None` cannot be passed to this (TypeError will be thrown)
            compute the minimum over all elements of `input` and return a Tensor with a single element,
            otherwise must be in the range :math:`[-input.ndim, input.ndim)`.
            If :math:`dim < 0`, the axis to reduce is :math:`input.ndim + dim`.
            Warning: if `dim` is specified, execute static graph will throw exceptions
            when not on a GPU device, since max_with_index is not implemented for non-GPU devices
        keepdim (bool, optional): Whether to reserve the reduced dimension in the
            output Tensor. The result tensor will have one fewer dimension
            than the `input` unless :attr:`keepdim` is true, default
            value is False. Note that if `dim` does not appear in neither (`*args`) or (`**kwargs`), this parameter cannot be passed alone
        other (Tensor, optional): the other tensor to perform `paddle.minimum` with. This Tensor should
            have the same or broadcast-able shape as the `input`. Note that (`dim` & `keepdim`) and `other` are mutually exclusive
            meaning that trying to composite both will result in TypeError
        out (Tensor|tuple[Tensor, Tensor], optional): the output Tensor or tuple of (Tensor, int64 Tensor) that can be optionally
            given to be used as output buffers. For case 1 and 3 out is just a Tensor, while for case 2 we expect a tuple


    Returns:
        - For case 1. A single value Tensor (0-dim)
        - For case 2. A named tuple MinMaxRetType(values: Tensor, indices: Tensor), `values` has the same data type as the `input`,
            while indices is always an int64 Tensor, with exactly the same shape as `values`.
            MinMaxRetType can be used (indexed, packed, unpacked) in the same way as a regular tuple
        - For case 3. See `paddle.minimum` (:ref:`api_paddle_minimum`)


    Examples:
        .. code-block:: python

            >>> import paddle

            >>> # data_x is a Tensor with shape [2, 4]
            >>> # the axis is a int element
            >>> x = paddle.to_tensor([[0.2, 0.3, 0.5, 0.9],
            ...                       [0.1, 0.2, 0.6, 0.7]],
            ...                       dtype='float64', stop_gradient=False)
            >>> # Case 1: reduce over all dims
            >>> result1 = paddle.compat.min(x)
            >>> result1
            Tensor(shape=[], dtype=float64, place=Place(gpu:0), stop_gradient=False,
            0.10000000)

            >>> # Case 2: reduce over specified dim
            >>> x.clear_grad()
            >>> result2 = paddle.compat.min(x, dim=1)
            >>> result2
            MinMaxRetType(values=Tensor(shape=[2], dtype=float64, place=Place(gpu:0), stop_gradient=False,
                [0.20000000, 0.10000000]), indices=Tensor(shape=[2], dtype=int64, place=Place(gpu:0), stop_gradient=True,
                [0, 0]))
            >>> result2[0].backward()
            >>> x.grad
            Tensor(shape=[2, 4], dtype=float64, place=Place(gpu:0), stop_gradient=False,
                [[1., 0., 0., 0.],
                 [1., 0., 0., 0.]])

            >>> # Case 3: equivalent to `paddle.minimum`
            >>> x.clear_grad()
            >>> y = paddle.to_tensor([[0.5, 0.4, 0.1, 0.2],
            ...                       [0.3, 0.1, 0.6, 0.7]],
            ...                       dtype='float64', stop_gradient=False)
            >>> result3 = paddle.compat.min(x, y)
            >>> result3
            Tensor(shape=[2, 4], dtype=float64, place=Place(gpu:0), stop_gradient=False,
                [[0.20000000, 0.30000000, 0.10000000, 0.20000000],
                 [0.10000000, 0.10000000, 0.60000000, 0.70000000]])
    9input should be a tensor, but got an instance with type ''minNFTr   r   r   rT   r   Zdevicer   rJ   )r'   r0   r5   r6   r   r2   r3   r>   r   r   rN   r   r   ndimr	   placeis_gpu_placer   Zargmintake_along_axisrC   squeeze_r   Zmin_with_indexr   zerosr   minimumrU   r;   r<   rQ   rD   rp   rq   r   r   retr<   r;   valsZindsr   r   r   r     sZ   a







r   zpaddle.compat.maxz
paddle.maxc                O  r   )a  

    Computes the maximum of tensor elements. There are mainly 3 cases (functionalities):

    1. paddle.compat.max(input: Tensor): reduce max over all dims, return a single value Tensor
    2. paddle.compat.max(input: Tensor, dim: int (cannot be None), keepdim=False): reduce max over the given dim,
        returns a named tuple MinMaxRetType(values: Tensor, indices: Tensor)
    3. paddle.compat.max(input: Tensor, other: Tensor): see `paddle.maximum`

    Special warning: the gradient behavior is NOT well-documented by PyTorch, the actual behavior should be:

    1. Case 1: the same as `max`
    2. Case 2: NOT evenly distributing the gradient for equal maximum elements! PyTorch actually only propagates to the elements with indices,
        for example: Tensor([1, 1, 1]) -> max(..., dim=0) -> values=Tensor(0, ...), indices=Tensor(0), the gradient for input tensor won't be
        Tensor([1/3, 1/3, 1/3]) as stated in their documentation, but will be Tensor([1, 0, 0]). This API implements a similar backward kernel.
    3. Case 3: the same as `maximum`

    Args:
        input (Tensor): A tensor, the data type is bfloat16, float16, float32, float64, int32, int64 on GPU.
            uint8, int32, int64, float32, float64 are allowed on CPU.
        dim (int, optional): The dim along which the maximum is computed.
            If this is not specified: see case 1, note that: `None` cannot be passed to this (TypeError will be thrown)
            compute the maximum over all elements of `input` and return a Tensor with a single element,
            otherwise must be in the range :math:`[-input.ndim, input.ndim)`.
            If :math:`dim < 0`, the axis to reduce is :math:`input.ndim + dim`.
            Warning: if `dim` is specified, execute static graph will throw exceptions
            when not on a GPU device, since max_with_index is not implemented for non-GPU devices
        keepdim (bool, optional): Whether to reserve the reduced dimension in the
            output Tensor. The result tensor will have one fewer dimension
            than the `input` unless :attr:`keepdim` is true, default
            value is False. Note that if `dim` does not appear in neither (`*args`) or (`**kwargs`), this parameter cannot be passed alone
        other (Tensor, optional): the other tensor to perform `paddle.maximum` with. This Tensor should
            have the same or broadcast-able shape as the `input`. Note that (`dim` & `keepdim`) and `other` are mutually exclusive
            meaning that trying to composite both will result in TypeError
        out (Tensor|tuple[Tensor, Tensor], optional): the output Tensor or tuple of (Tensor, int64 Tensor) that can be optionally
            given to be used as output buffers. For case 1 and 3 out is just a Tensor, while for case 2 we expect a tuple


    Returns:
        - For case 1. A single value Tensor (0-dim)
        - For case 2. A named tuple MinMaxRetType(values: Tensor, indices: Tensor), `values` has the same data type as the `input`,
            while indices is always an int64 Tensor, with exactly the same shape as `values`.
            MinMaxRetType can be used (indexed, packed, unpacked) in the same way as a regular tuple
        - For case 3. See `paddle.maximum` (:ref:`api_paddle_maximum`)


    Examples:
        .. code-block:: python

            >>> import paddle

            >>> # data_x is a Tensor with shape [2, 4]
            >>> # the axis is a int element
            >>> x = paddle.to_tensor([[0.2, 0.3, 0.5, 0.9],
            ...                       [0.1, 0.2, 0.6, 0.7]],
            ...                       dtype='float64', stop_gradient=False)
            >>> # Case 1: reduce over all dims
            >>> result1 = paddle.compat.max(x)
            >>> result1
            Tensor(shape=[], dtype=float64, place=Place(gpu:0), stop_gradient=False,
            0.90000000)

            >>> # Case 2: reduce over specified dim
            >>> x.clear_grad()
            >>> result2 = paddle.compat.max(x, dim=1)
            >>> result2
            MinMaxRetType(values=Tensor(shape=[2], dtype=float64, place=Place(gpu:0), stop_gradient=False,
                [0.90000000, 0.70000000]), indices=Tensor(shape=[2], dtype=int64, place=Place(gpu:0), stop_gradient=True,
                [3, 3]))
            >>> result2[0].backward()
            >>> x.grad
            Tensor(shape=[2, 4], dtype=float64, place=Place(gpu:0), stop_gradient=False,
                [[0., 0., 0., 1.],
                 [0., 0., 0., 1.]])

            >>> # Case 3: equivalent to `paddle.maximum`
            >>> x.clear_grad()
            >>> y = paddle.to_tensor([[0.5, 0.4, 0.1, 0.2],
            ...                       [0.3, 0.1, 0.6, 0.7]],
            ...                       dtype='float64', stop_gradient=False)
            >>> result3 = paddle.compat.max(x, y)
            >>> result3
            Tensor(shape=[2, 4], dtype=float64, place=Place(gpu:0), stop_gradient=False,
                [[0.50000000, 0.40000000, 0.50000000, 0.90000000],
                 [0.30000000, 0.20000000, 0.60000000, 0.70000000]])
    r   r   maxNFTr   r   rT   r   r   rJ   )r'   r0   r5   r6   r   r2   r3   r>   r   r   rN   r   r   r   r	   r   r   r   Zargmaxr   rC   r   r   Zmax_with_indexr   r   r   maximumrU   r;   r<   r   r   r   r   r     sZ   a







r   c                   @  r9   )MedianRetTyper   r;   r<   Nr=   r   r   r   r   r   F  rB   r   zpaddle.compat.medianzpaddle.median
int | Noner   %tuple[Tensor, Tensor] | Tensor | NoneTensor | MedianRetTypec                C  s   |du r t |d tj| ||dd}|durt|| |S |S t |d tj| ||dd\}}|durNt||d  t||d  t|d |d dS t||dS )	a+  
    Returns the median of the values in input.

    Args:
        input (Tensor): The input tensor.
        dim (int|None, optional): The dimension to reduce. If None, computes the median over all elements. Default is None.
        keepdim (bool, optional): Whether the output tensor has dim retained or not. Default is False.
        out (Tensor|tuple[Tensor, Tensor], optional): If provided, the result will be written into this tensor.
            For global median (dim=None), out must be a single tensor.
            For median along a dimension (dim specified, including dim=-1), out must be a tuple of two tensors (values, indices).

    Returns:
        Tensor|MedianRetType: If dim is None, returns a single tensor. If dim is specified (including dim=-1),
        returns a named tuple MedianRetType(values: Tensor, indices: Tensor).

    Examples:
        .. code-block:: python

            >>> import paddle

            >>> x = paddle.to_tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]])
            >>> result = paddle.compat.median(x)
            >>> print(result)
            Tensor(shape=[], dtype=int64, place=Place(cpu), stop_gradient=True, 5)

            >>> ret = paddle.compat.median(x, dim=1)
            >>> print(ret.values)
            Tensor(shape=[3], dtype=int64, place=Place(cpu), stop_gradient=True, [2, 5, 8])
            >>> print(ret.indices)
            Tensor(shape=[3], dtype=int64, place=Place(cpu), stop_gradient=True, [1, 1, 1])

            >>> # Using out parameter
            >>> out_values = paddle.zeros([3], dtype='int64')
            >>> out_indices = paddle.zeros([3], dtype='int64')
            >>> paddle.compat.median(x, dim=1, out=(out_values, out_indices))
            >>> print(out_values)
            Tensor(shape=[3], dtype=int64, place=Place(cpu), stop_gradient=True, [2, 5, 8])
    NFr   r   r   modeTr   rJ   rT   )rN   r0   medianrU   r   rQ   r   r   rD   resultr;   r<   r   r   r   r   K  s    2


r   zpaddle.compat.nanmedianzpaddle.nanmedianc                C  s   |du r t |d tj| ||dd}|durt|| |S |S t |d tj| ||dd\}}t|t|}|durWt||d  t||d  t|d |d dS t||dS )	a  
    Returns the median of the values in input, ignoring NaN values.

    Args:
        input (Tensor): The input tensor.
        dim (int|None, optional): The dimension to reduce. If None, computes the nanmedian over all elements. Default is None.
        keepdim (bool, optional): Whether the output tensor has dim retained or not. Default is False.
        out (Tensor|tuple[Tensor, Tensor], optional): If provided, the result will be written into this tensor.
            For global nanmedian (dim=None), out must be a single tensor.
            For nanmedian along a dimension (dim specified, including dim=-1), out must be a tuple of two tensors (values, indices).

    Returns:
        Tensor|MedianRetType: The median values, ignoring NaN. If dim is None, returns a single tensor. If dim is specified (including dim=-1),
        returns a named tuple MedianRetType(values: Tensor, indices: Tensor).

    Examples:
        .. code-block:: python

            >>> import paddle
            >>> import numpy as np

            >>> x = paddle.to_tensor([[1, float('nan'), 3], [4, 5, 6], [float('nan'), 8, 9]], dtype='float32')
            >>> result = paddle.compat.nanmedian(x)
            >>> print(result)
            Tensor(shape=[], dtype=float32, place=Place(cpu), stop_gradient=True, 5.0)

            >>> ret = paddle.compat.nanmedian(x, dim=1)
            >>> print(ret.values)
            Tensor(shape=[3], dtype=float32, place=Place(cpu), stop_gradient=True, [1.0, 5.0, 8.0])
            >>> print(ret.indices)
            Tensor(shape=[3], dtype=int64, place=Place(cpu), stop_gradient=True, [0, 1, 1])

            >>> # Using out parameter
            >>> out_values = paddle.zeros([3], dtype='float32')
            >>> out_indices = paddle.zeros([3], dtype='int64')
            >>> paddle.compat.nanmedian(x, dim=1, out=(out_values, out_indices))
            >>> print(out_values)
            Tensor(shape=[3], dtype=float32, place=Place(cpu), stop_gradient=True, [1.0, 5.0, 8.0])
    NFr   r   Tr   rJ   rT   )rN   r0   	nanmedianrU   r   Z
zeros_liker   r   r   r   r   r     s"   3


r   )r   )r   r   r   r   r   r   r   r   rk   )rD   rE   rF   rG   )rP   FFN)
rQ   r   r   r   rR   rG   rS   rG   r   r:   )r   ro   rp   r   rq   r   )rQ   r   )
rQ   r   rp   r   rD   rE   rq   r   r   r   )NF)
rQ   r   r   r   r   rG   rD   r   r   r   )#
__future__r   typingr   r   r   r0   r   Zbase.frameworkr   Z	frameworkr	   collections.abcr
   r   Zpaddle._typingr   r   Zpaddle.utils.decorator_utilsr   __all__r4   r:   rC   rN   rV   rW   r   r   r   r   r   r   r   r   r   r   r   r   <module>   s    0C
O
D
  @