| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380 |
- # -------------------------------------------------------------------------
- # Copyright (c) Microsoft Corporation. All rights reserved.
- # Licensed under the MIT License. See License.txt in the project root for
- # license information.
- # --------------------------------------------------------------------------
- import logging
- import numpy as np
- import torch
- from transformers import WhisperConfig
- from onnxruntime import InferenceSession
- logger = logging.getLogger(__name__)
- # Create audio_features for encoder
- # Shape is (batch_size, feature_size, sequence_length) = (batch_size, num_mel_filters, num_frames)
- # where num_mel_filters is a model attribute and num_frames = (chunk_length * sample_rate) // hop_length.
- #
- # Hard-coded audio hyperparameters:
- # SAMPLE_RATE = 16000
- # N_FFT = 400
- # HOP_LENGTH = 160
- # CHUNK_LENGTH = 30 (i.e. 30-second chunk of audio)
- # N_SAMPLES = CHUNK_LENGTH * SAMPLE_RATE = 30 * 16000 = 480000 (i.e. 480,000 samples in a 30-second chunk of audio)
- # N_FRAMES = N_SAMPLES // HOP_LENGTH = 480000 // 160 = 3000 (i.e. 3000 frames in a mel spectrogram input)
- #
- # N_SAMPLES_PER_TOKEN = HOP_LENGTH * 2 = 160 * 2 = 320
- # FRAMES_PER_TOKEN = SAMPLE_RATE // HOP_LENGTH = 16000 // 160 = 100 (i.e. 10 ms per audio frame)
- # TOKENS_PER_SECOND = SAMPLE_RATE // N_SAMPLES_PER_TOKEN = 16000 // 320 = 50 (i.e. 20 ms per audio token)
- def get_sample_audio_features(
- config: WhisperConfig,
- device: torch.device,
- batch_size: int,
- sequence_length: int = 3000,
- use_fp16: bool = False,
- ):
- torch_dtype = torch.float16 if use_fp16 else torch.float32
- audio_features = torch.randn(batch_size, config.num_mel_bins, sequence_length, device=device, dtype=torch_dtype)
- return audio_features
- # Create input_ids for decoder
- # Shape is (batch_size, sequence_length) where sequence_length is the initial decoder sequence length
- def get_sample_decoder_input_ids(
- config: WhisperConfig,
- device: torch.device,
- batch_size: int,
- sequence_length: int,
- use_int32: bool = True,
- ):
- torch_dtype = torch.int32 if use_int32 else torch.int64
- decoder_input_ids = torch.randint(
- low=0, high=config.vocab_size, size=(batch_size, sequence_length), device=device, dtype=torch_dtype
- )
- return decoder_input_ids
- # Create encoder_hidden_states for decoder-init
- # Shape is (batch_size, num_frames // 2, hidden_size)
- def get_sample_encoder_hidden_states(
- config: WhisperConfig,
- device: torch.device,
- batch_size: int,
- use_fp16: bool = False,
- ):
- torch_dtype = torch.float16 if use_fp16 else torch.float32
- encoder_hidden_states = torch.randn(
- batch_size, config.max_source_positions, config.d_model, device=device, dtype=torch_dtype
- )
- return encoder_hidden_states
- # Create past_key_values
- # Self-attention KV caches are of shape (batch_size, num_heads, past_sequence_length, head_size)
- # Cross-attention KV caches are of shape (batch_size, num_heads, num_frames // 2, head_size)
- def get_sample_past_key_values(
- config: WhisperConfig,
- device: torch.device,
- batch_size: int,
- past_seq_len: int,
- use_fp16: bool = False,
- ):
- num_heads = config.decoder_attention_heads
- head_size = config.d_model // num_heads
- max_source_positions = (
- config.max_source_positions
- ) # equal to num_frames // 2 = encoder's sequence_length // 2 = 3000 // 2 = 1500
- torch_dtype = torch.float16 if use_fp16 else torch.float32
- self_attention_kv_caches = [
- (
- torch.rand(batch_size, num_heads, past_seq_len, head_size, device=device, dtype=torch_dtype),
- torch.rand(batch_size, num_heads, past_seq_len, head_size, device=device, dtype=torch_dtype),
- )
- for _ in range(config.decoder_layers)
- ]
- cross_attention_kv_caches = [
- (
- torch.rand(batch_size, num_heads, max_source_positions, head_size, device=device, dtype=torch_dtype),
- torch.rand(batch_size, num_heads, max_source_positions, head_size, device=device, dtype=torch_dtype),
- )
- for _ in range(config.decoder_layers)
- ]
- return flatten_past_key_values(self_attention_kv_caches, cross_attention_kv_caches)
- # Flatten KV caches into pairs-of-4 where each pair is defined as:
- # (self_attn_key_cache, self_attn_value_cache, cross_attn_key_cache, cross_attn_value_cache)
- def flatten_past_key_values(
- self_attn_kv_caches: list[tuple[torch.Tensor, torch.Tensor]],
- cross_attn_kv_caches: list[tuple[torch.Tensor, torch.Tensor]],
- ):
- past_key_values = []
- for (self_k_cache, self_v_cache), (cross_k_cache, cross_v_cache) in zip(
- self_attn_kv_caches, cross_attn_kv_caches, strict=False
- ):
- layer_kv_caches = (self_k_cache, self_v_cache, cross_k_cache, cross_v_cache)
- past_key_values.append(layer_kv_caches)
- return past_key_values
- # Group KV caches into two 1D lists where one list contains the self attention KV caches and
- # one list contains the cross attention KV caches
- def group_past_key_values(
- kv_caches: list[tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]],
- ):
- self_attn_kv_caches, cross_attn_kv_caches = [], []
- for self_k_cache, self_v_cache, cross_k_cache, cross_v_cache in kv_caches:
- self_attn_kv_caches.append(self_k_cache)
- self_attn_kv_caches.append(self_v_cache)
- cross_attn_kv_caches.append(cross_k_cache)
- cross_attn_kv_caches.append(cross_v_cache)
- return self_attn_kv_caches, cross_attn_kv_caches
- # Create alignment heads for timestamps
- # Shape is (num_alignment_heads, 2)
- def get_sample_alignment_heads(
- config: WhisperConfig,
- device: torch.device,
- num_alignment_heads: int = 6,
- use_int32: bool = True,
- ):
- torch_dtype = torch.int32 if use_int32 else torch.int64
- alignment_heads = torch.ones((num_alignment_heads, 2), device=device, dtype=torch_dtype)
- return alignment_heads
- # Create length of start-of-transcription sequence for timestamps
- # Shape is (1)
- def get_sample_sot_sequence_length(
- device: torch.device,
- sot_sequence_length: int,
- use_int32: bool = False,
- ):
- torch_dtype = torch.int32 if use_int32 else torch.int64
- sot_length = torch.tensor([sot_sequence_length], device=device, dtype=torch_dtype)
- return sot_length
- # Create segment length for timestamps
- # Shape is (1)
- def get_sample_segment_length(
- device: torch.device,
- segment_length: int,
- use_int32: bool = False,
- ):
- torch_dtype = torch.int32 if use_int32 else torch.int64
- segment_size = torch.tensor([segment_length], device=device, dtype=torch_dtype)
- return segment_size
- # Create QKs for timestamps
- # Shape is (batch_size, num_heads, sequence_length, num_frames // 2)
- def get_sample_QKs( # noqa: N802
- config: WhisperConfig,
- device: torch.device,
- batch_size: int,
- sequence_length: int,
- use_fp16: bool = False,
- ):
- num_heads = config.decoder_attention_heads
- torch_dtype = torch.float16 if use_fp16 else torch.float32
- QKs = [ # noqa: N806
- torch.rand(
- batch_size, num_heads, sequence_length, config.max_source_positions, device=device, dtype=torch_dtype
- )
- for _ in range(config.decoder_layers)
- ]
- return QKs
- # Create inputs for encoder component of Whisper
- def get_sample_encoder_inputs(
- config: WhisperConfig,
- device: torch.device,
- batch_size: int,
- sequence_length: int = 3000,
- use_fp16: bool = False,
- ):
- audio_features = get_sample_audio_features(config, device, batch_size, sequence_length, use_fp16)
- return {"audio_features": audio_features}
- # Create inputs for encoder component + first pass through decoder component of Whisper
- def get_sample_encoder_decoder_init_inputs(
- config: WhisperConfig,
- device: torch.device,
- batch_size: int,
- decoder_sequence_length: int,
- encoder_sequence_length: int = 3000,
- use_fp16: bool = False,
- use_int32: bool = True,
- ):
- audio_features = get_sample_audio_features(config, device, batch_size, encoder_sequence_length, use_fp16)
- decoder_input_ids = get_sample_decoder_input_ids(config, device, batch_size, decoder_sequence_length, use_int32)
- return {"audio_features": audio_features, "decoder_input_ids": decoder_input_ids}
- # Create inputs for decoder component of Whisper
- # Inputs for first pass through the decoder (i.e. decoder-init): decoder_input_ids, encoder_hidden_states
- # Inputs for subsequent passes through the decoder (i.e. decoder-with-past): decoder_input_ids, past_key_values
- def get_sample_decoder_inputs(
- config: WhisperConfig,
- device: torch.device,
- batch_size: int,
- past_sequence_length: int,
- sequence_length: int,
- use_fp16: bool = False,
- use_int32: bool = True,
- ):
- decoder_input_ids = get_sample_decoder_input_ids(config, device, batch_size, sequence_length, use_int32)
- encoder_hidden_states = get_sample_encoder_hidden_states(config, device, batch_size, use_fp16)
- past_key_values = get_sample_past_key_values(config, device, batch_size, past_sequence_length, use_fp16)
- return {
- "decoder_input_ids": decoder_input_ids,
- "encoder_hidden_states": encoder_hidden_states,
- "past_key_values": past_key_values,
- }
- # Create inputs for timestamps component of Whisper
- def get_sample_jump_times_inputs(
- config: WhisperConfig,
- device: torch.device,
- batch_size: int,
- sequence_length: int,
- num_alignment_heads: int,
- sot_sequence_length: int,
- segment_length: int,
- use_fp16: bool = False,
- use_int32: bool = True,
- ):
- alignment_heads = get_sample_alignment_heads(config, device, num_alignment_heads, use_int32)
- # lengths need to be int64 because subsequent 'Slice' ops only take int64 inputs
- sot_sequence_length = get_sample_sot_sequence_length(device, sot_sequence_length)
- segment_length = get_sample_segment_length(device, segment_length)
- QKs = get_sample_QKs(config, device, batch_size, sequence_length, use_fp16) # noqa: N806
- return {
- "alignment_heads": alignment_heads,
- "sot_sequence_length": sot_sequence_length,
- "segment_length": segment_length,
- "QKs": QKs,
- }
- # Convert PyTorch inputs to ONNX Runtime inputs
- def convert_inputs_for_ort(
- inputs: dict,
- model: InferenceSession,
- ):
- self_attn_kv_caches, cross_attn_kv_caches = None, None
- batch_size, num_heads, past_seq_len, head_size = 0, 0, 0, 0
- num_beams, max_seq_len = 1, 448
- if "past_key_values" in inputs:
- (self_attn_kv_caches, cross_attn_kv_caches) = group_past_key_values(inputs["past_key_values"])
- batch_size, num_heads, past_seq_len, head_size = self_attn_kv_caches[0].shape
- ort_inputs = {}
- model_inputs = list(map(lambda i: i.name, model.get_inputs())) # noqa: C417
- use_buffer_sharing = "cache_indirection" in model_inputs
- for name in model_inputs:
- if name in {"audio_features", "encoder_input_ids"}:
- # Encoder input
- ort_inputs[name] = inputs["audio_features"].detach().cpu().numpy()
- elif name == "encoder_hidden_states":
- # Encoder output
- ort_inputs[name] = inputs["encoder_hidden_states"].detach().cpu().numpy()
- elif name in {"decoder_input_ids", "input_ids"}:
- # Decoder input
- ort_inputs[name] = inputs["decoder_input_ids"].detach().cpu().numpy()
- elif "past_key_self" in name or "past_value_self" in name:
- # Decoder input
- orig_kv_cache = self_attn_kv_caches.pop(0).detach().cpu().numpy()
- if use_buffer_sharing:
- new_kv_cache = np.zeros((batch_size, num_heads, max_seq_len, head_size), dtype=orig_kv_cache.dtype)
- new_kv_cache[:batch_size, :num_heads, :past_seq_len, :head_size] = orig_kv_cache
- ort_inputs[name] = new_kv_cache
- else:
- ort_inputs[name] = orig_kv_cache
- elif "past_key_cross" in name or "past_value_cross" in name:
- # Decoder input
- orig_kv_cache = cross_attn_kv_caches.pop(0).detach().cpu().numpy()
- ort_inputs[name] = orig_kv_cache
- elif name == "past_sequence_length":
- # Decoder input
- ort_inputs[name] = np.array([past_seq_len], dtype=np.int32)
- elif name == "cache_indirection":
- # Decoder input
- ort_inputs[name] = np.zeros((batch_size, num_beams, max_seq_len), dtype=np.int32)
- elif name == "alignment_heads":
- # Jump times input
- ort_inputs[name] = inputs["alignment_heads"].detach().cpu().numpy()
- elif name == "sot_sequence_length":
- # Jump times input
- ort_inputs[name] = inputs["sot_sequence_length"].detach().cpu().numpy()
- elif name == "segment_length":
- # Jump times input
- ort_inputs[name] = inputs["segment_length"].detach().cpu().numpy()
- elif "cross_qk" in name:
- # Jump times input
- ort_inputs[name] = inputs["QKs"].pop(0).detach().cpu().numpy()
- else:
- raise ValueError(f"Unknown name not recognized: {name}")
- return ort_inputs
- # Get dynamic axes for all inputs and outputs to the model
- def get_model_dynamic_axes(
- config: WhisperConfig,
- input_names: list[str],
- output_names: list[str],
- ):
- dynamic_axes = {}
- for name in input_names + output_names:
- if name in {"audio_features", "encoder_input_ids"}:
- # shape is (batch_size, num_mels, num_frames)
- dynamic_axes[name] = {0: "batch_size"}
- elif name in {"input_ids", "decoder_input_ids"}:
- # shape is (batch_size, sequence_length)
- dynamic_axes[name] = {0: "batch_size", 1: "sequence_length"}
- elif name == "alignment_heads":
- # shape is (num_alignment_heads, 2)
- dynamic_axes[name] = {0: "num_alignment_heads"}
- elif name in {"sot_sequence_length", "segment_length"}:
- # shape is (1)
- pass
- elif name == "logits":
- # shape is (batch_size, sequence_length, vocab_size)
- dynamic_axes[name] = {0: "batch_size", 1: "sequence_length"}
- elif name == "encoder_hidden_states":
- # shape is (batch_size, num_frames // 2, hidden_size)
- dynamic_axes[name] = {0: "batch_size"}
- elif "past_key_self" in name or "past_value_self" in name:
- # shape is (batch_size, num_heads, past_sequence_length, head_size)
- dynamic_axes[name] = {0: "batch_size", 2: "past_sequence_length"}
- elif "present_key_self" in name or "present_value_self" in name:
- # shape is (batch_size, num_heads, past_sequence_length + sequence_length, head_size),
- # which is equal to (batch_size, num_heads, total_sequence_length, head_size)
- dynamic_axes[name] = {0: "batch_size", 2: "total_sequence_length"}
- elif (
- "past_key_cross" in name
- or "past_value_cross" in name
- or "present_key_cross" in name
- or "present_value_cross" in name
- ):
- # shape is (batch_size, num_heads, num_frames // 2, head_size)
- dynamic_axes[name] = {0: "batch_size"}
- elif "cross_qk" in name:
- # shape is (batch_size, num_heads, source_sequence_length, target_sequence_length)
- dynamic_axes[name] = {0: "batch_size", 2: "sequence_length"}
- elif "jump_times" in name:
- # shape is (batch_size, max_length)
- dynamic_axes[name] = {0: "batch_size", 1: "max_length"}
- else:
- raise Exception(f"Unknown input or output name found: {name}")
- return dynamic_axes
|