#################################################################################################
#
# Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# 1. Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
#
# 2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
# 3. Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#
#################################################################################################
from __future__ import annotations

import subprocess

from cutlass_library import DataTypeTag

from cutlass.backend.evt.ir.dag_ir import DAGIR


_COLOR_MAP = {
    "load": '"AliceBlue"',
    "compute": "LemonChiffon1",
    "accumulator": "LightGrey",
    "store": "PowderBlue",
    "layout": "lightseagreen",
    "dag": "darkorange"
}


class EVTGraphDrawer:
    """
    Visualize a EVT DAGIR with graphviz
    """
    def __init__(
        self,
        graph: DAGIR,
        name: str
    ):
        self._name = name
        self._dot_graphs = {}

        self._dot_graphs[name] = self._to_dot(graph, name)

    def _get_node_style(self, node):
        template = {
            "shape": "record",
            "fillcolor": "#CAFFE3",
            "style": '"filled,rounded"',
            "fontcolor": "#000000",
        }
        if node.op in _COLOR_MAP:
            template["fillcolor"] = _COLOR_MAP[node.op]
        else:
            raise NotImplementedError("unknown node op")
        if node.disabled:
            template["fontcolor"] = "grey"
            template["fillcolor"] = "white"
        return template

    def _get_node_label(self, node):
        label = "{" + f"name={node.name}|op={node.op}"
        if node.op == "layout":
            label += f"|fn={node.fn.__name__}"
            for key in node.kwargs:
                label += f"|{key}={node.kwargs[key]}"
        if node.underlying_impl is not None:
            label += f"|impl={type(node.underlying_impl).__name__}"
            if node.op == "load":
                label += f"|element_output={DataTypeTag[node.underlying_impl.element]}"
            elif node.op == "compute":
                label += f"|element_compute={DataTypeTag[node.underlying_impl.element_compute]}|element_output={DataTypeTag[node.underlying_impl.element_output]}"
            elif node.op == "store":
                label += f"|element_store={DataTypeTag[node.underlying_impl.element]}|element_output={DataTypeTag[node.underlying_impl.element_output]}"
            elif node.op == "dag":
                label += f"|element_output={DataTypeTag[node.underlying_impl.element_output]}"
        if node.tensor is not None:
            shape = node.tensor.shape
            stride = node.tensor.stride
            label += f"|shape={shape}|stride={stride}"

        if hasattr(node, "store_tensor"):
            if node.store_tensor is not None:
                store_shape = node.store_tensor.shape
                store_stride = node.store_tensor.stride
                label += f"|store_shape={store_shape}|stride_stride={store_stride}"

        label += "}"
        return label

    def _to_dot(
        self,
        graph: DAGIR,
        name: str
    ):
        import pydot
        dot_graph = pydot.Dot(name, randir="TB")
        for node in graph.nodes_meta:
            style = self._get_node_style(node)
            label = self._get_node_label(node)
            dot_node = pydot.Node(
                node.name, label=label, **style
            )
            dot_graph.add_node(dot_node)
            if node.op == "dag":
                dot_subgraph = self._to_dot(node.subgraph, name=node.name)
                self._dot_graphs[node.name] = dot_subgraph

        # Add edges
        for src, dst in graph.edges:
            weight = graph.get_edge_weight(src, dst)
            dot_graph.add_edge(pydot.Edge(src, dst, label=weight))

        return dot_graph

    def get_dot_graph(self) -> pydot.Dot:
        return [(key, self.get_dot_graph_by_name(key)) for key in self._dot_graphs.keys()]

    def get_dot_graph_by_name(self, name) -> pydot.Dot:
        return self._dot_graphs[name]

    def get_main_dot_graph(self) -> pydot.Dot:
        return self._dot_graphs[self._name]
