contexts.py 6.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261
  1. from __future__ import annotations
  2. from contextlib import contextmanager
  3. import os
  4. from pathlib import Path
  5. import tempfile
  6. from typing import (
  7. IO,
  8. TYPE_CHECKING,
  9. Any,
  10. )
  11. import uuid
  12. from pandas._config import using_copy_on_write
  13. from pandas.compat import (
  14. PYPY,
  15. WARNING_CHECK_DISABLED,
  16. )
  17. from pandas.errors import ChainedAssignmentError
  18. from pandas import set_option
  19. from pandas.io.common import get_handle
  20. if TYPE_CHECKING:
  21. from collections.abc import Generator
  22. from pandas._typing import (
  23. BaseBuffer,
  24. CompressionOptions,
  25. FilePath,
  26. )
  27. @contextmanager
  28. def decompress_file(
  29. path: FilePath | BaseBuffer, compression: CompressionOptions
  30. ) -> Generator[IO[bytes], None, None]:
  31. """
  32. Open a compressed file and return a file object.
  33. Parameters
  34. ----------
  35. path : str
  36. The path where the file is read from.
  37. compression : {'gzip', 'bz2', 'zip', 'xz', 'zstd', None}
  38. Name of the decompression to use
  39. Returns
  40. -------
  41. file object
  42. """
  43. with get_handle(path, "rb", compression=compression, is_text=False) as handle:
  44. yield handle.handle
  45. @contextmanager
  46. def set_timezone(tz: str) -> Generator[None, None, None]:
  47. """
  48. Context manager for temporarily setting a timezone.
  49. Parameters
  50. ----------
  51. tz : str
  52. A string representing a valid timezone.
  53. Examples
  54. --------
  55. >>> from datetime import datetime
  56. >>> from dateutil.tz import tzlocal
  57. >>> tzlocal().tzname(datetime(2021, 1, 1)) # doctest: +SKIP
  58. 'IST'
  59. >>> with set_timezone('US/Eastern'):
  60. ... tzlocal().tzname(datetime(2021, 1, 1))
  61. ...
  62. 'EST'
  63. """
  64. import time
  65. def setTZ(tz) -> None:
  66. if hasattr(time, "tzset"):
  67. if tz is None:
  68. try:
  69. del os.environ["TZ"]
  70. except KeyError:
  71. pass
  72. else:
  73. os.environ["TZ"] = tz
  74. time.tzset()
  75. orig_tz = os.environ.get("TZ")
  76. setTZ(tz)
  77. try:
  78. yield
  79. finally:
  80. setTZ(orig_tz)
  81. @contextmanager
  82. def ensure_clean(
  83. filename=None, return_filelike: bool = False, **kwargs: Any
  84. ) -> Generator[Any, None, None]:
  85. """
  86. Gets a temporary path and agrees to remove on close.
  87. This implementation does not use tempfile.mkstemp to avoid having a file handle.
  88. If the code using the returned path wants to delete the file itself, windows
  89. requires that no program has a file handle to it.
  90. Parameters
  91. ----------
  92. filename : str (optional)
  93. suffix of the created file.
  94. return_filelike : bool (default False)
  95. if True, returns a file-like which is *always* cleaned. Necessary for
  96. savefig and other functions which want to append extensions.
  97. **kwargs
  98. Additional keywords are passed to open().
  99. """
  100. folder = Path(tempfile.gettempdir())
  101. if filename is None:
  102. filename = ""
  103. filename = str(uuid.uuid4()) + filename
  104. path = folder / filename
  105. path.touch()
  106. handle_or_str: str | IO = str(path)
  107. encoding = kwargs.pop("encoding", None)
  108. if return_filelike:
  109. kwargs.setdefault("mode", "w+b")
  110. if encoding is None and "b" not in kwargs["mode"]:
  111. encoding = "utf-8"
  112. handle_or_str = open(path, encoding=encoding, **kwargs)
  113. try:
  114. yield handle_or_str
  115. finally:
  116. if not isinstance(handle_or_str, str):
  117. handle_or_str.close()
  118. if path.is_file():
  119. path.unlink()
  120. @contextmanager
  121. def with_csv_dialect(name: str, **kwargs) -> Generator[None, None, None]:
  122. """
  123. Context manager to temporarily register a CSV dialect for parsing CSV.
  124. Parameters
  125. ----------
  126. name : str
  127. The name of the dialect.
  128. kwargs : mapping
  129. The parameters for the dialect.
  130. Raises
  131. ------
  132. ValueError : the name of the dialect conflicts with a builtin one.
  133. See Also
  134. --------
  135. csv : Python's CSV library.
  136. """
  137. import csv
  138. _BUILTIN_DIALECTS = {"excel", "excel-tab", "unix"}
  139. if name in _BUILTIN_DIALECTS:
  140. raise ValueError("Cannot override builtin dialect.")
  141. csv.register_dialect(name, **kwargs)
  142. try:
  143. yield
  144. finally:
  145. csv.unregister_dialect(name)
  146. @contextmanager
  147. def use_numexpr(use, min_elements=None) -> Generator[None, None, None]:
  148. from pandas.core.computation import expressions as expr
  149. if min_elements is None:
  150. min_elements = expr._MIN_ELEMENTS
  151. olduse = expr.USE_NUMEXPR
  152. oldmin = expr._MIN_ELEMENTS
  153. set_option("compute.use_numexpr", use)
  154. expr._MIN_ELEMENTS = min_elements
  155. try:
  156. yield
  157. finally:
  158. expr._MIN_ELEMENTS = oldmin
  159. set_option("compute.use_numexpr", olduse)
  160. def raises_chained_assignment_error(warn=True, extra_warnings=(), extra_match=()):
  161. from pandas._testing import assert_produces_warning
  162. if not warn:
  163. from contextlib import nullcontext
  164. return nullcontext()
  165. if (PYPY or WARNING_CHECK_DISABLED) and not extra_warnings:
  166. from contextlib import nullcontext
  167. return nullcontext()
  168. elif (PYPY or WARNING_CHECK_DISABLED) and extra_warnings:
  169. return assert_produces_warning(
  170. extra_warnings,
  171. match="|".join(extra_match),
  172. )
  173. else:
  174. if using_copy_on_write():
  175. warning = ChainedAssignmentError
  176. match = (
  177. "A value is trying to be set on a copy of a DataFrame or Series "
  178. "through chained assignment"
  179. )
  180. else:
  181. warning = FutureWarning # type: ignore[assignment]
  182. # TODO update match
  183. match = "ChainedAssignmentError"
  184. if extra_warnings:
  185. warning = (warning, *extra_warnings) # type: ignore[assignment]
  186. return assert_produces_warning(
  187. warning,
  188. match="|".join((match, *extra_match)),
  189. )
  190. def assert_cow_warning(warn=True, match=None, **kwargs):
  191. """
  192. Assert that a warning is raised in the CoW warning mode.
  193. Parameters
  194. ----------
  195. warn : bool, default True
  196. By default, check that a warning is raised. Can be turned off by passing False.
  197. match : str
  198. The warning message to match against, if different from the default.
  199. kwargs
  200. Passed through to assert_produces_warning
  201. """
  202. from pandas._testing import assert_produces_warning
  203. if not warn or WARNING_CHECK_DISABLED:
  204. from contextlib import nullcontext
  205. return nullcontext()
  206. if not match:
  207. match = "Setting a value on a view"
  208. return assert_produces_warning(FutureWarning, match=match, **kwargs)