quantizer.py 6.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181
  1. # mypy: allow-untyped-defs
  2. from abc import ABC, abstractmethod
  3. from dataclasses import dataclass, field
  4. from typing import Callable, Optional, Union
  5. import torch
  6. from torch import Tensor
  7. from torch.ao.quantization import ObserverOrFakeQuantize
  8. from torch.ao.quantization.qconfig import _ObserverOrFakeQuantizeConstructor
  9. from torch.fx import Node
  10. __all__ = [
  11. "Quantizer",
  12. "QuantizationSpecBase",
  13. "QuantizationSpec",
  14. "FixedQParamsQuantizationSpec",
  15. "EdgeOrNode",
  16. "SharedQuantizationSpec",
  17. "DerivedQuantizationSpec",
  18. "QuantizationAnnotation",
  19. ]
  20. class QuantizationSpecBase(ABC): # noqa: B024
  21. """Base class for different types of quantization specs that allows users to
  22. specify how to quantize a Tensor (input/output of a Node) in the model
  23. """
  24. @dataclass(eq=True, frozen=True)
  25. class QuantizationSpec(QuantizationSpecBase):
  26. """Quantization spec for common operators that allows user to specify how to
  27. quantize a Tensor, this includes dtype, quant_min, quant_max etc.
  28. """
  29. dtype: torch.dtype
  30. # observer or fake_quantize constructor such as
  31. # MinMaxObserver, PerChannelHistogramObserver etc.
  32. # or we can attach some custom args to them
  33. # e.g. MinMaxObserver.with_args(eps=eps)
  34. observer_or_fake_quant_ctr: _ObserverOrFakeQuantizeConstructor
  35. quant_min: Optional[int] = None
  36. quant_max: Optional[int] = None
  37. qscheme: Optional[torch.qscheme] = None
  38. ch_axis: Optional[int] = None
  39. is_dynamic: bool = False
  40. def __post_init__(self):
  41. # TODO: add init for quant_min/quant_max
  42. # quant_min must be less than quant_max
  43. if (
  44. self.quant_min is not None
  45. and self.quant_max is not None
  46. and self.quant_min > self.quant_max
  47. ):
  48. raise ValueError(
  49. f"quant_min {self.quant_min} must be <= quant_max {self.quant_max}."
  50. )
  51. # ch_axis must be less than the number of channels
  52. # but no way to check here. Just check that it is not < 0.
  53. if self.ch_axis is not None and self.ch_axis < 0:
  54. raise ValueError("Ch_axis is < 0.")
  55. @dataclass(eq=True, frozen=True)
  56. class FixedQParamsQuantizationSpec(QuantizationSpecBase):
  57. dtype: torch.dtype
  58. scale: float
  59. zero_point: int
  60. quant_min: Optional[int] = None
  61. quant_max: Optional[int] = None
  62. qscheme: Optional[torch.qscheme] = None
  63. is_dynamic: bool = False
  64. """
  65. The way we refer to other points of quantization in the graph will be either
  66. an input edge or an output value
  67. input edge is the connection between input node and the node consuming the input, so it's a Tuple[Node, Node]
  68. output value is an fx Node
  69. """
  70. EdgeOrNode = Union[tuple[Node, Node], Node]
  71. EdgeOrNode.__module__ = "torch.ao.quantization.quantizer.quantizer"
  72. @dataclass(eq=True, frozen=True)
  73. class SharedQuantizationSpec(QuantizationSpecBase):
  74. """
  75. Quantization spec for the Tensors whose quantization parameters are shared with other Tensors
  76. """
  77. # the edge or node to share observer or fake quant instances with
  78. edge_or_node: EdgeOrNode
  79. @dataclass(eq=True, frozen=True)
  80. class DerivedQuantizationSpec(QuantizationSpecBase):
  81. """Quantization spec for the Tensors whose quantization parameters are derived from other Tensors"""
  82. derived_from: list[EdgeOrNode]
  83. derive_qparams_fn: Callable[[list[ObserverOrFakeQuantize]], tuple[Tensor, Tensor]]
  84. dtype: torch.dtype
  85. quant_min: Optional[int] = None
  86. quant_max: Optional[int] = None
  87. qscheme: Optional[torch.qscheme] = None
  88. ch_axis: Optional[int] = None
  89. is_dynamic: bool = False
  90. @dataclass
  91. class QuantizationAnnotation:
  92. """How are input argument or output should be quantized,
  93. expressed as QuantizationSpec, this corresponds to how a Tensor in the
  94. operator Graph is observed (PTQ) or fake quantized (QAT)
  95. """
  96. # a map from torch.fx.Node to a type of QuantizationSpecBase
  97. input_qspec_map: dict[Node, Optional[QuantizationSpecBase]] = field(
  98. default_factory=dict
  99. )
  100. # How the output of this node is quantized, expressed as QuantizationSpec
  101. # TODO: change the value to QuantizationSpec in a separate PR
  102. output_qspec: Optional[QuantizationSpecBase] = None
  103. # For a Node: node1 and edge: (node1, node2), since they are observing the same
  104. # Tensor, we may want to implicitly share observers, this flag allows people to
  105. # turn off this behavior for the output of the node
  106. allow_implicit_sharing: bool = True
  107. # whether the node is annotated or not
  108. _annotated: bool = False
  109. class Quantizer(ABC):
  110. def transform_for_annotation(
  111. self, model: torch.fx.GraphModule
  112. ) -> torch.fx.GraphModule:
  113. """Allows for user defined transforms to run before annotating the graph.
  114. This allows quantizer to allow quantizing part of the model that are otherwise not quantizable.
  115. For example quantizer can
  116. a) decompose a compound operator like scaled dot product attention,
  117. into bmm and softmax if quantizer knows how to quantize bmm/softmax but not sdpa
  118. or b) transform scalars to tensor to allow quantizing scalares.
  119. Note: this is an optional method
  120. """
  121. return model
  122. # annotate nodes in the graph with observer or fake quant constructors
  123. # to convey the desired way of quantization
  124. @abstractmethod
  125. def annotate(self, model: torch.fx.GraphModule) -> torch.fx.GraphModule:
  126. pass
  127. # validate the annotated graph is supported by the backend
  128. @abstractmethod
  129. def validate(self, model: torch.fx.GraphModule) -> None:
  130. pass
  131. def prepare_obs_or_fq_callback(
  132. self,
  133. model: torch.fx.GraphModule,
  134. edge_or_node_to_obs_or_fq: dict[EdgeOrNode, ObserverOrFakeQuantize],
  135. ) -> None:
  136. """A callback that will be called after the observers or fake quants are created
  137. for each sharing group, but before they are inserted into the graph. The
  138. callback can be used to make final quantization adjustments, such as enforcing
  139. specific scale and zero point on model input or output.
  140. Args:
  141. * `model`: the graph module being prepared.
  142. * `edge_or_node_to_obs_or_fq`: a dictionary mapping each annotated edge and
  143. node to the corresponding observer or fake quant object. Note that multiple
  144. edges and/or nodes can map to the same observer / fake quant instance if
  145. they were annotated with SharedQuantizationSpec. This dictionary can be
  146. modified by the callback.
  147. """
  148. return