from typing import List, Optional

import torch
from torch import nn, Tensor
from torch.nn import functional as F
from torchvision.prototype.models.depth.stereo.raft_stereo import grid_sample, make_coords_grid


def make_gaussian_kernel(kernel_size: int, sigma: float) -> torch.Tensor:
    """Function to create a 2D Gaussian kernel."""

    x = torch.arange(kernel_size, dtype=torch.float32)
    y = torch.arange(kernel_size, dtype=torch.float32)
    x = x - (kernel_size - 1) / 2
    y = y - (kernel_size - 1) / 2
    x, y = torch.meshgrid(x, y, indexing="ij")
    grid = (x**2 + y**2) / (2 * sigma**2)
    kernel = torch.exp(-grid)
    kernel = kernel / kernel.sum()
    return kernel


def _sequence_loss_fn(
    flow_preds: List[Tensor],
    flow_gt: Tensor,
    valid_flow_mask: Optional[Tensor],
    gamma: Tensor,
    max_flow: int = 256,
    exclude_large: bool = False,
    weights: Optional[Tensor] = None,
):
    """Loss function defined over sequence of flow predictions"""
    torch._assert(
        gamma < 1,
        "sequence_loss: `gamma` must be lower than 1, but got {}".format(gamma),
    )

    if exclude_large:
        # exclude invalid pixels and extremely large diplacements
        flow_norm = torch.sum(flow_gt**2, dim=1).sqrt()
        if valid_flow_mask is not None:
            valid_flow_mask = valid_flow_mask & (flow_norm < max_flow)
        else:
            valid_flow_mask = flow_norm < max_flow

    if valid_flow_mask is not None:
        valid_flow_mask = valid_flow_mask.unsqueeze(1)
    flow_preds = torch.stack(flow_preds)  # shape = (num_flow_updates, batch_size, 2, H, W)

    abs_diff = (flow_preds - flow_gt).abs()
    if valid_flow_mask is not None:
        abs_diff = abs_diff * valid_flow_mask.unsqueeze(0)

    abs_diff = abs_diff.mean(axis=(1, 2, 3, 4))
    num_predictions = flow_preds.shape[0]

    # allocating on CPU and moving to device during run-time can force
    # an unwanted GPU synchronization that produces a large overhead
    if weights is None or len(weights) != num_predictions:
        weights = gamma ** torch.arange(num_predictions - 1, -1, -1, device=flow_preds.device, dtype=flow_preds.dtype)

    flow_loss = (abs_diff * weights).sum()
    return flow_loss, weights


class SequenceLoss(nn.Module):
    def __init__(self, gamma: float = 0.8, max_flow: int = 256, exclude_large_flows: bool = False) -> None:
        """
        Args:
            gamma: value for the exponential weighting of the loss across frames
            max_flow: maximum flow value to exclude
            exclude_large_flows: whether to exclude large flows
        """

        super().__init__()
        self.max_flow = max_flow
        self.excluding_large = exclude_large_flows
        self.register_buffer("gamma", torch.tensor([gamma]))
        # cache the scale factor for the loss
        self._weights = None

    def forward(self, flow_preds: List[Tensor], flow_gt: Tensor, valid_flow_mask: Optional[Tensor]) -> Tensor:
        """
        Args:
            flow_preds: list of flow predictions of shape (batch_size, C, H, W)
            flow_gt: ground truth flow of shape (batch_size, C, H, W)
            valid_flow_mask: mask of valid flow pixels of shape (batch_size, H, W)
        """
        loss, weights = _sequence_loss_fn(
            flow_preds, flow_gt, valid_flow_mask, self.gamma, self.max_flow, self.excluding_large, self._weights
        )
        self._weights = weights
        return loss

    def set_gamma(self, gamma: float) -> None:
        self.gamma.fill_(gamma)
        # reset the cached scale factor
        self._weights = None


