context.py 9.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361
  1. from __future__ import annotations
  2. import contextlib
  3. import functools
  4. import operator
  5. import os
  6. import shutil
  7. import subprocess
  8. import sys
  9. import tempfile
  10. import urllib.request
  11. import warnings
  12. from typing import Iterator
  13. if sys.version_info < (3, 12):
  14. from backports import tarfile
  15. else:
  16. import tarfile
  17. @contextlib.contextmanager
  18. def pushd(dir: str | os.PathLike) -> Iterator[str | os.PathLike]:
  19. """
  20. >>> tmp_path = getfixture('tmp_path')
  21. >>> with pushd(tmp_path):
  22. ... assert os.getcwd() == os.fspath(tmp_path)
  23. >>> assert os.getcwd() != os.fspath(tmp_path)
  24. """
  25. orig = os.getcwd()
  26. os.chdir(dir)
  27. try:
  28. yield dir
  29. finally:
  30. os.chdir(orig)
  31. @contextlib.contextmanager
  32. def tarball(
  33. url, target_dir: str | os.PathLike | None = None
  34. ) -> Iterator[str | os.PathLike]:
  35. """
  36. Get a tarball, extract it, yield, then clean up.
  37. >>> import urllib.request
  38. >>> url = getfixture('tarfile_served')
  39. >>> target = getfixture('tmp_path') / 'out'
  40. >>> tb = tarball(url, target_dir=target)
  41. >>> import pathlib
  42. >>> with tb as extracted:
  43. ... contents = pathlib.Path(extracted, 'contents.txt').read_text(encoding='utf-8')
  44. >>> assert not os.path.exists(extracted)
  45. """
  46. if target_dir is None:
  47. target_dir = os.path.basename(url).replace('.tar.gz', '').replace('.tgz', '')
  48. # In the tar command, use --strip-components=1 to strip the first path and
  49. # then
  50. # use -C to cause the files to be extracted to {target_dir}. This ensures
  51. # that we always know where the files were extracted.
  52. os.mkdir(target_dir)
  53. try:
  54. req = urllib.request.urlopen(url)
  55. with tarfile.open(fileobj=req, mode='r|*') as tf:
  56. tf.extractall(path=target_dir, filter=strip_first_component)
  57. yield target_dir
  58. finally:
  59. shutil.rmtree(target_dir)
  60. def strip_first_component(
  61. member: tarfile.TarInfo,
  62. path,
  63. ) -> tarfile.TarInfo:
  64. _, member.name = member.name.split('/', 1)
  65. return member
  66. def _compose(*cmgrs):
  67. """
  68. Compose any number of dependent context managers into a single one.
  69. The last, innermost context manager may take arbitrary arguments, but
  70. each successive context manager should accept the result from the
  71. previous as a single parameter.
  72. Like :func:`jaraco.functools.compose`, behavior works from right to
  73. left, so the context manager should be indicated from outermost to
  74. innermost.
  75. Example, to create a context manager to change to a temporary
  76. directory:
  77. >>> temp_dir_as_cwd = _compose(pushd, temp_dir)
  78. >>> with temp_dir_as_cwd() as dir:
  79. ... assert os.path.samefile(os.getcwd(), dir)
  80. """
  81. def compose_two(inner, outer):
  82. def composed(*args, **kwargs):
  83. with inner(*args, **kwargs) as saved, outer(saved) as res:
  84. yield res
  85. return contextlib.contextmanager(composed)
  86. return functools.reduce(compose_two, reversed(cmgrs))
  87. tarball_cwd = _compose(pushd, tarball)
  88. @contextlib.contextmanager
  89. def tarball_context(*args, **kwargs):
  90. warnings.warn(
  91. "tarball_context is deprecated. Use tarball or tarball_cwd instead.",
  92. DeprecationWarning,
  93. stacklevel=2,
  94. )
  95. pushd_ctx = kwargs.pop('pushd', pushd)
  96. with tarball(*args, **kwargs) as tball, pushd_ctx(tball) as dir:
  97. yield dir
  98. def infer_compression(url):
  99. """
  100. Given a URL or filename, infer the compression code for tar.
  101. >>> infer_compression('http://foo/bar.tar.gz')
  102. 'z'
  103. >>> infer_compression('http://foo/bar.tgz')
  104. 'z'
  105. >>> infer_compression('file.bz')
  106. 'j'
  107. >>> infer_compression('file.xz')
  108. 'J'
  109. """
  110. warnings.warn(
  111. "infer_compression is deprecated with no replacement",
  112. DeprecationWarning,
  113. stacklevel=2,
  114. )
  115. # cheat and just assume it's the last two characters
  116. compression_indicator = url[-2:]
  117. mapping = dict(gz='z', bz='j', xz='J')
  118. # Assume 'z' (gzip) if no match
  119. return mapping.get(compression_indicator, 'z')
  120. @contextlib.contextmanager
  121. def temp_dir(remover=shutil.rmtree):
  122. """
  123. Create a temporary directory context. Pass a custom remover
  124. to override the removal behavior.
  125. >>> import pathlib
  126. >>> with temp_dir() as the_dir:
  127. ... assert os.path.isdir(the_dir)
  128. ... _ = pathlib.Path(the_dir).joinpath('somefile').write_text('contents', encoding='utf-8')
  129. >>> assert not os.path.exists(the_dir)
  130. """
  131. temp_dir = tempfile.mkdtemp()
  132. try:
  133. yield temp_dir
  134. finally:
  135. remover(temp_dir)
  136. @contextlib.contextmanager
  137. def repo_context(url, branch=None, quiet=True, dest_ctx=temp_dir):
  138. """
  139. Check out the repo indicated by url.
  140. If dest_ctx is supplied, it should be a context manager
  141. to yield the target directory for the check out.
  142. """
  143. exe = 'git' if 'git' in url else 'hg'
  144. with dest_ctx() as repo_dir:
  145. cmd = [exe, 'clone', url, repo_dir]
  146. if branch:
  147. cmd.extend(['--branch', branch])
  148. devnull = open(os.path.devnull, 'w')
  149. stdout = devnull if quiet else None
  150. subprocess.check_call(cmd, stdout=stdout)
  151. yield repo_dir
  152. def null():
  153. """
  154. A null context suitable to stand in for a meaningful context.
  155. >>> with null() as value:
  156. ... assert value is None
  157. This context is most useful when dealing with two or more code
  158. branches but only some need a context. Wrap the others in a null
  159. context to provide symmetry across all options.
  160. """
  161. warnings.warn(
  162. "null is deprecated. Use contextlib.nullcontext",
  163. DeprecationWarning,
  164. stacklevel=2,
  165. )
  166. return contextlib.nullcontext()
  167. class ExceptionTrap:
  168. """
  169. A context manager that will catch certain exceptions and provide an
  170. indication they occurred.
  171. >>> with ExceptionTrap() as trap:
  172. ... raise Exception()
  173. >>> bool(trap)
  174. True
  175. >>> with ExceptionTrap() as trap:
  176. ... pass
  177. >>> bool(trap)
  178. False
  179. >>> with ExceptionTrap(ValueError) as trap:
  180. ... raise ValueError("1 + 1 is not 3")
  181. >>> bool(trap)
  182. True
  183. >>> trap.value
  184. ValueError('1 + 1 is not 3')
  185. >>> trap.tb
  186. <traceback object at ...>
  187. >>> with ExceptionTrap(ValueError) as trap:
  188. ... raise Exception()
  189. Traceback (most recent call last):
  190. ...
  191. Exception
  192. >>> bool(trap)
  193. False
  194. """
  195. exc_info = None, None, None
  196. def __init__(self, exceptions=(Exception,)):
  197. self.exceptions = exceptions
  198. def __enter__(self):
  199. return self
  200. @property
  201. def type(self):
  202. return self.exc_info[0]
  203. @property
  204. def value(self):
  205. return self.exc_info[1]
  206. @property
  207. def tb(self):
  208. return self.exc_info[2]
  209. def __exit__(self, *exc_info):
  210. type = exc_info[0]
  211. matches = type and issubclass(type, self.exceptions)
  212. if matches:
  213. self.exc_info = exc_info
  214. return matches
  215. def __bool__(self):
  216. return bool(self.type)
  217. def raises(self, func, *, _test=bool):
  218. """
  219. Wrap func and replace the result with the truth
  220. value of the trap (True if an exception occurred).
  221. First, give the decorator an alias to support Python 3.8
  222. Syntax.
  223. >>> raises = ExceptionTrap(ValueError).raises
  224. Now decorate a function that always fails.
  225. >>> @raises
  226. ... def fail():
  227. ... raise ValueError('failed')
  228. >>> fail()
  229. True
  230. """
  231. @functools.wraps(func)
  232. def wrapper(*args, **kwargs):
  233. with ExceptionTrap(self.exceptions) as trap:
  234. func(*args, **kwargs)
  235. return _test(trap)
  236. return wrapper
  237. def passes(self, func):
  238. """
  239. Wrap func and replace the result with the truth
  240. value of the trap (True if no exception).
  241. First, give the decorator an alias to support Python 3.8
  242. Syntax.
  243. >>> passes = ExceptionTrap(ValueError).passes
  244. Now decorate a function that always fails.
  245. >>> @passes
  246. ... def fail():
  247. ... raise ValueError('failed')
  248. >>> fail()
  249. False
  250. """
  251. return self.raises(func, _test=operator.not_)
  252. class suppress(contextlib.suppress, contextlib.ContextDecorator):
  253. """
  254. A version of contextlib.suppress with decorator support.
  255. >>> @suppress(KeyError)
  256. ... def key_error():
  257. ... {}['']
  258. >>> key_error()
  259. """
  260. class on_interrupt(contextlib.ContextDecorator):
  261. """
  262. Replace a KeyboardInterrupt with SystemExit(1)
  263. >>> def do_interrupt():
  264. ... raise KeyboardInterrupt()
  265. >>> on_interrupt('error')(do_interrupt)()
  266. Traceback (most recent call last):
  267. ...
  268. SystemExit: 1
  269. >>> on_interrupt('error', code=255)(do_interrupt)()
  270. Traceback (most recent call last):
  271. ...
  272. SystemExit: 255
  273. >>> on_interrupt('suppress')(do_interrupt)()
  274. >>> with __import__('pytest').raises(KeyboardInterrupt):
  275. ... on_interrupt('ignore')(do_interrupt)()
  276. """
  277. def __init__(self, action='error', /, code=1):
  278. self.action = action
  279. self.code = code
  280. def __enter__(self):
  281. return self
  282. def __exit__(self, exctype, excinst, exctb):
  283. if exctype is not KeyboardInterrupt or self.action == 'ignore':
  284. return
  285. elif self.action == 'error':
  286. raise SystemExit(self.code) from excinst
  287. return self.action == 'suppress'