fake_impls.py 43 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160116111621163116411651166116711681169117011711172117311741175117611771178117911801181118211831184118511861187118811891190119111921193119411951196119711981199120012011202120312041205120612071208120912101211121212131214121512161217121812191220122112221223122412251226122712281229123012311232123312341235123612371238123912401241124212431244124512461247124812491250125112521253125412551256125712581259126012611262126312641265126612671268126912701271127212731274127512761277127812791280128112821283128412851286128712881289129012911292129312941295129612971298129913001301130213031304130513061307130813091310131113121313131413151316131713181319132013211322132313241325132613271328132913301331
  1. # mypy: ignore-errors
  2. import functools
  3. import itertools
  4. import math
  5. import operator
  6. import sys
  7. from functools import reduce
  8. from typing import Callable, Union
  9. import torch
  10. import torch._custom_op
  11. import torch._logging
  12. import torch._prims_common as utils
  13. from torch._dispatch.python import no_python_dispatcher
  14. from torch._ops import OpOverload
  15. from torch._prims_common import (
  16. elementwise_dtypes,
  17. ELEMENTWISE_TYPE_PROMOTION_KIND,
  18. is_boolean_dtype,
  19. is_contiguous,
  20. is_contiguous_for_memory_format_or_false,
  21. is_contiguous_or_false,
  22. is_float_dtype,
  23. is_integer_dtype,
  24. make_contiguous_strides_for,
  25. )
  26. from torch._subclasses.fake_tensor import (
  27. DataDependentOutputException,
  28. DynamicOutputShapeException,
  29. FakeTensor,
  30. in_kernel_invocation_manager,
  31. run_fallback_kernel,
  32. UnsupportedOperatorException,
  33. )
  34. from torch.fx.operator_schemas import normalize_function
  35. from torch.utils._stats import count_label
  36. pytree = torch.utils._pytree
  37. __all__ = [
  38. "op_implementations_checks",
  39. "get_fast_op_impls",
  40. "stride_incorrect_op",
  41. "has_meta",
  42. ]
  43. op_implementations_dict = {}
  44. op_implementations_checks = []
  45. aten = torch._ops.ops.aten
  46. def ordered_set(*items):
  47. return dict.fromkeys(items, True)
  48. # This function indicates if the backend device
  49. # supports non-contiguous tensors
  50. def is_noncontiguous_supported(device):
  51. return device.type != "hpu"
  52. _like_tensor_constructors = ordered_set(
  53. aten.empty_like.default,
  54. aten.empty_like.out,
  55. aten.full_like.default,
  56. aten.full_like.out,
  57. aten.ones_like.default,
  58. aten.ones_like.out,
  59. aten.rand_like.default,
  60. aten.rand_like.out,
  61. aten.randn_like.default,
  62. aten.randn_like.out,
  63. aten.randint_like.default,
  64. aten.randint_like.Tensor,
  65. aten.randint_like.Tensor_out,
  66. aten.randint_like.out,
  67. aten.randint_like.low_dtype,
  68. aten.randint_like.low_dtype_out,
  69. aten.zeros_like.default,
  70. aten.zeros_like.out,
  71. aten.new_empty.default,
  72. aten.new_empty.out,
  73. aten.new_empty_strided.default,
  74. aten.new_empty_strided.out,
  75. aten.new_full.default,
  76. aten.new_full.out,
  77. aten.new_zeros.default,
  78. aten.new_zeros.out,
  79. aten.new_ones.default,
  80. aten.new_ones.out,
  81. )
  82. _device_not_kwarg_ops = ordered_set(
  83. aten._resize_output_.default,
  84. aten._nested_tensor_from_tensor_list.default,
  85. aten._nested_tensor_from_tensor_list.out,
  86. aten.pin_memory.default,
  87. aten.to.device,
  88. aten.to.prim_Device,
  89. aten.is_pinned.default,
  90. aten._pin_memory.default,
  91. aten._pin_memory.out,
  92. aten._resize_output.default,
  93. aten._resize_output.out,
  94. )
  95. # this op is never actually used
  96. _non_kwarg_device_constructors = (aten._list_to_tensor,)
  97. def contains_tensor_types(type):
  98. tensor_type = torch._C.TensorType.get()
  99. return type.isSubtypeOf(tensor_type) or any(
  100. contains_tensor_types(e) for e in type.containedTypes()
  101. )
  102. @functools.cache
  103. def _is_tensor_constructor(func: OpOverload):
  104. assert isinstance(func, OpOverload)
  105. schema = func._schema
  106. if any(contains_tensor_types(arg.type) for arg in schema.arguments):
  107. return False
  108. # TODO: no real reason to restrict multiple outputs
  109. return (
  110. len(schema.returns) == 1 and schema.returns[0].type is torch._C.TensorType.get()
  111. )
  112. def register_op_impl(run_impl_check: Union[Callable[[OpOverload], bool], OpOverload]):
  113. def impl_decorator(op_impl):
  114. if isinstance(run_impl_check, OpOverload):
  115. assert run_impl_check not in op_implementations_dict, (
  116. f"duplicate registration: {run_impl_check}"
  117. )
  118. op_implementations_dict[run_impl_check] = op_impl
  119. elif isinstance(run_impl_check, (list, tuple)):
  120. for op in run_impl_check:
  121. register_op_impl(op)(op_impl)
  122. else:
  123. assert callable(run_impl_check)
  124. op_implementations_checks.append((run_impl_check, op_impl))
  125. return op_impl
  126. return impl_decorator
  127. def _is_op_registered_to_fake_rule(op):
  128. return op in op_implementations_dict
  129. def _deregister_op_impl(op):
  130. if op in op_implementations_dict:
  131. del op_implementations_dict[op]
  132. for check, impl in op_implementations_checks:
  133. if check is op:
  134. op_implementations_checks.remove((check, impl))
  135. break
  136. @register_op_impl(op_implementations_dict.__contains__)
  137. def dispatch_to_op_implementations_dict(fake_mode, func, *args, **kwargs):
  138. return op_implementations_dict[func](fake_mode, func, *args, **kwargs)
  139. @register_op_impl(_is_tensor_constructor)
  140. @register_op_impl([*_like_tensor_constructors])
  141. def constructors(fake_mode, func, *args, **kwargs):
  142. assert func not in _non_kwarg_device_constructors
  143. _, new_kwargs = normalize_function(
  144. func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
  145. )
  146. if "names" in kwargs:
  147. raise UnsupportedOperatorException(
  148. "torch.compile doesn't support named tensors"
  149. )
  150. if func in _like_tensor_constructors:
  151. default_device = new_kwargs["input"].device
  152. # TODO: file issue
  153. args = (new_kwargs.pop("input"),)
  154. else:
  155. # cpu is default device if none is specified
  156. default_device = torch.device("cpu")
  157. args = ()
  158. out_device = new_kwargs.pop("device", None)
  159. out_device = out_device if out_device is not None else default_device
  160. new_kwargs["device"] = torch.device("meta")
  161. # _like constructors have fake tensor inputs (maybe this causes the non-like
  162. # to fail? hmmm)
  163. with in_kernel_invocation_manager(fake_mode):
  164. r = func(*args, **new_kwargs)
  165. return FakeTensor(fake_mode, r, out_device)
  166. @register_op_impl(aten.is_pinned.default)
  167. def non_kwarg_is_pinned(fake_mode, func, *args, **kwargs):
  168. _, new_kwargs = normalize_function(
  169. func, args, kwargs, normalize_to_only_use_kwargs=True
  170. )
  171. inp = new_kwargs.pop("input")
  172. # we'll ignore device argument because it is deprecated and not
  173. # actually used by is_pinned.
  174. with in_kernel_invocation_manager(fake_mode):
  175. r = func(inp)
  176. return r
  177. @register_op_impl(aten.to.prim_Device)
  178. @register_op_impl(aten.to.device)
  179. def non_kwarg_to(fake_mode, func, *args, **kwargs):
  180. _, new_kwargs = normalize_function(
  181. func, args, kwargs, normalize_to_only_use_kwargs=True
  182. )
  183. input_device = new_kwargs["device"]
  184. out_device = input_device if input_device else new_kwargs["input"].device
  185. new_kwargs["device"] = torch.device("meta")
  186. inp = new_kwargs.pop("input")
  187. with in_kernel_invocation_manager(fake_mode):
  188. r = func(inp, **new_kwargs)
  189. # TODO: I think this does the wrong thing if r is inp
  190. return fake_mode.fake_tensor_converter.from_meta_and_device(
  191. fake_mode, r, out_device
  192. )
  193. def stride_incorrect_op(op):
  194. return False
  195. # These operators have meta implementations with incorrect strides
  196. @register_op_impl(stride_incorrect_op)
  197. def wordaround_stride_incorrect_op(fake_mode, func, *args, **kwargs):
  198. # This is a workaround for meta implementations with incorrect strides
  199. def is_symbolic(x):
  200. if isinstance(x, FakeTensor):
  201. return x._has_symbolic_sizes_strides
  202. if isinstance(x, (torch.SymInt, torch.SymFloat, torch.SymBool)):
  203. return True
  204. return False
  205. # For static shapes, we can fall back to eager for the real strides
  206. if fake_mode.allow_fallback_kernels:
  207. require_dynamic = any(
  208. is_symbolic(x) for x in itertools.chain(args, kwargs.values())
  209. )
  210. if not require_dynamic:
  211. flat_args, args_spec = pytree.tree_flatten((args, kwargs))
  212. return run_fallback_kernel(fake_mode, func, flat_args, args_spec, None)
  213. raise UnsupportedOperatorException(func)
  214. # Dont default to default device handling,
  215. # since the device of `the_template` is ignored
  216. @register_op_impl(aten.resize_as_.default)
  217. def resize_as_(fake_mode, func, *args, **kwargs):
  218. with in_kernel_invocation_manager(fake_mode):
  219. return func(*args, **kwargs)
  220. @register_op_impl(aten._sparse_coo_tensor_with_dims_and_tensors.default)
  221. def _sparse_coo_tensor_with_dims_and_tensors(fake_mode, func, *args, **kwargs):
  222. # TODO: remove me
  223. return constructors(fake_mode, func, *args, **kwargs)
  224. # index.Tensor data-dependent in only some conditions
  225. @register_op_impl(
  226. lambda func: torch.Tag.dynamic_output_shape in func.tags
  227. and func
  228. not in [aten.index.Tensor, aten.nonzero.default, aten.repeat_interleave.Tensor]
  229. )
  230. def dyn_shape(fake_mode, func, *args, **kwargs):
  231. raise DynamicOutputShapeException(func)
  232. def _unique(
  233. fake_mode,
  234. func,
  235. arg,
  236. dim,
  237. sorted=True,
  238. return_inverse=False,
  239. return_counts=False,
  240. *,
  241. unique_consecutive=False,
  242. ):
  243. if (
  244. fake_mode.shape_env is None
  245. or not fake_mode.shape_env.allow_dynamic_output_shape_ops
  246. ):
  247. # Without symints/symfloats, cannot handle this
  248. raise DynamicOutputShapeException(func)
  249. nnz = arg.unique_consecutive_memo if unique_consecutive else arg.unique_memo
  250. # Do not use a memo for unique_dim
  251. if dim is not None or nnz is None:
  252. # Avoid importing sympy at a module level
  253. from torch.fx.experimental.symbolic_shapes import (
  254. _constrain_range_for_size,
  255. has_free_symbols,
  256. )
  257. if not has_free_symbols(arg.numel()) and arg.numel() == 0:
  258. # If numel is zero, then the output size must be zero.
  259. # In this case, we must not allocate an unbacked SymInt,
  260. # because if we do, it will immediately get refined to
  261. # zero, but this will be inconsistent with size oblivious
  262. # tests (which will continue to claim that the unbacked
  263. # symint cannot equal zero). We could also unconditionally
  264. # allocate an unbacked SymInt and not refine its range,
  265. # but this seems more precise.
  266. nnz = 0
  267. else:
  268. nnz = fake_mode.shape_env.create_unbacked_symint()
  269. maxval = sys.maxsize - 1
  270. numel = arg.numel() if dim is None else arg.size(dim)
  271. if not has_free_symbols(numel):
  272. maxval = int(numel)
  273. _constrain_range_for_size(nnz, max=maxval)
  274. if dim is None:
  275. if unique_consecutive:
  276. arg.unique_consecutive_memo = nnz
  277. else:
  278. arg.unique_memo = nnz
  279. if dim is None:
  280. ret = [arg.new_empty((nnz,))]
  281. else:
  282. ret = [arg.new_empty(*arg.shape[:dim], nnz, *arg.shape[dim + 1 :])]
  283. return_if_dim_and_cpu = dim is not None and arg.fake_device == torch.device("cpu")
  284. if return_inverse or return_if_dim_and_cpu:
  285. inverse = arg.new_empty(arg.shape if dim is None else (arg.shape[dim],))
  286. else:
  287. inverse = arg.new_empty(0)
  288. ret.append(inverse)
  289. if return_counts or return_if_dim_and_cpu:
  290. counts = arg.new_empty(ret[0].shape if dim is None else (ret[0].shape[dim],))
  291. else:
  292. counts = arg.new_empty(0)
  293. ret.append(counts)
  294. return tuple(ret)
  295. @register_op_impl(aten._unique2.default)
  296. def unique2(
  297. fake_mode, func, arg, sorted=True, return_inverse=False, return_counts=False
  298. ):
  299. return _unique(fake_mode, func, arg, None, sorted, return_inverse, return_counts)
  300. @register_op_impl(aten.select.int)
  301. def meta_select(fake_mode, func, self, dim, index):
  302. from torch.fx.experimental.symbolic_shapes import guard_or_false
  303. if self.is_sparse:
  304. return NotImplemented
  305. ndim = self.dim()
  306. torch._check_index(
  307. ndim != 0,
  308. lambda: "select() cannot be applied to a 0-dim tensor.",
  309. )
  310. dim = dim if dim >= 0 else dim + ndim
  311. size = self.size(dim)
  312. new_size = list(self.size())
  313. new_stride = list(self.stride())
  314. new_storage_offset = None
  315. if guard_or_false(index >= 0):
  316. new_storage_offset = self.storage_offset() + index * new_stride[dim]
  317. elif guard_or_false(index < 0):
  318. new_storage_offset = self.storage_offset() + (index + size) * new_stride[dim]
  319. if new_storage_offset is None:
  320. if fake_mode.shape_env is None or (
  321. not fake_mode.shape_env.allow_scalar_outputs
  322. and not fake_mode.allow_scalar_outputs
  323. ):
  324. raise DataDependentOutputException(func)
  325. # index is data-dependent, we do not know which index we are accessing it could be index or index+size!
  326. # we assign a new data-dependent symbol for the storage offset.
  327. new_storage_offset = fake_mode.shape_env.create_unbacked_symint()
  328. del new_size[dim]
  329. del new_stride[dim]
  330. assert new_storage_offset is not None
  331. return self.as_strided(new_size, new_stride, new_storage_offset)
  332. @register_op_impl(aten.unique_dim.default)
  333. def unique_dim(
  334. fake_mode, func, arg, dim, sorted=True, return_inverse=False, return_counts=False
  335. ):
  336. return _unique(
  337. fake_mode,
  338. func,
  339. arg,
  340. # normalize dim to be non-negative
  341. dim if dim >= 0 else dim % max(arg.ndim, 1),
  342. sorted,
  343. return_inverse,
  344. return_counts,
  345. )
  346. @register_op_impl(aten.unique_consecutive.default)
  347. def _(fake_mode, func, arg, return_inverse=False, return_counts=False, dim=None):
  348. return _unique(
  349. fake_mode,
  350. func,
  351. arg,
  352. dim,
  353. False,
  354. return_inverse,
  355. return_counts,
  356. unique_consecutive=True,
  357. )
  358. # This function is python match of computeStride_impl in TensorUtils.cpp
  359. def _compute_stride(old_shape, old_stride, new_shape, size_oblivious=False):
  360. from torch.fx.experimental.symbolic_shapes import (
  361. guard_or_false,
  362. guard_or_true,
  363. sym_eq,
  364. )
  365. def maybe_guard_or_false(x):
  366. if size_oblivious:
  367. return guard_or_false(x)
  368. return x
  369. def maybe_guard_or_true(x):
  370. if size_oblivious:
  371. return guard_or_true(x)
  372. return x
  373. if len(old_shape) == 0:
  374. return [1] * len(new_shape)
  375. numel = reduce(operator.mul, old_shape, 1)
  376. zero_numel = maybe_guard_or_false(numel == 0)
  377. if zero_numel and maybe_guard_or_false(sym_eq(old_shape, new_shape)):
  378. return old_stride
  379. new_stride = [0] * len(new_shape)
  380. if zero_numel:
  381. for view_d in range(len(new_shape) - 1, -1, -1):
  382. if view_d == len(new_shape) - 1:
  383. new_stride[view_d] = 1
  384. else:
  385. new_stride[view_d] = (
  386. max(new_shape[view_d + 1], 1) * new_stride[view_d + 1]
  387. )
  388. return new_stride
  389. view_d = len(new_shape) - 1
  390. chunk_base_stride = old_stride[-1]
  391. tensor_numel = 1
  392. view_numel = 1
  393. for tensor_d in range(len(old_shape) - 1, -1, -1):
  394. tensor_numel *= old_shape[tensor_d]
  395. if tensor_d == 0 or (
  396. maybe_guard_or_true(old_shape[tensor_d - 1] != 1)
  397. and maybe_guard_or_true(
  398. old_stride[tensor_d - 1] != tensor_numel * chunk_base_stride
  399. )
  400. ):
  401. while view_d >= 0 and (
  402. maybe_guard_or_true(view_numel < tensor_numel)
  403. or maybe_guard_or_false(new_shape[view_d] == 1)
  404. ):
  405. new_stride[view_d] = view_numel * chunk_base_stride
  406. view_numel *= new_shape[view_d]
  407. view_d -= 1
  408. if maybe_guard_or_true(view_numel != tensor_numel):
  409. return None
  410. if tensor_d > 0:
  411. chunk_base_stride = old_stride[tensor_d - 1]
  412. tensor_numel = 1
  413. view_numel = 1
  414. if view_d != -1:
  415. return None
  416. return new_stride
  417. def _view_has_unbacked_input(a, shape):
  418. from torch.fx.experimental.symbolic_shapes import has_hint
  419. shape = utils.extract_shape_from_varargs(shape, validate=False)
  420. return (
  421. any(not has_hint(s) for s in a.size())
  422. or any(not has_hint(s) for s in a.stride())
  423. or any(not has_hint(s) for s in shape)
  424. )
  425. def _view_unbacked_meta(a, shape, size_oblivious_enabled=True):
  426. from torch._prims import view_of
  427. from torch.fx.experimental.symbolic_shapes import guard_or_false, sym_eq
  428. # Creates a valid shape
  429. shape = utils.extract_shape_from_varargs(shape, validate=False)
  430. # Reshape may be given a shape with a -1 length
  431. # This indicates that the dimension's length should be inferred
  432. shape = utils.infer_size(shape, a.numel())
  433. # Special-cases reshaping zero dim tensors
  434. if a.ndim == 0:
  435. _a = a
  436. for length in shape:
  437. torch._check(length == 1)
  438. _a = torch._refs.unsqueeze(_a, -1)
  439. if _a is a:
  440. return view_of(a)
  441. else:
  442. return _a
  443. # Special-cases reshaping to zero dim tensors
  444. if len(shape) == 0:
  445. _a = a
  446. for length in a.shape:
  447. torch._check(length == 1)
  448. _a = torch._refs.squeeze(_a, -1)
  449. if _a is a:
  450. return view_of(a)
  451. else:
  452. return _a
  453. shape_numel = reduce(operator.mul, shape, 1)
  454. torch._check(
  455. a.numel() == shape_numel,
  456. lambda: f"Could not reshape a tensor with shape {a.shape} as a tensor with shape {shape}!",
  457. )
  458. if len(shape) == len(a.shape) and guard_or_false(sym_eq(shape, a.shape)):
  459. return view_of(a)
  460. if is_contiguous_or_false(a) if size_oblivious_enabled else is_contiguous(a):
  461. strides = make_contiguous_strides_for(shape)
  462. return a.as_strided(shape, strides)
  463. new_strides = _compute_stride(
  464. a.size(), a.stride(), shape, size_oblivious=size_oblivious_enabled
  465. )
  466. if new_strides is not None:
  467. return a.as_strided(shape, new_strides)
  468. # If we fail to do size oblivious view, and backed_size_oblivious was on,
  469. # then we redo everything by looking at hints and guarding instead of failing.
  470. # Also if the expression has unbacked symbols, then we run again with size_oblivious_enabled=False
  471. # to throw a data dependent error.
  472. if size_oblivious_enabled and (
  473. torch.fx.experimental._config.backed_size_oblivious
  474. or _view_has_unbacked_input(a, shape)
  475. ):
  476. return _view_unbacked_meta(a, shape, size_oblivious_enabled=False)
  477. msg = f"Cannot view a tensor with shape {a.shape} and strides {a.stride()} as a tensor with shape {shape}!"
  478. raise ValueError(msg)
  479. @register_op_impl(aten.view.default)
  480. @register_op_impl(aten._unsafe_view.default)
  481. def _view_meta(fake_mode, func, a, *shape):
  482. if torch.fx.experimental._config.backed_size_oblivious or _view_has_unbacked_input(
  483. a, shape
  484. ):
  485. return _view_unbacked_meta(a, shape)
  486. else:
  487. return torch._refs._reshape_view_helper(a, *shape, allow_copy=False)
  488. @register_op_impl(aten.view_copy.default)
  489. def _view_meta_copy(fake_mode, func, a, *shape, out=None):
  490. result = _view_meta(fake_mode, func, a, *shape)
  491. if out is not None:
  492. return result
  493. return pytree.tree_map(
  494. lambda x: x.clone(memory_format=torch.contiguous_format),
  495. result,
  496. )
  497. @register_op_impl(aten.repeat_interleave.Tensor)
  498. def repeat_interleave_tensor(fake_mode, func, repeats, output_size=None):
  499. if output_size is None:
  500. if (
  501. fake_mode.shape_env is None
  502. or not fake_mode.shape_env.allow_dynamic_output_shape_ops
  503. ):
  504. raise DynamicOutputShapeException(func)
  505. output_size = fake_mode.shape_env.create_unbacked_symint()
  506. # Avoid importing sympy at a module level
  507. from torch.fx.experimental.symbolic_shapes import _constrain_range_for_size
  508. _constrain_range_for_size(output_size)
  509. # TODO: consider a memo
  510. return repeats.new_empty(output_size)
  511. @register_op_impl(torch.ops.aten.item.default)
  512. @register_op_impl(torch.ops.aten._local_scalar_dense.default)
  513. def local_scalar_dense(fake_mode, func, arg):
  514. if (r := arg.item_memo) is not None:
  515. return r
  516. if fake_mode.shape_env is None or (
  517. not fake_mode.shape_env.allow_scalar_outputs
  518. and not fake_mode.allow_scalar_outputs
  519. ):
  520. # Without symints/symfloats, cannot handle this
  521. raise DataDependentOutputException(func)
  522. if is_float_dtype(arg.dtype):
  523. r = fake_mode.shape_env.create_unbacked_symfloat()
  524. elif is_integer_dtype(arg.dtype):
  525. r = fake_mode.shape_env.create_unbacked_symint()
  526. elif is_boolean_dtype(arg.dtype):
  527. r = fake_mode.shape_env.create_unbacked_symbool()
  528. else:
  529. raise NotImplementedError(f"local_scalar_dense/item NYI for {arg.dtype}")
  530. arg.item_memo = r
  531. return r
  532. @register_op_impl(torch.ops.aten.nonzero_numpy.default)
  533. def nonzero_numpy(fake_mode, func, arg):
  534. return torch.ops.aten.nonzero.default(arg).unbind(1)
  535. @register_op_impl(torch.ops.aten.nonzero.default)
  536. def nonzero(fake_mode, func, arg):
  537. if (
  538. fake_mode.shape_env is None
  539. or not fake_mode.shape_env.allow_dynamic_output_shape_ops
  540. ):
  541. # Without symints/symfloats, cannot handle this
  542. raise DynamicOutputShapeException(func)
  543. if (nnz := arg.nonzero_memo) is None:
  544. # Avoid importing sympy at a module level
  545. from torch.fx.experimental.symbolic_shapes import (
  546. _constrain_range_for_size,
  547. has_free_symbols,
  548. )
  549. from torch.utils._sympy.numbers import IntInfinity
  550. from torch.utils._sympy.value_ranges import bound_sympy
  551. if not has_free_symbols(arg.numel()) and arg.numel() == 0:
  552. # If numel is zero, then the output size must be zero.
  553. # In this case, we must not allocate an unbacked SymInt,
  554. # because if we do, it will immediately get refined to
  555. # zero, but this will be inconsistent with size oblivious
  556. # tests (which will continue to claim that the unbacked
  557. # symint cannot equal zero). We could also unconditionally
  558. # allocate an unbacked SymInt and not refine its range,
  559. # but this seems more precise.
  560. nnz = 0
  561. else:
  562. nnz = fake_mode.shape_env.create_unbacked_symint()
  563. maxval = sys.maxsize - 1
  564. if not has_free_symbols(arg.numel()):
  565. maxval = int(arg.numel())
  566. else:
  567. prod_node = math.prod(arg.shape).node
  568. prod_range = bound_sympy(
  569. prod_node.expr, prod_node.shape_env.var_to_range
  570. )
  571. if isinstance(prod_range.upper, IntInfinity):
  572. maxval = sys.maxsize - 1
  573. else:
  574. maxval = prod_range.upper
  575. _constrain_range_for_size(nnz, max=maxval)
  576. arg.nonzero_memo = nnz
  577. return arg.new_empty_strided((nnz, arg.dim()), (1, nnz), dtype=torch.int64)
  578. @register_op_impl(torch.ops.aten._padded_dense_to_jagged_forward.default)
  579. def _padded_dense_to_jagged_forward(fake_mode, func, padded, offsets, total_L=None):
  580. # only one jagged dim is supported for now
  581. assert len(offsets) == 1
  582. if not total_L:
  583. if (
  584. fake_mode.shape_env is None
  585. or not fake_mode.shape_env.allow_dynamic_output_shape_ops
  586. ):
  587. # Without symints/symfloats, cannot handle this
  588. raise DynamicOutputShapeException(func)
  589. total_L = fake_mode.shape_env.create_unbacked_symint()
  590. maxval = sys.maxsize - 1
  591. # Avoid importing sympy at a module level
  592. from torch.fx.experimental.symbolic_shapes import (
  593. _constrain_range_for_size,
  594. has_free_symbols,
  595. )
  596. if not has_free_symbols(padded.numel()):
  597. maxval = int(padded.numel())
  598. _constrain_range_for_size(total_L, min=0, max=maxval)
  599. output_shape = (total_L, *padded.shape[2:])
  600. return padded.new_empty(output_shape)
  601. @register_op_impl(torch.ops.aten.masked_select.default)
  602. def masked_select(fake_mode, func, self, mask):
  603. if (
  604. fake_mode.shape_env is None
  605. or not fake_mode.shape_env.allow_dynamic_output_shape_ops
  606. ):
  607. # Without symints/symfloats, cannot handle this
  608. raise DynamicOutputShapeException(func)
  609. nnz = fake_mode.shape_env.create_unbacked_symint()
  610. # see nonzero for commentary
  611. maxval = sys.maxsize - 1
  612. # Avoid importing sympy at a module level
  613. from torch.fx.experimental.symbolic_shapes import (
  614. _constrain_range_for_size,
  615. has_free_symbols,
  616. )
  617. from torch.utils._sympy.numbers import IntInfinity
  618. from torch.utils._sympy.value_ranges import bound_sympy
  619. # If num elements is expressed symbolically, calculate
  620. # the concrete value based on upper bounds. Otherwise,
  621. # we can set max val directly.
  622. if not has_free_symbols(self.numel()):
  623. num_elements = int(self.numel())
  624. else:
  625. prod_node = math.prod(self.shape).node
  626. prod_range = bound_sympy(prod_node.expr, prod_node.shape_env.var_to_range)
  627. if isinstance(prod_range.upper, IntInfinity):
  628. num_elements = sys.maxsize - 1
  629. else:
  630. num_elements = prod_range.upper
  631. if num_elements > 2:
  632. maxval = num_elements
  633. _constrain_range_for_size(nnz, max=maxval)
  634. return self.new_empty((nnz,))
  635. @register_op_impl(torch.ops.aten._assert_tensor_metadata.default)
  636. def assert_tensor_metadata(
  637. fake_mode,
  638. func,
  639. t,
  640. sizes=None,
  641. strides=None,
  642. dtype=None,
  643. *,
  644. device=None,
  645. layout=None,
  646. ) -> None:
  647. if sizes is not None:
  648. assert t.size() == sizes, (
  649. f"Tensor sizes mismatch! Expected: {sizes}, Got: {t.size()}"
  650. )
  651. if strides is not None:
  652. assert t.stride() == strides, (
  653. f"Tensor strides mismatch! Expected: {strides}, Got: {t.stride()}"
  654. )
  655. if dtype is not None:
  656. assert t.dtype == dtype, (
  657. f"Tensor dtype mismatch! Expected: {dtype}, Got: {t.dtype}"
  658. )
  659. if layout is not None:
  660. assert t.layout == layout, (
  661. f"Tensor layout mismatch! Expected: {layout}, Got: {t.layout()}"
  662. )
  663. if device is not None:
  664. assert t.device == device, (
  665. f"Tensor device mismatch! Expected: {device}, Got: {t.device}"
  666. )
  667. # NB: this must be ordered after local_scalar_dense
  668. @register_op_impl(lambda func: torch.Tag.data_dependent_output in func.tags)
  669. def data_dep(fake_mode, func, *args, **kwargs):
  670. raise DataDependentOutputException(func)
  671. # Bool Indices get Expanded as Masks
  672. # See: IndexingUtils.h:expandTensors
  673. def check_no_bool_index_tensors(func, self, indices):
  674. for index in indices:
  675. if index is not None and index.dtype in (torch.bool, torch.uint8):
  676. raise DynamicOutputShapeException(func)
  677. def run_and_return_new_tensor_of_input_device(fake_mode, func, args, kwargs):
  678. _, new_kwargs = normalize_function(
  679. func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
  680. )
  681. out_device = new_kwargs["input"].device
  682. with in_kernel_invocation_manager(fake_mode):
  683. out = func(*args, **kwargs)
  684. if not is_noncontiguous_supported(out_device):
  685. out = out.new_empty(out.shape)
  686. if out is new_kwargs["input"]:
  687. return out # copy_
  688. return FakeTensor(fake_mode, out, out_device)
  689. _is_builtin_namespaces = ordered_set("aten", "prims", "prim")
  690. def is_builtin(op):
  691. return op.namespace in _is_builtin_namespaces
  692. def has_meta(func):
  693. return torch._C._dispatch_has_computed_kernel_for_dispatch_key(func.name(), "Meta")
  694. # These are for the `torch._foreach_...` ops like `torch._foreach_add`.
  695. @register_op_impl(
  696. lambda func: is_builtin(func)
  697. and func.name().startswith("aten::_foreach_")
  698. and has_meta(func)
  699. )
  700. def foreach_run_and_map_input_device(fake_mode, func, *args, **kwargs):
  701. tensor_lists = [
  702. arg
  703. for arg in itertools.chain(args, kwargs.values())
  704. if isinstance(arg, (list, tuple))
  705. and len(arg)
  706. and isinstance(arg[0], torch.Tensor)
  707. ]
  708. try:
  709. with in_kernel_invocation_manager(fake_mode):
  710. out_meta = func(*args, **kwargs)
  711. except NotImplementedError:
  712. return NotImplemented
  713. if not out_meta:
  714. return out_meta
  715. assert tensor_lists
  716. out_fake = []
  717. for i, meta_t in enumerate(out_meta):
  718. device, _ = FakeTensor._find_common_device(func, [tl[i] for tl in tensor_lists])
  719. out_fake.append(
  720. fake_mode.fake_tensor_converter.from_meta_and_device(
  721. fake_mode, meta_t, device
  722. )
  723. )
  724. return out_fake
  725. # Dont default to default device handling,
  726. # Since op can take in non-zero sized cpu
  727. # index tensors with cuda self
  728. @register_op_impl(aten.index.Tensor)
  729. def index_tensor(fake_mode, func, *args, **kwargs):
  730. from torch._meta_registrations import meta_index_Tensor
  731. _, new_kwargs = normalize_function(
  732. func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
  733. )
  734. out_device = new_kwargs["input"].device
  735. # ensure nonzero call goes to fake tensor
  736. with fake_mode:
  737. out = meta_index_Tensor(*args, **kwargs)
  738. return out.to(out_device)
  739. # Can take mixed meta/non-meta arguments; the meta registration
  740. # will roughly do the right thing even when given real devices
  741. @register_op_impl(aten._embedding_bag.default)
  742. def embedding_bag(fake_mode, func, *args, **kwargs):
  743. from torch._meta_registrations import meta_embedding_bag
  744. with fake_mode:
  745. return meta_embedding_bag(*args, **kwargs)
  746. # takes in multiple-devices, dont default to default device handling
  747. @register_op_impl(aten._unsafe_index_put.default)
  748. @register_op_impl(aten.copy.default)
  749. @register_op_impl(aten.copy_.default)
  750. @register_op_impl(aten.slice_scatter.default)
  751. def multi_device_op_default(fake_mode, func, *args, **kwargs):
  752. return run_and_return_new_tensor_of_input_device(fake_mode, func, args, kwargs)
  753. # same with multi_device_op_default, but return the input
  754. @register_op_impl(aten.copy.out)
  755. @register_op_impl(aten.slice_scatter.out)
  756. def multi_device_op_out(fake_mode, func, *args, **kwargs):
  757. with in_kernel_invocation_manager(fake_mode):
  758. func(*args, **kwargs)
  759. _, new_kwargs = normalize_function(
  760. func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
  761. )
  762. return new_kwargs["input"]
  763. @register_op_impl(aten.index_put.default)
  764. @register_op_impl(aten.index_put_.default)
  765. def index_put_impl(fake_mode, func, *args, **kwargs):
  766. _, new_kwargs = normalize_function(
  767. func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
  768. )
  769. values = new_kwargs["values"]
  770. self_device = new_kwargs["input"].fake_device
  771. torch._check(
  772. self_device == values.fake_device or (values.ndim == 0 and values.numel() == 1),
  773. lambda: f"Mismatching {func} device between self ({self_device}) and values ({values.device})",
  774. )
  775. out = run_and_return_new_tensor_of_input_device(fake_mode, func, args, kwargs)
  776. if func is aten.index_put_.default:
  777. return new_kwargs["input"]
  778. else:
  779. return out
  780. @register_op_impl(aten._nested_tensor_from_tensor_list.default)
  781. @register_op_impl(aten._nested_tensor_from_tensor_list.out)
  782. @register_op_impl(aten._nested_view_from_buffer.default)
  783. @register_op_impl(aten._nested_view_from_buffer_copy.default)
  784. def nested_tensors_unsupported(fake_mode, func, *args, **kwargs):
  785. raise UnsupportedOperatorException(
  786. "torch.compile does not support strided NestedTensor"
  787. )
  788. @register_op_impl(
  789. [
  790. x
  791. for x in _device_not_kwarg_ops
  792. if x
  793. not in (
  794. # these are already registered elsewhere
  795. aten.is_pinned.default,
  796. aten.to.device,
  797. aten.to.prim_Device,
  798. aten._nested_tensor_from_tensor_list.default,
  799. aten._nested_tensor_from_tensor_list.out,
  800. )
  801. ]
  802. )
  803. def nyi(fake_mode, func, *args, **kwargs):
  804. assert func not in _device_not_kwarg_ops, f"NYI: {func}"
  805. @register_op_impl([aten.convolution.default, aten.convolution_backward.default])
  806. def conv(fake_mode, func, *args, **kwargs):
  807. _, kwargs = normalize_function(
  808. func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
  809. )
  810. device = kwargs["input"].fake_device
  811. # need to re-enable mode so the tensors report fake device
  812. with fake_mode:
  813. # if the input is unsqueezed is done in Convolution.cpp we get segfault
  814. k = kwargs["weight"].ndim
  815. batch = kwargs["input"].shape[0]
  816. # Avoid importing sympy at a module level
  817. from torch.fx.experimental.symbolic_shapes import has_hint
  818. if not has_hint(batch):
  819. # TODO: We can make this a little more faithful with best effort
  820. # channels last detection (but only if it's statically obvious!)
  821. mem_fmt = None
  822. elif k == 3 and not kwargs["input"].is_mkldnn and not kwargs["input"].is_xpu:
  823. mem_fmt = None
  824. else:
  825. if func is aten.convolution.default:
  826. conv_backend = torch._C._select_conv_backend(**kwargs)
  827. else:
  828. conv_backend = torch._C._select_conv_backend(
  829. kwargs["input"],
  830. kwargs["weight"],
  831. bias=None,
  832. stride=kwargs["stride"],
  833. padding=kwargs["padding"],
  834. dilation=kwargs["dilation"],
  835. transposed=kwargs["transposed"],
  836. output_padding=kwargs["output_padding"],
  837. groups=kwargs["groups"],
  838. bias_sizes=kwargs["bias_sizes"],
  839. )
  840. mem_fmt = torch._C._conv_determine_backend_memory_format(
  841. kwargs["input"], kwargs["weight"], conv_backend
  842. )
  843. def convert(t, mem_fmt):
  844. if t is None:
  845. return t
  846. if mem_fmt is not None:
  847. t = t.to(memory_format=mem_fmt)
  848. return FakeTensor(fake_mode, t, device)
  849. with in_kernel_invocation_manager(fake_mode):
  850. out = func(**kwargs)
  851. if func is aten.convolution.default:
  852. return convert(out, mem_fmt)
  853. else:
  854. return (
  855. convert(out[0], mem_fmt),
  856. convert(out[1], mem_fmt),
  857. convert(out[2], None),
  858. )
  859. @register_op_impl(torch.ops.aten.bincount.default)
  860. def bincount(fake_mode, func, inputs, weights=None, minlength=0):
  861. if (
  862. fake_mode.shape_env is None
  863. or not fake_mode.shape_env.allow_dynamic_output_shape_ops
  864. ):
  865. # Without symints/symfloats, cannot handle this
  866. raise DynamicOutputShapeException(func)
  867. new_size = fake_mode.shape_env.create_unbacked_symint()
  868. from torch.fx.experimental.symbolic_shapes import _constrain_range_for_size
  869. _constrain_range_for_size(new_size)
  870. torch._check(new_size >= minlength)
  871. return inputs.new_empty(new_size)
  872. @register_op_impl(torch.ops.aten._pack_padded_sequence.default)
  873. def _pack_padded_sequence(fake_mode, func, inputs, lengths, batch_first):
  874. if (
  875. fake_mode.shape_env is None
  876. or not fake_mode.shape_env.allow_dynamic_output_shape_ops
  877. ):
  878. # Without symints/symfloats, cannot handle this
  879. raise DynamicOutputShapeException(func)
  880. new_batch_size = fake_mode.shape_env.create_unbacked_symint()
  881. from torch.fx.experimental.symbolic_shapes import _constrain_range_for_size
  882. _constrain_range_for_size(new_batch_size)
  883. if not batch_first:
  884. # Inputs should have shape (batch_size, seq_len, *)
  885. inputs = inputs.transpose(0, 1)
  886. res_size = inputs.shape[1:]
  887. packed_data = inputs.new_empty(res_size)
  888. batch_size = inputs.new_empty((new_batch_size,))
  889. return (packed_data, batch_size)
  890. FAST_OP_IMPLEMENTATIONS = {}
  891. # Unlike register_op_impl, these don't do the slow iteration for
  892. # run_impl_check, and these run BEFORE decompositions
  893. def register_fast_op_impl(func: OpOverload):
  894. def impl_decorator(op_impl):
  895. FAST_OP_IMPLEMENTATIONS[func] = op_impl
  896. return op_impl
  897. return impl_decorator
  898. # infer_size_impl in ExpandUtils
  899. def infer_size(a, b):
  900. from torch.fx.experimental.symbolic_shapes import guard_or_false
  901. dimsA = len(a)
  902. dimsB = len(b)
  903. ndim = max(dimsA, dimsB)
  904. expandedSizes = [0] * ndim
  905. for i in range(ndim - 1, -1, -1):
  906. offset = ndim - 1 - i
  907. dimA = dimsA - 1 - offset
  908. dimB = dimsB - 1 - offset
  909. sizeA = a[dimA] if dimA >= 0 else 1
  910. sizeB = b[dimB] if dimB >= 0 else 1
  911. # NB: It is very important to test for broadcasting, before testing
  912. # sizeA == sizeB. This is because the broadcasting tests are likely
  913. # to be statically known (in particular, if sizeA/sizeB is unbacked
  914. # but size-like, we will unsoundly assume they never equal 1), but
  915. # the sizeA == sizeB test may not be statically known. However, once
  916. # we have established that no broadcasting is happening, the
  917. # sizeA == sizeB is now expect_true and we can defer it as a runtime
  918. # assert (this works because Python will return the terminal
  919. # expression of an or statement as-is, without bool()'ing it; if this
  920. # were not the case, we'd need to write this using torch.sym_or() or
  921. # something like that).
  922. torch._check(
  923. guard_or_false(sizeA == 1) or guard_or_false(sizeB == 1) or sizeA == sizeB,
  924. lambda: f"The size of tensor a ({sizeA}) "
  925. f"must match the size of tensor b ({sizeB}) "
  926. f"at non-singleton dimension {i})",
  927. )
  928. expandedSizes[i] = sizeB if guard_or_false(sizeA == 1) else sizeA
  929. return tuple(expandedSizes)
  930. def make_fast_binary_impl(
  931. slow_ref, type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT
  932. ):
  933. def fast_binary_impl(mode, *args, **kwargs):
  934. def slow(msg):
  935. count_label(f"slow {msg}")
  936. with mode:
  937. return slow_ref(*args, **kwargs)
  938. count_label("attempt fast")
  939. # Fast path (based off of TensorIterator fast path).
  940. # Unfortunately, there is no way to easily deduplicate
  941. # this with either the TensorIterator C++ implementation
  942. # (which we don't want to SymIntify, and also the algorithm
  943. # here is slightly different from TensorIterator to allow
  944. # for broadcasting), nor the PrimTorch implementation
  945. # (which does not actually implement a fast path.)
  946. operands = args
  947. # compute_shape
  948. final_shape = None
  949. for op in operands:
  950. shape = op.shape if isinstance(op, torch.Tensor) else ()
  951. if final_shape is None:
  952. final_shape = shape
  953. # TODO: Minor optimization: track if the shapes
  954. # were equal so you can skip the equality check
  955. # below if unnecessary
  956. final_shape = infer_size(final_shape, shape)
  957. assert final_shape is not None
  958. from torch.fx.experimental.symbolic_shapes import guard_or_false, sym_eq
  959. # Do some extra safety checks to see if the output
  960. # stride is obvious
  961. for op in operands:
  962. if (
  963. isinstance(op, torch.Tensor)
  964. and len(op.shape) == len(final_shape)
  965. # take the slow path if result is not determined.
  966. and guard_or_false(sym_eq(op.shape, final_shape))
  967. ):
  968. break
  969. else:
  970. # if we never break in the for loop above we take the slow path.
  971. return slow("both tensors nontrivially broadcast")
  972. # compute_types
  973. cpu = torch.device("cpu")
  974. common_device = cpu
  975. common_dtype = None
  976. has_different_input_dtypes = False
  977. for op in operands:
  978. if not isinstance(op, torch.Tensor):
  979. # Use elementwise_dtypes for the tricky case
  980. has_different_input_dtypes = True
  981. continue
  982. if common_device == cpu and not op.device.type == "cpu":
  983. common_device = op.device
  984. # Slightly simplified here as target_dtype cannot vary
  985. if common_dtype is None:
  986. common_dtype = op.dtype
  987. elif common_dtype != op.dtype:
  988. has_different_input_dtypes = True
  989. if has_different_input_dtypes:
  990. # compute promotion
  991. # TODO: we don't need the compute type
  992. _, common_dtype = elementwise_dtypes(
  993. *operands, type_promotion_kind=type_promotion_kind
  994. )
  995. # check all tensors on same device
  996. # cpu scalars are assumed allow
  997. current_cpu_scalars_on_non_cpu = 0
  998. max_cpu_scalars_on_non_cpu = 1 # hard coded atm
  999. for op in operands:
  1000. if not isinstance(op, torch.Tensor):
  1001. continue
  1002. if common_device != cpu and op.dim() == 0 and op.device == cpu:
  1003. if current_cpu_scalars_on_non_cpu >= max_cpu_scalars_on_non_cpu:
  1004. return slow("error")
  1005. current_cpu_scalars_on_non_cpu += 1
  1006. elif op.device != common_device:
  1007. return slow("error")
  1008. # compute_fast_setup_type
  1009. definitely_contiguous = True
  1010. definitely_channels_last = True
  1011. # TODO: is_non-overlapping_and_dense not bound from Python
  1012. # no inplace, no out, everything defined
  1013. if is_noncontiguous_supported(common_device):
  1014. for op in operands:
  1015. if not isinstance(op, torch.Tensor):
  1016. continue
  1017. definitely_contiguous = (
  1018. definitely_contiguous
  1019. and is_contiguous_for_memory_format_or_false(
  1020. op, memory_format=torch.contiguous_format
  1021. )
  1022. )
  1023. definitely_channels_last = (
  1024. definitely_channels_last
  1025. and is_contiguous_for_memory_format_or_false(
  1026. op, memory_format=torch.channels_last
  1027. )
  1028. )
  1029. if definitely_contiguous:
  1030. # do contiguous
  1031. count_label("fast is_contiguous")
  1032. return FakeTensor(
  1033. mode,
  1034. torch.empty(
  1035. final_shape,
  1036. dtype=common_dtype,
  1037. device="meta",
  1038. memory_format=torch.contiguous_format,
  1039. ),
  1040. device=common_device,
  1041. )
  1042. if definitely_channels_last:
  1043. count_label("fast channels_last")
  1044. # do channels last
  1045. return FakeTensor(
  1046. mode,
  1047. torch.empty(
  1048. final_shape,
  1049. dtype=common_dtype,
  1050. device="meta",
  1051. memory_format=torch.channels_last,
  1052. ),
  1053. device=common_device,
  1054. )
  1055. return slow("no contiguity match")
  1056. return fast_binary_impl
  1057. # disable the python dispatcher to avoid decomposing detach() further
  1058. # (proxy_mode should still decompose detach() though)
  1059. def fast_detach(fake_mode, x, include_real=False):
  1060. with no_python_dispatcher(), in_kernel_invocation_manager(fake_mode):
  1061. out = torch.ops.aten.detach.default(x)
  1062. if include_real:
  1063. return FakeTensor(fake_mode, out, x.device, real_tensor=x.real_tensor)
  1064. return FakeTensor(fake_mode, out, x.device)
  1065. @functools.cache
  1066. def get_fast_op_impls():
  1067. import torch._refs
  1068. register_fast_op_impl(torch.ops.aten.add.Tensor)(
  1069. make_fast_binary_impl(torch._refs.add)
  1070. )
  1071. register_fast_op_impl(torch.ops.aten.sub.Tensor)(
  1072. make_fast_binary_impl(torch._refs.sub)
  1073. )
  1074. register_fast_op_impl(torch.ops.aten.mul.Tensor)(
  1075. make_fast_binary_impl(torch._refs.mul)
  1076. ) # type: ignore[has-type]
  1077. register_fast_op_impl(torch.ops.aten.div.Tensor)(
  1078. make_fast_binary_impl(
  1079. torch._refs.div,
  1080. type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
  1081. )
  1082. )
  1083. register_fast_op_impl(torch.ops.aten.detach.default)(fast_detach)
  1084. return FAST_OP_IMPLEMENTATIONS