def _ssim_loss_fn(
    source: Tensor,
    reference: Tensor,
    kernel: Tensor,
    eps: float = 1e-8,
    c1: float = 0.01**2,
    c2: float = 0.03**2,
    use_padding: bool = False,
) -> Tensor:
    # ref: Algorithm section: https://en.wikipedia.org/wiki/Structural_similarity
    # ref: Alternative implementation: https://kornia.readthedocs.io/en/latest/_modules/kornia/metrics/ssim.html#ssim

    torch._assert(
        source.ndim == reference.ndim == 4,
        "SSIM: `source` and `reference` must be 4-dimensional tensors",
    )

    torch._assert(
        source.shape == reference.shape,
        "SSIM: `source` and `reference` must have the same shape, but got {} and {}".format(
            source.shape, reference.shape
        ),
    )

    B, C, H, W = source.shape
    kernel = kernel.unsqueeze(0).unsqueeze(0).repeat(C, 1, 1, 1)
    if use_padding:
        pad_size = kernel.shape[2] // 2
        source = F.pad(source, (pad_size, pad_size, pad_size, pad_size), "reflect")
        reference = F.pad(reference, (pad_size, pad_size, pad_size, pad_size), "reflect")

    mu1 = F.conv2d(source, kernel, groups=C)
    mu2 = F.conv2d(reference, kernel, groups=C)

    mu1_sq = mu1.pow(2)
    mu2_sq = mu2.pow(2)

    mu1_mu2 = mu1 * mu2
    mu_img1_sq = F.conv2d(source.pow(2), kernel, groups=C)
    mu_img2_sq = F.conv2d(reference.pow(2), kernel, groups=C)
    mu_img1_mu2 = F.conv2d(source * reference, kernel, groups=C)

    sigma1_sq = mu_img1_sq - mu1_sq
    sigma2_sq = mu_img2_sq - mu2_sq
    sigma12 = mu_img1_mu2 - mu1_mu2

    numerator = (2 * mu1_mu2 + c1) * (2 * sigma12 + c2)
    denominator = (mu1_sq + mu2_sq + c1) * (sigma1_sq + sigma2_sq + c2)
    ssim = numerator / (denominator + eps)

    # doing 1 - ssim because we want to maximize the ssim
    return 1 - ssim.mean(dim=(1, 2, 3))


class SSIM(nn.Module):
    def __init__(
        self,
        kernel_size: int = 11,
        max_val: float = 1.0,
        sigma: float = 1.5,
        eps: float = 1e-12,
        use_padding: bool = True,
    ) -> None:
        """SSIM loss function.

        Args:
            kernel_size: size of the Gaussian kernel
            max_val: constant scaling factor
            sigma: sigma of the Gaussian kernel
            eps: constant for division by zero
            use_padding: whether to pad the input tensor such that we have a score for each pixel
        """
        super().__init__()

        self.kernel_size = kernel_size
        self.max_val = max_val
        self.sigma = sigma

        gaussian_kernel = make_gaussian_kernel(kernel_size, sigma)
        self.register_buffer("gaussian_kernel", gaussian_kernel)

        self.c1 = (0.01 * self.max_val) ** 2
        self.c2 = (0.03 * self.max_val) ** 2

        self.use_padding = use_padding
        self.eps = eps

    def forward(self, source: torch.Tensor, reference: torch.Tensor) -> torch.Tensor:
        """
        Args:
            source: source image of shape (batch_size, C, H, W)
            reference: reference image of shape (batch_size, C, H, W)

        Returns:
            SSIM loss of shape (batch_size,)
        """
        return _ssim_loss_fn(
            source,
            reference,
            kernel=self.gaussian_kernel,
            c1=self.c1,
            c2=self.c2,
            use_padding=self.use_padding,
            eps=self.eps,
        )


def _smoothness_loss_fn(img_gx: Tensor, img_gy: Tensor, val_gx: Tensor, val_gy: Tensor):
    # ref: https://github.com/nianticlabs/monodepth2/blob/b676244e5a1ca55564eb5d16ab521a48f823af31/layers.py#L202

    torch._assert(
        img_gx.ndim >= 3,
        "smoothness_loss: `img_gx` must be at least 3-dimensional tensor of shape (..., C, H, W)",
    )

    torch._assert(
        img_gx.ndim == val_gx.ndim,
        "smoothness_loss: `img_gx` and `depth_gx` must have the same dimensionality, but got {} and {}".format(
            img_gx.ndim, val_gx.ndim
        ),
    )

    for idx in range(img_gx.ndim):
        torch._assert(
            (img_gx.shape[idx] == val_gx.shape[idx] or (img_gx.shape[idx] == 1 or val_gx.shape[idx] == 1)),
            "smoothness_loss: `img_gx` and `depth_gx` must have either the same shape or broadcastable shape, but got {} and {}".format(
                img_gx.shape, val_gx.shape
            ),
        )

    # -3 is channel dimension
    weights_x = torch.exp(-torch.mean(torch.abs(val_gx), axis=-3, keepdim=True))
    weights_y = torch.exp(-torch.mean(torch.abs(val_gy), axis=-3, keepdim=True))

    smoothness_x = img_gx * weights_x
    smoothness_y = img_gy * weights_y

    smoothness = (torch.abs(smoothness_x) + torch.abs(smoothness_y)).mean(axis=(-3, -2, -1))
    return smoothness


