whisper_chain.py 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331
  1. # -------------------------------------------------------------------------
  2. # Copyright (c) Microsoft Corporation. All rights reserved.
  3. # Licensed under the MIT License. See License.txt in the project root for
  4. # license information.
  5. # --------------------------------------------------------------------------
  6. import logging
  7. import os
  8. import onnx
  9. from benchmark_helper import Precision
  10. from convert_generation import (
  11. get_shared_initializers,
  12. update_decoder_subgraph_output_cross_attention,
  13. update_decoder_subgraph_share_buffer_and_use_decoder_masked_mha,
  14. )
  15. from onnx import TensorProto, helper
  16. from transformers import WhisperConfig, WhisperTokenizer
  17. logger = logging.getLogger(__name__)
  18. def verify_inputs(beam_inputs, graph_inputs):
  19. # Verify that ONNX graph's inputs match beam search op's inputs
  20. beam_required_inputs = list(filter(lambda beam_input: beam_input, beam_inputs))
  21. assert len(graph_inputs) == len(beam_required_inputs)
  22. for graph_input, beam_input in zip(graph_inputs, beam_required_inputs, strict=False):
  23. # Check if graph_input is in beam_input to handle beam_input names with the "_fp16" suffix
  24. assert graph_input.name in beam_input
  25. def clean_list(arr, remove_all_strings=True):
  26. if remove_all_strings:
  27. # Remove all empty strings in list
  28. return list(filter(lambda elm: elm != "", arr))
  29. # Remove empty strings at end of list
  30. while len(arr) > 0:
  31. if arr[-1] == "":
  32. arr.pop()
  33. else:
  34. break
  35. return arr
  36. def chain_model(args):
  37. # Load encoder/decoder and insert necessary (but unused) graph inputs expected by WhisperBeamSearch op
  38. encoder_model = onnx.load_model(args.encoder_path, load_external_data=True)
  39. encoder_model.graph.name = "encoderdecoderinit subgraph"
  40. decoder_model = onnx.load_model(args.decoder_path, load_external_data=True)
  41. decoder_model.graph.name = "decoder subgraph"
  42. config = WhisperConfig.from_pretrained(args.model_name_or_path, cache_dir=args.cache_dir)
  43. tokenizer = WhisperTokenizer.from_pretrained(args.model_name_or_path, cache_dir=args.cache_dir)
  44. # Create inputs/outputs for WhisperBeamSearch op
  45. temperature_name = "temperature_fp16" if args.precision == Precision.FLOAT16 else "temperature"
  46. beam_inputs = [
  47. "input_features_fp16" if args.precision == Precision.FLOAT16 else "input_features",
  48. "max_length",
  49. "min_length",
  50. "num_beams",
  51. "num_return_sequences",
  52. "length_penalty_fp16" if args.precision == Precision.FLOAT16 else "length_penalty",
  53. "repetition_penalty_fp16" if args.precision == Precision.FLOAT16 else "repetition_penalty",
  54. "vocab_mask" if args.use_vocab_mask else "",
  55. "prefix_vocab_mask" if args.use_prefix_vocab_mask else "",
  56. "", # attention mask
  57. "decoder_input_ids" if args.use_forced_decoder_ids else "",
  58. "logits_processor" if args.use_logits_processor else "",
  59. "cross_qk_layer_head" if args.collect_cross_qk else "",
  60. "extra_decoding_ids" if args.extra_decoding_ids else "",
  61. temperature_name if args.use_temperature else "",
  62. ]
  63. sequence_scores_name = "sequence_scores_fp16" if args.precision == Precision.FLOAT16 else "sequence_scores"
  64. scores_name = "scores_fp16" if args.precision == Precision.FLOAT16 else "scores"
  65. beam_outputs = [
  66. "sequences",
  67. sequence_scores_name if args.output_sequence_scores else "",
  68. scores_name if args.output_scores else "",
  69. "cross_qk" if args.collect_cross_qk else "",
  70. "no_speech_probs_beam" if args.output_no_speech_probs else "",
  71. ]
  72. graph_nodes = []
  73. if args.precision == Precision.FLOAT16:
  74. input_features_cast_node = helper.make_node(
  75. "Cast",
  76. inputs=["input_features"],
  77. outputs=["input_features_fp16"],
  78. name="CastInputFeaturesToFp16",
  79. to=TensorProto.FLOAT16,
  80. )
  81. len_pen_cast_node = helper.make_node(
  82. "Cast",
  83. inputs=["length_penalty"],
  84. outputs=["length_penalty_fp16"],
  85. name="CastLengthPenaltyToFp16",
  86. to=TensorProto.FLOAT16,
  87. )
  88. rep_pen_cast_node = helper.make_node(
  89. "Cast",
  90. inputs=["repetition_penalty"],
  91. outputs=["repetition_penalty_fp16"],
  92. name="CastRepetitionPenaltyToFp16",
  93. to=TensorProto.FLOAT16,
  94. )
  95. graph_nodes.extend([input_features_cast_node, len_pen_cast_node, rep_pen_cast_node])
  96. if args.use_temperature:
  97. temp_cast_node = helper.make_node(
  98. "Cast",
  99. inputs=["temperature"],
  100. outputs=["temperature_fp16"],
  101. name="temperature_to_fp16",
  102. to=TensorProto.FLOAT16,
  103. )
  104. graph_nodes.append(temp_cast_node)
  105. if args.output_sequence_scores:
  106. output_sequence_scores_cast_node = helper.make_node(
  107. "Cast",
  108. inputs=["sequence_scores_fp16"],
  109. outputs=["sequence_scores"],
  110. name="CastOutputSequenceScoresToFp32",
  111. to=TensorProto.FLOAT,
  112. )
  113. graph_nodes.append(output_sequence_scores_cast_node)
  114. if args.output_scores:
  115. output_scores_cast_node = helper.make_node(
  116. "Cast",
  117. inputs=["scores_fp16"],
  118. outputs=["scores"],
  119. name="CastScoresToFp32",
  120. to=TensorProto.FLOAT,
  121. )
  122. graph_nodes.append(output_scores_cast_node)
  123. # Create WhisperBeamSearch op
  124. beam_search_attrs = [
  125. helper.make_attribute("eos_token_id", config.eos_token_id),
  126. helper.make_attribute("pad_token_id", config.pad_token_id),
  127. helper.make_attribute(
  128. "decoder_start_token_id", config.decoder_start_token_id
  129. ), # same as tokenizer.convert_tokens_to_ids(['<|startoftranscript|>'])[0]
  130. helper.make_attribute("translate_token_id", tokenizer.convert_tokens_to_ids(["<|translate|>"])[0]),
  131. helper.make_attribute("transcribe_token_id", tokenizer.convert_tokens_to_ids(["<|transcribe|>"])[0]),
  132. helper.make_attribute("start_of_lm_token_id", tokenizer.convert_tokens_to_ids(["<|startoflm|>"])[0]),
  133. (
  134. helper.make_attribute("no_speech_token_id", tokenizer.convert_tokens_to_ids(["<|nospeech|>"])[0])
  135. if args.output_no_speech_probs
  136. else ""
  137. ),
  138. helper.make_attribute("no_timestamps_token_id", tokenizer.convert_tokens_to_ids(["<|notimestamps|>"])[0]),
  139. helper.make_attribute("beginning_timestamp_token_id", tokenizer.convert_tokens_to_ids(["<|0.00|>"])[0]),
  140. helper.make_attribute("no_repeat_ngram_size", args.no_repeat_ngram_size),
  141. helper.make_attribute("early_stopping", True),
  142. helper.make_attribute("model_type", 2),
  143. helper.make_attribute("decoder_output_cross_qk", 1) if args.collect_cross_qk else "",
  144. ]
  145. node = helper.make_node(
  146. "WhisperBeamSearch",
  147. inputs=clean_list(beam_inputs, remove_all_strings=False),
  148. outputs=clean_list(beam_outputs, remove_all_strings=False),
  149. name="BeamSearch",
  150. domain="com.microsoft",
  151. )
  152. node.attribute.extend(clean_list(beam_search_attrs, remove_all_strings=True))
  153. # Graph inputs
  154. input_features = helper.make_tensor_value_info(
  155. "input_features", TensorProto.FLOAT, ["batch_size", "feature_size", "sequence_length"]
  156. )
  157. max_length = helper.make_tensor_value_info("max_length", TensorProto.INT32, [1])
  158. min_length = helper.make_tensor_value_info("min_length", TensorProto.INT32, [1])
  159. num_beams = helper.make_tensor_value_info("num_beams", TensorProto.INT32, [1])
  160. num_return_sequences = helper.make_tensor_value_info("num_return_sequences", TensorProto.INT32, [1])
  161. length_penalty = helper.make_tensor_value_info("length_penalty", TensorProto.FLOAT, [1])
  162. repetition_penalty = helper.make_tensor_value_info("repetition_penalty", TensorProto.FLOAT, [1])
  163. vocab_mask = helper.make_tensor_value_info("vocab_mask", TensorProto.INT32, [config.vocab_size])
  164. prefix_vocab_mask = helper.make_tensor_value_info(
  165. "prefix_vocab_mask", TensorProto.INT32, ["batch_size", config.vocab_size]
  166. )
  167. decoder_input_ids = helper.make_tensor_value_info(
  168. "decoder_input_ids", TensorProto.INT32, ["batch_size", "initial_sequence_length"]
  169. )
  170. logits_processor = helper.make_tensor_value_info("logits_processor", TensorProto.INT32, [1])
  171. cross_qk_layer_head = helper.make_tensor_value_info("cross_qk_layer_head", TensorProto.INT32, ["num_layer_head", 2])
  172. extra_decoding_ids = helper.make_tensor_value_info(
  173. "extra_decoding_ids", TensorProto.INT32, ["batch_size", "extra_decoding_ids_len"]
  174. )
  175. temperature = helper.make_tensor_value_info("temperature", TensorProto.FLOAT, [1])
  176. graph_inputs = clean_list(
  177. [
  178. input_features,
  179. max_length,
  180. min_length,
  181. num_beams,
  182. num_return_sequences,
  183. length_penalty,
  184. repetition_penalty,
  185. vocab_mask if args.use_vocab_mask else "",
  186. prefix_vocab_mask if args.use_prefix_vocab_mask else "",
  187. decoder_input_ids if args.use_forced_decoder_ids else "",
  188. logits_processor if args.use_logits_processor else "",
  189. cross_qk_layer_head if args.collect_cross_qk else "",
  190. extra_decoding_ids if args.extra_decoding_ids else "",
  191. temperature if args.use_temperature else "",
  192. ]
  193. )
  194. # Graph outputs
  195. sequences = helper.make_tensor_value_info(
  196. "sequences", TensorProto.INT32, ["batch_size", "num_return_sequences", "max_length"]
  197. )
  198. sequence_scores = helper.make_tensor_value_info("sequence_scores", TensorProto.FLOAT, ["batch_size"])
  199. scores = helper.make_tensor_value_info("scores", TensorProto.FLOAT, ["batch_size"])
  200. cross_qk = helper.make_tensor_value_info(
  201. "cross_qk",
  202. TensorProto.FLOAT,
  203. ["batch_size", "num_return_sequences", "num_layer_head_cross_qk", "max_length", "frames"],
  204. )
  205. no_speech_probs = helper.make_tensor_value_info("no_speech_probs", TensorProto.FLOAT, ["batch_size"])
  206. graph_outputs = clean_list(
  207. [
  208. sequences,
  209. sequence_scores if args.output_sequence_scores else "",
  210. scores if args.output_scores else "",
  211. cross_qk if args.output_cross_qk or (not args.cross_qk_onnx_model and args.collect_cross_qk) else "",
  212. no_speech_probs if args.output_no_speech_probs else "",
  213. ]
  214. )
  215. # Replace MultiHeadAttention with DecoderMaskedMultiHeadAttention for CUDA EP inference
  216. if hasattr(args, "use_gpu") and args.use_gpu:
  217. if update_decoder_subgraph_share_buffer_and_use_decoder_masked_mha(decoder_model.graph):
  218. logger.info("Updated whisper decoder subgraph to use DecoderMaskedMultiHeadAttention successfully!")
  219. else:
  220. logger.warning("DecoderMaskedMultiHeadAttention could not be applied to whisper decoder subgraph")
  221. if hasattr(args, "collect_cross_qk") and args.collect_cross_qk:
  222. update_decoder_subgraph_output_cross_attention(decoder_model.graph)
  223. # Initializers/opsets
  224. # Delete shared data between decoder/encoder and move to larger graph initializers
  225. initializers = get_shared_initializers(encoder_model, decoder_model)
  226. node.attribute.extend(
  227. [
  228. helper.make_attribute("decoder", decoder_model.graph),
  229. helper.make_attribute("encoder", encoder_model.graph),
  230. ]
  231. )
  232. opset_import = [helper.make_opsetid(domain="com.microsoft", version=1), helper.make_opsetid(domain="", version=17)]
  233. graph_nodes.append(node)
  234. if args.output_no_speech_probs:
  235. prob_cast_node = helper.make_node(
  236. "Cast",
  237. inputs=["no_speech_probs_beam"],
  238. outputs=["no_speech_probs"],
  239. name="no_speech_probs_cast_to_fp32",
  240. to=TensorProto.FLOAT,
  241. )
  242. graph_nodes.append(prob_cast_node)
  243. # Make graph with WhisperBeamSearch op
  244. beam_graph = helper.make_graph(
  245. graph_nodes,
  246. name="WhisperBeamSearch Graph",
  247. inputs=graph_inputs,
  248. outputs=graph_outputs,
  249. initializer=initializers,
  250. )
  251. beam_graph_input_names = [gi.name for gi in graph_inputs]
  252. beam_graph_output_names = [go.name for go in graph_outputs]
  253. if args.cross_qk_onnx_model:
  254. post_qk_model = onnx.load_model(args.cross_qk_onnx_model, load_external_data=True)
  255. post_qk_graph = post_qk_model.graph
  256. beam_graph.initializer.extend(post_qk_graph.initializer)
  257. beam_graph.node.extend(post_qk_graph.node)
  258. # If tensor from cross_qk_onnx_model has same name as tensor in beamsearch graph, treat them as same tensor.
  259. # User should notice this rule when provide cross_qk_onnx_model to append to the beamsearch node.
  260. for pgi in post_qk_graph.input:
  261. if (
  262. (pgi.name not in beam_graph_input_names)
  263. and (pgi.name not in beam_graph_output_names)
  264. and (pgi.name != "cross_qk")
  265. ):
  266. beam_graph.input.extend([pgi])
  267. beam_graph.output.extend(post_qk_graph.output)
  268. # Verify graph's inputs match beam search's inputs
  269. verify_inputs(beam_inputs, graph_inputs)
  270. assert decoder_model.ir_version == encoder_model.ir_version
  271. logger.info(f"Using IR version {decoder_model.ir_version} for chained model")
  272. # Set IR version of chained model to IR version of subgraphs in order to generate a working E2E model
  273. beam_model = helper.make_model_gen_version(
  274. beam_graph,
  275. producer_name="onnxruntime.transformers",
  276. opset_imports=opset_import,
  277. ir_version=decoder_model.ir_version,
  278. )
  279. # Save WhisperBeamSearch graph and external data
  280. if os.path.isfile(args.beam_model_output_dir):
  281. logger.info(f"Overwriting {args.beam_model_output_dir} and {args.beam_model_output_dir + '.data'}")
  282. if os.path.exists(args.beam_model_output_dir):
  283. os.remove(args.beam_model_output_dir)
  284. if os.path.exists(args.beam_model_output_dir + ".data"):
  285. os.remove(args.beam_model_output_dir + ".data")
  286. onnx.save(
  287. beam_model,
  288. args.beam_model_output_dir,
  289. save_as_external_data=args.use_external_data_format,
  290. all_tensors_to_one_file=True,
  291. convert_attribute=True,
  292. location=f"{os.path.basename(args.beam_model_output_dir)}.data",
  293. )
  294. try:
  295. onnx.checker.check_model(args.beam_model_output_dir, full_check=True)
  296. except Exception as e:
  297. logger.error(f"An error occurred while running the ONNX checker: {e}", exc_info=True) # noqa: G201