#################################################################################################
#
# Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# 1. Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
#
# 2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
# 3. Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#
#################################################################################################

"""
Layout manipulation nodes and implementations

The layout Nodes change the layout of intermediate nodes in epilogue visitor graph
"""

from copy import deepcopy

from cutlass_library import LayoutType
from pycute import product, flatten

import cutlass
from cutlass.backend.evt.ir.layout_algorithm import _list_to_tuple, _tuple_to_list
from cutlass.backend.evt.ir.node import NodeBase
from cutlass.backend.evt.ir.tensor import Tensor


class PermutationImpl:
    """
    Detailed implementation and helper functions for permutation
    """
    def __init__(self, node) -> None:
        assert "indices" in node.kwargs.keys()
        self.indices = list(node.kwargs["indices"])
        self.inverse_indices = self.get_inverse_indices(self.indices)

    def get_inverse_impl(self):
        inverse_impl = deepcopy(self)
        inverse_impl.indices = self.inverse_indices
        inverse_impl.inverse_indices = self.indices
        return inverse_impl

    def update(self, shape):
        num_dim = len(shape)
        indices = self.indices
        num_old_dim = len(indices)
        # Add offset
        for i, idx in enumerate(indices):
            indices[i] = idx + num_dim - num_old_dim
        # Add broadcast dims
        for i in range(num_dim - num_old_dim):
            indices = [i,] + indices

        self.indices = indices
        self.inverse_indices = self.get_inverse_indices(self.indices)

    def get_inverse_indices(self, indices):
        """
        Get the indices for inverse permutation
        """
        num_dim = len(indices)
        inverse_indices = [0] * num_dim
        for i in range(num_dim):
            inverse_indices[indices[i]] = i
        return inverse_indices

    def shape_propagation(self, input_node_meta):
        input_shape = input_node_meta.tensor.shape
        output_shape = tuple([input_shape[idx] for idx in self.indices])
        return output_shape

    def broadcast(self, shape, node_meta: NodeBase):
        """
        Broadcast the inputs based on current shape
        """
        self.update(shape)
        inverse_shape = tuple([shape[idx] for idx in self.inverse_indices])
        node_meta.tensor.broadcast(inverse_shape)

    def apply_to_user(self, usr_meta: NodeBase):
        """
        Propagate the permutation to the users of the current nodes
        """
        usr_meta.tensor.permute(self.inverse_indices)
        if hasattr(usr_meta, "store_tensor"):
            if usr_meta.store_tensor is not None:
                usr_meta.store_tensor.permute(self.inverse_indices)

    def apply_to_input(self, input_meta: NodeBase):
        """
        Propagate the permutation to inputs of the current nodes
        """
        input_meta.tensor.permute(self.indices)
        if hasattr(input_meta, "store_tensor"):
            if input_meta.store_tensor is not None:
                input_meta.store_tensor.permute(self.indices)


