# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: LicenseRef-NvidiaProprietary
#
# Use of this software is governed by the terms and conditions of the
# NVIDIA End User License Agreement (EULA), available at:
# https://docs.nvidia.com/cutlass/media/docs/pythonDSL/license.html
#
# Any use, reproduction, disclosure, or distribution of this software
# and related documentation outside the scope permitted by the EULA
# is strictly prohibited.

"""
This module provides MLIR Arith Dialect helper functions
"""

import array
import numpy as np

from ..common import *
from ..._mlir import ir  # type: ignore
from ..._mlir.extras import types as T  # type: ignore
from ..._mlir.dialects import arith, nvgpu, math, builtin  # type: ignore

from .lru_cache_ir import lru_cache_ir

# =============================================================================
# Arith Dialect Helper functions
# =============================================================================


def recast_type(src_type, res_elem_type) -> ir.Type:
    if isinstance(src_type, T.VectorType):
        if src_type.scalable:
            res_type = T.vector(
                *src_type.shape,
                res_elem_type,
                scalable=src_type.scalable,
                scalable_dims=src_type.scalable_dims,
            )
        else:
            res_type = T.vector(*src_type.shape, res_elem_type)
    elif isinstance(src_type, T.RankedTensorType):
        res_type = T.RankedTensorType.get(
            element_type=res_elem_type, shape=src_type.shape, strides=src_type.strides
        )
    elif isinstance(src_type, T.UnrankedTensorType):
        res_type = T.UnrankedTensorType.get(element_type=res_elem_type)
    elif isinstance(src_type, T.MemRefType):
        res_type = T.MemRefType.get(
            element_type=res_elem_type, shape=src_type.shape, strides=src_type.strides
        )
    else:
        res_type = res_elem_type
    return res_type


def is_scalar(ty) -> bool:
    return not isinstance(
        ty, (T.VectorType, T.RankedTensorType, T.UnrankedTensorType, T.MemRefType)
    )


def element_type(ty) -> ir.Type:
    if not is_scalar(ty):
        return ty.element_type
    else:
        return ty


def is_narrow_precision(ty) -> bool:
    narrow_types = {
        T.f8E8M0FNU(),
        T.f8E4M3FN(),
        T.f8E4M3(),
        T.f8E5M2(),
        T.f8E4M3B11FNUZ(),
        T.f4E2M1FN(),
        T.f6E3M2FN(),
        T.f6E2M3FN(),
    }
    return ty in narrow_types


def is_float_type(ty) -> bool:
    return (
        arith._is_float_type(ty)
        # TODO-upstream: prediction is not correct. Patch here and fix in upstream later
        or is_narrow_precision(ty)
        or ty in (T.bf16(), T.tf32())
    )


def truncf_to_narrow(res_ty, src, loc, ip):
    res_elem_ty = element_type(res_ty)
    if res_elem_ty == T.f8E8M0FNU():
        rnd = nvgpu.RoundingMode.RP
    else:
        rnd = nvgpu.RoundingMode.RN
    return nvgpu.cvt_fptrunc(res_ty, src, rnd=rnd, loc=loc, ip=ip)


def extf_from_narrow(res_ty, src, loc, ip):
    src_elem_ty = element_type(src.type)

    # When source type is E8M0, temporary element type has to be bf16
    tmp_elem_ty = T.bf16() if src_elem_ty == T.f8E8M0FNU() else T.f16()
    tmp_ty = recast_type(src.type, tmp_elem_ty)

    # narrow -> bf16/f16 -> target type
    tmp = nvgpu.cvt_fpext(tmp_ty, src, loc=loc, ip=ip)
    return arith.extf(res_ty, tmp, loc=loc, ip=ip)


def bitcast(src, res_elem_type, *, loc=None, ip=None):
    res_type = recast_type(src.type, res_elem_type)
    return arith.bitcast(res_type, src, loc=loc, ip=ip)


