# Owner(s): ["module: dynamo"]

import torch
import torch._dynamo
import torch._dynamo.test_case
import torch._functorch
from torch._dynamo.precompile_context import PrecompileContext
from torch._functorch import config as functorch_config
from torch._functorch._aot_autograd.autograd_cache import (
    BundledAOTAutogradCacheArtifact,
)
from torch._inductor.test_case import TestCase as InductorTestCase
from torch.testing._internal.inductor_utils import GPU_TYPE, requires_triton


@functorch_config.patch({"enable_autograd_cache": True})
@functorch_config.patch(
    {"bundled_autograd_cache": True}
)  # Requires bundledaotautograd cache for now
class PrecompileContextTests(InductorTestCase):
    def setUp(self):
        """
        Reset all counters and caches before each unit test
        """
        super().setUp()
        # Clear PrecompileContext cache artifacts
        PrecompileContext.clear()

    @requires_triton()
    def test_basic(self):
        """
        Test that after torch.compile, PrecompileContext._new_cache_artifacts length is 1
        """

        def simple_function(x):
            return x.sin() + x.cos()

        compiled_fn = torch.compile(simple_function)

        # Run the compiled function
        x = torch.randn(10, device=GPU_TYPE, requires_grad=True)
        result = compiled_fn(x)
        result.sum().backward()
        # Check that PrecompileContext._new_cache_artifacts_by_key has length 1
        self.assertEqual(len(PrecompileContext._new_cache_artifacts_by_key), 1)

        self.assertEqual(len(PrecompileContext._new_cache_artifacts), 0)
        result = PrecompileContext.serialize()
        assert result is not None
        serialized, cache_info = result
        self.assertEqual(len(cache_info.precompile_aot_autograd_artifacts), 1)

        artifacts = PrecompileContext.deserialize(serialized)
        assert artifacts is not None
        deserialized = artifacts["precompile_aot_autograd"]
        assert len(deserialized) == 1
        entry = deserialized[0]
        assert isinstance(entry, BundledAOTAutogradCacheArtifact)
        entry = entry.after_deserialization()
        # Now that we've serialized, there should be no new cache artifacts
        self.assertEqual(
            len(PrecompileContext._new_cache_artifacts["precompile_aot_autograd"]), 0
        )

    @requires_triton()
    def test_serialize_by_key(self):
        """
        Test that after torch.compile, PrecompileContext._new_cache_artifacts length is 1
        """

        def simple_function(x):
            return x.sin() + x.cos()

        compiled_fn = torch.compile(simple_function)

        # Run the compiled function
        x = torch.randn(10, device=GPU_TYPE, requires_grad=True)
        result = compiled_fn(x)
        result.sum().backward()
        # Check that PrecompileContext._new_cache_artifacts_by_key has length 1
        # TODO: the key right now is the AOTAutogradCacheKey, but will be backend_id once
        # we have torch._dynamo.package implemented
        self.assertEqual(len(PrecompileContext._new_cache_artifacts_by_key), 1)
        key = next(iter(PrecompileContext._new_cache_artifacts_by_key.keys()))
        result = PrecompileContext.serialize_artifact_by_key(key)
        assert isinstance(result, BundledAOTAutogradCacheArtifact)
        self.assertEqual(result.key, key)

        self.assertEqual(len(PrecompileContext._new_cache_artifacts), 0)
        result = PrecompileContext.serialize()
        assert result is not None
        _, cache_info = result
        self.assertEqual(len(cache_info.precompile_aot_autograd_artifacts), 1)


if __name__ == "__main__":
    from torch._dynamo.test_case import run_tests

    run_tests()
