# Owner(s): ["module: functorch"]
import json
import tempfile
import zipfile
from pathlib import Path

import torch
import torch._dynamo
import torch._functorch
import torch._inductor
import torch._inductor.decomposition
from torch._higher_order_ops.torchbind import CallTorchBind, enable_torchbind_tracing
from torch._inductor import aot_compile, ir
from torch._inductor.package import package_aoti
from torch._inductor.test_case import run_tests, TestCase
from torch.testing._internal.inductor_utils import GPU_TYPE, requires_gpu
from torch.testing._internal.torchbind_impls import (
    _empty_tensor_queue,
    init_torchbind_implementations,
)


class TestTorchbind(TestCase):
    def setUp(self):
        super().setUp()
        init_torchbind_implementations()

    def get_dummy_exported_model(self):
        """
        Returns the ExportedProgram, example inputs, and result from calling the
        eager model with those inputs
        """

        class M(torch.nn.Module):
            def forward(self, x):
                return x + 1

        m = M()
        inputs = (torch.ones(2, 3),)
        orig_res = m(*inputs)

        ep = torch.export.export(m, inputs, strict=False)

        return ep, inputs, orig_res, m

    def get_exported_model(self):
        """
        Returns the ExportedProgram, example inputs, and result from calling the
        eager model with those inputs
        """

        class M(torch.nn.Module):
            def __init__(self) -> None:
                super().__init__()
                self.attr = torch.classes._TorchScriptTesting._Foo(10, 20)
                self.b = torch.randn(2, 3)

            def forward(self, x):
                x = x + self.b
                a = torch.ops._TorchScriptTesting.takes_foo_tuple_return(self.attr, x)
                y = a[0] + a[1]
                b = torch.ops._TorchScriptTesting.takes_foo(self.attr, y)
                c = self.attr.add_tensor(x)
                return x + b + c

        m = M()
        inputs = (torch.ones(2, 3),)
        orig_res = m(*inputs)

        # We can't directly torch.compile because dynamo doesn't trace ScriptObjects yet
        with enable_torchbind_tracing():
            ep = torch.export.export(m, inputs, strict=False)

        return ep, inputs, orig_res, m

    def test_torchbind_inductor(self):
        ep, inputs, orig_res, _ = self.get_exported_model()
        compiled = torch._inductor.compile(ep.module(), inputs)

        new_res = compiled(*inputs)
        self.assertTrue(torch.allclose(orig_res, new_res))

    def test_torchbind_compile_symint(self):
        class M(torch.nn.Module):
            def __init__(self) -> None:
                super().__init__()
                self.attr = torch.classes._TorchScriptTesting._Foo(2, 3)

            def forward(self, x):
                a = torch.ops._TorchScriptTesting.takes_foo_tensor_return(self.attr, x)
                return a

        m = M()
        inputs = (torch.ones(2, 3),)
        orig_res = m(*inputs)
        new_res = torch.compile(m, backend="inductor")(*inputs)
        self.assertTrue(torch.allclose(orig_res, new_res))

    def test_torchbind_compile(self):
        _, inputs, orig_res, mod = self.get_exported_model()
        new_res = torch.compile(mod, backend="inductor")(*inputs)
        self.assertTrue(torch.allclose(orig_res, new_res))

    def test_torchbind_get_buf_bytes(self):
        a = torch.classes._TorchScriptTesting._Foo(10, 20)
        buffer = ir.TorchBindObject(name="a", value=a)
        size = buffer.get_buf_bytes()
        self.assertEqual(size, 0)

        t = torch.randn(2, 3)
        b = torch.classes._TorchScriptTesting._ContainsTensor(t)
        buffer = ir.TorchBindObject(name="b", value=b)
        size = buffer.get_buf_bytes()
        self.assertEqual(size, 2 * 3 * 4)

        q = _empty_tensor_queue()
        buffer = ir.TorchBindObject(name="q", value=q)
        size = buffer.get_buf_bytes()
        self.assertEqual(size, 0)

        q.push(torch.ones(2, 3))
        size = buffer.get_buf_bytes()
        self.assertEqual(size, 2 * 3 * 4)

    def test_torchbind_hop_schema(self):
        foo = torch.classes._TorchScriptTesting._Foo(10, 20)
        foo_ir = ir.TorchBindObject(name="foo", value=foo)
        schema = CallTorchBind.schema(foo_ir, "add")
        self.assertEqual(
            str(schema),
            "call_torchbind(__torch__.torch.classes._TorchScriptTesting._Foo _0, str method, int _1) -> int _0",
        )

    def test_torchbind_config_not_generated(self):
        # custom_objs_config.json should not be generated when its empty
        ep, inputs, _, _ = self.get_dummy_exported_model()
        aoti_files = aot_compile(
            ep.module(), inputs, options={"aot_inductor.package": True}
        )
        for file in aoti_files:
            self.assertTrue(not file.endswith("/custom_objs_config.json"))

    def test_torchbind_hop_schema_no_input(self):
        q = _empty_tensor_queue()
        q_ir = ir.TorchBindObject(name="q", value=q)
        schema = CallTorchBind.schema(q_ir, "pop")
        self.assertEqual(
            str(schema),
            "call_torchbind(__torch__.torch.classes._TorchScriptTesting._TensorQueue _0, str method) -> Tensor _0",
        )

    def test_torchbind_hop_schema_no_output(self):
        q = _empty_tensor_queue()
        q_ir = ir.TorchBindObject(name="q", value=q)
        schema = CallTorchBind.schema(q_ir, "push")
        self.assertEqual(
            str(schema),
            "call_torchbind(__torch__.torch.classes._TorchScriptTesting._TensorQueue _0, str method, Tensor _1) -> NoneType _0",
        )

    def test_torchbind_aot_compile(self):
        ep, inputs, _, _ = self.get_exported_model()
        aoti_files = aot_compile(
            ep.module(), inputs, options={"aot_inductor.package": True}
        )

        custom_objs_config = None
        custom_obj_0 = None
        extern_json = None
        for file in aoti_files:
            if file.endswith("/custom_objs_config.json"):
                custom_objs_config = file
            elif file.endswith("/custom_obj_0"):
                custom_obj_0 = file
            elif file.endswith(".json") and "metadata" not in file:
                extern_json = file

        self.assertIsNotNone(custom_objs_config)
        self.assertIsNotNone(custom_obj_0)
        self.assertIsNotNone(extern_json)

        with open(custom_objs_config) as file:
            data = json.load(file)
            self.assertEqual(data, {"_torchbind_obj0": "custom_obj_0"})

        with open(extern_json) as file:
            data = json.load(file)
            self.assertEqual(
                data,
                {
                    "nodes": [
                        {
                            "name": "buf3",
                            "node": {
                                "target": "_TorchScriptTesting::takes_foo_tuple_return",
                                "inputs": [
                                    {
                                        "name": "foo",
                                        "arg": {
                                            "as_custom_obj": {
                                                "name": "_torchbind_obj0",
                                                "class_fqn": "__torch__.torch.classes._TorchScriptTesting._Foo",
                                            }
                                        },
                                        "kind": 1,
                                    },
                                    {
                                        "name": "x",
                                        "arg": {"as_tensor": {"name": "buf2"}},
                                        "kind": 1,
                                    },
                                ],
                                "outputs": [
                                    {"as_tensor": {"name": "buf4"}},
                                    {"as_tensor": {"name": "buf5"}},
                                ],
                                "metadata": {},
                                "is_hop_single_tensor_return": None,
                            },
                        },
                        {
                            "name": "buf7",
                            "node": {
                                "target": "_TorchScriptTesting::takes_foo",
                                "inputs": [
                                    {
                                        "name": "foo",
                                        "arg": {
                                            "as_custom_obj": {
                                                "name": "_torchbind_obj0",
                                                "class_fqn": "__torch__.torch.classes._TorchScriptTesting._Foo",
                                            }
                                        },
                                        "kind": 1,
                                    },
                                    {
                                        "name": "x",
                                        "arg": {"as_tensor": {"name": "buf6"}},
                                        "kind": 1,
                                    },
                                ],
                                "outputs": [{"as_tensor": {"name": "buf8"}}],
                                "metadata": {},
                                "is_hop_single_tensor_return": None,
                            },
                        },
                        {
                            "name": "buf9",
                            "node": {
                                "target": "call_torchbind",
                                "inputs": [
                                    {
                                        "name": "_0",
                                        "arg": {
                                            "as_custom_obj": {
                                                "name": "_torchbind_obj0",
                                                "class_fqn": "__torch__.torch.classes._TorchScriptTesting._Foo",
                                            }
                                        },
                                        "kind": 1,
                                    },
                                    {
                                        "name": "method",
                                        "arg": {"as_string": "add_tensor"},
                                        "kind": 1,
                                    },
                                    {
                                        "name": "_1",
                                        "arg": {"as_tensor": {"name": "buf2"}},
                                        "kind": 1,
                                    },
                                ],
                                "outputs": [{"as_tensor": {"name": "buf10"}}],
                                "metadata": {},
                                "is_hop_single_tensor_return": None,
                            },
                        },
                    ]
                },
            )

        # Test that the files are packaged
        with tempfile.NamedTemporaryFile(suffix=".pt2") as f:
            package_path = package_aoti(f.name, aoti_files)

            with zipfile.ZipFile(package_path, "r") as zip_ref:
                all_files = zip_ref.namelist()
                base_folder = all_files[0].split("/")[0]
                tmp_path_model = Path(base_folder) / "data" / "aotinductor" / "model"
                tmp_path_constants = Path(base_folder) / "data" / "constants"

                self.assertTrue(
                    str(tmp_path_model / "custom_objs_config.json") in all_files
                )
                self.assertTrue(str(tmp_path_constants / "custom_obj_0") in all_files)

    def test_torchbind_aoti(self):
        ep, inputs, orig_res, _ = self.get_exported_model()
        pt2_path = torch._inductor.aoti_compile_and_package(ep)
        optimized = torch._inductor.aoti_load_package(pt2_path)
        result = optimized(*inputs)
        self.assertEqual(result, orig_res)

    @torch._inductor.config.patch("aot_inductor.use_runtime_constant_folding", True)
    def test_torchbind_aot_compile_constant_folding(self):
        ep, inputs, orig_res, _ = self.get_exported_model()
        pt2_path = torch._inductor.aoti_compile_and_package(ep)
        optimized = torch._inductor.aoti_load_package(pt2_path)
        result = optimized(*inputs)
        self.assertEqual(result, orig_res)

    def test_torchbind_list_return_aot_compile(self):
        class M(torch.nn.Module):
            def __init__(self) -> None:
                super().__init__()
                self.attr = torch.classes._TorchScriptTesting._Foo(10, 20)

            def forward(self, x):
                a = torch.ops._TorchScriptTesting.takes_foo_list_return(self.attr, x)
                y = a[0] + a[1] + a[2]
                b = torch.ops._TorchScriptTesting.takes_foo(self.attr, y)
                return x + b

        m = M()
        inputs = (torch.ones(2, 3),)
        orig_res = m(*inputs)

        # We can't directly torch.compile because dynamo doesn't trace ScriptObjects yet
        with enable_torchbind_tracing():
            ep = torch.export.export(m, inputs, strict=False)

        pt2_path = torch._inductor.aoti_compile_and_package(ep)
        optimized = torch._inductor.aoti_load_package(pt2_path)
        result = optimized(*inputs)
        self.assertEqual(result, orig_res)

    def test_torchbind_queue(self):
        class Foo(torch.nn.Module):
            def __init__(self, tq) -> None:
                super().__init__()
                self.tq = tq

            def forward(self, x):
                self.tq.push(x.cos())
                self.tq.push(x.sin())
                # TODO: int return type in fallback kernel not support yet
                x_cos = self.tq.pop()  # + self.tq.size()
                x_sin = self.tq.pop()  # - self.tq.size()
                return x_sin, x_cos

        inputs = (torch.randn(3, 2),)

        q = _empty_tensor_queue()
        m = Foo(q)
        orig_res = m(*inputs)

        q2 = _empty_tensor_queue()
        m2 = Foo(q2)

        # We can't directly torch.compile because dynamo doesn't trace ScriptObjects yet
        with enable_torchbind_tracing():
            ep = torch.export.export(m2, inputs, strict=False)

        pt2_path = torch._inductor.aoti_compile_and_package(ep)
        optimized = torch._inductor.aoti_load_package(pt2_path)
        result = optimized(*inputs)
        self.assertEqual(result, orig_res)

    @requires_gpu()
    @torch._dynamo.config.patch("capture_dynamic_output_shape_ops", True)
    @torch._inductor.config.patch("graph_partition", True)
    def test_torchbind_compile_gpu_op_symint_graph_partition(self):
        class M(torch.nn.Module):
            def __init__(self) -> None:
                super().__init__()
                self.attr = torch.classes._TorchScriptTesting._Foo(2, 3)

            def forward(self, x):
                a = torch.ops._TorchScriptTesting.takes_foo_tensor_return(self.attr, x)
                a_cuda = a.to(device=GPU_TYPE)
                return a_cuda + 1

        m = M()
        inputs = (torch.ones(2, 3),)
        orig_res = m(*inputs)
        new_res = torch.compile(m, backend="inductor")(*inputs)
        self.assertTrue(torch.allclose(orig_res, new_res))

    def test_torchbind_input_aot_compile(self):
        class M(torch.nn.Module):
            def __init__(self) -> None:
                super().__init__()

            def forward(self, x, y):
                a = torch.ops._TorchScriptTesting.takes_foo_list_return(x, y)
                return a

        m = M()
        inputs = (torch.classes._TorchScriptTesting._Foo(10, 20), torch.ones(2, 3))

        # We can't directly torch.compile because dynamo doesn't trace ScriptObjects yet
        with enable_torchbind_tracing():
            ep = torch.export.export(m, inputs, strict=False)

        from torch._dynamo.exc import UserError

        with self.assertRaisesRegex(
            UserError,
            expected_regex="TorchBind object inputs are not supported in AOTInductor",
        ):
            aot_compile(ep.module(), inputs, options={"aot_inductor.package": True})


if __name__ == "__main__":
    run_tests()
