_warnings.py 8.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232
  1. from __future__ import annotations
  2. from contextlib import (
  3. contextmanager,
  4. nullcontext,
  5. )
  6. import inspect
  7. import re
  8. import sys
  9. from typing import (
  10. TYPE_CHECKING,
  11. Literal,
  12. cast,
  13. )
  14. import warnings
  15. from pandas.compat import PY311
  16. if TYPE_CHECKING:
  17. from collections.abc import (
  18. Generator,
  19. Sequence,
  20. )
  21. @contextmanager
  22. def assert_produces_warning(
  23. expected_warning: type[Warning] | bool | tuple[type[Warning], ...] | None = Warning,
  24. filter_level: Literal[
  25. "error", "ignore", "always", "default", "module", "once"
  26. ] = "always",
  27. check_stacklevel: bool = True,
  28. raise_on_extra_warnings: bool = True,
  29. match: str | None = None,
  30. ) -> Generator[list[warnings.WarningMessage], None, None]:
  31. """
  32. Context manager for running code expected to either raise a specific warning,
  33. multiple specific warnings, or not raise any warnings. Verifies that the code
  34. raises the expected warning(s), and that it does not raise any other unexpected
  35. warnings. It is basically a wrapper around ``warnings.catch_warnings``.
  36. Parameters
  37. ----------
  38. expected_warning : {Warning, False, tuple[Warning, ...], None}, default Warning
  39. The type of Exception raised. ``exception.Warning`` is the base
  40. class for all warnings. To raise multiple types of exceptions,
  41. pass them as a tuple. To check that no warning is returned,
  42. specify ``False`` or ``None``.
  43. filter_level : str or None, default "always"
  44. Specifies whether warnings are ignored, displayed, or turned
  45. into errors.
  46. Valid values are:
  47. * "error" - turns matching warnings into exceptions
  48. * "ignore" - discard the warning
  49. * "always" - always emit a warning
  50. * "default" - print the warning the first time it is generated
  51. from each location
  52. * "module" - print the warning the first time it is generated
  53. from each module
  54. * "once" - print the warning the first time it is generated
  55. check_stacklevel : bool, default True
  56. If True, displays the line that called the function containing
  57. the warning to show were the function is called. Otherwise, the
  58. line that implements the function is displayed.
  59. raise_on_extra_warnings : bool, default True
  60. Whether extra warnings not of the type `expected_warning` should
  61. cause the test to fail.
  62. match : str, optional
  63. Match warning message.
  64. Examples
  65. --------
  66. >>> import warnings
  67. >>> with assert_produces_warning():
  68. ... warnings.warn(UserWarning())
  69. ...
  70. >>> with assert_produces_warning(False):
  71. ... warnings.warn(RuntimeWarning())
  72. ...
  73. Traceback (most recent call last):
  74. ...
  75. AssertionError: Caused unexpected warning(s): ['RuntimeWarning'].
  76. >>> with assert_produces_warning(UserWarning):
  77. ... warnings.warn(RuntimeWarning())
  78. Traceback (most recent call last):
  79. ...
  80. AssertionError: Did not see expected warning of class 'UserWarning'.
  81. ..warn:: This is *not* thread-safe.
  82. """
  83. __tracebackhide__ = True
  84. with warnings.catch_warnings(record=True) as w:
  85. warnings.simplefilter(filter_level)
  86. try:
  87. yield w
  88. finally:
  89. if expected_warning:
  90. expected_warning = cast(type[Warning], expected_warning)
  91. _assert_caught_expected_warning(
  92. caught_warnings=w,
  93. expected_warning=expected_warning,
  94. match=match,
  95. check_stacklevel=check_stacklevel,
  96. )
  97. if raise_on_extra_warnings:
  98. _assert_caught_no_extra_warnings(
  99. caught_warnings=w,
  100. expected_warning=expected_warning,
  101. )
  102. def maybe_produces_warning(warning: type[Warning], condition: bool, **kwargs):
  103. """
  104. Return a context manager that possibly checks a warning based on the condition
  105. """
  106. if condition:
  107. return assert_produces_warning(warning, **kwargs)
  108. else:
  109. return nullcontext()
  110. def _assert_caught_expected_warning(
  111. *,
  112. caught_warnings: Sequence[warnings.WarningMessage],
  113. expected_warning: type[Warning],
  114. match: str | None,
  115. check_stacklevel: bool,
  116. ) -> None:
  117. """Assert that there was the expected warning among the caught warnings."""
  118. saw_warning = False
  119. matched_message = False
  120. unmatched_messages = []
  121. for actual_warning in caught_warnings:
  122. if issubclass(actual_warning.category, expected_warning):
  123. saw_warning = True
  124. if check_stacklevel:
  125. _assert_raised_with_correct_stacklevel(actual_warning)
  126. if match is not None:
  127. if re.search(match, str(actual_warning.message)):
  128. matched_message = True
  129. else:
  130. unmatched_messages.append(actual_warning.message)
  131. if not saw_warning:
  132. raise AssertionError(
  133. f"Did not see expected warning of class "
  134. f"{repr(expected_warning.__name__)}"
  135. )
  136. if match and not matched_message:
  137. raise AssertionError(
  138. f"Did not see warning {repr(expected_warning.__name__)} "
  139. f"matching '{match}'. The emitted warning messages are "
  140. f"{unmatched_messages}"
  141. )
  142. def _assert_caught_no_extra_warnings(
  143. *,
  144. caught_warnings: Sequence[warnings.WarningMessage],
  145. expected_warning: type[Warning] | bool | tuple[type[Warning], ...] | None,
  146. ) -> None:
  147. """Assert that no extra warnings apart from the expected ones are caught."""
  148. extra_warnings = []
  149. for actual_warning in caught_warnings:
  150. if _is_unexpected_warning(actual_warning, expected_warning):
  151. # GH#38630 pytest.filterwarnings does not suppress these.
  152. if actual_warning.category == ResourceWarning:
  153. # GH 44732: Don't make the CI flaky by filtering SSL-related
  154. # ResourceWarning from dependencies
  155. if "unclosed <ssl.SSLSocket" in str(actual_warning.message):
  156. continue
  157. # GH 44844: Matplotlib leaves font files open during the entire process
  158. # upon import. Don't make CI flaky if ResourceWarning raised
  159. # due to these open files.
  160. if any("matplotlib" in mod for mod in sys.modules):
  161. continue
  162. if PY311 and actual_warning.category == EncodingWarning:
  163. # EncodingWarnings are checked in the CI
  164. # pyproject.toml errors on EncodingWarnings in pandas
  165. # Ignore EncodingWarnings from other libraries
  166. continue
  167. extra_warnings.append(
  168. (
  169. actual_warning.category.__name__,
  170. actual_warning.message,
  171. actual_warning.filename,
  172. actual_warning.lineno,
  173. )
  174. )
  175. if extra_warnings:
  176. raise AssertionError(f"Caused unexpected warning(s): {repr(extra_warnings)}")
  177. def _is_unexpected_warning(
  178. actual_warning: warnings.WarningMessage,
  179. expected_warning: type[Warning] | bool | tuple[type[Warning], ...] | None,
  180. ) -> bool:
  181. """Check if the actual warning issued is unexpected."""
  182. if actual_warning and not expected_warning:
  183. return True
  184. expected_warning = cast(type[Warning], expected_warning)
  185. return bool(not issubclass(actual_warning.category, expected_warning))
  186. def _assert_raised_with_correct_stacklevel(
  187. actual_warning: warnings.WarningMessage,
  188. ) -> None:
  189. # https://stackoverflow.com/questions/17407119/python-inspect-stack-is-slow
  190. frame = inspect.currentframe()
  191. for _ in range(4):
  192. frame = frame.f_back # type: ignore[union-attr]
  193. try:
  194. caller_filename = inspect.getfile(frame) # type: ignore[arg-type]
  195. finally:
  196. # See note in
  197. # https://docs.python.org/3/library/inspect.html#inspect.Traceback
  198. del frame
  199. msg = (
  200. "Warning not set with correct stacklevel. "
  201. f"File where warning is raised: {actual_warning.filename} != "
  202. f"{caller_filename}. Warning message: {actual_warning.message}"
  203. )
  204. assert actual_warning.filename == caller_filename, msg