# Owner(s): ["module: dynamo"]
import unittest
import unittest.mock as mock

import torch
import torch._dynamo.test_case
import torch._functorch.config
import torch.utils.checkpoint
from torch._dynamo.testing import (
    AotEagerAndRecordGraphs,
    EagerAndRecordGraphs,
    normalize_gm,
)
from torch._higher_order_ops.schema import find_hop_schema
from torch.testing._internal.common_utils import instantiate_parametrized_tests
from torch.testing._internal.inductor_utils import HAS_CUDA


requires_cuda = unittest.skipUnless(HAS_CUDA, "requires cuda")


def normalize_graph(gm):
    return normalize_gm(gm.print_readable(print_output=False))


class InvokeQuantTest(torch._higher_order_ops.BaseHOP):
    def __init__(self):
        super().__init__("invoke_quant_test")

    def __call__(self, subgraph, *operands, scheme):
        return super().__call__(subgraph, *operands, scheme=scheme)


invoke_quant_test = InvokeQuantTest()


class BaseHOPTest(torch._dynamo.test_case.TestCase):
    # TODO: flip to False later, we're landing a refactor PR and don't want to merge conflict
    @torch._dynamo.config.patch(assume_static_by_default=True)
    def test_dynamo(self):
        def inner(x, y):
            return (x @ y).sin().cos()

        x = torch.randn(3, 3, requires_grad=True)
        y = torch.randn(3, 3, requires_grad=True)

        backend = EagerAndRecordGraphs()

        @torch.compile(backend=backend)
        def f(x, y):
            return invoke_quant_test(inner, x, y, scheme="nf4")

        out = f(x, y)
        self.assertEqual(out, inner(x, y))

        assert len(backend.graphs) == 1
        self.assertExpectedInline(
            normalize_graph(backend.graphs[0]),
            """\
class GraphModule(torch.nn.Module):
    def forward(self, L_x_: "f32[3, 3]", L_y_: "f32[3, 3]"):
        l_x_ = L_x_
        l_y_ = L_y_

        subgraph_0 = self.subgraph_0
        invoke_quant_test = torch.ops.higher_order.invoke_quant_test(subgraph_0, l_x_, l_y_, scheme = 'nf4');  subgraph_0 = l_x_ = l_y_ = None
        getitem: "f32[3, 3]" = invoke_quant_test[0];  invoke_quant_test = None
        return (getitem,)

    class subgraph_0(torch.nn.Module):
        def forward(self, l_x_: "f32[3, 3]", l_y_: "f32[3, 3]"):
            matmul: "f32[3, 3]" = l_x_ @ l_y_;  l_x_ = l_y_ = None
            sin: "f32[3, 3]" = matmul.sin();  matmul = None
            cos: "f32[3, 3]" = sin.cos();  sin = None
            return (cos,)
""",  # NOQA: B950
        )

    def test_schema_gen_single_return(self):
        def inner(x, y):
            return (x @ y).sin().cos()

        x = torch.randn(3, 3, requires_grad=False)
        y = torch.randn(3, 3, requires_grad=False)

        backend = EagerAndRecordGraphs()

        @torch.compile(backend=backend)
        def f(x, y):
            return invoke_quant_test(inner, x, y, scheme="nf4")

        out = f(x.clone(), y)
        self.assertEqual(out, inner(x.clone(), y))
        schemas = find_hop_schema(backend.graphs[0], invoke_quant_test)
        self.assertEqual(len(schemas), 1)
        self.assertExpectedInline(
            str(schemas[0]),
            """invoke_quant_test(Any subgraph, Tensor arg0, Tensor arg1, *, str scheme="nf4") -> ((Tensor))""",  # noqa: B950
        )

    def test_schema_gen_pytree_in_out(self):
        def inner(x_y):
            x, y = x_y
            return [
                (x @ y).sin().cos(),
                (x + y, x - y),
                {"out": (x @ y,)},
            ]

        # make x not require grad because we want to inplace mutate it
        x = torch.randn(3, 3, requires_grad=False)
        y = torch.randn(3, 3, requires_grad=True)

        backend = EagerAndRecordGraphs()

        @torch.compile(backend=backend)
        def f(x, y):
            return invoke_quant_test(inner, [x, y], scheme="nf4")

        out = f(x.clone(), y)
        self.assertEqual(out, inner([x.clone(), y]))
        schemas = find_hop_schema(backend.graphs[0], invoke_quant_test)
        self.assertEqual(len(schemas), 1)
        self.assertExpectedInline(
            str(schemas[0]),
            """invoke_quant_test(Any subgraph, Tensor arg0, Tensor arg1, *, str scheme="nf4") -> (Tensor, Tensor, Tensor, Tensor)""",  # noqa: B950
        )

    def test_schema_gen_single_return_with_mutation(self):
        def inner(x, y):
            x.add_(1)
            y.mul_(-1)
            return (x @ y).sin().cos()

        x = torch.randn(3, 3, requires_grad=False)
        y = torch.randn(3, 3, requires_grad=False)

        backend = EagerAndRecordGraphs()

        def f(x, y):
            return invoke_quant_test(inner, x, y, scheme="nf4")

        with mock.patch(
            "torch._dynamo.variables.higher_order_ops.BaseHOPVariable.supports_input_mutation",
            True,
        ):
            torch.compile(f, backend=backend, fullgraph=True)(x.clone(), y)
        self.assertEqual(len(backend.graphs), 1)
        self.assertExpectedInline(
            normalize_graph(backend.graphs[0]),
            """\
class GraphModule(torch.nn.Module):
    def forward(self, L_x_: "f32[3, 3]", L_y_: "f32[3, 3]"):
        l_x_ = L_x_
        l_y_ = L_y_

        subgraph_0 = self.subgraph_0
        invoke_quant_test = torch.ops.higher_order.invoke_quant_test(subgraph_0, l_x_, l_y_, scheme = 'nf4');  subgraph_0 = l_x_ = l_y_ = None
        getitem: "f32[3, 3]" = invoke_quant_test[0];  invoke_quant_test = None
        return (getitem,)

    class subgraph_0(torch.nn.Module):
        def forward(self, l_x_: "f32[3, 3]", l_y_: "f32[3, 3]"):
            add_: "f32[3, 3]" = l_x_.add_(1);  add_ = None

            mul_: "f32[3, 3]" = l_y_.mul_(-1);  mul_ = None

            matmul: "f32[3, 3]" = l_x_ @ l_y_;  l_x_ = l_y_ = None
            sin: "f32[3, 3]" = matmul.sin();  matmul = None
            cos: "f32[3, 3]" = sin.cos();  sin = None
            return (cos,)
""",  # noqa: B950
        )
        self.assertExpectedInline(
            str(find_hop_schema(backend.graphs[0], invoke_quant_test)[0]),
            """invoke_quant_test(Any subgraph, Tensor(a1!) arg0, Tensor(a2!) arg1, *, str scheme="nf4") -> ((Tensor))""",
        )

    def test_schema_gen_pytree_in_out_with_mutation(self):
        def inner(x_y):
            x, y = x_y
            x.add_(1)
            return [
                (x @ y).sin().cos(),
                (x + y, x - y),
                {"out": (x @ y,)},
            ]

        # make x not require grad because we want to inplace mutate it
        x = torch.randn(3, 3, requires_grad=False)
        y = torch.randn(3, 3, requires_grad=True)

        bk = EagerAndRecordGraphs()

        def f(x, y):
            return invoke_quant_test(inner, [x, y], scheme="nf4")

        with (
            mock.patch(
                "torch._dynamo.variables.higher_order_ops.BaseHOPVariable.supports_input_mutation",
                True,
            ),
            torch.no_grad(),
        ):
            torch.compile(f, backend=bk, fullgraph=True)(x.clone(), y)

        self.assertEqual(len(bk.graphs), 1)
        self.assertExpectedInline(
            normalize_graph(bk.graphs[0]),
            """\
class GraphModule(torch.nn.Module):
    def forward(self, L_x_: "f32[3, 3]", L_y_: "f32[3, 3]"):
        l_x_ = L_x_
        l_y_ = L_y_

        subgraph_0 = self.subgraph_0
        invoke_quant_test = torch.ops.higher_order.invoke_quant_test(subgraph_0, l_x_, l_y_, scheme = 'nf4');  subgraph_0 = l_x_ = l_y_ = None
        getitem: "f32[3, 3]" = invoke_quant_test[0]
        getitem_1: "f32[3, 3]" = invoke_quant_test[1]
        getitem_2: "f32[3, 3]" = invoke_quant_test[2]
        getitem_3: "f32[3, 3]" = invoke_quant_test[3];  invoke_quant_test = None
        return (getitem, getitem_1, getitem_2, getitem_3)

    class subgraph_0(torch.nn.Module):
        def forward(self, l_x_: "f32[3, 3]", l_y_: "f32[3, 3]"):
            add_: "f32[3, 3]" = l_x_.add_(1);  add_ = None

            matmul: "f32[3, 3]" = l_x_ @ l_y_
            sin: "f32[3, 3]" = matmul.sin();  matmul = None
            child: "f32[3, 3]" = sin.cos();  sin = None

            child_1: "f32[3, 3]" = l_x_ + l_y_
            child_2: "f32[3, 3]" = l_x_ - l_y_

            child_3: "f32[3, 3]" = l_x_ @ l_y_;  l_x_ = l_y_ = None
            return (child, child_1, child_2, child_3)
""",  # noqa: B950
        )
        self.assertExpectedInline(
            str(find_hop_schema(bk.graphs[0], invoke_quant_test)[0]),
            """invoke_quant_test(Any subgraph, Tensor(a1!) arg0, Tensor arg1, *, str scheme="nf4") -> (Tensor, Tensor, Tensor, Tensor)""",  # noqa: B950
        )

    def test_none_input(self):
        def inner(x, y):
            if x is not None:
                return y.sin()
            return y.cos()

        backend = EagerAndRecordGraphs()

        @torch.compile(backend=backend, fullgraph=True)
        def f(x, y):
            return invoke_quant_test(inner, x, y, scheme="nf4")

        x = None
        y = torch.randn(3, 4)
        out = f(x, y)
        self.assertEqual(out, inner(x, y))
        self.assertExpectedInline(
            normalize_graph(backend.graphs[0]),
            """\
class GraphModule(torch.nn.Module):
    def forward(self, L_y_: "f32[3, 4]"):
        l_y_ = L_y_

        subgraph_0 = self.subgraph_0
        invoke_quant_test = torch.ops.higher_order.invoke_quant_test(subgraph_0, l_y_, scheme = 'nf4');  subgraph_0 = l_y_ = None
        getitem: "f32[3, 4]" = invoke_quant_test[0];  invoke_quant_test = None
        return (getitem,)

    class subgraph_0(torch.nn.Module):
        def forward(self, l_y_: "f32[3, 4]"):
            cos: "f32[3, 4]" = l_y_.cos();  l_y_ = None
            return (cos,)
""",
        )

    def test_int_input(self):
        def inner(x, y):
            return x + y

        backend = EagerAndRecordGraphs()

        @torch.compile(backend=backend, fullgraph=True)
        def f(x, y):
            return invoke_quant_test(inner, x, y, scheme="nf4")

        x = 1
        y = torch.randn(3, 4)
        out = f(x, y)
        self.assertEqual(out, inner(x, y))
        self.assertExpectedInline(
            normalize_graph(backend.graphs[0]),
            """\
class GraphModule(torch.nn.Module):
    def forward(self, L_y_: "f32[3, 4]"):
        l_y_ = L_y_

        subgraph_0 = self.subgraph_0
        invoke_quant_test = torch.ops.higher_order.invoke_quant_test(subgraph_0, l_y_, scheme = 'nf4');  subgraph_0 = l_y_ = None
        getitem: "f32[3, 4]" = invoke_quant_test[0];  invoke_quant_test = None
        return (getitem,)

    class subgraph_0(torch.nn.Module):
        def forward(self, l_y_: "f32[3, 4]"):
            add: "f32[3, 4]" = 1 + l_y_;  l_y_ = None
            return (add,)
""",
        )

    def test_auto_functionalize(self):
        def inner(x, y):
            x.add_(1)
            return x + y

        backend = AotEagerAndRecordGraphs()

        def f(x, y):
            return invoke_quant_test(inner, x, y, scheme="nf4")

        x = torch.randn(3, 3, requires_grad=False)
        x_clone = x.clone()
        y = torch.randn(3, 3, requires_grad=True)
        with (
            mock.patch(
                "torch._dynamo.variables.higher_order_ops.BaseHOPVariable.supports_input_mutation",
                True,
            ),
            torch.no_grad(),
        ):
            compiled_out = torch.compile(f, backend=backend, fullgraph=True)(x, y)
        self.assertEqual(x, x_clone + 1)
        self.assertEqual(compiled_out, x_clone + y + 1)
        self.assertEqual(len(backend.fw_graphs), 1)
        self.assertExpectedInline(
            normalize_graph(backend.fw_graphs[0]),
            """\
class <lambda>(torch.nn.Module):
    def forward(self, arg0_1: "f32[3, 3]", arg1_1: "f32[3, 3]"):
        auto_functionalized_subgraph_0 = self.auto_functionalized_subgraph_0
        _tree_spec_constant0 = self._tree_spec_constant0
        auto_functionalized_v2 = torch.ops.higher_order.auto_functionalized_v2(torch.ops.higher_order.invoke_quant_test, subgraph = auto_functionalized_subgraph_0, arg1 = arg1_1, scheme = 'nf4', _arg0_base_index = 0, _all_bases = [arg0_1], _op_schema = _tree_spec_constant0);  auto_functionalized_subgraph_0 = arg1_1 = _tree_spec_constant0 = None
        getitem: "f32[3, 3]" = auto_functionalized_v2[0]
        getitem_1: "f32[3, 3]" = auto_functionalized_v2[1];  auto_functionalized_v2 = None
        copy_: "f32[3, 3]" = torch.ops.aten.copy_.default(arg0_1, getitem_1);  arg0_1 = getitem_1 = copy_ = None
        return (getitem,)

    class auto_functionalized_subgraph_0(torch.nn.Module):
        def forward(self, arg0_1: "f32[3, 3]", arg1_1: "f32[3, 3]"):
            add: "f32[3, 3]" = torch.ops.aten.add.Tensor(arg0_1, 1)
            add_1: "f32[3, 3]" = torch.ops.aten.add.Tensor(add, arg1_1);  arg1_1 = None
            copy_: "f32[3, 3]" = torch.ops.aten.copy_.default(arg0_1, add);  arg0_1 = add = copy_ = None
            return (add_1,)
""",  # noqa: B950
        )

    @torch._dynamo.config.patch(assume_static_by_default=True)
    def test_aot_eager(self):
        def inner(x, y):
            return (x @ y).sin_().cos()

        x = torch.randn(3, 3, requires_grad=True)
        y = torch.randn(3, 3, requires_grad=True)

        backend = AotEagerAndRecordGraphs()

        @torch.compile(backend=backend)
        def f(x, y):
            return invoke_quant_test(inner, x, y, scheme="nf4")

        out = f(x, y)
        result = torch.autograd.grad(out, x, y)
        out = inner(x, y)
        expected = torch.autograd.grad(out, x, y)
        self.assertEqual(result, expected)

        assert len(backend.fw_graphs) == 1
        self.assertExpectedInline(
            normalize_graph(backend.fw_graphs[0]),
            """\
class GraphModule(torch.nn.Module):
    def forward(self, primals_1: "f32[3, 3]", primals_2: "f32[3, 3]"):
        subgraph0 = self.subgraph0
        invoke_quant_test = torch.ops.higher_order.invoke_quant_test(subgraph0, primals_1, primals_2, scheme = 'nf4');  subgraph0 = None
        getitem: "f32[3, 3]" = invoke_quant_test[0];  invoke_quant_test = None
        return (getitem, primals_1, primals_2)

    class subgraph0(torch.nn.Module):
        def forward(self, arg0_1: "f32[3, 3]", arg1_1: "f32[3, 3]"):
            mm: "f32[3, 3]" = torch.ops.aten.mm.default(arg0_1, arg1_1);  arg0_1 = arg1_1 = None
            sin: "f32[3, 3]" = torch.ops.aten.sin.default(mm);  mm = None
            cos: "f32[3, 3]" = torch.ops.aten.cos.default(sin);  sin = None
            return (cos,)
""",  # NOQA: B950
        )

        assert len(backend.bw_graphs) == 1
        self.assertExpectedInline(
            normalize_graph(backend.bw_graphs[0]),
            """\
class GraphModule(torch.nn.Module):
    def forward(self, primals_1: "f32[3, 3]", primals_2: "f32[3, 3]", tangents_1: "f32[3, 3]"):
        subgraph1 = self.subgraph1
        invoke_quant_test_1 = torch.ops.higher_order.invoke_quant_test(subgraph1, primals_1, primals_2, tangents_1, scheme = 'nf4');  subgraph1 = primals_1 = primals_2 = tangents_1 = None
        getitem_1: "f32[3, 3]" = invoke_quant_test_1[0]
        getitem_2: "f32[3, 3]" = invoke_quant_test_1[1];  invoke_quant_test_1 = None
        return (getitem_1, getitem_2)

    class subgraph1(torch.nn.Module):
        def forward(self, arg0_1: "f32[3, 3]", arg1_1: "f32[3, 3]", arg2_1: "f32[3, 3]"):
            mm: "f32[3, 3]" = torch.ops.aten.mm.default(arg0_1, arg1_1)
            clone: "f32[3, 3]" = torch.ops.aten.clone.default(mm)
            sin: "f32[3, 3]" = torch.ops.aten.sin.default(mm);  mm = None
            sin_1: "f32[3, 3]" = torch.ops.aten.sin.default(sin);  sin = None
            neg: "f32[3, 3]" = torch.ops.aten.neg.default(sin_1);  sin_1 = None
            mul: "f32[3, 3]" = torch.ops.aten.mul.Tensor(arg2_1, neg);  arg2_1 = neg = None
            cos_1: "f32[3, 3]" = torch.ops.aten.cos.default(clone);  clone = None
            mul_1: "f32[3, 3]" = torch.ops.aten.mul.Tensor(mul, cos_1);  mul = cos_1 = None
            t: "f32[3, 3]" = torch.ops.aten.t.default(arg0_1);  arg0_1 = None
            mm_1: "f32[3, 3]" = torch.ops.aten.mm.default(t, mul_1);  t = None
            t_1: "f32[3, 3]" = torch.ops.aten.t.default(arg1_1);  arg1_1 = None
            mm_2: "f32[3, 3]" = torch.ops.aten.mm.default(mul_1, t_1);  mul_1 = t_1 = None
            return (mm_2, mm_1)
""",  # NOQA: B950
        )

    def test_aliasing_mutation_error(self):
        def inner(x, y):
            return x

        def inner2(x, y):
            x.sin_()
            return x + y

        x = torch.randn(3, 3)
        y = torch.randn(3, 3)

        @torch.compile(backend="eager", fullgraph=True)
        def f(inner, x, y):
            return invoke_quant_test(inner, x, y, scheme="nf4")

        with self.assertRaisesRegex(
            RuntimeError, "Encountered aliasing during higher order op tracing"
        ):
            f(inner, x, y)

        with self.assertRaisesRegex(
            RuntimeError,
            "Encountered input mutation during higher order op tracing",
        ):
            f(inner2, x, y)

    def test_eager_call(self):
        def inner(x, y):
            return x + y

        x = torch.randn(3, 3)
        y = torch.randn(3, 3)

        with self.assertRaisesRegex(RuntimeError, "torch.fx.GraphModule"):
            invoke_quant_test(inner, x, y, scheme="nf4")

        from functorch import make_fx

        result = make_fx(inner)(x, y)
        # smoke test
        invoke_quant_test(result, x, y, scheme="nf4")


instantiate_parametrized_tests(BaseHOPTest)


if __name__ == "__main__":
    from torch._dynamo.test_case import run_tests

    run_tests()
