# ------------------------------------------------------------------------- # Copyright (c) Microsoft Corporation. All rights reserved. # Licensed under the MIT License. See License.txt in the project root for # license information. # -------------------------------------------------------------------------- import logging import os import tempfile from pathlib import Path import numpy as np import onnx import torch from float16 import convert_float_to_float16 from onnx import ModelProto from onnx_model import OnnxModel from transformers import WhisperConfig from whisper_inputs import get_model_dynamic_axes, get_sample_encoder_inputs from onnxruntime import InferenceSession logger = logging.getLogger(__name__) class WhisperEncoder(torch.nn.Module): """Whisper encoder component""" def __init__(self, config: WhisperConfig, model: torch.nn.Module, model_impl: str): super().__init__() self.config = config self.device = model.device self.model_impl = model_impl self.encoder = model.encoder if model_impl == "openai" else model.model.encoder def forward(self, audio_features: torch.Tensor): outputs = self.encoder(audio_features) return outputs if self.model_impl == "openai" else outputs.last_hidden_state def input_names(self): input_names = ["audio_features"] return input_names def output_names(self): output_names = ["encoder_hidden_states"] return output_names def dynamic_axes(self, input_names, output_names): dynamic_axes = get_model_dynamic_axes(self.config, input_names, output_names) return dynamic_axes def fix_layernorm_weights(self, model: ModelProto, use_fp16_inputs: bool): if self.model_impl == "openai" and use_fp16_inputs: # Cast ONNX model to float16 to ensure LayerNorm weights are converted from # float32 to float16 since exported model already has float16 weights everywhere # except for LayerNorm ops. This happens because OpenAI always upcasts to float32 # when computing LayerNorm. # # Reference: # https://github.com/openai/whisper/blob/90db0de1896c23cbfaf0c58bc2d30665f709f170/whisper/model.py#L41 model = convert_float_to_float16(model) return model def export_onnx( self, onnx_model_path: str, provider: str, verbose: bool = True, use_external_data_format: bool = False, use_fp16_inputs: bool = False, ): """Export encoder to ONNX Args: onnx_model_path (str): path to save ONNX model provider (str): provider to use for verifying parity on ONNX model verbose (bool, optional): print verbose information. Defaults to True. use_external_data_format (bool, optional): use external data format or not. Defaults to False. use_fp16_inputs (bool, optional): use float16 inputs for the audio_features. Defaults to False. """ # Shape of encoder's tensors: # Inputs: # audio_features: (batch_size, num_mels, num_frames) # Outputs: # encoder_hidden_states: (batch_size, num_frames // 2, hidden_size) inputs = get_sample_encoder_inputs( self.config, self.device, batch_size=2, use_fp16=use_fp16_inputs, ) input_names = self.input_names() output_names = self.output_names() dynamic_axes = self.dynamic_axes(input_names, output_names) Path(onnx_model_path).parent.mkdir(parents=True, exist_ok=True) with tempfile.TemporaryDirectory() as tmp_dir_name: temp_onnx_model_path = os.path.join(tmp_dir_name, "encoder.onnx") Path(temp_onnx_model_path).parent.mkdir(parents=True, exist_ok=True) out_path = temp_onnx_model_path if use_external_data_format else onnx_model_path torch.onnx.export( self, args=(inputs["audio_features"]), f=out_path, export_params=True, input_names=input_names, output_names=output_names, dynamic_axes=dynamic_axes, opset_version=17, do_constant_folding=True, verbose=verbose, ) model = onnx.load_model(out_path, load_external_data=use_external_data_format) model = self.fix_layernorm_weights(model, use_fp16_inputs) OnnxModel.save( model, onnx_model_path, save_as_external_data=use_external_data_format, all_tensors_to_one_file=True, ) self.verify_onnx(onnx_model_path, provider, use_fp16_inputs) def verify_onnx( self, onnx_model_path: str, provider: str, use_fp16_inputs: bool, ): """Verify ONNX model outputs and PyTorch model outputs match Args: onnx_model_path (str): path to save ONNX model provider (str): execution provider for ONNX model use_fp16_inputs (bool, optional): use float16 inputs for the audio_features """ # Shape of encoder's tensors: # Inputs: # audio_features: (batch_size, num_mels, num_frames) # Outputs: # encoder_hidden_states: (batch_size, num_frames // 2, hidden_size) inputs = get_sample_encoder_inputs( self.config, self.device, batch_size=2, use_fp16=use_fp16_inputs, ) # Run PyTorch model pt_outputs = self.forward(inputs["audio_features"]).detach().cpu().numpy() # Run ONNX model sess = InferenceSession(onnx_model_path, providers=[provider]) ort_outputs = sess.run(None, {"audio_features": inputs["audio_features"].detach().cpu().numpy()})[0] # Calculate output difference diff = np.abs(pt_outputs - ort_outputs) logger.warning("Comparing encoder_hidden_states...") logger.warning(f"Max diff: {np.max(diff)}")