# Owner(s): ["oncall: distributed"]

import os
import pickle
from io import BytesIO
from typing import cast

import torch
import torch.distributed as dist
from torch.distributed._serialization import _streaming_load, _streaming_save
from torch.distributed.tensor import DeviceMesh, distribute_tensor, DTensor
from torch.testing._internal.common_utils import requires_cuda, run_tests, TestCase


DEBUG_ENV = "TORCH_SERIALIZATION_DEBUG"


class MyClass:
    def __init__(self, a: int) -> None:
        self.a = a

    def __eq__(self, other: "MyClass") -> bool:
        return self.a == other.a


class TestSerialization(TestCase):
    def setUp(self) -> None:
        # disable debug asserts
        self._old_debug = os.environ.get(DEBUG_ENV)
        os.environ[DEBUG_ENV] = "0"

    def tearDown(self):
        if self._old_debug is not None:
            os.environ[DEBUG_ENV] = self._old_debug

    def test_scalar_tensor(self) -> None:
        tensor = torch.tensor(42, dtype=torch.int32)
        state_dict = {"scalar": tensor}
        file = BytesIO()
        _streaming_save(state_dict, file)
        file.seek(0)

        result = _streaming_load(file)
        torch.testing.assert_close(result, state_dict)

    def test_strided_tensor(self) -> None:
        base_tensor = torch.arange(16, dtype=torch.float32).reshape(4, 4)
        strided_tensor = base_tensor[::2, ::2]
        state_dict = {"strided": strided_tensor}
        file = BytesIO()
        _streaming_save(state_dict, file)
        file.seek(0)

        result = _streaming_load(file)
        torch.testing.assert_close(result, state_dict)

    def test_tensor_with_offset(self) -> None:
        state_dict = {
            "offset": torch.arange(10, dtype=torch.float64)[2:],
            "strided": torch.arange(10, dtype=torch.float64)[2::2],
        }
        file = BytesIO()
        _streaming_save(state_dict, file)
        file.seek(0)

        result = _streaming_load(file)
        torch.testing.assert_close(result, state_dict)

    def test_nested_tensors(self) -> None:
        tensor1 = torch.tensor([1, 2, 3], dtype=torch.int32)
        tensor2 = torch.tensor([[1.5, 2.5], [3.5, 4.5]], dtype=torch.float64)
        state_dict = {"nested": {"tensor1": tensor1, "tensor2": tensor2}}
        file = BytesIO()
        _streaming_save(state_dict, file)
        file.seek(0)

        result = _streaming_load(file)
        torch.testing.assert_close(result, state_dict)

    def test_various_data_types(self) -> None:
        tensor_float32 = torch.tensor([1.0, 2.0, 3.0], dtype=torch.float32)
        tensor_int16 = torch.tensor([1, 2, 3], dtype=torch.int16)
        tensor_bool = torch.tensor([True, False, True], dtype=torch.bool)
        tensor_uint16 = torch.tensor([True, False, True], dtype=torch.uint16)
        state_dict = {
            "float32": tensor_float32,
            "int16": tensor_int16,
            "bool": tensor_bool,
            "uint16": tensor_uint16,
        }
        file = BytesIO()
        _streaming_save(state_dict, file)
        file.seek(0)

        result = _streaming_load(file)
        torch.testing.assert_close(result, state_dict)

    def test_dtensor(self) -> None:
        dist.init_process_group(
            backend="gloo", rank=0, world_size=1, store=dist.HashStore()
        )

        device_mesh = DeviceMesh("cpu", 1)
        tensor = torch.randn(4, 4)
        dtensor = distribute_tensor(tensor, device_mesh, [])
        state_dict = dtensor
        file = BytesIO()
        _streaming_save(state_dict, file)
        file.seek(0)

        result = cast(DTensor, _streaming_load(file))
        torch.testing.assert_close(result.to_local(), state_dict.to_local())
        self.assertEqual(result._spec, state_dict._spec)

    def test_python_object(self) -> None:
        state_dict = {
            "obj": MyClass(42),
        }

        file = BytesIO()
        _streaming_save(state_dict, file)
        file.seek(0)

        result = _streaming_load(file, weights_only=False)
        self.assertEqual(result, state_dict)

    def test_str_utf8(self) -> None:
        state_dict = {
            "obj": "Ü",
        }

        file = BytesIO()
        _streaming_save(state_dict, file)
        file.seek(0)

        result = _streaming_load(file)
        self.assertEqual(result, state_dict)

    def test_weights_only(self) -> None:
        state_dict = {
            "obj": MyClass(42),
        }

        file = BytesIO()
        _streaming_save(state_dict, file)
        file.seek(0)

        with self.assertRaisesRegex(pickle.UnpicklingError, "not an allowed global"):
            _streaming_load(file)

        with self.assertRaisesRegex(RuntimeError, "explicit pickle_module"):
            _streaming_load(file, weights_only=True, pickle_module=pickle)

    @requires_cuda
    def test_cuda(self) -> None:
        device = torch.device("cuda:0")

        tensor = torch.tensor(42, dtype=torch.float, device=device)
        state_dict = {"scalar": tensor}
        file = BytesIO()
        _streaming_save(state_dict, file)
        file.seek(0)

        result = _streaming_load(file)
        torch.testing.assert_close(result, state_dict)
        self.assertEqual(result["scalar"].device, device)


if __name__ == "__main__":
    run_tests()
