fake_utils.py 10 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304
  1. # mypy: ignore-errors
  2. import functools
  3. import warnings
  4. from typing import Any, Callable, Union
  5. import torch
  6. import torch.utils._pytree as pytree
  7. from torch._ops import OpOverload
  8. from torch._subclasses.fake_tensor import (
  9. FakeTensor,
  10. FakeTensorMode,
  11. MetadataMismatchError,
  12. tree_flatten_only,
  13. UnsupportedFakeTensorException,
  14. )
  15. from torch.utils._python_dispatch import TorchDispatchMode
  16. aten = torch._ops.ops.aten
  17. def outputs_alias_inputs(outputs, inputs):
  18. input_storages = {
  19. inp._typed_storage()._cdata
  20. for inp in tree_flatten_only(torch.Tensor, inputs)
  21. if torch._C._has_storage(inp)
  22. }
  23. return any(
  24. torch._C._has_storage(out) and out._typed_storage()._cdata in input_storages
  25. for out in tree_flatten_only(torch.Tensor, outputs)
  26. )
  27. def outputs_are_inputs(outputs, inputs):
  28. input_ids = {id(inp) for inp in tree_flatten_only(torch.Tensor, inputs)}
  29. return any(id(out) in input_ids for out in tree_flatten_only(torch.Tensor, outputs))
  30. def output_alias_each_other(outputs):
  31. storages = set()
  32. for out in tree_flatten_only(torch.Tensor, outputs):
  33. if not torch._C._has_storage(out):
  34. continue
  35. stor = out._typed_storage()._cdata
  36. if stor in storages:
  37. return True
  38. storages.add(stor)
  39. return False
  40. def _check_alias_info(context, real_out, real_in, fake_out, fake_in):
  41. r_aliasing = outputs_alias_inputs(real_out, real_in)
  42. f_aliasing = outputs_alias_inputs(fake_out, fake_in)
  43. if r_aliasing != f_aliasing:
  44. raise MetadataMismatchError(
  45. f"{context} mismatch in outputs_alias_inputs check {f_aliasing} != {r_aliasing}"
  46. )
  47. r_identity_eq = outputs_are_inputs(real_out, real_in)
  48. f_identity_eq = outputs_are_inputs(fake_out, fake_in)
  49. if r_identity_eq != f_identity_eq:
  50. raise MetadataMismatchError(
  51. f"{context} mismatch in outputs_are_inputs check {f_identity_eq} != {r_identity_eq}"
  52. )
  53. r_output_alias_each_other = output_alias_each_other(real_out)
  54. f_output_alias_each_other = output_alias_each_other(fake_out)
  55. if r_output_alias_each_other != f_output_alias_each_other:
  56. raise MetadataMismatchError(
  57. f"{context} mismatch in outputs_alias_each_other check "
  58. f"{f_output_alias_each_other} != {r_output_alias_each_other}"
  59. )
  60. def is_sdpa_error(func, idx, e):
  61. if (
  62. (
  63. func is aten._scaled_dot_product_flash_attention.default
  64. or func is aten._flash_attention_forward.default
  65. )
  66. and idx in (6, 7)
  67. and "Devices" in repr(e)
  68. ):
  69. return True
  70. if (
  71. (
  72. func is aten._scaled_dot_product_efficient_attention.default
  73. or func is aten._efficient_attention_forward.default
  74. )
  75. and idx in (2, 3)
  76. and "Devices" in repr(e)
  77. ):
  78. return True
  79. if (
  80. func is aten._scaled_dot_product_cudnn_attention.default
  81. and idx in (6, 7)
  82. and "Devices" in repr(e)
  83. ):
  84. return True
  85. return False
  86. def try_convert_fake_to_real(
  87. ten_list: list[Union[FakeTensor, Any]],
  88. ) -> list[Union[FakeTensor, torch.Tensor, Any]]:
  89. """
  90. Attempt to convert fake tensors to a corresponding real tensor with the correct underlying storage by looking up
  91. the FakeTensorMode meta to real storage mapping. On failure to find the storage mapping, the FakeTensor will
  92. remain in the list.
  93. Note: this is not currently optimized (makes copies of the meta converter internal dictionaries)
  94. """
  95. fake_tensor = next(
  96. (item for item in ten_list if isinstance(item, FakeTensor)), None
  97. )
  98. if fake_tensor is None:
  99. return ten_list
  100. fake_mode = fake_tensor.fake_mode
  101. meta_converter = fake_mode.fake_tensor_converter.meta_converter
  102. desc = meta_converter.describer
  103. storage_to_key = {v: k for k, v in meta_converter.storage_memo.items()}
  104. key_to_real_storage = {v: k for k, v in desc.lookup_storage.items()}
  105. out = []
  106. for t in ten_list:
  107. if not isinstance(t, FakeTensor) or not t.layout == torch.strided:
  108. out.append(t)
  109. continue
  110. key = storage_to_key.get(t.untyped_storage())
  111. real_storage = None if key is None else key_to_real_storage.get(key)
  112. if real_storage is None:
  113. out.append(t)
  114. continue
  115. unhinted = False
  116. def map_symint(s):
  117. nonlocal unhinted
  118. if not isinstance(s, torch.SymInt):
  119. return s
  120. unhinted = unhinted if not unhinted else s.node.has_hint()
  121. return s.node.hint
  122. stor_offset = map_symint(t.storage_offset())
  123. size = [map_symint(s) for s in t.shape]
  124. stride = [map_symint(s) for s in t.stride()]
  125. if unhinted:
  126. out.append(t)
  127. continue
  128. new_tensor = torch.empty(
  129. [],
  130. dtype=t.dtype,
  131. device=t.device,
  132. )
  133. new_tensor.set_(
  134. real_storage,
  135. storage_offset=stor_offset,
  136. size=size,
  137. stride=stride,
  138. )
  139. out.append(new_tensor.clone())
  140. return out
  141. def _check_fake_real_tensors(
  142. real_out: torch.Tensor,
  143. fake_out: FakeTensor,
  144. context="",
  145. sizes=True,
  146. strides=False,
  147. storage_offset=True,
  148. requires_grad=True,
  149. ):
  150. if requires_grad:
  151. if real_out.requires_grad != fake_out.requires_grad:
  152. raise MetadataMismatchError(
  153. f"{context} mismatched requires_grad-ness of outputs. "
  154. f"This usually means that you have added autograd support "
  155. f"for your operator at a dispatch key other than Autograd, "
  156. f"which will lead to problems"
  157. )
  158. if torch._C._has_storage(real_out):
  159. r_offset = real_out.storage_offset()
  160. f_offset = fake_out.storage_offset()
  161. if r_offset != f_offset:
  162. raise MetadataMismatchError(f"{context} mismatched storage offset")
  163. torch._prims.utils.compare_tensor_meta(
  164. real_out,
  165. fake_out,
  166. check_sizes=sizes,
  167. check_strides=strides,
  168. allow_rhs_unbacked=True,
  169. )
  170. class CrossRefFakeMode(TorchDispatchMode):
  171. def __init__(
  172. self,
  173. ignore_op_fn: Union[Callable[[OpOverload], bool], None] = None,
  174. *,
  175. check_strides=True,
  176. check_aliasing=True,
  177. only_check_ops_with_meta=True,
  178. ):
  179. super().__init__()
  180. self.ignore_op_fn = (
  181. ignore_op_fn if ignore_op_fn is not None else lambda fn: False
  182. )
  183. self.check_strides = check_strides
  184. self.check_aliasing = check_aliasing
  185. self.only_check_ops_with_meta = only_check_ops_with_meta
  186. def __torch_dispatch__(self, func, types, args=(), kwargs=None):
  187. kwargs = kwargs or {}
  188. fake_r = None
  189. # empty_like excluded for now due to sparse complex
  190. # aten._to_dense.default this one is getting called with csc
  191. if (
  192. func
  193. not in (
  194. aten.lift_fresh.default,
  195. aten.lift_fresh_copy.default,
  196. aten.set_.source_Storage_storage_offset,
  197. )
  198. and not self.ignore_op_fn(func)
  199. and (
  200. not self.only_check_ops_with_meta
  201. or torch._subclasses.fake_impls.has_meta(func)
  202. )
  203. and torch.Tag.dynamic_output_shape not in func.tags
  204. and torch.Tag.inplace_view not in func.tags
  205. and torch.Tag.data_dependent_output not in func.tags
  206. ):
  207. # Do not import symbolic_shapes at the top of the module as it imports sympy and that's slow
  208. from torch.fx.experimental.symbolic_shapes import ShapeEnv
  209. try:
  210. # TODO: enable_python_dispatcher() here
  211. with FakeTensorMode(shape_env=ShapeEnv()) as fake_mode:
  212. fake_args, fake_kwargs = pytree.tree_map_only(
  213. torch.Tensor,
  214. functools.partial(fake_mode.from_tensor, static_shapes=True),
  215. (args, kwargs),
  216. )
  217. with warnings.catch_warnings():
  218. fake_r = func(*fake_args, **fake_kwargs)
  219. except UnsupportedFakeTensorException:
  220. pass
  221. context = (
  222. f"When comparing the output of {func} on FakeTensor and concrete Tensors, "
  223. f"found"
  224. )
  225. r = func(*args, **kwargs)
  226. if fake_r is not None:
  227. r_flat = pytree.tree_leaves(r)
  228. f_flat = pytree.tree_leaves(fake_r)
  229. assert len(f_flat) == len(r_flat), (
  230. f"{context} mismatch in number of returns {len(f_flat)} != {len(r_flat)}"
  231. )
  232. if self.check_aliasing:
  233. _check_alias_info(
  234. context, r, (args, kwargs), fake_r, (fake_args, fake_kwargs)
  235. )
  236. for idx, (r_out, f_out) in enumerate(
  237. zip(pytree.tree_leaves(r), pytree.tree_leaves(fake_r))
  238. ):
  239. r_is_ten = isinstance(r_out, torch.Tensor)
  240. assert r_is_ten == isinstance(f_out, torch.Tensor), (
  241. f"{context} mismatched number of tensor outputs"
  242. )
  243. if r_is_ten:
  244. try:
  245. _check_fake_real_tensors(
  246. r_out,
  247. f_out,
  248. sizes=True,
  249. strides=self.check_strides,
  250. storage_offset=True,
  251. requires_grad=True,
  252. )
  253. except Exception as e:
  254. if is_sdpa_error(func, idx, e):
  255. continue
  256. error_message = (
  257. f"{context} mismatched tensor metadata: {e}"
  258. if len(r_flat) == 1
  259. else f"{context} mismatched tensor metadata for output[{idx}]: {e}"
  260. )
  261. raise MetadataMismatchError(error_message) from e
  262. return r