def cvtf(src, res_elem_type, *, loc=None, ip=None):
    src_elem_type = element_type(src.type)

    if res_elem_type == src_elem_type:
        return src

    res_type = recast_type(src.type, res_elem_type)

    # Treat TF32 as F32 and use i32 as intermediate data
    # TODO-upstream: update arith to support tf32 <-> f32 conversion
    if src_elem_type == T.tf32():
        # tf32 -> i32
        tmp_type = recast_type(src.type, T.i32())
        src = builtin.unrealized_conversion_cast([tmp_type], [src], loc=loc, ip=ip)
        # i32 -> f32
        src = bitcast(src, T.f32(), loc=loc, ip=ip)
        # f32 -> X with `cvtf` recursively
        return cvtf(src, res_elem_type, loc=loc, ip=ip)

    if res_elem_type == T.tf32():
        # X -> f32 with `cvtf`` recursively
        tmp = cvtf(src, T.f32(), loc=loc, ip=ip)
        # f32 -> i32
        tmp = bitcast(tmp, T.i32(), loc=loc, ip=ip)
        # i32 -> tf32
        return builtin.unrealized_conversion_cast([res_type], [tmp], loc=loc, ip=ip)

    if res_elem_type.width > src_elem_type.width:
        if is_narrow_precision(src_elem_type):
            return extf_from_narrow(res_type, src, loc, ip)
        else:
            return arith.extf(res_type, src, loc=loc, ip=ip)
    else:
        tmp_mlir_type = recast_type(src.type, T.f32())

        # f16 -- extf -> f32 -- truncf -> bf16
        # TODO-upstream: update arith to support bf16 <-> f16 conversion?
        if (src_elem_type == T.f16() and res_elem_type == T.bf16()) or (
            src_elem_type == T.bf16() and res_elem_type == T.f16()
        ):
            tmp = arith.extf(tmp_mlir_type, src, loc=loc, ip=ip)
            return arith.truncf(res_type, tmp, loc=loc, ip=ip)

        # {f8, f6, f4} -> f16, f32, ...
        elif is_narrow_precision(res_elem_type):
            return truncf_to_narrow(res_type, src, loc, ip)
        else:
            return arith.truncf(res_type, src, loc=loc, ip=ip)


def fptoi(src, signed: Union[bool, None], res_elem_type, *, loc=None, ip=None):
    res_type = recast_type(src.type, res_elem_type)
    # TODO-upstream: update arith to support this kind of conversion
    if element_type(src.type) in (T.tf32(), T.bf16()):
        src = cvtf(src, T.f32(), loc=loc, ip=ip)

    if signed:
        return arith.fptosi(res_type, src, loc=loc, ip=ip)
    else:
        return arith.fptoui(res_type, src, loc=loc, ip=ip)


def itofp(src, signed: Union[bool, None], res_elem_type, *, loc=None, ip=None):
    res_type = recast_type(src.type, res_elem_type)

    orig_res_type = res_type
    # TODO-upstream: update arith to support this kind of conversion
    if res_elem_type in (T.tf32(), T.bf16()):
        res_type = recast_type(src.type, T.f32())

    if signed and element_type(src.type).width > 1:
        res = arith.sitofp(res_type, src, loc=loc, ip=ip)
    else:
        res = arith.uitofp(res_type, src, loc=loc, ip=ip)

    if orig_res_type == res_type:
        return res

    return cvtf(res, element_type(orig_res_type), loc=loc, ip=ip)


