utils.py 18 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568
  1. from __future__ import annotations
  2. import contextlib
  3. import functools
  4. import hashlib
  5. import os
  6. import re
  7. import sys
  8. import textwrap
  9. from dataclasses import fields, is_dataclass
  10. from enum import auto, Enum
  11. from pathlib import Path
  12. from typing import Any, Callable, Generic, Literal, NoReturn, TYPE_CHECKING, TypeVar
  13. from typing_extensions import assert_never, deprecated, Self
  14. from torchgen.code_template import CodeTemplate
  15. if TYPE_CHECKING:
  16. from argparse import Namespace
  17. from collections.abc import Iterable, Iterator, Sequence
  18. TORCHGEN_ROOT = Path(__file__).absolute().parent
  19. REPO_ROOT = TORCHGEN_ROOT.parent
  20. # Many of these functions share logic for defining both the definition
  21. # and declaration (for example, the function signature is the same), so
  22. # we organize them into one function that takes a Target to say which
  23. # code we want.
  24. #
  25. # This is an OPEN enum (we may add more cases to it in the future), so be sure
  26. # to explicitly specify with Literal[Target.XXX] or Literal[Target.XXX, Target.YYY]
  27. # what targets are valid for your use.
  28. class Target(Enum):
  29. # top level namespace (not including at)
  30. DEFINITION = auto()
  31. DECLARATION = auto()
  32. # TORCH_LIBRARY(...) { ... }
  33. REGISTRATION = auto()
  34. # namespace { ... }
  35. ANONYMOUS_DEFINITION = auto()
  36. # namespace cpu { ... }
  37. NAMESPACED_DEFINITION = auto()
  38. NAMESPACED_DECLARATION = auto()
  39. # Matches "foo" in "foo, bar" but not "foobar". Used to search for the
  40. # occurrence of a parameter in the derivative formula
  41. IDENT_REGEX = r"(^|\W){}($|\W)"
  42. # TODO: Use a real parser here; this will get bamboozled
  43. def split_name_params(schema: str) -> tuple[str, list[str]]:
  44. m = re.match(r"(\w+)(\.\w+)?\((.*)\)", schema)
  45. if m is None:
  46. raise RuntimeError(f"Unsupported function schema: {schema}")
  47. name, _, params = m.groups()
  48. return name, params.split(", ")
  49. T = TypeVar("T")
  50. S = TypeVar("S")
  51. # These two functions purposely return generators in analogy to map()
  52. # so that you don't mix up when you need to list() them
  53. # Map over function that may return None; omit Nones from output sequence
  54. def mapMaybe(func: Callable[[T], S | None], xs: Iterable[T]) -> Iterator[S]:
  55. for x in xs:
  56. r = func(x)
  57. if r is not None:
  58. yield r
  59. # Map over function that returns sequences and cat them all together
  60. def concatMap(func: Callable[[T], Sequence[S]], xs: Iterable[T]) -> Iterator[S]:
  61. for x in xs:
  62. yield from func(x)
  63. # Conveniently add error context to exceptions raised. Lets us
  64. # easily say that an error occurred while processing a specific
  65. # context.
  66. @contextlib.contextmanager
  67. def context(msg_fn: Callable[[], str]) -> Iterator[None]:
  68. try:
  69. yield
  70. except Exception as e:
  71. # TODO: this does the wrong thing with KeyError
  72. msg = msg_fn()
  73. msg = textwrap.indent(msg, " ")
  74. msg = f"{e.args[0]}\n{msg}" if e.args else msg
  75. e.args = (msg,) + e.args[1:]
  76. raise
  77. if TYPE_CHECKING:
  78. # A little trick from https://github.com/python/mypy/issues/6366
  79. # for getting mypy to do exhaustiveness checking
  80. # TODO: put this somewhere else, maybe
  81. @deprecated("Use typing_extensions.assert_never instead")
  82. def assert_never(x: NoReturn) -> NoReturn: # type: ignore[misc] # noqa: F811
  83. raise AssertionError(f"Unhandled type: {type(x).__name__}")
  84. @functools.cache
  85. def _read_template(template_fn: str) -> CodeTemplate:
  86. return CodeTemplate.from_file(template_fn)
  87. # String hash that's stable across different executions, unlike builtin hash
  88. def string_stable_hash(s: str) -> int:
  89. sha1 = hashlib.sha1(s.encode("latin1"), usedforsecurity=False).digest()
  90. return int.from_bytes(sha1, byteorder="little")
  91. # A small abstraction for writing out generated files and keeping track
  92. # of what files have been written (so you can write out a list of output
  93. # files)
  94. class FileManager:
  95. def __init__(
  96. self,
  97. install_dir: str | Path,
  98. template_dir: str | Path,
  99. dry_run: bool,
  100. ) -> None:
  101. self.install_dir = Path(install_dir)
  102. self.template_dir = Path(template_dir)
  103. self.files: set[Path] = set()
  104. self.dry_run = dry_run
  105. @property
  106. def filenames(self) -> frozenset[str]:
  107. return frozenset({file.as_posix() for file in self.files})
  108. def _write_if_changed(self, filename: str | Path, contents: str) -> None:
  109. file = Path(filename)
  110. old_contents: str | None = None
  111. try:
  112. old_contents = file.read_text(encoding="utf-8")
  113. except OSError:
  114. pass
  115. if contents != old_contents:
  116. # Create output directory if it doesn't exist
  117. file.parent.mkdir(parents=True, exist_ok=True)
  118. file.write_text(contents, encoding="utf-8")
  119. # Read from template file and replace pattern with callable (type could be dict or str).
  120. def substitute_with_template(
  121. self,
  122. template_fn: str | Path,
  123. env_callable: Callable[[], str | dict[str, Any]],
  124. ) -> str:
  125. assert not Path(template_fn).is_absolute(), (
  126. f"template_fn must be relative: {template_fn}"
  127. )
  128. template_path = self.template_dir / template_fn
  129. env = env_callable()
  130. if isinstance(env, dict):
  131. if "generated_comment" not in env:
  132. generator_default = TORCHGEN_ROOT / "gen.py"
  133. try:
  134. generator = Path(
  135. sys.modules["__main__"].__file__ or generator_default
  136. ).absolute()
  137. except (KeyError, AttributeError):
  138. generator = generator_default.absolute()
  139. try:
  140. generator_path = generator.relative_to(REPO_ROOT).as_posix()
  141. except ValueError:
  142. generator_path = generator.name
  143. env = {
  144. **env, # copy the original dict instead of mutating it
  145. "generated_comment": (
  146. "@" + f"generated by {generator_path} from {template_fn}"
  147. ),
  148. }
  149. template = _read_template(template_path)
  150. substitute_out = template.substitute(env)
  151. # Ensure an extra blank line between the class/function definition
  152. # and the docstring of the previous class/function definition.
  153. # NB: It is generally not recommended to have docstrings in pyi stub
  154. # files. But if there are any, we need to ensure that the file
  155. # is properly formatted.
  156. return re.sub(
  157. r'''
  158. (""")\n+ # match triple quotes
  159. (
  160. (\s*@.+\n)* # match decorators if any
  161. \s*(class|def) # match class/function definition
  162. )
  163. ''',
  164. r"\g<1>\n\n\g<2>",
  165. substitute_out,
  166. flags=re.VERBOSE,
  167. )
  168. if isinstance(env, str):
  169. return env
  170. assert_never(env)
  171. def write_with_template(
  172. self,
  173. filename: str | Path,
  174. template_fn: str | Path,
  175. env_callable: Callable[[], str | dict[str, Any]],
  176. ) -> None:
  177. filename = Path(filename)
  178. assert not filename.is_absolute(), f"filename must be relative: {filename}"
  179. file = self.install_dir / filename
  180. assert file not in self.files, f"duplicate file write {file}"
  181. self.files.add(file)
  182. if not self.dry_run:
  183. substitute_out = self.substitute_with_template(
  184. template_fn=template_fn,
  185. env_callable=env_callable,
  186. )
  187. self._write_if_changed(filename=file, contents=substitute_out)
  188. def write(
  189. self,
  190. filename: str | Path,
  191. env_callable: Callable[[], str | dict[str, Any]],
  192. ) -> None:
  193. self.write_with_template(filename, filename, env_callable)
  194. def write_sharded(
  195. self,
  196. filename: str | Path,
  197. items: Iterable[T],
  198. *,
  199. key_fn: Callable[[T], str],
  200. env_callable: Callable[[T], dict[str, list[str]]],
  201. num_shards: int,
  202. base_env: dict[str, Any] | None = None,
  203. sharded_keys: set[str],
  204. ) -> None:
  205. self.write_sharded_with_template(
  206. filename,
  207. filename,
  208. items,
  209. key_fn=key_fn,
  210. env_callable=env_callable,
  211. num_shards=num_shards,
  212. base_env=base_env,
  213. sharded_keys=sharded_keys,
  214. )
  215. def write_sharded_with_template(
  216. self,
  217. filename: str | Path,
  218. template_fn: str | Path,
  219. items: Iterable[T],
  220. *,
  221. key_fn: Callable[[T], str],
  222. env_callable: Callable[[T], dict[str, list[str]]],
  223. num_shards: int,
  224. base_env: dict[str, Any] | None = None,
  225. sharded_keys: set[str],
  226. ) -> None:
  227. file = Path(filename)
  228. assert not file.is_absolute(), f"filename must be relative: {filename}"
  229. everything: dict[str, Any] = {"shard_id": "Everything"}
  230. shards: list[dict[str, Any]] = [
  231. {"shard_id": f"_{i}"} for i in range(num_shards)
  232. ]
  233. all_shards = [everything] + shards
  234. if base_env is not None:
  235. for shard in all_shards:
  236. shard.update(base_env)
  237. for key in sharded_keys:
  238. for shard in all_shards:
  239. if key in shard:
  240. assert isinstance(shard[key], list), (
  241. "sharded keys in base_env must be a list"
  242. )
  243. shard[key] = shard[key].copy()
  244. else:
  245. shard[key] = []
  246. def merge_env(into: dict[str, list[str]], from_: dict[str, list[str]]) -> None:
  247. for k, v in from_.items():
  248. assert k in sharded_keys, f"undeclared sharded key {k}"
  249. into[k] += v
  250. if self.dry_run:
  251. # Dry runs don't write any templates, so incomplete environments are fine
  252. items = ()
  253. for item in items:
  254. key = key_fn(item)
  255. sid = string_stable_hash(key) % num_shards
  256. env = env_callable(item)
  257. merge_env(shards[sid], env)
  258. merge_env(everything, env)
  259. for shard in all_shards:
  260. shard_id = shard["shard_id"]
  261. self.write_with_template(
  262. file.with_stem(f"{file.stem}{shard_id}"),
  263. template_fn,
  264. lambda: shard,
  265. )
  266. # filenames is used to track compiled files, but FooEverything.cpp isn't meant to be compiled
  267. self.files.discard(self.install_dir / file.with_stem(f"{file.stem}Everything"))
  268. def write_outputs(self, variable_name: str, filename: str | Path) -> None:
  269. """Write a file containing the list of all outputs which are generated by this script."""
  270. content = "\n".join(
  271. (
  272. "set(",
  273. variable_name,
  274. # Use POSIX paths to avoid invalid escape sequences on Windows
  275. *(f' "{file.as_posix()}"' for file in sorted(self.files)),
  276. ")",
  277. )
  278. )
  279. self._write_if_changed(filename, content)
  280. def template_dir_for_comments(self) -> str:
  281. """
  282. This needs to be deterministic. The template dir is an absolute path
  283. that varies across builds. So, just use the path relative to this file,
  284. which will point to the codegen source but will be stable.
  285. """
  286. return os.path.relpath(self.template_dir, os.path.dirname(__file__))
  287. # Helper function to generate file manager
  288. def make_file_manager(
  289. options: Namespace,
  290. install_dir: str | Path | None = None,
  291. ) -> FileManager:
  292. template_dir = os.path.join(options.source_path, "templates")
  293. install_dir = install_dir if install_dir else options.install_dir
  294. return FileManager(
  295. install_dir=install_dir,
  296. template_dir=template_dir,
  297. dry_run=options.dry_run,
  298. )
  299. # Helper function to create a pretty representation for dataclasses
  300. def dataclass_repr(
  301. obj: Any,
  302. indent: int = 0,
  303. width: int = 80,
  304. ) -> str:
  305. # built-in pprint module support dataclasses from python 3.10
  306. if sys.version_info >= (3, 10):
  307. from pprint import pformat
  308. return pformat(obj, indent, width)
  309. return _pformat(obj, indent=indent, width=width)
  310. def _pformat(
  311. obj: Any,
  312. indent: int,
  313. width: int,
  314. curr_indent: int = 0,
  315. ) -> str:
  316. assert is_dataclass(obj), f"obj should be a dataclass, received: {type(obj)}"
  317. class_name = obj.__class__.__name__
  318. # update current indentation level with class name
  319. curr_indent += len(class_name) + 1
  320. fields_list = [(f.name, getattr(obj, f.name)) for f in fields(obj) if f.repr]
  321. fields_str = []
  322. for name, attr in fields_list:
  323. # update the current indent level with the field name
  324. # dict, list, set and tuple also add indent as done in pprint
  325. _curr_indent = curr_indent + len(name) + 1
  326. if is_dataclass(attr):
  327. str_repr = _pformat(attr, indent, width, _curr_indent)
  328. elif isinstance(attr, dict):
  329. str_repr = _format_dict(attr, indent, width, _curr_indent)
  330. elif isinstance(attr, (list, set, tuple)):
  331. str_repr = _format_list(attr, indent, width, _curr_indent)
  332. else:
  333. str_repr = repr(attr)
  334. fields_str.append(f"{name}={str_repr}")
  335. indent_str = curr_indent * " "
  336. body = f",\n{indent_str}".join(fields_str)
  337. return f"{class_name}({body})"
  338. def _format_dict(
  339. attr: dict[Any, Any],
  340. indent: int,
  341. width: int,
  342. curr_indent: int,
  343. ) -> str:
  344. curr_indent += indent + 3
  345. dict_repr = []
  346. for k, v in attr.items():
  347. k_repr = repr(k)
  348. v_str = (
  349. _pformat(v, indent, width, curr_indent + len(k_repr))
  350. if is_dataclass(v)
  351. else repr(v)
  352. )
  353. dict_repr.append(f"{k_repr}: {v_str}")
  354. return _format(dict_repr, indent, width, curr_indent, "{", "}")
  355. def _format_list(
  356. attr: list[Any] | set[Any] | tuple[Any, ...],
  357. indent: int,
  358. width: int,
  359. curr_indent: int,
  360. ) -> str:
  361. curr_indent += indent + 1
  362. list_repr = [
  363. _pformat(l, indent, width, curr_indent) if is_dataclass(l) else repr(l)
  364. for l in attr
  365. ]
  366. start, end = ("[", "]") if isinstance(attr, list) else ("(", ")")
  367. return _format(list_repr, indent, width, curr_indent, start, end)
  368. def _format(
  369. fields_str: list[str],
  370. indent: int,
  371. width: int,
  372. curr_indent: int,
  373. start: str,
  374. end: str,
  375. ) -> str:
  376. delimiter, curr_indent_str = "", ""
  377. # if it exceed the max width then we place one element per line
  378. if len(repr(fields_str)) >= width:
  379. delimiter = "\n"
  380. curr_indent_str = " " * curr_indent
  381. indent_str = " " * indent
  382. body = f", {delimiter}{curr_indent_str}".join(fields_str)
  383. return f"{start}{indent_str}{body}{end}"
  384. class NamespaceHelper:
  385. """A helper for constructing the namespace open and close strings for a nested set of namespaces.
  386. e.g. for namespace_str torch::lazy,
  387. prologue:
  388. namespace torch {
  389. namespace lazy {
  390. epilogue:
  391. } // namespace lazy
  392. } // namespace torch
  393. """
  394. def __init__(
  395. self,
  396. namespace_str: str,
  397. entity_name: str = "",
  398. max_level: int = 2,
  399. ) -> None:
  400. # cpp_namespace can be a colon joined string such as torch::lazy
  401. cpp_namespaces = namespace_str.split("::")
  402. assert len(cpp_namespaces) <= max_level, (
  403. f"Codegen doesn't support more than {max_level} level(s) of custom namespace. Got {namespace_str}."
  404. )
  405. self.cpp_namespace_ = namespace_str
  406. self.prologue_ = "\n".join([f"namespace {n} {{" for n in cpp_namespaces])
  407. self.epilogue_ = "\n".join(
  408. [f"}} // namespace {n}" for n in reversed(cpp_namespaces)]
  409. )
  410. self.namespaces_ = cpp_namespaces
  411. self.entity_name_ = entity_name
  412. @staticmethod
  413. def from_namespaced_entity(
  414. namespaced_entity: str,
  415. max_level: int = 2,
  416. ) -> NamespaceHelper:
  417. """
  418. Generate helper from nested namespaces as long as class/function name. E.g.: "torch::lazy::add"
  419. """
  420. names = namespaced_entity.split("::")
  421. entity_name = names[-1]
  422. namespace_str = "::".join(names[:-1])
  423. return NamespaceHelper(
  424. namespace_str=namespace_str, entity_name=entity_name, max_level=max_level
  425. )
  426. @property
  427. def prologue(self) -> str:
  428. return self.prologue_
  429. @property
  430. def epilogue(self) -> str:
  431. return self.epilogue_
  432. @property
  433. def entity_name(self) -> str:
  434. return self.entity_name_
  435. # Only allow certain level of namespaces
  436. def get_cpp_namespace(self, default: str = "") -> str:
  437. """
  438. Return the namespace string from joining all the namespaces by "::" (hence no leading "::").
  439. Return default if namespace string is empty.
  440. """
  441. return self.cpp_namespace_ if self.cpp_namespace_ else default
  442. class OrderedSet(Generic[T]):
  443. storage: dict[T, Literal[None]]
  444. def __init__(self, iterable: Iterable[T] | None = None) -> None:
  445. if iterable is None:
  446. self.storage = {}
  447. else:
  448. self.storage = dict.fromkeys(iterable)
  449. def __contains__(self, item: T) -> bool:
  450. return item in self.storage
  451. def __iter__(self) -> Iterator[T]:
  452. return iter(self.storage.keys())
  453. def update(self, items: OrderedSet[T]) -> None:
  454. self.storage.update(items.storage)
  455. def add(self, item: T) -> None:
  456. self.storage[item] = None
  457. def copy(self) -> OrderedSet[T]:
  458. ret: OrderedSet[T] = OrderedSet()
  459. ret.storage = self.storage.copy()
  460. return ret
  461. @staticmethod
  462. def union(*args: OrderedSet[T]) -> OrderedSet[T]:
  463. ret = args[0].copy()
  464. for s in args[1:]:
  465. ret.update(s)
  466. return ret
  467. def __or__(self, other: OrderedSet[T]) -> OrderedSet[T]:
  468. return OrderedSet.union(self, other)
  469. def __ior__(self, other: OrderedSet[T]) -> Self:
  470. self.update(other)
  471. return self
  472. def __eq__(self, other: object) -> bool:
  473. if isinstance(other, OrderedSet):
  474. return self.storage == other.storage
  475. else:
  476. return set(self.storage.keys()) == other