# Owner(s): ["module: dynamo"]

import contextlib
import sys

import torch
import torch._dynamo.config
import torch._dynamo.test_case
import torch._functorch.config
import torch.nn
import torch.utils.checkpoint
from torch._dynamo.bytecode_transformation import Instruction
from torch._dynamo.symbolic_convert import SpeculationLog, SpeculationLogDivergence
from torch.testing._internal.common_utils import (
    instantiate_parametrized_tests,
    make_dynamo_test,
    parametrize,
)


class CustomException(Exception):
    pass


class CustomExceptionMeta(type):
    def __instancecheck__(cls, instance):
        return True


class CustomExceptionWithInstanceCheck(Exception, metaclass=CustomExceptionMeta):
    pass


class CustomExceptionWithArgs(Exception):
    def __init__(self, a, b=None):
        self.a = a
        self.b = b


class MyException(OSError):
    pass


class ExceptionTests(torch._dynamo.test_case.TestCase):
    def test_exception(self):
        def fn(x):
            x = torch.cos(x)
            try:
                x = torch.sin(x)
                raise NotImplementedError
            except Exception:
                x = torch.sigmoid(x)

            return x

        x = torch.randn(4)
        ref = fn(x)
        opt_fn = torch.compile(fn, backend="eager", fullgraph=True)
        res = opt_fn(x)
        self.assertEqual(ref, res)

    def test_exception2(self):
        def fn(x):
            x = torch.cos(x)
            try:
                x = torch.sin(x)
                raise NotImplementedError
            except (NotImplementedError, AttributeError):
                x = torch.sigmoid(x)

            return x

        x = torch.randn(4)
        ref = fn(x)
        opt_fn = torch.compile(fn, backend="eager", fullgraph=True)
        res = opt_fn(x)
        self.assertEqual(ref, res)

    def test_exception3(self):
        def fn(x):
            x = torch.cos(x)
            try:
                x = torch.sin(x)
                raise NotImplementedError("Not implemented")
            except AssertionError:
                x = torch.sigmoid(x)
            except NotImplementedError:
                x = torch.cos(x)
            finally:
                x = torch.cos(x)

            return x

        x = torch.randn(4)
        ref = fn(x)
        opt_fn = torch.compile(fn, backend="eager", fullgraph=True)
        res = opt_fn(x)
        self.assertEqual(ref, res)

    def test_exception4(self):
        def fn(x):
            for i in range(10):
                if i == 5:
                    return x
                try:
                    x = torch.sin(x)
                    raise NotImplementedError
                except Exception:
                    x = torch.sigmoid(x)

            return x

        x = torch.randn(4)
        ref = fn(x)
        opt_fn = torch.compile(fn, backend="eager", fullgraph=True)
        res = opt_fn(x)
        self.assertEqual(ref, res)

    def test_exception_with_another_exception(self):
        def fn(x):
            x = torch.cos(x)
            try:
                x = torch.sin(x)
                raise NotImplementedError("Not implemented")
            except NotImplementedError:
                x = torch.sigmoid(x)
                try:
                    x = torch.cos(x)
                    raise AssertionError
                except AssertionError:
                    x = torch.cos(x)

        x = torch.randn(4)
        ref = fn(x)
        opt_fn = torch.compile(fn, backend="eager", fullgraph=True)
        res = opt_fn(x)
        self.assertEqual(ref, res)

    def test_autocast_with_exception(self):
        class Optimizer(torch.autograd.Function):
            @staticmethod
            def forward(ctx, x):
                raise NotImplementedError("Not implemented")

            @staticmethod
            def backward(ctx, grad_out):
                return grad_out

        @torch.compile
        def f(x: torch.Tensor):
            try:
                with torch.autocast(device_type="cpu", dtype=None):
                    Optimizer.apply(x)
            except NotImplementedError:
                return x + 1

        inp = torch.ones(3)
        out = f(inp)
        self.assertTrue(torch.equal(out, inp + 1))

    @make_dynamo_test
    def test_isinstance_CustomException(self):
        assert isinstance(CustomException, type)
        assert not isinstance(CustomException(), type)
        C = CustomExceptionWithInstanceCheck
        assert isinstance(C, C)
        assert isinstance(C(), C)

    @make_dynamo_test
    def test_propagate_exception_inside_ctx_manager(self):
        @contextlib.contextmanager
        def cm():
            try:
                yield
            except BaseException:
                raise ValueError  # noqa: B904

        @contextlib.contextmanager
        def nothing():
            try:
                yield
            finally:
                pass

        z = 0
        with nothing():
            try:
                with cm():
                    raise IndexError
            except ValueError:
                z = 1
            except IndexError:
                z = 2
            assert z == 1

    def test_exception_else(self):
        def gn(x):
            return torch.cos(x)

        def fn(x):
            x = torch.cos(x)
            try:
                x = torch.sin(x)
                x = gn(x)
            except Exception:
                x = torch.sigmoid(x)
            else:
                x = torch.cos(x)

            return x

        x = torch.randn(4)
        ref = fn(x)
        opt_fn = torch.compile(fn, backend="eager", fullgraph=True)
        res = opt_fn(x)
        self.assertEqual(ref, res)

    @make_dynamo_test
    def test_raise_match(self):
        a = AttributeError
        b = BytesWarning
        c = ConnectionError
        d = DeprecationWarning
        e = Exception

        def fn(a, b):
            try:
                raise a
            finally:
                raise b

        def fix_exc_context(frame_exc, new_exc, old_exc):
            # slightly change from ExitStack.fix_exc_context function
            while 1:
                exc_context = new_exc.__context__
                if exc_context is None or exc_context is old_exc:
                    return
                if exc_context is frame_exc:
                    break
                new_exc = exc_context
            new_exc.__context__ = old_exc

        @contextlib.contextmanager
        def ctx():
            try:
                yield
            finally:
                frame_exc = prev_exc = sys.exc_info()
                args = [(d, c), (b, a)]
                for x, y in args:
                    try:
                        fn(x, y)
                    except BaseException:
                        new_exc = sys.exc_info()
                        fix_exc_context(frame_exc[1], new_exc[1], prev_exc[1])
                        prev_exc = new_exc

                try:
                    fixed_ctx = prev_exc[1].__context__
                    raise prev_exc[1]
                except BaseException:
                    prev_exc[1].__context__ = fixed_ctx
                    raise

        try:
            with ctx():
                raise e
        except Exception as exc:
            assert isinstance(exc, a)
            assert isinstance(exc.__context__, b)
            assert isinstance(exc.__context__.__context__, c)
            assert isinstance(exc.__context__.__context__.__context__, d)
            assert isinstance(exc.__context__.__context__.__context__.__context__, e)

    # TODO(anijain2305) - does not work with fullgraph=True
    def test_exception_with_another_exception2(self):
        def gn(x):
            try:
                x = torch.cos(x)
                raise NotImplementedError("Not implemented")
            except NotImplementedError:
                x = torch.sigmoid(x)
                raise

        def fn(x):
            try:
                x = torch.cos(x)
                gn(x)
            except Exception:
                pass
            return x

        x = torch.randn(4)
        fn(x)
        # Cant use fullgraph=True because RERAISE is not supported
        opt_fn = torch.compile(fn, backend="eager")
        opt_fn(x)

    def test_exception_with_ctx_manager(self):
        def fn(x):
            x = torch.cos(x)
            try:
                with torch.no_grad():
                    x = torch.sin(x)
                    raise NotImplementedError("Not implemented")
            except NotImplementedError:
                x = torch.sigmoid(x)
            return x

        x = torch.randn(4)
        ref = fn(x)
        opt_fn = torch.compile(fn, backend="eager", fullgraph=True)
        res = opt_fn(x)
        self.assertEqual(ref, res)

    def test_exception_raised_from_child(self):
        def gn():
            raise NotImplementedError("foo")

        def fn(x):
            x = torch.cos(x)
            try:
                x = torch.sin(x)
                gn()
                x = torch.sin(x)
            except Exception:
                x = torch.sigmoid(x)

            return x

        x = torch.randn(4)
        ref = fn(x)
        opt_fn = torch.compile(fn, backend="eager", fullgraph=True)
        res = opt_fn(x)
        self.assertEqual(ref, res)

    def test_dynamo_undo_kw_names(self):
        def g(x, k=None):
            if k:
                raise TypeError("error")
            return x.sin()

        def fn(x):
            d = {"a": x}
            try:
                g(x, k=True)
            except Exception:
                y = 0
                for _, b in d.items():  # noqa: PERF102
                    y += b.sum()
            return y

        x = torch.randn(2, 3)
        expected = fn(x)
        opt_fn = torch.compile(fn, backend="eager", fullgraph=True)
        got = opt_fn(x)
        self.assertEqual(expected, got)

    def test_raise_custom_exception(self):
        class Exc(Exception):
            pass

        @torch.compile(backend="eager", fullgraph=True)
        def fn(t):
            try:
                raise Exc
            except Exc:
                return t.sin()
            except Exception:
                return t.cos()

        t = torch.randn(2)
        y = fn(t)
        self.assertEqual(y, t.sin())

    def test_raise_custom_exception_with_args(self):
        class Exc(Exception):
            pass

        @torch.compile(backend="eager", fullgraph=True)
        def fn(t):
            try:
                raise Exc(1, 2.0)
            except Exc as e:
                return t.sin() + e.args[0] + e.args[1]
            except Exception:
                return t.cos()

        t = torch.randn(2)
        y = fn(t)
        self.assertEqual(y, t.sin() + 1 + 2.0)

    def test_nn_module_getattr(self):
        class A:
            def __init__(self) -> None:
                self._b = 20

            def __getattr__(self, name):
                fixed_name = "_" + name
                if fixed_name in self.__dict__:
                    return self.__dict__[fixed_name]
                raise AttributeError(f"{name} absent")

        class B(A):
            def __init__(self) -> None:
                self.a = 10

            def __getattr__(self, name):
                try:
                    return super().__getattr__(name)
                except AttributeError:
                    return 30

        obj = B()

        def fn(x):
            return x * obj.a * obj.b * obj.c

        x = torch.ones(4)
        ref = fn(x)
        print(ref)
        opt_fn = torch.compile(fn, backend="eager", fullgraph=True)
        res = opt_fn(x)
        self.assertEqual(ref, res)

    @torch._dynamo.config.patch(inline_inbuilt_nn_modules=True)
    def test_custom_getattr_on_module_exception(self):
        class Foo(torch.nn.Module):
            def __init__(self, a=3):
                super().__init__()
                self.register_parameter("a", torch.nn.Parameter(torch.ones(4) * 2))

            def __getattr__(self, name):
                try:
                    return super().__getattr__(name)  # defer to nn.Module's logic
                except AttributeError:
                    if name == "a_copy":
                        return self.a
                    raise

            def forward(self, x):
                return x * self.a * self.a_copy

        mod = Foo()
        opt_mod = torch.compile(mod, backend="eager", fullgraph=True)

        x = torch.ones(4)
        self.assertEqual(mod(x), opt_mod(x))

    def test_attribute_error_from_getattr(self):
        class Mock:
            def __init__(self):
                self.a = 5

            def __getattr__(self, name):
                if name != "a":
                    raise AttributeError("missing")
                return self.__dict__["a"]

        mock = Mock()

        def fn(x):
            if hasattr(mock, "b"):
                return torch.cos(x)
            return torch.sin(x)

        opt_fn = torch.compile(fn, backend="eager", fullgraph=True)
        x = torch.randn(4)
        ref = fn(x)
        res = opt_fn(x)
        self.assertEqual(ref, res)

    def test_stop_iteration(self):
        def zip_longest(*iterables, fillvalue=None):
            # Get the iterators for each iterable
            iterators = [iter(it) for it in iterables]

            result = []
            while True:
                for it in iterators:
                    try:
                        value = next(it)
                    except StopIteration:
                        result.append(fillvalue)
                        return result
                    result.append(value)

        def fn(x, y):
            torch.cos(torch.randn(4))
            return tuple(zip_longest(x, y))

        x = [1, 2, 3, 4]
        y = [10, 11, 12]

        opt_fn = torch.compile(fn, backend="eager", fullgraph=True)
        ref = fn(x, y)
        res = opt_fn(x, y)
        self.assertEqual(ref, res)

    def test_nn_reraise(self):
        class M(torch.nn.Module):
            def forward(self, x):
                raise ValueError("woof")
                return x + 2

        m = M()
        m.register_forward_pre_hook(lambda m, go: None)

        torch._dynamo.utils.clear_compilation_metrics()
        opt_call = torch.compile(lambda x: m(x), backend="eager")
        self.assertRaises(ValueError, lambda: opt_call(torch.randn(3)))
        metrics = torch._dynamo.utils.get_compilation_metrics()
        self.assertIn("Observed exception", metrics[0].fail_reason)

    def test_key_error(self):
        def fn(x, d):
            try:
                a = d["b"]
            except KeyError:
                a = 2
            return x * a

        opt_fn = torch.compile(fn, backend="eager", fullgraph=True)
        x = torch.randn(4)
        d = {"a": 1}
        ref = fn(x, d)
        res = opt_fn(x, d)
        self.assertEqual(ref, res)

    def test_atrribute_error(self):
        class Mock:
            def __init__(self):
                self.a = 1

        mock = Mock()

        def fn(x):
            try:
                c = 2
                mock.b
            except AttributeError:
                c = 3
            return torch.sin(x) * c

        opt_fn = torch.compile(fn, backend="eager")
        x = torch.randn(4)
        ref = fn(x)
        res = opt_fn(x)
        self.assertEqual(ref, res)

    def test_raise_from_None(self):
        # Inspired from os.environ
        class MyMapping:
            def __init__(self, d):
                self._d = d

            def __getitem__(self, key):
                try:
                    value = self._d[key]
                except KeyError:
                    raise KeyError(key) from None
                return value

        d = MyMapping({"a": 10, "b": 20})

        def mapping_get(obj, key, value=None):
            try:
                return obj.__getitem__(key)
            except KeyError:
                return value

        def fn(x, d, key):
            x = torch.sin(x + 1)
            return x, mapping_get(d, key)

        opt_fn = torch.compile(fn, backend="eager", fullgraph=True)

        x = torch.rand(2, 3)
        ref = fn(x, d, "m")
        res = opt_fn(x, d, "m")
        self.assertEqual(ref[0], res[0])
        self.assertEqual(ref[1], res[1])

    @make_dynamo_test
    def test_raise_from_None_2(self):
        def fn():
            try:
                raise ValueError
            except Exception:
                raise TypeError from None

        try:
            fn()
        except TypeError as e:
            assert e.__cause__ is None
            assert e.__suppress_context__ is True

    @make_dynamo_test
    def test_raise_from_other(self):
        def fn():
            try:
                raise ValueError
            except Exception as e:
                raise TypeError from e

        try:
            fn()
        except TypeError as e:
            assert isinstance(e.__cause__, ValueError)
            assert e.__suppress_context__ is True

    @make_dynamo_test
    def test_reraise_first_exc(self):
        def fn():
            try:
                raise ZeroDivisionError
            except ZeroDivisionError:
                try:
                    raise ValueError
                except ValueError:
                    pass
                raise

        try:
            fn()
        except ZeroDivisionError:
            pass
        assert sys.exc_info()[0] is None

    @make_dynamo_test
    def test_ensure_exception_is_active_after_try_except_block(self):
        try:
            try:
                raise ZeroDivisionError
            except ZeroDivisionError:
                for exc in (KeyError, IndexError):
                    try:
                        raise exc
                    except exc:
                        pass
                raise
        except ZeroDivisionError:
            pass
        assert sys.exc_info()[0] is None

    @make_dynamo_test
    def test_ensure_exception_is_active_inside_try_except_block(self):
        try:
            try:
                raise ZeroDivisionError
            except ZeroDivisionError:
                for exc in (KeyError, IndexError):
                    try:
                        raise exc
                    except exc as e:
                        assert isinstance(e.__context__, ZeroDivisionError)
                raise
        except ZeroDivisionError:
            pass
        assert sys.exc_info()[0] is None

    @make_dynamo_test
    def test_handle_all_exceptions(self):
        def cm():
            try:
                yield 1
            except ValueError:
                try:
                    raise TypeError
                finally:
                    pass

        try:
            gen = cm()
            next(gen)
            gen.throw(ValueError)
        except TypeError:
            pass
        assert sys.exc_info()[0] is None

    @make_dynamo_test
    def test_reraise(self):
        try:
            try:
                raise ValueError
            except ValueError:  # noqa: TRY203
                raise
        except ValueError:
            pass
        assert sys.exc_info()[0] is None

    @make_dynamo_test
    def test_raise_finally_simple(self):
        def fn():
            try:
                raise ValueError
            except ValueError:
                try:
                    raise TypeError
                finally:
                    pass

        try:
            fn()
        except TypeError:
            pass
        assert sys.exc_info()[0] is None

    def test_reconstruct___context__(self):
        @torch.compile(backend="eager", fullgraph=True)
        def fn(t):
            v = ValueError(1, 2, 3)
            v.__context__ = TypeError()
            v.__cause__ = RuntimeError()
            return t.sin(), v

        t = torch.randn(2)
        y, v = fn(t)
        self.assertEqual(y, t.sin())
        self.assertIsInstance(v, ValueError)
        self.assertIsInstance(v.__context__, TypeError)
        self.assertIsInstance(v.__cause__, RuntimeError)
        self.assertTrue(v.__suppress_context__)

    def test_reconstruct_exception_2(self):
        @torch.compile(backend="eager", fullgraph=True)
        def fn(t):
            try:
                raise ValueError(1, 2, 3)
            except Exception:
                try:
                    raise TypeError(4, 5) from None
                except Exception as e:
                    e.__cause__ = RuntimeError(6, 7)
                    return t.sin(), e

        t = torch.randn(2)
        y, v = fn(t)
        self.assertEqual(y, t.sin())
        self.assertIsInstance(v, TypeError)
        self.assertIsInstance(v.__context__, ValueError)
        self.assertIsInstance(v.__cause__, RuntimeError)

    def test_raise_GeneratorExit(self):
        # GeneratorExit does not inherit from Exception
        @torch.compile(backend="eager", fullgraph=True)
        def fn(t):
            try:
                raise GeneratorExit
            except Exception:
                return t.sin()
            except BaseException:
                return t.cos()

        t = torch.randn(2)
        y = fn(t)
        self.assertEqual(y, t.cos())

    def test_speculation_exception(self):
        log = SpeculationLog()
        log.next("fake", 555, "fake", Instruction(1, "fake", 1, 1))
        log.restart()
        with self.assertRaises(SpeculationLogDivergence):
            log.next("bad", 58, "bad", Instruction(2, "different", 2, 2))

    def test_dict_pop(self):
        # Pattern from inspect.bind
        def fn(dt, x):
            try:
                dt.pop("b")
            except KeyError:
                return torch.sin(x)
            else:
                return torch.cos(x)

        d = {"a": 1}
        opt_fn = torch.compile(fn, backend="eager", fullgraph=True)

        x = torch.randn(4)
        self.assertEqual(fn(d, x), opt_fn(d, x))
        self.assertEqual(fn({"a": 1, "b": 2}, x), opt_fn({"a": 1, "b": 2}, x))

    def test_block_stack_cleanup(self):
        params = {
            "a": 3,
            "b": 4,
            "c": 5,
        }

        dt = {
            "c": 5,
        }

        def fn(x):
            for name in params:
                try:
                    x = x * dt[name]
                except KeyError:
                    x = x * torch.sin(x)
            return x

        opt_fn = torch.compile(fn, backend="eager", fullgraph=True)
        x = torch.randn(4)
        self.assertEqual(fn(x), opt_fn(x))

    def test_set_cause_with_arg(self):
        @torch.compile(backend="eager", fullgraph=True)
        def fn(t, err):
            err.__cause__ = ValueError()
            return t.sin()

        t = torch.randn(2)
        e = TypeError("abcd")
        fn(t, e)
        self.assertIsInstance(e.__cause__, ValueError)

    def test_set_cause_with_arg_error(self):
        @torch.compile(backend="eager", fullgraph=True)
        def fn(t, err):
            err.__cause__ = 2
            return t.sin()

        t = torch.randn(2)
        e = TypeError("abcd")
        with self.assertRaisesRegex(TypeError, "exception cause must be"):
            fn(t, e)

    @parametrize(
        "ex",
        [TypeError, CustomException],
        name_fn=lambda x: x.__name__,
    )
    @make_dynamo_test
    def test_set___cause__(self, ex):
        def fn():
            try:
                raise ex
            except ex:
                raise TypeError from None

        try:
            fn()
        except TypeError as e:
            assert isinstance(e.__context__, ex)
            assert e.__cause__ is None
            assert e.__suppress_context__ is True

    @parametrize(
        "ex",
        [RuntimeError, CustomException],
        name_fn=lambda x: x.__name__,
    )
    @make_dynamo_test
    def test_set___cause___error(self, ex):
        def fn():
            try:
                raise ex
            except Exception as e:
                e.__cause__ = 2
                raise

        z = 0

        try:
            fn()
        except TypeError as e:
            z = 1
            assert e.args == (
                "exception cause must be None or derive from BaseException",
            )
        except Exception:
            raise AssertionError from None

        assert z == 1

    def test_user_defined_exception_variable(self):
        @torch.compile(backend="eager", fullgraph=True)
        def fn(t):
            z = 0
            try:
                raise CustomException
            except ValueError:
                z = 1
            except CustomException:
                z = 2
            assert z == 2
            return t.sin()

        t = torch.randn(2)
        fn(t)

    def test_user_defined_exception_with_args(self):
        @torch.compile(backend="eager", fullgraph=True)
        def fn(t):
            z = 0
            try:
                raise CustomExceptionWithArgs(2, b=3)
            except ValueError:
                z = 1
            except CustomExceptionWithArgs:
                z = 2
            assert z == 2

        t = torch.randn(2)
        fn(t)

    @make_dynamo_test
    def test_raise_set___context__(self):
        try:
            raise TypeError
        except TypeError as e:
            exc = e

        assert exc.__context__ is None

        try:
            raise ValueError
        except ValueError as e:
            exc2 = e

        assert exc2.__context__ is None


instantiate_parametrized_tests(ExceptionTests)


if __name__ == "__main__":
    from torch._dynamo.test_case import run_tests

    run_tests()
