# Owner(s): ["module: inductor"]
import contextlib
import sys
import unittest

import torch
from torch._inductor import config
from torch.testing._internal.common_utils import (
    instantiate_parametrized_tests,
    MACOS_VERSION,
    parametrize,
)
from torch.testing._internal.inductor_utils import GPU_TYPE, RUN_CPU, RUN_GPU


try:
    try:
        from . import test_torchinductor
    except ImportError:
        import test_torchinductor  # @manual=fbcode//caffe2/test/inductor:test_inductor-library
except unittest.SkipTest:
    if __name__ == "__main__":
        sys.exit(0)
    raise

TestCase = test_torchinductor.TestCase
check_model = test_torchinductor.check_model
check_model_gpu = test_torchinductor.check_model_gpu
skip_if_cpp_wrapper = test_torchinductor.skip_if_cpp_wrapper
copy_tests = test_torchinductor.copy_tests
define_custom_op_for_test = test_torchinductor.define_custom_op_for_test


@instantiate_parametrized_tests
class CommonTemplate:
    def test_unaligned_input(self):
        def fn(x):
            return torch.nn.functional.relu(x)

        x = torch.randn(1024 + 16, device=self.device)[1:-15]
        # TODO (malfet): Investigate failures on MacOS-14
        with (
            contextlib.nullcontext()
            if self.device != "mps" or MACOS_VERSION >= 15.0
            else self.assertRaises(AssertionError)
        ):
            self.common(fn, (x,), check_lowp=False)

    def test_unaligned_input_2d(self):
        def fn(x):
            return torch.nn.functional.relu(x)

        x = torch.randn(1024, 1024 + 16, device=self.device)[:, 1:-15]
        self.common(fn, (x,), check_lowp=False)

    def test_alignment_without_custom_op(self):
        def fn(x):
            a = torch.nn.functional.relu(x)
            b = (3 * a)[1:-15]
            c = torch.cos(b)
            return c

        x = torch.randn(1024 + 16, device=self.device)
        self.common(fn, (x,), check_lowp=False)

    @config.patch(implicit_fallbacks=True)
    def test_no_align_for_custom_op(self):
        def slice1d(x):
            return (3 * x)[1:-15]

        def slice1d_meta(x):
            return torch.empty_like(x)[1:-15]

        define_custom_op_for_test("slice1d", slice1d, slice1d_meta)

        def fn(x):
            a = torch.nn.functional.relu(x)
            b = torch.ops.test.slice1d(a)
            c = torch.cos(b)
            return c

        x = torch.randn(1024 + 16, device=self.device)
        self.common(fn, (x,), check_lowp=False)

    @config.patch(implicit_fallbacks=True)
    def test_no_align_for_custom_op_2d(self):
        def slice2d(x):
            return (3 * x)[..., 1:-15]

        def slice2d_meta(x):
            return torch.empty_like(x)[..., 1:-15]

        define_custom_op_for_test("slice2d", slice2d, slice2d_meta)

        def fn(x):
            a = torch.nn.functional.relu(x)
            b = torch.ops.test.slice2d(a)
            c = torch.cos(b)
            return c

        x = torch.randn(1024, 1024 + 16, device=self.device)
        self.common(fn, (x,), check_lowp=False)

    @config.patch(implicit_fallbacks=True, alignment_asserts=True)
    @skip_if_cpp_wrapper(
        "Inductor does not generate alignment assertion for cpp_wrapper right now"
    )
    def test_incorrect_meta_for_custom_op_2d(self):
        def slice2d(x):
            return (3 * x)[..., 1:-15]

        def slice2d_meta(x):
            return torch.empty_like(x)[..., 0:-16]

        define_custom_op_for_test("slice2d_incorrect_meta", slice2d, slice2d_meta)

        def fn(x):
            a = torch.nn.functional.relu(x)
            b = torch.ops.test.slice2d_incorrect_meta(a)
            c = torch.cos(b)
            return c

        x = torch.randn(1024, 1024 + 16, device=self.device)

        expected_error = "Expect the tensor to be 16 bytes aligned. Fail due to storage_offset=1 itemsize=4"
        with self.assertRaisesRegex(AssertionError, expected_error):
            self.common(fn, (x,), check_lowp=False)

    def test_slice(self):
        def f(x):
            return x[1:] + 1

        x = torch.randn(1025, device=self.device)
        self.common(f, (x,))

    def test_view_dtype_slice(self):
        def f(x):
            return x.view(dtype=torch.float32)[1:] + 1

        x = torch.randn(1025 * 2, dtype=torch.bfloat16, device=self.device)
        self.common(f, (x,), reference_in_float=False)

    @parametrize(
        "size",
        (
            # wrapper for size = 128: https://gist.github.com/shunting314/88f1e72957b9fc5e9826aaa346a0e652
            # ptx: https://gist.github.com/shunting314/eb657ee8821eef9f0685b7b91e2ad5c2
            # the ptx file uses ld.global.b32 to load input buffer
            128,
            # wrapper for size = 1024: https://gist.github.com/shunting314/d7f64e1f52f6b1e2ec25e1a51052ce43
            # ptx: https://gist.github.com/shunting314/a24ff7563bb6b04523d11b119ab0f2b2
            # the ptx file uses ld.global.v2.b32 to load input buffer
            1024,
            # wrapper for size = 1024 * 1024: https://gist.github.com/shunting314/016b95cf0b6e9a75c25f5c9d5ed0a2ba
            # ptx: https://gist.github.com/shunting314/360112a4893c759b114c12fc99958297
            # the ptx file uses ld.global.v4.b32 to load input buffer
            1024 * 1024,
        ),
    )
    def test_slice_view_dtype(self, size):
        offset = 1

        def f(x):
            return x[2:].view(dtype=torch.float32) + 1

        x = torch.randn((size + offset) * 2, dtype=torch.bfloat16, device=self.device)
        self.common(f, (x,), reference_in_float=False)

    def test_Q4_K_dequantization(self):
        """
        Test the alignment issue for Q4_K dequantization.
        """

        QK_K = 256
        K_SCALE_SIZE = 12

        def get_scale_min(scales):
            n_blocks = scales.shape[0]
            scales = scales.view(torch.uint8)
            scales = scales.reshape((n_blocks, 3, 4))

            d, m, m_d = torch.split(scales, scales.shape[-2] // 3, dim=-2)

            sc = torch.cat([d & 0x3F, (m_d & 0x0F) | ((d >> 2) & 0x30)], dim=-1)
            min = torch.cat([m & 0x3F, (m_d >> 4) | ((m >> 2) & 0x30)], dim=-1)

            return (sc.reshape((n_blocks, 8)), min.reshape((n_blocks, 8)))

        def split_block_dims(blocks, *args):
            n_max = blocks.shape[1]
            dims = list(args) + [n_max - sum(args)]
            return torch.split(blocks, dims, dim=1)

        def dequantize_blocks_Q4_K(blocks, block_size, type_size):
            n_blocks = blocks.shape[0]

            d, dmin, scales, qs = split_block_dims(blocks, 2, 2, K_SCALE_SIZE)
            d = d.view(torch.float16)
            dmin = dmin.view(torch.float16)

            sc, m = get_scale_min(scales)

            d = (d * sc).reshape((n_blocks, -1, 1))
            dm = (dmin * m).reshape((n_blocks, -1, 1))

            qs = qs.reshape((n_blocks, -1, 1, 32)) >> torch.tensor(
                [0, 4], device=d.device, dtype=torch.uint8
            ).reshape((1, 1, 2, 1))
            qs = (qs & 0x0F).reshape((n_blocks, -1, 32))

            return (d * qs - dm).reshape((n_blocks, QK_K))

        data = torch.randint(
            0, 16, (18432, 1728), device=self.device, dtype=torch.uint8
        )

        def dequantize(data):
            block_size, type_size = 256, 144
            rows = data.reshape((-1, data.shape[-1])).view(torch.uint8)
            n_blocks = rows.numel() // type_size
            blocks = rows.reshape((n_blocks, type_size))
            blocks = dequantize_blocks_Q4_K(blocks, block_size, type_size)
            return blocks.reshape(18432, 3072)

        self.common(dequantize, (data,), check_lowp=False, atol=1e-3, rtol=1e-3)


if RUN_CPU:

    class CpuTests(TestCase):
        common = check_model
        device = "cpu"

    copy_tests(CommonTemplate, CpuTests, "cpu")

if RUN_GPU:

    class GPUTests(TestCase):
        common = check_model_gpu
        device = GPU_TYPE

    copy_tests(CommonTemplate, GPUTests, GPU_TYPE)

if __name__ == "__main__":
    from torch._inductor.test_case import run_tests

    if RUN_CPU or RUN_GPU:
        run_tests()
