onnx_model_conformer.py 1.3 KB

1234567891011121314151617181920212223242526272829303132
  1. # -------------------------------------------------------------------------
  2. # Copyright (c) Microsoft Corporation. All rights reserved.
  3. # Licensed under the MIT License.
  4. # --------------------------------------------------------------------------
  5. import logging
  6. from fusion_attention import AttentionMask
  7. from fusion_conformer_attention import FusionConformerAttention
  8. from fusion_options import FusionOptions
  9. from onnx_model_bert import BertOnnxModel
  10. logger = logging.getLogger(__name__)
  11. class ConformerOnnxModel(BertOnnxModel):
  12. def __init__(self, model, num_heads, hidden_size):
  13. super().__init__(model, num_heads, hidden_size)
  14. self.attention_mask = AttentionMask(self)
  15. self.attention_fusion = FusionConformerAttention(self, self.hidden_size, self.num_heads, self.attention_mask)
  16. def optimize(self, options: FusionOptions | None = None, add_dynamic_axes: bool = False):
  17. self.attention_fusion.use_multi_head_attention = False if options is None else options.use_multi_head_attention
  18. self.attention_fusion.disable_multi_head_attention_bias = (
  19. False if options is None else options.disable_multi_head_attention_bias
  20. )
  21. super().optimize(options, add_dynamic_axes)
  22. def fuse_attention(self):
  23. self.attention_fusion.apply()
  24. def preprocess(self):
  25. self.adjust_reshape_and_expand()