whisper_decoder.py 21 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464
  1. # -------------------------------------------------------------------------
  2. # Copyright (c) Microsoft Corporation. All rights reserved.
  3. # Licensed under the MIT License. See License.txt in the project root for
  4. # license information.
  5. # --------------------------------------------------------------------------
  6. import logging
  7. import os
  8. import tempfile
  9. from itertools import chain
  10. from pathlib import Path
  11. import numpy as np
  12. import onnx
  13. import torch
  14. from float16 import convert_float_to_float16
  15. from google.protobuf.internal.containers import RepeatedCompositeFieldContainer
  16. from onnx import ModelProto, ValueInfoProto
  17. from onnx_model import OnnxModel
  18. from past_helper import PastKeyValuesHelper
  19. from transformers import WhisperConfig
  20. from whisper_inputs import (
  21. convert_inputs_for_ort,
  22. get_model_dynamic_axes,
  23. get_sample_decoder_inputs,
  24. group_past_key_values,
  25. )
  26. from onnxruntime import InferenceSession
  27. logger = logging.getLogger(__name__)
  28. class WhisperDecoder(torch.nn.Module):
  29. """A Whisper decoder with optional past key values"""
  30. def __init__(self, config: WhisperConfig, model: torch.nn.Module, model_impl: str, no_beam_search_op: bool = False):
  31. super().__init__()
  32. self.config = config
  33. self.device = model.device
  34. self.model_impl = model_impl
  35. self.no_beam_search_op = no_beam_search_op
  36. self.decoder = None if model_impl == "openai" else model.model.decoder
  37. self.proj_out = None if model_impl == "openai" else model.proj_out
  38. self.model = model if model_impl == "openai" else None
  39. self.max_source_positions = self.config.max_source_positions
  40. self.num_heads = self.config.decoder_attention_heads
  41. self.head_size = self.config.d_model // self.num_heads
  42. def hf_forward(
  43. self,
  44. decoder_input_ids: torch.Tensor,
  45. encoder_hidden_states: torch.Tensor | None = None,
  46. past_key_values: list[tuple[torch.Tensor]] | None = None,
  47. ):
  48. outputs = self.decoder(
  49. encoder_hidden_states=encoder_hidden_states,
  50. input_ids=decoder_input_ids,
  51. past_key_values=past_key_values,
  52. use_cache=True,
  53. )
  54. logits = self.proj_out(outputs.last_hidden_state)
  55. present_key_values = outputs.past_key_values
  56. if past_key_values is None:
  57. # Return present_self_* and present_cross_* for decoder-init
  58. return logits, present_key_values
  59. # Before: (past_key_self_0, past_value_self_0, past_key_cross_0, past_value_cross_0),
  60. # (past_key_self_1, past_value_self_1, past_key_cross_1, past_value_cross_1),
  61. # After: (past_key_self_0, past_value_self_0, past_key_self_1, past_value_self_1), ...,
  62. # (past_key_cross_0, past_value_cross_0, past_key_cross_1, past_value_cross_1), ...
  63. present_self, present_cross = PastKeyValuesHelper.group_by_self_and_cross(present_key_values)
  64. # Return present_self_* for decoder-with-past since past_cross_* and present_cross_* are identical
  65. return logits, present_self
  66. def oai_forward(
  67. self,
  68. decoder_input_ids: torch.Tensor,
  69. encoder_hidden_states: torch.Tensor | None = None,
  70. past_key_values: list[tuple[torch.Tensor]] | None = None,
  71. ):
  72. past_kv_cache = {}
  73. if past_key_values is not None:
  74. # Convert past KV caches (BxNxSxH --> BxSxNxH --> BxSxD) for OpenAI's forward pass
  75. self_attn_kv_caches, cross_attn_kv_caches = group_past_key_values(past_key_values)
  76. self_attn_kv_caches = [past_kv.transpose(1, 2) for past_kv in self_attn_kv_caches]
  77. self_attn_kv_caches = [past_kv.reshape((*past_kv.shape[:2], -1)) for past_kv in self_attn_kv_caches]
  78. cross_attn_kv_caches = [past_kv.transpose(1, 2) for past_kv in cross_attn_kv_caches]
  79. cross_attn_kv_caches = [past_kv.reshape((*past_kv.shape[:2], -1)) for past_kv in cross_attn_kv_caches]
  80. for idx, block in enumerate(self.model.decoder.blocks):
  81. past_kv_cache[block.attn.key] = self_attn_kv_caches[2 * idx]
  82. past_kv_cache[block.attn.value] = self_attn_kv_caches[2 * idx + 1]
  83. past_kv_cache[block.cross_attn.key] = cross_attn_kv_caches[2 * idx]
  84. past_kv_cache[block.cross_attn.value] = cross_attn_kv_caches[2 * idx + 1]
  85. # Install OpenAI's hooks on the forward pass of each nn.Linear for key and value
  86. # since the hooks will capture the output of the key and value MatMuls, which
  87. # represent the current keys and values.
  88. #
  89. # For OpenAI's forward pass, the hook function will also perform the concat
  90. # operation (past_kv + curr_kv --> pres_kv) if needed. However, the ONNX model
  91. # will not contain this concat operation because the present KV caches aren't
  92. # returned by OpenAI's forward pass.
  93. kv_cache, hooks = self.model.install_kv_cache_hooks()
  94. # Run forward pass
  95. # NOTE: There is a bug with openai-whisper==20240930 with the introduction of SDPA.
  96. # In the Whisper codebase, the following line
  97. #
  98. # is_causal = mask is not None and n_ctx > 1
  99. #
  100. # has been added where `mask` is a torch tensor. The right-hand side evaluates to `tensor(True/False)`
  101. # but `is_causal` only accepts the boolean value. The fix is to apply `.item()` after the right-hand
  102. # side has been evaluated. In other words, the line should be
  103. #
  104. # is_causal = (mask is not None and n_ctx > 1).item()
  105. #
  106. # instead.
  107. logits = self.model.decoder(x=decoder_input_ids, xa=encoder_hidden_states, kv_cache=past_kv_cache)
  108. # Re-do concat operation on self attention KV caches for ONNX export (if past self attention KV caches exist)
  109. if past_key_values is not None:
  110. for block in self.model.decoder.blocks:
  111. kv_cache[block.attn.key] = torch.cat(
  112. [past_kv_cache[block.attn.key], kv_cache[block.attn.key]], dim=1
  113. ).detach()
  114. kv_cache[block.attn.value] = torch.cat(
  115. [past_kv_cache[block.attn.value], kv_cache[block.attn.value]], dim=1
  116. ).detach()
  117. present_self, present_cross = [], []
  118. for block in self.model.decoder.blocks:
  119. # Group self and cross values
  120. present_self.append(kv_cache[block.attn.key])
  121. present_self.append(kv_cache[block.attn.value])
  122. if past_key_values is None:
  123. # Return present_self_* and present_cross_* for decoder-init
  124. present_cross.append(kv_cache[block.cross_attn.key])
  125. present_cross.append(kv_cache[block.cross_attn.value])
  126. # Convert present KV caches (BxSxD --> BxSxNxH --> BxNxSxH) after OpenAI's forward pass
  127. present_self = [
  128. present_kv.reshape((*present_kv.shape[:2], -1, self.head_size)).transpose(1, 2)
  129. for present_kv in present_self
  130. ]
  131. present_cross = [
  132. present_kv.reshape((*present_kv.shape[:2], -1, self.head_size)).transpose(1, 2)
  133. for present_kv in present_cross
  134. ]
  135. # Remove OpenAI's hooks since they can persist after this function completes
  136. for hook in hooks:
  137. hook.remove()
  138. if past_key_values is None:
  139. # Return present_self_* and present_cross_* for decoder-init
  140. present_key_values = PastKeyValuesHelper.group_by_layer(
  141. present_self + present_cross, len(present_self) // 2
  142. )
  143. return logits, present_key_values
  144. # Return present_self_* for decoder-with-past since past_cross_* and present_cross_* are identical
  145. return logits, present_self
  146. def forward(
  147. self,
  148. decoder_input_ids: torch.Tensor,
  149. encoder_hidden_states: torch.Tensor | None = None,
  150. past_key_values: list[tuple[torch.Tensor]] | None = None,
  151. ):
  152. if self.model_impl == "openai":
  153. return self.oai_forward(decoder_input_ids, encoder_hidden_states, past_key_values)
  154. return self.hf_forward(decoder_input_ids, encoder_hidden_states, past_key_values)
  155. def input_names(self):
  156. if self.first_pass:
  157. input_names = ["input_ids", "encoder_hidden_states"]
  158. else:
  159. input_names = [
  160. "input_ids",
  161. "encoder_hidden_states",
  162. *list(
  163. chain.from_iterable(
  164. (f"past_key_self_{i}", f"past_value_self_{i}", f"past_key_cross_{i}", f"past_value_cross_{i}")
  165. for i in range(self.config.decoder_layers)
  166. )
  167. ),
  168. ]
  169. return input_names
  170. def output_names(self):
  171. if self.first_pass:
  172. output_names = [
  173. "logits",
  174. *list(
  175. chain.from_iterable(
  176. (
  177. f"present_key_self_{i}",
  178. f"present_value_self_{i}",
  179. f"present_key_cross_{i}",
  180. f"present_value_cross_{i}",
  181. )
  182. for i in range(self.config.decoder_layers)
  183. )
  184. ),
  185. ]
  186. else:
  187. output_names = [
  188. "logits",
  189. *list(
  190. chain.from_iterable(
  191. (f"present_key_self_{i}", f"present_value_self_{i}") for i in range(self.config.decoder_layers)
  192. )
  193. ),
  194. ]
  195. return output_names
  196. def dynamic_axes(self, input_names, output_names):
  197. dynamic_axes = get_model_dynamic_axes(self.config, input_names, output_names)
  198. if "input_ids" in dynamic_axes and not self.no_beam_search_op:
  199. # Set dynamic axes for `input_ids` when using beam search op to {0: "batch_size"} only
  200. del dynamic_axes["input_ids"][1]
  201. return dynamic_axes
  202. def inputs(self, use_fp16_inputs: bool, use_int32_inputs: bool, return_dict: bool = False):
  203. inputs = get_sample_decoder_inputs(
  204. self.config,
  205. self.device,
  206. batch_size=2,
  207. past_sequence_length=(0 if self.first_pass else 6),
  208. sequence_length=(6 if self.first_pass else 1),
  209. use_fp16=use_fp16_inputs,
  210. use_int32=use_int32_inputs,
  211. )
  212. if return_dict:
  213. if self.first_pass:
  214. del inputs["past_key_values"]
  215. return inputs
  216. if self.first_pass:
  217. return (
  218. inputs["decoder_input_ids"],
  219. inputs["encoder_hidden_states"],
  220. )
  221. return (
  222. inputs["decoder_input_ids"],
  223. inputs["encoder_hidden_states"],
  224. inputs["past_key_values"],
  225. )
  226. def fix_key_value_cache_dims(self, io: ValueInfoProto, is_cross: bool = False, is_output: bool = False):
  227. # Shape should be (batch_size, num_heads, sequence_length, head_size) for self attention KV caches
  228. # and (batch_size, num_heads, num_frames // 2, head_size) for cross attention KV caches
  229. num_heads = io.type.tensor_type.shape.dim[1]
  230. if "_dim_" in num_heads.dim_param:
  231. num_heads.Clear()
  232. num_heads.dim_value = self.num_heads
  233. sequence_length = io.type.tensor_type.shape.dim[2]
  234. if "_dim_" in sequence_length.dim_param:
  235. sequence_length.Clear()
  236. if is_cross:
  237. sequence_length.dim_value = self.max_source_positions
  238. else:
  239. sequence_length.dim_param = "total_sequence_length" if is_output else "past_sequence_length"
  240. head_size = io.type.tensor_type.shape.dim[3]
  241. if "_dim_" in head_size.dim_param:
  242. head_size.Clear()
  243. head_size.dim_value = self.head_size
  244. return io
  245. def fix_io(self, io_list: RepeatedCompositeFieldContainer, is_output: bool = False):
  246. # Fix order of inputs/outputs and each dim_value of input/output
  247. reordered_io = []
  248. self_attn_kv_caches = []
  249. cross_attn_kv_caches = []
  250. for io in io_list:
  251. if "past" not in io.name and "present" not in io.name:
  252. reordered_io.append(io)
  253. elif "self" in io.name:
  254. # Self attention KV caches
  255. new_io = self.fix_key_value_cache_dims(io, is_cross=False, is_output=is_output)
  256. if self.no_beam_search_op:
  257. reordered_io.append(new_io)
  258. else:
  259. self_attn_kv_caches.append(new_io)
  260. else:
  261. # Cross attention KV caches
  262. new_io = self.fix_key_value_cache_dims(io, is_cross=True, is_output=is_output)
  263. if self.no_beam_search_op:
  264. reordered_io.append(new_io)
  265. else:
  266. cross_attn_kv_caches.append(new_io)
  267. if not self.no_beam_search_op:
  268. reordered_io += self_attn_kv_caches + cross_attn_kv_caches
  269. return reordered_io
  270. def fix_inputs_and_outputs(self, model: ModelProto):
  271. # ONNX exporter might mark dimensions like 'Transposepresent_value_self_1_dim_2' in shape inference.
  272. # We now change the dim_values to the correct one.
  273. reordered_inputs = self.fix_io(model.graph.input, is_output=False)
  274. while len(model.graph.input) > 0:
  275. model.graph.input.pop()
  276. model.graph.input.extend(reordered_inputs)
  277. reordered_outputs = self.fix_io(model.graph.output, is_output=True)
  278. while len(model.graph.output) > 0:
  279. model.graph.output.pop()
  280. model.graph.output.extend(reordered_outputs)
  281. return model
  282. def fix_layernorm_weights(self, model: ModelProto, use_fp16_inputs: bool):
  283. if self.model_impl == "openai" and use_fp16_inputs:
  284. # Cast ONNX model to float16 to ensure LayerNorm weights are converted from
  285. # float32 to float16 since exported model already has float16 weights everywhere
  286. # except for LayerNorm ops. This happens because OpenAI always upcasts to float32
  287. # when computing LayerNorm.
  288. #
  289. # Reference:
  290. # https://github.com/openai/whisper/blob/90db0de1896c23cbfaf0c58bc2d30665f709f170/whisper/model.py#L41
  291. model = convert_float_to_float16(model)
  292. return model
  293. def export_onnx(
  294. self,
  295. onnx_model_path: str,
  296. provider: str,
  297. verbose: bool = True,
  298. use_external_data_format: bool = False,
  299. use_fp16_inputs: bool = False,
  300. use_int32_inputs: bool = True,
  301. use_encoder_hidden_states: bool = False,
  302. use_kv_cache_inputs: bool = True,
  303. ):
  304. """Export decoder to ONNX
  305. Args:
  306. onnx_model_path (str): path to save ONNX model
  307. provider (str): provider to use for verifying parity on ONNX model
  308. verbose (bool, optional): print verbose information. Defaults to True.
  309. use_external_data_format (bool, optional): use external data format or not. Defaults to False.
  310. use_fp16_inputs (bool, optional): use float16 inputs for the KV caches. Defaults to False.
  311. use_int32_inputs (bool, optional): use int32 inputs for the decoder_input_ids. Defaults to True.
  312. use_encoder_hidden_states (bool, optional): use encoder_hidden_states as model input for decoder-init/decoder-without-past models. Defaults to False.
  313. use_kv_cache_inputs (bool, optional): use KV caches as model inputs for decoder-with-past models. Defaults to True.
  314. """
  315. # Shape of decoder's tensors:
  316. # Required Inputs:
  317. # decoder_input_ids: (batch_size, sequence_length)
  318. # Optional Inputs:
  319. # encoder_hidden_states (comes from encoder's outputs): (batch_size, num_frames // 2, hidden_size)
  320. # past_{key/value}_self_* (past self attention KV caches): (batch_size, num_heads, past_sequence_length, head_size)
  321. # past_{key/value}_cross_* (past cross attention KV caches): (batch_size, num_heads, num_frames // 2, head_size)
  322. # Outputs:
  323. # logits: (batch_size, sequence_length, vocab_size)
  324. # present_{key/value}_self_* (present self attention KV caches): (batch_size, num_heads, past_sequence_length + sequence_length, head_size)
  325. # present_{key/value}_cross_* (present cross attention KV caches): (batch_size, num_heads, num_frames // 2, head_size)
  326. # For the first pass through the decoder (i.e. decoder-init/decoder-without-past)
  327. self.first_pass = use_encoder_hidden_states and not use_kv_cache_inputs
  328. # For subsequent passes through the decoder (i.e. decoder-with-past)
  329. self.later_pass = not use_encoder_hidden_states and use_kv_cache_inputs
  330. assert self.first_pass or self.later_pass, (
  331. "Only one of `use_encoder_hidden_states` and `use_kv_cache_inputs` can be true at once."
  332. )
  333. inputs = self.inputs(use_fp16_inputs=use_fp16_inputs, use_int32_inputs=use_int32_inputs)
  334. input_names = self.input_names()
  335. output_names = self.output_names()
  336. dynamic_axes = self.dynamic_axes(input_names, output_names)
  337. Path(onnx_model_path).parent.mkdir(parents=True, exist_ok=True)
  338. with tempfile.TemporaryDirectory() as tmp_dir_name:
  339. temp_onnx_model_path = os.path.join(tmp_dir_name, "decoder.onnx")
  340. Path(temp_onnx_model_path).parent.mkdir(parents=True, exist_ok=True)
  341. out_path = temp_onnx_model_path if use_external_data_format else onnx_model_path
  342. torch.onnx.export(
  343. self,
  344. args=inputs,
  345. f=out_path,
  346. export_params=True,
  347. input_names=input_names,
  348. output_names=output_names,
  349. dynamic_axes=dynamic_axes,
  350. opset_version=17,
  351. do_constant_folding=True,
  352. verbose=verbose,
  353. )
  354. model = onnx.load_model(out_path, load_external_data=use_external_data_format)
  355. model = self.fix_inputs_and_outputs(model)
  356. model = self.fix_layernorm_weights(model, use_fp16_inputs)
  357. OnnxModel.save(
  358. model,
  359. onnx_model_path,
  360. save_as_external_data=use_external_data_format,
  361. all_tensors_to_one_file=True,
  362. )
  363. self.verify_onnx(onnx_model_path, provider, use_fp16_inputs, use_int32_inputs)
  364. def verify_onnx(
  365. self,
  366. onnx_model_path: str,
  367. provider: str,
  368. use_fp16_inputs: bool,
  369. use_int32_inputs: bool,
  370. ):
  371. """Verify ONNX model outputs and PyTorch model outputs match
  372. Args:
  373. onnx_model_path (str): path to save ONNX model
  374. provider (str): execution provider for ONNX model
  375. use_fp16_inputs (bool, optional): use float16 inputs for the KV caches
  376. use_int32_inputs (bool, optional): use int32 inputs for the decoder_input_ids
  377. """
  378. # Shape of decoder's tensors:
  379. # Required Inputs:
  380. # decoder_input_ids: (batch_size, sequence_length)
  381. # Optional Inputs:
  382. # encoder_hidden_states (comes from encoder's outputs): (batch_size, num_frames // 2, hidden_size)
  383. # past_{key/value}_self_* (past self attention KV caches): (batch_size, num_heads, past_sequence_length, head_size)
  384. # past_{key/value}_cross_* (past cross attention KV caches): (batch_size, num_heads, num_frames // 2, head_size)
  385. # Outputs:
  386. # logits: (batch_size, sequence_length, vocab_size)
  387. # present_{key/value}_self_* (present self attention KV caches): (batch_size, num_heads, past_sequence_length + sequence_length, head_size)
  388. # present_{key/value}_cross_* (present cross attention KV caches): (batch_size, num_heads, num_frames // 2, head_size)
  389. # Run PyTorch model
  390. inputs = self.inputs(use_fp16_inputs=use_fp16_inputs, use_int32_inputs=use_int32_inputs, return_dict=True)
  391. pt_outputs = []
  392. if self.first_pass:
  393. out = self.forward(**inputs)
  394. pt_outputs.append(out[0].detach().cpu().numpy())
  395. for present_key_value_layer in out[1]:
  396. for present_key_value in present_key_value_layer:
  397. pt_outputs.append(present_key_value.detach().cpu().numpy())
  398. else:
  399. out = self.forward(**inputs)
  400. pt_outputs.append(out[0].detach().cpu().numpy())
  401. for present_self_key_value in out[1]:
  402. pt_outputs.append(present_self_key_value.detach().cpu().numpy())
  403. # Run ONNX model
  404. sess = InferenceSession(onnx_model_path, providers=[provider])
  405. ort_outputs = sess.run(None, convert_inputs_for_ort(inputs, sess))
  406. # Calculate output difference
  407. try:
  408. for i, output_name in enumerate(self.output_names()):
  409. diff = np.abs(pt_outputs[i] - ort_outputs[i])
  410. logger.warning(f"Comparing {output_name}...")
  411. logger.warning(f"Max diff: {np.max(diff)}")
  412. except: # noqa: E722
  413. pass