reference.py 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581
  1. # mypy: allow-untyped-defs
  2. import math
  3. import operator
  4. from typing import Union
  5. import sympy
  6. import torch
  7. from torch.utils._sympy.functions import (
  8. _keep_float,
  9. BitwiseFn_bitwise_and,
  10. BitwiseFn_bitwise_or,
  11. FloatPow,
  12. FloatTrueDiv,
  13. FloorDiv,
  14. IntTrueDiv,
  15. Max,
  16. Min,
  17. Mod,
  18. OpaqueUnaryFn_exp,
  19. OpaqueUnaryFn_log,
  20. OpaqueUnaryFn_log2,
  21. OpaqueUnaryFn_sqrt,
  22. PowByNatural,
  23. RoundDecimal,
  24. RoundToInt,
  25. ToFloat,
  26. TruncToInt,
  27. )
  28. # The sympy interpretation of operators. It will also sometimes work with
  29. # plain int/float, but if you do certain operations you will get out a
  30. # sympy.Basic in the end. If you want the Python/FX traceable interpretation,
  31. # check PythonReferenceAnalysis.
  32. # NB: For magic methods this needs to use normal magic methods
  33. # so that test_magic_methods works
  34. class ReferenceAnalysis:
  35. @staticmethod
  36. def constant(c, dtype):
  37. return sympy.sympify(c)
  38. @staticmethod
  39. def or_(a, b):
  40. return a | b
  41. @staticmethod
  42. def and_(a, b):
  43. return a & b
  44. @staticmethod
  45. def eq(a, b):
  46. if isinstance(a, sympy.Expr) or isinstance(b, sympy.Expr):
  47. return sympy.Eq(a, b)
  48. return a == b
  49. @classmethod
  50. def ne(cls, a, b):
  51. return cls.not_(cls.eq(a, b))
  52. @staticmethod
  53. def lt(a, b):
  54. return a < b
  55. @staticmethod
  56. def gt(a, b):
  57. return a > b
  58. @staticmethod
  59. def le(a, b):
  60. return a <= b
  61. @staticmethod
  62. def ge(a, b):
  63. return a >= b
  64. @staticmethod
  65. def not_(a):
  66. assert not isinstance(a, bool)
  67. return ~a
  68. @staticmethod
  69. def reciprocal(x):
  70. return FloatTrueDiv(1.0, x)
  71. @staticmethod
  72. def square(x):
  73. return PowByNatural(x, 2)
  74. @staticmethod
  75. def trunc_to_int(x, dtype):
  76. return TruncToInt(x)
  77. @staticmethod
  78. def ceil_to_int(x, dtype):
  79. return sympy.ceiling(x)
  80. @staticmethod
  81. def floor_to_int(x, dtype):
  82. return sympy.floor(x)
  83. @staticmethod
  84. def floor(x):
  85. return _keep_float(sympy.floor)(x)
  86. @staticmethod
  87. def ceil(x):
  88. return _keep_float(sympy.ceiling)(x)
  89. @staticmethod
  90. def to_dtype(x, dtype):
  91. if dtype == torch.float64:
  92. return ToFloat(x)
  93. raise NotImplementedError(f"to_dtype {dtype} NYI")
  94. @staticmethod
  95. def mod(x, y):
  96. return Mod(x, y)
  97. @staticmethod
  98. def abs(x):
  99. return abs(x)
  100. @staticmethod
  101. def neg(x):
  102. return -x
  103. @staticmethod
  104. def truediv(a, b):
  105. return FloatTrueDiv(a, b)
  106. @staticmethod
  107. def int_truediv(a, b):
  108. return IntTrueDiv(a, b)
  109. @staticmethod
  110. def floordiv(a, b):
  111. return FloorDiv(a, b)
  112. @staticmethod
  113. def truncdiv(a, b):
  114. raise NotImplementedError("TODO: truncdiv")
  115. @staticmethod
  116. def add(a, b):
  117. return _keep_float(operator.add)(a, b)
  118. @classmethod
  119. def sym_sum(cls, args):
  120. return sympy.Add(*args)
  121. @staticmethod
  122. def mul(a, b):
  123. return _keep_float(operator.mul)(a, b)
  124. @staticmethod
  125. def sub(a, b):
  126. return _keep_float(operator.sub)(a, b)
  127. @staticmethod
  128. def exp(x):
  129. return OpaqueUnaryFn_exp(x)
  130. @staticmethod
  131. def log(x):
  132. return OpaqueUnaryFn_log(x)
  133. @staticmethod
  134. def log2(x):
  135. return OpaqueUnaryFn_log2(x)
  136. @staticmethod
  137. def sqrt(x):
  138. return OpaqueUnaryFn_sqrt(x)
  139. @staticmethod
  140. def pow(a, b):
  141. return _keep_float(FloatPow)(a, b)
  142. @staticmethod
  143. def pow_by_natural(a, b):
  144. return PowByNatural(a, b)
  145. @staticmethod
  146. def minimum(a, b):
  147. return Min(a, b)
  148. @staticmethod
  149. def maximum(a, b):
  150. return Max(a, b)
  151. @staticmethod
  152. def round_to_int(a, dtype):
  153. return RoundToInt(a)
  154. @staticmethod
  155. def round_decimal(a, b):
  156. return RoundDecimal(a, b)
  157. @staticmethod
  158. def bitwise_and(a, b):
  159. return BitwiseFn_bitwise_and(a, b)
  160. @staticmethod
  161. def bitwise_or(a, b):
  162. return BitwiseFn_bitwise_or(a, b)
  163. # Unlike ReferenceAnalysis, does NOT sympyify, instead, works with plain
  164. # Python types and is FX traceable. Inheritance here is purely for code
  165. # sharing (TODO: considering splitting out a BaseReferenceAnalysis).
  166. class PythonReferenceAnalysis(ReferenceAnalysis):
  167. @staticmethod
  168. def constant(c, dtype):
  169. if dtype is torch.int64:
  170. return int(c)
  171. elif dtype is torch.double:
  172. return float(c)
  173. elif dtype is torch.bool:
  174. return bool(c)
  175. else:
  176. raise AssertionError(f"unrecognized dtype {dtype}")
  177. @staticmethod
  178. def not_(a):
  179. return torch.sym_not(a)
  180. @classmethod
  181. def sym_sum(cls, args):
  182. if len(args) == 0:
  183. return 0
  184. if len(args) == 1:
  185. return args[0]
  186. acc = cls.add(args[0], args[1])
  187. for i in range(2, len(args)):
  188. acc = cls.add(acc, args[i])
  189. return acc
  190. @staticmethod
  191. def floordiv(a, b):
  192. return a // b
  193. @staticmethod
  194. def mod(x, y):
  195. return x % y
  196. @staticmethod
  197. def truncdiv(a, b):
  198. return a / b
  199. @staticmethod
  200. def to_dtype(x, dtype):
  201. if dtype == torch.float64:
  202. return torch.sym_float(x)
  203. raise NotImplementedError(f"to_dtype {dtype} NYI")
  204. @staticmethod
  205. def exp(x):
  206. raise AssertionError("exp is not valid shape sympy expr")
  207. @staticmethod
  208. def log(x):
  209. raise AssertionError("log is not valid shape sympy expr")
  210. @staticmethod
  211. def log2(x):
  212. return torch._sym_log2(x) # type: ignore[attr-defined]
  213. @staticmethod
  214. def sqrt(x):
  215. return torch._sym_sqrt(x) # type: ignore[attr-defined]
  216. @staticmethod
  217. def minimum(a, b):
  218. return torch.sym_min(a, b)
  219. @staticmethod
  220. def maximum(a, b):
  221. return torch.sym_max(a, b)
  222. @staticmethod
  223. def floor_to_int(x, dtype):
  224. return math.floor(x)
  225. @staticmethod
  226. def ceil_to_int(x, dtype):
  227. return math.ceil(x)
  228. @staticmethod
  229. def floor(x):
  230. return float(math.floor(x))
  231. @staticmethod
  232. def ceil(x):
  233. return float(math.ceil(x))
  234. @staticmethod
  235. def truediv(a, b):
  236. return a / b
  237. @staticmethod
  238. def pow(a, b):
  239. return a**b
  240. @staticmethod
  241. def pow_by_natural(a, b):
  242. # Pray that safe_pow is not needed here lol. In particular, this
  243. # never participates in VR low/high ranges, so overflow should be
  244. # unlikely
  245. return a**b
  246. @staticmethod
  247. def round_to_int(a, dtype):
  248. return round(a)
  249. @staticmethod
  250. def round_decimal(a, b):
  251. return round(a, ndigits=b)
  252. @staticmethod
  253. def bitwise_and(a, b):
  254. return a & b
  255. @staticmethod
  256. def bitwise_or(a, b):
  257. return a | b
  258. # Like PythonReferenceAnalysis, but some export-unfriendly choices of
  259. # operators to make things faster
  260. class OptimizedPythonReferenceAnalysis(PythonReferenceAnalysis):
  261. @staticmethod
  262. def sym_sum(args):
  263. return torch.sym_sum(args)
  264. def _to_dtype(x: torch.Tensor, dtype: torch.dtype) -> torch.Tensor:
  265. return torch.ops.prims.convert_element_type.default(x, dtype)
  266. # Suppose we have some int/float arguments. This diagram commutes:
  267. #
  268. # int/float -- PythonReferenceAnalysis.op --> int/float
  269. # | |
  270. # | |
  271. # torch.tensor(..., dtype=torch.int64/torch.float64)
  272. # | |
  273. # V V
  274. # Tensor -- TensorReferenceAnalysis.op --> Tensor
  275. #
  276. # NB: int before and after must be representable in int64 (we will
  277. # insert guards accordingly.)
  278. #
  279. # This is guaranteed to be FX traceable with OpOverloads only.
  280. class TensorReferenceAnalysis:
  281. # NB: This is actually dead, because with Proxy tracing the factory
  282. # function isn't traced correctly. Here for completeness.
  283. @staticmethod
  284. def constant(c, dtype):
  285. d: Union[int, float, bool]
  286. if dtype is torch.int64:
  287. d = int(c)
  288. elif dtype is torch.double:
  289. d = float(c)
  290. elif dtype is torch.bool:
  291. d = bool(c)
  292. else:
  293. raise AssertionError(f"unrecognized dtype {dtype}")
  294. return torch.ops.aten.scalar_tensor.default(d, dtype=dtype)
  295. @staticmethod
  296. def or_(a, b):
  297. return torch.ops.aten.logical_or.default(a, b)
  298. @staticmethod
  299. def and_(a, b):
  300. return torch.ops.aten.logical_and.default(a, b)
  301. @staticmethod
  302. def bitwise_and(a, b):
  303. return torch.ops.aten.bitwise_and(a, b)
  304. @staticmethod
  305. def bitwise_or(a, b):
  306. return torch.ops.aten.bitwise_or(a, b)
  307. @staticmethod
  308. def eq(a, b):
  309. return torch.ops.aten.eq.Tensor(a, b)
  310. @classmethod
  311. def ne(cls, a, b):
  312. return torch.ops.aten.ne.Tensor(a, b)
  313. @staticmethod
  314. def lt(a, b):
  315. return torch.ops.aten.lt.Tensor(a, b)
  316. @staticmethod
  317. def gt(a, b):
  318. return torch.ops.aten.gt.Tensor(a, b)
  319. @staticmethod
  320. def le(a, b):
  321. return torch.ops.aten.le.Tensor(a, b)
  322. @staticmethod
  323. def ge(a, b):
  324. return torch.ops.aten.ge.Tensor(a, b)
  325. @staticmethod
  326. def not_(a):
  327. return torch.ops.aten.logical_not.default(a)
  328. @staticmethod
  329. def reciprocal(x):
  330. return torch.ops.aten.reciprocal.default(x)
  331. @staticmethod
  332. def square(x):
  333. # TODO: maybe composite implicit autograd doesn't work here?
  334. return torch.ops.aten.square.default(x)
  335. @staticmethod
  336. def trunc_to_int(x, dtype):
  337. return _to_dtype(torch.ops.aten.trunc.default(x), dtype)
  338. @staticmethod
  339. def ceil_to_int(x, dtype):
  340. return _to_dtype(torch.ops.aten.ceil.default(x), dtype)
  341. @staticmethod
  342. def floor_to_int(x, dtype):
  343. return _to_dtype(torch.ops.aten.floor.default(x), dtype)
  344. @staticmethod
  345. def floor(x):
  346. return torch.ops.aten.floor.default(x)
  347. @staticmethod
  348. def ceil(x):
  349. return torch.ops.aten.ceil.default(x)
  350. @staticmethod
  351. def to_dtype(x, dtype):
  352. return _to_dtype(x, dtype)
  353. @staticmethod
  354. def mod(x, y):
  355. # TODO: https://github.com/pytorch/pytorch/pull/133654
  356. raise NotImplementedError(
  357. "no C-style modulus operation available from frontend atm"
  358. )
  359. @staticmethod
  360. def abs(x):
  361. return torch.ops.aten.abs.default(x)
  362. @staticmethod
  363. def neg(x):
  364. return torch.ops.aten.neg.default(x)
  365. @staticmethod
  366. def truediv(a, b):
  367. return torch.ops.aten.true_divide.Tensor(a, b)
  368. @staticmethod
  369. def int_truediv(a, b):
  370. raise NotImplementedError(
  371. "Python int truediv difficult to implement in PyTorch atm"
  372. )
  373. # TODO: This is wrong, CPython has a custom implementation of true
  374. # division that results in higher precision when the floats are
  375. # sufficiently large. Short term fix: add a guard here
  376. return torch.ops.aten.true_divide.default(
  377. _to_dtype(a, torch.float64), _to_dtype(b, torch.float64)
  378. )
  379. @staticmethod
  380. def floordiv(a, b):
  381. return torch.ops.aten.div.Tensor_mode(a, b, rounding_mode="floor")
  382. @staticmethod
  383. def truncdiv(a, b):
  384. raise NotImplementedError(
  385. "no C-style truncdiv operation available from frontend atm"
  386. )
  387. @staticmethod
  388. def add(a, b):
  389. return torch.ops.aten.add.Tensor(a, b)
  390. @staticmethod
  391. def mul(a, b):
  392. return torch.ops.aten.mul.Tensor(a, b)
  393. @staticmethod
  394. def sub(a, b):
  395. return torch.ops.aten.sub.Tensor(a, b)
  396. @staticmethod
  397. def exp(x):
  398. return torch.ops.aten.exp.default(x)
  399. @staticmethod
  400. def log(x):
  401. return torch.ops.aten.log.default(x)
  402. @staticmethod
  403. def log2(x):
  404. return torch.ops.aten.log2.default(x)
  405. @staticmethod
  406. def sqrt(x):
  407. return torch.ops.aten.sqrt.default(x)
  408. @staticmethod
  409. def sin(x):
  410. return torch.ops.aten.sin.default(x)
  411. @staticmethod
  412. def cos(x):
  413. return torch.ops.aten.cos.default(x)
  414. @staticmethod
  415. def tanh(x):
  416. return torch.ops.aten.tanh.default(x)
  417. @staticmethod
  418. def sinh(x):
  419. return torch.ops.aten.sinh.default(x)
  420. @staticmethod
  421. def cosh(x):
  422. return torch.ops.aten.cosh.default(x)
  423. @staticmethod
  424. def tan(x):
  425. return torch.ops.aten.tan.default(x)
  426. @staticmethod
  427. def acos(x):
  428. return torch.ops.aten.acos.default(x)
  429. @staticmethod
  430. def atan(x):
  431. return torch.ops.aten.atan.default(x)
  432. @staticmethod
  433. def asin(x):
  434. return torch.ops.aten.asin.default(x)
  435. @staticmethod
  436. def pow(a, b):
  437. return torch.ops.aten.pow.Tensor_Tensor(a, b)
  438. @staticmethod
  439. def pow_by_natural(a, b):
  440. # NB: pow handles int x int fine
  441. return torch.ops.aten.pow.Tensor_Tensor(a, b)
  442. @staticmethod
  443. def minimum(a, b):
  444. return torch.ops.aten.minimum.default(a, b)
  445. @staticmethod
  446. def maximum(a, b):
  447. return torch.ops.aten.maximum.default(a, b)
  448. @staticmethod
  449. def round_to_int(a, dtype):
  450. return torch.ops.aten.round.default(a)
  451. @staticmethod
  452. def round_decimal(a, b):
  453. raise NotImplementedError(
  454. "round decimal doesn't support Tensor second argument atm"
  455. )
  456. # return torch.ops.aten.round.decimals(a, b)