# Owner(s): ["oncall: distributed"]

import unittest
from collections import deque, OrderedDict
from contextlib import ContextDecorator, contextmanager, nullcontext
from copy import deepcopy
from functools import partial

import torch
import torch.nn as nn
from torch.distributed._composable import checkpoint
from torch.testing._internal.common_cuda import TEST_CUDA
from torch.testing._internal.common_utils import run_tests, TestCase
from torch.utils.checkpoint import CheckpointError


class MemoryDelta(ContextDecorator):
    def __init__(self, device: torch.device):
        self.device: torch.device = device
        self.active_memory_enter: int = 0
        self.active_memory_exit: int = 0

    def __enter__(self):
        self.active_memory_enter = (
            torch.cuda.memory_stats()["active_bytes.all.current"]
            if self.device.type == "cuda"
            else 0
        )
        return self

    def __exit__(self, *exc):
        self.active_memory_exit = (
            torch.cuda.memory_stats()["active_bytes.all.current"]
            if self.device.type == "cuda"
            else 0
        )

    def delta(self) -> int:
        return self.active_memory_exit - self.active_memory_enter


class ToyModel(nn.Module):
    def __init__(self) -> None:
        super().__init__()
        self.l1 = nn.Linear(100, 100)
        self.seq = nn.Sequential(
            nn.ReLU(),
            nn.Linear(100, 100),
            nn.ReLU(),
        )

    def forward(self, x):
        return self.seq(self.l1(x))


class RandomModel(nn.Module):
    def __init__(self) -> None:
        super().__init__()
        self.p = nn.Parameter(torch.randn(100, 100))

    def forward(self, x):
        y = torch.matmul(self.p, torch.randn(100, 100, device=self.p.device))
        return torch.matmul(x, y)


