utils.py 16 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519
  1. from __future__ import annotations
  2. import contextlib
  3. import functools
  4. import hashlib
  5. import os
  6. import re
  7. import sys
  8. import textwrap
  9. from dataclasses import is_dataclass
  10. from enum import auto, Enum
  11. from pathlib import Path
  12. from pprint import pformat
  13. from typing import Any, Generic, TYPE_CHECKING, TypeVar
  14. from typing_extensions import assert_never, Self
  15. from torchgen.code_template import CodeTemplate
  16. if TYPE_CHECKING:
  17. from argparse import Namespace
  18. from collections.abc import Callable, Iterable, Iterator, Sequence
  19. TORCHGEN_ROOT = Path(__file__).absolute().parent
  20. REPO_ROOT = TORCHGEN_ROOT.parent
  21. # Many of these functions share logic for defining both the definition
  22. # and declaration (for example, the function signature is the same), so
  23. # we organize them into one function that takes a Target to say which
  24. # code we want.
  25. #
  26. # This is an OPEN enum (we may add more cases to it in the future), so be sure
  27. # to explicitly specify with Literal[Target.XXX] or Literal[Target.XXX, Target.YYY]
  28. # what targets are valid for your use.
  29. class Target(Enum):
  30. # top level namespace (not including at)
  31. DEFINITION = auto()
  32. DECLARATION = auto()
  33. # TORCH_LIBRARY(...) { ... }
  34. REGISTRATION = auto()
  35. # namespace { ... }
  36. ANONYMOUS_DEFINITION = auto()
  37. # namespace cpu { ... }
  38. NAMESPACED_DEFINITION = auto()
  39. NAMESPACED_DECLARATION = auto()
  40. # Matches "foo" in "foo, bar" but not "foobar". Used to search for the
  41. # occurrence of a parameter in the derivative formula
  42. IDENT_REGEX = r"(^|\W){}($|\W)"
  43. # TODO: Use a real parser here; this will get bamboozled
  44. def split_name_params(schema: str) -> tuple[str, list[str]]:
  45. m = re.match(r"(\w+)(\.\w+)?\((.*)\)", schema)
  46. if m is None:
  47. raise RuntimeError(f"Unsupported function schema: {schema}")
  48. name, _, params = m.groups()
  49. return name, params.split(", ")
  50. T = TypeVar("T")
  51. S = TypeVar("S")
  52. # These two functions purposely return generators in analogy to map()
  53. # so that you don't mix up when you need to list() them
  54. # Map over function that may return None; omit Nones from output sequence
  55. def mapMaybe(func: Callable[[T], S | None], xs: Iterable[T]) -> Iterator[S]:
  56. for x in xs:
  57. r = func(x)
  58. if r is not None:
  59. yield r
  60. # Map over function that returns sequences and cat them all together
  61. def concatMap(func: Callable[[T], Sequence[S]], xs: Iterable[T]) -> Iterator[S]:
  62. for x in xs:
  63. yield from func(x)
  64. # Conveniently add error context to exceptions raised. Lets us
  65. # easily say that an error occurred while processing a specific
  66. # context.
  67. @contextlib.contextmanager
  68. def context(msg_fn: Callable[[], str]) -> Iterator[None]:
  69. try:
  70. yield
  71. except Exception as e:
  72. # TODO: this does the wrong thing with KeyError
  73. msg = msg_fn()
  74. msg = textwrap.indent(msg, " ")
  75. msg = f"{e.args[0]}\n{msg}" if e.args else msg
  76. e.args = (msg,) + e.args[1:]
  77. raise
  78. @functools.cache
  79. def _read_template(template_fn: str) -> CodeTemplate:
  80. return CodeTemplate.from_file(template_fn)
  81. # String hash that's stable across different executions, unlike builtin hash
  82. def string_stable_hash(s: str) -> int:
  83. sha1 = hashlib.sha1(s.encode("latin1"), usedforsecurity=False).digest()
  84. return int.from_bytes(sha1, byteorder="little")
  85. # A small abstraction for writing out generated files and keeping track
  86. # of what files have been written (so you can write out a list of output
  87. # files)
  88. class FileManager:
  89. def __init__(
  90. self,
  91. install_dir: str | Path,
  92. template_dir: str | Path,
  93. dry_run: bool,
  94. ) -> None:
  95. self.install_dir = Path(install_dir)
  96. self.template_dir = Path(template_dir)
  97. self.files: set[Path] = set()
  98. self.dry_run = dry_run
  99. @property
  100. def filenames(self) -> frozenset[str]:
  101. return frozenset({file.as_posix() for file in self.files})
  102. def _write_if_changed(self, filename: str | Path, contents: str) -> None:
  103. file = Path(filename)
  104. old_contents: str | None = None
  105. try:
  106. old_contents = file.read_text(encoding="utf-8")
  107. except OSError:
  108. pass
  109. if contents != old_contents:
  110. # Create output directory if it doesn't exist
  111. file.parent.mkdir(parents=True, exist_ok=True)
  112. file.write_text(contents, encoding="utf-8")
  113. # Read from template file and replace pattern with callable (type could be dict or str).
  114. def substitute_with_template(
  115. self,
  116. template_fn: str | Path,
  117. env_callable: Callable[[], str | dict[str, Any]],
  118. ) -> str:
  119. assert not Path(template_fn).is_absolute(), (
  120. f"template_fn must be relative: {template_fn}"
  121. )
  122. template_path = self.template_dir / template_fn
  123. env = env_callable()
  124. if isinstance(env, dict):
  125. if "generated_comment" not in env:
  126. generator_default = TORCHGEN_ROOT / "gen.py"
  127. try:
  128. generator = Path(
  129. sys.modules["__main__"].__file__ or generator_default
  130. ).absolute()
  131. except (KeyError, AttributeError):
  132. generator = generator_default.absolute()
  133. try:
  134. generator_path = generator.relative_to(REPO_ROOT).as_posix()
  135. except ValueError:
  136. generator_path = generator.name
  137. env = {
  138. **env, # copy the original dict instead of mutating it
  139. "generated_comment": (
  140. "@" + f"generated by {generator_path} from {template_fn}"
  141. ),
  142. }
  143. template = _read_template(template_path)
  144. substitute_out = template.substitute(env)
  145. # Ensure an extra blank line between the class/function definition
  146. # and the docstring of the previous class/function definition.
  147. # NB: It is generally not recommended to have docstrings in pyi stub
  148. # files. But if there are any, we need to ensure that the file
  149. # is properly formatted.
  150. return re.sub(
  151. r'''
  152. (""")\n+ # match triple quotes
  153. (
  154. (\s*@.+\n)* # match decorators if any
  155. \s*(class|def) # match class/function definition
  156. )
  157. ''',
  158. r"\g<1>\n\n\g<2>",
  159. substitute_out,
  160. flags=re.VERBOSE,
  161. )
  162. if isinstance(env, str):
  163. return env
  164. assert_never(env)
  165. def write_with_template(
  166. self,
  167. filename: str | Path,
  168. template_fn: str | Path,
  169. env_callable: Callable[[], str | dict[str, Any]],
  170. ) -> None:
  171. filename = Path(filename)
  172. assert not filename.is_absolute(), f"filename must be relative: {filename}"
  173. file = self.install_dir / filename
  174. assert file not in self.files, f"duplicate file write {file}"
  175. self.files.add(file)
  176. if not self.dry_run:
  177. substitute_out = self.substitute_with_template(
  178. template_fn=template_fn,
  179. env_callable=env_callable,
  180. )
  181. self._write_if_changed(filename=file, contents=substitute_out)
  182. def write(
  183. self,
  184. filename: str | Path,
  185. env_callable: Callable[[], str | dict[str, Any]],
  186. ) -> None:
  187. self.write_with_template(filename, filename, env_callable)
  188. def write_sharded(
  189. self,
  190. filename: str | Path,
  191. items: Iterable[T],
  192. *,
  193. key_fn: Callable[[T], str],
  194. env_callable: Callable[[T], dict[str, list[str]]],
  195. num_shards: int,
  196. base_env: dict[str, Any] | None = None,
  197. sharded_keys: set[str],
  198. ) -> None:
  199. self.write_sharded_with_template(
  200. filename,
  201. filename,
  202. items,
  203. key_fn=key_fn,
  204. env_callable=env_callable,
  205. num_shards=num_shards,
  206. base_env=base_env,
  207. sharded_keys=sharded_keys,
  208. )
  209. def write_sharded_with_template(
  210. self,
  211. filename: str | Path,
  212. template_fn: str | Path,
  213. items: Iterable[T],
  214. *,
  215. key_fn: Callable[[T], str],
  216. env_callable: Callable[[T], dict[str, list[str]]],
  217. num_shards: int,
  218. base_env: dict[str, Any] | None = None,
  219. sharded_keys: set[str],
  220. ) -> None:
  221. file = Path(filename)
  222. assert not file.is_absolute(), f"filename must be relative: {filename}"
  223. everything: dict[str, Any] = {"shard_id": "Everything"}
  224. shards: list[dict[str, Any]] = [
  225. {"shard_id": f"_{i}"} for i in range(num_shards)
  226. ]
  227. all_shards = [everything] + shards
  228. if base_env is not None:
  229. for shard in all_shards:
  230. shard.update(base_env)
  231. for key in sharded_keys:
  232. for shard in all_shards:
  233. if key in shard:
  234. assert isinstance(shard[key], list), (
  235. "sharded keys in base_env must be a list"
  236. )
  237. shard[key] = shard[key].copy()
  238. else:
  239. shard[key] = []
  240. def merge_env(into: dict[str, list[str]], from_: dict[str, list[str]]) -> None:
  241. for k, v in from_.items():
  242. assert k in sharded_keys, f"undeclared sharded key {k}"
  243. into[k] += v
  244. if self.dry_run:
  245. # Dry runs don't write any templates, so incomplete environments are fine
  246. items = ()
  247. for item in items:
  248. key = key_fn(item)
  249. sid = string_stable_hash(key) % num_shards
  250. env = env_callable(item)
  251. merge_env(shards[sid], env)
  252. merge_env(everything, env)
  253. for shard in all_shards:
  254. shard_id = shard["shard_id"]
  255. self.write_with_template(
  256. file.with_stem(f"{file.stem}{shard_id}"),
  257. template_fn,
  258. lambda: shard,
  259. )
  260. # filenames is used to track compiled files, but FooEverything.cpp isn't meant to be compiled
  261. self.files.discard(self.install_dir / file.with_stem(f"{file.stem}Everything"))
  262. def write_outputs(self, variable_name: str, filename: str | Path) -> None:
  263. """Write a file containing the list of all outputs which are generated by this script."""
  264. content = "\n".join(
  265. (
  266. "set(",
  267. variable_name,
  268. # Use POSIX paths to avoid invalid escape sequences on Windows
  269. *(f' "{file.as_posix()}"' for file in sorted(self.files)),
  270. ")",
  271. )
  272. )
  273. self._write_if_changed(filename, content)
  274. def template_dir_for_comments(self) -> str:
  275. """
  276. This needs to be deterministic. The template dir is an absolute path
  277. that varies across builds. So, just use the path relative to this file,
  278. which will point to the codegen source but will be stable.
  279. """
  280. return os.path.relpath(self.template_dir, os.path.dirname(__file__))
  281. # Helper function to generate file manager
  282. def make_file_manager(
  283. options: Namespace,
  284. install_dir: str | Path | None = None,
  285. ) -> FileManager:
  286. template_dir = os.path.join(options.source_path, "templates")
  287. install_dir = install_dir if install_dir else options.install_dir
  288. return FileManager(
  289. install_dir=install_dir,
  290. template_dir=template_dir,
  291. dry_run=options.dry_run,
  292. )
  293. # Helper function to create a pretty representation for dataclasses
  294. def dataclass_repr(
  295. obj: Any,
  296. indent: int = 0,
  297. width: int = 80,
  298. ) -> str:
  299. return pformat(obj, indent, width)
  300. def _format_dict(
  301. attr: dict[Any, Any],
  302. indent: int,
  303. width: int,
  304. curr_indent: int,
  305. ) -> str:
  306. curr_indent += indent + 3
  307. dict_repr = []
  308. for k, v in attr.items():
  309. k_repr = repr(k)
  310. v_str = (
  311. pformat(v, indent, width, curr_indent + len(k_repr))
  312. if is_dataclass(v)
  313. else repr(v)
  314. )
  315. dict_repr.append(f"{k_repr}: {v_str}")
  316. return _format(dict_repr, indent, width, curr_indent, "{", "}")
  317. def _format_list(
  318. attr: list[Any] | set[Any] | tuple[Any, ...],
  319. indent: int,
  320. width: int,
  321. curr_indent: int,
  322. ) -> str:
  323. curr_indent += indent + 1
  324. list_repr = [
  325. pformat(l, indent, width, curr_indent) if is_dataclass(l) else repr(l)
  326. for l in attr
  327. ]
  328. start, end = ("[", "]") if isinstance(attr, list) else ("(", ")")
  329. return _format(list_repr, indent, width, curr_indent, start, end)
  330. def _format(
  331. fields_str: list[str],
  332. indent: int,
  333. width: int,
  334. curr_indent: int,
  335. start: str,
  336. end: str,
  337. ) -> str:
  338. delimiter, curr_indent_str = "", ""
  339. # if it exceed the max width then we place one element per line
  340. if len(repr(fields_str)) >= width:
  341. delimiter = "\n"
  342. curr_indent_str = " " * curr_indent
  343. indent_str = " " * indent
  344. body = f", {delimiter}{curr_indent_str}".join(fields_str)
  345. return f"{start}{indent_str}{body}{end}"
  346. class NamespaceHelper:
  347. """A helper for constructing the namespace open and close strings for a nested set of namespaces.
  348. e.g. for namespace_str torch::lazy,
  349. prologue:
  350. namespace torch {
  351. namespace lazy {
  352. epilogue:
  353. } // namespace lazy
  354. } // namespace torch
  355. """
  356. def __init__(
  357. self,
  358. namespace_str: str,
  359. entity_name: str = "",
  360. max_level: int = 2,
  361. ) -> None:
  362. # cpp_namespace can be a colon joined string such as torch::lazy
  363. cpp_namespaces = namespace_str.split("::")
  364. assert len(cpp_namespaces) <= max_level, (
  365. f"Codegen doesn't support more than {max_level} level(s) of custom namespace. Got {namespace_str}."
  366. )
  367. self.cpp_namespace_ = namespace_str
  368. self.prologue_ = "\n".join([f"namespace {n} {{" for n in cpp_namespaces])
  369. self.epilogue_ = "\n".join(
  370. [f"}} // namespace {n}" for n in reversed(cpp_namespaces)]
  371. )
  372. self.namespaces_ = cpp_namespaces
  373. self.entity_name_ = entity_name
  374. @staticmethod
  375. def from_namespaced_entity(
  376. namespaced_entity: str,
  377. max_level: int = 2,
  378. ) -> NamespaceHelper:
  379. """
  380. Generate helper from nested namespaces as long as class/function name. E.g.: "torch::lazy::add"
  381. """
  382. names = namespaced_entity.split("::")
  383. entity_name = names[-1]
  384. namespace_str = "::".join(names[:-1])
  385. return NamespaceHelper(
  386. namespace_str=namespace_str, entity_name=entity_name, max_level=max_level
  387. )
  388. @property
  389. def prologue(self) -> str:
  390. return self.prologue_
  391. @property
  392. def epilogue(self) -> str:
  393. return self.epilogue_
  394. @property
  395. def entity_name(self) -> str:
  396. return self.entity_name_
  397. # Only allow certain level of namespaces
  398. def get_cpp_namespace(self, default: str = "") -> str:
  399. """
  400. Return the namespace string from joining all the namespaces by "::" (hence no leading "::").
  401. Return default if namespace string is empty.
  402. """
  403. return self.cpp_namespace_ if self.cpp_namespace_ else default
  404. class OrderedSet(Generic[T]):
  405. storage: dict[T, None]
  406. def __init__(self, iterable: Iterable[T] | None = None) -> None:
  407. if iterable is None:
  408. self.storage = {}
  409. else:
  410. self.storage = dict.fromkeys(iterable)
  411. def __contains__(self, item: T) -> bool:
  412. return item in self.storage
  413. def __iter__(self) -> Iterator[T]:
  414. return iter(self.storage.keys())
  415. def update(self, items: OrderedSet[T]) -> None:
  416. self.storage.update(items.storage)
  417. def add(self, item: T) -> None:
  418. self.storage[item] = None
  419. def copy(self) -> OrderedSet[T]:
  420. ret: OrderedSet[T] = OrderedSet()
  421. ret.storage = self.storage.copy()
  422. return ret
  423. @staticmethod
  424. def union(*args: OrderedSet[T]) -> OrderedSet[T]:
  425. ret = args[0].copy()
  426. for s in args[1:]:
  427. ret.update(s)
  428. return ret
  429. def __or__(self, other: OrderedSet[T]) -> OrderedSet[T]:
  430. return OrderedSet.union(self, other)
  431. def __ior__(self, other: OrderedSet[T]) -> Self:
  432. self.update(other)
  433. return self
  434. def __eq__(self, other: object) -> bool:
  435. if isinstance(other, OrderedSet):
  436. return self.storage == other.storage
  437. else:
  438. return set(self.storage.keys()) == other