# Owner(s): ["module: inductor"]

import math
import os

import torch
import torch._inductor.config as inductor_config
import torch.nn.functional as F
from torch._dynamo.utils import rmse, same
from torch._inductor.test_case import run_tests, TestCase
from torch._inductor.utils import run_and_get_code
from torch.testing._internal.common_utils import (
    instantiate_parametrized_tests,
    IS_LINUX,
    parametrize,
)
from torch.testing._internal.inductor_utils import GPU_TYPE, HAS_CUDA


DO_PERF_TEST = os.environ.get("DO_PERF_TEST") == "1"
USE_LARGE_INPUT = os.environ.get("USE_LARGE_INPUT") == "1" or DO_PERF_TEST


def _prepare_softmax(x, dim):
    xmax = x.amax(dim=dim, keepdim=True)
    xsum = (x - xmax).exp().sum(dim=dim, keepdim=True)
    return xmax, xsum


class TestOnlineSoftmax(TestCase):
    def do_test_acc_and_perf(self, op):
        if DO_PERF_TEST:
            N = 32 * 1024
            V = 50304  # padded version for gpt2
        else:
            N, V = 1024, 2048  # small value to avoid OOM in CI

        def f(x):
            return op(x, dim=-1)

        x = torch.randn(N, V, dtype=torch.bfloat16, device=GPU_TYPE)
        opt_f = torch.compile(f)
        expected = f(x)
        actual = opt_f(x)

        self.assertTrue(same(expected, actual, tol=1e-2))

        if DO_PERF_TEST:
            from triton.testing import do_bench

            eager_ms = do_bench(lambda: f(x))
            opt_ms = do_bench(lambda: opt_f(x))
            print(f"{eager_ms=}")
            print(f"{opt_ms=}")

    def test_softmax(self):
        self.do_test_acc_and_perf(torch.softmax)

    def test_log_softmax(self):
        self.do_test_acc_and_perf(torch.log_softmax)

    @inductor_config.patch(use_fast_math=True)
    def test_prepare_softmax_perf(self):
        self.do_test_acc_and_perf(_prepare_softmax)

    def get_softmax_wrapper(self, V=50304, use_log_softmax=False, device=GPU_TYPE):
        N = 32 * 1024

        @torch.compile
        def f(x):
            if use_log_softmax:
                return torch.log_softmax(x, dim=-1)
            else:
                return torch.softmax(x, dim=-1)

        x = torch.randn(N, V, dtype=torch.bfloat16, device=device)
        out, source_codes = run_and_get_code(f, x)
        return source_codes[0]

    def test_codegen_3pass_softmax_due_to_disable(self):
        with inductor_config.patch(online_softmax=False):
            wrapper_code = self.get_softmax_wrapper()

        self.assertEqual(wrapper_code.count("for r0_offset in"), 3)

    @parametrize("V", [2048, 50304])
    @parametrize("use_log_softmax", [False, True])
    def test_codegen_online_softmax(self, use_log_softmax, V):
        wrapper_code = self.get_softmax_wrapper(use_log_softmax=use_log_softmax, V=V)

        self.assertEqual(wrapper_code.count("for r0_offset in"), 2)

    def test_no_online_softmax_for_cpu(self):
        code = self.get_softmax_wrapper(V=2048, device="cpu")

        # CPU need an explicit loop across different rows.
        # For GPU, this is parallelized by the hardware.
        self.assertEqual(code.count("for(int64_t"), 4)

    def test_codegen_softmax_persistent_reduction(self):
        """
        Persistent reduction has no for loops.
        """
        wrapper_code = self.get_softmax_wrapper(1024)
        self.assertEqual(wrapper_code.count("for r0_offset in"), 0)

    @inductor_config.patch("triton.persistent_reductions", False)
    def test_sdpa(self):
        """
        Make sure online softmax here does not conflict with the sdpa
        patterns.
        """
        q, k, v = (
            torch.randn((4, 2, 16, 32), device=GPU_TYPE, dtype=torch.bfloat16)
            for _ in range(3)
        )

        def f(q, k, v):
            return (
                torch.matmul(q, k.transpose(-2, -1))
                .div(math.sqrt(k.shape[-1]))
                .softmax(dim=-1)
                .matmul(v)
            )

        opt_f = torch.compile(f)
        ref = f(q, k, v)
        act, (code,) = run_and_get_code(opt_f, q, k, v)
        self.assertTrue(torch.allclose(ref, act, atol=1e-2, rtol=1e-2))
        self.assertTrue("aten._scaled_dot_product_" in code)

    @parametrize("nrow", [2, 2048])
    @parametrize("dim", [-1, 0, 1])
    def test_prepare_softmax(self, dim, nrow):
        x = torch.randn(nrow, 2048, dtype=torch.bfloat16, device=GPU_TYPE)
        act, (code,) = run_and_get_code(torch.compile(_prepare_softmax), x, dim)
        ref = _prepare_softmax(x, dim)
        self.assertTrue(same(ref, act, tol=1e-2))

        if nrow == 2048 and dim == 0:
            # split reduction is triggered. We have multiple kernels
            self.assertTrue(code.count("def triton") >= 2)
        else:
            if nrow == 2 and dim == 0:
                # persistent reduction triggered
                expected_num_loop = 0
            else:
                # A single loop due to online softmax
                expected_num_loop = 1
            self.assertEqual(code.count("for r0_offset in"), expected_num_loop)

    def test_split_reduction(self):
        """
        We don't split online_softmax_reduce for now. Check
        'Split online_softmax_reduce' note in the code.

        When a split is promsing, we fallback for now.

        This is just a manual example rather than something we
        see in practice.
        """
        # tensor shape to trigger split reduction
        x = torch.randn(1, 2**20, dtype=torch.bfloat16, device=GPU_TYPE)
        ref = torch.softmax(x, dim=-1)
        act, (code,) = run_and_get_code(torch.compile(torch.softmax), x, dim=-1)
        self.assertTrue(torch.allclose(ref, act, atol=1e-3, rtol=1e-3))
        self.assertTrue(code.count("def triton") >= 2)
        self.assertTrue("online_softmax_reduce" not in code)

    @parametrize("dtype", [torch.bfloat16, torch.half, torch.float32])
    def test_prepare_softmax_acc_with_fp64(self, dtype):
        if USE_LARGE_INPUT:
            M, N = 32768, 50257
        else:
            M, N = 1024, 2048

        x = torch.randn(M, N, device=GPU_TYPE, dtype=dtype)

        ref_fp64 = _prepare_softmax(x.to(dtype=torch.float64), dim=-1)
        ref = _prepare_softmax(x, dim=-1)
        res, (code,) = run_and_get_code(torch.compile(_prepare_softmax), x, dim=-1)
        self.assertTrue("online_softmax_reduce" in code)

        # Max should be exactly equal
        self.assertEqual(ref[0], res[0])
        self.assertEqual(ref[0].to(dtype=torch.float64), ref_fp64[0])

        ref_error = rmse(ref_fp64[1], ref[1]).item()
        res_error = rmse(ref_fp64[1], res[1]).item()

        # My local tests even shows a smaller res_error:
        #   ref_error=2.1065, res_error=2.1028
        # for bf16
        #   ref_error=0.2611, res_error=0.2609
        # for fp16
        #   ref_error=0.0001, res_error=0.0001
        # for fp32
        print(f"{ref_error=:.4f}, {res_error=:.4f}")

        self.assertTrue(
            res_error < ref_error + 0.1
        )  # Is this good enough to make CI stable

    @parametrize("fn", [torch.log_softmax, torch.softmax])
    @parametrize("dtype", [torch.bfloat16, torch.half, torch.float32])
    def test_softmax_acc_with_fp64(self, dtype, fn):
        if USE_LARGE_INPUT:
            M, N = 32768, 50257
        else:
            M, N = 1024, 2048

        x = torch.randn(M, N, device=GPU_TYPE, dtype=dtype)

        ref_fp64 = fn(x.to(dtype=torch.float64), dim=-1)
        ref = fn(x, dim=-1)
        res, (code,) = run_and_get_code(torch.compile(fn), x, dim=-1)
        self.assertTrue("online_softmax_reduce" in code)

        ref_error = rmse(ref_fp64, ref).item()
        res_error = rmse(ref_fp64, res).item()

        # For torch.softmax,
        # I get almost 0 for ref_error/res_error for all 3 dtypes. It's because
        # each value is very small since each row add up to 1.0
        #
        # For torch.log_softmax
        #   ref_error=0.0180399032, res_error=0.0180399031
        # for bf16
        #   ref_error=0.0022548872, res_error=0.0022548872
        # for fp16
        #   ref_error=0.0000003744, res_error=0.0000003748
        # for fp32
        print(f"{ref_error=:.10f}, {res_error=:.10f}")

        self.assertTrue(
            res_error < ref_error + 0.1
        )  # Is this good enough to make CI stable

    def test_softmin(self):
        """
        The rnumel==1 kind of reduction should be unrolled.
        """

        def f(x):
            return F.softmin(x, dim=0)

        x = torch.randn(1, device=GPU_TYPE)
        ref = f(x)
        act, (code,) = run_and_get_code(torch.compile(f), x)
        self.assertTrue(torch.allclose(ref, act))
        self.assertTrue("online_softmax_reduce" not in code)

    def test_causal_mask(self):
        def f(x):
            return x.softmax(dim=-1)

        x = torch.randn(2048, 2048, device=GPU_TYPE)
        mask = torch.tril(torch.ones(2048, 2048, device=GPU_TYPE))
        x.masked_fill_(mask == 0, float("-inf"))

        ref = f(x)
        act = torch.compile(f)(x)
        self.assertTrue(not ref.isnan().any())
        self.assertTrue(not act.isnan().any())
        self.assertTrue(torch.allclose(ref, act))

    def test_tb_speech_transformer_attn(self):
        """
        This is an example extracted from speech_transformer.
        Since online softmax use the max from partial elements of an entire
        row, if the input contains '-inf', it's possible that the
        max of those partial elements is '-inf' even if the entire row
        has non '-inf' value. In this cause, online softmax will need
        do things like 'float(-inf) - float(-inf)' which becomes 'nan'.
        We fixed this by interpreting 'float(-inf) - float(-inf)' as 0
        if we found both operands are 'float(-inf)'.
        """
        torch.manual_seed(1337)

        def f(x, mask):
            x = torch.where(mask, float("-inf"), x)
            xmax = x.amax(dim=-1, keepdim=True)
            xsum = (x - xmax).exp().sum(dim=-1, keepdim=True)
            return xsum

        x = torch.randn(8, 10, 22, 204, device=GPU_TYPE)
        mask = torch.randint(0, 2, (10, 204), device=GPU_TYPE) == 0
        mask = mask.view(1, 10, 1, 204)

        ref = f(x, mask)
        act = torch.compile(f)(x, mask)
        self.assertTrue(not ref.isnan().any())
        self.assertTrue(not act.isnan().any())
        self.assertTrue(torch.allclose(ref, act))


instantiate_parametrized_tests(TestOnlineSoftmax)

if __name__ == "__main__":
    if IS_LINUX and HAS_CUDA:
        run_tests()