class MultiOutputModel(nn.Module):
    def __init__(self, device: torch.device):
        super().__init__()
        self.w1 = nn.Parameter(torch.randn((100, 100), device=device))
        self.w2 = nn.Parameter(torch.randn((100, 100), device=device))

    def forward(self, x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
        z = x @ self.w1
        z = nn.functional.relu(z)
        z = z @ self.w2
        return z.sin(), z.cos()


class MultiInputModel(nn.Module):
    def __init__(self, device: torch.device):
        super().__init__()
        self.w = nn.Parameter(torch.randn((100, 100), device=device))

    def forward(self, xs: tuple[torch.Tensor, torch.Tensor]) -> torch.Tensor:
        assert len(xs) == 2, f"Expects 2 args but got {len(xs)}"
        x, y = xs
        z = x + y
        z = z @ self.w
        return nn.functional.relu(z)


class TestCheckpoint(TestCase):
    def _get_graph_size(self, out: torch.Tensor) -> int:
        q = deque([out.grad_fn])
        num_functions = 0
        while len(q):
            fn = q.pop()
            num_functions += 1
            for next_fn, _ in fn.next_functions:
                if next_fn:
                    q.append(next_fn)

        return num_functions

    def _test_tensor_only(
        self,
        net: nn.Module,
        x: torch.Tensor,
    ) -> None:
        x1 = x.clone()
        x2 = x.clone()
        x1.requires_grad = True
        x2.requires_grad = True

        net1 = net
        net2 = deepcopy(net)

        # no checkpoint
        with MemoryDelta(x.device) as mem1:
            loss1 = net1(x1).sum()
        loss1.backward()

        # with checkpoint
        checkpoint(net2.seq)
        with MemoryDelta(x.device) as mem2:
            loss2 = net2(x2).sum()
        loss2.backward()

        if x.is_cuda:
            self.assertTrue(mem2.delta() < mem1.delta())

        for p1, p2 in zip(net1.parameters(), net2.parameters()):
            self.assertEqual(p1.grad, p2.grad)

    def test_tensor_only_cpu(self):
        x = torch.randn(20, 100)
        net = ToyModel()
        self._test_tensor_only(net, x)

    @unittest.skipIf(not TEST_CUDA, "no cuda")
    def test_tensor_only_gpu(self):
        x = torch.randn(20, 100, device="cuda:0")
        net = ToyModel().to("cuda:0")
        self._test_tensor_only(net, x)

    def test_random_cpu(self):
        x1 = torch.randn(20, 100, requires_grad=True)
        x2 = x1.clone()

        net1 = RandomModel()
        net2 = deepcopy(net1)

        cpu_rng_state = torch.get_rng_state()
        net1(x1).sum().backward()
        torch.set_rng_state(cpu_rng_state)
        checkpoint(net2)(x2).sum().backward()

        for p1, p2 in zip(net1.parameters(), net2.parameters()):
            self.assertEqual(p1.grad, p2.grad)

    def test_multi_args(self):
        """
        Tests checkpoint for modules with multiple output args and hence
        multiple backward function input args.
        """
        device = torch.device("cpu")
        net1 = nn.Sequential(
            MultiOutputModel(device),
            MultiInputModel(device),
            MultiOutputModel(device),
            MultiInputModel(device),
        )
        net2 = deepcopy(net1)
        checkpoint(net2[0])
        checkpoint(net2[2])
        x1 = torch.randn(20, 100, requires_grad=True)
        x2 = x1.clone()
        net1(x1).sum().backward()
        net2(x2).sum().backward()
        for p1, p2 in zip(net1.parameters(), net2.parameters()):
            self.assertEqual(p1.grad, p2.grad)

    def test_clears_state_on_error_in_forward(self):
        class MyModel(torch.nn.Module):
            def __init__(self, raise_in_recomp):
                super().__init__()
                self.fwd_count = 0
                self.raise_in_recomp = raise_in_recomp
                self.a = torch.nn.Linear(2, 2)

            def forward(self, x):
                if self.raise_in_recomp and self.fwd_count == 1:
                    raise RuntimeError("foo")
                else:
                    if not self.raise_in_recomp:
                        # raise in the first forward
                        raise RuntimeError("foo")
                    self.fwd_count += 1
                    return self.a(x)

        m = MyModel(raise_in_recomp=True)
        m_seq = torch.nn.Sequential(OrderedDict({"m": m}))
        checkpoint(m_seq.m)
        inp = torch.randn(1, 2)
        out = m_seq(inp).sum()
        # Should raise in forward recomputation
        with self.assertRaisesRegex(RuntimeError, "foo"):
            out.backward()

        # Check that _ac_generator is cleared out
        self.assertEqual(None, checkpoint.state(m)._ac_generator)

        m = MyModel(raise_in_recomp=False)
        checkpoint(m)
        inp = torch.randn(1, 2)
        # Should raise in first forward
        with self.assertRaises(RuntimeError):
            m(inp)

        self.assertEqual(None, checkpoint.state(m)._ac_generator)

    def test_checkpoint_kwargs(self):
        class MyModel(torch.nn.Module):
            def __init__(self, raise_exp: bool, change_shape_in_recomp: bool):
                super().__init__()
                self.fwd_count = 0
                self.raise_exp = raise_exp
                self.change_shape_in_recomp = change_shape_in_recomp
                self.a = torch.nn.Linear(2, 2)

            def forward(self, x):
                if self.raise_exp and self.fwd_count == 0:
                    raise RuntimeError("foo")
                if self.raise_exp and self.fwd_count == 1:
                    raise RuntimeError("bar")
                if self.change_shape_in_recomp and self.fwd_count == 1:
                    x.relu_()
                random_tensor = torch.randn(1, 2)
                x = self.a(x + random_tensor)
                self.fwd_count += 1
                return x

        m = MyModel(True, False)
        m0, m1, m2, m3 = (deepcopy(m) for _ in range(4))

        # composable checkpoint does not support use_reentrant=True
        with self.assertRaisesRegex(
            NotImplementedError,
            "use_reentrant=True is not supported in composable checkpoint. "
            "Please use torch.utils.checkpoint.checkpoint instead.",
        ):
            checkpoint(m, use_reentrant=True)

        # check giving an unsupported kwarg
        with self.assertRaisesRegex(ValueError, "Unexpected keyword arguments: foo"):
            checkpoint(m0, foo="bar")

        handled_fwd_exp = False
        handled_recomp_exp = False

        @contextmanager
        def fwd_ctx(mod: MyModel):
            try:
                mod.raise_exp = False
                yield
            finally:
                nonlocal handled_fwd_exp
                handled_fwd_exp = True
                mod.raise_exp = True

        @contextmanager
        def recomp_ctx(mod: MyModel):
            try:
                mod.raise_exp = False
                yield
            finally:
                nonlocal handled_recomp_exp
                handled_recomp_exp = True
                mod.raise_exp = True

        # Test different context functions
        x = torch.randn(1, 2, requires_grad=True)
        checkpoint(
            m1, context_fn=lambda: (partial(fwd_ctx, m1)(), partial(recomp_ctx, m1)())
        )
        m1(x.clone()).sum().backward()
        self.assertEqual((handled_fwd_exp, handled_recomp_exp), (True, True))

        checkpoint(m2, context_fn=lambda: (nullcontext(), partial(recomp_ctx, m2)()))
        with self.assertRaisesRegex(RuntimeError, "foo"):
            m2(x.clone())

        handled_fwd_exp = False  # Reset flag
        checkpoint(m3, context_fn=lambda: (partial(fwd_ctx, m3)(), nullcontext()))
        with self.assertRaisesRegex(RuntimeError, "bar"):
            m3(x.clone()).sum().backward()
        self.assertEqual(handled_fwd_exp, True)

        # Test determinism check failure
        m4 = MyModel(False, True)
        m5 = deepcopy(m4)
        # Determinism check should not throw an error,
        # but autograd should throw a RuntimeError
        checkpoint(m4, determinism_check="none")
        with self.assertRaises(RuntimeError):
            m4(x.clone()).sum().backward()

        # Determinism check should throw a CheckpointError
        checkpoint(m5, determinism_check="default")
        with self.assertRaises(CheckpointError):
            m5(x.clone()).sum().backward()

        # Test preserving random state
        m6 = MyModel(False, False)
        m7, m8 = (deepcopy(m6) for _ in range(2))
        checkpoint(m7, preserve_rng_state=False)
        checkpoint(m8, preserve_rng_state=True)

        for mi in (m6, m7, m8):
            torch.manual_seed(42)
            loss = mi(x.clone()).sum()
            torch.manual_seed(41)
            loss.backward()
        # check that m6 and m7 have at least one different grad
        self.assertNotEqual(
            (p1.grad for p1 in m6.parameters()), (p2.grad for p2 in m7.parameters())
        )
        # check that m6 and m8 have identical grads
        for p1, p2 in zip(m6.parameters(), m8.parameters()):
            self.assertEqual(p1.grad, p2.grad)


if __name__ == "__main__":
    run_tests()
