| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464 |
- # -------------------------------------------------------------------------
- # Copyright (c) Microsoft Corporation. All rights reserved.
- # Licensed under the MIT License. See License.txt in the project root for
- # license information.
- # --------------------------------------------------------------------------
- import logging
- import os
- import tempfile
- from itertools import chain
- from pathlib import Path
- import numpy as np
- import onnx
- import torch
- from float16 import convert_float_to_float16
- from google.protobuf.internal.containers import RepeatedCompositeFieldContainer
- from onnx import ModelProto, ValueInfoProto
- from onnx_model import OnnxModel
- from past_helper import PastKeyValuesHelper
- from transformers import WhisperConfig
- from whisper_inputs import (
- convert_inputs_for_ort,
- get_model_dynamic_axes,
- get_sample_decoder_inputs,
- group_past_key_values,
- )
- from onnxruntime import InferenceSession
- logger = logging.getLogger(__name__)
- class WhisperDecoder(torch.nn.Module):
- """A Whisper decoder with optional past key values"""
- def __init__(self, config: WhisperConfig, model: torch.nn.Module, model_impl: str, no_beam_search_op: bool = False):
- super().__init__()
- self.config = config
- self.device = model.device
- self.model_impl = model_impl
- self.no_beam_search_op = no_beam_search_op
- self.decoder = None if model_impl == "openai" else model.model.decoder
- self.proj_out = None if model_impl == "openai" else model.proj_out
- self.model = model if model_impl == "openai" else None
- self.max_source_positions = self.config.max_source_positions
- self.num_heads = self.config.decoder_attention_heads
- self.head_size = self.config.d_model // self.num_heads
- def hf_forward(
- self,
- decoder_input_ids: torch.Tensor,
- encoder_hidden_states: torch.Tensor | None = None,
- past_key_values: list[tuple[torch.Tensor]] | None = None,
- ):
- outputs = self.decoder(
- encoder_hidden_states=encoder_hidden_states,
- input_ids=decoder_input_ids,
- past_key_values=past_key_values,
- use_cache=True,
- )
- logits = self.proj_out(outputs.last_hidden_state)
- present_key_values = outputs.past_key_values
- if past_key_values is None:
- # Return present_self_* and present_cross_* for decoder-init
- return logits, present_key_values
- # Before: (past_key_self_0, past_value_self_0, past_key_cross_0, past_value_cross_0),
- # (past_key_self_1, past_value_self_1, past_key_cross_1, past_value_cross_1),
- # After: (past_key_self_0, past_value_self_0, past_key_self_1, past_value_self_1), ...,
- # (past_key_cross_0, past_value_cross_0, past_key_cross_1, past_value_cross_1), ...
- present_self, present_cross = PastKeyValuesHelper.group_by_self_and_cross(present_key_values)
- # Return present_self_* for decoder-with-past since past_cross_* and present_cross_* are identical
- return logits, present_self
- def oai_forward(
- self,
- decoder_input_ids: torch.Tensor,
- encoder_hidden_states: torch.Tensor | None = None,
- past_key_values: list[tuple[torch.Tensor]] | None = None,
- ):
- past_kv_cache = {}
- if past_key_values is not None:
- # Convert past KV caches (BxNxSxH --> BxSxNxH --> BxSxD) for OpenAI's forward pass
- self_attn_kv_caches, cross_attn_kv_caches = group_past_key_values(past_key_values)
- self_attn_kv_caches = [past_kv.transpose(1, 2) for past_kv in self_attn_kv_caches]
- self_attn_kv_caches = [past_kv.reshape((*past_kv.shape[:2], -1)) for past_kv in self_attn_kv_caches]
- cross_attn_kv_caches = [past_kv.transpose(1, 2) for past_kv in cross_attn_kv_caches]
- cross_attn_kv_caches = [past_kv.reshape((*past_kv.shape[:2], -1)) for past_kv in cross_attn_kv_caches]
- for idx, block in enumerate(self.model.decoder.blocks):
- past_kv_cache[block.attn.key] = self_attn_kv_caches[2 * idx]
- past_kv_cache[block.attn.value] = self_attn_kv_caches[2 * idx + 1]
- past_kv_cache[block.cross_attn.key] = cross_attn_kv_caches[2 * idx]
- past_kv_cache[block.cross_attn.value] = cross_attn_kv_caches[2 * idx + 1]
- # Install OpenAI's hooks on the forward pass of each nn.Linear for key and value
- # since the hooks will capture the output of the key and value MatMuls, which
- # represent the current keys and values.
- #
- # For OpenAI's forward pass, the hook function will also perform the concat
- # operation (past_kv + curr_kv --> pres_kv) if needed. However, the ONNX model
- # will not contain this concat operation because the present KV caches aren't
- # returned by OpenAI's forward pass.
- kv_cache, hooks = self.model.install_kv_cache_hooks()
- # Run forward pass
- # NOTE: There is a bug with openai-whisper==20240930 with the introduction of SDPA.
- # In the Whisper codebase, the following line
- #
- # is_causal = mask is not None and n_ctx > 1
- #
- # has been added where `mask` is a torch tensor. The right-hand side evaluates to `tensor(True/False)`
- # but `is_causal` only accepts the boolean value. The fix is to apply `.item()` after the right-hand
- # side has been evaluated. In other words, the line should be
- #
- # is_causal = (mask is not None and n_ctx > 1).item()
- #
- # instead.
- logits = self.model.decoder(x=decoder_input_ids, xa=encoder_hidden_states, kv_cache=past_kv_cache)
- # Re-do concat operation on self attention KV caches for ONNX export (if past self attention KV caches exist)
- if past_key_values is not None:
- for block in self.model.decoder.blocks:
- kv_cache[block.attn.key] = torch.cat(
- [past_kv_cache[block.attn.key], kv_cache[block.attn.key]], dim=1
- ).detach()
- kv_cache[block.attn.value] = torch.cat(
- [past_kv_cache[block.attn.value], kv_cache[block.attn.value]], dim=1
- ).detach()
- present_self, present_cross = [], []
- for block in self.model.decoder.blocks:
- # Group self and cross values
- present_self.append(kv_cache[block.attn.key])
- present_self.append(kv_cache[block.attn.value])
- if past_key_values is None:
- # Return present_self_* and present_cross_* for decoder-init
- present_cross.append(kv_cache[block.cross_attn.key])
- present_cross.append(kv_cache[block.cross_attn.value])
- # Convert present KV caches (BxSxD --> BxSxNxH --> BxNxSxH) after OpenAI's forward pass
- present_self = [
- present_kv.reshape((*present_kv.shape[:2], -1, self.head_size)).transpose(1, 2)
- for present_kv in present_self
- ]
- present_cross = [
- present_kv.reshape((*present_kv.shape[:2], -1, self.head_size)).transpose(1, 2)
- for present_kv in present_cross
- ]
- # Remove OpenAI's hooks since they can persist after this function completes
- for hook in hooks:
- hook.remove()
- if past_key_values is None:
- # Return present_self_* and present_cross_* for decoder-init
- present_key_values = PastKeyValuesHelper.group_by_layer(
- present_self + present_cross, len(present_self) // 2
- )
- return logits, present_key_values
- # Return present_self_* for decoder-with-past since past_cross_* and present_cross_* are identical
- return logits, present_self
- def forward(
- self,
- decoder_input_ids: torch.Tensor,
- encoder_hidden_states: torch.Tensor | None = None,
- past_key_values: list[tuple[torch.Tensor]] | None = None,
- ):
- if self.model_impl == "openai":
- return self.oai_forward(decoder_input_ids, encoder_hidden_states, past_key_values)
- return self.hf_forward(decoder_input_ids, encoder_hidden_states, past_key_values)
- def input_names(self):
- if self.first_pass:
- input_names = ["input_ids", "encoder_hidden_states"]
- else:
- input_names = [
- "input_ids",
- "encoder_hidden_states",
- *list(
- chain.from_iterable(
- (f"past_key_self_{i}", f"past_value_self_{i}", f"past_key_cross_{i}", f"past_value_cross_{i}")
- for i in range(self.config.decoder_layers)
- )
- ),
- ]
- return input_names
- def output_names(self):
- if self.first_pass:
- output_names = [
- "logits",
- *list(
- chain.from_iterable(
- (
- f"present_key_self_{i}",
- f"present_value_self_{i}",
- f"present_key_cross_{i}",
- f"present_value_cross_{i}",
- )
- for i in range(self.config.decoder_layers)
- )
- ),
- ]
- else:
- output_names = [
- "logits",
- *list(
- chain.from_iterable(
- (f"present_key_self_{i}", f"present_value_self_{i}") for i in range(self.config.decoder_layers)
- )
- ),
- ]
- return output_names
- def dynamic_axes(self, input_names, output_names):
- dynamic_axes = get_model_dynamic_axes(self.config, input_names, output_names)
- if "input_ids" in dynamic_axes and not self.no_beam_search_op:
- # Set dynamic axes for `input_ids` when using beam search op to {0: "batch_size"} only
- del dynamic_axes["input_ids"][1]
- return dynamic_axes
- def inputs(self, use_fp16_inputs: bool, use_int32_inputs: bool, return_dict: bool = False):
- inputs = get_sample_decoder_inputs(
- self.config,
- self.device,
- batch_size=2,
- past_sequence_length=(0 if self.first_pass else 6),
- sequence_length=(6 if self.first_pass else 1),
- use_fp16=use_fp16_inputs,
- use_int32=use_int32_inputs,
- )
- if return_dict:
- if self.first_pass:
- del inputs["past_key_values"]
- return inputs
- if self.first_pass:
- return (
- inputs["decoder_input_ids"],
- inputs["encoder_hidden_states"],
- )
- return (
- inputs["decoder_input_ids"],
- inputs["encoder_hidden_states"],
- inputs["past_key_values"],
- )
- def fix_key_value_cache_dims(self, io: ValueInfoProto, is_cross: bool = False, is_output: bool = False):
- # Shape should be (batch_size, num_heads, sequence_length, head_size) for self attention KV caches
- # and (batch_size, num_heads, num_frames // 2, head_size) for cross attention KV caches
- num_heads = io.type.tensor_type.shape.dim[1]
- if "_dim_" in num_heads.dim_param:
- num_heads.Clear()
- num_heads.dim_value = self.num_heads
- sequence_length = io.type.tensor_type.shape.dim[2]
- if "_dim_" in sequence_length.dim_param:
- sequence_length.Clear()
- if is_cross:
- sequence_length.dim_value = self.max_source_positions
- else:
- sequence_length.dim_param = "total_sequence_length" if is_output else "past_sequence_length"
- head_size = io.type.tensor_type.shape.dim[3]
- if "_dim_" in head_size.dim_param:
- head_size.Clear()
- head_size.dim_value = self.head_size
- return io
- def fix_io(self, io_list: RepeatedCompositeFieldContainer, is_output: bool = False):
- # Fix order of inputs/outputs and each dim_value of input/output
- reordered_io = []
- self_attn_kv_caches = []
- cross_attn_kv_caches = []
- for io in io_list:
- if "past" not in io.name and "present" not in io.name:
- reordered_io.append(io)
- elif "self" in io.name:
- # Self attention KV caches
- new_io = self.fix_key_value_cache_dims(io, is_cross=False, is_output=is_output)
- if self.no_beam_search_op:
- reordered_io.append(new_io)
- else:
- self_attn_kv_caches.append(new_io)
- else:
- # Cross attention KV caches
- new_io = self.fix_key_value_cache_dims(io, is_cross=True, is_output=is_output)
- if self.no_beam_search_op:
- reordered_io.append(new_io)
- else:
- cross_attn_kv_caches.append(new_io)
- if not self.no_beam_search_op:
- reordered_io += self_attn_kv_caches + cross_attn_kv_caches
- return reordered_io
- def fix_inputs_and_outputs(self, model: ModelProto):
- # ONNX exporter might mark dimensions like 'Transposepresent_value_self_1_dim_2' in shape inference.
- # We now change the dim_values to the correct one.
- reordered_inputs = self.fix_io(model.graph.input, is_output=False)
- while len(model.graph.input) > 0:
- model.graph.input.pop()
- model.graph.input.extend(reordered_inputs)
- reordered_outputs = self.fix_io(model.graph.output, is_output=True)
- while len(model.graph.output) > 0:
- model.graph.output.pop()
- model.graph.output.extend(reordered_outputs)
- return model
- def fix_layernorm_weights(self, model: ModelProto, use_fp16_inputs: bool):
- if self.model_impl == "openai" and use_fp16_inputs:
- # Cast ONNX model to float16 to ensure LayerNorm weights are converted from
- # float32 to float16 since exported model already has float16 weights everywhere
- # except for LayerNorm ops. This happens because OpenAI always upcasts to float32
- # when computing LayerNorm.
- #
- # Reference:
- # https://github.com/openai/whisper/blob/90db0de1896c23cbfaf0c58bc2d30665f709f170/whisper/model.py#L41
- model = convert_float_to_float16(model)
- return model
- def export_onnx(
- self,
- onnx_model_path: str,
- provider: str,
- verbose: bool = True,
- use_external_data_format: bool = False,
- use_fp16_inputs: bool = False,
- use_int32_inputs: bool = True,
- use_encoder_hidden_states: bool = False,
- use_kv_cache_inputs: bool = True,
- ):
- """Export decoder to ONNX
- Args:
- onnx_model_path (str): path to save ONNX model
- provider (str): provider to use for verifying parity on ONNX model
- verbose (bool, optional): print verbose information. Defaults to True.
- use_external_data_format (bool, optional): use external data format or not. Defaults to False.
- use_fp16_inputs (bool, optional): use float16 inputs for the KV caches. Defaults to False.
- use_int32_inputs (bool, optional): use int32 inputs for the decoder_input_ids. Defaults to True.
- use_encoder_hidden_states (bool, optional): use encoder_hidden_states as model input for decoder-init/decoder-without-past models. Defaults to False.
- use_kv_cache_inputs (bool, optional): use KV caches as model inputs for decoder-with-past models. Defaults to True.
- """
- # Shape of decoder's tensors:
- # Required Inputs:
- # decoder_input_ids: (batch_size, sequence_length)
- # Optional Inputs:
- # encoder_hidden_states (comes from encoder's outputs): (batch_size, num_frames // 2, hidden_size)
- # past_{key/value}_self_* (past self attention KV caches): (batch_size, num_heads, past_sequence_length, head_size)
- # past_{key/value}_cross_* (past cross attention KV caches): (batch_size, num_heads, num_frames // 2, head_size)
- # Outputs:
- # logits: (batch_size, sequence_length, vocab_size)
- # present_{key/value}_self_* (present self attention KV caches): (batch_size, num_heads, past_sequence_length + sequence_length, head_size)
- # present_{key/value}_cross_* (present cross attention KV caches): (batch_size, num_heads, num_frames // 2, head_size)
- # For the first pass through the decoder (i.e. decoder-init/decoder-without-past)
- self.first_pass = use_encoder_hidden_states and not use_kv_cache_inputs
- # For subsequent passes through the decoder (i.e. decoder-with-past)
- self.later_pass = not use_encoder_hidden_states and use_kv_cache_inputs
- assert self.first_pass or self.later_pass, (
- "Only one of `use_encoder_hidden_states` and `use_kv_cache_inputs` can be true at once."
- )
- inputs = self.inputs(use_fp16_inputs=use_fp16_inputs, use_int32_inputs=use_int32_inputs)
- input_names = self.input_names()
- output_names = self.output_names()
- dynamic_axes = self.dynamic_axes(input_names, output_names)
- Path(onnx_model_path).parent.mkdir(parents=True, exist_ok=True)
- with tempfile.TemporaryDirectory() as tmp_dir_name:
- temp_onnx_model_path = os.path.join(tmp_dir_name, "decoder.onnx")
- Path(temp_onnx_model_path).parent.mkdir(parents=True, exist_ok=True)
- out_path = temp_onnx_model_path if use_external_data_format else onnx_model_path
- torch.onnx.export(
- self,
- args=inputs,
- f=out_path,
- export_params=True,
- input_names=input_names,
- output_names=output_names,
- dynamic_axes=dynamic_axes,
- opset_version=17,
- do_constant_folding=True,
- verbose=verbose,
- )
- model = onnx.load_model(out_path, load_external_data=use_external_data_format)
- model = self.fix_inputs_and_outputs(model)
- model = self.fix_layernorm_weights(model, use_fp16_inputs)
- OnnxModel.save(
- model,
- onnx_model_path,
- save_as_external_data=use_external_data_format,
- all_tensors_to_one_file=True,
- )
- self.verify_onnx(onnx_model_path, provider, use_fp16_inputs, use_int32_inputs)
- def verify_onnx(
- self,
- onnx_model_path: str,
- provider: str,
- use_fp16_inputs: bool,
- use_int32_inputs: bool,
- ):
- """Verify ONNX model outputs and PyTorch model outputs match
- Args:
- onnx_model_path (str): path to save ONNX model
- provider (str): execution provider for ONNX model
- use_fp16_inputs (bool, optional): use float16 inputs for the KV caches
- use_int32_inputs (bool, optional): use int32 inputs for the decoder_input_ids
- """
- # Shape of decoder's tensors:
- # Required Inputs:
- # decoder_input_ids: (batch_size, sequence_length)
- # Optional Inputs:
- # encoder_hidden_states (comes from encoder's outputs): (batch_size, num_frames // 2, hidden_size)
- # past_{key/value}_self_* (past self attention KV caches): (batch_size, num_heads, past_sequence_length, head_size)
- # past_{key/value}_cross_* (past cross attention KV caches): (batch_size, num_heads, num_frames // 2, head_size)
- # Outputs:
- # logits: (batch_size, sequence_length, vocab_size)
- # present_{key/value}_self_* (present self attention KV caches): (batch_size, num_heads, past_sequence_length + sequence_length, head_size)
- # present_{key/value}_cross_* (present cross attention KV caches): (batch_size, num_heads, num_frames // 2, head_size)
- # Run PyTorch model
- inputs = self.inputs(use_fp16_inputs=use_fp16_inputs, use_int32_inputs=use_int32_inputs, return_dict=True)
- pt_outputs = []
- if self.first_pass:
- out = self.forward(**inputs)
- pt_outputs.append(out[0].detach().cpu().numpy())
- for present_key_value_layer in out[1]:
- for present_key_value in present_key_value_layer:
- pt_outputs.append(present_key_value.detach().cpu().numpy())
- else:
- out = self.forward(**inputs)
- pt_outputs.append(out[0].detach().cpu().numpy())
- for present_self_key_value in out[1]:
- pt_outputs.append(present_self_key_value.detach().cpu().numpy())
- # Run ONNX model
- sess = InferenceSession(onnx_model_path, providers=[provider])
- ort_outputs = sess.run(None, convert_inputs_for_ort(inputs, sess))
- # Calculate output difference
- try:
- for i, output_name in enumerate(self.output_names()):
- diff = np.abs(pt_outputs[i] - ort_outputs[i])
- logger.warning(f"Comparing {output_name}...")
- logger.warning(f"Max diff: {np.max(diff)}")
- except: # noqa: E722
- pass
|