normalize.py 5.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163
  1. # mypy: allow-untyped-defs
  2. import operator
  3. from typing import Any, Callable, Optional
  4. import torch
  5. import torch.fx
  6. import torch.fx as fx
  7. from torch.fx import Proxy, Transformer
  8. from torch.fx.node import Argument, map_aggregate, Node, Target
  9. from torch.fx.operator_schemas import (
  10. create_type_hint,
  11. normalize_function,
  12. normalize_module,
  13. )
  14. from .schema_type_annotation import AnnotateTypesWithSchema
  15. class NormalizeArgs(Transformer):
  16. """
  17. Normalize arguments to Python targets. This means that
  18. `args/kwargs` will be matched up to the module/functional's
  19. signature and rewritten to exclusively kwargs in positional order
  20. if `normalize_to_only_use_kwargs` is true. Also populates default
  21. values. Does not support positional-only parameters or varargs
  22. parameters (*args, **kwargs).
  23. If the nodes have 'type' metadata, it will use it to disambiguate
  24. overloads. Otherwise, it will throw an error.
  25. Example usage:
  26. m = torchvision.models.resnet18()
  27. traced = torch.fx.symbolic_trace(m)
  28. traced = NormalizeArgs(traced).transform()
  29. """
  30. def __init__(
  31. self, module: torch.fx.GraphModule, normalize_to_only_use_kwargs: bool = True
  32. ):
  33. super().__init__(module)
  34. self.node_map: dict[Proxy, Node] = {}
  35. self.normalize_to_only_use_kwargs = normalize_to_only_use_kwargs
  36. def run_node(self, n: Node) -> Any:
  37. args, kwargs = self.fetch_args_kwargs_from_env(n)
  38. def get_type(arg):
  39. if isinstance(arg, fx.Node):
  40. return n.meta["type"] if "type" in n.meta else None
  41. return type(arg)
  42. arg_types = map_aggregate(n.args, get_type)
  43. assert isinstance(arg_types, tuple)
  44. arg_types = tuple([create_type_hint(i) for i in arg_types])
  45. kwarg_types = {k: get_type(v) for k, v in kwargs.items()}
  46. if n.op == "call_function":
  47. out = self.call_function(n.target, args, kwargs, arg_types, kwarg_types)
  48. else:
  49. out = super().run_node(n)
  50. if n.op != "output":
  51. self.node_map[out] = n
  52. out.node.meta = n.meta
  53. out.node.type = n.type
  54. return out
  55. def call_function(
  56. self,
  57. target: Target,
  58. args: tuple[Argument, ...],
  59. kwargs: dict[str, Any],
  60. arg_types: Optional[tuple[Any, ...]] = None,
  61. kwarg_types: Optional[dict[str, Any]] = None,
  62. ):
  63. assert callable(target)
  64. new_args_and_kwargs = normalize_function(
  65. target,
  66. args, # type: ignore[arg-type]
  67. kwargs,
  68. arg_types, # type: ignore[arg-type]
  69. kwarg_types,
  70. self.normalize_to_only_use_kwargs,
  71. )
  72. if new_args_and_kwargs:
  73. new_args, new_kwargs = new_args_and_kwargs
  74. return self.tracer.create_proxy(
  75. "call_function", target, new_args, new_kwargs
  76. )
  77. else:
  78. return super().call_function(target, args, kwargs)
  79. def call_module(
  80. self, target: Target, args: tuple[Argument, ...], kwargs: dict[str, Any]
  81. ):
  82. assert isinstance(target, str)
  83. new_args_and_kwargs = normalize_module(
  84. self.module,
  85. target,
  86. args, # type: ignore[arg-type]
  87. kwargs,
  88. self.normalize_to_only_use_kwargs,
  89. )
  90. if new_args_and_kwargs:
  91. new_args, new_kwargs = new_args_and_kwargs
  92. return super().call_module(target, new_args, new_kwargs)
  93. else:
  94. return super().call_module(target, args, kwargs)
  95. class NormalizeOperators(AnnotateTypesWithSchema):
  96. """
  97. Normalize callsites that are different ways of "spelling" the same
  98. invocation into a single, canonical call. Currently supports:
  99. 1. Normalize operators (e.g. operator.add) to the `torch` ops they
  100. ultimately invoke (e.g. torch.add) when it is possible to statically
  101. reason that
  102. Example usage:
  103. m = torchvision.models.resnet18()
  104. traced = torch.fx.symbolic_trace(m)
  105. traced = NormalizeOperators(traced).transform()
  106. """
  107. binary_magic_method_remap: dict[
  108. Callable[[Any, Any], Any], Callable[[Any, Any], Any]
  109. ] = {
  110. torch.add: operator.add,
  111. torch.mul: operator.mul,
  112. torch.sub: operator.sub,
  113. torch.div: operator.truediv,
  114. torch.floor_divide: operator.floordiv,
  115. torch.remainder: operator.mod,
  116. torch.eq: operator.eq,
  117. torch.ne: operator.ne,
  118. torch.lt: operator.lt,
  119. torch.le: operator.le,
  120. torch.gt: operator.gt,
  121. torch.ge: operator.ge,
  122. }
  123. def call_function(
  124. self, target: Target, args: tuple[Argument, ...], kwargs: dict[str, Any]
  125. ):
  126. # Normalize operators according to the magic methods implemented on tensors here:
  127. # https://github.com/pytorch/pytorch/blob/28c5d90b679c6b38bf4183ec99f16d933c2f1bcd/tools/autograd/templates/python_variable_methods.cpp#L1137 # noqa: B950
  128. assert callable(target)
  129. if target in self.binary_magic_method_remap:
  130. if len(args) != 2:
  131. return super().call_function(target, args, kwargs)
  132. lhs, rhs = args
  133. return super().call_function(
  134. target=self.binary_magic_method_remap[target],
  135. args=(lhs, rhs),
  136. kwargs={},
  137. )
  138. return super().call_function(target, args, kwargs)