meta_tracer.py 10 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311
  1. # mypy: allow-untyped-defs
  2. import builtins
  3. import functools
  4. import warnings
  5. from typing import Any, Callable, Optional, Union
  6. import torch
  7. import torch.fx
  8. def embedding_override(self, input):
  9. return torch.empty(*input.shape, self.weight.shape[-1], device="meta")
  10. def nn_layernorm_override(self, input):
  11. return input
  12. def torch_relu_override(x):
  13. return x
  14. def torch_nn_relu_override(self, x):
  15. return x
  16. def functional_relu_override(x, inplace=False):
  17. assert not inplace, "dont support inplace functional.relu for metatensor analysis"
  18. return x
  19. def torch_where_override(condition, x, y):
  20. # torch.where returns the broadcasted tensor of condition, x, and y,
  21. # so hack it by using addition
  22. return condition.to(device="meta") + x.to(device="meta") + y.to(device="meta")
  23. def torch_abs_override(input, *, out=None):
  24. assert out is None, "Dont support in-place abs for MetaTensor analysis"
  25. return input
  26. manual_meta_overrides: dict[Callable, Callable] = {
  27. torch.nn.Embedding: embedding_override,
  28. torch.nn.LayerNorm: nn_layernorm_override,
  29. torch.relu: torch_relu_override,
  30. torch.nn.functional.relu: functional_relu_override,
  31. torch.nn.ReLU: torch_nn_relu_override,
  32. torch.where: torch_where_override,
  33. torch.abs: torch_abs_override,
  34. }
  35. def gen_constructor_wrapper(target):
  36. @functools.wraps(target)
  37. def wrapper(*args, **kwargs):
  38. proxy = None
  39. def check_has_proxy(v):
  40. if isinstance(v, torch.fx.Proxy):
  41. nonlocal proxy
  42. proxy = v
  43. torch.fx.node.map_aggregate(args, check_has_proxy)
  44. torch.fx.node.map_aggregate(kwargs, check_has_proxy)
  45. if proxy is not None:
  46. return proxy.tracer.create_proxy("call_function", target, args, kwargs)
  47. else:
  48. return target(*args, **kwargs)
  49. return wrapper, target
  50. class MetaProxy(torch.fx.Proxy):
  51. def install_tensor_meta(self, tensor_meta):
  52. self._tensor_meta = tensor_meta
  53. def size(self, dim=None):
  54. if hasattr(self, "_tensor_meta") and self._tensor_meta is not None:
  55. return self._tensor_meta.size(*[dim] if dim else [])
  56. return self.tracer.create_proxy(
  57. "call_method", "size", (self, dim) if dim else (self,), {}
  58. )
  59. def dim(self):
  60. if hasattr(self, "_tensor_meta") and self._tensor_meta is not None:
  61. return self._tensor_meta.dim()
  62. return self.tracer.create_proxy("call_method", "dim", (self,), {})
  63. @property
  64. def shape(self):
  65. if hasattr(self, "_tensor_meta") and self._tensor_meta is not None:
  66. return self._tensor_meta.shape
  67. return self.tracer.create_proxy(
  68. "call_function", builtins.getattr, (self, "shape"), {}
  69. )
  70. @property
  71. def dtype(self):
  72. if hasattr(self, "_tensor_meta") and self._tensor_meta is not None:
  73. return self._tensor_meta.dtype
  74. return self.tracer.create_proxy(
  75. "call_function", builtins.getattr, (self, "dtype"), {}
  76. )
  77. @property
  78. def device(self):
  79. # Hack so we can track when devices are used. During meta-tensor propagation,
  80. # replace these values with a constant 'meta'
  81. return MetaDeviceAttribute(self, "device")
  82. def __getattr__(self, k):
  83. if k == "_tensor_meta":
  84. return self.__getattribute__(k)
  85. # note: not added to the graph yet, if this is a method call
  86. # we peephole optimize to the method invocation
  87. return MetaAttribute(self, k)
  88. class MetaAttribute(MetaProxy):
  89. def __init__(self, root, attr: str):
  90. self.root = root
  91. self.attr = attr
  92. self.tracer = root.tracer
  93. self._node = None
  94. @property
  95. def node(self): # type: ignore[override]
  96. # the node for attributes is added lazily, since most will just be method calls
  97. # which do not rely on the getitem call
  98. if self._node is None:
  99. self._node = self.tracer.create_proxy(
  100. "call_function", getattr, (self.root, self.attr), {}
  101. ).node
  102. return self._node
  103. def __call__(self, *args, **kwargs):
  104. return self.tracer.create_proxy(
  105. "call_method", self.attr, (self.root,) + args, kwargs
  106. )
  107. class MetaDeviceAttribute(MetaAttribute):
  108. pass
  109. def proxys_to_metas(v):
  110. if isinstance(v, MetaDeviceAttribute):
  111. return "meta"
  112. if isinstance(v, torch.fx.Proxy):
  113. assert isinstance(v, MetaProxy), f"Expected MetaProxy but got {type(v)}"
  114. assert hasattr(v, "_tensor_meta"), "MetaProxy does not have an associated meta"
  115. return v._tensor_meta
  116. return v
  117. class MetaTracer(torch.fx.Tracer):
  118. allow_insert_stateless_mods: bool = True
  119. _TORCH_METHODS_TO_PATCH = ["arange", "zeros", "ones", "full_like", "eye"]
  120. def create_proxy(
  121. self,
  122. kind,
  123. target,
  124. args,
  125. kwargs,
  126. name=None,
  127. type_expr=None,
  128. proxy_factory_fn=None,
  129. ):
  130. rv = super().create_proxy(
  131. kind, target, args, kwargs, name, type_expr, proxy_factory_fn
  132. )
  133. if kind == "placeholder" and target in self.meta_args:
  134. rv.install_tensor_meta(self.meta_args[target])
  135. return rv
  136. if target in self.orig_fns:
  137. # NOTE: tensor constructors in PyTorch define the `device` argument as
  138. # *kwargs-only*. That is why this works. If you add methods to
  139. # _TORCH_METHODS_TO_PATCH that do not define `device` as kwarg-only,
  140. # this will break and you will likely see issues where we cannot infer
  141. # the size of the output.
  142. if "device" in kwargs:
  143. kwargs["device"] = "meta"
  144. try:
  145. args_metas = torch.fx.node.map_aggregate(args, proxys_to_metas)
  146. kwargs_metas = torch.fx.node.map_aggregate(kwargs, proxys_to_metas)
  147. if kind == "call_function":
  148. meta_target = manual_meta_overrides.get(target, target)
  149. meta_out = meta_target(*args_metas, **kwargs_metas)
  150. elif kind == "call_method":
  151. meta_target = getattr(args_metas[0], target) # type: ignore[index]
  152. meta_out = meta_target(*args_metas[1:], **kwargs_metas) # type: ignore[index]
  153. elif kind == "call_module":
  154. assert hasattr(self, "orig_forward")
  155. self._disable_module_getattr = True
  156. try:
  157. mod = self.root.get_submodule(target)
  158. mod_type = type(mod)
  159. if mod_type in manual_meta_overrides:
  160. meta_out = manual_meta_overrides[mod_type](
  161. mod, *args_metas, **kwargs_metas
  162. ) # type: ignore[misc, arg-type]
  163. else:
  164. meta_out = self.orig_forward(*args_metas, **kwargs_metas)
  165. finally:
  166. self._disable_module_getattr = False
  167. elif kind == "get_attr":
  168. self._disable_module_getattr = True
  169. try:
  170. attr_itr = self.root
  171. atoms = target.split(".")
  172. for atom in atoms:
  173. attr_itr = getattr(attr_itr, atom)
  174. assert isinstance(attr_itr, torch.Tensor)
  175. meta_out = attr_itr.to(device="meta")
  176. finally:
  177. self._disable_module_getattr = False
  178. else:
  179. return rv
  180. # TODO
  181. assert isinstance(rv, torch.fx.Proxy), "Dont support composite output yet"
  182. rv.install_tensor_meta(meta_out)
  183. except Exception as e:
  184. warnings.warn(f"Could not compute metadata for {kind} target {target}: {e}")
  185. return rv
  186. def getattr(self, attr, attr_val, parameter_proxy_cache):
  187. if getattr(self, "_disable_module_getattr", False):
  188. return attr_val
  189. else:
  190. return super().getattr(attr, attr_val, parameter_proxy_cache)
  191. def call_module(self, m, forward, args, kwargs):
  192. self.orig_forward = forward
  193. return super().call_module(m, forward, args, kwargs)
  194. def _insert_module_as_submodule(self, mod: torch.nn.Module) -> str:
  195. """
  196. Helper method which tries to insert a module that was not declared as submodule.
  197. """
  198. idx = 0
  199. mod_name = mod.__class__.__name__.lower()
  200. path = f"{mod_name}_{idx}"
  201. while hasattr(self.root, path):
  202. path = f"{mod_name}_{idx}"
  203. idx += 1
  204. self.root.add_module(path, mod)
  205. return path
  206. def path_of_module(self, mod: torch.nn.Module) -> str:
  207. try:
  208. return super().path_of_module(mod)
  209. except NameError:
  210. if (
  211. self.allow_insert_stateless_mods
  212. and len(list(mod.parameters())) == 0
  213. and len(list(mod.buffers())) == 0
  214. ):
  215. path = self._insert_module_as_submodule(mod)
  216. self.prev_module = path
  217. return path
  218. raise
  219. def proxy(self, node):
  220. return MetaProxy(node, self)
  221. def trace(self, root, meta_args: dict[str, torch.Tensor], concrete_args=None): # type: ignore[override]
  222. assert isinstance(meta_args, dict)
  223. self.meta_args = meta_args
  224. self.patched_torch_methods = {
  225. target: gen_constructor_wrapper(getattr(torch, target))
  226. for target in self._TORCH_METHODS_TO_PATCH
  227. }
  228. self.orig_fns = set()
  229. for name, (wrapper, orig) in self.patched_torch_methods.items():
  230. setattr(torch, name, wrapper)
  231. self.orig_fns.add(orig)
  232. try:
  233. graph = super().trace(root, concrete_args)
  234. graph._tracer_extras = {"meta_args": meta_args}
  235. return graph
  236. finally:
  237. for name, (_, orig) in self.patched_torch_methods.items():
  238. setattr(torch, name, orig)
  239. def symbolic_trace(
  240. root: Union[torch.nn.Module, Callable[..., Any]],
  241. meta_args: Optional[dict[str, torch.Tensor]] = None,
  242. concrete_args: Optional[dict[str, Any]] = None,
  243. ) -> torch.fx.GraphModule:
  244. tracer = MetaTracer()
  245. graph = tracer.trace(root, meta_args, concrete_args) # type: ignore[arg-type]
  246. name = (
  247. root.__class__.__name__ if isinstance(root, torch.nn.Module) else root.__name__
  248. )
  249. gm = torch.fx.GraphModule(tracer.root, graph, name)
  250. return gm