# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: LicenseRef-NvidiaProprietary
#
# Use of this software is governed by the terms and conditions of the
# NVIDIA End User License Agreement (EULA), available at:
# https://docs.nvidia.com/cutlass/media/docs/pythonDSL/license.html
#
# Any use, reproduction, disclosure, or distribution of this software
# and related documentation outside the scope permitted by the EULA
# is strictly prohibited.

"""
This module provides helper functions that are generated by the preprocessor.
The preprocessor read through python's ast and changes the input code.
"""

from typing import Callable, Iterator, Optional, overload

from .utils.logger import log
from .common import *

from ._mlir_helpers.arith import ArithValue

class Executor:
    """
    The Executor class handles dynamic and compile-time (constexpr) execution
    of "for" loops and "if-else-elif" statements.

    Methods:
        set_functions:  Assigns the functions for checking loop bounds and
                        conditional evaluation.

        for_dynamic: Generates MLIR for OP
        for_constexpr: Executes a for loop at JIT compile-time
        for_execute: Decides whether to execute the loop at compile-time or generate MLIR for OP based on the provided bounds.

        if_dynamic: Generates MLIR if OP
        if_constexpr: Executes a if at JIT compile-time by python interpreter
        if_execute: Decides whether to execute the if statement at compile-time or generate MLIR if OP based on the predicate.
    """

    def __init__(self):
        self._is_dynamic_expression = None
        self._loop_execute_range_dynamic = None
        self._if_dynamic = None
        self._while_dynamic = None

    def set_functions(
        self,
        is_dynamic_expression: Callable,
        loop_execute_range_dynamic: Callable,
        if_dynamic: Callable,
        while_dynamic: Callable,
    ):
        self._is_dynamic_expression = is_dynamic_expression
        self._loop_execute_range_dynamic = loop_execute_range_dynamic
        self._if_dynamic = if_dynamic
        self._while_dynamic = while_dynamic

    @staticmethod
    def convert_to_list(x):
        """This function is used to convert x to a list.
        If x is None, return an empty list.
        If x is not a list, return a list containing x.
        Otherwise, return x itself.
        """
        if x is None:
            return []
        if not isinstance(x, list):
            return [x]
        return x

    @staticmethod
    def converge_ret_val(res):
        """This function is used to converge res (the return value) of the function.
        If res is None, return None.
        If res is a list and has only one element, return the element.
        Otherwise, return res itself.
        """
        if res is None:
            return res
        elif isinstance(res, list) and len(res) == 1:
            return res[0]
        return res

    def for_dynamic(
        self,
        func: Callable,
        start,
        stop,
        step,
        used_args: list,
        iter_args: list,
        iter_arg_names: list,
        unroll=bool,
        unroll_full=int,
    ):
        log().debug("start [%s] stop [%s] step [%s]", start, stop, step)
        return self._loop_execute_range_dynamic(
            func,
            start,
            stop,
            step,
            used_args,
            iter_args,
            iter_arg_names,
            unroll,
            unroll_full,
        )

    @staticmethod
    def for_constexpr(
        func: Callable,
        start: int,
        stop: int,
        step: int,
        used_args: list,
        iter_args: list,
    ):
        log().debug("start [%s] stop [%s] step [%s]", start, stop, step)
        loop_results = iter_args
        log().debug("iter_args [%s]", iter_args)
        for i in range(start, stop, step):
            log().debug("i  [%s] iter_args  [%s]", i, iter_args)
            loop_results = func(i, *used_args, *loop_results)
            log().debug("loop_results  [%s]", loop_results)
            if loop_results is None:
                loop_results = []
            if not isinstance(loop_results, list):
                loop_results = [loop_results]

        log().debug("done loop_results [%s]", loop_results)
        return Executor.converge_ret_val(loop_results)

    def for_execute(
        self,
        func,
        start,
        stop,
        step,
        used_args=[],
        iter_args=[],
        iter_arg_names=[],
        unroll=-1,
        unroll_full=False,
        is_range_constexpr=None,
    ):
        assert (
            self._loop_execute_range_dynamic and self._is_dynamic_expression
        ), "Functions must be set before execution."
        log().debug("start [%s] stop [%s] step [%s]", start, stop, step)
        any_dynamic_expression = (
            self._is_dynamic_expression(start)
            or self._is_dynamic_expression(stop)
            or self._is_dynamic_expression(step)
        )

        if is_range_constexpr is None:
            if not any_dynamic_expression:
                return self.for_constexpr(func, start, stop, step, used_args, iter_args)
            else:
                return self.for_dynamic(
                    func,
                    start,
                    stop,
                    step,
                    used_args,
                    iter_args,
                    iter_arg_names,
                    unroll,
                    unroll_full,
                )

        # Ensure bounds are compile-time constants for constexpr execution
        if is_range_constexpr:
            if any_dynamic_expression:
                raise DSLRuntimeError(
                    "Loop bounds must be constexpr (compile-time constants)"
                )
            return self.for_constexpr(func, start, stop, step, used_args, iter_args)

        # MLIR generation
        return self.for_dynamic(
            func,
            start,
            stop,
            step,
            used_args,
            iter_args,
            iter_arg_names,
            unroll,
            unroll_full,
        )

    def if_dynamic(
        self,
        pred,
        then_block: Callable,
        else_block: Optional[Callable] = None,
        used_args=[],
        yield_args=[],
        yield_arg_names=[],
    ):
        return self._if_dynamic(
            pred, then_block, else_block, used_args, yield_args, yield_arg_names
        )

    @staticmethod
    def if_constexpr(
        pred,
        then_block: Callable,
        else_block: Optional[Callable] = None,
        used_args=[],
        yield_args=[],
    ):
        if pred:
            log().debug(" running then block [%s]", yield_args)
            res = then_block(*used_args, *yield_args)
            log().debug("result [%s]", res)
            return Executor.converge_ret_val(res)
        elif else_block is not None:
            log().debug("running else [%s]", yield_args)
            res = else_block(*used_args, *yield_args)
            log().debug("result [%s]", res)
            return Executor.converge_ret_val(res)

    def if_execute(
        self,
        pred,
        then_block: Callable,
        else_block: Optional[Callable] = None,
        used_args=[],
        yield_args=[],
        yield_arg_names=[],
        if_constexpr=None,
    ):
        assert (
            self._if_dynamic and self._is_dynamic_expression
        ), "Functions must be set before execution."

        is_if_constexpr = not self._is_dynamic_expression(pred)
        if if_constexpr is None:
            if is_if_constexpr:
                return self.if_constexpr(
                    pred, then_block, else_block, used_args, yield_args
                )
            else:
                return self.if_dynamic(
                    pred, then_block, else_block, used_args, yield_args, yield_arg_names
                )

        # Ensure bounds are compile-time constants for constexpr execution
        if if_constexpr:
            if not is_if_constexpr:
                raise DSLRuntimeError(
                    "If predicate must be constexpr (compile-time constants)"
                )
            return self.if_constexpr(
                pred, then_block, else_block, used_args, yield_args
            )

        # MLIR generation
        return self.if_dynamic(
            pred, then_block, else_block, used_args, yield_args, yield_arg_names
        )

    def while_dynamic(
        self,
        while_before_block: Callable,
        while_after_block: Callable,
        used_args=[],
        yield_args=[],
        yield_arg_names=[],
    ):
        return self._while_dynamic(
            while_before_block,
            while_after_block,
            used_args,
            yield_args,
            yield_arg_names,
        )

    @staticmethod
    def while_constexpr(
        while_before_block,
        while_after_block,
        used_args=[],
        yield_args=[],
    ):
        log().debug(
            "while_constexpr begin %s", while_before_block.__qualname__
        )
        cond, loop_results = while_before_block(*used_args, *yield_args)
        while cond:
            loop_results = Executor.convert_to_list(loop_results)
            log().debug(
                "calling while_after [%s], [%s]",
                used_args,
                loop_results,
            )
            loop_results = while_after_block(*used_args, *loop_results)
            log().debug(
                "while after [%s]", loop_results
            )
            loop_results = Executor.convert_to_list(loop_results)
            log().debug(
                "calling while_before [%s], [%s]",
                used_args,
                loop_results,
            )
            cond, loop_results = while_before_block(*used_args, *loop_results)
            log().debug(
                "while_before cond, results [%s], [%s]",
                cond,
                loop_results,
            )

        log().debug(
            "while_constexpr results %s", loop_results
        )
        return Executor.converge_ret_val(loop_results)

    def while_execute(
        self,
        pred,
        while_before_block: Callable,
        while_after_block: Callable,
        used_args=[],
        yield_args=[],
        yield_arg_names=[],
        while_constexpr=None,
    ):
        assert (
            self._while_dynamic and self._is_dynamic_expression
        ), "Functions must be set before execution."

        is_while_constexpr = not self._is_dynamic_expression(pred)

        # Ensure bounds are compile-time constants for constexpr execution
        if while_constexpr:
            if not is_while_constexpr:
                raise DSLRuntimeError(
                    "While predicate must be constexpr (compile-time constants)"
                )
            return self.while_constexpr(
                while_before_block, while_after_block, used_args, yield_args
            )

        # MLIR generation
        return self.while_dynamic(
            while_before_block,
            while_after_block,
            used_args,
            yield_args,
            yield_arg_names,
        )