class SmoothnessLoss(nn.Module):
    def __init__(self) -> None:
        super().__init__()

    def _x_gradient(self, img: Tensor) -> Tensor:
        if img.ndim > 4:
            original_shape = img.shape
            is_reshaped = True
            img = img.reshape(-1, *original_shape[-3:])
        else:
            is_reshaped = False

        padded = F.pad(img, (0, 1, 0, 0), mode="replicate")
        grad = padded[..., :, :-1] - padded[..., :, 1:]
        if is_reshaped:
            grad = grad.reshape(original_shape)
        return grad

    def _y_gradient(self, x: torch.Tensor) -> torch.Tensor:
        if x.ndim > 4:
            original_shape = x.shape
            is_reshaped = True
            x = x.reshape(-1, *original_shape[-3:])
        else:
            is_reshaped = False

        padded = F.pad(x, (0, 0, 0, 1), mode="replicate")
        grad = padded[..., :-1, :] - padded[..., 1:, :]
        if is_reshaped:
            grad = grad.reshape(original_shape)
        return grad

    def forward(self, images: Tensor, vals: Tensor) -> Tensor:
        """
        Args:
            images: tensor of shape (D1, D2, ..., DN, C, H, W)
            vals: tensor of shape (D1, D2, ..., DN, 1, H, W)

        Returns:
            smoothness loss of shape (D1, D2, ..., DN)
        """
        img_gx = self._x_gradient(images)
        img_gy = self._y_gradient(images)

        val_gx = self._x_gradient(vals)
        val_gy = self._y_gradient(vals)

        return _smoothness_loss_fn(img_gx, img_gy, val_gx, val_gy)


def _flow_sequence_consistency_loss_fn(
    flow_preds: List[Tensor],
    gamma: float = 0.8,
    resize_factor: float = 0.25,
    rescale_factor: float = 0.25,
    rescale_mode: str = "bilinear",
    weights: Optional[Tensor] = None,
):
    """Loss function defined over sequence of flow predictions"""

    # Simplified version of ref: https://arxiv.org/pdf/2006.11242.pdf
    # In the original paper, an additional refinement network is used to refine a flow prediction.
    # Each step performed by the recurrent module in Raft or CREStereo is a refinement step using a delta_flow update.
    # which should be consistent with the previous step. In this implementation, we simplify the overall loss
    # term and ignore left-right consistency loss or photometric loss which can be treated separately.

    torch._assert(
        rescale_factor <= 1.0,
        "sequence_consistency_loss: `rescale_factor` must be less than or equal to 1, but got {}".format(
            rescale_factor
        ),
    )

    flow_preds = torch.stack(flow_preds)  # shape = (num_flow_updates, batch_size, 2, H, W)
    N, B, C, H, W = flow_preds.shape

    # rescale flow predictions to account for bilinear upsampling artifacts
    if rescale_factor:
        flow_preds = (
            F.interpolate(
                flow_preds.view(N * B, C, H, W), scale_factor=resize_factor, mode=rescale_mode, align_corners=True
            )
        ) * rescale_factor
        flow_preds = torch.stack(torch.chunk(flow_preds, N, dim=0), dim=0)

    # force the next prediction to be similar to the previous prediction
    abs_diff = (flow_preds[1:] - flow_preds[:-1]).square()
    abs_diff = abs_diff.mean(axis=(1, 2, 3, 4))

    num_predictions = flow_preds.shape[0] - 1  # because we are comparing differences
    if weights is None or len(weights) != num_predictions:
        weights = gamma ** torch.arange(num_predictions - 1, -1, -1, device=flow_preds.device, dtype=flow_preds.dtype)

    flow_loss = (abs_diff * weights).sum()
    return flow_loss, weights


class FlowSequenceConsistencyLoss(nn.Module):
    def __init__(
        self,
        gamma: float = 0.8,
        resize_factor: float = 0.25,
        rescale_factor: float = 0.25,
        rescale_mode: str = "bilinear",
    ) -> None:
        super().__init__()
        self.gamma = gamma
        self.resize_factor = resize_factor
        self.rescale_factor = rescale_factor
        self.rescale_mode = rescale_mode
        self._weights = None

    def forward(self, flow_preds: List[Tensor]) -> Tensor:
        """
        Args:
            flow_preds: list of tensors of shape (batch_size, C, H, W)

        Returns:
            sequence consistency loss of shape (batch_size,)
        """
        loss, weights = _flow_sequence_consistency_loss_fn(
            flow_preds,
            gamma=self.gamma,
            resize_factor=self.resize_factor,
            rescale_factor=self.rescale_factor,
            rescale_mode=self.rescale_mode,
            weights=self._weights,
        )
        self._weights = weights
        return loss

    def set_gamma(self, gamma: float) -> None:
        self.gamma.fill_(gamma)
        # reset the cached scale factor
        self._weights = None