class ReshapeImpl:
    """
    Detailed implementation and helper functions for reshape
    """
    def __init__(self, node) -> None:
        self.node = node
        assert "new_shape" in node.kwargs.keys()
        self.output_shape = _list_to_tuple(node.kwargs["new_shape"])

    def get_inverse_impl(self):
        inverse_impl = deepcopy(self)
        inverse_impl.output_shape = self.input_shape
        inverse_impl.input_shape = self.output_shape
        return inverse_impl

    def shape_propagation(self, input_node_meta):
        self.input_shape = input_node_meta.tensor.shape
        return _list_to_tuple(self.output_shape)

    def broadcast(self, shape, node_meta: NodeBase):
        """
        Broadcast the inputs based on current shape.
        """
        # Step 1: infer split
        flatten_split_shape = self.infer_split(flatten(self.input_shape), flatten(self.output_shape))
        split_input_shape = self.infer_merge(flatten_split_shape, self.input_shape)
        split_output_shape = self.infer_merge(flatten_split_shape, self.output_shape)

        # broadcast shape -> split_output_shape -> flatten_split_shape
        if len(shape) - len(split_output_shape) > 0:
            for _ in range(len(shape) - len(split_output_shape)):
                split_output_shape = [1,] + split_output_shape
                flatten_split_shape = [1,] + flatten_split_shape
                split_input_shape = [1,] + split_input_shape
        broadcast_factor = []
        for dim, old_dim in zip(shape, split_output_shape):
            if not isinstance(dim, list):
                dim = [dim,]
            if not isinstance(old_dim, list):
                old_dim = [old_dim,]
            if product(tuple(dim)) == product(tuple(old_dim)):
                broadcast_factor += [1] * len(old_dim)
            elif product(tuple(old_dim)) == 1:
                assert len(dim) == 1
                broadcast_factor.append(dim[0])
            else:
                raise NotImplementedError(f"Invalid Broadcast: {old_dim} -> {dim}")

        # flatten_split_shape -> split_input_shape
        factor_idx = 0
        broadcast_split_input_shape = []
        for dim in split_input_shape:
            if isinstance(dim, list):
                new_dim = []
                for d in dim:
                    new_dim.append(d * broadcast_factor[factor_idx])
                    factor_idx += 1
                broadcast_split_input_shape.append(new_dim)
            else:
                broadcast_split_input_shape.append(dim * broadcast_factor[factor_idx])
                factor_idx += 1
        broadcast_split_input_shape = _list_to_tuple(broadcast_split_input_shape)
        node_meta.tensor.reshape(_list_to_tuple(split_input_shape))
        node_meta.tensor.broadcast(broadcast_split_input_shape)
        # Last reshape op to clean up
        broadcast_input_shape = tuple([product(dim) for dim in broadcast_split_input_shape])
        node_meta.tensor.reshape(broadcast_input_shape)
        # Update the input shape and output shape
        self.input_shape = _list_to_tuple(node_meta.tensor.shape)
        self.output_shape = _list_to_tuple(shape)

    def apply_to_user(self, user_meta: NodeBase):
        """
        Propagate the reshape to user nodes
        """
        user_meta.tensor.reshape(tuple(self.input_shape))
        if hasattr(user_meta, "store_tensor"):
            if user_meta.store_tensor is not None:
                user_meta.store_tensor.reshape(tuple(self.input_shape))

    def apply_to_input(self, input_meta: NodeBase):
        """
        Propagate the reshape to input nodes
        """
        input_meta.tensor.reshape(tuple(self.output_shape))
        if hasattr(input_meta, "store_tensor"):
            if input_meta.store_tensor is not None:
                input_meta.store_tensor.reshape(tuple(self.output_shape))

    #
    # Helper functions
    #

    def infer_split(self, input_shape, output_shape):
        """
        Infer the flatten splitted shape that can be merged to both input_shape and output_shape
        """
        input_shape = _tuple_to_list(input_shape)
        output_shape = _tuple_to_list(output_shape)
        if len(input_shape) == 0 and len(output_shape) == 0:
            return []
        if len(input_shape) == 0:
            if product(tuple(output_shape)) != 1:
                raise ValueError("Invalid reshape size")
            else:
                return output_shape
        if len(output_shape) == 0:
            if product(tuple(input_shape)) != 1:
                raise ValueError("Invalid reshape size")
            else:
                return input_shape
        # This is done recursively by only process the last dimension at each time
        old_dim = input_shape[-1]
        new_dim = output_shape[-1]
        # Exact match
        if old_dim == new_dim:
            return self.infer_split(input_shape[:-1], output_shape[:-1]) + [new_dim,]
        # Needs split
        if old_dim > new_dim and old_dim % new_dim == 0:
            residual = old_dim // new_dim
            return self.infer_split(input_shape[:-1] + [residual,], output_shape[:-1]) + [new_dim,]
        # Needs merge
        if old_dim < new_dim and new_dim % old_dim == 0:
            residual = new_dim // old_dim
            return self.infer_split(input_shape[:-1], output_shape[:-1] + [residual,]) + [old_dim,]

        raise NotImplementedError(f"Unsupported split: {input_shape} -> {output_shape}")

    def infer_merge(self, flatten_shape, shape):
        flatten_shape = _tuple_to_list(flatten_shape)
        shape = _tuple_to_list(shape)
        idx_flat = len(flatten_shape) - 1
        merged_shape = []
        for dim in reversed(shape):
            # Exact match
            if dim == flatten_shape[idx_flat]:
                merged_shape.append(dim)
                idx_flat -= 1
            # need group
            elif dim > flatten_shape[idx_flat] and dim % flatten_shape[idx_flat] == 0:
                residual = dim
                group = []
                while(residual > 1):
                    group.append(flatten_shape[idx_flat])
                    residual = residual // flatten_shape[idx_flat]
                    idx_flat -= 1
                merged_shape.append(group[::-1])
            else:
                raise NotImplementedError(f"Unsupported merge: {flatten_shape} -> {shape}")

        return merged_shape[::-1]


class LayoutNode(NodeBase):
    """
    Layout manipulation nodes
    """
    fn_to_impl = {
        "permute": PermutationImpl,
        "reshape": ReshapeImpl
    }
    def __init__(self, name: str, fn, kwargs: dict) -> None:
        super().__init__(name)
        self.op = "layout"
        self.fn = fn
        self.kwargs = kwargs
        self.underlying_impl = self.fn_to_impl[self.fn.__name__](self)

    def get_inverse_node(self):
        inverse_node = deepcopy(self)
        inverse_node.underlying_impl = self.underlying_impl.get_inverse_impl()
        return inverse_node

    def shape_propagation(self, input_node_metas):
        if self._tensor is not None:
            return
        assert len(input_node_metas) == 1, "Layout node can only have one input node"

        output_shape = self.underlying_impl.shape_propagation(input_node_metas[0])

        self._tensor = Tensor(
            element=self.element_output,
            shape=output_shape, layout_tag=LayoutType.RowMajor
        )

        return super().shape_propagation(input_node_metas)

    def type_propagation(self, input_node_metas: 'list[NodeBase]'):
        """
        The store nodes has element_output = element_input
        """
        assert len(input_node_metas) == 1, "Layout node can only have one input node"
        self.element_output = input_node_metas[0].element_output

    def broadcast_propagation(self, input_node_metas: 'list[NodeBase]'):
        """
        Propagate the broadcast in the reversed topological order
        """
        if self.tensor is None:
            raise RuntimeError(f"The tensor of node {self.name} is unknown.")
        shape = self.tensor.shape

        for child in input_node_metas:
            self.underlying_impl.broadcast(shape, child)

    def apply_to_user(self, usr_meta: NodeBase):
        """
        Propagate the permutation to user nodes
        """
        self.underlying_impl.apply_to_user(usr_meta)

    def apply_to_input(self, input_meta: NodeBase):
        """
        Propagate the permutation to input nodes
        """
        self.underlying_impl.apply_to_input(input_meta)
