operator_schemas.py 21 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566
  1. # mypy: allow-untyped-defs
  2. import enum
  3. import inspect
  4. import numbers
  5. import types
  6. import typing
  7. import warnings
  8. from typing import Any, Callable, cast, NamedTuple, Optional, TYPE_CHECKING
  9. import torch
  10. from torch._jit_internal import boolean_dispatched
  11. from torch._ops import OpOverload, OpOverloadPacket
  12. from ._compatibility import compatibility
  13. if TYPE_CHECKING:
  14. from .node import Argument
  15. __all__ = [
  16. "ArgsKwargsPair",
  17. "check_for_mutable_operation",
  18. "get_signature_for_torch_op",
  19. "create_type_hint",
  20. "type_matches",
  21. "normalize_function",
  22. "normalize_module",
  23. ]
  24. @compatibility(is_backward_compatible=False)
  25. class ArgsKwargsPair(NamedTuple):
  26. """
  27. Simple named tuple for wrapping args/kwargs pairs.
  28. """
  29. args: tuple[Any, ...]
  30. kwargs: dict[str, Any]
  31. _manual_overrides: dict[Callable, list[inspect.Signature]] = {}
  32. def _nonzero_schemas():
  33. signatures = []
  34. def nonzero(self):
  35. pass
  36. signatures.append(inspect.signature(nonzero))
  37. def nonzero(self, *, as_tuple: bool): # type: ignore[no-redef]
  38. pass
  39. signatures.append(inspect.signature(nonzero))
  40. return signatures
  41. _manual_overrides[torch.nonzero] = _nonzero_schemas()
  42. class _FakeGlobalNamespace:
  43. def __getattr__(self, name):
  44. if name == "torch":
  45. return torch
  46. raise RuntimeError("Expected a torch namespace lookup")
  47. _type_eval_globals = {
  48. "Tensor": torch.Tensor,
  49. "Device": torch.device,
  50. "Layout": torch.layout,
  51. "number": numbers.Number,
  52. "Future": torch.jit.Future,
  53. "AnyEnumType": enum.Enum,
  54. "QScheme": torch.qscheme,
  55. "__torch__": _FakeGlobalNamespace(),
  56. "NoneType": type(None),
  57. "Storage": torch.UntypedStorage,
  58. "t": typing.TypeVar("t"),
  59. }
  60. for k in dir(typing):
  61. _type_eval_globals[k] = getattr(typing, k)
  62. def _torchscript_type_to_python_type(ts_type: "torch._C.JitType") -> Any:
  63. """
  64. Convert a TorchScript type to a Python type (including subtypes) via
  65. eval'ing the annotation_str. _type_eval_globals sets up expressions
  66. like "List" and "Future" to map to actual types (typing.List and jit.Future)
  67. """
  68. return eval(ts_type.annotation_str, _type_eval_globals)
  69. def _torchscript_schema_to_signature_impl(
  70. ts_schema: torch._C.FunctionSchema,
  71. ) -> inspect.Signature:
  72. from inspect import Parameter
  73. parameters: list[Parameter] = []
  74. for arg in ts_schema.arguments:
  75. arg_type = _torchscript_type_to_python_type(arg.type)
  76. default = arg.default_value if arg.has_default_value() else Parameter.empty
  77. # TODO: Figure out if this is safe. It seems like when generating the type signatures for
  78. # PythonArgParser, we emit signatures with `input` instead of `self` as the first tensor
  79. # argument name. Downstream, if someone converts that positional argument to a keyword
  80. # argument, the name mismatch will break things, so here we're going to normalize the
  81. # name to "input"
  82. name = arg.name if arg.name != "self" else "input"
  83. kind = (
  84. Parameter.KEYWORD_ONLY
  85. if arg.kwarg_only
  86. else Parameter.POSITIONAL_OR_KEYWORD
  87. )
  88. # "from" is a keyword therefore it must be a POSITIONAL_ONLY argument
  89. if name == "from":
  90. assert kind == Parameter.POSITIONAL_OR_KEYWORD
  91. # ParameterKind type is internal implementation detail to inspec package
  92. # which makes it hard to do type annotation
  93. kind = Parameter.POSITIONAL_ONLY # type: ignore[assignment]
  94. # This renders all previous arguments to positional only
  95. for idx, p in enumerate(parameters):
  96. assert p.kind == Parameter.POSITIONAL_OR_KEYWORD
  97. parameters[idx] = Parameter(
  98. name=p.name,
  99. kind=Parameter.POSITIONAL_ONLY,
  100. default=p.default,
  101. annotation=p.annotation,
  102. )
  103. parameters.append(
  104. Parameter(name=name, kind=kind, default=default, annotation=arg_type)
  105. )
  106. return_types = [
  107. _torchscript_type_to_python_type(ret.type) for ret in ts_schema.returns
  108. ]
  109. if len(return_types) == 0:
  110. return_type = None
  111. elif len(return_types) == 1:
  112. return_type = return_types[0]
  113. else:
  114. return_type = tuple(return_types)
  115. return inspect.Signature(parameters, return_annotation=return_type)
  116. _SCHEMA_TO_SIGNATURE_CACHE: dict[tuple[str, str], inspect.Signature] = {}
  117. def _torchscript_schema_to_signature(
  118. ts_schema: torch._C.FunctionSchema,
  119. ) -> inspect.Signature:
  120. # Cached as it's called in the hot path of FakeTensor dispatch
  121. cache_key = ts_schema.name, ts_schema.overload_name
  122. cache_val = _SCHEMA_TO_SIGNATURE_CACHE.get(cache_key)
  123. if cache_val is not None:
  124. return cache_val
  125. res = _torchscript_schema_to_signature_impl(ts_schema)
  126. _SCHEMA_TO_SIGNATURE_CACHE[cache_key] = res
  127. return res
  128. @compatibility(is_backward_compatible=False)
  129. def check_for_mutable_operation(
  130. target: Callable, args: tuple["Argument", ...], kwargs: dict[str, "Argument"]
  131. ):
  132. signatures, schemas = get_signature_for_torch_op(target, return_schemas=True)
  133. if signatures and schemas:
  134. matched_schemas = []
  135. # Iterate through all of the schema until we find one that matches
  136. # If one matches, populate `new_args_and_kwargs` with the new args/kwargs
  137. # values. If none matches, `new_args_and_kwargs` will be None
  138. for candidate_signature, schema in zip(signatures, schemas):
  139. try:
  140. candidate_signature.bind(*args, **kwargs)
  141. matched_schemas.append((candidate_signature, schema))
  142. except TypeError:
  143. continue
  144. def throw_if_mutable(schema):
  145. if schema.is_mutable:
  146. raise RuntimeError(
  147. f"Tried to trace mutable operation {schema}. FX only supports functional "
  148. f"code, so operations that mutate operands in-place (e.g. via `out` arguments) "
  149. f"are not supported"
  150. )
  151. if len(matched_schemas) == 0:
  152. # Did not match any schema. Cannot check for mutation
  153. pass
  154. elif len(matched_schemas) == 1:
  155. # Matched exactly one schema, unambiguous
  156. _, schema_to_check = matched_schemas[0]
  157. throw_if_mutable(schema_to_check)
  158. else:
  159. # Ambiguous schema match. Since mutability checking is best effort,
  160. # do nothing.
  161. pass
  162. @compatibility(is_backward_compatible=False)
  163. def get_signature_for_torch_op(op: Callable, return_schemas: bool = False):
  164. """
  165. Given an operator on the `torch` namespace, return a list of `inspect.Signature`
  166. objects corresponding to the overloads of that op.. May return `None` if a signature
  167. could not be retrieved.
  168. Args:
  169. op (Callable): An operator on the `torch` namespace to look up a signature for
  170. Returns:
  171. Optional[List[inspect.Signature]]: A list of signatures for the overloads of this
  172. operator, or None if the operator signatures could not be retrieved. If
  173. return_schemas=True, returns a tuple containing the optional Python signatures
  174. and the optional TorchScript Function signature
  175. """
  176. if isinstance(op, OpOverload):
  177. schemas = [op._schema]
  178. elif isinstance(op, OpOverloadPacket):
  179. schemas = [getattr(op, overload)._schema for overload in op.overloads()]
  180. else:
  181. override = _manual_overrides.get(op)
  182. if override:
  183. return (override, None) if return_schemas else None
  184. aten_fn = torch.jit._builtins._find_builtin(op)
  185. if aten_fn is None:
  186. return (None, None) if return_schemas else None
  187. schemas = torch._C._jit_get_schemas_for_operator(aten_fn)
  188. signatures = [_torchscript_schema_to_signature(schema) for schema in schemas]
  189. return (signatures, schemas) if return_schemas else signatures
  190. @compatibility(is_backward_compatible=False)
  191. def create_type_hint(x):
  192. """
  193. Produces a type hint for the given argument.
  194. The :func:`create_type_hint` looks for a type hint compatible with the input argument `x`.
  195. If `x` is a `list` or `tuple`, it looks for an object in the list whose type is a superclass
  196. of the rest, and uses that as `base_type` for the `List` or `Tuple` to be returned.
  197. If no such object is found, it defaults to `List[Any]`.
  198. If `x` is neither a `list` nor a `tuple`, it returns `x`.
  199. """
  200. try:
  201. if isinstance(x, (list, tuple)):
  202. # todo(chilli): Figure out the right way for mypy to handle this
  203. if isinstance(x, list):
  204. def ret_type(x):
  205. return list[x] # type: ignore[valid-type]
  206. else:
  207. def ret_type(x):
  208. return tuple[x, ...] # type: ignore[valid-type]
  209. if len(x) == 0:
  210. return ret_type(Any)
  211. base_type = x[0]
  212. for t in x:
  213. if issubclass(t, base_type):
  214. continue
  215. elif issubclass(base_type, t):
  216. base_type = t
  217. else:
  218. return ret_type(Any)
  219. return ret_type(base_type)
  220. except Exception:
  221. # We tried to create a type hint for list but failed.
  222. warnings.warn(
  223. f"We were not able to successfully create type hint from the type {x}"
  224. )
  225. return x
  226. @compatibility(is_backward_compatible=False)
  227. def type_matches(signature_type: Any, argument_type: Any):
  228. sig_origin_type = getattr(signature_type, "__origin__", signature_type)
  229. if signature_type is argument_type:
  230. return True
  231. # Union types in signature. Given type needs to match one of the
  232. # contained types in the Union
  233. if sig_origin_type is typing.Union and signature_type != argument_type:
  234. sig_contained = signature_type.__args__
  235. return any(type_matches(c, argument_type) for c in sig_contained)
  236. if getattr(signature_type, "__origin__", None) is list:
  237. sig_el_type = signature_type.__args__[0]
  238. # int can be promoted to list[int]
  239. if argument_type is int and sig_el_type is int:
  240. return True
  241. if not inspect.isclass(sig_el_type):
  242. warnings.warn(
  243. f"Does not support nested parametric types, got {signature_type}. Please file a bug."
  244. )
  245. return False
  246. if getattr(argument_type, "__origin__", None) is list:
  247. return issubclass(argument_type.__args__[0], sig_el_type)
  248. def is_homogeneous_tuple(t):
  249. if getattr(t, "__origin__", None) is not tuple:
  250. return False
  251. contained = t.__args__
  252. if t.__args__ == ((),): # Tuple[()].__args__ == ((),) for some reason
  253. return True
  254. return all((c is Ellipsis) or issubclass(c, sig_el_type) for c in contained)
  255. # Tuple[T] is accepted for List[T] parameters
  256. return is_homogeneous_tuple(argument_type)
  257. # Dtype is an int in schemas
  258. if signature_type is int and argument_type is torch.dtype:
  259. return True
  260. if signature_type is numbers.Number and argument_type in {int, float}:
  261. return True
  262. if inspect.isclass(argument_type) and inspect.isclass(signature_type):
  263. return issubclass(argument_type, signature_type)
  264. return False
  265. @compatibility(is_backward_compatible=False)
  266. def normalize_function(
  267. target: Callable,
  268. args: tuple[Any, ...],
  269. kwargs: Optional[dict[str, Any]] = None,
  270. arg_types: Optional[tuple[Any]] = None,
  271. kwarg_types: Optional[dict[str, Any]] = None,
  272. normalize_to_only_use_kwargs: bool = False,
  273. ) -> Optional[ArgsKwargsPair]:
  274. """
  275. Returns normalized arguments to PyTorch functions. This means that
  276. `args/kwargs` will be matched up to the functional's
  277. signature and return exclusively kwargs in positional order if
  278. `normalize_to_only_use_kwargs` is True.
  279. Also populates default values. Does not support positional-only
  280. parameters or varargs parameters (*args, **kwargs). Does not support modules.
  281. May require `arg_types` and `kwarg_types` in order to disambiguate overloads.
  282. Args:
  283. target (Callable): Function that we are normalizing
  284. args (Tuple[Any]): Tuple of args to the function
  285. kwargs (Optional[Dict[str, Any]]): Dict of kwargs to the function
  286. arg_types (Optional[Tuple[Any]]): Tuple of arg types for the args
  287. kwarg_types (Optional[Dict[str, Any]]): Dict of arg types for the kwargs
  288. normalize_to_only_use_kwargs (bool): Whether to normalize to only use kwargs.
  289. Returns:
  290. Returns normalized_args_and_kwargs, or `None` if not successful.
  291. """
  292. if kwargs is None:
  293. kwargs = {}
  294. new_args_and_kwargs = None
  295. if (
  296. not isinstance(target, types.BuiltinFunctionType)
  297. and not (isinstance(target, (OpOverloadPacket, OpOverload)))
  298. and hasattr(target, "_op")
  299. ):
  300. # ExecuTorch's EdgeOpOverload are a wrapper around PyTorch's OpOverload,
  301. # so we can unwrap it here to get its schema
  302. # Can't import EdgeOpOverload directly because of a circular dependency,
  303. # so checking for "_op" existing is the next best thing.
  304. target = target._op
  305. # Repeat the condition after checking for the inner _op field.
  306. if not isinstance(target, types.BuiltinFunctionType) and not (
  307. isinstance(target, (OpOverloadPacket, OpOverload))
  308. ):
  309. target_for_analysis = target
  310. if target in boolean_dispatched:
  311. # HACK: `boolean_dispatch` as used in `torch.nn.functional` makes it so that we have
  312. # a 2-way dispatch based on a boolean value. Here we check that the `true` and `false`
  313. # branches of the dispatch have exactly the same signature. If they do, use the `true`
  314. # branch signature for analysis. Otherwise, leave this un-normalized
  315. assert not isinstance(target, str)
  316. dispatched = boolean_dispatched[target]
  317. if_true, if_false = dispatched["if_true"], dispatched["if_false"]
  318. if (
  319. inspect.signature(if_true).parameters
  320. != inspect.signature(if_false).parameters
  321. ):
  322. return None
  323. target_for_analysis = if_true
  324. assert callable(target_for_analysis)
  325. sig = inspect.signature(inspect.unwrap(target_for_analysis))
  326. new_args_and_kwargs = _args_kwargs_to_normalized_args_kwargs(
  327. sig, args, kwargs, normalize_to_only_use_kwargs
  328. )
  329. else:
  330. assert callable(target)
  331. torch_op_schemas = get_signature_for_torch_op(target)
  332. matched_schemas = []
  333. if torch_op_schemas:
  334. # Iterate through all of the schema until we find one that matches
  335. # If one matches, populate `new_args_and_kwargs` with the new args/kwargs
  336. # values. If none matches, `new_args_and_kwargs` will be None
  337. for candidate_signature in torch_op_schemas:
  338. try:
  339. candidate_signature.bind(*args, **kwargs)
  340. matched_schemas.append(candidate_signature)
  341. except TypeError:
  342. continue
  343. if len(matched_schemas) == 0:
  344. # Did not match any schema. Cannot normalize
  345. pass
  346. elif len(matched_schemas) == 1:
  347. # Matched exactly one schema, unambiguous
  348. new_args_and_kwargs = _args_kwargs_to_normalized_args_kwargs(
  349. matched_schemas[0], args, kwargs, normalize_to_only_use_kwargs
  350. )
  351. else:
  352. if arg_types is not None or kwarg_types is not None:
  353. arg_types = arg_types if arg_types else cast(tuple[Any], ())
  354. kwarg_types = kwarg_types if kwarg_types else {}
  355. for candidate_signature in torch_op_schemas:
  356. sig_matches = True
  357. try:
  358. bound_types = candidate_signature.bind(
  359. *arg_types, **kwarg_types
  360. )
  361. for arg_name, arg_type in bound_types.arguments.items():
  362. param = candidate_signature.parameters[arg_name]
  363. sig_matches = sig_matches and type_matches(
  364. param.annotation, arg_type
  365. )
  366. except TypeError:
  367. sig_matches = False
  368. if sig_matches:
  369. new_args_and_kwargs = (
  370. _args_kwargs_to_normalized_args_kwargs(
  371. candidate_signature,
  372. args,
  373. kwargs,
  374. normalize_to_only_use_kwargs,
  375. )
  376. )
  377. break
  378. else:
  379. # Matched more than one schema. In this situation, the caller must provide the types of
  380. # the arguments of the overload they expect.
  381. schema_printouts = "\n".join(
  382. str(schema) for schema in matched_schemas
  383. )
  384. raise RuntimeError(
  385. f"Tried to normalize arguments to {torch.typename(target)} but "
  386. f"the schema match was ambiguous! Please provide argument types to "
  387. f"the normalize_arguments() call. Available schemas:\n{schema_printouts}"
  388. )
  389. return new_args_and_kwargs
  390. @compatibility(is_backward_compatible=False)
  391. def normalize_module(
  392. root: torch.nn.Module,
  393. target: str,
  394. args: tuple[Any],
  395. kwargs: Optional[dict[str, Any]] = None,
  396. normalize_to_only_use_kwargs: bool = False,
  397. ) -> Optional[ArgsKwargsPair]:
  398. """
  399. Returns normalized arguments to PyTorch modules. This means that
  400. `args/kwargs` will be matched up to the functional's
  401. signature and return exclusively kwargs in positional order if
  402. `normalize_to_only_use_kwargs` is True.
  403. Also populates default values. Does not support positional-only
  404. parameters or varargs parameters (*args, **kwargs).
  405. Args:
  406. root (nn.Module): root module upon which we query modules
  407. target (Callable): Function that we are normalizing
  408. args (Tuple[Any]): Tuple of args to the function
  409. kwargs (Optional[Dict[str, Any]]): Dict of kwargs to the function
  410. normalize_to_only_use_kwargs (bool): Whether to normalize to only use kwargs.
  411. Returns:
  412. Returns normalized_args_and_kwargs, or `None` if not successful.
  413. """
  414. try:
  415. submod = root.get_submodule(target)
  416. except AttributeError as e:
  417. raise RuntimeError(
  418. f"Tried to normalize node with target {target} but root did not "
  419. f"have that target!"
  420. ) from e
  421. if hasattr(submod.__class__, "__name__"):
  422. classname = submod.__class__.__name__
  423. if getattr(torch.nn, classname, None) == submod.__class__:
  424. sig = inspect.signature(inspect.unwrap(submod.forward))
  425. if kwargs is None:
  426. kwargs = {}
  427. new_args_and_kwargs = _args_kwargs_to_normalized_args_kwargs(
  428. sig, args, kwargs, normalize_to_only_use_kwargs
  429. )
  430. return new_args_and_kwargs
  431. return None
  432. def _args_kwargs_to_normalized_args_kwargs(
  433. sig: inspect.Signature,
  434. args: tuple[Any, ...],
  435. kwargs: dict[str, Any],
  436. normalize_to_only_use_kwargs: bool,
  437. ) -> Optional[ArgsKwargsPair]:
  438. """
  439. Given a call target, args, and kwargs, return the arguments normalized into
  440. an ArgsKwargsPair, or None if the type signature is not supported by
  441. this normalization.
  442. Args:
  443. sig (inspect.Signature): Signature object for the target
  444. args (Tuple): Arguments that appear at the callsite for `target`
  445. kwargs (Dict): Keyword arguments that appear at the callsite for `target`
  446. normalize_to_only_use_kwargs (bool): Whether to normalize to only use kwargs.
  447. Returns:
  448. Optional[ArgsKwargsPair]: Normalized args and kwargs for `target`, or `None` if
  449. this target is not supported.
  450. """
  451. # Don't currently support positional-only
  452. # or varargs (*args, **kwargs) signatures
  453. supported_parameter_types = {
  454. inspect.Parameter.POSITIONAL_OR_KEYWORD,
  455. inspect.Parameter.KEYWORD_ONLY,
  456. }
  457. if any(p.kind not in supported_parameter_types for p in sig.parameters.values()):
  458. # Add an exception for one signature, which is common for random/uniform, i.e.:
  459. # Tensor(a!) self, float from=0, float to=1, *, Generator? generator=None
  460. # `from` is Python keyword and as such functions with that signature should have
  461. # positional-only args, but at the same time they could be dispatched as kwargs
  462. if list(sig.parameters.keys()) != ["input", "from", "to", "generator"]:
  463. return None
  464. bound_args = sig.bind(*args, **kwargs)
  465. bound_args.apply_defaults()
  466. new_kwargs: dict[str, Any] = {}
  467. new_args: list[Any] = []
  468. for i, param in enumerate(sig.parameters):
  469. if not normalize_to_only_use_kwargs and i < len(args):
  470. new_args.append(bound_args.arguments[param])
  471. else:
  472. new_kwargs[param] = bound_args.arguments[param]
  473. return ArgsKwargsPair(tuple(new_args), new_kwargs)