# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: LicenseRef-NvidiaProprietary
#
# Use of this software is governed by the terms and conditions of the
# NVIDIA End User License Agreement (EULA), available at:
# https://docs.nvidia.com/cutlass/media/docs/pythonDSL/license.html
#
# Any use, reproduction, disclosure, or distribution of this software
# and related documentation outside the scope permitted by the EULA
# is strictly prohibited.

from typing import Type, Tuple
from enum import Enum

from cutlass.utils.layout import LayoutEnum
from cutlass.cutlass_dsl import (
    Float16,
    BFloat16,
    Float8E5M2,
    Float8E4M3FN,
    Numeric,
    NumericMeta,
    dsl_user_op,
)

import cutlass
import cutlass.cute as cute
from cutlass.cute.nvgpu.common import CopyUniversalOp
from cutlass.cute.nvgpu.warp import StMatrix8x8x16bOp
from cutlass.cute.nvgpu.warpgroup import (
    MmaF16BF16Op,
    MmaF8Op,
    OperandMajorMode,
    OperandSource,
)

@dsl_user_op
def sm90_get_smem_store_op(
    layout_d: LayoutEnum,
    elem_ty_d: Type[Numeric],
    elem_ty_acc: Type[Numeric],
    *,
    loc=None,
    ip=None,
) -> cute.CopyAtom:
    """
    Selects the largest vectorized smem store atom available subject to constraint of gmem layout.

    Parameters:
    -----------
    layout_d : LayoutEnum
        The layout enum of the output tensor D.

    elem_ty_d : Type[Numeric]
        The element type for output tensor D.

    elem_ty_acc : Type[Numeric]
        The element type for accumulator.

    Returns:
    --------
    Either SmemStoreMatrix or SimtSyncCopy, based on the input parameters.
    """

    def validate_type(ty, ty_name):
        if not isinstance(ty, NumericMeta):
            raise TypeError(f"{ty_name} must be a Numeric, but got {ty}")

    validate_type(elem_ty_d, "elem_ty_d")
    validate_type(elem_ty_acc, "elem_ty_acc")

    is_m_major = layout_d.is_m_major_c()

    if elem_ty_d.width == 16:
        return cute.make_copy_atom(
            StMatrix8x8x16bOp(is_m_major, 4), elem_ty_d, loc=loc, ip=ip
        )
    else:
        return cute.make_copy_atom(CopyUniversalOp(), elem_ty_d, loc=loc, ip=ip)


class SmemCapacity(Enum):
    SM90_SMEM_CAPACITY_BYTES = (228 - 1) * 1024


# Dictionary to map compute capability to SMEM capacity
SMEM_CAPACITY = {
    "sm90": SmemCapacity.SM90_SMEM_CAPACITY_BYTES.value,
}

def make_trivial_tiled_mma(
    a_dtype: Type[Numeric],
    b_dtype: Type[Numeric],
    a_leading_mode: OperandMajorMode,
    b_leading_mode: OperandMajorMode,
    acc_dtype: Type[Numeric],
    atom_layout_mnk: Tuple[int, int, int],
    tiler_mn: Tuple[int, int],
) -> cute.TiledMma:
    """Make a tiled MMA atom with given data type, leading dimension, cta group and mma tile shape.
    By default, the MMA atom is created with SMEM operand source for A.

    :param a_dtype: Data type of operand A.
    :type a_dtype: type[Numeric]
    :param b_dtype: Data type of operand B.
    :type b_dtype: type[Numeric]
    :param a_leading_mode: Leading dimension of operand A (1 for K, 0 for M/N).
    :type a_leading_mode: warpgroup.OperandMajorMode
    :param b_leading_mode: Leading dimension of operand B (1 for K, 0 for M/N).
    :type b_leading_mode: warpgroup.OperandMajorMode
    :param acc_dtype: Data type of the accumulator.
    :type acc_dtype: type[Numeric]
    :param atom_layout_mnk: A integer tuple describing the tiling of Atom across threads.
    :type atom_layout_mnk: Tuple[int, int, int]
    :param tiler_mn: The shape (M, N) of the cta tiler.
    :type tiler_mn: Tuple[int, int]

    :return: A tiled MMA atom.
    :rtype: cute.TiledMma

    :raises TypeError: If the data type is not supported.
    """

    if a_dtype in {Float16, BFloat16}:
        if cutlass.const_expr(a_dtype != b_dtype):
            raise TypeError(f"Type mismatch: {a_dtype} != {b_dtype}")
        if cutlass.const_expr(a_dtype.width != b_dtype.width):
            raise TypeError(f"Type width mismatch: {a_dtype.width} != {b_dtype.width}")

        mma_op = MmaF16BF16Op(
            a_dtype,
            acc_dtype,
            (*tiler_mn, 16),
            OperandSource.SMEM,
            a_leading_mode,
            b_leading_mode,
        )
    elif a_dtype in {Float8E4M3FN, Float8E5M2} and b_dtype in {
        Float8E4M3FN,
        Float8E5M2,
    }:
        mma_op = MmaF8Op(
            a_dtype,
            b_dtype,
            acc_dtype,
            (*tiler_mn, 32),
            OperandSource.SMEM,
            a_leading_mode,
            b_leading_mode,
        )
    else:
        raise TypeError(f"unsupported a_dtype and b_dtype, got {a_dtype} and {b_dtype}")

    return cute.make_tiled_mma(cute.make_mma_atom(mma_op), atom_layout_mnk)

def get_smem_layout_atom(
    layout: LayoutEnum,
    element_type: Type[Numeric],
    major_mode_size: int,
    *,
    loc=None,
    ip=None,
):
    """Select the optimal shared memory layout atom based on parameters.

    :param layout: Layout enum of the tensor
    :type layout: LayoutEnum
    :param element_type: Data type of the elements
    :type element_type: type[cutlass.Numeric]
    :param major_mode_size: Size of the major mode dimension
    :type major_mode_size: int

    :return: Selected shared memory layout atom kind
    :rtype: cute.nvgpu.warpgroup.SmemLayoutAtomKind
    """
    assert major_mode_size % 8 == 0
    sw128_num_contiguous_bits = 1024
    sw64_num_contiguous_bits = 512
    sw32_num_contiguous_bits = 256
    major_mode_size_bits = major_mode_size * element_type.width
    if layout.sm90_mma_major_mode() == OperandMajorMode.MN:
        if major_mode_size_bits % sw128_num_contiguous_bits == 0:
            return cute.nvgpu.warpgroup.SmemLayoutAtomKind.MN_SW128
        if major_mode_size_bits % sw64_num_contiguous_bits == 0:
            return cute.nvgpu.warpgroup.SmemLayoutAtomKind.MN_SW64
        if major_mode_size_bits % sw32_num_contiguous_bits == 0:
            return cute.nvgpu.warpgroup.SmemLayoutAtomKind.MN_SW32
        return cute.nvgpu.warpgroup.SmemLayoutAtomKind.MN_INTER
    if major_mode_size_bits % sw128_num_contiguous_bits == 0:
        return cute.nvgpu.warpgroup.SmemLayoutAtomKind.K_SW128
    if major_mode_size_bits % sw64_num_contiguous_bits == 0:
        return cute.nvgpu.warpgroup.SmemLayoutAtomKind.K_SW64
    if major_mode_size_bits % sw32_num_contiguous_bits == 0:
        return cute.nvgpu.warpgroup.SmemLayoutAtomKind.K_SW32
    return cute.nvgpu.warpgroup.SmemLayoutAtomKind.K_INTER
