_shape_functions.py 44 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126112711281129113011311132113311341135113611371138113911401141114211431144114511461147114811491150115111521153115411551156115711581159116011611162116311641165116611671168116911701171117211731174117511761177117811791180118111821183118411851186118711881189119011911192119311941195119611971198119912001201120212031204120512061207120812091210121112121213121412151216121712181219122012211222122312241225122612271228122912301231123212331234123512361237123812391240124112421243124412451246124712481249125012511252125312541255125612571258125912601261126212631264126512661267126812691270127112721273127412751276127712781279128012811282128312841285128612871288128912901291129212931294129512961297129812991300130113021303130413051306130713081309131013111312131313141315131613171318131913201321132213231324132513261327132813291330133113321333133413351336133713381339134013411342134313441345134613471348134913501351135213531354135513561357135813591360136113621363136413651366136713681369137013711372137313741375137613771378137913801381138213831384138513861387138813891390139113921393139413951396139713981399140014011402140314041405140614071408140914101411141214131414141514161417141814191420142114221423142414251426142714281429143014311432143314341435143614371438143914401441144214431444144514461447144814491450145114521453145414551456145714581459146014611462146314641465146614671468146914701471147214731474
  1. # mypy: allow-untyped-defs
  2. import math
  3. from typing import Any, Callable, Dict, List, Optional, Tuple, Union
  4. number = Union[int, float]
  5. # flake8: noqa
  6. ###
  7. # There are generated files that depend on this file
  8. # To re-generate, please run from the root of the repo:
  9. # python torchgen/shape_functions/gen_jit_shape_functions.py
  10. # How to test:
  11. # After regenerating files, compile PyTorch.
  12. # Then run: ./build/bin/test_jit --gtest_filter=TestShapeGraphLinting.Basic
  13. # If you have enabled opinfo testing for the op, also run:
  14. # python test/test_ops_jit.py TestJitCPU.test_variant_consistency_jit_[FAILING_OP]_cpu_float32
  15. # to reproduce errors from opinfo tests.
  16. # Example PR: https://github.com/pytorch/pytorch/pull/80860/files
  17. ####
  18. import torch
  19. def broadcast(a: list[int], b: list[int]):
  20. dimsA = len(a)
  21. dimsB = len(b)
  22. ndim = max(dimsA, dimsB)
  23. expandedSizes: list[int] = []
  24. for i in range(ndim):
  25. offset = ndim - 1 - i
  26. dimA = dimsA - 1 - offset
  27. dimB = dimsB - 1 - offset
  28. sizeA = a[dimA] if (dimA >= 0) else 1
  29. sizeB = b[dimB] if (dimB >= 0) else 1
  30. if sizeA != sizeB and sizeA != 1 and sizeB != 1:
  31. # TODO: only assertion error is bound in C++ compilation right now
  32. raise AssertionError(
  33. f"The size of tensor a {sizeA} must match the size of tensor b ({sizeB}) at non-singleton dimension {i}"
  34. )
  35. expandedSizes.append(sizeB if sizeA == 1 else sizeA)
  36. return expandedSizes
  37. def broadcast_three(a: list[int], b: list[int], c: list[int]):
  38. return broadcast(broadcast(a, b), c)
  39. def broadcast_one_three(a: list[int], b: Any, c: list[int]):
  40. return broadcast(a, c)
  41. def adaptive_avg_pool2d(self: list[int], out: list[int]):
  42. assert len(out) == 2
  43. assert len(self) == 3 or len(self) == 4
  44. for i in range(1, len(self)):
  45. assert self[i] != 0
  46. shape: list[int] = []
  47. for i in range(0, len(self) - 2):
  48. shape.append(self[i])
  49. for elem in out:
  50. shape.append(elem)
  51. return shape
  52. def _copy(self: list[int]):
  53. out: list[int] = []
  54. for elem in self:
  55. out.append(elem)
  56. return out
  57. def unary(self: list[int]):
  58. return _copy(self)
  59. def broadcast_inplace(a: list[int], b: list[int]):
  60. dimsA = len(a)
  61. dimsB = len(b)
  62. if dimsB > dimsA:
  63. raise AssertionError(
  64. f"The dims of tensor b ({dimsB}) must be less than or equal to the dims of tensor a ({dimsA}) "
  65. )
  66. for dimA in range(dimsA):
  67. dimB = dimsB - dimsA + dimA
  68. sizeA = a[dimA]
  69. sizeB = b[dimB] if (dimB >= 0) else 1
  70. if sizeA != sizeB and sizeB != 1:
  71. # TODO: only assertion error is bound in C++ compilation right now
  72. raise AssertionError(
  73. "The size of tensor a {} must match the size of tensor b ("
  74. "{}) at non-singleton dimension {}".format(sizeA, sizeB, dimA)
  75. )
  76. return _copy(a)
  77. def expand(self: list[int], sizes: list[int]):
  78. assert len(sizes) >= len(self)
  79. ndim = len(sizes)
  80. tensor_dim = len(self)
  81. if ndim == 0:
  82. return _copy(sizes)
  83. out: list[int] = []
  84. for i in range(ndim):
  85. offset = ndim - 1 - i
  86. dim = tensor_dim - 1 - offset
  87. size = self[dim] if dim >= 0 else 1
  88. targetSize = sizes[i]
  89. if targetSize == -1:
  90. assert dim >= 0
  91. targetSize = size
  92. if size != targetSize:
  93. assert size == 1
  94. size = targetSize
  95. out.append(size)
  96. return out
  97. def expand_one_unused(self: list[int], sizes: list[int], inp0: Any):
  98. return expand(self, sizes)
  99. def infer_size_impl(shape: list[int], numel: int) -> list[int]:
  100. newsize = 1
  101. infer_dim: Optional[int] = None
  102. for dim in range(len(shape)):
  103. if shape[dim] == -1:
  104. if infer_dim is not None:
  105. raise AssertionError("only one dimension can be inferred")
  106. infer_dim = dim
  107. elif shape[dim] >= 0:
  108. newsize *= shape[dim]
  109. else:
  110. raise AssertionError("invalid shape dimensions")
  111. if not (
  112. numel == newsize
  113. or (infer_dim is not None and newsize > 0 and numel % newsize == 0)
  114. ):
  115. raise AssertionError("invalid shape")
  116. out = _copy(shape)
  117. if infer_dim is not None:
  118. out[infer_dim] = numel // newsize
  119. return out
  120. def numel(sizes: list[int]):
  121. numel = 1
  122. for elem in sizes:
  123. numel *= elem
  124. return numel
  125. def view(self: list[int], sizes: list[int]):
  126. return infer_size_impl(sizes, numel(self))
  127. def view_one_unused(self: list[int], sizes: list[int], *, implicit: bool = False):
  128. return view(self, sizes)
  129. def sum_mean_dim(
  130. self: list[int], opt_dims: Optional[list[int]], keep_dim: bool, dt: Any
  131. ):
  132. out: list[int] = []
  133. if opt_dims is None or len(opt_dims) == 0:
  134. dims: list[int] = list(range(len(self)))
  135. else:
  136. dims = opt_dims
  137. for idx in range(len(self)):
  138. is_mean_dim: bool = False
  139. for reduce_dim in dims:
  140. if idx == maybe_wrap_dim(reduce_dim, len(self)):
  141. is_mean_dim = True
  142. if is_mean_dim:
  143. if keep_dim:
  144. out.append(1)
  145. else:
  146. out.append(self[idx])
  147. return out
  148. def max_dim(self: list[int], dim: int, keep_dim: bool):
  149. out = sum_mean_dim(self, [dim], keep_dim, None)
  150. return out, out
  151. # note: python already rounds down towards negative infinity on integer division, special arithmetic not needed
  152. def div_rtn(x: int, y: int):
  153. return x // y
  154. def pooling_output_shape_pad_lr(
  155. inputSize: int,
  156. kernelSize: int,
  157. pad_l: int,
  158. pad_r: int,
  159. stride: int,
  160. dilation: int,
  161. ceil_mode: bool,
  162. ):
  163. outputSize = (
  164. div_rtn(
  165. inputSize
  166. + pad_l
  167. + pad_r
  168. - dilation * (kernelSize - 1)
  169. - 1
  170. + (stride - 1 if ceil_mode else 0),
  171. stride,
  172. )
  173. + 1
  174. )
  175. if ceil_mode:
  176. if (outputSize - 1) * stride >= inputSize + pad_l:
  177. outputSize = outputSize - 1
  178. return outputSize
  179. def pooling_output_shape(
  180. inputSize: int,
  181. kernelSize: int,
  182. pad_l: int,
  183. stride: int,
  184. dilation: int,
  185. ceil_mode: bool,
  186. ):
  187. assert stride != 0, "stride should not be zeero"
  188. return pooling_output_shape_pad_lr(
  189. inputSize, kernelSize, pad_l, pad_l, stride, dilation, ceil_mode
  190. )
  191. def pool2d_shape_check(
  192. input: list[int],
  193. kH: int,
  194. kW: int,
  195. dH: int,
  196. dW: int,
  197. padH: int,
  198. padW: int,
  199. dilationH: int,
  200. dilationW: int,
  201. nInputPlane: int,
  202. inputHeight: int,
  203. inputWidth: int,
  204. outputHeight: int,
  205. outputWidth: int,
  206. ):
  207. ndim = len(input)
  208. assert kW > 0 and kH > 0
  209. assert dW > 0 and dH > 0
  210. assert dilationH > 0 and dilationW > 0
  211. valid_dims = input[1] != 0 and input[2] != 0
  212. assert (
  213. ndim == 3
  214. and input[0] != 0
  215. and valid_dims
  216. or (ndim == 4 and valid_dims and input[3] != 0)
  217. )
  218. assert kW // 2 >= padW and kH // 2 >= padH
  219. assert outputWidth >= 1 and outputHeight >= 1
  220. def max_pool2d(
  221. input: list[int],
  222. kernel_size: list[int],
  223. stride: list[int],
  224. padding: list[int],
  225. dilation: list[int],
  226. ceil_mode: bool,
  227. ):
  228. assert len(kernel_size) == 1 or len(kernel_size) == 2, (
  229. "max_pool2d: kernel_size must either be a single int, or a tuple of two ints"
  230. )
  231. kH = kernel_size[0]
  232. kW = kH if len(kernel_size) == 1 else kernel_size[1]
  233. assert len(stride) == 0 or len(stride) == 1 or len(stride) == 2, (
  234. "max_pool2d: stride must either be omitted, a single int, or a tuple of two ints"
  235. )
  236. dH = kH if len(stride) == 0 else stride[0]
  237. if len(stride) == 0:
  238. dW = kW
  239. elif len(stride) == 1:
  240. dW = dH
  241. else:
  242. dW = stride[1]
  243. assert len(padding) == 1 or len(padding) == 2, (
  244. "max_pool2d: padding must either be a single int, or a tuple of two ints"
  245. )
  246. padH = padding[0]
  247. padW = padH if len(padding) == 1 else padding[1]
  248. assert len(dilation) == 1 or len(dilation) == 2, (
  249. "max_pool2d: dilation must be either a single int, or a tuple of two ints"
  250. )
  251. dilationH = dilation[0]
  252. dilationW = dilationH if len(dilation) == 1 else dilation[1]
  253. assert len(input) == 3 or len(input) == 4
  254. nbatch = input[-4] if len(input) == 4 else 1
  255. nInputPlane = input[-3]
  256. inputHeight = input[-2]
  257. inputWidth = input[-1]
  258. outputHeight = pooling_output_shape(inputHeight, kH, padH, dH, dilationH, ceil_mode)
  259. outputWidth = pooling_output_shape(inputWidth, kW, padW, dW, dilationW, ceil_mode)
  260. pool2d_shape_check(
  261. input,
  262. kH,
  263. kW,
  264. dH,
  265. dW,
  266. padH,
  267. padW,
  268. dilationH,
  269. dilationW,
  270. nInputPlane,
  271. inputHeight,
  272. inputWidth,
  273. outputHeight,
  274. outputWidth,
  275. )
  276. if len(input) == 3:
  277. return [nInputPlane, outputHeight, outputWidth]
  278. else:
  279. return [nbatch, nInputPlane, outputHeight, outputWidth]
  280. def max_pool2d_with_indices(
  281. input: list[int],
  282. kernel_size: list[int],
  283. stride: list[int],
  284. padding: list[int],
  285. dilation: list[int],
  286. ceil_mode: bool,
  287. ):
  288. out = max_pool2d(input, kernel_size, stride, padding, dilation, ceil_mode)
  289. return (out, out)
  290. def upsample_nearest2d(
  291. input: list[int],
  292. output_size: Optional[list[int]],
  293. scale_factors: Optional[list[float]],
  294. ):
  295. out: list[int] = []
  296. out.append(input[0])
  297. out.append(input[1])
  298. if scale_factors is None and output_size is None:
  299. assert 0, "Either output_size or scale_factors must be presented"
  300. if output_size is not None:
  301. assert scale_factors is None, (
  302. "Must specify exactly one of output_size and scale_factors"
  303. )
  304. assert len(output_size) == 2
  305. out.append(output_size[0])
  306. out.append(output_size[1])
  307. if scale_factors is not None:
  308. assert output_size is None, (
  309. "Must specify exactly one of output_size and scale_factors"
  310. )
  311. assert len(scale_factors) == 2
  312. out.append(int(input[2] * scale_factors[0]))
  313. out.append(int(input[3] * scale_factors[1]))
  314. return out
  315. def mm(self: list[int], mat2: list[int]):
  316. assert len(self) == 2, "self must be a matrix"
  317. assert len(mat2) == 2, "mat2 must be a matrix"
  318. assert self[1] == mat2[0]
  319. return [self[0], mat2[1]]
  320. def dot(self: list[int], tensor: list[int]):
  321. assert len(self) == 1 and len(tensor) == 1
  322. assert self[0] == tensor[0]
  323. out: list[int] = []
  324. return out
  325. def mv(self: list[int], vec: list[int]):
  326. assert len(self) == 2 and len(vec) == 1
  327. assert self[1] == vec[0]
  328. # TODO: return self
  329. return [self[0]]
  330. def unsqueeze(li: list[int], dim: int):
  331. dim = maybe_wrap_dim(dim, len(li) + 1)
  332. out = _copy(li)
  333. out.insert(dim, 1)
  334. return out
  335. def squeeze_nodim(li: list[int]):
  336. out: list[int] = []
  337. for i in range(len(li)):
  338. if li[i] != 1:
  339. out.append(li[i])
  340. return out
  341. def squeeze(li: list[int], dim: int):
  342. out: list[int] = []
  343. wrapped_dim = maybe_wrap_dim(dim, len(li))
  344. for i in range(len(li)):
  345. if i == wrapped_dim:
  346. if li[i] != 1:
  347. out.append(li[i])
  348. else:
  349. out.append(li[i])
  350. return out
  351. def squeeze_dims(li: list[int], dims: list[int]):
  352. if len(dims) == 0:
  353. return li
  354. wrapped_dims = _copy(dims)
  355. for i in range(len(dims)):
  356. wrapped_dims[i] = maybe_wrap_dim(wrapped_dims[i], len(li))
  357. result: list[int] = []
  358. for i in range(len(li)):
  359. if li[i] == 1:
  360. if i not in wrapped_dims:
  361. result.append(li[i])
  362. else:
  363. result.append(li[i])
  364. return result
  365. def index_select(self: list[int], dim: int, index: list[int]):
  366. dim = maybe_wrap_dim(dim, len(self))
  367. numel = multiply_integers(index)
  368. assert len(index) <= 1
  369. assert dim == 0 or dim < len(self)
  370. result_size: list[int] = []
  371. for i in range(len(self)):
  372. if dim == i:
  373. result_size.append(numel)
  374. else:
  375. result_size.append(self[i])
  376. return result_size
  377. def embedding(
  378. weight: list[int],
  379. indices: list[int],
  380. padding_idx: int = -1,
  381. scale_grad_by_freq: bool = False,
  382. sparse: bool = False,
  383. ):
  384. assert len(weight) == 2
  385. if len(indices) == 1:
  386. return index_select(weight, 0, indices)
  387. size = _copy(indices)
  388. size.append(weight[1])
  389. return size
  390. def max_int():
  391. return 9223372036854775807
  392. def slice(
  393. self: list[int], dim: int, start: Optional[int], end: Optional[int], step: int
  394. ):
  395. ndim = len(self)
  396. assert ndim != 0
  397. dim = maybe_wrap_dim(dim, ndim)
  398. start_val = start if start is not None else 0
  399. end_val = end if end is not None else max_int()
  400. assert step > 0
  401. if start_val == max_int():
  402. start_val = 0
  403. if start_val < 0:
  404. start_val += self[dim]
  405. if end_val < 0:
  406. end_val += self[dim]
  407. if start_val < 0:
  408. start_val = 0
  409. elif start_val > self[dim]:
  410. start_val = self[dim]
  411. if end_val < start_val:
  412. end_val = start_val
  413. elif end_val >= self[dim]:
  414. end_val = self[dim]
  415. slice_len = end_val - start_val
  416. out = _copy(self)
  417. out[dim] = (slice_len + step - 1) // step
  418. return out
  419. def check_cat_no_zero_dim(tensors: list[list[int]]):
  420. for tensor in tensors:
  421. assert len(tensor) > 0
  422. def legacy_cat_wrap_dim(dim: int, tensor_sizes: list[list[int]]):
  423. out_dim: Optional[int] = None
  424. for size in tensor_sizes:
  425. if not (len(size) == 1 and size[0] == 0):
  426. if out_dim is None:
  427. out_dim = maybe_wrap_dim(dim, len(size))
  428. if out_dim is None:
  429. out_dim = dim
  430. return out_dim
  431. def should_skip(tensor: list[int]):
  432. return numel(tensor) == 0 and len(tensor) == 1
  433. def check_cat_shape_except_dim(
  434. first: list[int], second: list[int], dimension: int, index: int
  435. ):
  436. first_dims = len(first)
  437. second_dims = len(second)
  438. assert first_dims == second_dims, "Tensors must have same number of dimensions"
  439. for dim in range(0, first_dims):
  440. if dim != dimension:
  441. assert first[dim] == second[dim], (
  442. "Sizes of tensors must match except in dimension"
  443. )
  444. def cat(tensors: list[list[int]], dim: int):
  445. check_cat_no_zero_dim(tensors)
  446. dim = legacy_cat_wrap_dim(dim, tensors)
  447. assert len(tensors) > 0
  448. not_skipped_tensor: Optional[list[int]] = None
  449. for tensor in tensors:
  450. if not should_skip(tensor):
  451. not_skipped_tensor = tensor
  452. if not_skipped_tensor is None:
  453. return [0]
  454. cat_dim_size = 0
  455. for i in range(len(tensors)):
  456. tensor = tensors[i]
  457. if not should_skip(tensor):
  458. check_cat_shape_except_dim(not_skipped_tensor, tensor, dim, i)
  459. cat_dim_size = cat_dim_size + tensor[dim]
  460. result_size = _copy(not_skipped_tensor)
  461. result_size[dim] = cat_dim_size
  462. return result_size
  463. def stack(tensors: list[list[int]], dim: int):
  464. unsqueezed_tensors: list[list[int]] = []
  465. for tensor in tensors:
  466. unsqueezed = unsqueeze(tensor, dim)
  467. unsqueezed_tensors.append(unsqueezed)
  468. return cat(unsqueezed_tensors, dim)
  469. def select(self: list[int], dim: int, index: int):
  470. ndim = len(self)
  471. assert ndim != 0
  472. dim = maybe_wrap_dim(dim, ndim)
  473. size = self[dim]
  474. assert not (index < -size or index >= size)
  475. if index < 0:
  476. index += size
  477. out: list[int] = []
  478. for i in range(ndim):
  479. if i != dim:
  480. out.append(self[i])
  481. return out
  482. def matmul(tensor1: list[int], tensor2: list[int]):
  483. dim_tensor1 = len(tensor1)
  484. dim_tensor2 = len(tensor2)
  485. if dim_tensor1 == 1 and dim_tensor2 == 1:
  486. return dot(tensor1, tensor2)
  487. elif dim_tensor1 == 2 and dim_tensor2 == 1:
  488. return mv(tensor1, tensor2)
  489. elif dim_tensor1 == 1 and dim_tensor2 == 2:
  490. return squeeze(mm(unsqueeze(tensor1, 0), tensor2), 0)
  491. elif dim_tensor1 == 2 and dim_tensor2 == 2:
  492. return mm(tensor1, tensor2)
  493. elif dim_tensor1 >= 1 and dim_tensor2 >= 1:
  494. # We are multiplying b1 x n x m1 by x2 x m2 x p (where b1 can be a list);
  495. # we track m1 vs m2 separately even though they must match for nicer error messages
  496. n = tensor1[-2] if dim_tensor1 > 1 else 1
  497. batch_tensor1: list[int] = []
  498. # TODO: handling of slice
  499. for i in range(dim_tensor1 - 2):
  500. batch_tensor1.append(tensor1[i])
  501. p = tensor2[-1]
  502. batch_tensor2: list[int] = []
  503. # TODO: handling of slice
  504. for i in range(dim_tensor2 - 2):
  505. batch_tensor2.append(tensor2[i])
  506. # expand the batch portion (i.e. cut off matrix dimensions and expand rest)
  507. expand_batch_portion = broadcast(batch_tensor1, batch_tensor2)
  508. # todo: copy ?
  509. output_shape = expand_batch_portion
  510. if dim_tensor1 > 1:
  511. output_shape.append(n)
  512. if dim_tensor2 > 1:
  513. output_shape.append(p)
  514. return output_shape
  515. else:
  516. assert False, "both arguments to matmul need to be at least 1D"
  517. def t(self: list[int]):
  518. assert len(self) <= 2
  519. self_len = len(self)
  520. if self_len == 0:
  521. out: list[int] = []
  522. return out
  523. elif self_len == 1:
  524. return [self[0]]
  525. else:
  526. return [self[1], self[0]]
  527. def transpose(self: list[int], dim0: int, dim1: int):
  528. ndims = len(self)
  529. dim0 = maybe_wrap_dim(dim0, ndims)
  530. dim1 = maybe_wrap_dim(dim1, ndims)
  531. if dim0 == dim1:
  532. return _copy(self)
  533. out: list[int] = []
  534. for i in range(ndims):
  535. if i == dim0:
  536. out.append(self[dim1])
  537. elif i == dim1:
  538. out.append(self[dim0])
  539. else:
  540. out.append(self[i])
  541. return out
  542. def linear(input: list[int], weight: list[int], bias: Optional[list[int]]):
  543. out = matmul(input, t(weight))
  544. if bias is not None:
  545. assert broadcast(bias, out) == out
  546. return out
  547. def addmm(self: list[int], mat1: list[int], mat2: list[int], beta: Any, alpha: Any):
  548. return broadcast(self, mm(mat1, mat2))
  549. def check_non_negative(array: list[int]) -> bool:
  550. # TODO: look into rewriting with early return and getting loop unrolling to fire
  551. non_negative = False
  552. for val in array:
  553. if val < 0:
  554. non_negative = True
  555. return non_negative
  556. def check_shape_forward(
  557. input: list[int],
  558. weight_sizes: list[int],
  559. bias: Optional[list[int]],
  560. stride: list[int],
  561. padding: list[int],
  562. dilation: list[int],
  563. groups: int,
  564. ):
  565. k = len(input)
  566. weight_dim = len(weight_sizes)
  567. # TODO: assertions could be expanded with the error messages
  568. assert not check_non_negative(padding)
  569. assert not check_non_negative(stride)
  570. assert weight_dim == k
  571. assert weight_sizes[0] >= groups
  572. assert (weight_sizes[0] % groups) == 0
  573. # only handling not transposed
  574. assert input[1] == weight_sizes[1] * groups
  575. assert bias is None or (len(bias) == 1 and bias[0] == weight_sizes[0])
  576. for i in range(2, k):
  577. assert (input[i] + 2 * padding[i - 2]) >= (
  578. dilation[i - 2] * (weight_sizes[i] - 1) + 1
  579. )
  580. # this is not handling transposed convolution yet
  581. def conv_output_size(
  582. input_size: list[int],
  583. weight_size: list[int],
  584. bias: Optional[list[int]],
  585. stride: list[int],
  586. padding: list[int],
  587. dilation: list[int],
  588. groups: int,
  589. ):
  590. check_shape_forward(
  591. input_size, weight_size, bias, stride, padding, dilation, groups
  592. )
  593. has_dilation = len(dilation) > 0
  594. dim = len(input_size)
  595. output_size: list[int] = []
  596. input_batch_size_dim = 0
  597. weight_output_channels_dim = 0
  598. output_size.append(input_size[input_batch_size_dim])
  599. output_size.append(weight_size[weight_output_channels_dim])
  600. for d in range(2, dim):
  601. dilation_ = dilation[d - 2] if has_dilation else 1
  602. kernel = dilation_ * (weight_size[d] - 1) + 1
  603. output_size.append(
  604. (input_size[d] + (2 * padding[d - 2]) - kernel) // stride[d - 2] + 1
  605. )
  606. return output_size
  607. def conv1d(
  608. input: list[int],
  609. weight: list[int],
  610. bias: Optional[list[int]],
  611. stride: list[int],
  612. padding: list[int],
  613. dilation: list[int],
  614. groups: int,
  615. ):
  616. assert len(weight) == 3
  617. assert len(input) == 3
  618. return conv_output_size(input, weight, bias, stride, padding, dilation, groups)
  619. def conv2d(
  620. input: list[int],
  621. weight: list[int],
  622. bias: Optional[list[int]],
  623. stride: list[int],
  624. padding: list[int],
  625. dilation: list[int],
  626. groups: int,
  627. ):
  628. assert len(weight) == 4
  629. assert len(input) == 4
  630. return conv_output_size(input, weight, bias, stride, padding, dilation, groups)
  631. def conv_backwards(
  632. grad_output: list[int],
  633. input: list[int],
  634. weight: list[int],
  635. biases: Optional[list[int]],
  636. ):
  637. # Bias gradient is always generated regardess of if biases is supplied
  638. return _copy(input), _copy(weight), [grad_output[1]]
  639. def conv_transpose2d_input(
  640. input: list[int],
  641. weight: list[int],
  642. bias: Optional[list[int]] = None,
  643. stride: Optional[list[int]] = None,
  644. padding: Optional[list[int]] = None,
  645. output_padding: Optional[list[int]] = None,
  646. groups: int = 1,
  647. dilation: Optional[list[int]] = None,
  648. ) -> list[int]:
  649. if stride is None:
  650. stride = [1, 1]
  651. if padding is None:
  652. padding = [0, 0]
  653. if output_padding is None:
  654. output_padding = [0, 0]
  655. if dilation is None:
  656. dilation = [1, 1]
  657. has_dilation = len(dilation) > 0
  658. dim = len(input)
  659. output_size: list[int] = []
  660. input_batch_size_dim = 0
  661. weight_output_channels_dim = 1
  662. output_size.append(input[input_batch_size_dim])
  663. output_size.append(weight[weight_output_channels_dim] * groups)
  664. for d in range(2, dim):
  665. dilation_ = dilation[d - 2] if has_dilation else 1
  666. kernel = dilation_ * (weight[d] - 1)
  667. output_size.append(
  668. (input[d] - 1) * stride[d - 2]
  669. - 2 * padding[d - 2]
  670. + kernel
  671. + output_padding[d - 2]
  672. + 1
  673. )
  674. return output_size
  675. def conv_forwards(
  676. input: list[int],
  677. weight: list[int],
  678. bias: Optional[list[int]],
  679. stride: list[int],
  680. padding: list[int],
  681. dilation: list[int],
  682. transposed: bool,
  683. output_padding: list[int],
  684. groups: int,
  685. ) -> list[int]:
  686. has_dilation = len(dilation) > 0
  687. has_output_padding = len(output_padding) > 0
  688. dim = len(input)
  689. output_size: list[int] = []
  690. input_batch_size_dim = 0
  691. weight_output_channels_dim = 1 if transposed else 0
  692. output_size.append(input[input_batch_size_dim])
  693. if transposed:
  694. output_size.append(weight[weight_output_channels_dim] * groups)
  695. else:
  696. output_size.append(weight[weight_output_channels_dim])
  697. for d in range(2, dim):
  698. dilation_ = dilation[d - 2] if has_dilation else 1
  699. output_padding_ = output_padding[d - 2] if has_output_padding else 0
  700. if transposed:
  701. kernel = dilation_ * (weight[d] - 1)
  702. output_size.append(
  703. (input[d] - 1) * stride[d - 2]
  704. - 2 * padding[d - 2]
  705. + kernel
  706. + output_padding_
  707. + 1
  708. )
  709. else:
  710. kernel = dilation_ * (weight[d] - 1) + 1
  711. output_size.append(
  712. (input[d] + (2 * padding[d - 2]) - kernel) // stride[d - 2] + 1
  713. )
  714. return output_size
  715. def _conv_forwards(
  716. input: list[int],
  717. weight: list[int],
  718. bias: Optional[list[int]],
  719. stride: list[int],
  720. padding: list[int],
  721. dilation: list[int],
  722. transposed: bool,
  723. output_padding: list[int],
  724. groups: int,
  725. benchmark: bool,
  726. deterministic: bool,
  727. cudnn_enabled: bool,
  728. allow_tf32: bool,
  729. ) -> list[int]:
  730. return conv_forwards(
  731. input,
  732. weight,
  733. bias,
  734. stride,
  735. padding,
  736. dilation,
  737. transposed,
  738. output_padding,
  739. groups,
  740. )
  741. def batch_norm(
  742. input: list[int],
  743. weight: Optional[list[int]],
  744. bias: Optional[list[int]],
  745. running_mean: Optional[list[int]],
  746. running_var: Optional[list[int]],
  747. training: bool,
  748. momentum: float,
  749. eps: float,
  750. cudnn_enabled: bool,
  751. ):
  752. out: list[int] = []
  753. for elem in input:
  754. out.append(elem)
  755. return out
  756. def conv3d(
  757. input: list[int],
  758. weight: list[int],
  759. bias: Optional[list[int]],
  760. stride: list[int],
  761. padding: list[int],
  762. dilation: list[int],
  763. groups: int,
  764. ):
  765. assert len(weight) == 5
  766. assert len(input) == 5
  767. return conv_output_size(input, weight, bias, stride, padding, dilation, groups)
  768. def maybe_wrap_dim(dim: int, dim_post_expr: int, wrap_scalar: bool = True):
  769. if dim_post_expr <= 0:
  770. assert wrap_scalar
  771. dim_post_expr = 1
  772. min = -dim_post_expr
  773. max = dim_post_expr - 1
  774. assert not (dim < min or dim > max)
  775. if dim < 0:
  776. dim += dim_post_expr
  777. return dim
  778. def zero_dim_tensor(input: Any):
  779. out: list[int] = []
  780. return out
  781. def multiply_integers(li: list[int]):
  782. out = 1
  783. for elem in li:
  784. out = out * elem
  785. return out
  786. def arange_end(end: number, inp0: Any, inp1: Any, inp2: Any, inp3: Any):
  787. assert end >= 0
  788. return [int(math.ceil(end))]
  789. def arange_start(
  790. start: number, end: number, inp0: Any, inp1: Any, inp2: Any, inp3: Any
  791. ):
  792. assert end >= 0
  793. assert end >= start
  794. return [int(math.ceil(end - start))]
  795. def arange_start_step(
  796. start: number, end: number, step: number, inp0: Any, inp1: Any, inp2: Any, inp3: Any
  797. ):
  798. assert step != 0
  799. if step < 0:
  800. assert start >= end
  801. else:
  802. assert end >= start
  803. return [int(math.ceil((end - start) / step))]
  804. def permute(input: list[int], dims: list[int]):
  805. assert len(input) == len(dims)
  806. ndim = len(dims)
  807. seen_dims: list[int] = []
  808. newSizes: list[int] = []
  809. for i in range(ndim):
  810. dim = maybe_wrap_dim(dims[i], ndim)
  811. seen_dims.append(dim)
  812. newSizes.append(input[dim])
  813. for i in range(1, ndim):
  814. for j in range(i):
  815. assert seen_dims[i] != seen_dims[j]
  816. return newSizes
  817. def movedim(self: list[int], source: list[int], destination: list[int]) -> list[int]:
  818. self_dim = len(self)
  819. if self_dim <= 1:
  820. return self
  821. normalized_src: list[int] = []
  822. normalized_dst: list[int] = []
  823. for i in range(len(source)):
  824. normalized_src.append(maybe_wrap_dim(source[i], self_dim))
  825. normalized_dst.append(maybe_wrap_dim(destination[i], self_dim))
  826. order = [-1 for i in range(self_dim)]
  827. src_dims = [i for i in range(self_dim)]
  828. dst_dims = [i for i in range(self_dim)]
  829. for i in range(len(source)):
  830. order[normalized_dst[i]] = normalized_src[i]
  831. src_dims[normalized_src[i]] = -1
  832. dst_dims[normalized_dst[i]] = -1
  833. source_dims: list[int] = []
  834. destination_dims: list[int] = []
  835. for ele in src_dims:
  836. if ele != -1:
  837. source_dims.append(ele)
  838. for ele in dst_dims:
  839. if ele != -1:
  840. destination_dims.append(ele)
  841. rest_dim = self_dim - len(source)
  842. for i in range(rest_dim):
  843. order[destination_dims[i]] = source_dims[i]
  844. return permute(self, order)
  845. def flatten(input: list[int], start_dim: int, end_dim: int):
  846. start_dim = maybe_wrap_dim(start_dim, len(input))
  847. end_dim = maybe_wrap_dim(end_dim, len(input))
  848. assert start_dim <= end_dim
  849. if len(input) == 0:
  850. return [1]
  851. if start_dim == end_dim:
  852. # TODO: return self
  853. out: list[int] = []
  854. for elem in input:
  855. out.append(elem)
  856. return out
  857. slice_numel = 1
  858. for i in range(start_dim, end_dim + 1):
  859. slice_numel *= input[i]
  860. # TODO: use slicing when slice optimization has landed
  861. # slice_numel = multiply_integers(input[start_dim:end_dim - start_dim + 1])
  862. shape: list[int] = []
  863. for i in range(start_dim):
  864. shape.append(input[i])
  865. shape.append(slice_numel)
  866. for i in range(end_dim + 1, len(input)):
  867. shape.append(input[i])
  868. return shape
  869. def nonzero_lower_bound(input: list[int]):
  870. return [0, len(input)]
  871. def nonzero_upper_bound(input: list[int]):
  872. return [numel(input), len(input)]
  873. def _reduce_along_dim(self: list[int], dim: int, keepdim: bool):
  874. dim = maybe_wrap_dim(dim, len(self))
  875. out: list[int] = []
  876. for i, self_dim in enumerate(self):
  877. if i == dim:
  878. if keepdim:
  879. out.append(1)
  880. else:
  881. out.append(self_dim)
  882. return out
  883. def argmax(
  884. self: list[int], dim: Optional[int] = None, keepdim: bool = False
  885. ) -> list[int]:
  886. if dim is None:
  887. return []
  888. return _reduce_along_dim(self, dim, keepdim)
  889. def bmm(self: list[int], mat2: list[int]) -> list[int]:
  890. assert len(self) == 3, "bmm only supports 3D tensors"
  891. assert len(mat2) == 3, "bmm only supports 3D tensors"
  892. assert self[0] == mat2[0], "mismatching batch dimension"
  893. assert self[2] == mat2[1], "mismatching contracting dimension"
  894. return [self[0], self[1], mat2[2]]
  895. def _shape_as_tensor(self: list[int]) -> list[int]:
  896. return [len(self)]
  897. def topk(self: list[int], k: int, dim: int = -1) -> tuple[list[int], list[int]]:
  898. if len(self) == 0:
  899. result: list[int] = []
  900. else:
  901. assert k <= self[dim], (
  902. f"k ({k}) is too big for dimension {dim} of size {self[dim]}"
  903. )
  904. result = _copy(self)
  905. result[dim] = k
  906. return result, result
  907. def nll_loss_forward(
  908. self: list[int], target: list[int], weight: Optional[list[int]], reduction: int
  909. ) -> tuple[list[int], list[int]]:
  910. # This is taken shamelessly from the meta function in LossNLL.cpp
  911. self_dim = len(self)
  912. target_dim = len(target)
  913. assert 0 < self_dim <= 2
  914. assert target_dim <= 1
  915. no_batch_dim = self_dim == 1 and target_dim == 0
  916. assert no_batch_dim or (self[0] == target[0])
  917. n_classes = self[-1]
  918. scalar_shape: list[int] = []
  919. assert weight is None or (len(weight) == 1 and weight[0] == n_classes)
  920. if reduction == 0 and self_dim == 2:
  921. reduction_shape = [self[0]]
  922. else:
  923. reduction_shape = scalar_shape
  924. return reduction_shape, scalar_shape
  925. def native_layer_norm(
  926. input: list[int], normalized_shape: list[int]
  927. ) -> tuple[list[int], list[int], list[int]]:
  928. reduction_shape: list[int] = []
  929. num_unreduced_dimensions = len(input) - len(normalized_shape)
  930. assert num_unreduced_dimensions >= 0
  931. for i in range(num_unreduced_dimensions):
  932. reduction_shape.append(input[i])
  933. for i in range(num_unreduced_dimensions, len(input)):
  934. reduction_shape.append(1)
  935. return _copy(input), reduction_shape, reduction_shape
  936. def native_batch_norm(
  937. input: list[int],
  938. weight: Optional[list[int]],
  939. bias: Optional[list[int]],
  940. running_mean: Optional[list[int]],
  941. running_var: Optional[list[int]],
  942. training: bool,
  943. ) -> tuple[list[int], list[int], list[int]]:
  944. if training:
  945. _size = [input[1]]
  946. else:
  947. _size = [0]
  948. return _copy(input), _size, _size
  949. def _batch_norm_with_update(
  950. input: list[int],
  951. weight: Optional[list[int]],
  952. bias: Optional[list[int]],
  953. running_mean: Optional[list[int]],
  954. running_var: Optional[list[int]],
  955. ) -> tuple[list[int], list[int], list[int], list[int]]:
  956. _size = [input[1]]
  957. return _copy(input), _size, _size, [0]
  958. def cross_entropy_loss(
  959. self: list[int],
  960. target: list[int],
  961. weight: Optional[list[int]] = None,
  962. reduction: int = 1,
  963. ignore_index: int = -100,
  964. label_smoothing: float = 0.0,
  965. ) -> list[int]:
  966. result_shape = nll_loss_forward(self, target, weight, reduction)[0]
  967. return result_shape
  968. """
  969. Currently deferring the enabling of this, as part of the propoasal to suspend
  970. adding ops.
  971. There are currently cases in the test case where this is being called
  972. in the SSA opinfo tests with with unexpected values (eg list of two ints, see the first
  973. opinfo test). The behavior of index is significantly dependent on the inputs.
  974. This could be an error with how we are matching up shape functions, or that this
  975. function needs to just implement everything.
  976. def index_Tensor(self: List[int], indices: List[Optional[List[int]]]) -> List[int]:
  977. assert len(indices) <= len(self), "More indices than dimensions to index"
  978. broadcasted_shape: List[int] = []
  979. for index_tensor_shape in indices:
  980. if index_tensor_shape is not None:
  981. broadcasted_shape = broadcast(broadcasted_shape, index_tensor_shape)
  982. return broadcasted_shape
  983. """
  984. ScriptFn = torch._C.ScriptFunction
  985. shape_compute_graph_mapping: dict[str, ScriptFn] = {}
  986. bounded_compute_graph_mapping: dict[str, tuple[ScriptFn, ScriptFn]] = {}
  987. script_func_map: dict[Callable, ScriptFn] = {}
  988. def process_func(func: Callable):
  989. if func not in script_func_map:
  990. scripted_func = torch.jit.script(func)
  991. torch._C._jit_pass_inline(scripted_func.graph)
  992. for _ in range(2):
  993. torch._C._jit_pass_peephole(scripted_func.graph)
  994. torch._C._jit_pass_constant_propagation(scripted_func.graph)
  995. script_func_map[func] = scripted_func
  996. return script_func_map[func]
  997. def add_shape_compute_mapping(operator_schema: str, func: Callable):
  998. global shape_compute_graph_mapping
  999. shape_compute_graph_mapping[operator_schema] = process_func(func)
  1000. def add_bounded_compute_mapping(
  1001. operator_schema: str, lower_bound_func: Callable, upper_bound_func: Callable
  1002. ):
  1003. # Adds a shape compute function for both upper and lower bounds
  1004. fns = (process_func(lower_bound_func), process_func(upper_bound_func))
  1005. bounded_compute_graph_mapping[operator_schema] = fns
  1006. add_shape_compute_mapping(
  1007. "aten::contiguous(Tensor(a) self, *, MemoryFormat memory_format=contiguous_format) -> Tensor(a)",
  1008. unary,
  1009. )
  1010. add_shape_compute_mapping(
  1011. "aten::rsub.Tensor(Tensor self, Scalar other, Scalar alpha=1) -> Tensor", unary
  1012. )
  1013. add_shape_compute_mapping(
  1014. "aten::dropout(Tensor input, float p, bool train) -> Tensor", unary
  1015. )
  1016. add_shape_compute_mapping(
  1017. "aten::adaptive_avg_pool2d(Tensor self, int[2] output_size) -> Tensor",
  1018. adaptive_avg_pool2d,
  1019. )
  1020. add_shape_compute_mapping(
  1021. "prim::NumToTensor.Scalar(Scalar a) -> Tensor", zero_dim_tensor
  1022. )
  1023. add_shape_compute_mapping("prim::NumToTensor.bool(bool a) -> Tensor", zero_dim_tensor)
  1024. add_shape_compute_mapping(
  1025. "aten::zeros(int[] size, *, int? dtype=None, int? layout=None, Device? device=None, bool? pin_memory=None) -> (Tensor)",
  1026. unary,
  1027. )
  1028. add_shape_compute_mapping(
  1029. "aten::to.dtype(Tensor(a) self, int dtype, bool non_blocking=False, bool copy=False, int? memory_format=None) -> (Tensor(a))",
  1030. unary,
  1031. )
  1032. add_shape_compute_mapping(
  1033. "aten::arange(Scalar end, *, int? dtype=None, int? layout=None, Device? device=None, bool? pin_memory=None) -> (Tensor)",
  1034. arange_end,
  1035. )
  1036. add_shape_compute_mapping(
  1037. "aten::arange.start(Scalar start, Scalar end, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor",
  1038. arange_start,
  1039. )
  1040. add_shape_compute_mapping(
  1041. "aten::arange.start_step(Scalar start, Scalar end, Scalar step, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor",
  1042. arange_start_step,
  1043. )
  1044. add_shape_compute_mapping("aten::squeeze(Tensor(a) self) -> Tensor(a)", squeeze_nodim)
  1045. add_shape_compute_mapping(
  1046. "aten::squeeze.dim(Tensor(a) self, int dim) -> Tensor(a)", squeeze
  1047. )
  1048. add_shape_compute_mapping(
  1049. "aten::squeeze.dims(Tensor(a) self, int[] dim) -> Tensor(a)", squeeze_dims
  1050. )
  1051. add_shape_compute_mapping(
  1052. "aten::unsqueeze(Tensor(a) self, int dim) -> Tensor(a)", unsqueeze
  1053. )
  1054. add_shape_compute_mapping(
  1055. "aten::slice.Tensor(Tensor(a) self, int dim=0, int? start=None, int? end=None, int step=1) -> Tensor(a)",
  1056. slice,
  1057. )
  1058. add_shape_compute_mapping(
  1059. "aten::select.int(Tensor(a) self, int dim, int index) -> Tensor(a)", select
  1060. )
  1061. add_shape_compute_mapping(
  1062. "aten::index_select(Tensor self, int dim, Tensor index) -> Tensor", index_select
  1063. )
  1064. add_shape_compute_mapping(
  1065. "aten::layer_norm(Tensor input, int[] normalized_shape, Tensor? weight=None, Tensor? bias=None, "
  1066. "float eps=1e-05, bool cudnn_enable=True) -> Tensor",
  1067. unary,
  1068. )
  1069. add_shape_compute_mapping(
  1070. "aten::softmax.int(Tensor self, int dim, ScalarType? dtype=None) -> Tensor", unary
  1071. )
  1072. add_shape_compute_mapping(
  1073. "aten::_no_grad_embedding_renorm_(Tensor weight, Tensor input, float max_norm, float norm_type) -> Tensor",
  1074. unary,
  1075. )
  1076. add_shape_compute_mapping(
  1077. "aten::embedding_renorm_(Tensor(a!) self, Tensor indices, float max_norm, float norm_type) -> Tensor(a!)",
  1078. unary,
  1079. )
  1080. add_shape_compute_mapping(
  1081. "aten::embedding(Tensor weight, Tensor indices, int padding_idx=-1, bool scale_grad_by_freq=False, bool sparse=False) -> Tensor",
  1082. embedding,
  1083. )
  1084. add_shape_compute_mapping("aten::mm(Tensor self, Tensor mat2) -> Tensor", mm)
  1085. add_shape_compute_mapping("aten::dot(Tensor self, Tensor tensor) -> Tensor", dot)
  1086. add_shape_compute_mapping("aten::mv(Tensor self, Tensor vec) -> Tensor", mv)
  1087. add_shape_compute_mapping("aten::matmul(Tensor self, Tensor other) -> Tensor", matmul)
  1088. add_shape_compute_mapping(
  1089. "aten::linear(Tensor input, Tensor weight, Tensor? bias=None) -> Tensor", linear
  1090. )
  1091. add_shape_compute_mapping(
  1092. "aten::max_pool2d(Tensor self, int[2] kernel_size, int[2] stride=[], int[2] padding=0, int[2] dilation=1, bool ceil_mode=False) -> Tensor",
  1093. max_pool2d,
  1094. )
  1095. add_shape_compute_mapping(
  1096. "aten::max_pool2d_with_indices(Tensor self, int[2] kernel_size, int[2] stride=[], int[2] padding=0, int[2] dilation=1, bool ceil_mode=False) -> (Tensor, Tensor)",
  1097. max_pool2d_with_indices,
  1098. )
  1099. add_shape_compute_mapping("aten::t(Tensor(a) self) -> Tensor(a)", t)
  1100. add_shape_compute_mapping(
  1101. "aten::transpose.int(Tensor(a) self, int dim0, int dim1) -> Tensor(a)", transpose
  1102. )
  1103. add_shape_compute_mapping(
  1104. "aten::conv1d(Tensor input, Tensor weight, Tensor? bias=None, int[1] stride=1, int[1] padding=0, int[1] dilation=1, int groups=1) -> Tensor",
  1105. conv1d,
  1106. )
  1107. add_shape_compute_mapping(
  1108. "aten::conv2d(Tensor input, Tensor weight, Tensor? bias=None, int[2] stride=1, int[2] padding=0, int[2] dilation=1, int groups=1) -> Tensor",
  1109. conv2d,
  1110. )
  1111. add_shape_compute_mapping(
  1112. "aten::batch_norm(Tensor input, Tensor? weight, Tensor? bias, Tensor? running_mean, Tensor? running_var, bool training, float momentum, float eps, bool cudnn_enabled) -> Tensor",
  1113. batch_norm,
  1114. )
  1115. add_shape_compute_mapping(
  1116. "aten::conv3d(Tensor input, Tensor weight, Tensor? bias=None, int[3] stride=1, int[3] padding=0, int[3] dilation=1, int groups=1) -> Tensor",
  1117. conv3d,
  1118. )
  1119. add_shape_compute_mapping(
  1120. "aten::convolution_backward(Tensor grad_output, Tensor input, Tensor weight, int[]? bias_sizes, int[] stride, int[] padding, int[] dilation, bool transposed, int[] output_padding, int groups, bool[3] output_mask) -> (Tensor, Tensor, Tensor)",
  1121. conv_backwards,
  1122. )
  1123. add_shape_compute_mapping(
  1124. "aten::convolution(Tensor input, Tensor weight, Tensor? bias, int[] stride, int[] padding, int[] dilation, bool transposed, int[] output_padding, int groups) -> Tensor",
  1125. conv_forwards,
  1126. )
  1127. add_shape_compute_mapping(
  1128. "aten::_convolution(Tensor input, Tensor weight, Tensor? bias, int[] stride, int[] padding, int[] dilation, bool transposed, int[] output_padding, int groups, bool benchmark, bool deterministic, bool cudnn_enabled, bool allow_tf32) -> Tensor",
  1129. _conv_forwards,
  1130. )
  1131. add_shape_compute_mapping(
  1132. "aten::conv_transpose2d.input(Tensor input, Tensor weight, Tensor? bias=None, int[2] stride=1, int[2] padding=0, int[2] output_padding=0, int groups=1, int[2] dilation=1) -> Tensor",
  1133. conv_transpose2d_input,
  1134. )
  1135. add_shape_compute_mapping(
  1136. "aten::flatten.using_ints(Tensor(a) self, int start_dim=0, int end_dim=-1) -> Tensor(a)",
  1137. flatten,
  1138. )
  1139. add_shape_compute_mapping("aten::cat(Tensor[] tensors, int dim=0) -> Tensor", cat)
  1140. add_shape_compute_mapping("aten::stack(Tensor[] tensors, int dim=0) -> Tensor", stack)
  1141. add_shape_compute_mapping(
  1142. "aten::permute(Tensor(a) self, int[] dims) -> Tensor(a)", permute
  1143. )
  1144. add_shape_compute_mapping(
  1145. "aten::movedim.intlist(Tensor(a) self, int[] source, int[] destination) -> Tensor(a)",
  1146. movedim,
  1147. )
  1148. add_shape_compute_mapping("aten::view(Tensor(a) self, int[] size) -> Tensor(a)", view)
  1149. add_shape_compute_mapping(
  1150. "aten::expand_as(Tensor(a) self, Tensor other) -> Tensor(a)", expand
  1151. )
  1152. add_shape_compute_mapping(
  1153. "aten::expand(Tensor(a) self, int[] size, *, bool implicit=False) -> Tensor(a)",
  1154. expand_one_unused,
  1155. )
  1156. add_shape_compute_mapping(
  1157. "aten::mean.dim(Tensor self, int[1]? dim, bool keepdim=False, *, ScalarType? dtype=None) -> Tensor",
  1158. sum_mean_dim,
  1159. )
  1160. add_shape_compute_mapping(
  1161. "aten::sum.dim_IntList(Tensor self, int[1]? dim, bool keepdim=False, *, ScalarType? dtype=None) -> Tensor",
  1162. sum_mean_dim,
  1163. )
  1164. add_shape_compute_mapping(
  1165. "aten::max.dim(Tensor self, int dim, bool keepdim=False) -> (Tensor values, Tensor indices)",
  1166. max_dim,
  1167. )
  1168. add_shape_compute_mapping(
  1169. "aten::mean(Tensor self, *, ScalarType? dtype=None) -> Tensor", zero_dim_tensor
  1170. )
  1171. add_shape_compute_mapping(
  1172. "aten::sum(Tensor self, *, ScalarType? dtype=None) -> Tensor", zero_dim_tensor
  1173. )
  1174. add_shape_compute_mapping(
  1175. "aten::addmm(Tensor self, Tensor mat1, Tensor mat2, *, Scalar beta=1, Scalar alpha=1) -> Tensor",
  1176. addmm,
  1177. )
  1178. add_shape_compute_mapping(
  1179. "aten::upsample_nearest2d.vec(Tensor input, int[]? output_size, float[]? scale_factors) -> (Tensor)",
  1180. upsample_nearest2d,
  1181. )
  1182. add_shape_compute_mapping(
  1183. "aten::quantize_per_tensor(Tensor self, float scale, int zero_point, ScalarType dtype) -> Tensor",
  1184. unary,
  1185. )
  1186. add_shape_compute_mapping(
  1187. "aten::quantize_per_tensor.tensor_qparams(Tensor self, Tensor scale, Tensor zero_point, ScalarType dtype) -> Tensor",
  1188. unary,
  1189. )
  1190. add_shape_compute_mapping("aten::dequantize(Tensor self) -> Tensor", unary)
  1191. add_shape_compute_mapping(
  1192. "quantized::add(Tensor qa, Tensor qb, float scale, int zero_point) -> Tensor qc",
  1193. broadcast,
  1194. )
  1195. add_shape_compute_mapping(
  1196. "aten::argmax(Tensor self, int? dim=None, bool keepdim=False) -> Tensor", argmax
  1197. )
  1198. add_shape_compute_mapping("aten::bmm(Tensor self, Tensor mat2) -> Tensor", bmm)
  1199. add_shape_compute_mapping(
  1200. "aten::_shape_as_tensor(Tensor self) -> Tensor", _shape_as_tensor
  1201. )
  1202. add_shape_compute_mapping(
  1203. "aten::topk(Tensor self, int k, int dim=-1, bool largest=True, bool sorted=True) -> (Tensor values, Tensor indices)",
  1204. topk,
  1205. )
  1206. add_shape_compute_mapping(
  1207. "aten::nll_loss_forward(Tensor self, Tensor target, Tensor? weight, int reduction, int ignore_index) -> (Tensor output, Tensor total_weight)",
  1208. nll_loss_forward,
  1209. )
  1210. add_shape_compute_mapping(
  1211. "aten::native_layer_norm(Tensor input, int[] normalized_shape, Tensor? weight, Tensor? bias, float eps) -> (Tensor, Tensor, Tensor)",
  1212. native_layer_norm,
  1213. )
  1214. add_shape_compute_mapping(
  1215. "aten::native_batch_norm(Tensor input, Tensor? weight, Tensor? bias, Tensor? running_mean, Tensor? running_var, bool training, float momentum, float eps) -> (Tensor, Tensor, Tensor)",
  1216. native_batch_norm,
  1217. )
  1218. add_shape_compute_mapping(
  1219. "aten::_native_batch_norm_legit(Tensor input, Tensor? weight, Tensor? bias, Tensor running_mean, Tensor running_var, bool training, float momentum, float eps) -> (Tensor, Tensor, Tensor)",
  1220. native_batch_norm,
  1221. )
  1222. add_shape_compute_mapping(
  1223. "aten::_native_batch_norm_legit.no_stats(Tensor input, Tensor? weight, Tensor? bias, Tensor running_mean, Tensor running_var, bool training, float momentum, float eps) -> (Tensor, Tensor, Tensor)",
  1224. native_batch_norm,
  1225. )
  1226. add_shape_compute_mapping(
  1227. "_batch_norm_with_update(Tensor input, Tensor? weight, Tensor? bias, Tensor(a!) running_mean, Tensor(b!) running_var, float momentum, float eps) -> (Tensor, Tensor, Tensor, Tensor)",
  1228. _batch_norm_with_update,
  1229. )
  1230. add_shape_compute_mapping(
  1231. "aten::cross_entropy_loss(Tensor self, Tensor target, Tensor? weight=None, int reduction=Mean, SymInt ignore_index=-100, float label_smoothing=0.0) -> Tensor",
  1232. cross_entropy_loss,
  1233. )
  1234. # add_shape_compute_mapping("aten::index.Tensor(Tensor self, Tensor?[] indices) -> Tensor", index_Tensor)
  1235. # TODO: migrate over all of symbolic_shape_registry_util.cpp
  1236. # These are duplicated here so that the functions will be serialized
  1237. add_shape_compute_mapping(
  1238. "aten::lerp.Tensor(Tensor self, Tensor end, Tensor weight) -> Tensor",
  1239. broadcast_three,
  1240. )
  1241. add_shape_compute_mapping(
  1242. "aten::where.ScalarSelf(Tensor condition, Scalar self, Tensor other) -> Tensor",
  1243. broadcast_one_three,
  1244. )
  1245. add_shape_compute_mapping(
  1246. "aten::add_.Tensor(Tensor(a!) self, Tensor other, *, Scalar alpha=1) -> Tensor(a!)",
  1247. broadcast_inplace,
  1248. )
  1249. # quantized_conv_prepack TODO
  1250. # Shape Compute Fn with upper and lower bounds
  1251. add_bounded_compute_mapping(
  1252. "aten::nonzero(Tensor self) -> (Tensor)", nonzero_lower_bound, nonzero_upper_bound
  1253. )