| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164 |
- # -------------------------------------------------------------------------
- # Copyright (c) Microsoft Corporation. All rights reserved.
- # Licensed under the MIT License. See License.txt in the project root for
- # license information.
- # --------------------------------------------------------------------------
- import logging
- import os
- import tempfile
- from pathlib import Path
- import numpy as np
- import onnx
- import torch
- from float16 import convert_float_to_float16
- from onnx import ModelProto
- from onnx_model import OnnxModel
- from transformers import WhisperConfig
- from whisper_inputs import get_model_dynamic_axes, get_sample_encoder_inputs
- from onnxruntime import InferenceSession
- logger = logging.getLogger(__name__)
- class WhisperEncoder(torch.nn.Module):
- """Whisper encoder component"""
- def __init__(self, config: WhisperConfig, model: torch.nn.Module, model_impl: str):
- super().__init__()
- self.config = config
- self.device = model.device
- self.model_impl = model_impl
- self.encoder = model.encoder if model_impl == "openai" else model.model.encoder
- def forward(self, audio_features: torch.Tensor):
- outputs = self.encoder(audio_features)
- return outputs if self.model_impl == "openai" else outputs.last_hidden_state
- def input_names(self):
- input_names = ["audio_features"]
- return input_names
- def output_names(self):
- output_names = ["encoder_hidden_states"]
- return output_names
- def dynamic_axes(self, input_names, output_names):
- dynamic_axes = get_model_dynamic_axes(self.config, input_names, output_names)
- return dynamic_axes
- def fix_layernorm_weights(self, model: ModelProto, use_fp16_inputs: bool):
- if self.model_impl == "openai" and use_fp16_inputs:
- # Cast ONNX model to float16 to ensure LayerNorm weights are converted from
- # float32 to float16 since exported model already has float16 weights everywhere
- # except for LayerNorm ops. This happens because OpenAI always upcasts to float32
- # when computing LayerNorm.
- #
- # Reference:
- # https://github.com/openai/whisper/blob/90db0de1896c23cbfaf0c58bc2d30665f709f170/whisper/model.py#L41
- model = convert_float_to_float16(model)
- return model
- def export_onnx(
- self,
- onnx_model_path: str,
- provider: str,
- verbose: bool = True,
- use_external_data_format: bool = False,
- use_fp16_inputs: bool = False,
- ):
- """Export encoder to ONNX
- Args:
- onnx_model_path (str): path to save ONNX model
- provider (str): provider to use for verifying parity on ONNX model
- verbose (bool, optional): print verbose information. Defaults to True.
- use_external_data_format (bool, optional): use external data format or not. Defaults to False.
- use_fp16_inputs (bool, optional): use float16 inputs for the audio_features. Defaults to False.
- """
- # Shape of encoder's tensors:
- # Inputs:
- # audio_features: (batch_size, num_mels, num_frames)
- # Outputs:
- # encoder_hidden_states: (batch_size, num_frames // 2, hidden_size)
- inputs = get_sample_encoder_inputs(
- self.config,
- self.device,
- batch_size=2,
- use_fp16=use_fp16_inputs,
- )
- input_names = self.input_names()
- output_names = self.output_names()
- dynamic_axes = self.dynamic_axes(input_names, output_names)
- Path(onnx_model_path).parent.mkdir(parents=True, exist_ok=True)
- with tempfile.TemporaryDirectory() as tmp_dir_name:
- temp_onnx_model_path = os.path.join(tmp_dir_name, "encoder.onnx")
- Path(temp_onnx_model_path).parent.mkdir(parents=True, exist_ok=True)
- out_path = temp_onnx_model_path if use_external_data_format else onnx_model_path
- torch.onnx.export(
- self,
- args=(inputs["audio_features"]),
- f=out_path,
- export_params=True,
- input_names=input_names,
- output_names=output_names,
- dynamic_axes=dynamic_axes,
- opset_version=17,
- do_constant_folding=True,
- verbose=verbose,
- )
- model = onnx.load_model(out_path, load_external_data=use_external_data_format)
- model = self.fix_layernorm_weights(model, use_fp16_inputs)
- OnnxModel.save(
- model,
- onnx_model_path,
- save_as_external_data=use_external_data_format,
- all_tensors_to_one_file=True,
- )
- self.verify_onnx(onnx_model_path, provider, use_fp16_inputs)
- def verify_onnx(
- self,
- onnx_model_path: str,
- provider: str,
- use_fp16_inputs: bool,
- ):
- """Verify ONNX model outputs and PyTorch model outputs match
- Args:
- onnx_model_path (str): path to save ONNX model
- provider (str): execution provider for ONNX model
- use_fp16_inputs (bool, optional): use float16 inputs for the audio_features
- """
- # Shape of encoder's tensors:
- # Inputs:
- # audio_features: (batch_size, num_mels, num_frames)
- # Outputs:
- # encoder_hidden_states: (batch_size, num_frames // 2, hidden_size)
- inputs = get_sample_encoder_inputs(
- self.config,
- self.device,
- batch_size=2,
- use_fp16=use_fp16_inputs,
- )
- # Run PyTorch model
- pt_outputs = self.forward(inputs["audio_features"]).detach().cpu().numpy()
- # Run ONNX model
- sess = InferenceSession(onnx_model_path, providers=[provider])
- ort_outputs = sess.run(None, {"audio_features": inputs["audio_features"].detach().cpu().numpy()})[0]
- # Calculate output difference
- diff = np.abs(pt_outputs - ort_outputs)
- logger.warning("Comparing encoder_hidden_states...")
- logger.warning(f"Max diff: {np.max(diff)}")
|