embed_layernorm.py 3.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121
  1. import logging
  2. import onnx
  3. from onnx import onnx_pb as onnx_proto # noqa: F401
  4. from ..quant_utils import attribute_to_kwarg, ms_domain
  5. from .base_operator import QuantOperatorBase
  6. """
  7. Quantizes the EmbedLayerNorm fused ONNXRuntime Op.
  8. This Quant operator keeps the input and segment IDs at int32 but will quantize all initializer and
  9. weight inputs associated with the node to uint8.
  10. """
  11. class EmbedLayerNormalizationQuant(QuantOperatorBase):
  12. def __init__(self, onnx_quantizer, onnx_node):
  13. super().__init__(onnx_quantizer, onnx_node)
  14. def should_quantize(self):
  15. return self.quantizer.should_quantize_node(self.node)
  16. def quantize(self):
  17. node = self.node
  18. assert node.op_type == "EmbedLayerNormalization"
  19. if len(node.output) > 2:
  20. logging.info(f"Quantization is not applied to {node.name} since it has 3 outputs")
  21. return super().quantize()
  22. """
  23. Pre-quantization EmbedLayerNorm inputs:
  24. [0] input_ids (int32)
  25. [1] segment_ids (int32)
  26. [2] word_embedding (float32)
  27. [3] position_embedding (float32)
  28. [4] segment_embedding (float32)
  29. [5] gamma (float32)
  30. [6] beta (float32)
  31. [7] mask (int32) (optional)
  32. """
  33. (
  34. quantized_input_names,
  35. zero_point_names,
  36. scale_names,
  37. nodes,
  38. ) = self.quantizer.quantize_activation(node, [2, 3, 4, 5, 6])
  39. if quantized_input_names is None:
  40. return super().quantize()
  41. qembed_layer_norm_name = "" if not node.name else node.name + "_quant"
  42. """
  43. Quantized Input Tensor List
  44. [0] input_ids (int32)
  45. [1] segment_ids (int32)
  46. [2] word_embedding (uint8)
  47. [3] position_embedding (uint8)
  48. [4] segment_embedding (uint8)
  49. [5] gamma (uint8)
  50. [6] beta (uint8)
  51. [7] mask (int32) (optional)
  52. [8] word_embedding_scale (float)
  53. [9] position_embedding_scale (float)
  54. [10] segment_embedding_scale (float)
  55. [11] gamma_scale (float)
  56. [12] beta_scale (float)
  57. [13] word_embedding_zero_point (uint8)
  58. [14] position_embedding_zero_point (uint8)
  59. [15] segment_embedding_zero_point (uint8)
  60. [16] gamma_zero_point (uint8)
  61. [17] beta_zero_point (uint8)
  62. """
  63. inputs = []
  64. # 'input_ids'
  65. inputs.extend([node.input[0]])
  66. # 'segment_ids'
  67. inputs.extend([node.input[1]])
  68. # 'word_embedding_quant'
  69. inputs.extend([quantized_input_names[0]])
  70. # 'position_embedding_quant'
  71. inputs.extend([quantized_input_names[1]])
  72. # 'segment_embedding_quant'
  73. inputs.extend([quantized_input_names[2]])
  74. # 'gamma_quant'
  75. inputs.extend([quantized_input_names[3]])
  76. # 'beta_quant'
  77. inputs.extend([quantized_input_names[4]])
  78. # 'mask' (optional)
  79. inputs.extend([node.input[7] if len(node.input) > 7 else ""])
  80. # Add all scales:
  81. inputs.extend([scale_names[0]])
  82. inputs.extend([scale_names[1]])
  83. inputs.extend([scale_names[2]])
  84. inputs.extend([scale_names[3]])
  85. inputs.extend([scale_names[4]])
  86. # Add all zero points:
  87. inputs.extend([zero_point_names[0]])
  88. inputs.extend([zero_point_names[1]])
  89. inputs.extend([zero_point_names[2]])
  90. inputs.extend([zero_point_names[3]])
  91. inputs.extend([zero_point_names[4]])
  92. kwargs = {}
  93. for attribute in node.attribute:
  94. kwargs.update(attribute_to_kwarg(attribute))
  95. kwargs["domain"] = ms_domain
  96. qembed_layer_norm_node = onnx.helper.make_node(
  97. "QEmbedLayerNormalization",
  98. inputs,
  99. node.output,
  100. qembed_layer_norm_name,
  101. **kwargs,
  102. )
  103. nodes.append(qembed_layer_norm_node)
  104. self.quantizer.new_nodes += nodes