test_case.py 7.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220
  1. """Testing utilities for Dynamo, providing a specialized TestCase class and test running functionality.
  2. This module extends PyTorch's testing framework with Dynamo-specific testing capabilities.
  3. It includes:
  4. - A custom TestCase class that handles Dynamo-specific setup/teardown
  5. - Test running utilities with dependency checking
  6. - Automatic reset of Dynamo state between tests
  7. - Proper handling of gradient mode state
  8. """
  9. import contextlib
  10. import importlib
  11. import inspect
  12. import logging
  13. import os
  14. import re
  15. import sys
  16. import unittest
  17. from typing import Any, Callable, Union
  18. import torch
  19. import torch.testing
  20. from torch._dynamo import polyfills
  21. from torch._logging._internal import trace_log
  22. from torch.testing._internal.common_utils import ( # type: ignore[attr-defined]
  23. IS_WINDOWS,
  24. TEST_WITH_CROSSREF,
  25. TEST_WITH_TORCHDYNAMO,
  26. TestCase as TorchTestCase,
  27. )
  28. from . import config, reset, utils
  29. log = logging.getLogger(__name__)
  30. def run_tests(needs: Union[str, tuple[str, ...]] = ()) -> None:
  31. from torch.testing._internal.common_utils import run_tests
  32. if TEST_WITH_TORCHDYNAMO or TEST_WITH_CROSSREF:
  33. return # skip testing
  34. if (
  35. not torch.xpu.is_available()
  36. and IS_WINDOWS
  37. and os.environ.get("TORCHINDUCTOR_WINDOWS_TESTS", "0") == "0"
  38. ):
  39. return
  40. if isinstance(needs, str):
  41. needs = (needs,)
  42. for need in needs:
  43. if need == "cuda":
  44. if not torch.cuda.is_available():
  45. return
  46. else:
  47. try:
  48. importlib.import_module(need)
  49. except ImportError:
  50. return
  51. run_tests()
  52. class TestCase(TorchTestCase):
  53. _exit_stack: contextlib.ExitStack
  54. @classmethod
  55. def tearDownClass(cls) -> None:
  56. cls._exit_stack.close()
  57. super().tearDownClass()
  58. @classmethod
  59. def setUpClass(cls) -> None:
  60. super().setUpClass()
  61. cls._exit_stack = contextlib.ExitStack() # type: ignore[attr-defined]
  62. cls._exit_stack.enter_context( # type: ignore[attr-defined]
  63. config.patch(
  64. raise_on_ctx_manager_usage=True,
  65. suppress_errors=False,
  66. log_compilation_metrics=False,
  67. ),
  68. )
  69. def setUp(self) -> None:
  70. self._prior_is_grad_enabled = torch.is_grad_enabled()
  71. super().setUp()
  72. reset()
  73. utils.counters.clear()
  74. self.handler = logging.NullHandler()
  75. trace_log.addHandler(self.handler)
  76. def tearDown(self) -> None:
  77. trace_log.removeHandler(self.handler)
  78. for k, v in utils.counters.items():
  79. print(k, v.most_common())
  80. reset()
  81. utils.counters.clear()
  82. super().tearDown()
  83. if self._prior_is_grad_enabled is not torch.is_grad_enabled():
  84. log.warning("Running test changed grad mode")
  85. torch.set_grad_enabled(self._prior_is_grad_enabled)
  86. def assertEqual(self, x: Any, y: Any, *args: Any, **kwargs: Any) -> None: # type: ignore[override]
  87. if (
  88. config.debug_disable_compile_counter
  89. and isinstance(x, utils.CompileCounterInt)
  90. or isinstance(y, utils.CompileCounterInt)
  91. ):
  92. return
  93. return super().assertEqual(x, y, *args, **kwargs)
  94. # assertExpectedInline might also need to be disabled for wrapped nested
  95. # graph break tests
  96. class CPythonTestCase(TestCase):
  97. """
  98. Test class for CPython tests located in "test/dynamo/CPython/Py_version/*".
  99. This class enables specific features that are disabled by default, such as
  100. tracing through unittest methods.
  101. """
  102. _stack: contextlib.ExitStack
  103. dynamo_strict_nopython = True
  104. # Restore original unittest methods to simplify tracing CPython test cases.
  105. assertEqual = unittest.TestCase.assertEqual # type: ignore[assignment]
  106. assertNotEqual = unittest.TestCase.assertNotEqual # type: ignore[assignment]
  107. assertTrue = unittest.TestCase.assertTrue
  108. assertFalse = unittest.TestCase.assertFalse
  109. assertIs = unittest.TestCase.assertIs
  110. assertIsNot = unittest.TestCase.assertIsNot
  111. assertIsNone = unittest.TestCase.assertIsNone
  112. assertIsNotNone = unittest.TestCase.assertIsNotNone
  113. assertIn = unittest.TestCase.assertIn
  114. assertNotIn = unittest.TestCase.assertNotIn
  115. assertIsInstance = unittest.TestCase.assertIsInstance
  116. assertNotIsInstance = unittest.TestCase.assertNotIsInstance
  117. assertAlmostEqual = unittest.TestCase.assertAlmostEqual
  118. assertNotAlmostEqual = unittest.TestCase.assertNotAlmostEqual
  119. assertGreater = unittest.TestCase.assertGreater
  120. assertGreaterEqual = unittest.TestCase.assertGreaterEqual
  121. assertLess = unittest.TestCase.assertLess
  122. assertLessEqual = unittest.TestCase.assertLessEqual
  123. assertRegex = unittest.TestCase.assertRegex
  124. assertNotRegex = unittest.TestCase.assertNotRegex
  125. assertCountEqual = unittest.TestCase.assertCountEqual
  126. assertMultiLineEqual = polyfills.assert_multi_line_equal
  127. assertSequenceEqual = polyfills.assert_sequence_equal
  128. assertListEqual = unittest.TestCase.assertListEqual
  129. assertTupleEqual = unittest.TestCase.assertTupleEqual
  130. assertSetEqual = unittest.TestCase.assertSetEqual
  131. assertDictEqual = polyfills.assert_dict_equal
  132. assertRaises = unittest.TestCase.assertRaises
  133. assertRaisesRegex = unittest.TestCase.assertRaisesRegex
  134. assertWarns = unittest.TestCase.assertWarns
  135. assertWarnsRegex = unittest.TestCase.assertWarnsRegex
  136. assertLogs = unittest.TestCase.assertLogs
  137. fail = unittest.TestCase.fail
  138. failureException = unittest.TestCase.failureException
  139. def compile_fn(
  140. self,
  141. fn: Callable[..., Any],
  142. backend: Union[str, Callable[..., Any]],
  143. nopython: bool,
  144. ) -> Callable[..., Any]:
  145. # We want to compile only the test function, excluding any setup code
  146. # from unittest
  147. method = getattr(self, self._testMethodName)
  148. method = torch._dynamo.optimize(backend, error_on_graph_break=nopython)(method)
  149. setattr(self, self._testMethodName, method)
  150. return fn
  151. def _dynamo_test_key(self) -> str:
  152. suffix = super()._dynamo_test_key()
  153. test_cls = self.__class__
  154. test_file = inspect.getfile(test_cls).split(os.sep)[-1].split(".")[0]
  155. py_ver = re.search(r"/([\d_]+)/", inspect.getfile(test_cls))
  156. if py_ver:
  157. py_ver = py_ver.group().strip(os.sep).replace("_", "") # type: ignore[assignment]
  158. else:
  159. return suffix
  160. return f"CPython{py_ver}-{test_file}-{suffix}"
  161. @classmethod
  162. def tearDownClass(cls) -> None:
  163. cls._stack.close()
  164. super().tearDownClass()
  165. @classmethod
  166. def setUpClass(cls) -> None:
  167. # Skip test if python versions doesn't match
  168. prefix = os.path.join("dynamo", "cpython") + os.path.sep
  169. regex = re.escape(prefix) + r"\d_\d{2}"
  170. search_path = inspect.getfile(cls)
  171. m = re.search(regex, search_path)
  172. if m:
  173. test_py_ver = tuple(map(int, m.group().removeprefix(prefix).split("_")))
  174. py_ver = sys.version_info[:2]
  175. if py_ver < test_py_ver:
  176. expected = ".".join(map(str, test_py_ver))
  177. got = ".".join(map(str, py_ver))
  178. raise unittest.SkipTest(
  179. f"Test requires Python {expected} but got Python {got}"
  180. )
  181. else:
  182. raise unittest.SkipTest(
  183. f"Test requires a specific Python version but not found in path {inspect.getfile(cls)}"
  184. )
  185. super().setUpClass()
  186. cls._stack = contextlib.ExitStack() # type: ignore[attr-defined]
  187. cls._stack.enter_context( # type: ignore[attr-defined]
  188. config.patch(
  189. enable_trace_unittest=True,
  190. ),
  191. )