# Owner(s): ["module: inductor"]

import dataclasses

from torch.testing._internal.common_utils import TestCase
from torch.utils._appending_byte_serializer import (
    AppendingByteSerializer,
    BytesReader,
    BytesWriter,
)


class TestAppendingByteSerializer(TestCase):
    def test_write_and_read_int(self) -> None:
        def int_serializer(writer: BytesWriter, i: int) -> None:
            writer.write_uint64(i)

        def int_deserializer(reader: BytesReader) -> int:
            return reader.read_uint64()

        s = AppendingByteSerializer(serialize_fn=int_serializer)

        data = [1, 2, 3, 4]
        s.extend(data)
        self.assertListEqual(
            data,
            AppendingByteSerializer.to_list(
                s.to_bytes(), deserialize_fn=int_deserializer
            ),
        )

        data2 = [8, 9, 10, 11]
        s.extend(data2)
        self.assertListEqual(
            data + data2,
            AppendingByteSerializer.to_list(
                s.to_bytes(), deserialize_fn=int_deserializer
            ),
        )

    def test_write_and_read_class(self) -> None:
        @dataclasses.dataclass(frozen=True, eq=True)
        class Foo:
            x: int
            y: str
            z: bytes

            @staticmethod
            def serialize(writer: BytesWriter, cls: "Foo") -> None:
                writer.write_uint64(cls.x)
                writer.write_str(cls.y)
                writer.write_bytes(cls.z)

            @staticmethod
            def deserialize(reader: BytesReader) -> "Foo":
                x = reader.read_uint64()
                y = reader.read_str()
                z = reader.read_bytes()
                return Foo(x, y, z)

        a = Foo(5, "ok", bytes([15]))
        b = Foo(10, "lol", bytes([25]))

        s = AppendingByteSerializer(serialize_fn=Foo.serialize)
        s.append(a)
        self.assertListEqual(
            [a],
            AppendingByteSerializer.to_list(
                s.to_bytes(), deserialize_fn=Foo.deserialize
            ),
        )

        s.append(b)
        self.assertListEqual(
            [a, b],
            AppendingByteSerializer.to_list(
                s.to_bytes(), deserialize_fn=Foo.deserialize
            ),
        )

    def test_checksum(self) -> None:
        writer = BytesWriter()
        writer.write_str("test")
        b = writer.to_bytes()
        b = bytearray(b)
        b[0:1] = b"\x00"
        b = bytes(b)

        with self.assertRaisesRegex(
            RuntimeError, r"Bytes object is corrupted, checksum does not match.*"
        ):
            BytesReader(b)


if __name__ == "__main__":
    from torch._inductor.test_case import run_tests

    run_tests()