def int_to_int(a, dst_elem_type, *, loc=None, ip=None):
    src_signed = a.signed
    dst_signed = dst_elem_type.signed
    src_width = element_type(a.type).width
    dst_width = dst_elem_type.width

    dst_mlir_type = recast_type(a.type, dst_elem_type.mlir_type)

    if dst_width == src_width:
        return a
    elif src_signed and not dst_signed:
        # Signed -> Unsigned
        if dst_width > src_width:
            return arith.extui(dst_mlir_type, a, loc=loc, ip=ip)
        else:
            return arith.trunci(dst_mlir_type, a, loc=loc, ip=ip)
    elif src_signed == dst_signed:
        # Same signedness
        if dst_width > src_width:
            if src_signed and src_width > 1:
                return arith.extsi(dst_mlir_type, a, loc=loc, ip=ip)
            else:
                return arith.extui(dst_mlir_type, a, loc=loc, ip=ip)
        else:
            return arith.trunci(dst_mlir_type, a, loc=loc, ip=ip)
    else:
        # Unsigned -> Signed
        if dst_width > src_width:
            return arith.extui(dst_mlir_type, a, loc=loc, ip=ip)
        else:
            # For truncation from unsigned to signed, we need to handle overflow
            # First truncate to the target width
            trunc = arith.trunci(dst_mlir_type, a, loc=loc, ip=ip)
            # Then reinterpret as signed
            if dst_signed:
                return arith.bitcast(dst_mlir_type, trunc, loc=loc, ip=ip)
            return trunc


# =============================================================================
# Arith Ops Emitter Helpers
#   - assuming type of lhs and rhs match each other
#   - op name matches python module operator
# =============================================================================


def _cast(res_elem_ty, src, is_signed=None, *, loc=None, ip=None):
    """
    This function provides simplified interface to upstream op builder
        arith.truncf(T.vector(shape, new_type), src)

    is simplified as because it's element-wise op which can't change shape
        arith.truncf(new_type, src)
    """
    if isinstance(src, ir.Value):
        src_ty = src.type
    else:
        src_ty = type(src).mlir_type
        src = src.ir_value()

    src_elem_ty = element_type(src_ty)

    if src_elem_ty == res_elem_ty:
        return src
    elif is_float_type(src_elem_ty) and is_float_type(res_elem_ty):
        # float-to-float
        return cvtf(src, res_elem_ty, loc=loc, ip=ip)
    elif arith._is_integer_like_type(src_elem_ty) and arith._is_integer_like_type(
        res_elem_ty
    ):
        if src_elem_ty.width >= res_elem_ty.width:
            cast_op = arith.trunci
        else:
            if is_signed:
                cast_op = arith.extsi
            else:
                cast_op = arith.extui

        res_ty = recast_type(src_ty, res_elem_ty)
        return cast_op(res_ty, src, loc=loc, ip=ip)
    elif is_float_type(src_elem_ty) and arith._is_integer_like_type(res_elem_ty):
        return fptoi(src, is_signed, res_elem_ty, loc=loc, ip=ip)
    elif arith._is_integer_like_type(src_elem_ty) and is_float_type(res_elem_ty):
        return itofp(src, is_signed, res_elem_ty, loc=loc, ip=ip)
    else:
        raise DSLRuntimeError(
            f"cast from {src_elem_ty} to {res_elem_ty} is not supported"
        )


@lru_cache_ir()
def const(value, ty=None, *, loc=None, ip=None):
    """
    Generates dynamic expression for constant values.
    """
    from ..typing import Numeric, NumericMeta
    from ..dsl import is_dynamic_expression, _numpy_type_to_mlir_type

    if isinstance(value, Numeric):
        value = value.value

    # Early return
    if is_dynamic_expression(value) and (
        value.type.isinstance(value.type) or T.bool().isinstance(value.type)
    ):
        return value

    # Assume type
    if ty is None:
        if isinstance(value, float):
            ty = T.f32()
        elif isinstance(value, bool):
            ty = T.bool()
        elif isinstance(value, int):
            ty = T.i32()
        elif isinstance(value, np.ndarray):
            ty = T.vector(*value.shape, _numpy_type_to_mlir_type(value.dtype))
            value = array.array(value.dtype.kind, value.flatten().tolist())
        else:
            raise DSLNotImplemented(f"{type(value)} is not supported")
    elif isinstance(ty, NumericMeta):
        ty = ty.mlir_type
    elif isinstance(ty, ir.Type):
        if ir.RankedTensorType.isinstance(ty) or ir.VectorType.isinstance(ty):
            elem_ty = ty.element_type
            if isinstance(elem_ty, ir.IntegerType):
                attr = ir.IntegerAttr.get(elem_ty, value)
            else:
                attr = ir.FloatAttr.get(elem_ty, value)
            value = ir.DenseElementsAttr.get_splat(ty, attr)
        elif arith._is_float_type(ty) and isinstance(value, (bool, int)):
            value = float(value)
        elif arith._is_integer_like_type(ty) and isinstance(value, float):
            value = int(value)
    else:
        raise DSLNotImplemented(f"type {ty} is not supported")

    return arith.constant(ty, value, loc=loc, ip=ip)


