# Copyright (c) Meta Platforms, Inc. and affiliates
# Owner(s): ["oncall: distributed"]
import copy

from model_registry import MLPModule

import torch
from torch.distributed.pipelining._backward import (
    stage_backward,
    stage_backward_input,
    stage_backward_weight,
)
from torch.testing._internal.common_device_type import instantiate_device_type_tests
from torch.testing._internal.common_utils import run_tests, TestCase


d_hid = 512
batch_size = 256


class StageBackwardTests(TestCase):
    def test_stage_backward(self, device):
        # MLP as a stage module
        mod = MLPModule(d_hid).to(device)
        x = torch.randn(batch_size, d_hid, device=device)
        # As in a pipeline stage, the inputs to this stage requires gradients
        x.requires_grad_(True)
        target = torch.randn(batch_size, d_hid, device=device)
        loss_fn = torch.nn.MSELoss(reduction="sum")

        # Make a copy
        ref_mod = copy.deepcopy(mod).to(device)
        ref_x = x.detach().requires_grad_(x.requires_grad).to(device)
        ref_target = target.detach().to(device)

        # Forward and backward in stage manner
        out = mod(x)
        loss = loss_fn(out, target)
        grad_inputs = stage_backward(
            stage_output=loss,
            output_grads=None,
            input_values=(x,),
        )

        # Run reference
        ref_out = ref_mod(ref_x)
        ref_loss = loss_fn(ref_out, ref_target)
        ref_loss.backward()

        torch.testing.assert_close(grad_inputs[0], ref_x.grad)

        # Every rank checks gradients
        for name, p in mod.named_parameters():
            ref_p = ref_mod.get_parameter(name)
            try:
                torch.testing.assert_close(p.grad, ref_p.grad)
            except AssertionError:
                print(f"Gradient test failed for {name}: {p.grad} vs {ref_p.grad}")
                raise

    def test_stage_backward_input(self, device):
        # MLP as a stage module
        mod = MLPModule(d_hid).to(device)
        x = torch.randn(batch_size, d_hid, device=device)
        # As in a pipeline stage, the inputs to this stage requires gradients
        x.requires_grad_(True)
        target = torch.randn(batch_size, d_hid, device=device)
        loss_fn = torch.nn.MSELoss(reduction="sum")

        # Make a copy
        ref_mod = copy.deepcopy(mod).to(device)
        ref_x = x.detach().requires_grad_(x.requires_grad).to(device)
        ref_target = target.detach().to(device)

        # Forward, then backward of loss with respect to inputs
        out = mod(x)
        loss = loss_fn(out, target)
        dinputs, _param_groups = stage_backward_input(
            stage_outputs_or_loss=(loss,),
            output_grads=None,
            input_values=[x],
            weights=mod.parameters(),
        )

        # Run reference
        ref_out = ref_mod(ref_x)
        ref_loss = loss_fn(ref_out, ref_target)
        ref_loss.backward()

        torch.testing.assert_close(x.grad, ref_x.grad)
        torch.testing.assert_close(dinputs[0], ref_x.grad)
        for _, p in mod.named_parameters():
            # Check that the weight gradients were not updated
            self.assertEqual(p.grad, None)

    def test_stage_backward_weight(self, device):
        # MLP as a stage module
        mod = MLPModule(d_hid).to(device)
        x = torch.randn(batch_size, d_hid, device=device)
        # As in a pipeline stage, the inputs to this stage requires gradients
        x.requires_grad_(True)
        target = torch.randn(batch_size, d_hid, device=device)
        loss_fn = torch.nn.MSELoss(reduction="sum")

        # Make a copy
        ref_mod = copy.deepcopy(mod).to(device)
        ref_x = x.detach().requires_grad_(x.requires_grad).to(device)
        ref_target = target.detach().to(device)
        # Forward, then backward of loss with respect to inputs
        out = mod(x)
        loss = loss_fn(out, target)
        _dinputs, param_groups = stage_backward_input(
            stage_outputs_or_loss=(loss,),
            output_grads=None,
            input_values=[x],
            weights=mod.parameters(),
        )

        # backward of loss with respect to weights
        stage_backward_weight(mod.parameters(), param_groups, retain_graph=True)

        # Run reference
        ref_out = ref_mod(ref_x)
        ref_loss = loss_fn(ref_out, ref_target)
        ref_loss.backward()

        # Every rank checks gradients
        for name, p in mod.named_parameters():
            ref_p = ref_mod.get_parameter(name)
            try:
                torch.testing.assert_close(p.grad, ref_p.grad)
            except AssertionError:
                print(f"Gradient test failed for {name}: {p.grad} vs {ref_p.grad}")
                raise

    def test_stage_backward_weight_multiple_iters(self, device):
        # MLP as a stage module
        mod = MLPModule(d_hid).to(device)
        inputs = []
        for _ in range(10):
            x = torch.randn(batch_size, d_hid, device=device)
            inputs.append(x)
            # As in a pipeline stage, the inputs to this stage requires gradients
            x.requires_grad_(True)

        target = torch.randn(batch_size, d_hid, device=device)
        loss_fn = torch.nn.MSELoss(reduction="sum")

        # Make a copy
        ref_mod = copy.deepcopy(mod).to(device)
        ref_inputs = []
        for x in inputs:
            ref_x = x.detach().requires_grad_(x.requires_grad).to(device)
            ref_inputs.append(ref_x)
        ref_target = target.detach().to(device)

        # Forward, then backward of loss with respect to inputs
        for x in inputs:
            out = mod(x)
            loss = loss_fn(out, target)
            _dinputs, param_groups = stage_backward_input(
                stage_outputs_or_loss=(loss,),
                output_grads=None,
                input_values=[x],
                weights=mod.parameters(),
            )

            # backward of loss with respect to weights
            stage_backward_weight(mod.parameters(), param_groups)

        # Run reference
        for ref_x in ref_inputs:
            ref_out = ref_mod(ref_x)
            ref_loss = loss_fn(ref_out, ref_target)
            ref_loss.backward()

        # Every rank checks gradients
        for name, p in mod.named_parameters():
            ref_p = ref_mod.get_parameter(name)
            try:
                torch.testing.assert_close(p.grad, ref_p.grad)
            except AssertionError:
                print(f"Gradient test failed for {name}: {p.grad} vs {ref_p.grad}")
                raise


devices = ["cpu", "cuda", "hpu", "xpu"]
instantiate_device_type_tests(StageBackwardTests, globals(), only_for=devices)

if __name__ == "__main__":
    run_tests()