def _psnr_loss_fn(source: torch.Tensor, target: torch.Tensor, max_val: float) -> torch.Tensor:
    torch._assert(
        source.shape == target.shape,
        "psnr_loss: source and target must have the same shape, but got {} and {}".format(source.shape, target.shape),
    )

    # ref https://en.wikipedia.org/wiki/Peak_signal-to-noise_ratio
    return 10 * torch.log10(max_val**2 / ((source - target).pow(2).mean(axis=(-3, -2, -1))))


class PSNRLoss(nn.Module):
    def __init__(self, max_val: float = 256) -> None:
        """
        Args:
            max_val: maximum value of the input tensor. This refers to the maximum domain value of the input tensor.

        """
        super().__init__()
        self.max_val = max_val

    def forward(self, source: Tensor, target: Tensor) -> Tensor:
        """
        Args:
            source: tensor of shape (D1, D2, ..., DN, C, H, W)
            target: tensor of shape (D1, D2, ..., DN, C, H, W)

        Returns:
            psnr loss of shape (D1, D2, ..., DN)
        """

        # multiply by -1 as we want to maximize the psnr
        return -1 * _psnr_loss_fn(source, target, self.max_val)


class FlowPhotoMetricLoss(nn.Module):
    def __init__(
        self,
        ssim_weight: float = 0.85,
        ssim_window_size: int = 11,
        ssim_max_val: float = 1.0,
        ssim_sigma: float = 1.5,
        ssim_eps: float = 1e-12,
        ssim_use_padding: bool = True,
        max_displacement_ratio: float = 0.15,
    ) -> None:
        super().__init__()

        self._ssim_loss = SSIM(
            kernel_size=ssim_window_size,
            max_val=ssim_max_val,
            sigma=ssim_sigma,
            eps=ssim_eps,
            use_padding=ssim_use_padding,
        )

        self._L1_weight = 1 - ssim_weight
        self._SSIM_weight = ssim_weight
        self._max_displacement_ratio = max_displacement_ratio

    def forward(
        self,
        source: Tensor,
        reference: Tensor,
        flow_pred: Tensor,
        valid_mask: Optional[Tensor] = None,
    ):
        """
        Args:
            source: tensor of shape (B, C, H, W)
            reference: tensor of shape (B, C, H, W)
            flow_pred: tensor of shape (B, 2, H, W)
            valid_mask: tensor of shape (B, H, W) or None

        Returns:
            photometric loss of shape

        """
        torch._assert(
            source.ndim == 4,
            "FlowPhotoMetricLoss: source must have 4 dimensions, but got {}".format(source.ndim),
        )
        torch._assert(
            reference.ndim == source.ndim,
            "FlowPhotoMetricLoss: source and other must have the same number of dimensions, but got {} and {}".format(
                source.ndim, reference.ndim
            ),
        )
        torch._assert(
            flow_pred.shape[1] == 2,
            "FlowPhotoMetricLoss: flow_pred must have 2 channels, but got {}".format(flow_pred.shape[1]),
        )
        torch._assert(
            flow_pred.ndim == 4,
            "FlowPhotoMetricLoss: flow_pred must have 4 dimensions, but got {}".format(flow_pred.ndim),
        )

        B, C, H, W = source.shape
        flow_channels = flow_pred.shape[1]

        max_displacements = []
        for dim in range(flow_channels):
            shape_index = -1 - dim
            max_displacements.append(int(self._max_displacement_ratio * source.shape[shape_index]))

        # mask out all pixels that have larger flow than the max flow allowed
        max_flow_mask = torch.logical_and(
            *[flow_pred[:, dim, :, :] < max_displacements[dim] for dim in range(flow_channels)]
        )

        if valid_mask is not None:
            valid_mask = torch.logical_and(valid_mask, max_flow_mask).unsqueeze(1)
        else:
            valid_mask = max_flow_mask.unsqueeze(1)

        grid = make_coords_grid(B, H, W, device=str(source.device))
        resampled_grids = grid - flow_pred
        resampled_grids = resampled_grids.permute(0, 2, 3, 1)
        resampled_source = grid_sample(reference, resampled_grids, mode="bilinear")

        # compute SSIM loss
        ssim_loss = self._ssim_loss(resampled_source * valid_mask, source * valid_mask)
        l1_loss = (resampled_source * valid_mask - source * valid_mask).abs().mean(axis=(-3, -2, -1))
        loss = self._L1_weight * l1_loss + self._SSIM_weight * ssim_loss

        return loss.mean()
