fake_quantize.py 22 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650
  1. # mypy: allow-untyped-decorators
  2. # mypy: allow-untyped-defs
  3. """Implements modules used to perform fake quantization."""
  4. import re
  5. from abc import ABC, abstractmethod
  6. from typing import Any
  7. import torch
  8. from torch.ao.quantization.observer import (
  9. _with_args,
  10. default_fixed_qparams_range_0to1_observer,
  11. default_fixed_qparams_range_neg1to1_observer,
  12. FixedQParamsObserver,
  13. HistogramObserver,
  14. MovingAverageMinMaxObserver,
  15. MovingAveragePerChannelMinMaxObserver,
  16. )
  17. from torch.nn import Module
  18. __all__ = [
  19. "FakeQuantizeBase",
  20. "FakeQuantize",
  21. "FixedQParamsFakeQuantize",
  22. "FusedMovingAvgObsFakeQuantize",
  23. "disable_fake_quant",
  24. "disable_observer",
  25. "enable_fake_quant",
  26. "enable_observer",
  27. "default_fake_quant",
  28. "default_weight_fake_quant",
  29. "default_dynamic_fake_quant",
  30. "default_fixed_qparams_range_neg1to1_fake_quant",
  31. "default_fixed_qparams_range_0to1_fake_quant",
  32. "default_symmetric_fixed_qparams_fake_quant",
  33. "default_affine_fixed_qparams_fake_quant",
  34. "default_per_channel_weight_fake_quant",
  35. "default_embedding_fake_quant",
  36. "default_embedding_fake_quant_4bit",
  37. "default_histogram_fake_quant",
  38. "default_fused_act_fake_quant",
  39. "default_fused_wt_fake_quant",
  40. "default_fused_per_channel_wt_fake_quant",
  41. "fused_wt_fake_quant_range_neg_127_to_127",
  42. "fused_per_channel_wt_fake_quant_range_neg_127_to_127",
  43. ]
  44. def _is_per_channel(qscheme: "torch.qscheme") -> bool:
  45. return qscheme in [
  46. torch.per_channel_symmetric,
  47. torch.per_channel_affine,
  48. torch.per_channel_affine_float_qparams,
  49. ]
  50. def _is_per_tensor(qscheme: "torch.qscheme") -> bool:
  51. return qscheme in [torch.per_tensor_symmetric, torch.per_tensor_affine]
  52. def _is_symmetric_quant(qscheme: "torch.qscheme") -> bool:
  53. return qscheme in [torch.per_tensor_symmetric, torch.per_channel_symmetric]
  54. def _is_float_qparams(qscheme: "torch.qscheme") -> bool:
  55. return qscheme in [
  56. torch.per_channel_affine_float_qparams,
  57. ]
  58. class FakeQuantizeBase(ABC, Module):
  59. r"""Base fake quantize module.
  60. Base fake quantize module
  61. Any fake quantize implementation should derive from this class.
  62. Concrete fake quantize module should follow the same API. In forward, they will update
  63. the statistics of the observed Tensor and fake quantize the input. They should also provide a
  64. `calculate_qparams` function that computes the quantization parameters given
  65. the collected statistics.
  66. """
  67. fake_quant_enabled: torch.Tensor
  68. observer_enabled: torch.Tensor
  69. def __init__(self) -> None:
  70. """Set fake_quant_enabled and observer_enabled."""
  71. super().__init__()
  72. # fake_quant_enabled and observer_enabled are buffers to support their
  73. # replication in DDP. Data type is uint8 because NCCL does not support
  74. # bool tensors.
  75. self.register_buffer("fake_quant_enabled", torch.tensor([1], dtype=torch.uint8))
  76. self.register_buffer("observer_enabled", torch.tensor([1], dtype=torch.uint8))
  77. @abstractmethod
  78. def forward(self, x):
  79. pass
  80. @abstractmethod
  81. def calculate_qparams(self, **kwargs):
  82. pass
  83. @torch.jit.export
  84. def enable_fake_quant(self, enabled: bool = True) -> None:
  85. self.fake_quant_enabled[0] = 1 if enabled else 0
  86. @torch.jit.export
  87. def disable_fake_quant(self):
  88. self.enable_fake_quant(False)
  89. @torch.jit.export
  90. def enable_observer(self, enabled: bool = True) -> None:
  91. self.observer_enabled[0] = 1 if enabled else 0
  92. @torch.jit.export
  93. def disable_observer(self):
  94. self.enable_observer(False)
  95. @classmethod
  96. def with_args(cls, **kwargs):
  97. fake_quant_constructor = _with_args(cls, **kwargs)
  98. # need to assign the correct module to fake_quantize
  99. # constructors to satisfy public v private requirements
  100. fake_quant_constructor.__module__ = "torch.ao.quantization.fake_quantize"
  101. return fake_quant_constructor
  102. class FakeQuantize(FakeQuantizeBase):
  103. r"""Simulate the quantize and dequantize operations in training time.
  104. The output of this module is given by::
  105. x_out = (
  106. clamp(round(x / scale + zero_point), quant_min, quant_max) - zero_point
  107. ) * scale
  108. * :attr:`is_dynamic` indicates whether the fake quantie is a placeholder for dynamic quantization
  109. operators (choose_qparams -> q -> dq) or static quantization operators (q -> dq)
  110. * :attr:`scale` defines the scale factor used for quantization.
  111. * :attr:`zero_point` specifies the quantized value to which 0 in floating point maps to
  112. * :attr:`fake_quant_enabled` controls the application of fake quantization on tensors, note that
  113. statistics can still be updated.
  114. * :attr:`observer_enabled` controls statistics collection on tensors
  115. * :attr:`dtype` specifies the quantized dtype that is being emulated with fake-quantization,
  116. allowable values are torch.qint8 and torch.quint8.
  117. Args:
  118. observer (module): Module for observing statistics on input tensors and calculating scale
  119. and zero-point.
  120. observer_kwargs (optional): Arguments for the observer module
  121. Attributes:
  122. activation_post_process (Module): User provided module that collects statistics on the input tensor and
  123. provides a method to calculate scale and zero-point.
  124. """
  125. scale: torch.Tensor
  126. zero_point: torch.Tensor
  127. def __init__(
  128. self,
  129. observer=MovingAverageMinMaxObserver,
  130. quant_min=None,
  131. quant_max=None,
  132. is_dynamic=False,
  133. **observer_kwargs,
  134. ):
  135. super().__init__()
  136. # Populate quant_min/quant_max to observer_kwargs if valid
  137. if quant_min is not None and quant_max is not None:
  138. assert quant_min <= quant_max, (
  139. "quant_min must be less than or equal to quant_max"
  140. )
  141. dtype = observer_kwargs.get("dtype", torch.quint8)
  142. if hasattr(observer, "p"):
  143. # In case observer is _PartialWrapper, dtype can be stored in
  144. # observer.p.keywords["dtype"]
  145. dtype = getattr(getattr(observer, "p", {}), "keywords", {}).get(
  146. "dtype", dtype
  147. )
  148. assert torch.iinfo(dtype).min <= quant_min, "quant_min out of bound"
  149. assert quant_max <= torch.iinfo(dtype).max, "quant_max out of bound"
  150. observer_kwargs.update({"quant_min": quant_min, "quant_max": quant_max})
  151. observer_kwargs["is_dynamic"] = is_dynamic
  152. self.activation_post_process = observer(**observer_kwargs)
  153. # TODO: keeping self.quant_min/max for BC; remove after a couple releases
  154. # Users should use self.activation_post_process.quant_min
  155. self.quant_min = self.activation_post_process.quant_min
  156. self.quant_max = self.activation_post_process.quant_max
  157. self.is_dynamic = self.activation_post_process.is_dynamic
  158. if _is_float_qparams(self.activation_post_process.qscheme):
  159. zero_point_dtype = torch.float
  160. else:
  161. zero_point_dtype = torch.int
  162. self.register_buffer("scale", torch.tensor([1.0], dtype=torch.float))
  163. self.register_buffer("zero_point", torch.tensor([0], dtype=zero_point_dtype))
  164. self.dtype = self.activation_post_process.dtype
  165. self.qscheme = self.activation_post_process.qscheme
  166. self.ch_axis = (
  167. self.activation_post_process.ch_axis
  168. if hasattr(self.activation_post_process, "ch_axis")
  169. else -1
  170. )
  171. assert _is_per_channel(self.qscheme) or _is_per_tensor(self.qscheme), (
  172. "Only per channel and per tensor quantization are supported in fake quantize"
  173. + " got qscheme: "
  174. + str(self.qscheme)
  175. )
  176. self.is_per_channel = _is_per_channel(self.qscheme)
  177. @torch.jit.export
  178. def calculate_qparams(self): # type: ignore[override]
  179. return self.activation_post_process.calculate_qparams()
  180. def forward(self, X):
  181. if self.observer_enabled[0] == 1:
  182. self.activation_post_process(X.detach())
  183. _scale, _zero_point = self.calculate_qparams()
  184. _scale, _zero_point = (
  185. _scale.to(self.scale.device),
  186. _zero_point.to(self.zero_point.device),
  187. )
  188. if self.scale.shape != _scale.shape:
  189. self.scale.resize_(_scale.shape)
  190. self.zero_point.resize_(_zero_point.shape)
  191. self.scale.copy_(_scale)
  192. self.zero_point.copy_(_zero_point)
  193. if self.fake_quant_enabled[0] == 1:
  194. if self.is_per_channel:
  195. X = torch.fake_quantize_per_channel_affine(
  196. X,
  197. self.scale,
  198. self.zero_point,
  199. self.ch_axis,
  200. self.activation_post_process.quant_min,
  201. self.activation_post_process.quant_max,
  202. )
  203. else:
  204. X = torch.fake_quantize_per_tensor_affine(
  205. X,
  206. self.scale,
  207. self.zero_point,
  208. self.activation_post_process.quant_min,
  209. self.activation_post_process.quant_max,
  210. )
  211. return X
  212. @torch.jit.export
  213. def extra_repr(self):
  214. return (
  215. f"fake_quant_enabled={self.fake_quant_enabled}, observer_enabled={self.observer_enabled}, "
  216. f"quant_min={self.activation_post_process.quant_min}, quant_max={self.activation_post_process.quant_max}, "
  217. f"dtype={self.dtype}, qscheme={self.qscheme}, ch_axis={self.ch_axis}, "
  218. f"scale={self.scale}, zero_point={self.zero_point}"
  219. )
  220. def _save_to_state_dict(self, destination, prefix, keep_vars):
  221. # We cannot currently register scalar values as buffers, so need to manually
  222. # specify serialization here.
  223. super()._save_to_state_dict(destination, prefix, keep_vars)
  224. destination[prefix + "scale"] = self.scale
  225. destination[prefix + "zero_point"] = self.zero_point
  226. def _load_from_state_dict(
  227. self,
  228. state_dict,
  229. prefix,
  230. local_metadata,
  231. strict,
  232. missing_keys,
  233. unexpected_keys,
  234. error_msgs,
  235. ):
  236. # Removing this function throws an error that the size of the loaded tensor does not match the original size
  237. # i.e., These buffers start out with numel 0 and become numel 1 once they have their first forward pass.
  238. local_state = ["scale", "zero_point"]
  239. for name in local_state:
  240. key = prefix + name
  241. if key in state_dict:
  242. val = state_dict[key]
  243. # Custom handling to allow loading scale and zero_point
  244. # of size N into uninitialized buffers of size 0. The
  245. # buffers are resized here, and the values are copied in
  246. # the default state_dict loading code of the parent.
  247. if name == "scale":
  248. self.scale.resize_(val.shape)
  249. else:
  250. assert name == "zero_point"
  251. self.zero_point.resize_(val.shape)
  252. # For torchscript module we need to update the attributes here since we do not
  253. # call the `_load_from_state_dict` function defined module.py
  254. if torch.jit.is_scripting():
  255. if name == "scale":
  256. self.scale.copy_(val)
  257. else:
  258. assert name == "zero_point"
  259. self.zero_point.copy_(val)
  260. elif strict:
  261. missing_keys.append(key)
  262. super()._load_from_state_dict(
  263. state_dict,
  264. prefix,
  265. local_metadata,
  266. strict,
  267. missing_keys,
  268. unexpected_keys,
  269. error_msgs,
  270. )
  271. class FixedQParamsFakeQuantize(FakeQuantize):
  272. """Simulate quantize and dequantize in training time.
  273. Simulate quantize and dequantize with fixed quantization
  274. parameters in training time. Only per tensor quantization
  275. is supported.
  276. """
  277. # TODO: rename observer to observer_ctr
  278. def __init__(self, observer):
  279. super().__init__(observer=observer)
  280. assert type(self.activation_post_process) == FixedQParamsObserver, (
  281. f"{self.__class__.__name__}'s observer must be a {FixedQParamsObserver.__name__}"
  282. )
  283. self._observer_ctr = observer
  284. self.scale = self.activation_post_process.scale
  285. self.zero_point = self.activation_post_process.zero_point
  286. assert _is_per_tensor(self.qscheme), (
  287. "Only per tensor quantization is supported"
  288. + " FixedQParamsFakeQuantize module, got qscheme:"
  289. + str(self.qscheme)
  290. )
  291. @torch.jit.export
  292. def calculate_qparams(self): # type: ignore[override]
  293. return self.scale, self.zero_point
  294. @torch.jit.export
  295. def extra_repr(self):
  296. """Define a string representation of the object's attributes."""
  297. return (
  298. f"fake_quant_enabled={self.fake_quant_enabled}, observer_enabled={self.observer_enabled}, "
  299. f"scale={self.scale}, zero_point={self.zero_point}, "
  300. f"dtype={self.dtype}, quant_min={self.activation_post_process.quant_min}, "
  301. f"quant_max={self.activation_post_process.quant_max}, qscheme={self.qscheme}"
  302. )
  303. class FusedMovingAvgObsFakeQuantize(FakeQuantize):
  304. r"""Define a fused module to observe the tensor.
  305. Fused module that is used to observe the input tensor (compute min/max), compute
  306. scale/zero_point and fake_quantize the tensor.
  307. This module uses calculation similar MovingAverageMinMaxObserver for the inputs,
  308. to compute the min/max values in order to compute the scale/zero_point.
  309. The qscheme input in the observer is used to differentiate between symmetric/affine
  310. quantization scheme.
  311. The output of this module is given by
  312. x_out = (clamp(round(x/scale + zero_point), quant_min, quant_max)-zero_point)*scale
  313. Similar to :class:`~torch.ao.quantization.FakeQuantize`, and accepts the same attributes as the
  314. base class.
  315. """
  316. def __init__(
  317. self,
  318. observer: Any = MovingAverageMinMaxObserver,
  319. quant_min: int = 0,
  320. quant_max: int = 255,
  321. **observer_kwargs: Any,
  322. ) -> None:
  323. super().__init__(observer, quant_min, quant_max, **observer_kwargs)
  324. assert isinstance(
  325. self.activation_post_process,
  326. (MovingAverageMinMaxObserver, MovingAveragePerChannelMinMaxObserver),
  327. ), (
  328. "Fused observer+fake_quant module only works with MovingAverageMinMaxObserver"
  329. )
  330. self.register_buffer("fake_quant_enabled", torch.tensor([1], dtype=torch.long))
  331. self.register_buffer("observer_enabled", torch.tensor([1], dtype=torch.long))
  332. self.is_symmetric_quant = _is_symmetric_quant(
  333. self.activation_post_process.qscheme
  334. )
  335. @torch.jit.export
  336. def calculate_qparams(self) -> tuple[torch.Tensor, torch.Tensor]: # type: ignore[override]
  337. return self.activation_post_process.calculate_qparams()
  338. @torch.jit.export
  339. def extra_repr(self) -> str:
  340. return (
  341. f"fake_quant_enabled={self.fake_quant_enabled}, observer_enabled={self.observer_enabled}, "
  342. f"scale={self.scale}, zero_point={self.zero_point}, dtype={self.dtype}, "
  343. f"quant_min={self.activation_post_process.quant_min}, quant_max={self.activation_post_process.quant_max}, "
  344. f"qscheme={self.qscheme}, reduce_range={self.activation_post_process.reduce_range}"
  345. )
  346. def forward(self, X: torch.Tensor) -> torch.Tensor:
  347. return torch.fused_moving_avg_obs_fake_quant(
  348. X,
  349. self.observer_enabled,
  350. self.fake_quant_enabled,
  351. self.activation_post_process.min_val,
  352. self.activation_post_process.max_val,
  353. self.scale,
  354. self.zero_point,
  355. self.activation_post_process.averaging_constant,
  356. self.activation_post_process.quant_min,
  357. self.activation_post_process.quant_max,
  358. self.ch_axis,
  359. self.is_per_channel,
  360. self.is_symmetric_quant,
  361. )
  362. default_fake_quant = FakeQuantize.with_args(
  363. observer=MovingAverageMinMaxObserver,
  364. quant_min=0,
  365. quant_max=255,
  366. dtype=torch.quint8,
  367. qscheme=torch.per_tensor_affine,
  368. reduce_range=True,
  369. )
  370. """
  371. Default fake_quant for activations.
  372. """
  373. default_weight_fake_quant = FakeQuantize.with_args(
  374. observer=MovingAverageMinMaxObserver,
  375. quant_min=-128,
  376. quant_max=127,
  377. dtype=torch.qint8,
  378. qscheme=torch.per_tensor_symmetric,
  379. reduce_range=False,
  380. )
  381. """
  382. Default fake_quant for weights.
  383. Observer is memoryless since averaging_constant is 1.
  384. """
  385. default_dynamic_fake_quant = FakeQuantize.with_args(
  386. observer=MovingAverageMinMaxObserver,
  387. quant_min=0,
  388. quant_max=255,
  389. is_dynamic=True,
  390. dtype=torch.quint8,
  391. averaging_constant=1,
  392. )
  393. """
  394. Default dynamic fake_quant for activations.
  395. """
  396. default_fixed_qparams_range_neg1to1_fake_quant = FixedQParamsFakeQuantize.with_args(
  397. observer=default_fixed_qparams_range_neg1to1_observer
  398. )
  399. default_fixed_qparams_range_0to1_fake_quant = FixedQParamsFakeQuantize.with_args(
  400. observer=default_fixed_qparams_range_0to1_observer
  401. )
  402. # TODO: the following 2 variables are kept for backwards compatibility; remove after a few releases
  403. default_symmetric_fixed_qparams_fake_quant = (
  404. default_fixed_qparams_range_neg1to1_fake_quant
  405. )
  406. default_affine_fixed_qparams_fake_quant = default_fixed_qparams_range_0to1_fake_quant
  407. default_per_channel_weight_fake_quant = FakeQuantize.with_args(
  408. observer=MovingAveragePerChannelMinMaxObserver,
  409. quant_min=-128,
  410. quant_max=127,
  411. dtype=torch.qint8,
  412. qscheme=torch.per_channel_symmetric,
  413. reduce_range=False,
  414. ch_axis=0,
  415. )
  416. """
  417. Default fake_quant for per-channel weights.
  418. Observer is memoryless since averaging_constant is 1.
  419. """
  420. default_embedding_fake_quant = FakeQuantize.with_args(
  421. observer=MovingAveragePerChannelMinMaxObserver,
  422. qscheme=torch.per_channel_affine_float_qparams,
  423. dtype=torch.quint8,
  424. quant_min=0,
  425. quant_max=255,
  426. ch_axis=0,
  427. averaging_constant=1,
  428. )
  429. """
  430. Default fake_quant for embeddings.
  431. Observer is memoryless since averaging_constant is 1.
  432. """
  433. default_embedding_fake_quant_4bit = FakeQuantize.with_args(
  434. observer=MovingAveragePerChannelMinMaxObserver,
  435. qscheme=torch.per_channel_affine_float_qparams,
  436. ch_axis=0,
  437. dtype=torch.quint4x2,
  438. averaging_constant=1,
  439. )
  440. default_histogram_fake_quant = FakeQuantize.with_args(
  441. observer=HistogramObserver,
  442. quant_min=0,
  443. quant_max=255,
  444. dtype=torch.quint8,
  445. qscheme=torch.per_tensor_affine,
  446. reduce_range=True,
  447. )
  448. """
  449. Fake_quant for activations using a histogram..
  450. """
  451. default_fused_act_fake_quant = FusedMovingAvgObsFakeQuantize.with_args(
  452. observer=MovingAverageMinMaxObserver,
  453. quant_min=0,
  454. quant_max=255,
  455. dtype=torch.quint8,
  456. )
  457. """
  458. Fused version of `default_fake_quant`, with improved performance.
  459. """
  460. default_fused_wt_fake_quant = FusedMovingAvgObsFakeQuantize.with_args(
  461. observer=MovingAverageMinMaxObserver,
  462. quant_min=-128,
  463. quant_max=127,
  464. dtype=torch.qint8,
  465. qscheme=torch.per_tensor_symmetric,
  466. )
  467. """
  468. Fused version of `default_weight_fake_quant`, with improved performance.
  469. """
  470. default_fused_per_channel_wt_fake_quant = FusedMovingAvgObsFakeQuantize.with_args(
  471. observer=MovingAveragePerChannelMinMaxObserver,
  472. quant_min=-128,
  473. quant_max=127,
  474. dtype=torch.qint8,
  475. qscheme=torch.per_channel_symmetric,
  476. )
  477. """
  478. Fused version of `default_per_channel_weight_fake_quant`, with improved performance.
  479. """
  480. fused_wt_fake_quant_range_neg_127_to_127 = FusedMovingAvgObsFakeQuantize.with_args(
  481. observer=MovingAverageMinMaxObserver,
  482. quant_min=-127,
  483. quant_max=127,
  484. dtype=torch.qint8,
  485. qscheme=torch.per_tensor_symmetric,
  486. eps=2**-12,
  487. )
  488. """
  489. Fused version of `default_weight_fake_quant`, with the 8-bit values restricted to [-127, +127], excluding -128.
  490. """
  491. fused_per_channel_wt_fake_quant_range_neg_127_to_127 = (
  492. FusedMovingAvgObsFakeQuantize.with_args(
  493. observer=MovingAveragePerChannelMinMaxObserver,
  494. quant_min=-127,
  495. quant_max=127,
  496. dtype=torch.qint8,
  497. qscheme=torch.per_channel_symmetric,
  498. eps=2**-12,
  499. )
  500. )
  501. """
  502. Fused version of `default_per_channel_weight_fake_quant`, with the 8-bit values restricted to [-127, +127], excluding -128.
  503. """
  504. def _is_fake_quant_script_module(mod):
  505. """Return true if given mod is an instance of FakeQuantize script module."""
  506. if isinstance(mod, torch.jit.RecursiveScriptModule):
  507. # qualified name looks like '__torch__.torch.ao.quantization.fake_quantize.___torch_mangle_2.FakeQuantize'
  508. suffix = mod._c.qualified_name.split(".", 1)[1]
  509. name = re.sub(r"\.___torch_mangle_\d+", "", suffix)
  510. return (
  511. name == "torch.ao.quantization.fake_quantize.FakeQuantize"
  512. or name
  513. == "torch.ao.quantization.fake_quantize.FusedMovingAvgObsFakeQuantize"
  514. )
  515. return False
  516. def disable_fake_quant(mod):
  517. """Disable fake quantization for the module.
  518. Disable fake quantization for this module, if applicable. Example usage::
  519. # model is any PyTorch model
  520. model.apply(torch.ao.quantization.disable_fake_quant)
  521. """
  522. if isinstance(mod, FakeQuantizeBase) or _is_fake_quant_script_module(mod):
  523. mod.disable_fake_quant()
  524. def enable_fake_quant(mod):
  525. """Enable fake quantization for the module.
  526. Enable fake quantization for this module, if applicable. Example usage::
  527. # model is any PyTorch model
  528. model.apply(torch.ao.quantization.enable_fake_quant)
  529. """
  530. if isinstance(mod, FakeQuantizeBase) or _is_fake_quant_script_module(mod):
  531. mod.enable_fake_quant()
  532. def disable_observer(mod):
  533. """Disable observation for this module.
  534. Disable observation for this module, if applicable. Example usage::
  535. # model is any PyTorch model
  536. model.apply(torch.ao.quantization.disable_observer)
  537. """
  538. if isinstance(mod, FakeQuantizeBase) or _is_fake_quant_script_module(mod):
  539. mod.disable_observer()
  540. def enable_observer(mod):
  541. """Enable observation for this module.
  542. Enable observation for this module, if applicable. Example usage::
  543. # model is any PyTorch model
  544. model.apply(torch.ao.quantization.enable_observer)
  545. """
  546. if isinstance(mod, FakeQuantizeBase) or _is_fake_quant_script_module(mod):
  547. mod.enable_observer()