# Owner(s): ["module: inductor"]

import sympy

import torch
from torch._inductor.codegen.block_analysis import BlockPatternMatcher
from torch._inductor.utils import sympy_dot
from torch._inductor.virtualized import V
from torch.testing._internal.common_utils import (
    instantiate_parametrized_tests,
    parametrize,
    run_tests,
    TestCase,
)
from torch.testing._internal.inductor_utils import dummy_graph
from torch.utils._sympy.functions import FloorDiv, Identity, ModularIndexing


# Some useful symbols
x, y = sympy.symbols("x y")


@instantiate_parametrized_tests
class BlockAnalysisTest(TestCase):
    @classmethod
    def setUpClass(cls):
        super().setUpClass()

        # Create a GraphLowering, so we can access V.graph.
        cls.graph = dummy_graph()

    @parametrize(
        "stride,symbol,expr",
        [
            (5, x, Identity(5 * x)),
            (4, y, 4 * Identity(y)),
            (3, x, Identity(3) * x),
        ],
    )
    def test_affine_identity(self, stride: int, symbol: sympy.Symbol, expr: sympy.Expr):
        # Test that we can handle an identity expression in affine indexing.
        matched_stride = BlockPatternMatcher.match_affine_block_expr(expr, symbol)
        self.assertEqual(matched_stride, stride)

    @parametrize(
        "dims,strides,symbol,expr",
        [
            (
                (2, 4),
                (4, 1),
                x,
                4 * FloorDiv(Identity(x), 4) + ModularIndexing(x, 1, 4),
            ),
            (
                (3, 9),
                (5, 2),
                x,
                5 * FloorDiv(x, 9) + 2 * ModularIndexing(Identity(x), 1, 9),
            ),
            ((2, 7), (1, 1), x, Identity(FloorDiv(x, 7) + ModularIndexing(x, 1, 7))),
        ],
    )
    def test_mod_div_identity(
        self,
        dims: tuple[int],
        strides: tuple[int],
        symbol: sympy.Symbol,
        expr: sympy.Expr,
    ):
        # Test that we can handle an identity expression in modular indexing.
        numel = int(torch.prod(torch.Tensor(dims)))
        num_dims = len(dims)
        with V.set_graph_handler(self.graph):
            match_result = BlockPatternMatcher.match_mod_div_block_expr(
                expr, symbol, numel, num_dims
            )

        # Check the matched block dimensions.
        self.assertNotEqual(match_result, None)
        matched_dims, matched_strides, matched_block_index_exprs = match_result
        self.assertEqual(matched_dims, dims)
        self.assertEqual(matched_strides, strides)

    @parametrize(
        "symbol,expr,subexpr",
        [
            (x, Identity(x), x),
            (x, Identity(x + 5), x),
            (y, Identity(x + 2 * y) + 5, 2 * y),
        ],
    )
    def test_subexpr_identity(
        self,
        symbol: sympy.Symbol,
        expr: sympy.Expr,
        subexpr: sympy.Expr,
    ):
        matched_subexpr = BlockPatternMatcher.get_subexpr_involving_symbol(expr, symbol)
        self.assertEqual(matched_subexpr, subexpr)

    def test_index_with_dynamic_shapes(self):
        s0 = sympy.var("s0", integer=True)
        s1 = sympy.var("s1", integer=True)

        dims = [s1, sympy.Integer(3)]
        num_dims = len(dims)
        numel = dims[0] * dims[1]
        strides = [sympy.Integer(3) * s0, sympy.Integer(1)]
        block_index_exprs = [
            FloorDiv(y, sympy.Integer(3)),
            ModularIndexing(y, sympy.Integer(1), sympy.Integer(3)),
        ]
        index = sympy_dot(strides, block_index_exprs)

        with V.set_graph_handler(self.graph):
            match = BlockPatternMatcher.match_mod_div_block_expr(
                index, y, numel, num_dims
            )
            sizevars = V.graph.sizevars
            for expected, actual in zip((dims, strides, block_index_exprs), match):
                assert isinstance(expected, (list, tuple)) and isinstance(
                    actual, (list, tuple)
                )
                for expected_expr, actual_expr in zip(expected, actual):
                    assert isinstance(expected_expr, sympy.Expr) and isinstance(
                        actual_expr, sympy.Expr
                    )
                    self.assertTrue(
                        sizevars.statically_known_equals(
                            sizevars.remove_precomputed_replacements(expected_expr),
                            sizevars.remove_precomputed_replacements(actual_expr),
                        )
                    )


if __name__ == "__main__":
    run_tests()
