# Owner(s): ["module: dynamo"]

# ruff: noqa: TRY002

import itertools
import types
import unittest
import weakref
from collections import defaultdict, namedtuple, OrderedDict
from typing import Any

import torch
import torch._dynamo.test_case
import torch._dynamo.testing
import torch._functorch.config
import torch.nn
import torch.utils.checkpoint
from torch._dynamo.testing import same
from torch._dynamo.utils import dict_items


class SimpleDict(dict):
    pass


class DictTests(torch._dynamo.test_case.TestCase):
    def test_dict_subclass_instantiation(self):
        def fn(x):
            sd = SimpleDict(x=5)
            return sd["x"] * x

        x = torch.randn(4)
        opt_fn = torch.compile(fn, backend="eager", fullgraph=True)
        self.assertEqual(fn(x), opt_fn(x))

    def test_dict_subclass_local_mutation(self):
        def fn(x):
            sd = SimpleDict(x=5)
            z = sd["x"] * x
            sd["x"] = 10
            return z * sd["x"]

        x = torch.randn(4)
        opt_fn = torch.compile(fn, backend="eager", fullgraph=True)
        self.assertEqual(fn(x), opt_fn(x))

    def test_dict_subclass_local_with_non_dict_method(self):
        # Checks that add_1 method is inlined
        class MethodDict(dict):
            def add_1(self, x):
                return x + 1

        def fn(x):
            sd = MethodDict(x=5)
            z = sd["x"] * x
            sd["x"] = 10
            return sd.add_1(z * sd["x"])

        x = torch.randn(4)
        opt_fn = torch.compile(fn, backend="eager", fullgraph=True)
        self.assertEqual(fn(x), opt_fn(x))

    def test_dict_contains(self):
        sd = dict()
        sd[2] = 5
        sd[4] = 10

        def fn(x):
            if 1 in sd:
                x = x * 2
            else:
                x = x * 3
            return x

        x = torch.randn(4)
        opt_fn = torch.compile(fn, backend="eager", fullgraph=True)
        self.assertEqual(fn(x), opt_fn(x))

        # Ensure a recompilation
        sd[1] = 15
        self.assertEqual(fn(x), opt_fn(x))

        # Ensure not recompilation because the traced program remains same here.
        sd[2] = 10
        with unittest.mock.patch("torch._dynamo.config.error_on_recompile", True):
            self.assertEqual(fn(x), opt_fn(x))

    def test_dict_subclass_methods_fallback_readonly(self):
        sd = SimpleDict()
        sd[2] = 5
        sd[4] = 10
        # check that regular attr accesses work well
        sd.attr = 4

        def fn(x):
            for value in sd.values():
                x = x * value
            for key in sd.keys():
                x = x * key
            for k, v in sd.items():
                x = x * k
                x = x * v
            # for k in sd:
            #     x = x * k

            if 1 in sd:
                x = x * 2
            else:
                x = x * 3

            x = x * sd.get(2, 0)
            x = x * sd.get(3, 4)
            x = len(sd) * x
            x = x * sd.attr
            return x

        x = torch.randn(4)
        opt_fn = torch.compile(fn, backend="eager", fullgraph=True)
        self.assertEqual(fn(x), opt_fn(x))

        # Ensure a recompilation
        sd[6] = 15
        self.assertEqual(fn(x), opt_fn(x))

    def test_dict_subclass_instantiation_return(self):
        def fn(x):
            sd = SimpleDict(x=5 * x)
            sd["y"] = 10
            return sd

        x = torch.randn(4)
        opt_fn = torch.compile(fn, backend="eager", fullgraph=True)
        ref = fn(x)
        res = opt_fn(x)
        self.assertEqual(type(ref), type(res))
        self.assertEqual(ref["x"], res["x"])
        self.assertEqual(ref["y"], res["y"])

    def test_dict_subclass_methods_fallback_mutation(self):
        def fn(sd, x):
            for value in sd.values():
                x = x * value
            sd[6] = 14
            for key in sd.keys():
                x = x * key
            for k, v in sd.items():
                x = x * k
                x = x * v
            # for k in sd:
            #     x = x * k

            if 1 in sd:
                x = x * 2
            else:
                x = x * 3

            x = x * sd.get(2, 0)
            x = x * sd.get(3, 4)
            x = len(sd) * x
            return x

        x = torch.randn(4)
        opt_fn = torch.compile(fn, backend="eager", fullgraph=True)

        sd1 = SimpleDict()
        sd1[2] = 5
        sd1[4] = 10

        sd2 = SimpleDict()
        sd2[2] = 5
        sd2[4] = 10
        self.assertTrue(sd1 == sd2)

        self.assertEqual(fn(sd1, x), opt_fn(sd2, x))
        self.assertTrue(sd1 == sd2)

    def test_dict_subclass_setitem(self):
        class SetItemDict(dict):
            def __setitem__(self, key, value):
                super().__setitem__(key, value + 1)

        def fn(x):
            sd = SetItemDict(x=5 * x)
            sd["y"] = 10
            return sd

        x = torch.randn(4)
        opt_fn = torch.compile(fn, backend="eager", fullgraph=True)
        ref = fn(x)
        res = opt_fn(x)
        self.assertEqual(type(ref), type(res))
        self.assertEqual(ref["x"], res["x"])
        self.assertEqual(ref["y"], res["y"])

    def test_custom_iter_dict(self):
        class ReversedDict(dict):
            def __iter__(self):
                return reversed(list(self.keys()))

        d = {
            "foo": 1,
            "bar": 2,
        }

        d = ReversedDict(d)

        @torch.compile(backend="eager")
        def fn(x, d):
            # Forces side effects attribute reapplication logic
            d.sample = 1
            d["baz"] = 4
            return x * d["foo"] * d["bar"]

        fn(torch.randn(4), d)
        # This is intentional because the dict is mutated, so we will have a recompilation.
        fn(torch.randn(4), d)
        with unittest.mock.patch("torch._dynamo.config.error_on_recompile", True):
            fn(torch.randn(4), d)

    def test_custom_keys_iter_dict(self):
        class ReversedDict(dict):
            def keys(self):
                return ["bar", "foo"]

        d = {
            "foo": 1,
            "bar": 2,
        }

        d = ReversedDict(d)

        @torch.compile(backend="eager")
        def fn(x, d):
            return x * d["foo"] * d["bar"]

        fn(torch.randn(4), d)
        with unittest.mock.patch("torch._dynamo.config.error_on_recompile", True):
            fn(torch.randn(4), d)

    def test_dict_guard_on_keys_order(self):
        d = {
            2: 4,
            3: 5,
        }

        cnts = torch._dynamo.testing.CompileCounter()

        def fn(x, d):
            for key, value in d.items():
                x = x * key + value
            return x

        opt_fn = torch.compile(fn, backend=cnts)
        opt_fn(torch.randn(4), d)
        opt_fn(torch.randn(4), d)
        # No recompilation
        self.assertEqual(cnts.frame_count, 1)

        # move 2 to the end
        d[2] = d.pop(2)

        x = torch.randn(4)
        res = opt_fn(x, d)
        # Check recompilation
        self.assertEqual(cnts.frame_count, 2)
        self.assertEqual(res, fn(x, d))

    def test_dict_guard_on_keys_order2(self):
        d = {
            2: 4,
            3: 5,
        }

        cnts = torch._dynamo.testing.CompileCounter()

        def fn(x, d):
            for key in d:
                value = d[key]
                x = x * key + value
            return x

        opt_fn = torch.compile(fn, backend=cnts)
        opt_fn(torch.randn(4), d)
        opt_fn(torch.randn(4), d)
        # No recompilation
        self.assertEqual(cnts.frame_count, 1)

        # move 2 to the end
        d[2] = d.pop(2)

        x = torch.randn(4)
        res = opt_fn(x, d)
        # Check recompilation
        self.assertEqual(cnts.frame_count, 2)
        self.assertEqual(res, fn(x, d))

    def test_ordered_dict_reordered_keys(self):
        d = OrderedDict()
        d[2] = 4
        d[3] = 5
        d.move_to_end(2)

        cnts = torch._dynamo.testing.CompileCounter()

        def fn(x, d):
            y = 0
            for idx, (key, value) in enumerate(d.items()):
                if idx == 0:
                    y += torch.sin(x * value)
                else:
                    y += torch.cos(x * value)
            return y

        opt_fn = torch.compile(fn, backend=cnts)
        x = torch.randn(4)
        self.assertEqual(opt_fn(x, d), fn(x, d))

    def test_ordered_dict_subclass_reordered_keys(self):
        class ODSubclass(OrderedDict):
            def keys(self):
                return super().keys()

        d = ODSubclass()
        d[2] = 4
        d[3] = 5
        d.move_to_end(2)

        cnts = torch._dynamo.testing.CompileCounter()

        def fn(x, d):
            y = 0
            for idx, (key, value) in enumerate(d.items()):
                if idx == 0:
                    y += torch.sin(x * value)
                else:
                    y += torch.cos(x * value)
            return y

        opt_fn = torch.compile(fn, backend=cnts)
        x = torch.randn(4)
        self.assertEqual(opt_fn(x, d), fn(x, d))

    def test_lazy_key_guarding(self):
        d = {"a": 2, "b": 3, "c": 5}

        def fn(x):
            return x * d["a"]

        opt_fn = torch.compile(fn, backend="eager", fullgraph=True)

        x = torch.randn(4)
        ref = fn(x)
        res = opt_fn(x)
        self.assertEqual(ref, res)

        # Since key c was not used, it should not lead to a recompilation
        d.pop("c")
        d["d"] = 10

        with unittest.mock.patch("torch._dynamo.config.error_on_recompile", True):
            ref = fn(x)
            res = opt_fn(x)
            self.assertEqual(ref, res)

    def test_lazy_key_non_const_guarding(self):
        d = {
            list: 2,
            dict: 3,
            OrderedDict: 5,
            namedtuple: 7,
        }

        def fn(x):
            return x * d[list]

        opt_fn = torch.compile(fn, backend="eager", fullgraph=True)

        x = torch.randn(4)
        ref = fn(x)
        res = opt_fn(x)
        self.assertEqual(ref, res)

        # Since key c was not used, it should not lead to a recompilation
        d.pop(dict)
        d[defaultdict] = 10

        with unittest.mock.patch("torch._dynamo.config.error_on_recompile", True):
            ref = fn(x)
            res = opt_fn(x)
            self.assertEqual(ref, res)

    def test_dict_mutation_side_effect(self):
        def fn(d):
            d["c"] = d["a"] + d.pop("b")
            return d

        args1 = {"a": torch.randn(10), "b": torch.randn(10)}
        args2 = dict(args1)
        assert fn(args1) is args1
        cnts = torch._dynamo.testing.CompileCounter()
        opt_fn = torch.compile(fn, backend=cnts)
        self.assertIs(opt_fn(args2), args2)
        self.assertTrue(same(args1, args2))
        self.assertEqual(cnts.frame_count, 1)
        self.assertEqual(cnts.op_count, 1)

    def test_dict_copy_alias(self):
        @torch.compile(backend="eager", fullgraph=True)
        def run(x, d0):
            d1 = d0.copy()
            d1[0] = 1
            return x + 1, d1

        d0 = {}
        res, d1 = run(torch.zeros(1), d0)
        self.assertTrue(same(res, torch.ones(1)))
        self.assertEqual(d0, {})
        self.assertEqual(d1, {0: 1})

    def test_dict_subclass_get_method(self):
        class dotdict(dict):
            """dot.notation access to dictionary attributes"""

            __getattr__ = dict.get
            __setattr__ = dict.__setitem__
            __delattr__ = dict.__delitem__

        config = dotdict({"a": 1, "b": 2})

        def fn(x):
            x2 = x * 2  # noqa: F841
            x3 = x * config.get("a", 3)
            return x3

        x = torch.randn(2)
        opt_fn = torch.compile(fn, backend="eager", fullgraph=True)
        self.assertEqual(fn(x), opt_fn(x))

    def test_dict_order_keys(self):
        def fn(d):
            c = 0
            for v in d.values():
                c += v
            return c

        args1 = {}
        args1["a"] = torch.rand(10)
        args1["b"] = torch.rand(10)
        cnts = torch._dynamo.testing.CompileCounter()
        opt_fn = torch.compile(fn, backend=cnts)
        self.assertEqual(fn(args1), opt_fn(args1))
        self.assertEqual(cnts.frame_count, 1)
        self.assertEqual(cnts.op_count, 2)

        # A different order of keys recompiles
        args2 = {}
        args2["b"] = args1["b"]
        args2["a"] = args1["a"]
        self.assertEqual(fn(args2), opt_fn(args2))
        self.assertEqual(cnts.frame_count, 2)
        # Extra calls don't recompile
        self.assertEqual(cnts.frame_count, 2)

    def test_dict_namedtuple(self):
        def fn(d):
            if namedtuple in d:
                return d[3] * 2
            else:
                return d[3] * 3

        args1 = {namedtuple: None, 3: torch.randn(3)}
        cnts = torch._dynamo.testing.CompileCounter()
        opt_fn = torch.compile(fn, backend=cnts)
        self.assertEqual(fn(args1), opt_fn(args1))
        self.assertEqual(cnts.frame_count, 1)
        # Test a failing namedtuple guard
        args2 = {2: None, 3: torch.randn(3)}
        self.assertEqual(fn(args2), opt_fn(args2))
        self.assertEqual(cnts.frame_count, 2)

    def test_dict_order_keys_tensors(self):
        def fn(d, x):
            return d[x] + 3

        args1 = {}
        x = torch.randn(10)
        y = torch.randn(10)
        z = torch.randn(10)
        args1[x] = y
        args1[3] = z

        cnts = torch._dynamo.testing.CompileCounter()
        opt_fn = torch.compile(fn, backend=cnts)
        self.assertEqual(fn(args1, x), opt_fn(args1, x))
        self.assertEqual(cnts.frame_count, 1)

        # Calling again doesn't recompile (same id and key order)
        opt_fn(args1, x)
        self.assertEqual(cnts.frame_count, 1)
        args2 = {}
        args2[3] = z
        args2[x] = y

        # Different order recompiles
        self.assertEqual(fn(args2, x), opt_fn(args2, x))
        self.assertEqual(cnts.frame_count, 2)

    def test_dict_order_keys_modules(self):
        def fn(d, x):
            return d[x](torch.ones(2, 2))

        args1 = {}
        x = torch.nn.Linear(2, 2)
        y = torch.nn.Linear(2, 2)
        z = torch.nn.Linear(2, 2)
        args1[x] = y
        args1[3] = z

        cnts = torch._dynamo.testing.CompileCounter()
        opt_fn = torch.compile(fn, backend=cnts)
        self.assertEqual(fn(args1, x), opt_fn(args1, x))
        self.assertEqual(cnts.frame_count, 1)

        # Calling again doesn't recompile (same id and key order)
        opt_fn(args1, x)
        self.assertEqual(cnts.frame_count, 1)
        args2 = {}
        args2[3] = z
        args2[x] = y

        # Different order recompiles
        self.assertEqual(fn(args2, x), opt_fn(args2, x))
        self.assertEqual(cnts.frame_count, 2)

    def test_contains_dunder_dict(self):
        class UserDefined:
            def __init__(self) -> None:
                self.a = 3
                self.b = 5

            def run(self, x):
                if "a" in self.__dict__:
                    x = x * self.a
                if "b" in self.__dict__:
                    x = x * self.b
                self.c = 7
                if "c" in self.__dict__:
                    x = x * self.c
                return x * self.__dict__.get("a") * self.__dict__.get("z", 2)

        obj = UserDefined()

        def fn(x):
            return obj.run(x)

        x = torch.randn(4)
        ref = fn(x)
        opt_fn = torch.compile(fn, backend="eager", fullgraph=True)
        res = opt_fn(x)
        self.assertEqual(ref, res)

    def test_contains_module_dunder_dict(self):
        class MyModule(torch.nn.Module):
            def __init__(self) -> None:
                super().__init__()
                self.foo = 1
                self.bar = 2
                self.baz = 3

            def forward(self, x):
                if "foo" in self.__dict__:
                    return x * self.bar
                return x * self.baz

        mod = MyModule()
        x = torch.randn(10)
        opt_mod = torch.compile(mod, backend="eager", fullgraph=True)
        self.assertEqual(mod(x), opt_mod(x))

    def test_update_dunder_dict(self):
        class UserDefined:
            def run(self, x):
                self.__dict__["a"] = 10
                return x * self.a + self.__dict__["a"]

        obj1 = UserDefined()
        obj2 = UserDefined()

        def fn(x, obj):
            return obj.run(x)

        x = torch.randn(4)
        opt_fn = torch.compile(fn, backend="eager", fullgraph=True)
        ref = fn(x, obj1)
        res = opt_fn(x, obj2)
        self.assertEqual(ref, res)
        # Make sure only `a` is updated.
        self.assertEqual(obj1.__dict__, obj2.__dict__)

    def test_update_module_dunder_dict(self):
        class MyModule(torch.nn.Module):
            def __init__(self) -> None:
                super().__init__()

            def forward(self, x):
                self.__dict__["a"] = 10
                return x * self.a + self.__dict__["a"]

        mod = MyModule()
        x = torch.randn(10)
        opt_mod = torch.compile(mod, backend="eager", fullgraph=True)
        self.assertEqual(mod(x), opt_mod(x))

    def test_dict_reconstruct_keeps_original_order(self):
        def fn():
            modules = OrderedDict([("act", torch.nn.ReLU())])
            module_dict = torch.nn.ModuleDict(modules)

            next_modules = {"fc4": torch.nn.Linear(5, 6), "act3": torch.nn.Sigmoid()}
            modules.update(next_modules.items())
            module_dict.update(next_modules)
            return modules, module_dict

        cnts = torch._dynamo.testing.CompileCounter()
        opt_fn = torch.compile(fn, backend=cnts)
        modules, module_dict = opt_fn()

        self.assertEqual(len(module_dict), len(modules))
        for k1, m2 in zip(modules, module_dict.children()):
            self.assertTrue(modules[k1] is m2)

    def test_dict_subclass_initialization_in_graph(self):
        for super_class in (
            OrderedDict,
            dict,
        ):

            class CustomDict(super_class):
                def __new__(cls, *args, **kwargs):
                    return super().__new__(cls, *args, **kwargs)

                def __init__(self, *args, **kwargs):
                    super().__init__(*args, **kwargs)

            def fn(x):
                c = CustomDict()
                c["key"] = x
                assert "key" in c
                return c["key"] + 1

            opt_fn = torch.compile(fn, backend="eager", fullgraph=True)

            x = torch.rand(4)
            self.assertEqual(fn(x), opt_fn(x))

    def test_dict_list_values(self):
        def inner_fn(args):
            return [x[1].shape for x in args]

        @torch.compile(backend="eager")
        def fn(tensors):
            return inner_fn(zip(itertools.count(), tensors["args"]))

        fn({"args": [torch.ones(5, 5), torch.ones(5, 6), torch.ones(5, 7)]})
        fn({"args": [torch.ones(5, 5)]})

    def test_dict_iter(self):
        class MyMod(torch.nn.Module):
            def forward(self, x):
                z = {"my": 1, "const": 2, "dict": 3, "variable": 4}
                tot = 0
                for key in z:
                    tot += z[key]

                return tot

        x = torch.tensor([0])
        model = MyMod()
        opt_model = torch.compile(model, backend="eager", fullgraph=True)
        y = opt_model(x)

        self.assertEqual(y, 10)

    def test_dict_subclass_contains(self):
        # pattern from huggingface
        class ClassInstantier(OrderedDict):
            pass

        @torch.compile(fullgraph=True, backend="eager")
        def f(x, d):
            if "key1" in d:
                x = x + 2
            if "key2" in d:
                x = x + 4
            x = x + 8
            return x

        result = f(torch.ones(8), ClassInstantier({"key1": torch.ones(8)}))
        self.assertTrue(same(result, torch.full([8], 11.0)))

        result = f(torch.ones(8), ClassInstantier({"key2": torch.ones(8)}))
        self.assertTrue(same(result, torch.full([8], 13.0)))

    def test_dict_tag_guard(self):
        class Foo:
            def __init__(self) -> None:
                self.scalar = 10

        def fn(d, x):
            return d["a"] * d["b"] * d["c"].scalar * x

        foo = Foo()

        d = {"a": 2, "b": 3, "c": foo}

        opt_fn = torch.compile(fn, backend="eager")
        inp = torch.randn(3, 3)
        self.assertEqual(fn(d, inp), opt_fn(d, inp))

        d["a"] = 4
        self.assertEqual(fn(d, inp), opt_fn(d, inp))

        # Check that recompilation happens
        foo.scalar = 12
        self.assertEqual(fn(d, inp), opt_fn(d, inp))

    def test_empty_dict_recompilation(self):
        def fn(d, x):
            if d:
                return torch.cos(x)
            return torch.sin(x)

        opt_fn = torch.compile(fn, backend="eager", fullgraph=True)
        x = torch.randn(4)
        self.assertEqual(fn({}, x), opt_fn({}, x))
        self.assertEqual(fn({"a": 1}, x), opt_fn({"a": 1}, x))

    def test_udf_dict_reconstruction(self):
        class MyDict(dict):
            pass

        def fn(x, klass):
            x = x * 2
            sc_dict = dict.__new__(klass)
            sc_dict["x"] = x
            if isinstance(sc_dict, MyDict):
                sc_dict.attr = 3
            return sc_dict

        opt_fn = torch.compile(fn, backend="eager", fullgraph=True)
        x = torch.randn(4)
        ref = fn(x, MyDict)
        res = opt_fn(x, MyDict)
        self.assertEqual(ref, res)
        self.assertTrue(isinstance(res, MyDict))
        self.assertEqual(ref.attr, res.attr)

        ref = fn(x, dict)
        res = opt_fn(x, dict)
        self.assertEqual(ref, res)
        self.assertTrue(isinstance(res, dict))

    def test_weakref_dict(self):
        states = weakref.WeakKeyDictionary()

        mod1 = torch.nn.Module()
        mod2 = torch.nn.Module()

        states[mod1] = 2
        states[mod2] = 3

        def fn(x):
            if mod1 in states:
                x = torch.sin(x)
            if mod2 in states:
                x = torch.cos(x)
            return x

        opt_fn = torch.compile(fn, backend="eager", fullgraph=True)
        x = torch.randn(4)
        self.assertEqual(fn(x), opt_fn(x))

    def test_fn_id(self):
        def fn(x, f):
            d = {id(f): 3}
            return x * d[id(f)]

        opt_fn = torch.compile(fn, backend="eager", fullgraph=True)
        x = torch.randn(4)

        def nothing():
            pass

        f = nothing
        self.assertEqual(fn(x, f), opt_fn(x, f))

    def test_mapping_proxy_for_local(self):
        def fn(x):
            d = {"a": 2, "b": 3, "c": 5 * x}
            mp = types.MappingProxyType(d)
            y = torch.sin(x * mp["a"])
            for k, v in mp.items():  # noqa: PERF102
                y += torch.cos(x * v)
            return mp

        opt_fn = torch.compile(fn, backend="eager", fullgraph=True)
        x = torch.randn(4)
        ref = fn(x)
        res = opt_fn(x)
        self.assertEqual(ref, res)
        self.assertTrue(type(res) is types.MappingProxyType)

    def test_mapping_proxy_for_nonlocal(self):
        d = {"a": 2, "b": 3, "c": 5}

        def fn(x):
            mp = types.MappingProxyType(d)
            y = torch.sin(x * mp["a"])
            for k, v in mp.items():  # noqa: PERF102
                y += torch.cos(x * v)
            d["d"] = 4
            return mp

        opt_fn = torch.compile(fn, backend="eager", fullgraph=True)
        x = torch.randn(4)
        ref = fn(x)
        res = opt_fn(x)
        self.assertEqual(ref, res)
        self.assertTrue(type(res) is types.MappingProxyType)

        # check update to d is reflected in res
        d["e"] = 5
        self.assertEqual(d["e"], res["e"])

    def test_mapping_proxy_existing(self):
        d = {"a": 2, "b": 3, "c": 5}

        def fn(x, mp):
            y = torch.sin(x * mp["a"])
            for k, v in mp.items():  # noqa: PERF102
                y += torch.cos(x * v)
            if isinstance(mp, types.MappingProxyType):
                y *= 2
            return y

        opt_fn = torch.compile(fn, backend="eager", fullgraph=True)
        x = torch.randn(4)
        mp = types.MappingProxyType(d)
        ref = fn(x, mp)
        res = opt_fn(x, mp)
        self.assertEqual(ref, res)

        d["a"] = 3
        ref = fn(x, mp)
        res = opt_fn(x, mp)
        self.assertEqual(ref, res)

        d.pop("b")
        ref = fn(x, mp)
        res = opt_fn(x, mp)
        self.assertEqual(ref, res)

    def test_dict_construction_from_mapping_proxy(self):
        d = {"a": 2, "b": 3, "c": 5}

        def fn(x, mp):
            d = dict(mp)
            y = torch.sin(x * d["a"])
            return y

        opt_fn = torch.compile(fn, backend="eager", fullgraph=True)
        x = torch.randn(4)
        mp = types.MappingProxyType(d)
        ref = fn(x, mp)
        res = opt_fn(x, mp)
        self.assertEqual(ref, res)

    def test_mapping_proxy_existing_mutation(self):
        d = {"a": 2, "b": 3, "c": 5}

        mp = types.MappingProxyType(d)

        def fn(x):
            d["d"] = 4
            y = torch.sin(x * mp["d"])
            return y

        opt_fn = torch.compile(fn, backend="eager")
        x = torch.randn(4)
        ref = torch.sin(x * 4)
        res = opt_fn(x)
        self.assertEqual(ref, res)
        self.assertEqual(d.keys(), mp.keys())

    def test_mapping_proxy_existing_local_mutation(self):
        d = {"a": 2, "b": 3, "c": 5}

        mp = types.MappingProxyType(d)

        def fn(x):
            # Dynamo should not cause a graph break here because it knows that
            # the existing proxy cant point to this new dict
            other_dict = {}
            other_dict["d"] = 4
            y = torch.sin(x * mp["c"])
            return y

        opt_fn = torch.compile(fn, backend="eager", fullgraph=True)
        x = torch.randn(4)
        ref = torch.sin(x * mp["c"])
        res = opt_fn(x)
        self.assertEqual(ref, res)
        self.assertEqual(d.keys(), mp.keys())

    def test_move_to_end(self):
        def fn(x):
            d = OrderedDict({"a": torch.cos(x), "b": 3, "c": 5})
            d.move_to_end("a")
            return d

        opt_fn = torch.compile(fn, backend="eager", fullgraph=True)
        x = torch.randn(4)
        self.assertEqual(["b", "c", "a"], list(opt_fn(x).keys()))
        self.assertEqual(fn(x), opt_fn(x))

    def test_overridden_get_item(self):
        class MyDict(dict):
            def __init__(self, *args, **kwargs):
                super().__init__(*args, **kwargs)
                self.calls = 0

            def __getitem__(self, key):
                self.calls += 1
                return super().__getitem__(key) + 1

        def fn(x, d):
            d["d"] = 4
            return x * d["a"] + d["b"] + d["c"] + d["d"]

        opt_fn = torch.compile(fn, backend="eager", fullgraph=True)
        x = torch.randn(4)
        d1 = MyDict({"a": 2, "b": 3, "c": 5})
        ref = fn(x, d1)

        d2 = MyDict({"a": 2, "b": 3, "c": 5})
        res = opt_fn(x, d2)
        self.assertEqual(ref, res)
        self.assertEqual(d1.calls, d2.calls)

    def test_items_type(self):
        def fn():
            d = dict({"a": 1, "b": "2", "c": torch.tensor(3)})  # noqa: C418
            return d.items()

        opt_fn = torch.compile(fn, backend="eager", fullgraph=True)
        ref = fn()
        res = opt_fn()
        self.assertEqual(ref, res)
        self.assertEqual(type(res), dict_items)

    def test_builtin_or_with_invalid_types(self):
        args = (
            1,  # int
            1.0,  # float
            "a",  # str
            (1, 2),  # tuple
            [1, 2],  # list
        )

        @torch.compile(backend="eager", fullgraph=True)
        def fn(b: Any):
            a = {"one": torch.ones(1)}
            return a | b

        from torch._dynamo.exc import InternalTorchDynamoError

        for arg in args:
            with self.assertRaisesRegex(
                InternalTorchDynamoError, "unsupported operand type"
            ):
                _ = fn(arg)

    def test_builtin_or_with_diff_keys(self):
        def f():
            a = {"one": torch.ones(1)}
            b = {"two": torch.ones(2)}
            return a, b, a | b, b | a, a.__or__(b), b.__or__(a)

        opt_f = torch.compile(f, backend="eager", fullgraph=True)
        self.assertEqual(f(), opt_f())

    def test_builtin_or_with_same_keys(self):
        def f():
            a = {"one": torch.ones(1), "two": torch.ones(2)}
            b = {"one": torch.ones(1), "three": torch.ones(3)}
            return a, b, a | b, b | a, a.__or__(b), b.__or__(a)

        opt_f = torch.compile(f, backend="eager", fullgraph=True)
        self.assertEqual(f(), opt_f())

    def test_builtin_ior_(self):
        def f():
            a = {"one": torch.ones(1)}
            b = {"two": torch.ones(2)}
            a |= b
            return a, b

        opt_f = torch.compile(f, backend="eager", fullgraph=True)
        self.assertEqual(f(), opt_f())

    def test_newly_constructed_default_dict(self):
        def f(x):
            d = defaultdict(list)
            d[0] = 42
            return x + 1, d

        x = torch.ones(2)
        ref = f(x)
        res = torch.compile(f, backend="eager", fullgraph=True)(x)

        self.assertEqual(ref, res)


if __name__ == "__main__":
    from torch._dynamo.test_case import run_tests

    run_tests()
