mappings.py 18 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757
  1. import operator
  2. from typing import Callable, Optional
  3. import torch
  4. import torch.ao.nn.intrinsic as nni
  5. import torch.ao.nn.intrinsic.qat as nniqat
  6. import torch.ao.nn.intrinsic.quantized as nniq
  7. import torch.ao.nn.intrinsic.quantized.dynamic as nniqd
  8. import torch.ao.nn.qat as nnqat
  9. import torch.ao.nn.qat.dynamic as nnqatd
  10. import torch.ao.nn.quantized as nnq
  11. import torch.ao.nn.quantized.dynamic as nnqd
  12. import torch.ao.quantization.fx._lower_to_native_backend as _lower_to_native_backend
  13. import torch.ao.quantization.quantization_mappings as quantization_mappings
  14. import torch.nn as nn
  15. import torch.nn.functional as F
  16. from torch.ao.quantization.backend_config import get_native_backend_config
  17. from .ns_types import NSNodeTargetType
  18. toq = torch.ops.quantized
  19. def get_base_name_to_sets_of_related_ops() -> dict[str, set[NSNodeTargetType]]:
  20. # note: this set is modified below by items from backend_config
  21. sets_of_related_ops: list[set[NSNodeTargetType]] = [
  22. # conv modules
  23. {
  24. nn.Conv1d,
  25. },
  26. {
  27. nn.Conv2d,
  28. },
  29. {
  30. nn.Conv3d,
  31. },
  32. # conv functionals
  33. {
  34. F.conv1d,
  35. },
  36. {
  37. F.conv2d,
  38. },
  39. {
  40. F.conv3d,
  41. },
  42. # linear modules
  43. {
  44. nn.Linear,
  45. },
  46. # linear functionals
  47. {
  48. F.linear,
  49. },
  50. # average pool
  51. {
  52. nn.AvgPool1d,
  53. torch.avg_pool1d,
  54. },
  55. {
  56. nn.AvgPool2d,
  57. torch._C._nn.avg_pool2d,
  58. },
  59. {
  60. nn.AvgPool3d,
  61. torch._C._nn.avg_pool3d,
  62. },
  63. # adaptive average pool
  64. {
  65. nn.AdaptiveAvgPool1d,
  66. F.adaptive_avg_pool1d,
  67. },
  68. {
  69. nn.AdaptiveAvgPool2d,
  70. F.adaptive_avg_pool2d,
  71. },
  72. {
  73. nn.AdaptiveAvgPool3d,
  74. F.adaptive_avg_pool3d,
  75. },
  76. # LSTM
  77. {
  78. nn.LSTM,
  79. },
  80. # add
  81. {
  82. torch.add,
  83. operator.add, # x + y
  84. },
  85. # cat
  86. {
  87. torch.cat,
  88. },
  89. # mul
  90. {
  91. torch.mul,
  92. operator.mul,
  93. },
  94. # relu
  95. {
  96. F.relu,
  97. nn.ReLU,
  98. "relu",
  99. "relu_",
  100. torch.relu,
  101. },
  102. # maxpool
  103. {
  104. nn.MaxPool1d,
  105. F.max_pool1d,
  106. },
  107. {
  108. nn.MaxPool2d,
  109. F.max_pool2d,
  110. },
  111. {
  112. nn.MaxPool3d,
  113. F.max_pool3d,
  114. },
  115. # sigmoid
  116. {
  117. torch.sigmoid,
  118. "sigmoid",
  119. "sigmoid_",
  120. nn.Sigmoid,
  121. F.sigmoid,
  122. },
  123. # BatchNorm
  124. {
  125. nn.BatchNorm2d,
  126. },
  127. {
  128. nn.BatchNorm3d,
  129. },
  130. # ConvTranspose
  131. {
  132. nn.ConvTranspose1d,
  133. },
  134. {
  135. nn.ConvTranspose2d,
  136. },
  137. {
  138. nn.ConvTranspose3d,
  139. },
  140. # functional transposed conv
  141. {
  142. F.conv_transpose1d,
  143. },
  144. {
  145. F.conv_transpose2d,
  146. },
  147. {
  148. F.conv_transpose3d,
  149. },
  150. # ELU
  151. {
  152. nn.ELU,
  153. },
  154. # Embedding
  155. {
  156. nn.Embedding,
  157. },
  158. # EmbeddingBag
  159. {
  160. nn.EmbeddingBag,
  161. },
  162. # GroupNorm
  163. {
  164. nn.GroupNorm,
  165. },
  166. # Hardswish
  167. {
  168. nn.Hardswish,
  169. },
  170. # InstanceNorm
  171. {
  172. nn.InstanceNorm1d,
  173. },
  174. {
  175. nn.InstanceNorm2d,
  176. },
  177. {
  178. nn.InstanceNorm3d,
  179. },
  180. # LayerNorm
  181. {
  182. nn.LayerNorm,
  183. },
  184. # LeakyReLU
  185. {
  186. nn.LeakyReLU,
  187. },
  188. # ReLU6
  189. {
  190. nn.ReLU6,
  191. F.relu6,
  192. },
  193. # F.elu
  194. {
  195. F.elu,
  196. },
  197. # F.hardswish
  198. {
  199. F.hardswish,
  200. },
  201. # F.group_norm
  202. {
  203. F.group_norm,
  204. },
  205. # F.instance_norm
  206. {
  207. F.instance_norm,
  208. },
  209. # F.layer_norm
  210. {
  211. F.layer_norm,
  212. },
  213. # F.leaky_relu
  214. {
  215. F.leaky_relu,
  216. },
  217. # F.silu
  218. {
  219. nn.SiLU,
  220. F.silu,
  221. },
  222. # F.mish
  223. {
  224. nn.Mish,
  225. F.mish,
  226. },
  227. # F.tanh
  228. {
  229. nn.Tanh,
  230. F.tanh,
  231. torch.tanh,
  232. "tanh_",
  233. "tanh",
  234. },
  235. # F.hardsigmoid
  236. {
  237. "hardsigmoid_",
  238. "hardsigmoid",
  239. F.hardsigmoid,
  240. nn.Hardsigmoid,
  241. },
  242. # F.hardtanh
  243. {
  244. nn.Hardtanh,
  245. F.hardtanh,
  246. F.hardtanh_,
  247. },
  248. # floordiv
  249. {
  250. operator.floordiv,
  251. },
  252. # unsqueeze
  253. {
  254. torch.unsqueeze,
  255. },
  256. # stack
  257. {
  258. torch.stack,
  259. },
  260. # squeeze
  261. {
  262. torch.squeeze,
  263. },
  264. # sort
  265. {
  266. torch.sort,
  267. },
  268. # repeat_interleave
  269. {
  270. torch.repeat_interleave,
  271. },
  272. # min
  273. {
  274. torch.min,
  275. },
  276. # mean
  277. {
  278. torch.mean,
  279. },
  280. # max
  281. {
  282. torch.max,
  283. },
  284. # transpose
  285. {
  286. torch.transpose,
  287. },
  288. # flatten
  289. {
  290. torch.flatten,
  291. },
  292. # clamp
  293. {
  294. torch.clamp,
  295. },
  296. # chunk
  297. {
  298. torch.chunk,
  299. },
  300. # interpolate
  301. {
  302. torch.nn.functional.interpolate,
  303. },
  304. # dropout
  305. {
  306. nn.Dropout,
  307. },
  308. # F.dropout
  309. {
  310. F.dropout,
  311. },
  312. # matmul
  313. {
  314. torch.matmul,
  315. },
  316. # Softmax
  317. {
  318. nn.Softmax,
  319. },
  320. # PReLU
  321. {
  322. nn.PReLU,
  323. nnq.PReLU,
  324. },
  325. # F.prelu
  326. {
  327. F.prelu,
  328. toq.prelu,
  329. },
  330. # pixel shuffle
  331. {
  332. nn.PixelShuffle,
  333. },
  334. {
  335. F.pixel_shuffle,
  336. },
  337. # pixel unshuffle
  338. {
  339. nn.PixelUnshuffle,
  340. },
  341. {
  342. F.pixel_unshuffle,
  343. },
  344. # narrow
  345. {
  346. torch.narrow,
  347. },
  348. ]
  349. # for each floating point op, add versions of the op added by
  350. # backend_config
  351. backend_config = get_native_backend_config()
  352. new_connections: list[tuple[Callable, Callable]] = [
  353. # technical debt edge case
  354. (nn.Linear, nn.modules.linear.NonDynamicallyQuantizableLinear),
  355. ]
  356. for pattern, config in backend_config._pattern_complex_format_to_config.items():
  357. # pattern format: (c, (b, a))
  358. first_element = pattern
  359. # look from the end, because pattern is in reverse order
  360. while isinstance(first_element, (list, tuple)):
  361. first_element = first_element[-1]
  362. if config.fused_module is not None:
  363. # case 1: pattern fuses a pattern of ops into an op
  364. # example: nn.Conv1d, nn.ReLU fused into nni.ConvReLU1d
  365. new_connections.append((first_element, config.fused_module))
  366. if config.qat_module is not None:
  367. # case 2: pattern swaps a module into a QAT module
  368. # example: nni.ConvReLU1d swapped into nniqat.ConvReLU1d
  369. new_connections.append((first_element, config.qat_module))
  370. if config.reference_quantized_module is not None:
  371. # case 3: reference version of floating point module, such as
  372. # nn.Conv2d and nnqr.Conv2d
  373. new_connections.append((first_element, config.reference_quantized_module))
  374. #
  375. # Add reference module swaps from default lowering path
  376. #
  377. for source_to_target in (
  378. _lower_to_native_backend.STATIC_LOWER_MODULE_MAP,
  379. _lower_to_native_backend.DYNAMIC_LOWER_MODULE_MAP,
  380. _lower_to_native_backend.WEIGHT_ONLY_LOWER_MODULE_MAP,
  381. _lower_to_native_backend.SPECIAL_PATTERN_LOWER_MODULE_MAP,
  382. ):
  383. for source, target in source_to_target.items(): # type: ignore[attr-defined]
  384. new_connections.append((source, target))
  385. for source_to_double_target in (
  386. _lower_to_native_backend.STATIC_LOWER_FUSED_MODULE_MAP,
  387. _lower_to_native_backend.STATIC_LOWER_FUSED_MODULE_TWO_INPUTS_MAP,
  388. _lower_to_native_backend.DYNAMIC_LOWER_FUSED_MODULE_MAP,
  389. ):
  390. for source, (target1, target2) in source_to_double_target.items(): # type: ignore[attr-defined]
  391. new_connections.append((source, target1))
  392. new_connections.append((source, target2))
  393. #
  394. # Add function swaps from default lowering path
  395. #
  396. for source, ( # type:ignore[assignment]
  397. target1,
  398. target2,
  399. ) in _lower_to_native_backend.STATIC_LOWER_FUNCTIONAL_MAP.items():
  400. new_connections.append((source, target1))
  401. new_connections.append((source, target2))
  402. for source_to_target in (
  403. _lower_to_native_backend.QBIN_OP_MAPPING,
  404. _lower_to_native_backend.QBIN_RELU_OP_MAPPING,
  405. quantization_mappings.DEFAULT_FLOAT_TO_QUANTIZED_OPERATOR_MAPPINGS,
  406. ):
  407. for source, target in source_to_target.items(): # type:ignore[assignment]
  408. new_connections.append((source, target))
  409. #
  410. # Add other swaps, ideally in the future this could be removed
  411. # after the lowering code stops using these.
  412. #
  413. for source_to_target in (
  414. quantization_mappings.DEFAULT_DYNAMIC_QUANT_MODULE_MAPPINGS,
  415. ):
  416. for source, target in source_to_target.items(): # type:ignore[assignment]
  417. new_connections.append((source, target))
  418. # add the new connections from backend_config
  419. for item1, item2 in new_connections:
  420. for set_of_related_ops in sets_of_related_ops:
  421. if item1 in set_of_related_ops or item2 in set_of_related_ops:
  422. set_of_related_ops.add(item1)
  423. set_of_related_ops.add(item2)
  424. break
  425. base_name_to_sets_of_related_ops: dict[str, set[NSNodeTargetType]] = {}
  426. for counter, set_of_related_ops in enumerate(sets_of_related_ops):
  427. base_name = str(counter)
  428. base_name_to_sets_of_related_ops[base_name] = set_of_related_ops
  429. return base_name_to_sets_of_related_ops
  430. def get_base_name_for_op(
  431. base_name_to_sets_of_related_ops: dict[str, set[NSNodeTargetType]],
  432. op: NSNodeTargetType,
  433. ) -> Optional[str]:
  434. for base_name, set_of_related_ops in base_name_to_sets_of_related_ops.items():
  435. if op in set_of_related_ops:
  436. return base_name
  437. return None
  438. def add_op_to_sets_of_related_ops(
  439. base_name_to_sets_of_related_ops: dict[str, set[NSNodeTargetType]],
  440. op: NSNodeTargetType,
  441. related_op: Optional[NSNodeTargetType],
  442. ) -> None:
  443. if related_op is not None:
  444. for set_of_related_ops in base_name_to_sets_of_related_ops.values():
  445. if related_op in set_of_related_ops:
  446. set_of_related_ops.add(op)
  447. return
  448. # if we got here, related_op was not found
  449. raise AssertionError(f"{related_op} was not found")
  450. else:
  451. counter = 0
  452. while str(counter) in base_name_to_sets_of_related_ops:
  453. counter += 1
  454. base_name_to_sets_of_related_ops[str(counter)] = {op}
  455. # TODO(future PR): clean this up
  456. def get_node_type_to_io_type_map() -> dict[str, set[NSNodeTargetType]]:
  457. FUNS_IO_TYPE_FP32: set[NSNodeTargetType] = {
  458. F.linear,
  459. F.conv1d,
  460. F.conv2d,
  461. F.conv3d,
  462. torch.cat,
  463. F.elu,
  464. F.hardswish,
  465. F.instance_norm,
  466. F.layer_norm,
  467. F.leaky_relu,
  468. F.dropout,
  469. F.silu,
  470. F.mish,
  471. operator.add,
  472. torch.add,
  473. operator.mul,
  474. torch.mul,
  475. torch.sum,
  476. F.prelu,
  477. }
  478. FUNS_IO_TYPE_FP16: set[NSNodeTargetType] = set()
  479. FUNS_IO_TYPE_INT8: set[NSNodeTargetType] = {
  480. toq.linear,
  481. toq.linear_relu,
  482. toq.conv1d,
  483. toq.conv1d_relu,
  484. toq.conv2d,
  485. toq.conv2d_relu,
  486. toq.conv3d,
  487. toq.conv3d_relu,
  488. toq.cat,
  489. toq.elu,
  490. toq.hardswish,
  491. toq.instance_norm,
  492. toq.layer_norm,
  493. toq.leaky_relu,
  494. toq.dropout,
  495. toq.prelu,
  496. # TODO(future PR): implement shadowing for binary ops and
  497. # uncomment below
  498. # toq.add,
  499. # toq.mul,
  500. }
  501. FUNS_IO_TYPE_FP32_OR_INT8: set[NSNodeTargetType] = {
  502. F.relu,
  503. F.tanh,
  504. torch.tanh,
  505. F.sigmoid,
  506. torch.sigmoid,
  507. F.hardsigmoid,
  508. operator.floordiv,
  509. torch.adaptive_avg_pool1d,
  510. F.adaptive_avg_pool2d,
  511. F.adaptive_avg_pool3d,
  512. F.dropout,
  513. F.hardtanh,
  514. F.hardtanh_,
  515. F.interpolate,
  516. F.max_pool1d,
  517. F.max_pool2d,
  518. F.max_pool3d,
  519. F.relu6,
  520. F.pixel_shuffle,
  521. F.pixel_unshuffle,
  522. torch.avg_pool1d,
  523. torch._C._nn.avg_pool2d,
  524. torch._C._nn.avg_pool3d,
  525. torch.cat,
  526. torch.chunk,
  527. torch.clamp,
  528. torch.flatten,
  529. torch.transpose,
  530. torch.max,
  531. torch.mean,
  532. torch.min,
  533. torch.narrow,
  534. torch.repeat_interleave,
  535. torch.sort,
  536. torch.squeeze,
  537. torch.stack,
  538. torch.unsqueeze,
  539. operator.add,
  540. }
  541. MODS_IO_TYPE_FP32: set[NSNodeTargetType] = {
  542. nn.Linear,
  543. nnqat.Linear,
  544. nnqatd.Linear,
  545. nnqd.Linear,
  546. torch.nn.modules.linear.NonDynamicallyQuantizableLinear,
  547. nn.Conv1d,
  548. nn.Conv2d,
  549. nn.Conv3d,
  550. nnqat.Conv1d,
  551. nnqat.Conv2d,
  552. nnqat.Conv3d,
  553. nnqat.Embedding,
  554. nnqat.EmbeddingBag,
  555. nn.LSTM,
  556. # note: nnqd.Linear is an instance of nnq.Linear, so this
  557. # check has to happen before the int8 module check
  558. nnqd.LSTM,
  559. nn.BatchNorm2d,
  560. nn.BatchNorm3d,
  561. nn.Dropout,
  562. nn.ConvTranspose1d,
  563. nn.ConvTranspose2d,
  564. nn.ConvTranspose3d,
  565. nn.ELU,
  566. nn.GroupNorm,
  567. nn.InstanceNorm1d,
  568. nn.InstanceNorm2d,
  569. nn.InstanceNorm3d,
  570. nn.LayerNorm,
  571. nn.Hardswish,
  572. nn.LeakyReLU,
  573. nn.ReLU6,
  574. nn.SiLU,
  575. nn.Mish,
  576. nn.Softmax,
  577. nn.PReLU,
  578. nni.BNReLU2d,
  579. nni.BNReLU3d,
  580. nni.ConvReLU1d,
  581. nni.ConvReLU2d,
  582. nni.ConvReLU3d,
  583. nni.LinearReLU,
  584. nni.LinearBn1d,
  585. nni.ConvBn1d,
  586. nni.ConvBn2d,
  587. nni.ConvBn3d,
  588. nniqat.ConvBn1d,
  589. nniqat.ConvBn2d,
  590. nniqat.ConvBn3d,
  591. nniqat.ConvBnReLU1d,
  592. nniqat.ConvBnReLU2d,
  593. nniqat.ConvBnReLU3d,
  594. nniqat.ConvReLU1d,
  595. nniqat.ConvReLU2d,
  596. nniqat.ConvReLU3d,
  597. nniqat.LinearReLU,
  598. nniqat.LinearBn1d,
  599. nniqd.LinearReLU,
  600. nni.LinearLeakyReLU,
  601. nni.LinearTanh,
  602. nni.ConvAdd2d,
  603. nni.ConvAddReLU2d,
  604. }
  605. MODS_IO_TYPE_INT8: set[NSNodeTargetType] = {
  606. nnq.Linear,
  607. nnq.Conv1d,
  608. nnq.Conv2d,
  609. nnq.Conv3d,
  610. nnq.BatchNorm2d,
  611. nnq.BatchNorm3d,
  612. nnq.Dropout,
  613. nnq.ConvTranspose1d,
  614. nnq.ConvTranspose2d,
  615. nnq.ELU,
  616. nnq.InstanceNorm1d,
  617. nnq.InstanceNorm2d,
  618. nnq.InstanceNorm3d,
  619. nnq.LayerNorm,
  620. nnq.Hardswish,
  621. nnq.LeakyReLU,
  622. nnq.Embedding,
  623. nnq.EmbeddingBag,
  624. nnq.Dropout,
  625. nnq.Softmax,
  626. nnq.PReLU,
  627. nniq.BNReLU2d,
  628. nniq.BNReLU3d,
  629. nniq.ConvReLU1d,
  630. nniq.ConvReLU2d,
  631. nniq.ConvReLU3d,
  632. nniq.LinearReLU,
  633. nniq.LinearLeakyReLU,
  634. nniq.LinearTanh,
  635. nniq.ConvAdd2d,
  636. nniq.ConvAddReLU2d,
  637. }
  638. MODS_IO_TYPE_FP32_OR_INT8: set[NSNodeTargetType] = {
  639. nn.ReLU,
  640. nn.Tanh,
  641. nn.Sigmoid,
  642. nn.Hardsigmoid,
  643. nn.AdaptiveAvgPool1d,
  644. nn.AdaptiveAvgPool2d,
  645. nn.AdaptiveAvgPool3d,
  646. nn.AvgPool1d,
  647. nn.AvgPool2d,
  648. nn.AvgPool3d,
  649. nn.Dropout,
  650. nn.Hardtanh,
  651. nn.Identity,
  652. nn.MaxPool1d,
  653. nn.MaxPool2d,
  654. nn.MaxPool3d,
  655. nn.PixelShuffle,
  656. nn.PixelUnshuffle,
  657. nn.ReLU6,
  658. }
  659. METHS_IO_TYPE_FP32_OR_INT8: set[NSNodeTargetType] = {
  660. "sigmoid_",
  661. "sigmoid",
  662. "tanh_",
  663. "tanh",
  664. "hardsigmoid_",
  665. "hardsigmoid",
  666. "relu_",
  667. "relu",
  668. }
  669. return {
  670. "funs_io_type_fp32": FUNS_IO_TYPE_FP32,
  671. "funs_io_type_fp16": FUNS_IO_TYPE_FP16,
  672. "funs_io_type_int8": FUNS_IO_TYPE_INT8,
  673. "funs_io_type_fp32_or_int8": FUNS_IO_TYPE_FP32_OR_INT8,
  674. "mods_io_type_fp32": MODS_IO_TYPE_FP32,
  675. "mods_io_type_int8": MODS_IO_TYPE_INT8,
  676. "mods_io_type_fp32_or_int8": MODS_IO_TYPE_FP32_OR_INT8,
  677. "meths_io_type_fp32_or_int8": METHS_IO_TYPE_FP32_OR_INT8,
  678. }
  679. def get_unmatchable_types_map() -> dict[str, set[NSNodeTargetType]]:
  680. FUNS_UNMATCHABLE: set[NSNodeTargetType] = {
  681. torch.quantize_per_tensor,
  682. operator.getitem,
  683. }
  684. MODS_UNMATCHABLE: set[NSNodeTargetType] = {
  685. nn.Identity,
  686. }
  687. METHS_UNMATCHABLE: set[NSNodeTargetType] = {
  688. "to",
  689. "dequantize",
  690. "reshape",
  691. "view",
  692. "unsqueeze_",
  693. "unsqueeze",
  694. "transpose",
  695. "squeeze_",
  696. "squeeze",
  697. "size",
  698. "shape",
  699. "resize_",
  700. "repeat_interleave",
  701. "repeat",
  702. "permute",
  703. "numel",
  704. "mean",
  705. "detach_",
  706. "detach",
  707. "contiguous",
  708. "clamp",
  709. "chunk",
  710. }
  711. return {
  712. "funs_unmatchable": FUNS_UNMATCHABLE,
  713. "mods_unmatchable": MODS_UNMATCHABLE,
  714. "meths_unmatchable": METHS_UNMATCHABLE,
  715. }