diff --git a/test/dynamo/cpython/3_13/test_unittest/test_assertions.py b/test/dynamo/cpython/3_13/test_unittest/test_assertions.py index 1dec947ea76..5a8c2a9d3af 100644 --- a/test/dynamo/cpython/3_13/test_unittest/test_assertions.py +++ b/test/dynamo/cpython/3_13/test_unittest/test_assertions.py @@ -1,3 +1,54 @@ +# ======= BEGIN Dynamo patch ======= +# Owner(s): ["module: dynamo"] + +# ruff: noqa +# flake8: noqa + +import sys +import torch +import torch._dynamo.test_case +import unittest +from torch.testing._internal.common_utils import run_tests + + +__TestCase = torch._dynamo.test_case.CPythonTestCase + + +# redirect import statements +import sys +import importlib.abc + +redirect_imports = ( + "test.mapping_tests", + "test.typinganndata", + "test.test_grammar", + "test.test_math", + "test.test_iter", + "test.typinganndata.ann_module", +) + +class RedirectImportFinder(importlib.abc.MetaPathFinder): + def find_spec(self, fullname, path, target=None): + # Check if the import is the problematic one + if fullname in redirect_imports: + try: + # Attempt to import the standalone module + name = fullname.removeprefix("test.") + r = importlib.import_module(name) + # Redirect the module in sys.modules + sys.modules[fullname] = r + # Return a module spec from the found module + return importlib.util.find_spec(name) + except ImportError: + return None + return None + +# Add the custom finder to sys.meta_path +sys.meta_path.insert(0, RedirectImportFinder()) + + +# ======= END DYNAMO PATCH ======= + import datetime import warnings import weakref @@ -6,7 +57,7 @@ from test.support import gc_collect from itertools import product -class Test_Assertions(unittest.TestCase): +class Test_Assertions(__TestCase): def test_AlmostEqual(self): self.assertAlmostEqual(1.00000001, 1.0) self.assertNotAlmostEqual(1.0000001, 1.0) @@ -141,12 +192,13 @@ class Test_Assertions(unittest.TestCase): self.fail('assertNotRegex should have failed.') -class TestLongMessage(unittest.TestCase): +class TestLongMessage(__TestCase): """Test that the individual asserts honour longMessage. This actually tests all the message behaviour for asserts that use longMessage.""" def setUp(self): + super().setUp() class TestableTestFalse(unittest.TestCase): longMessage = False failureException = self.failureException @@ -414,4 +466,4 @@ class TestLongMessage(unittest.TestCase): if __name__ == "__main__": - unittest.main() + run_tests()