diff --git a/test/dynamo/cpython/3_13/test_sort.py b/test/dynamo/cpython/3_13/test_sort.py index 2a7cfb7affa..d661ae544b9 100644 --- a/test/dynamo/cpython/3_13/test_sort.py +++ b/test/dynamo/cpython/3_13/test_sort.py @@ -1,3 +1,54 @@ +# ======= BEGIN Dynamo patch ======= +# Owner(s): ["module: dynamo"] + +# ruff: noqa +# flake8: noqa + +import sys +import torch +import torch._dynamo.test_case +import unittest +from torch._dynamo.test_case import CPythonTestCase +from torch.testing._internal.common_utils import run_tests + +__TestCase = CPythonTestCase + + +# redirect import statements +import sys +import importlib.abc + +redirect_imports = ( + "test.mapping_tests", + "test.typinganndata", + "test.test_grammar", + "test.test_math", + "test.test_iter", + "test.typinganndata.ann_module", +) + +class RedirectImportFinder(importlib.abc.MetaPathFinder): + def find_spec(self, fullname, path, target=None): + # Check if the import is the problematic one + if fullname in redirect_imports: + try: + # Attempt to import the standalone module + name = fullname.removeprefix("test.") + r = importlib.import_module(name) + # Redirect the module in sys.modules + sys.modules[fullname] = r + # Return a module spec from the found module + return importlib.util.find_spec(name) + except ImportError: + return None + return None + +# Add the custom finder to sys.meta_path +sys.meta_path.insert(0, RedirectImportFinder()) + + +# ======= END DYNAMO PATCH ======= + from test import support import random import unittest @@ -39,7 +90,7 @@ def check(tag, expected, raw, compare=None): nerrors += 1 return -class TestBase(unittest.TestCase): +class TestBase(__TestCase): def testStressfully(self): # Try a variety of sizes at and around powers of 2, and at powers of 10. sizes = [0] @@ -151,7 +202,7 @@ class TestBase(unittest.TestCase): self.assertEqual(forced, native) #============================================================================== -class TestBugs(unittest.TestCase): +class TestBugs(__TestCase): def test_bug453523(self): # bug 453523 -- list.sort() crasher. @@ -188,7 +239,7 @@ class TestBugs(unittest.TestCase): #============================================================================== -class TestDecorateSortUndecorate(unittest.TestCase): +class TestDecorateSortUndecorate(__TestCase): def test_decorated(self): data = 'The quick Brown fox Jumped over The lazy Dog'.split() @@ -309,7 +360,7 @@ def check_against_PyObject_RichCompareBool(self, L): self.assertIs(opt, ref) #note: not assertEqual! We want to ensure *identical* behavior. -class TestOptimizedCompares(unittest.TestCase): +class TestOptimizedCompares(__TestCase): def test_safe_object_compare(self): heterogeneous_lists = [[0, 'foo'], [0.0, 'foo'], @@ -408,4 +459,4 @@ class TestOptimizedCompares(unittest.TestCase): #============================================================================== if __name__ == "__main__": - unittest.main() + run_tests()