utils.py 22 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647
  1. # mypy: allow-untyped-defs
  2. import dataclasses
  3. import inspect
  4. import sys
  5. from collections.abc import Callable, Iterable, Iterator
  6. from typing import Any, Literal, Optional, overload, Union
  7. import torch
  8. import torch.utils._pytree as pytree
  9. import torchgen
  10. from torch import _C, _utils_internal
  11. from torch._ops import OpOverload
  12. @dataclasses.dataclass
  13. class Kernel:
  14. """Models a (function, source location)"""
  15. func: Callable
  16. source: str
  17. def __call__(self, *args, **kwargs):
  18. return self.func(*args, **kwargs)
  19. class RegistrationHandle:
  20. """Does something when someone calls .destroy() on it"""
  21. def __init__(self, on_destroy: Callable):
  22. self._on_destroy = on_destroy
  23. def destroy(self) -> None:
  24. self._on_destroy()
  25. def get_source(stacklevel: int) -> str:
  26. """Get a string that represents the caller.
  27. Example: "/path/to/foo.py:42"
  28. Use stacklevel=1 to get the caller's source
  29. Use stacklevel=2 to get the caller's caller's source
  30. etc.
  31. """
  32. frame = inspect.getframeinfo(sys._getframe(stacklevel))
  33. source = f"{frame.filename}:{frame.lineno}"
  34. return source
  35. def parse_namespace(qualname: str) -> tuple[str, str]:
  36. splits = qualname.split("::")
  37. if len(splits) != 2:
  38. raise ValueError(
  39. f"Expected `qualname` to be of the form "
  40. f'"namespace::name", but got {qualname}. '
  41. f"The qualname passed to the torch.library APIs must consist "
  42. f"of a namespace and a name, e.g. aten::sin"
  43. )
  44. return splits[0], splits[1]
  45. def lookup_op(qualname: str) -> OpOverload:
  46. namespace, name = parse_namespace(qualname)
  47. if "." in name:
  48. name, overload = name.split(".")
  49. else:
  50. overload = "default"
  51. ns = getattr(torch.ops, namespace)
  52. packet = getattr(ns, name)
  53. return getattr(packet, overload)
  54. def is_builtin(op: OpOverload) -> bool:
  55. assert isinstance(op, OpOverload)
  56. return op.namespace in {"aten", "prim", "prims"}
  57. def is_functional_schema(schema: Any, *, allow_valid_view: bool = False) -> bool:
  58. """Check if the schema is functional.
  59. An operator is functional if:
  60. - it does not mutate any of its inputs
  61. - If no view are allowed
  62. - it does not return a view on any of its inputs
  63. - If valid views are allowed
  64. - it is not a view or a view with a single input Tensor and single output Tensor
  65. - it has at least one return
  66. """
  67. def is_functional(schema):
  68. if schema.is_mutable:
  69. return False
  70. rets = schema.returns
  71. is_non_mutating_view = len(rets) > 0 and any(
  72. r.alias_info is not None and not r.alias_info.is_write for r in rets
  73. )
  74. num_tensor_inputs = 0
  75. num_tensor_outputs = 0
  76. if isinstance(schema, torch.FunctionSchema):
  77. for arg in schema.arguments:
  78. if isinstance(arg.type, torch.TensorType):
  79. num_tensor_inputs += 1
  80. for ret in schema.returns:
  81. if isinstance(ret.type, torch.TensorType):
  82. num_tensor_outputs += 1
  83. elif isinstance(schema, torchgen.model.FunctionSchema):
  84. for argument in schema.arguments.flat_non_out:
  85. if argument.type.is_tensor_like():
  86. num_tensor_inputs += 1
  87. for ret_arg in schema.returns:
  88. if ret_arg.type.is_tensor_like():
  89. num_tensor_outputs += 1
  90. if is_non_mutating_view:
  91. return allow_valid_view and (
  92. num_tensor_inputs == 1 and num_tensor_outputs == 1
  93. )
  94. if not schema.returns:
  95. return False
  96. return True
  97. if isinstance(schema, torch._C.FunctionSchema):
  98. return is_functional(schema)
  99. # Lazy import because not all PyTorch builds have torchgen
  100. from torchgen.model import FunctionSchema
  101. if isinstance(schema, str):
  102. schema = FunctionSchema.parse(schema)
  103. assert isinstance(schema, FunctionSchema)
  104. return is_functional(schema)
  105. # should be torch._C.JitType but that annotation is busted
  106. def is_tensorlist_like_type(typ: Any) -> bool:
  107. return (
  108. typ == _C.ListType(_C.TensorType.get())
  109. or typ == _C.ListType(_C.OptionalType(_C.TensorType.get()))
  110. or typ == _C.OptionalType(_C.ListType(_C.TensorType.get()))
  111. or typ == _C.OptionalType(_C.ListType(_C.OptionalType(_C.TensorType.get())))
  112. )
  113. # should be torch._C.JitType but that annotation is busted
  114. def is_tensor_like_type(typ: Any) -> bool:
  115. return typ == _C.TensorType.get() or typ == _C.OptionalType(_C.TensorType.get())
  116. def mutates_and_returns_first_arg(op: OpOverload):
  117. """Check if an op is an inplace aten op, i.e. it mutates and returns the first arg.
  118. TODO: torchgen/model.py's FunctionSchema.parse is the source of truth for this,
  119. but not all PyTorch builds have torchgen (due to the yaml dependency being weird).
  120. Figure this out.
  121. Example: add_(Tensor(a!) x, Tensor y) -> Tensor(a)
  122. """
  123. if op.namespace != "aten":
  124. return False
  125. schema = op._schema
  126. if len(schema.returns) != 1:
  127. return False
  128. if schema.returns[0].alias_info is None:
  129. return False
  130. alias_set = schema.returns[0].alias_info.after_set
  131. if len(alias_set) != 1:
  132. return False
  133. loc = next(iter(alias_set))
  134. if len(schema.arguments) < 1:
  135. return False
  136. first_arg = schema.arguments[0]
  137. if first_arg.alias_info is None:
  138. return False
  139. if not first_arg.alias_info.is_write:
  140. return False
  141. alias_set = first_arg.alias_info.after_set
  142. if len(alias_set) != 1:
  143. return False
  144. if loc != next(iter(alias_set)):
  145. return False
  146. for arg in schema.arguments[1:]:
  147. if arg.alias_info is not None:
  148. return False
  149. return True
  150. def fill_defaults(schema, args, kwargs):
  151. new_args = []
  152. new_kwargs = {}
  153. for i in range(len(schema.arguments)):
  154. info = schema.arguments[i]
  155. if info.kwarg_only:
  156. if info.name in kwargs:
  157. new_kwargs[info.name] = kwargs[info.name]
  158. else:
  159. new_kwargs[info.name] = info.default_value
  160. else:
  161. if i < len(args):
  162. new_args.append(args[i])
  163. else:
  164. new_args.append(info.default_value)
  165. return tuple(new_args), new_kwargs
  166. def zip_schema(
  167. schema: _C.FunctionSchema, args: tuple[Any, ...], kwargs: dict[str, Any]
  168. ) -> Iterable[tuple[_C.Argument, Any]]:
  169. """zips schema.arguments and (args, kwargs) together.
  170. Assumes that (args, kwargs) were the inputs to some torch._ops.OpOverload:
  171. that is, (args, kwargs) must be bindable to the schema (args, kwargs).
  172. """
  173. assert len(schema.arguments) >= len(args) + len(kwargs)
  174. for i in range(len(schema.arguments)):
  175. info = schema.arguments[i]
  176. if info.kwarg_only:
  177. if info.name in kwargs:
  178. yield info, kwargs[info.name]
  179. continue
  180. if i >= len(args):
  181. if not info.kwarg_only and info.name in kwargs:
  182. yield info, kwargs[info.name]
  183. # args that are equal to their default values are not populated
  184. # if they are followed by args that are equal to their defaults.
  185. # Skip these.
  186. continue
  187. yield info, args[i]
  188. return
  189. def hop_schema_from_fx_node(node):
  190. from torchgen.gen_schema_utils import FunctionSchemaGen
  191. hop = node.target
  192. if not isinstance(hop, torch._ops.HigherOrderOperator):
  193. raise RuntimeError("fx_node's target must be a hop.")
  194. def _collect_example_val(node):
  195. meta_val = node.meta.get("val", None)
  196. if meta_val is None:
  197. assert node.op == "get_attr"
  198. meta_val = getattr(node.graph.owning_module, node.target)
  199. return meta_val
  200. example_inputs = []
  201. for arg in node.args:
  202. if isinstance(arg, (torch.fx.Node, torch.fx.node.Node)):
  203. example_inputs.append(_collect_example_val(arg))
  204. elif isinstance(
  205. arg, (torch.fx.immutable_collections.immutable_list, list, tuple)
  206. ):
  207. example_inputs.append([_collect_example_val(x) for x in arg])
  208. else:
  209. raise RuntimeError(f"Unsupported arg type {type(arg)}")
  210. # Bound the arguments to make sure number of inputs are correct
  211. bound_args: inspect.BoundArguments = inspect.signature(hop.__call__).bind(
  212. *example_inputs
  213. )
  214. # We treat example_output as a single value in return. This is to differentiate 1. return a single val
  215. # vs 2. return a tuple with one element.
  216. example_output = _collect_example_val(node)
  217. return FunctionSchemaGen.from_example(
  218. hop._name, tuple(bound_args.arguments.items()), (list(example_output),)
  219. )
  220. def can_generate_trivial_fake_impl(op: OpOverload) -> bool:
  221. assert isinstance(op, OpOverload)
  222. if is_builtin(op):
  223. # We control the built-ins. These may (in rare cases)
  224. # do input metadata mutation (which we have banned on custom ops)
  225. return False
  226. schema = op._schema
  227. # It's suspicious if the op is not mutable but returns nothing, so we return False out of an abundance of caution
  228. if not schema.is_mutable:
  229. return False
  230. if len(schema.returns) > 0:
  231. return False
  232. # If the op returns nothing, then it has a trivial fake impl.
  233. return True
  234. def requires_set_python_module() -> bool:
  235. """If an op was defined in C++ and extended from Python using the
  236. torch.library APIs, returns if we require that there have been a
  237. m.set_python_module("mylib.ops") call from C++ that associates
  238. the C++ op with a python module.
  239. """
  240. return getattr(_utils_internal, "REQUIRES_SET_PYTHON_MODULE", True)
  241. def handle_dispatch_mode(curr_mode, op_overload, *args, **kwargs):
  242. assert isinstance(curr_mode, torch.utils._python_dispatch.TorchDispatchMode)
  243. args_flattened, _ = torch.utils._pytree.tree_flatten((args, kwargs.values()))
  244. # TODO: need to double check the semantics of the "types" argument to torch_dispatch.
  245. # It's generated in PyInterpreter.cpp, but seems to be generated in two places,
  246. # where in one case we only include tensors with the python key, and in another
  247. # we include **all** tensors.
  248. overload_types = [
  249. type(a)
  250. for a in args_flattened
  251. if isinstance(a, torch.Tensor)
  252. and torch._C._dispatch_keys(a).has(torch._C.DispatchKey.Python)
  253. ]
  254. # TODO: check that I got these args correct (in C++, we pass in "0000"??)
  255. return curr_mode.__torch_dispatch__(op_overload, overload_types, args, kwargs)
  256. def has_kwarg_only_args(schema: _C.FunctionSchema):
  257. return any(a.kwarg_only for a in schema.arguments)
  258. def has_kwarg_only_tensors(schema: _C.FunctionSchema):
  259. for a in schema.arguments:
  260. if not (is_tensor_like_type(a.type) or is_tensorlist_like_type(a.type)):
  261. continue
  262. if not a.kwarg_only:
  263. continue
  264. return True
  265. return False
  266. def has_tensor_arg(schema: _C.FunctionSchema) -> bool:
  267. """
  268. Given a schema, returns True if the schema has a Tensor arg.
  269. A Tensor arg is any arg with a type annotation that might involve Tensor.
  270. """
  271. return any(
  272. (is_tensor_like_type(a.type) or is_tensorlist_like_type(a.type))
  273. for a in schema.arguments
  274. )
  275. def get_device_arg_index(schema: _C.FunctionSchema) -> Union[int, None]:
  276. """
  277. Given a schema, returns the id of the `device: torch.device` argument.
  278. If it does not exist, returns None.
  279. """
  280. for index, arg in enumerate(schema.arguments):
  281. if arg.type is _C.DeviceObjType.get() and arg.name == "device":
  282. return index
  283. return None
  284. def iter_tensors(
  285. args: tuple[Any], kwargs: dict[str, Any], allowed_nesting: int = 1
  286. ) -> Iterator[torch.Tensor]:
  287. def check(arg):
  288. if isinstance(arg, torch.Tensor):
  289. yield arg
  290. elif allowed_nesting > 0 and isinstance(arg, (tuple, list)):
  291. yield from iter_tensors(tuple(arg), {}, allowed_nesting - 1)
  292. for arg in args:
  293. yield from check(arg)
  294. for kwarg in kwargs.values():
  295. yield from check(kwarg)
  296. def check_aliasing_constraint(name, prev, result, get_module=lambda: "???"):
  297. """
  298. custom operators' outputs must not alias any inputs or other outputs.
  299. """
  300. storages = {t.untyped_storage()._cdata for t in prev if isinstance(t, torch.Tensor)}
  301. tuple_result = result
  302. if not isinstance(result, tuple):
  303. tuple_result = (result,)
  304. for tensor in iter_tensors(tuple_result, {}):
  305. key = tensor.untyped_storage()._cdata
  306. if tensor.untyped_storage()._cdata in storages:
  307. raise RuntimeError(
  308. f"{name} (with implementation in {get_module()}): "
  309. f"The output of this custom operator (1) must not "
  310. f"also be an input to this custom operator and "
  311. f"(2) may not alias any inputs to this custom operator "
  312. f"or other returns. "
  313. f"The most common way to trigger this error is if "
  314. f"we have y = custom_op(x) and y and x are the same Tensor. "
  315. f"Please instead return a clone of the offending output "
  316. f"tensor(s) (e.g. return x.clone()) or refactor the custom "
  317. f"operator to not return y."
  318. )
  319. storages.add(key)
  320. def _c_check_aliasing_constraint(name, args, kwargs, result, get_module=lambda: "???"):
  321. """
  322. custom operators' outputs must not have any aliases
  323. This version uses C++ implementation for perf.
  324. Only List container is supported.
  325. Tensors in Lists with not only Tensors are checked.
  326. """
  327. tuple_result = result
  328. if not isinstance(result, tuple):
  329. tuple_result = (result,)
  330. if _C._any_output_is_alias_to_input_or_output(args, kwargs, tuple_result):
  331. raise RuntimeError(
  332. f"{name} (with implementation in {get_module()}): "
  333. f"The output of this custom operator (1) must not "
  334. f"also be an input to this custom operator and "
  335. f"(2) may not alias any inputs to this custom operator "
  336. f"or other returns. "
  337. f"The most common way to trigger this error is if "
  338. f"we have y = custom_op(x) and y and x are the same Tensor. "
  339. f"Please instead return a clone of the offending output "
  340. f"tensor(s) (e.g. return x.clone()) or refactor the custom "
  341. f"operator to not return y."
  342. )
  343. class MutationChecker:
  344. """
  345. Check if an operator mutated its arguments.
  346. Usage:
  347. checker = MutationChecker(op, flat_args, args_spec)
  348. op(*args, **kwargs)
  349. checker.check()
  350. """
  351. def __init__(self, op, flat_args, args_spec):
  352. self.op = op
  353. self.args_spec = args_spec
  354. self.flat_args = flat_args
  355. self.real_pre_hashes = [
  356. hash_tensor(a) if isinstance(a, torch.Tensor) else None for a in flat_args
  357. ]
  358. def check(self):
  359. real_post_hashes = [
  360. hash_tensor(a) if isinstance(a, torch.Tensor) else None
  361. for a in self.flat_args
  362. ]
  363. was_mutated = [
  364. not torch.equal(pre, post)
  365. and not (pre.isnan().all() and post.isnan().all())
  366. if isinstance(pre, torch.Tensor) and isinstance(post, torch.Tensor)
  367. else None
  368. for pre, post in zip(self.real_pre_hashes, real_post_hashes)
  369. ]
  370. was_mutated_args, was_mutated_kwargs = pytree.tree_unflatten(
  371. was_mutated, self.args_spec
  372. )
  373. for info, was_mutated in zip_schema(
  374. self.op._schema, was_mutated_args, was_mutated_kwargs
  375. ):
  376. def check_one(info, was_mutated):
  377. if info.is_write == was_mutated:
  378. return
  379. raise RuntimeError(
  380. f"{self.op._name}: for argument '{info.name}': the operator's schema "
  381. f"{self.op._schema} specified that "
  382. f"the operator {'mutates' if info.is_write else 'does not mutate'} "
  383. f"the argument, but this seems to be empirically wrong. "
  384. f"Please make the schema and operator behavior consistent. "
  385. f"You can specify that an operator mutates a Tensor by "
  386. f"e.g. changing its schema type from 'Tensor name' to 'Tensor(a!) name'"
  387. f"(use different identifiers (a, b, c, ...) for different Tensors)"
  388. )
  389. if is_tensor_like_type(info.type):
  390. check_one(info, was_mutated)
  391. elif is_tensorlist_like_type(info.type):
  392. was_any_mutated = False if was_mutated is None else any(was_mutated)
  393. check_one(info, was_any_mutated)
  394. def hash_tensor(t: torch.Tensor) -> torch.Tensor:
  395. """Some inexpensive hash. Used as a quick and dirty indicator for tensor mutation"""
  396. return t.detach().float().mean()
  397. def has_fake_kernel(op: torch._ops.OpOverload) -> bool:
  398. """If an operator (that stays alive until FakeTensorMode) has a Fake kernel.
  399. Don't use this if the operator decomposes before FakeTensorMode.
  400. """
  401. if can_generate_trivial_fake_impl(op):
  402. return True
  403. name = op._name
  404. if torch._C._dispatch_has_kernel_for_dispatch_key(
  405. name, "CompositeImplicitAutograd"
  406. ):
  407. return True
  408. opdef = torch._library.custom_ops._maybe_get_opdef(name)
  409. if opdef is None:
  410. # the non-torch.library.custom_op path
  411. if torch._C._dispatch_has_kernel_for_dispatch_key(
  412. name, "CompositeExplicitAutograd"
  413. ):
  414. return True
  415. entry = torch._library.simple_registry.singleton.find(name)
  416. if entry.fake_impl.kernel is not None:
  417. return True
  418. if torch._C._dispatch_has_kernel_for_dispatch_key(name, "Meta"):
  419. return True
  420. else:
  421. # the torch.library.custom_op path
  422. if opdef._abstract_fn is not None:
  423. return True
  424. return False
  425. def mutated_args_kwargs(schema: _C.FunctionSchema) -> tuple[list[int], list[str]]:
  426. idxs = []
  427. keys = []
  428. for i, info in enumerate(schema.arguments):
  429. if info.alias_info is not None and info.alias_info.is_write:
  430. if info.kwarg_only:
  431. keys.append(info.name)
  432. else:
  433. idxs.append(i)
  434. return idxs, keys
  435. tags_by_priority = [
  436. _C.Tag.needs_exact_strides,
  437. _C.Tag.needs_contiguous_strides,
  438. _C.Tag.needs_fixed_stride_order,
  439. _C.Tag.flexible_layout,
  440. ]
  441. # Case 1: with_default=True (or omitted). Return type is guaranteed to be a Tag.
  442. @overload
  443. def get_layout_constraint_tag(
  444. fn: Any, *, with_default: Literal[True] = True
  445. ) -> _C.Tag: ...
  446. # Case 2: with_default=False. Return type can be a Tag or None.
  447. @overload
  448. def get_layout_constraint_tag(
  449. fn: Any, *, with_default: Literal[False]
  450. ) -> Optional[_C.Tag]: ...
  451. def get_layout_constraint_tag(fn, *, with_default=True):
  452. for tag in tags_by_priority:
  453. if tag in fn.tags:
  454. return tag
  455. if with_default:
  456. if is_builtin(fn):
  457. return _C.Tag.flexible_layout
  458. import torch._functorch
  459. from torch._functorch import config
  460. return getattr(torch._C.Tag, config.custom_op_default_layout_constraint)
  461. return None
  462. # List of random functions that should be treated as impure
  463. _RANDOM_FUNCTIONS = {
  464. torch.rand,
  465. torch.randn,
  466. torch.randint,
  467. torch.randperm,
  468. torch.rand_like,
  469. torch.randn_like,
  470. torch.randint_like,
  471. torch.normal,
  472. torch.poisson,
  473. torch.bernoulli,
  474. torch.multinomial,
  475. }
  476. def is_impure(
  477. op: Callable,
  478. *,
  479. args: Optional[tuple[Any, ...]] = None,
  480. kwargs: Optional[dict[str, Any]] = None,
  481. impure_random: bool = True,
  482. ) -> bool:
  483. """
  484. An operator is impure if it:
  485. - Mutates its inputs (has a mutable schema)
  486. - Has nondeterministic/random behavior that mutates RNG state
  487. - Is explicitly marked as effectful via torch.library._register_effectful_op
  488. Args:
  489. op: The operator to check (function, OpOverload, HigherOrderOperator, etc.)
  490. args: Optional arguments that would be passed to the callable
  491. kwargs: Optional keyword arguments that would be passed to the callable
  492. impure_random: Whether to treat random operations as impure (default: True)
  493. Returns:
  494. bool: True if the callable has side effects, False otherwise
  495. """
  496. # Import here to avoid circular dependencies
  497. from torch._higher_order_ops.effects import _get_effect
  498. from torch.fx.node import _side_effectful_functions
  499. if isinstance(op, torch._ops.OpOverload):
  500. schema = getattr(op, "_schema", None)
  501. if schema is not None and schema.is_mutable:
  502. return True
  503. if op in _side_effectful_functions:
  504. return True
  505. if _get_effect(op) is not None:
  506. return True
  507. if isinstance(op, torch._ops.HigherOrderOperator):
  508. if op in (
  509. torch.ops.higher_order.auto_functionalized,
  510. torch.ops.higher_order.auto_functionalized_v2,
  511. ):
  512. # Check if the auto-functionalized operator (the first argument) is
  513. # side-effectful
  514. if args and len(args) > 0:
  515. return args[0] in _side_effectful_functions
  516. if _get_effect(op) is not None:
  517. return True
  518. return False
  519. # Impure since it mutates RNG state
  520. if impure_random and getattr(op, "_nondeterministic_seeded", False):
  521. return True
  522. # Handle Python random functions that don't have _nondeterministic_seeded
  523. # but still affect global RNG state (issue #151524)
  524. # These should be impure regardless of impure_random setting to maintain
  525. # consistency between eager and compiled execution
  526. # All random operations are impure to ensure consistent behavior
  527. # between eager and compiled execution, regardless of generator usage
  528. if op in _RANDOM_FUNCTIONS:
  529. return True
  530. schema = getattr(op, "_schema", None)
  531. if schema is not None and schema.is_mutable:
  532. return True
  533. if op in _side_effectful_functions:
  534. return True
  535. return False