utils.py 18 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529
  1. # mypy: allow-untyped-defs
  2. import dataclasses
  3. import inspect
  4. import sys
  5. from collections.abc import Iterable, Iterator
  6. from typing import Any, Callable, Literal, Optional, overload, Union
  7. import torch
  8. import torch.utils._pytree as pytree
  9. from torch import _C, _utils_internal
  10. from torch._ops import OpOverload
  11. @dataclasses.dataclass
  12. class Kernel:
  13. """Models a (function, source location)"""
  14. func: Callable
  15. source: str
  16. def __call__(self, *args, **kwargs):
  17. return self.func(*args, **kwargs)
  18. class RegistrationHandle:
  19. """Does something when someone calls .destroy() on it"""
  20. def __init__(self, on_destroy: Callable):
  21. self._on_destroy = on_destroy
  22. def destroy(self) -> None:
  23. self._on_destroy()
  24. def get_source(stacklevel: int) -> str:
  25. """Get a string that represents the caller.
  26. Example: "/path/to/foo.py:42"
  27. Use stacklevel=1 to get the caller's source
  28. Use stacklevel=2 to get the caller's caller's source
  29. etc.
  30. """
  31. frame = inspect.getframeinfo(sys._getframe(stacklevel))
  32. source = f"{frame.filename}:{frame.lineno}"
  33. return source
  34. def parse_namespace(qualname: str) -> tuple[str, str]:
  35. splits = qualname.split("::")
  36. if len(splits) != 2:
  37. raise ValueError(
  38. f"Expected `qualname` to be of the form "
  39. f'"namespace::name", but got {qualname}. '
  40. f"The qualname passed to the torch.library APIs must consist "
  41. f"of a namespace and a name, e.g. aten::sin"
  42. )
  43. return splits[0], splits[1]
  44. def lookup_op(qualname: str) -> OpOverload:
  45. namespace, name = parse_namespace(qualname)
  46. if "." in name:
  47. name, overload = name.split(".")
  48. else:
  49. overload = "default"
  50. ns = getattr(torch.ops, namespace)
  51. packet = getattr(ns, name)
  52. return getattr(packet, overload)
  53. def is_builtin(op: OpOverload) -> bool:
  54. assert isinstance(op, OpOverload)
  55. return op.namespace in {"aten", "prim", "prims"}
  56. def is_functional_schema(schema: Any) -> bool:
  57. """Check if the schema is functional.
  58. An operator is functional if:
  59. - it does not mutate any of its inputs
  60. - it does not return a view on any of its inputs
  61. - it has at least one return
  62. """
  63. def is_functional(schema):
  64. if schema.is_mutable:
  65. return False
  66. rets = schema.returns
  67. is_non_mutating_view = len(rets) > 0 and any(
  68. r.alias_info is not None and not r.alias_info.is_write for r in rets
  69. )
  70. if is_non_mutating_view:
  71. return False
  72. if not schema.returns:
  73. return False
  74. return True
  75. if isinstance(schema, torch._C.FunctionSchema):
  76. return is_functional(schema)
  77. # Lazy import because not all PyTorch builds have torchgen
  78. from torchgen.model import FunctionSchema
  79. if isinstance(schema, str):
  80. schema = FunctionSchema.parse(schema)
  81. assert isinstance(schema, FunctionSchema)
  82. return is_functional(schema)
  83. # should be torch._C.JitType but that annotation is busted
  84. def is_tensorlist_like_type(typ: Any) -> bool:
  85. return (
  86. typ == _C.ListType(_C.TensorType.get())
  87. or typ == _C.ListType(_C.OptionalType(_C.TensorType.get()))
  88. or typ == _C.OptionalType(_C.ListType(_C.TensorType.get()))
  89. or typ == _C.OptionalType(_C.ListType(_C.OptionalType(_C.TensorType.get())))
  90. )
  91. # should be torch._C.JitType but that annotation is busted
  92. def is_tensor_like_type(typ: Any) -> bool:
  93. return typ == _C.TensorType.get() or typ == _C.OptionalType(_C.TensorType.get())
  94. def mutates_and_returns_first_arg(op: OpOverload):
  95. """Check if an op is an inplace aten op, i.e. it mutates and returns the first arg.
  96. TODO: torchgen/model.py's FunctionSchema.parse is the source of truth for this,
  97. but not all PyTorch builds have torchgen (due to the yaml dependency being weird).
  98. Figure this out.
  99. Example: add_(Tensor(a!) x, Tensor y) -> Tensor(a)
  100. """
  101. if op.namespace != "aten":
  102. return False
  103. schema = op._schema
  104. if not len(schema.returns) == 1:
  105. return False
  106. if schema.returns[0].alias_info is None:
  107. return False
  108. alias_set = schema.returns[0].alias_info.after_set
  109. if len(alias_set) != 1:
  110. return False
  111. loc = next(iter(alias_set))
  112. if len(schema.arguments) < 1:
  113. return False
  114. first_arg = schema.arguments[0]
  115. if first_arg.alias_info is None:
  116. return False
  117. if not first_arg.alias_info.is_write:
  118. return False
  119. alias_set = first_arg.alias_info.after_set
  120. if len(alias_set) != 1:
  121. return False
  122. if loc != next(iter(alias_set)):
  123. return False
  124. for arg in schema.arguments[1:]:
  125. if arg.alias_info is not None:
  126. return False
  127. return True
  128. def fill_defaults(schema, args, kwargs):
  129. new_args = []
  130. new_kwargs = {}
  131. for i in range(len(schema.arguments)):
  132. info = schema.arguments[i]
  133. if info.kwarg_only:
  134. if info.name in kwargs:
  135. new_kwargs[info.name] = kwargs[info.name]
  136. else:
  137. new_kwargs[info.name] = info.default_value
  138. else:
  139. if i < len(args):
  140. new_args.append(args[i])
  141. else:
  142. new_args.append(info.default_value)
  143. return tuple(new_args), new_kwargs
  144. def zip_schema(
  145. schema: _C.FunctionSchema, args: tuple[Any, ...], kwargs: dict[str, Any]
  146. ) -> Iterable[tuple[_C.Argument, Any]]:
  147. """zips schema.arguments and (args, kwargs) together.
  148. Assumes that (args, kwargs) were the inputs to some torch._ops.OpOverload:
  149. that is, (args, kwargs) must be bindable to the schema (args, kwargs).
  150. """
  151. assert len(schema.arguments) >= len(args) + len(kwargs)
  152. for i in range(len(schema.arguments)):
  153. info = schema.arguments[i]
  154. if info.kwarg_only:
  155. if info.name in kwargs:
  156. yield info, kwargs[info.name]
  157. continue
  158. if i >= len(args):
  159. if not info.kwarg_only and info.name in kwargs:
  160. yield info, kwargs[info.name]
  161. # args that are equal to their default values are not populated
  162. # if they are followed by args that are equal to their defaults.
  163. # Skip these.
  164. continue
  165. yield info, args[i]
  166. return
  167. def hop_schema_from_fx_node(node):
  168. from torchgen.gen_schema_utils import FunctionSchemaGen
  169. hop = node.target
  170. if not isinstance(hop, torch._ops.HigherOrderOperator):
  171. raise RuntimeError("fx_node's target must be a hop.")
  172. def _collect_example_val(node):
  173. meta_val = node.meta.get("val", None)
  174. if meta_val is None:
  175. assert node.op == "get_attr"
  176. meta_val = getattr(node.graph.owning_module, node.target)
  177. return meta_val
  178. example_inputs = []
  179. for arg in node.args:
  180. if isinstance(arg, (torch.fx.Node, torch.fx.node.Node)):
  181. example_inputs.append(_collect_example_val(arg))
  182. elif isinstance(
  183. arg, (torch.fx.immutable_collections.immutable_list, list, tuple)
  184. ):
  185. example_inputs.append([_collect_example_val(x) for x in arg])
  186. else:
  187. raise RuntimeError(f"Unsupported arg type {type(arg)}")
  188. # Bound the arguments to make sure number of inputs are correct
  189. bound_args: inspect.BoundArguments = inspect.signature(hop.__call__).bind(
  190. *example_inputs
  191. )
  192. # We treat example_output as a single value in return. This is to differentiate 1. return a single val
  193. # vs 2. return a tuple with one element.
  194. example_output = _collect_example_val(node)
  195. return FunctionSchemaGen.from_example(
  196. hop._name, tuple(bound_args.arguments.items()), (list(example_output),)
  197. )
  198. def can_generate_trivial_fake_impl(op: OpOverload) -> bool:
  199. assert isinstance(op, OpOverload)
  200. if is_builtin(op):
  201. # We control the built-ins. These may (in rare cases)
  202. # do input metadata mutation (which we have banned on custom ops)
  203. return False
  204. schema = op._schema
  205. # It's suspicious if the op is not mutable but returns nothing, so we return False out of an abundance of caution
  206. if not schema.is_mutable:
  207. return False
  208. if len(schema.returns) > 0:
  209. return False
  210. # If the op returns nothing, then it has a trivial fake impl.
  211. return True
  212. def requires_set_python_module() -> bool:
  213. """If an op was defined in C++ and extended from Python using the
  214. torch.library APIs, returns if we require that there have been a
  215. m.set_python_module("mylib.ops") call from C++ that associates
  216. the C++ op with a python module.
  217. """
  218. return getattr(_utils_internal, "REQUIRES_SET_PYTHON_MODULE", True)
  219. def handle_dispatch_mode(curr_mode, op_overload, *args, **kwargs):
  220. assert isinstance(curr_mode, torch.utils._python_dispatch.TorchDispatchMode)
  221. args_flattened, _ = torch.utils._pytree.tree_flatten((args, kwargs.values()))
  222. # TODO: need to double check the semantics of the "types" argument to torch_dispatch.
  223. # It's generated in PyInterpreter.cpp, but seems to be generated in two places,
  224. # where in one case we only include tensors with the python key, and in another
  225. # we include **all** tensors.
  226. overload_types = [
  227. type(a)
  228. for a in args_flattened
  229. if isinstance(a, torch.Tensor)
  230. and torch._C._dispatch_keys(a).has(torch._C.DispatchKey.Python)
  231. ]
  232. # TODO: check that I got these args correct (in C++, we pass in "0000"??)
  233. return curr_mode.__torch_dispatch__(op_overload, overload_types, args, kwargs)
  234. def has_kwarg_only_args(schema: _C.FunctionSchema):
  235. return any(a.kwarg_only for a in schema.arguments)
  236. def has_kwarg_only_tensors(schema: _C.FunctionSchema):
  237. for a in schema.arguments:
  238. if not (is_tensor_like_type(a.type) or is_tensorlist_like_type(a.type)):
  239. continue
  240. if not a.kwarg_only:
  241. continue
  242. return True
  243. return False
  244. def has_tensor_arg(schema: _C.FunctionSchema) -> bool:
  245. """
  246. Given a schema, returns True if the schema has a Tensor arg.
  247. A Tensor arg is any arg with a type annotation that might involve Tensor.
  248. """
  249. return any(
  250. (is_tensor_like_type(a.type) or is_tensorlist_like_type(a.type))
  251. for a in schema.arguments
  252. )
  253. def get_device_arg_index(schema: _C.FunctionSchema) -> Union[int, None]:
  254. """
  255. Given a schema, returns the id of the `device: torch.device` argument.
  256. If it does not exist, returns None.
  257. """
  258. for index, arg in enumerate(schema.arguments):
  259. if arg.type is _C.DeviceObjType.get() and arg.name == "device":
  260. return index
  261. return None
  262. def iter_tensors(
  263. args: tuple[Any], kwargs: dict[str, Any], allowed_nesting: int = 1
  264. ) -> Iterator[torch.Tensor]:
  265. def check(arg):
  266. if isinstance(arg, torch.Tensor):
  267. yield arg
  268. elif allowed_nesting > 0 and isinstance(arg, (tuple, list)):
  269. yield from iter_tensors(tuple(arg), {}, allowed_nesting - 1)
  270. for arg in args:
  271. yield from check(arg)
  272. for kwarg in kwargs.values():
  273. yield from check(kwarg)
  274. def check_aliasing_constraint(name, prev, result, get_module=lambda: "???"):
  275. """
  276. custom operators' outputs must not alias any inputs or other outputs.
  277. """
  278. storages = {id(t.untyped_storage()) for t in prev if isinstance(t, torch.Tensor)}
  279. tuple_result = result
  280. if not isinstance(result, tuple):
  281. tuple_result = (result,)
  282. for tensor in iter_tensors(tuple_result, {}):
  283. key = id(tensor.untyped_storage())
  284. if id(tensor.untyped_storage()) in storages:
  285. raise RuntimeError(
  286. f"{name} (with implementation in {get_module()}): "
  287. f"The output of this custom operator (1) must not "
  288. f"also be an input to this custom operator and "
  289. f"(2) may not alias any inputs to this custom operator "
  290. f"or other returns. "
  291. f"The most common way to trigger this error is if "
  292. f"we have y = custom_op(x) and y and x are the same Tensor. "
  293. f"Please instead return a clone of the offending output "
  294. f"tensor(s) (e.g. return x.clone()) or refactor the custom "
  295. f"operator to not return y."
  296. )
  297. storages.add(key)
  298. def _c_check_aliasing_constraint(name, args, kwargs, result, get_module=lambda: "???"):
  299. """
  300. custom operators' outputs must not have any aliases
  301. This version uses C++ implementation for perf.
  302. Only List container is supported.
  303. Tensors in Lists with not only Tensors are checked.
  304. """
  305. tuple_result = result
  306. if not isinstance(result, tuple):
  307. tuple_result = (result,)
  308. if _C._any_output_is_alias_to_input_or_output(args, kwargs, tuple_result):
  309. raise RuntimeError(
  310. f"{name} (with implementation in {get_module()}): "
  311. f"The output of this custom operator (1) must not "
  312. f"also be an input to this custom operator and "
  313. f"(2) may not alias any inputs to this custom operator "
  314. f"or other returns. "
  315. f"The most common way to trigger this error is if "
  316. f"we have y = custom_op(x) and y and x are the same Tensor. "
  317. f"Please instead return a clone of the offending output "
  318. f"tensor(s) (e.g. return x.clone()) or refactor the custom "
  319. f"operator to not return y."
  320. )
  321. class MutationChecker:
  322. """
  323. Check if an operator mutated its arguments.
  324. Usage:
  325. checker = MutationChecker(op, flat_args, args_spec)
  326. op(*args, **kwargs)
  327. checker.check()
  328. """
  329. def __init__(self, op, flat_args, args_spec):
  330. self.op = op
  331. self.args_spec = args_spec
  332. self.flat_args = flat_args
  333. self.real_pre_hashes = [
  334. hash_tensor(a) if isinstance(a, torch.Tensor) else None for a in flat_args
  335. ]
  336. def check(self):
  337. real_post_hashes = [
  338. hash_tensor(a) if isinstance(a, torch.Tensor) else None
  339. for a in self.flat_args
  340. ]
  341. was_mutated = [
  342. not torch.equal(pre, post)
  343. and not (pre.isnan().all() and post.isnan().all())
  344. if isinstance(pre, torch.Tensor) and isinstance(post, torch.Tensor)
  345. else None
  346. for pre, post in zip(self.real_pre_hashes, real_post_hashes)
  347. ]
  348. was_mutated_args, was_mutated_kwargs = pytree.tree_unflatten(
  349. was_mutated, self.args_spec
  350. )
  351. for info, was_mutated in zip_schema(
  352. self.op._schema, was_mutated_args, was_mutated_kwargs
  353. ):
  354. def check_one(info, was_mutated):
  355. if info.is_write == was_mutated:
  356. return
  357. raise RuntimeError(
  358. f"{self.op._name}: for argument '{info.name}': the operator's schema "
  359. f"{self.op._schema} specified that "
  360. f"the operator {'mutates' if info.is_write else 'does not mutate'} "
  361. f"the argument, but this seems to be empirically wrong. "
  362. f"Please make the schema and operator behavior consistent. "
  363. f"You can specify that an operator mutates a Tensor by "
  364. f"e.g. changing its schema type from 'Tensor name' to 'Tensor(a!) name'"
  365. f"(use different identifiers (a, b, c, ...) for different Tensors)"
  366. )
  367. if is_tensor_like_type(info.type):
  368. check_one(info, was_mutated)
  369. elif is_tensorlist_like_type(info.type):
  370. was_any_mutated = False if was_mutated is None else any(was_mutated)
  371. check_one(info, was_any_mutated)
  372. def hash_tensor(t: torch.Tensor) -> torch.Tensor:
  373. """Some inexpensive hash. Used as a quick and dirty indicator for tensor mutation"""
  374. return t.detach().float().mean()
  375. def has_fake_kernel(op: torch._ops.OpOverload) -> bool:
  376. """If an operator (that stays alive until FakeTensorMode) has a Fake kernel.
  377. Don't use this if the operator decomposes before FakeTensorMode.
  378. """
  379. if can_generate_trivial_fake_impl(op):
  380. return True
  381. name = op._name
  382. if torch._C._dispatch_has_kernel_for_dispatch_key(
  383. name, "CompositeImplicitAutograd"
  384. ):
  385. return True
  386. opdef = torch._library.custom_ops._maybe_get_opdef(name)
  387. if opdef is None:
  388. # the non-torch.library.custom_op path
  389. if torch._C._dispatch_has_kernel_for_dispatch_key(
  390. name, "CompositeExplicitAutograd"
  391. ):
  392. return True
  393. entry = torch._library.simple_registry.singleton.find(name)
  394. if entry.fake_impl.kernel is not None:
  395. return True
  396. if torch._C._dispatch_has_kernel_for_dispatch_key(name, "Meta"):
  397. return True
  398. else:
  399. # the torch.library.custom_op path
  400. if opdef._abstract_fn is not None:
  401. return True
  402. return False
  403. def mutated_args_kwargs(schema: _C.FunctionSchema) -> tuple[list[int], list[str]]:
  404. idxs = []
  405. keys = []
  406. for i, info in enumerate(schema.arguments):
  407. if info.alias_info is not None and info.alias_info.is_write:
  408. if info.kwarg_only:
  409. keys.append(info.name)
  410. else:
  411. idxs.append(i)
  412. return idxs, keys
  413. tags_by_priority = [
  414. _C.Tag.needs_exact_strides,
  415. _C.Tag.needs_contiguous_strides,
  416. _C.Tag.needs_fixed_stride_order,
  417. _C.Tag.flexible_layout,
  418. ]
  419. # Case 1: with_default=True (or omitted). Return type is guaranteed to be a Tag.
  420. @overload
  421. def get_layout_constraint_tag(
  422. fn: Any, *, with_default: Literal[True] = True
  423. ) -> _C.Tag: ...
  424. # Case 2: with_default=False. Return type can be a Tag or None.
  425. @overload
  426. def get_layout_constraint_tag(
  427. fn: Any, *, with_default: Literal[False]
  428. ) -> Optional[_C.Tag]: ...
  429. def get_layout_constraint_tag(fn, *, with_default=True):
  430. for tag in tags_by_priority:
  431. if tag in fn.tags:
  432. return tag
  433. if with_default:
  434. if is_builtin(fn):
  435. return _C.Tag.flexible_layout
  436. import torch._functorch
  437. from torch._functorch import config
  438. return getattr(torch._C.Tag, config.custom_op_default_layout_constraint)
  439. return None