# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: LicenseRef-NvidiaProprietary
#
# Use of this software is governed by the terms and conditions of the
# NVIDIA End User License Agreement (EULA), available at:
# https://docs.nvidia.com/cutlass/media/docs/pythonDSL/license.html
#
# Any use, reproduction, disclosure, or distribution of this software
# and related documentation outside the scope permitted by the EULA
# is strictly prohibited.

"""
This module provides a class that compiles generated IR using MLIR's PassManager
and executes it using MLIR's ExecutionEngine.

"""

from typing import Sequence, Optional, Tuple
import os
import sys
import inspect
from .common import DSLRuntimeError

_SCRIPT_PATH = os.path.dirname(os.path.abspath(__file__))
sys.path.append(_SCRIPT_PATH)

from .._mlir import ir


# =============================================================================
# Compiler Class
# =============================================================================


class CompilationError(RuntimeError):
    """Custom error class for compilation failures"""

    # Add ANSI color codes
    RED = "\033[91m"
    YELLOW = "\033[93m"
    BLUE = "\033[94m"
    GREEN = "\033[92m"
    BOLD = "\033[1m"
    RESET = "\033[0m"

    def __init__(
        self,
        message: str,
        nvvm_error: Optional[str] = None,
        ir_context: Optional[str] = None,
        cuda_toolkit: Optional[str] = None,
        arch: Optional[str] = None,
    ):
        self.nvvm_error = nvvm_error
        self.ir_context = ir_context
        self.cuda_toolkit = cuda_toolkit
        self.arch = arch
        # Call parent with formatted error to avoid showing class name
        super().__init__("")  # Empty string to avoid class name
        # Store formatted error for str() representation
        self._formatted_error = self._format_error()

    def __str__(self) -> str:
        """Override string representation to avoid showing class name"""
        return self._formatted_error

    def __repr__(self) -> str:
        """Override repr representation to avoid showing class name"""
        return self._formatted_error

    def _format_error(self) -> str:
        if not self.nvvm_error:
            return str(self.args[0])

        return f"""NVVM Compilation Error:
----------------------

{self.BLUE}⚙️  Current Settings:{self.RESET}
{self.BOLD}- CUDA Toolkit Path: {self.cuda_toolkit or "Not Set"}
- Target Architecture: {self.arch}{self.RESET}

IR Context (truncated):
{self.ir_context}

{self.YELLOW}💡 Possible Solutions:{self.RESET}
{self.GREEN}1. Check if CUDA_TOOLKIT_PATH is set correctly
2. Verify target architecture ({self.arch}) is supported by your CUDA toolkit
3. Make sure CUDA toolkit version matches the target architecture requirements{self.RESET}"""


class Compiler:
    """Compiler class for compiling and building MLIR modules."""

    def __init__(self, passmanager, execution_engine):
        self.passmanager = passmanager
        self.execution_engine = execution_engine

    def __call__(self, module):
        """Convenience application method."""
        self.compile(module)

    def _process_error(self, error_msg: str) -> Tuple[Optional[str], Optional[str]]:
        """Process error message to extract NVVM error and IR context"""
        nvvm_error = None
        ir_msg = ""

        if "NVVM_ERROR" in error_msg:
            # Extract the specific NVVM error
            nvvm_error = (
                error_msg.split("libNVVM extra log:")[1].strip()
                if "libNVVM extra log:" in error_msg
                else error_msg
            )

            # Extract IR context
            if "see current operation:" in error_msg:
                # Get the IR section
                ir_section = error_msg.split("see current operation:")[1].strip()
                # Remove duplicate IR section
                ir_section = ir_section.split("error: unknown: Failed translating")[
                    0
                ].strip()

                # Get first few lines and last few lines of the IR
                ir_lines = ir_section.split("\n")
                if len(ir_lines) > 10:
                    ir_msg = "\n".join(ir_lines[:5] + ["  ..."] + ir_lines[-5:])
                else:
                    ir_msg = ir_section

        return nvvm_error, ir_msg

    def compile(
        self,
        module,
        pipeline: str,
        cuda_toolkit: str = "",
        arch: str = "",
        enable_verifier=False,
    ):
        """Compiles the module by invoking the pipeline."""
        try:
            pm = self.passmanager.PassManager.parse(pipeline)
            pm.enable_verifier(enable_verifier)
            pm.run(module.operation)
        except Exception as e:
            error_msg = str(e)
            nvvm_error, ir_msg = self._process_error(error_msg)

            if nvvm_error:
                raise CompilationError(
                    error_msg,
                    nvvm_error=nvvm_error,
                    ir_context=ir_msg,
                    cuda_toolkit=cuda_toolkit,
                    arch=arch,
                ) from e
            raise e

    def jit(self, module, opt_level: int = 2, shared_libs: Sequence[str] = ()):
        """Wraps the module in a JIT execution engine."""
        return self.execution_engine.ExecutionEngine(
            module, opt_level=opt_level, shared_libs=shared_libs
        )

    def compile_and_jit(
        self,
        module,
        pipeline: str,
        shared_libs: Sequence[str] = (),
        opt_level: int = 2,
        cuda_toolkit: str = "",
        arch: str = "",
    ):
        """Compiles and jits the module."""
        self.compile(
            module,
            pipeline,
            cuda_toolkit,
            arch,
        )
        return self.jit(module, opt_level, shared_libs)


def compile(func, *args, **kwargs):
    if func is None:
        raise DSLRuntimeError("Function is not set or invalid.")

    if not callable(func):
        raise DSLRuntimeError("Object is not callable.")

    kwargs["compile_only"] = True
    kwargs["no_cache"] = True

    if inspect.isfunction(func):
        # regular function
        pass
    elif inspect.ismethod(func):
        # if it's a method, add the instance to the first argument
        args = [func.__self__] + list(args)
        func = func.__func__
    elif inspect.isclass(type(func)) and hasattr(func, "__call__"):
        # If it's a class instance, get the class's __call__ method
        args = [func] + list(args)
        # Get the actual function from the class definition
        func = func.__call__.__func__
    else:
        raise DSLRuntimeError(
            "Invalid function type, only function, method and module are supported, but got",
            func,
        )

    # If it's a wrapped function created by jit decorator, get the original function
    if hasattr(func, "__wrapped__"):
        func = func.__wrapped__

    if not hasattr(func, "_dsl_object"):
        raise DSLRuntimeError("Function is not decorated with jit decorator.")

    fcn_ptr = func._dsl_object._preprocess_and_execute(func)
    return func._dsl_object._func(fcn_ptr, *args, **kwargs)