# =============================================================================
# Decorator
# =============================================================================

executor = Executor()


def loop_selector(
    start,
    stop,
    step,
    used_args=[],
    iter_args=[],
    iter_arg_names=[],
    unroll=-1,
    unroll_full=False,
    constexpr=None,
):
    log().debug(
        "start [%s] stop [%s] step [%s] used_args [%s] iter_args [%s] unroll [%s] unroll_full [%s] constexpr [%s]",
        start,
        stop,
        step,
        used_args,
        iter_args,
        unroll,
        unroll_full,
        constexpr,
    )
    from .typing import Integer, Numeric

    def _maybe_upcast(value):
        if isinstance(value, Integer):
            value = value.ir_value()

        return value

    start = _maybe_upcast(start)
    stop = _maybe_upcast(stop)
    step = _maybe_upcast(step)

    def ir_loop(func):
        return executor.for_execute(
            func,
            start,
            stop,
            step,
            used_args,
            iter_args,
            iter_arg_names,
            unroll,
            unroll_full,
            constexpr,
        )

    return ir_loop


def if_selector(pred, used_args=[], yield_args=[]):
    log().debug("pred [%s] used_args [%s] yield_args [%s]", pred, used_args, yield_args)
    # Handle Numeric types here?

    from .typing import Numeric

    if isinstance(pred, Numeric):
        pred = pred.value

    def ir_loop(func):
        return func(pred, *used_args, *yield_args)

    return ir_loop


