# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import typing

from paddle.distribution import distribution, independent, transform


class TransformedDistribution(distribution.Distribution):
    r"""
    Applies a sequence of Transforms to a base distribution.

    Args:
        base (Distribution): The base distribution.
        transforms (Sequence[Transform]): A sequence of ``Transform`` .

    Examples:

        .. code-block:: python

            import paddle
            from paddle.distribution import transformed_distribution

            d = transformed_distribution.TransformedDistribution(
                paddle.distribution.Normal(0., 1.),
                [paddle.distribution.AffineTransform(paddle.to_tensor(1.), paddle.to_tensor(2.))]
            )

            print(d.sample([10]))
            # Tensor(shape=[10], dtype=float32, place=Place(gpu:0), stop_gradient=True,
            #        [-0.10697651,  3.33609009, -0.86234951,  5.07457638,  0.75925219,
            #         -4.17087793,  2.22579336, -0.93845034,  0.66054249,  1.50957513])
            print(d.log_prob(paddle.to_tensor(0.5)))
            # Tensor(shape=[], dtype=float32, place=Place(gpu:0), stop_gradient=True,
            #        -1.64333570)
    """

    def __init__(self, base, transforms):
        if not isinstance(base, distribution.Distribution):
            raise TypeError(
                f"Expected type of 'base' is Distribution, but got {type(base)}."
            )
        if not isinstance(transforms, typing.Sequence):
            raise TypeError(
                f"Expected type of 'transforms' is Sequence[Transform] or Chain, but got {type(transforms)}."
            )
        if not all(isinstance(t, transform.Transform) for t in transforms):
            raise TypeError("All element of transforms must be Transform type.")

        chain = transform.ChainTransform(transforms)
        base_shape = base.batch_shape + base.event_shape
        self._base = base
        self._transforms = transforms
        if not transforms:
            super().__init__(base.batch_shape, base.event_shape)
            return
        if len(base.batch_shape + base.event_shape) < chain._domain.event_rank:
            raise ValueError(
                f"'base' needs to have shape with size at least {chain._domain.event_rank}, bug got {len(base_shape)}."
            )
        if chain._domain.event_rank > len(base.event_shape):
            base = independent.Independent(
                (base, chain._domain.event_rank - len(base.event_shape))
            )

        transformed_shape = chain.forward_shape(
            base.batch_shape + base.event_shape
        )
        transformed_event_rank = chain._codomain.event_rank + max(
            len(base.event_shape) - chain._domain.event_rank, 0
        )
        super().__init__(
            transformed_shape[
                : len(transformed_shape) - transformed_event_rank
            ],
            transformed_shape[
                len(transformed_shape) - transformed_event_rank :
            ],
        )

    def sample(self, shape=()):
        """Sample from ``TransformedDistribution``.

        Args:
            shape (Sequence[int], optional): The sample shape. Defaults to ().

        Returns:
            [Tensor]: The sample result.
        """
        x = self._base.sample(shape)
        for t in self._transforms:
            x = t.forward(x)
        return x

    def rsample(self, shape=()):
        """Reparameterized sample from ``TransformedDistribution``.

        Args:
            shape (Sequence[int], optional): The sample shape. Defaults to ().

        Returns:
            [Tensor]: The sample result.
        """
        x = self._base.rsample(shape)
        for t in self._transforms:
            x = t.forward(x)
        return x

    def log_prob(self, value):
        """The log probability evaluated at value.

        Args:
            value (Tensor): The value to be evaluated.

        Returns:
            Tensor: The log probability.
        """
        log_prob = 0.0
        y = value
        event_rank = len(self.event_shape)
        for t in reversed(self._transforms):
            x = t.inverse(y)
            event_rank += t._domain.event_rank - t._codomain.event_rank
            log_prob = log_prob - _sum_rightmost(
                t.forward_log_det_jacobian(x), event_rank - t._domain.event_rank
            )
            y = x
        log_prob += _sum_rightmost(
            self._base.log_prob(y), event_rank - len(self._base.event_shape)
        )
        return log_prob


def _sum_rightmost(value, n):
    return value.sum(list(range(-n, 0))) if n > 0 else value
