whisper_encoder.py 6.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164
  1. # -------------------------------------------------------------------------
  2. # Copyright (c) Microsoft Corporation. All rights reserved.
  3. # Licensed under the MIT License. See License.txt in the project root for
  4. # license information.
  5. # --------------------------------------------------------------------------
  6. import logging
  7. import os
  8. import tempfile
  9. from pathlib import Path
  10. import numpy as np
  11. import onnx
  12. import torch
  13. from float16 import convert_float_to_float16
  14. from onnx import ModelProto
  15. from onnx_model import OnnxModel
  16. from transformers import WhisperConfig
  17. from whisper_inputs import get_model_dynamic_axes, get_sample_encoder_inputs
  18. from onnxruntime import InferenceSession
  19. logger = logging.getLogger(__name__)
  20. class WhisperEncoder(torch.nn.Module):
  21. """Whisper encoder component"""
  22. def __init__(self, config: WhisperConfig, model: torch.nn.Module, model_impl: str):
  23. super().__init__()
  24. self.config = config
  25. self.device = model.device
  26. self.model_impl = model_impl
  27. self.encoder = model.encoder if model_impl == "openai" else model.model.encoder
  28. def forward(self, audio_features: torch.Tensor):
  29. outputs = self.encoder(audio_features)
  30. return outputs if self.model_impl == "openai" else outputs.last_hidden_state
  31. def input_names(self):
  32. input_names = ["audio_features"]
  33. return input_names
  34. def output_names(self):
  35. output_names = ["encoder_hidden_states"]
  36. return output_names
  37. def dynamic_axes(self, input_names, output_names):
  38. dynamic_axes = get_model_dynamic_axes(self.config, input_names, output_names)
  39. return dynamic_axes
  40. def fix_layernorm_weights(self, model: ModelProto, use_fp16_inputs: bool):
  41. if self.model_impl == "openai" and use_fp16_inputs:
  42. # Cast ONNX model to float16 to ensure LayerNorm weights are converted from
  43. # float32 to float16 since exported model already has float16 weights everywhere
  44. # except for LayerNorm ops. This happens because OpenAI always upcasts to float32
  45. # when computing LayerNorm.
  46. #
  47. # Reference:
  48. # https://github.com/openai/whisper/blob/90db0de1896c23cbfaf0c58bc2d30665f709f170/whisper/model.py#L41
  49. model = convert_float_to_float16(model)
  50. return model
  51. def export_onnx(
  52. self,
  53. onnx_model_path: str,
  54. provider: str,
  55. verbose: bool = True,
  56. use_external_data_format: bool = False,
  57. use_fp16_inputs: bool = False,
  58. ):
  59. """Export encoder to ONNX
  60. Args:
  61. onnx_model_path (str): path to save ONNX model
  62. provider (str): provider to use for verifying parity on ONNX model
  63. verbose (bool, optional): print verbose information. Defaults to True.
  64. use_external_data_format (bool, optional): use external data format or not. Defaults to False.
  65. use_fp16_inputs (bool, optional): use float16 inputs for the audio_features. Defaults to False.
  66. """
  67. # Shape of encoder's tensors:
  68. # Inputs:
  69. # audio_features: (batch_size, num_mels, num_frames)
  70. # Outputs:
  71. # encoder_hidden_states: (batch_size, num_frames // 2, hidden_size)
  72. inputs = get_sample_encoder_inputs(
  73. self.config,
  74. self.device,
  75. batch_size=2,
  76. use_fp16=use_fp16_inputs,
  77. )
  78. input_names = self.input_names()
  79. output_names = self.output_names()
  80. dynamic_axes = self.dynamic_axes(input_names, output_names)
  81. Path(onnx_model_path).parent.mkdir(parents=True, exist_ok=True)
  82. with tempfile.TemporaryDirectory() as tmp_dir_name:
  83. temp_onnx_model_path = os.path.join(tmp_dir_name, "encoder.onnx")
  84. Path(temp_onnx_model_path).parent.mkdir(parents=True, exist_ok=True)
  85. out_path = temp_onnx_model_path if use_external_data_format else onnx_model_path
  86. torch.onnx.export(
  87. self,
  88. args=(inputs["audio_features"]),
  89. f=out_path,
  90. export_params=True,
  91. input_names=input_names,
  92. output_names=output_names,
  93. dynamic_axes=dynamic_axes,
  94. opset_version=17,
  95. do_constant_folding=True,
  96. verbose=verbose,
  97. )
  98. model = onnx.load_model(out_path, load_external_data=use_external_data_format)
  99. model = self.fix_layernorm_weights(model, use_fp16_inputs)
  100. OnnxModel.save(
  101. model,
  102. onnx_model_path,
  103. save_as_external_data=use_external_data_format,
  104. all_tensors_to_one_file=True,
  105. )
  106. self.verify_onnx(onnx_model_path, provider, use_fp16_inputs)
  107. def verify_onnx(
  108. self,
  109. onnx_model_path: str,
  110. provider: str,
  111. use_fp16_inputs: bool,
  112. ):
  113. """Verify ONNX model outputs and PyTorch model outputs match
  114. Args:
  115. onnx_model_path (str): path to save ONNX model
  116. provider (str): execution provider for ONNX model
  117. use_fp16_inputs (bool, optional): use float16 inputs for the audio_features
  118. """
  119. # Shape of encoder's tensors:
  120. # Inputs:
  121. # audio_features: (batch_size, num_mels, num_frames)
  122. # Outputs:
  123. # encoder_hidden_states: (batch_size, num_frames // 2, hidden_size)
  124. inputs = get_sample_encoder_inputs(
  125. self.config,
  126. self.device,
  127. batch_size=2,
  128. use_fp16=use_fp16_inputs,
  129. )
  130. # Run PyTorch model
  131. pt_outputs = self.forward(inputs["audio_features"]).detach().cpu().numpy()
  132. # Run ONNX model
  133. sess = InferenceSession(onnx_model_path, providers=[provider])
  134. ort_outputs = sess.run(None, {"audio_features": inputs["audio_features"].detach().cpu().numpy()})[0]
  135. # Calculate output difference
  136. diff = np.abs(pt_outputs - ort_outputs)
  137. logger.warning("Comparing encoder_hidden_states...")
  138. logger.warning(f"Max diff: {np.max(diff)}")