def while_selector(pred, used_args=[], yield_args=[]):
    def ir_while_loop(func):
        return func(pred, *used_args, *yield_args)

    return ir_while_loop


def while_executor(
    pred,
    while_before_block: Callable,
    while_after_block: Callable,
    used_args=[],
    yield_args=[],
    yield_arg_names=[],
    constexpr=None,
):
    return executor.while_execute(
        pred,
        while_before_block,
        while_after_block,
        used_args,
        yield_args,
        yield_arg_names,
        constexpr,
    )


def if_executor(
    pred,
    then_block: Callable,
    else_block: Optional[Callable] = None,
    used_args=[],
    yield_args=[],
    yield_arg_names=[],
    constexpr=None,
):
    return executor.if_execute(
        pred, then_block, else_block, used_args, yield_args, yield_arg_names, constexpr
    )


# =============================================================================
# Range
# =============================================================================


class range_dynamic:
    @overload
    def __new__(cls, stop, unroll=0, unroll_full=False):
        pass

    @overload
    def __new__(cls, start, stop, step, unroll=0, unroll_full=False):
        pass

    def __new__(cls, *args, **kwargs):
        raise DSLRuntimeError("range_dynamic should be always preprocessed to IR")


class range_constexpr:
    def __init__(self, *args):
        if len(args) == 1:
            self.start = 0
            self.stop = args[0]
            self.step = 1
        elif len(args) == 2:
            self.start, self.stop = args
            self.step = 1
        elif len(args) == 3:
            self.start, self.stop, self.step = args
        else:
            raise DSLRuntimeError(
                "range_constexpr supports up to 3 arguments (start, stop, step)"
            )
        # Ensure the arguments are compile-time constants (if required)
        for arg_name, arg_value in [
            ("step", self.step),
            ("start", self.start),
            ("stop", self.stop),
        ]:
            if executor._is_dynamic_expression(arg_value):
                raise DSLRuntimeError(
                    f"`range_constexpr` requires `constexpr` (non-IR Values) for all arguments, "
                    f"but `{arg_name}` is not. If the arguments are dynamic, use `range`; the DSL "
                    f"will handle them during runtime. ",
                    suggestion="Use `range` instead of `range_constexpr`.",
                )

    def __iter__(self) -> Iterator[int]:
        current = self.start
        while current < self.stop:
            yield current
            current += self.step


# =============================================================================
# If expressions
# =============================================================================


def const_expr(expression):
    if executor._is_dynamic_expression(expression):
        raise DSLRuntimeError(
            f"The function `const_expr({expression})` received a dynamic expression (non compile-time constant).",
            context={
                "const_expr": "Accepts only constexpr (compile-time constant)",
                "If your expression depends on dynamic values": "Avoid marking it as `const_expr()`",
                "If the expression could be either dynamic or constexpr": "Omit explicit `const_expr()` marker; the DSL will infer the correct handling automatically",
            },
        )
    return expression


def dynamic_expr(expression):
    raise DSLRuntimeError("dynamic_expr should be always preprocessed to IR")


# =============================================================================
# Assertion & casting
# =============================================================================


def assert_executor(test, msg=None):
    from .typing import Numeric

    fail = False
    # Implicit convert dynamic expression to bool is not allowed
    # So here explicitly do a None check
    if test is not None and executor._is_dynamic_expression(test):
        if isinstance(test, Numeric):
            try:
                test = test.to(bool)
            except:
                fail = True
        else:
            fail = True

    if not fail:
        assert test, msg
    else:
        raise DSLRuntimeError(
            "Only constexpr (Python Value) is allowed here, but got non-constexpr (IR Values) expression.",
            suggestion = "Please replace with runtime assert."
        )


def bool_cast(value):
    if executor._is_dynamic_expression(value):
        raise DSLRuntimeError(
            "Only constexpr (Python Value) is allowed here, but got non-constexpr (IR Values) expression.",
            suggestion = "Please explicitly convert to boolean with expressions like comparision."
        )
    return bool(value)
