whisper_inputs.py 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380
  1. # -------------------------------------------------------------------------
  2. # Copyright (c) Microsoft Corporation. All rights reserved.
  3. # Licensed under the MIT License. See License.txt in the project root for
  4. # license information.
  5. # --------------------------------------------------------------------------
  6. import logging
  7. import numpy as np
  8. import torch
  9. from transformers import WhisperConfig
  10. from onnxruntime import InferenceSession
  11. logger = logging.getLogger(__name__)
  12. # Create audio_features for encoder
  13. # Shape is (batch_size, feature_size, sequence_length) = (batch_size, num_mel_filters, num_frames)
  14. # where num_mel_filters is a model attribute and num_frames = (chunk_length * sample_rate) // hop_length.
  15. #
  16. # Hard-coded audio hyperparameters:
  17. # SAMPLE_RATE = 16000
  18. # N_FFT = 400
  19. # HOP_LENGTH = 160
  20. # CHUNK_LENGTH = 30 (i.e. 30-second chunk of audio)
  21. # N_SAMPLES = CHUNK_LENGTH * SAMPLE_RATE = 30 * 16000 = 480000 (i.e. 480,000 samples in a 30-second chunk of audio)
  22. # N_FRAMES = N_SAMPLES // HOP_LENGTH = 480000 // 160 = 3000 (i.e. 3000 frames in a mel spectrogram input)
  23. #
  24. # N_SAMPLES_PER_TOKEN = HOP_LENGTH * 2 = 160 * 2 = 320
  25. # FRAMES_PER_TOKEN = SAMPLE_RATE // HOP_LENGTH = 16000 // 160 = 100 (i.e. 10 ms per audio frame)
  26. # TOKENS_PER_SECOND = SAMPLE_RATE // N_SAMPLES_PER_TOKEN = 16000 // 320 = 50 (i.e. 20 ms per audio token)
  27. def get_sample_audio_features(
  28. config: WhisperConfig,
  29. device: torch.device,
  30. batch_size: int,
  31. sequence_length: int = 3000,
  32. use_fp16: bool = False,
  33. ):
  34. torch_dtype = torch.float16 if use_fp16 else torch.float32
  35. audio_features = torch.randn(batch_size, config.num_mel_bins, sequence_length, device=device, dtype=torch_dtype)
  36. return audio_features
  37. # Create input_ids for decoder
  38. # Shape is (batch_size, sequence_length) where sequence_length is the initial decoder sequence length
  39. def get_sample_decoder_input_ids(
  40. config: WhisperConfig,
  41. device: torch.device,
  42. batch_size: int,
  43. sequence_length: int,
  44. use_int32: bool = True,
  45. ):
  46. torch_dtype = torch.int32 if use_int32 else torch.int64
  47. decoder_input_ids = torch.randint(
  48. low=0, high=config.vocab_size, size=(batch_size, sequence_length), device=device, dtype=torch_dtype
  49. )
  50. return decoder_input_ids
  51. # Create encoder_hidden_states for decoder-init
  52. # Shape is (batch_size, num_frames // 2, hidden_size)
  53. def get_sample_encoder_hidden_states(
  54. config: WhisperConfig,
  55. device: torch.device,
  56. batch_size: int,
  57. use_fp16: bool = False,
  58. ):
  59. torch_dtype = torch.float16 if use_fp16 else torch.float32
  60. encoder_hidden_states = torch.randn(
  61. batch_size, config.max_source_positions, config.d_model, device=device, dtype=torch_dtype
  62. )
  63. return encoder_hidden_states
  64. # Create past_key_values
  65. # Self-attention KV caches are of shape (batch_size, num_heads, past_sequence_length, head_size)
  66. # Cross-attention KV caches are of shape (batch_size, num_heads, num_frames // 2, head_size)
  67. def get_sample_past_key_values(
  68. config: WhisperConfig,
  69. device: torch.device,
  70. batch_size: int,
  71. past_seq_len: int,
  72. use_fp16: bool = False,
  73. ):
  74. num_heads = config.decoder_attention_heads
  75. head_size = config.d_model // num_heads
  76. max_source_positions = (
  77. config.max_source_positions
  78. ) # equal to num_frames // 2 = encoder's sequence_length // 2 = 3000 // 2 = 1500
  79. torch_dtype = torch.float16 if use_fp16 else torch.float32
  80. self_attention_kv_caches = [
  81. (
  82. torch.rand(batch_size, num_heads, past_seq_len, head_size, device=device, dtype=torch_dtype),
  83. torch.rand(batch_size, num_heads, past_seq_len, head_size, device=device, dtype=torch_dtype),
  84. )
  85. for _ in range(config.decoder_layers)
  86. ]
  87. cross_attention_kv_caches = [
  88. (
  89. torch.rand(batch_size, num_heads, max_source_positions, head_size, device=device, dtype=torch_dtype),
  90. torch.rand(batch_size, num_heads, max_source_positions, head_size, device=device, dtype=torch_dtype),
  91. )
  92. for _ in range(config.decoder_layers)
  93. ]
  94. return flatten_past_key_values(self_attention_kv_caches, cross_attention_kv_caches)
  95. # Flatten KV caches into pairs-of-4 where each pair is defined as:
  96. # (self_attn_key_cache, self_attn_value_cache, cross_attn_key_cache, cross_attn_value_cache)
  97. def flatten_past_key_values(
  98. self_attn_kv_caches: list[tuple[torch.Tensor, torch.Tensor]],
  99. cross_attn_kv_caches: list[tuple[torch.Tensor, torch.Tensor]],
  100. ):
  101. past_key_values = []
  102. for (self_k_cache, self_v_cache), (cross_k_cache, cross_v_cache) in zip(
  103. self_attn_kv_caches, cross_attn_kv_caches, strict=False
  104. ):
  105. layer_kv_caches = (self_k_cache, self_v_cache, cross_k_cache, cross_v_cache)
  106. past_key_values.append(layer_kv_caches)
  107. return past_key_values
  108. # Group KV caches into two 1D lists where one list contains the self attention KV caches and
  109. # one list contains the cross attention KV caches
  110. def group_past_key_values(
  111. kv_caches: list[tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]],
  112. ):
  113. self_attn_kv_caches, cross_attn_kv_caches = [], []
  114. for self_k_cache, self_v_cache, cross_k_cache, cross_v_cache in kv_caches:
  115. self_attn_kv_caches.append(self_k_cache)
  116. self_attn_kv_caches.append(self_v_cache)
  117. cross_attn_kv_caches.append(cross_k_cache)
  118. cross_attn_kv_caches.append(cross_v_cache)
  119. return self_attn_kv_caches, cross_attn_kv_caches
  120. # Create alignment heads for timestamps
  121. # Shape is (num_alignment_heads, 2)
  122. def get_sample_alignment_heads(
  123. config: WhisperConfig,
  124. device: torch.device,
  125. num_alignment_heads: int = 6,
  126. use_int32: bool = True,
  127. ):
  128. torch_dtype = torch.int32 if use_int32 else torch.int64
  129. alignment_heads = torch.ones((num_alignment_heads, 2), device=device, dtype=torch_dtype)
  130. return alignment_heads
  131. # Create length of start-of-transcription sequence for timestamps
  132. # Shape is (1)
  133. def get_sample_sot_sequence_length(
  134. device: torch.device,
  135. sot_sequence_length: int,
  136. use_int32: bool = False,
  137. ):
  138. torch_dtype = torch.int32 if use_int32 else torch.int64
  139. sot_length = torch.tensor([sot_sequence_length], device=device, dtype=torch_dtype)
  140. return sot_length
  141. # Create segment length for timestamps
  142. # Shape is (1)
  143. def get_sample_segment_length(
  144. device: torch.device,
  145. segment_length: int,
  146. use_int32: bool = False,
  147. ):
  148. torch_dtype = torch.int32 if use_int32 else torch.int64
  149. segment_size = torch.tensor([segment_length], device=device, dtype=torch_dtype)
  150. return segment_size
  151. # Create QKs for timestamps
  152. # Shape is (batch_size, num_heads, sequence_length, num_frames // 2)
  153. def get_sample_QKs( # noqa: N802
  154. config: WhisperConfig,
  155. device: torch.device,
  156. batch_size: int,
  157. sequence_length: int,
  158. use_fp16: bool = False,
  159. ):
  160. num_heads = config.decoder_attention_heads
  161. torch_dtype = torch.float16 if use_fp16 else torch.float32
  162. QKs = [ # noqa: N806
  163. torch.rand(
  164. batch_size, num_heads, sequence_length, config.max_source_positions, device=device, dtype=torch_dtype
  165. )
  166. for _ in range(config.decoder_layers)
  167. ]
  168. return QKs
  169. # Create inputs for encoder component of Whisper
  170. def get_sample_encoder_inputs(
  171. config: WhisperConfig,
  172. device: torch.device,
  173. batch_size: int,
  174. sequence_length: int = 3000,
  175. use_fp16: bool = False,
  176. ):
  177. audio_features = get_sample_audio_features(config, device, batch_size, sequence_length, use_fp16)
  178. return {"audio_features": audio_features}
  179. # Create inputs for encoder component + first pass through decoder component of Whisper
  180. def get_sample_encoder_decoder_init_inputs(
  181. config: WhisperConfig,
  182. device: torch.device,
  183. batch_size: int,
  184. decoder_sequence_length: int,
  185. encoder_sequence_length: int = 3000,
  186. use_fp16: bool = False,
  187. use_int32: bool = True,
  188. ):
  189. audio_features = get_sample_audio_features(config, device, batch_size, encoder_sequence_length, use_fp16)
  190. decoder_input_ids = get_sample_decoder_input_ids(config, device, batch_size, decoder_sequence_length, use_int32)
  191. return {"audio_features": audio_features, "decoder_input_ids": decoder_input_ids}
  192. # Create inputs for decoder component of Whisper
  193. # Inputs for first pass through the decoder (i.e. decoder-init): decoder_input_ids, encoder_hidden_states
  194. # Inputs for subsequent passes through the decoder (i.e. decoder-with-past): decoder_input_ids, past_key_values
  195. def get_sample_decoder_inputs(
  196. config: WhisperConfig,
  197. device: torch.device,
  198. batch_size: int,
  199. past_sequence_length: int,
  200. sequence_length: int,
  201. use_fp16: bool = False,
  202. use_int32: bool = True,
  203. ):
  204. decoder_input_ids = get_sample_decoder_input_ids(config, device, batch_size, sequence_length, use_int32)
  205. encoder_hidden_states = get_sample_encoder_hidden_states(config, device, batch_size, use_fp16)
  206. past_key_values = get_sample_past_key_values(config, device, batch_size, past_sequence_length, use_fp16)
  207. return {
  208. "decoder_input_ids": decoder_input_ids,
  209. "encoder_hidden_states": encoder_hidden_states,
  210. "past_key_values": past_key_values,
  211. }
  212. # Create inputs for timestamps component of Whisper
  213. def get_sample_jump_times_inputs(
  214. config: WhisperConfig,
  215. device: torch.device,
  216. batch_size: int,
  217. sequence_length: int,
  218. num_alignment_heads: int,
  219. sot_sequence_length: int,
  220. segment_length: int,
  221. use_fp16: bool = False,
  222. use_int32: bool = True,
  223. ):
  224. alignment_heads = get_sample_alignment_heads(config, device, num_alignment_heads, use_int32)
  225. # lengths need to be int64 because subsequent 'Slice' ops only take int64 inputs
  226. sot_sequence_length = get_sample_sot_sequence_length(device, sot_sequence_length)
  227. segment_length = get_sample_segment_length(device, segment_length)
  228. QKs = get_sample_QKs(config, device, batch_size, sequence_length, use_fp16) # noqa: N806
  229. return {
  230. "alignment_heads": alignment_heads,
  231. "sot_sequence_length": sot_sequence_length,
  232. "segment_length": segment_length,
  233. "QKs": QKs,
  234. }
  235. # Convert PyTorch inputs to ONNX Runtime inputs
  236. def convert_inputs_for_ort(
  237. inputs: dict,
  238. model: InferenceSession,
  239. ):
  240. self_attn_kv_caches, cross_attn_kv_caches = None, None
  241. batch_size, num_heads, past_seq_len, head_size = 0, 0, 0, 0
  242. num_beams, max_seq_len = 1, 448
  243. if "past_key_values" in inputs:
  244. (self_attn_kv_caches, cross_attn_kv_caches) = group_past_key_values(inputs["past_key_values"])
  245. batch_size, num_heads, past_seq_len, head_size = self_attn_kv_caches[0].shape
  246. ort_inputs = {}
  247. model_inputs = list(map(lambda i: i.name, model.get_inputs())) # noqa: C417
  248. use_buffer_sharing = "cache_indirection" in model_inputs
  249. for name in model_inputs:
  250. if name in {"audio_features", "encoder_input_ids"}:
  251. # Encoder input
  252. ort_inputs[name] = inputs["audio_features"].detach().cpu().numpy()
  253. elif name == "encoder_hidden_states":
  254. # Encoder output
  255. ort_inputs[name] = inputs["encoder_hidden_states"].detach().cpu().numpy()
  256. elif name in {"decoder_input_ids", "input_ids"}:
  257. # Decoder input
  258. ort_inputs[name] = inputs["decoder_input_ids"].detach().cpu().numpy()
  259. elif "past_key_self" in name or "past_value_self" in name:
  260. # Decoder input
  261. orig_kv_cache = self_attn_kv_caches.pop(0).detach().cpu().numpy()
  262. if use_buffer_sharing:
  263. new_kv_cache = np.zeros((batch_size, num_heads, max_seq_len, head_size), dtype=orig_kv_cache.dtype)
  264. new_kv_cache[:batch_size, :num_heads, :past_seq_len, :head_size] = orig_kv_cache
  265. ort_inputs[name] = new_kv_cache
  266. else:
  267. ort_inputs[name] = orig_kv_cache
  268. elif "past_key_cross" in name or "past_value_cross" in name:
  269. # Decoder input
  270. orig_kv_cache = cross_attn_kv_caches.pop(0).detach().cpu().numpy()
  271. ort_inputs[name] = orig_kv_cache
  272. elif name == "past_sequence_length":
  273. # Decoder input
  274. ort_inputs[name] = np.array([past_seq_len], dtype=np.int32)
  275. elif name == "cache_indirection":
  276. # Decoder input
  277. ort_inputs[name] = np.zeros((batch_size, num_beams, max_seq_len), dtype=np.int32)
  278. elif name == "alignment_heads":
  279. # Jump times input
  280. ort_inputs[name] = inputs["alignment_heads"].detach().cpu().numpy()
  281. elif name == "sot_sequence_length":
  282. # Jump times input
  283. ort_inputs[name] = inputs["sot_sequence_length"].detach().cpu().numpy()
  284. elif name == "segment_length":
  285. # Jump times input
  286. ort_inputs[name] = inputs["segment_length"].detach().cpu().numpy()
  287. elif "cross_qk" in name:
  288. # Jump times input
  289. ort_inputs[name] = inputs["QKs"].pop(0).detach().cpu().numpy()
  290. else:
  291. raise ValueError(f"Unknown name not recognized: {name}")
  292. return ort_inputs
  293. # Get dynamic axes for all inputs and outputs to the model
  294. def get_model_dynamic_axes(
  295. config: WhisperConfig,
  296. input_names: list[str],
  297. output_names: list[str],
  298. ):
  299. dynamic_axes = {}
  300. for name in input_names + output_names:
  301. if name in {"audio_features", "encoder_input_ids"}:
  302. # shape is (batch_size, num_mels, num_frames)
  303. dynamic_axes[name] = {0: "batch_size"}
  304. elif name in {"input_ids", "decoder_input_ids"}:
  305. # shape is (batch_size, sequence_length)
  306. dynamic_axes[name] = {0: "batch_size", 1: "sequence_length"}
  307. elif name == "alignment_heads":
  308. # shape is (num_alignment_heads, 2)
  309. dynamic_axes[name] = {0: "num_alignment_heads"}
  310. elif name in {"sot_sequence_length", "segment_length"}:
  311. # shape is (1)
  312. pass
  313. elif name == "logits":
  314. # shape is (batch_size, sequence_length, vocab_size)
  315. dynamic_axes[name] = {0: "batch_size", 1: "sequence_length"}
  316. elif name == "encoder_hidden_states":
  317. # shape is (batch_size, num_frames // 2, hidden_size)
  318. dynamic_axes[name] = {0: "batch_size"}
  319. elif "past_key_self" in name or "past_value_self" in name:
  320. # shape is (batch_size, num_heads, past_sequence_length, head_size)
  321. dynamic_axes[name] = {0: "batch_size", 2: "past_sequence_length"}
  322. elif "present_key_self" in name or "present_value_self" in name:
  323. # shape is (batch_size, num_heads, past_sequence_length + sequence_length, head_size),
  324. # which is equal to (batch_size, num_heads, total_sequence_length, head_size)
  325. dynamic_axes[name] = {0: "batch_size", 2: "total_sequence_length"}
  326. elif (
  327. "past_key_cross" in name
  328. or "past_value_cross" in name
  329. or "present_key_cross" in name
  330. or "present_value_cross" in name
  331. ):
  332. # shape is (batch_size, num_heads, num_frames // 2, head_size)
  333. dynamic_axes[name] = {0: "batch_size"}
  334. elif "cross_qk" in name:
  335. # shape is (batch_size, num_heads, source_sequence_length, target_sequence_length)
  336. dynamic_axes[name] = {0: "batch_size", 2: "sequence_length"}
  337. elif "jump_times" in name:
  338. # shape is (batch_size, max_length)
  339. dynamic_axes[name] = {0: "batch_size", 1: "max_length"}
  340. else:
  341. raise Exception(f"Unknown input or output name found: {name}")
  342. return dynamic_axes