#################################################################################################
#
# Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# 1. Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
#
# 2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
# 3. Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#
#################################################################################################

"""
Base class for Python EVT Frontend
"""

from typing import Union

from cutlass_library import DataType
from cutlass.backend.evt.ir import (
    ComputeNode,
    DAGIR,
    LayoutNode,
    LoadNode,
    StoreNode,
)
from cutlass.backend.evt.passes import (
    EVTGraphDrawer,
    EVTPassManager,
    GetSmemSize,
    PassDAG2Tree,
    PassGetArgumentType,
    PassGetImpl,
    PassFixElementD,
    PassLayoutManipulateElimination,
    PassPreprocessRed,
    PassShapeTypePropagation,
)
from cutlass.backend.utils import device_cc
from cutlass.epilogue.evt_ops import permute, reshape
from cutlass.utils.datatypes import library_type


class EVTFrontendBase:
    layout_fns = {
        "permute": permute,
        "reshape": reshape
    }

    def __init__(self, cc, element_compute=DataType.f32, additional_passes=[], **kwargs) -> None:
        self.cc = cc
        self.element_compute = library_type(element_compute)
        self.dag_ir = DAGIR(self.cc, self.element_compute)
        self.compute_cnt = 0
        self.layout_cnt = 0

        self.pass_manager = EVTPassManager(
            self.dag_ir,
            [
                PassPreprocessRed,
                PassGetArgumentType,
                PassShapeTypePropagation,
                PassLayoutManipulateElimination,
                PassGetImpl,
                PassDAG2Tree,
                PassFixElementD
            ] + additional_passes)

        if self.cc == 80:
            self._epilogue_stages = 1
        else:
            self._epilogue_stages = None

    @property
    def epilogue_stages(self):
        return self._epilogue_stages

    @epilogue_stages.setter
    def epilogue_stages(self, stages):
        self._epilogue_stages = stages


    def parse(self, *args, **kwargs):
        raise NotImplementedError(f"The 'parse' function must be overloaded in frontend class")

    def trace(self, *args, **kwargs):
        # Parse the input
        self.parse(*args, **kwargs)

        # Run the passes
        self.pass_manager()
        # Set the epilogue type
        self.epilogue_thread_type = self.dag_ir.epilogue_thread_type
        if self.cc == 90:
            self.arg_c_type = self.dag_ir.arg_c_type
            self.arg_d_type = self.dag_ir.arg_d_type
        self.reduction_names = self.dag_ir.reduction_names

    #
    # Helper functions for DAG IR manipulation
    #

    def add_node(self, node):
        self.dag_ir.add_node(node)

    def add_edge(self, src, tgt, weight=0):
        self.dag_ir.add_edge(src, tgt, weight=weight)

    def set_tensor(self, node_name, example):
        """
        Add an example tensor to node {node_name} in the DAG IR
        """
        meta = self.dag_ir.get_node_meta(node_name)
        meta.tensor = {"tensor": example}

    def set_store_tensor(self, node_name, example):
        """
        Add an example tensor to node {node_name} in the DAG IR
        """
        meta = self.dag_ir.get_node_meta(node_name)
        meta.store_tensor = {"tensor": example}

    def mark_output(self, node_name):
        """
        Mark a store node as output
        """
        meta = self.dag_ir.get_node_meta(node_name)
        if not isinstance(meta, StoreNode):
            raise ValueError(
                f"Only StoreNodes can be marked as output. "
                f"Got {type(meta).__name__}: {node_name}")
        meta.is_output = True

    # Add node with specific type

    def add_load_node(self, name, example):
        """
        Add a Load node to DAG IR
        :param name: name of the loaded variable
        :type name: str
        :param example: example input
        :type example: np.ndarray|torch.Tensor|cupy.ndarray|float
        """
        if name is None:
            raise ValueError(f"Name is not provided.")
        if example is None:
            raise ValueError(f"Example input for {name} is not provided.")
        load_node = LoadNode(name)
        load_node.tensor = {"tensor": example}
        # Special logics for accumulator
        if name == "accum":
            if load_node.tensor.rank == 2:
                new_shape = tuple([1, ] + list(load_node.tensor.shape))
                load_node.tensor.broadcast(new_shape)
            elif load_node.tensor.rank < 2 or load_node.tensor.rank > 3:
                raise ValueError(f"Expect example inputs for 'accum' be a rank-2 or rank-3 tensor. Got {load_node.tensor.shape}.")
        self.add_node(load_node)

    def add_imm(self, value: Union[float,int]):
        """
        Add an immediate scalar value to DAG IR
        :param value: the value of the immediate scalar
        :type value: float
        """
        try:
            value = float(value)
        except:
            raise ValueError(f"{type(value).__name__} cannot be converted to float.")

        name = f"imm_{value}".replace('.', '_')
        load_node = LoadNode(name)
        load_node.tensor = {"tensor": value, "is_constant": True}
        self.add_node(load_node)
        return name

    def add_compute_node(self, op, name=None):
        """
        Add a compute node.
        :param op: the computation op
        :param name: the node name (optional)
        :type name: str
        :return: the name of the compute node
        """
        if name is None:
            name = f"compute_{self.compute_cnt}"
            self.compute_cnt += 1
        compute_node = ComputeNode(
            name=name, fn=op,
            element_output=self.element_compute,
            element_compute=self.element_compute)
        self.add_node(compute_node)
        return compute_node.name

    def add_layout_node(self, op, kwargs, name=None):
        """
        Add a layout node.
        :param op: the layout op
        :type op: evt_ops
        :param name: the node name (optional)
        :type name: str
        :return: the name of the layout node
        """
        if name is None:
            name = f"layout_{self.layout_cnt}"
            self.layout_cnt += 1
        layout_node = LayoutNode(name=name, fn=op, kwargs=kwargs)
        self.add_node(layout_node)
        return layout_node.name

    def add_store_node(self, name):
        store_node = StoreNode(name)
        self.add_node(store_node)

    #
    # Visualization The DAG IR
    #

    def visualize(self, name="dag_ir"):
        """
        Visualize the dag ir with svg file
        :param name: the name of the graph
        """
        drawer = EVTGraphDrawer(self.dag_ir, name)
        try:
            for name, graph in drawer.get_dot_graph():
                graph.write_svg(f"./{name}.svg")
        except:
            raise RuntimeError(
                "'dot' is not found in path. GraphDrawer is disabled. "
                "Please install it with 'sudo apt-get install graphviz'."
            )

    #
    # Get shared memory size
    #

    def get_smem_size(self, tile_description):
        """
        Get the shared memory size of the epilogue
        """
        smem_size = GetSmemSize(self.dag_ir)(tile_description)
        return smem_size