def _dispatch_to_rhs_r_op(op):
    """Decorator that dispatches to the right-hand-side's reverse operation.

    If the other operand is not an ArithValue or is a subclass (more specific)
    of ArithValue, this allows proper method resolution for binary operations.
    """

    def wrapper(self, other, **kwargs):
        if not isinstance(other, ArithValue):
            if not isinstance(other, (int, float, bool)):
                # allows to call other.__rmul__
                return NotImplemented

        return op(self, other, **kwargs)

    return wrapper


def _binary_op(op):
    """
    Decorator to check if the 'other' argument is an ArithValue.
    If not, returns NotImplemented.
    """

    def wrapper(self, other, **kwargs):
        # When reach this point, `self` must be cast to base `ArithValue` type
        if isinstance(other, (int, float, bool)):
            other = const(other, self.type).with_signedness(self.signed)

        # Call the original function
        # If sub-class doesn't implement overloaded arithmetic, cast to base class
        return op(self, other, **kwargs)

    return wrapper


# Operator overloading
@ir.register_value_caster(ir.Float4E2M1FNType.static_typeid)
@ir.register_value_caster(ir.Float6E2M3FNType.static_typeid)
@ir.register_value_caster(ir.Float6E3M2FNType.static_typeid)
@ir.register_value_caster(ir.Float8E4M3FNType.static_typeid)
@ir.register_value_caster(ir.Float8E4M3B11FNUZType.static_typeid)
@ir.register_value_caster(ir.Float8E5M2Type.static_typeid)
@ir.register_value_caster(ir.Float8E4M3Type.static_typeid)
@ir.register_value_caster(ir.Float8E8M0FNUType.static_typeid)
@ir.register_value_caster(ir.BF16Type.static_typeid)
@ir.register_value_caster(ir.F16Type.static_typeid)
@ir.register_value_caster(ir.FloatTF32Type.static_typeid)
@ir.register_value_caster(ir.F32Type.static_typeid)
@ir.register_value_caster(ir.F64Type.static_typeid)
@ir.register_value_caster(ir.IntegerType.static_typeid)
@ir.register_value_caster(ir.VectorType.static_typeid)
@ir.register_value_caster(ir.RankedTensorType.static_typeid)
class ArithValue(ir.Value):
    """Overloads operators for MLIR's Arith dialects binary operations."""

    def __init__(self, v, signed: Union[bool, None] = None):
        if isinstance(v, int):
            v = arith.constant(self.type, v)
        super().__init__(v)

        elem_ty = element_type(self.type)
        self.is_float = arith._is_float_type(elem_ty)
        # arith dialect consider `1` in `i1` as `-1`, treat it as unsigned for DSL
        self.signed = signed and elem_ty.width > 1

    def with_signedness(self, signed: Union[bool, None]):
        return type(self)(self, signed)

    def __neg__(self, *, loc=None, ip=None):
        if self.type == T.bool():
            raise TypeError(
                "Negation, the operator `-` is not supported for boolean type"
            )

        if self.is_float:
            return arith.negf(self, loc=loc, ip=ip)
        else:
            c0 = arith.constant(self.type, 0, loc=loc, ip=ip)
            return arith.subi(c0, self, loc=loc, ip=ip)

    @_binary_op
    def __pow__(self, other, *, loc=None, ip=None) -> "ArithValue":
        if self.is_float and other.is_float:
            return math.powf(self, other, loc=loc, ip=ip)
        elif self.is_float and not other.is_float:
            return math.fpowi(self, other, loc=loc, ip=ip)
        elif not self.is_float and other.is_float:
            lhs = itofp(self, self.signed, T.f32(), loc=loc, ip=ip)
            rhs = cvtf(other, T.f32(), loc=loc, ip=ip)
            return math.powf(lhs, rhs, loc=loc, ip=ip)
        elif not self.is_float and not other.is_float:
            return math.ipowi(self, other, loc=loc, ip=ip)
        else:
            raise DSLNotImplemented(f"Unsupported '{self} ** {other}'")

    @_binary_op
    def __rpow__(self, other, *, loc=None, ip=None) -> "ArithValue":
        return other.__pow__(self, loc=loc, ip=ip)

    # arith operators

    @_dispatch_to_rhs_r_op
    @_binary_op
    def __add__(self, other, *, loc=None, ip=None) -> "ArithValue":
        if self.is_float:
            return arith.addf(self, other, loc=loc, ip=ip)
        else:
            return arith.addi(self, other, loc=loc, ip=ip)

    @_dispatch_to_rhs_r_op
    @_binary_op
    def __sub__(self, other, *, loc=None, ip=None) -> "ArithValue":
        if self.is_float:
            return arith.subf(self, other, loc=loc, ip=ip)
        else:
            return arith.subi(self, other, loc=loc, ip=ip)

    @_dispatch_to_rhs_r_op
    @_binary_op
    def __mul__(self, other, *, loc=None, ip=None) -> "ArithValue":
        if self.is_float:
            return arith.mulf(self, other, loc=loc, ip=ip)
        else:
            return arith.muli(self, other, loc=loc, ip=ip)

    @_dispatch_to_rhs_r_op
    @_binary_op
    def __truediv__(self, other, *, loc=None, ip=None) -> "ArithValue":
        if self.is_float:
            return arith.divf(self, other, loc=loc, ip=ip)
        else:
            lhs = itofp(self, self.signed, T.f32(), loc=loc, ip=ip)
            rhs = itofp(other, other.signed, T.f32(), loc=loc, ip=ip)
            return arith.divf(lhs, rhs, loc=loc, ip=ip)

    @_dispatch_to_rhs_r_op
    @_binary_op
    def __floordiv__(self, other, *, loc=None, ip=None) -> "ArithValue":
        if self.is_float:
            q = arith.divf(self, other, loc=loc, ip=ip)
            return math.floor(q, loc=loc, ip=ip)
        elif self.signed:
            return arith.floordivsi(self, other, loc=loc, ip=ip)
        else:
            return arith.divui(self, other, loc=loc, ip=ip)

    @_dispatch_to_rhs_r_op
    @_binary_op
    def __mod__(self, other, *, loc=None, ip=None) -> "ArithValue":
        if self.is_float:
            return arith.remf(self, other, loc=loc, ip=ip)
        elif self.signed:
            return arith.remsi(self, other, loc=loc, ip=ip)
        else:
            return arith.remui(self, other, loc=loc, ip=ip)

    @_binary_op
    def __radd__(self, other, *, loc=None, ip=None) -> "ArithValue":
        return other.__add__(self, loc=loc, ip=ip)

    @_binary_op
    def __rsub__(self, other, *, loc=None, ip=None) -> "ArithValue":
        return other.__sub__(self, loc=loc, ip=ip)

    @_binary_op
    def __rmul__(self, other, *, loc=None, ip=None) -> "ArithValue":
        return other.__mul__(self, loc=loc, ip=ip)

    @_binary_op
    def __rtruediv__(self, other, *, loc=None, ip=None) -> "ArithValue":
        return other.__truediv__(self, loc=loc, ip=ip)

    @_binary_op
    def __rfloordiv__(self, other, *, loc=None, ip=None) -> "ArithValue":
        return other.__floordiv__(self, loc=loc, ip=ip)

    @_binary_op
    def __rmod__(self, other, *, loc=None, ip=None) -> "ArithValue":
        return other.__mod__(self, loc=loc, ip=ip)

    # Comparison operators (comparison doesn't have right-hand-side variants)
    @_dispatch_to_rhs_r_op
    @_binary_op
    def __lt__(self, other, *, loc=None, ip=None) -> "ArithValue":
        if self.is_float:
            return arith.cmpf(arith.CmpFPredicate.OLT, self, other, loc=loc, ip=ip)
        elif self.signed:
            return arith.cmpi(arith.CmpIPredicate.slt, self, other, loc=loc, ip=ip)
        else:
            return arith.cmpi(arith.CmpIPredicate.ult, self, other, loc=loc, ip=ip)

    @_dispatch_to_rhs_r_op
    @_binary_op
    def __le__(self, other, *, loc=None, ip=None) -> "ArithValue":
        if self.is_float:
            return arith.cmpf(arith.CmpFPredicate.OLE, self, other, loc=loc, ip=ip)
        elif self.signed:
            return arith.cmpi(arith.CmpIPredicate.sle, self, other, loc=loc, ip=ip)
        else:
            return arith.cmpi(arith.CmpIPredicate.ule, self, other, loc=loc, ip=ip)

    @_dispatch_to_rhs_r_op
    @_binary_op
    def __eq__(self, other, *, loc=None, ip=None) -> "ArithValue":
        if self.is_float:
            return arith.cmpf(arith.CmpFPredicate.OEQ, self, other, loc=loc, ip=ip)
        else:
            return arith.cmpi(arith.CmpIPredicate.eq, self, other, loc=loc, ip=ip)

    @_dispatch_to_rhs_r_op
    @_binary_op
    def __ne__(self, other, *, loc=None, ip=None) -> "ArithValue":
        if self.is_float:
            # In Python, bool(float("nan")) is True, so use unordered comparison here
            return arith.cmpf(arith.CmpFPredicate.UNE, self, other, loc=loc, ip=ip)
        else:
            return arith.cmpi(arith.CmpIPredicate.ne, self, other, loc=loc, ip=ip)

    @_dispatch_to_rhs_r_op
    @_binary_op
    def __gt__(self, other, *, loc=None, ip=None) -> "ArithValue":
        if self.is_float:
            return arith.cmpf(arith.CmpFPredicate.OGT, self, other, loc=loc, ip=ip)
        elif self.signed:
            return arith.cmpi(arith.CmpIPredicate.sgt, self, other, loc=loc, ip=ip)
        else:
            return arith.cmpi(arith.CmpIPredicate.ugt, self, other, loc=loc, ip=ip)

    @_dispatch_to_rhs_r_op
    @_binary_op
    def __ge__(self, other, *, loc=None, ip=None) -> "ArithValue":
        if self.is_float:
            return arith.cmpf(arith.CmpFPredicate.OGE, self, other, loc=loc, ip=ip)
        elif self.signed:
            return arith.cmpi(arith.CmpIPredicate.sge, self, other, loc=loc, ip=ip)
        else:
            return arith.cmpi(arith.CmpIPredicate.uge, self, other, loc=loc, ip=ip)

    # Unary operators
    def __invert__(self, *, loc=None, ip=None) -> "ArithValue":
        return arith.xori(self, arith.constant(self.type, -1))

    # Bitwise operations
    @_dispatch_to_rhs_r_op
    @_binary_op
    def __and__(self, other, *, loc=None, ip=None) -> "ArithValue":
        return arith.andi(self, other, loc=loc, ip=ip)

    @_dispatch_to_rhs_r_op
    @_binary_op
    def __or__(self, other, *, loc=None, ip=None) -> "ArithValue":
        return arith.ori(self, other, loc=loc, ip=ip)

    @_dispatch_to_rhs_r_op
    @_binary_op
    def __xor__(self, other, *, loc=None, ip=None) -> "ArithValue":
        return arith.xori(self, other, loc=loc, ip=ip)

    @_dispatch_to_rhs_r_op
    @_binary_op
    def __rshift__(self, other, *, loc=None, ip=None) -> "ArithValue":
        if self.signed:
            return arith.shrsi(self, other, loc=loc, ip=ip)
        else:
            return arith.shrui(self, other, loc=loc, ip=ip)

    @_dispatch_to_rhs_r_op
    @_binary_op
    def __lshift__(self, other, *, loc=None, ip=None) -> "ArithValue":
        return arith.shli(self, other, loc=loc, ip=ip)

    @_binary_op
    def __rand__(self, other, *, loc=None, ip=None) -> "ArithValue":
        return arith.andi(other, self, loc=loc, ip=ip)

    @_binary_op
    def __ror__(self, other, *, loc=None, ip=None) -> "ArithValue":
        return arith.ori(other, self, loc=loc, ip=ip)

    @_binary_op
    def __rxor__(self, other, *, loc=None, ip=None) -> "ArithValue":
        return arith.xori(other, self, loc=loc, ip=ip)

    @_binary_op
    def __rrshift__(self, other, *, loc=None, ip=None) -> "ArithValue":
        return other.__rshift__(self, loc=loc, ip=ip)

    @_binary_op
    def __rlshift__(self, other, *, loc=None, ip=None) -> "ArithValue":
        return other.__lshift__(self, loc=loc, ip=ip)

    def __hash__(self):
        return super().__hash__()

    def __str__(self):
        return super().__str__().replace(ir.Value.__name__, ArithValue.__name__)

    def __repr__(self):
        return self.__str__()


