# ------------------------------------------------------------------------- # Copyright (c) Microsoft Corporation. All rights reserved. # Licensed under the MIT License. See License.txt in the project root for # license information. # -------------------------------------------------------------------------- import logging import os import onnx from benchmark_helper import Precision from convert_generation import ( get_shared_initializers, update_decoder_subgraph_output_cross_attention, update_decoder_subgraph_share_buffer_and_use_decoder_masked_mha, ) from onnx import TensorProto, helper from transformers import WhisperConfig, WhisperTokenizer logger = logging.getLogger(__name__) def verify_inputs(beam_inputs, graph_inputs): # Verify that ONNX graph's inputs match beam search op's inputs beam_required_inputs = list(filter(lambda beam_input: beam_input, beam_inputs)) assert len(graph_inputs) == len(beam_required_inputs) for graph_input, beam_input in zip(graph_inputs, beam_required_inputs, strict=False): # Check if graph_input is in beam_input to handle beam_input names with the "_fp16" suffix assert graph_input.name in beam_input def clean_list(arr, remove_all_strings=True): if remove_all_strings: # Remove all empty strings in list return list(filter(lambda elm: elm != "", arr)) # Remove empty strings at end of list while len(arr) > 0: if arr[-1] == "": arr.pop() else: break return arr def chain_model(args): # Load encoder/decoder and insert necessary (but unused) graph inputs expected by WhisperBeamSearch op encoder_model = onnx.load_model(args.encoder_path, load_external_data=True) encoder_model.graph.name = "encoderdecoderinit subgraph" decoder_model = onnx.load_model(args.decoder_path, load_external_data=True) decoder_model.graph.name = "decoder subgraph" config = WhisperConfig.from_pretrained(args.model_name_or_path, cache_dir=args.cache_dir) tokenizer = WhisperTokenizer.from_pretrained(args.model_name_or_path, cache_dir=args.cache_dir) # Create inputs/outputs for WhisperBeamSearch op temperature_name = "temperature_fp16" if args.precision == Precision.FLOAT16 else "temperature" beam_inputs = [ "input_features_fp16" if args.precision == Precision.FLOAT16 else "input_features", "max_length", "min_length", "num_beams", "num_return_sequences", "length_penalty_fp16" if args.precision == Precision.FLOAT16 else "length_penalty", "repetition_penalty_fp16" if args.precision == Precision.FLOAT16 else "repetition_penalty", "vocab_mask" if args.use_vocab_mask else "", "prefix_vocab_mask" if args.use_prefix_vocab_mask else "", "", # attention mask "decoder_input_ids" if args.use_forced_decoder_ids else "", "logits_processor" if args.use_logits_processor else "", "cross_qk_layer_head" if args.collect_cross_qk else "", "extra_decoding_ids" if args.extra_decoding_ids else "", temperature_name if args.use_temperature else "", ] sequence_scores_name = "sequence_scores_fp16" if args.precision == Precision.FLOAT16 else "sequence_scores" scores_name = "scores_fp16" if args.precision == Precision.FLOAT16 else "scores" beam_outputs = [ "sequences", sequence_scores_name if args.output_sequence_scores else "", scores_name if args.output_scores else "", "cross_qk" if args.collect_cross_qk else "", "no_speech_probs_beam" if args.output_no_speech_probs else "", ] graph_nodes = [] if args.precision == Precision.FLOAT16: input_features_cast_node = helper.make_node( "Cast", inputs=["input_features"], outputs=["input_features_fp16"], name="CastInputFeaturesToFp16", to=TensorProto.FLOAT16, ) len_pen_cast_node = helper.make_node( "Cast", inputs=["length_penalty"], outputs=["length_penalty_fp16"], name="CastLengthPenaltyToFp16", to=TensorProto.FLOAT16, ) rep_pen_cast_node = helper.make_node( "Cast", inputs=["repetition_penalty"], outputs=["repetition_penalty_fp16"], name="CastRepetitionPenaltyToFp16", to=TensorProto.FLOAT16, ) graph_nodes.extend([input_features_cast_node, len_pen_cast_node, rep_pen_cast_node]) if args.use_temperature: temp_cast_node = helper.make_node( "Cast", inputs=["temperature"], outputs=["temperature_fp16"], name="temperature_to_fp16", to=TensorProto.FLOAT16, ) graph_nodes.append(temp_cast_node) if args.output_sequence_scores: output_sequence_scores_cast_node = helper.make_node( "Cast", inputs=["sequence_scores_fp16"], outputs=["sequence_scores"], name="CastOutputSequenceScoresToFp32", to=TensorProto.FLOAT, ) graph_nodes.append(output_sequence_scores_cast_node) if args.output_scores: output_scores_cast_node = helper.make_node( "Cast", inputs=["scores_fp16"], outputs=["scores"], name="CastScoresToFp32", to=TensorProto.FLOAT, ) graph_nodes.append(output_scores_cast_node) # Create WhisperBeamSearch op beam_search_attrs = [ helper.make_attribute("eos_token_id", config.eos_token_id), helper.make_attribute("pad_token_id", config.pad_token_id), helper.make_attribute( "decoder_start_token_id", config.decoder_start_token_id ), # same as tokenizer.convert_tokens_to_ids(['<|startoftranscript|>'])[0] helper.make_attribute("translate_token_id", tokenizer.convert_tokens_to_ids(["<|translate|>"])[0]), helper.make_attribute("transcribe_token_id", tokenizer.convert_tokens_to_ids(["<|transcribe|>"])[0]), helper.make_attribute("start_of_lm_token_id", tokenizer.convert_tokens_to_ids(["<|startoflm|>"])[0]), ( helper.make_attribute("no_speech_token_id", tokenizer.convert_tokens_to_ids(["<|nospeech|>"])[0]) if args.output_no_speech_probs else "" ), helper.make_attribute("no_timestamps_token_id", tokenizer.convert_tokens_to_ids(["<|notimestamps|>"])[0]), helper.make_attribute("beginning_timestamp_token_id", tokenizer.convert_tokens_to_ids(["<|0.00|>"])[0]), helper.make_attribute("no_repeat_ngram_size", args.no_repeat_ngram_size), helper.make_attribute("early_stopping", True), helper.make_attribute("model_type", 2), helper.make_attribute("decoder_output_cross_qk", 1) if args.collect_cross_qk else "", ] node = helper.make_node( "WhisperBeamSearch", inputs=clean_list(beam_inputs, remove_all_strings=False), outputs=clean_list(beam_outputs, remove_all_strings=False), name="BeamSearch", domain="com.microsoft", ) node.attribute.extend(clean_list(beam_search_attrs, remove_all_strings=True)) # Graph inputs input_features = helper.make_tensor_value_info( "input_features", TensorProto.FLOAT, ["batch_size", "feature_size", "sequence_length"] ) max_length = helper.make_tensor_value_info("max_length", TensorProto.INT32, [1]) min_length = helper.make_tensor_value_info("min_length", TensorProto.INT32, [1]) num_beams = helper.make_tensor_value_info("num_beams", TensorProto.INT32, [1]) num_return_sequences = helper.make_tensor_value_info("num_return_sequences", TensorProto.INT32, [1]) length_penalty = helper.make_tensor_value_info("length_penalty", TensorProto.FLOAT, [1]) repetition_penalty = helper.make_tensor_value_info("repetition_penalty", TensorProto.FLOAT, [1]) vocab_mask = helper.make_tensor_value_info("vocab_mask", TensorProto.INT32, [config.vocab_size]) prefix_vocab_mask = helper.make_tensor_value_info( "prefix_vocab_mask", TensorProto.INT32, ["batch_size", config.vocab_size] ) decoder_input_ids = helper.make_tensor_value_info( "decoder_input_ids", TensorProto.INT32, ["batch_size", "initial_sequence_length"] ) logits_processor = helper.make_tensor_value_info("logits_processor", TensorProto.INT32, [1]) cross_qk_layer_head = helper.make_tensor_value_info("cross_qk_layer_head", TensorProto.INT32, ["num_layer_head", 2]) extra_decoding_ids = helper.make_tensor_value_info( "extra_decoding_ids", TensorProto.INT32, ["batch_size", "extra_decoding_ids_len"] ) temperature = helper.make_tensor_value_info("temperature", TensorProto.FLOAT, [1]) graph_inputs = clean_list( [ input_features, max_length, min_length, num_beams, num_return_sequences, length_penalty, repetition_penalty, vocab_mask if args.use_vocab_mask else "", prefix_vocab_mask if args.use_prefix_vocab_mask else "", decoder_input_ids if args.use_forced_decoder_ids else "", logits_processor if args.use_logits_processor else "", cross_qk_layer_head if args.collect_cross_qk else "", extra_decoding_ids if args.extra_decoding_ids else "", temperature if args.use_temperature else "", ] ) # Graph outputs sequences = helper.make_tensor_value_info( "sequences", TensorProto.INT32, ["batch_size", "num_return_sequences", "max_length"] ) sequence_scores = helper.make_tensor_value_info("sequence_scores", TensorProto.FLOAT, ["batch_size"]) scores = helper.make_tensor_value_info("scores", TensorProto.FLOAT, ["batch_size"]) cross_qk = helper.make_tensor_value_info( "cross_qk", TensorProto.FLOAT, ["batch_size", "num_return_sequences", "num_layer_head_cross_qk", "max_length", "frames"], ) no_speech_probs = helper.make_tensor_value_info("no_speech_probs", TensorProto.FLOAT, ["batch_size"]) graph_outputs = clean_list( [ sequences, sequence_scores if args.output_sequence_scores else "", scores if args.output_scores else "", cross_qk if args.output_cross_qk or (not args.cross_qk_onnx_model and args.collect_cross_qk) else "", no_speech_probs if args.output_no_speech_probs else "", ] ) # Replace MultiHeadAttention with DecoderMaskedMultiHeadAttention for CUDA EP inference if hasattr(args, "use_gpu") and args.use_gpu: if update_decoder_subgraph_share_buffer_and_use_decoder_masked_mha(decoder_model.graph): logger.info("Updated whisper decoder subgraph to use DecoderMaskedMultiHeadAttention successfully!") else: logger.warning("DecoderMaskedMultiHeadAttention could not be applied to whisper decoder subgraph") if hasattr(args, "collect_cross_qk") and args.collect_cross_qk: update_decoder_subgraph_output_cross_attention(decoder_model.graph) # Initializers/opsets # Delete shared data between decoder/encoder and move to larger graph initializers initializers = get_shared_initializers(encoder_model, decoder_model) node.attribute.extend( [ helper.make_attribute("decoder", decoder_model.graph), helper.make_attribute("encoder", encoder_model.graph), ] ) opset_import = [helper.make_opsetid(domain="com.microsoft", version=1), helper.make_opsetid(domain="", version=17)] graph_nodes.append(node) if args.output_no_speech_probs: prob_cast_node = helper.make_node( "Cast", inputs=["no_speech_probs_beam"], outputs=["no_speech_probs"], name="no_speech_probs_cast_to_fp32", to=TensorProto.FLOAT, ) graph_nodes.append(prob_cast_node) # Make graph with WhisperBeamSearch op beam_graph = helper.make_graph( graph_nodes, name="WhisperBeamSearch Graph", inputs=graph_inputs, outputs=graph_outputs, initializer=initializers, ) beam_graph_input_names = [gi.name for gi in graph_inputs] beam_graph_output_names = [go.name for go in graph_outputs] if args.cross_qk_onnx_model: post_qk_model = onnx.load_model(args.cross_qk_onnx_model, load_external_data=True) post_qk_graph = post_qk_model.graph beam_graph.initializer.extend(post_qk_graph.initializer) beam_graph.node.extend(post_qk_graph.node) # If tensor from cross_qk_onnx_model has same name as tensor in beamsearch graph, treat them as same tensor. # User should notice this rule when provide cross_qk_onnx_model to append to the beamsearch node. for pgi in post_qk_graph.input: if ( (pgi.name not in beam_graph_input_names) and (pgi.name not in beam_graph_output_names) and (pgi.name != "cross_qk") ): beam_graph.input.extend([pgi]) beam_graph.output.extend(post_qk_graph.output) # Verify graph's inputs match beam search's inputs verify_inputs(beam_inputs, graph_inputs) assert decoder_model.ir_version == encoder_model.ir_version logger.info(f"Using IR version {decoder_model.ir_version} for chained model") # Set IR version of chained model to IR version of subgraphs in order to generate a working E2E model beam_model = helper.make_model_gen_version( beam_graph, producer_name="onnxruntime.transformers", opset_imports=opset_import, ir_version=decoder_model.ir_version, ) # Save WhisperBeamSearch graph and external data if os.path.isfile(args.beam_model_output_dir): logger.info(f"Overwriting {args.beam_model_output_dir} and {args.beam_model_output_dir + '.data'}") if os.path.exists(args.beam_model_output_dir): os.remove(args.beam_model_output_dir) if os.path.exists(args.beam_model_output_dir + ".data"): os.remove(args.beam_model_output_dir + ".data") onnx.save( beam_model, args.beam_model_output_dir, save_as_external_data=args.use_external_data_format, all_tensors_to_one_file=True, convert_attribute=True, location=f"{os.path.basename(args.beam_model_output_dir)}.data", ) try: onnx.checker.check_model(args.beam_model_output_dir, full_check=True) except Exception as e: logger.error(f"An error occurred while running the ONNX checker: {e}", exc_info=True) # noqa: G201