def _min(lhs, rhs, *, loc=None, ip=None):
    """
    This function provides a unified interface for building arith min

    Assuming the operands have the same type
    """
    from ..dsl import is_dynamic_expression

    if not is_dynamic_expression(lhs):
        if not is_dynamic_expression(rhs):
            return min(lhs, rhs)
        else:
            lhs = arith.constant(rhs.type, lhs, loc=loc, ip=ip)
    else:
        if not is_dynamic_expression(rhs):
            rhs = arith.constant(lhs.type, rhs, loc=loc, ip=ip)

    if arith._is_integer_like_type(lhs.type):
        if lhs.signed:
            return arith.minsi(lhs, rhs, loc=loc, ip=ip)
        else:
            return arith.minui(lhs, rhs, loc=loc, ip=ip)
    else:
        return arith.minimumf(lhs, rhs, loc=loc, ip=ip)


def _max(lhs, rhs, *, loc=None, ip=None):
    """
    This function provides a unified interface for building arith max

    Assuming the operands have the same type
    """
    from ..dsl import is_dynamic_expression

    if not is_dynamic_expression(lhs):
        if not is_dynamic_expression(rhs):
            return max(lhs, rhs)
        else:
            lhs = arith.constant(rhs.type, lhs, loc=loc, ip=ip)
    else:
        if not is_dynamic_expression(rhs):
            rhs = arith.constant(lhs.type, rhs, loc=loc, ip=ip)

    if arith._is_integer_like_type(lhs.type):
        if lhs.signed:
            return arith.maxsi(lhs, rhs, loc=loc, ip=ip)
        else:
            return arith.maxui(lhs, rhs, loc=loc, ip=ip)
    else:
        return arith.maximumf(lhs, rhs, loc=loc, ip=ip)
