modeling_clvp.py 85 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160116111621163116411651166116711681169117011711172117311741175117611771178117911801181118211831184118511861187118811891190119111921193119411951196119711981199120012011202120312041205120612071208120912101211121212131214121512161217121812191220122112221223122412251226122712281229123012311232123312341235123612371238123912401241124212431244124512461247124812491250125112521253125412551256125712581259126012611262126312641265126612671268126912701271127212731274127512761277127812791280128112821283128412851286128712881289129012911292129312941295129612971298129913001301130213031304130513061307130813091310131113121313131413151316131713181319132013211322132313241325132613271328132913301331133213331334133513361337133813391340134113421343134413451346134713481349135013511352135313541355135613571358135913601361136213631364136513661367136813691370137113721373137413751376137713781379138013811382138313841385138613871388138913901391139213931394139513961397139813991400140114021403140414051406140714081409141014111412141314141415141614171418141914201421142214231424142514261427142814291430143114321433143414351436143714381439144014411442144314441445144614471448144914501451145214531454145514561457145814591460146114621463146414651466146714681469147014711472147314741475147614771478147914801481148214831484148514861487148814891490149114921493149414951496149714981499150015011502150315041505150615071508150915101511151215131514151515161517151815191520152115221523152415251526152715281529153015311532153315341535153615371538153915401541154215431544154515461547154815491550155115521553155415551556155715581559156015611562156315641565156615671568156915701571157215731574157515761577157815791580158115821583158415851586158715881589159015911592159315941595159615971598159916001601160216031604160516061607160816091610161116121613161416151616161716181619162016211622162316241625162616271628162916301631163216331634163516361637163816391640164116421643164416451646164716481649165016511652165316541655165616571658165916601661166216631664166516661667166816691670167116721673167416751676167716781679168016811682168316841685168616871688168916901691169216931694169516961697169816991700170117021703170417051706170717081709171017111712171317141715171617171718171917201721172217231724172517261727172817291730173117321733173417351736173717381739174017411742174317441745174617471748174917501751175217531754175517561757175817591760176117621763176417651766176717681769177017711772177317741775177617771778177917801781178217831784178517861787178817891790179117921793179417951796179717981799180018011802180318041805180618071808180918101811181218131814181518161817181818191820182118221823182418251826182718281829183018311832183318341835183618371838183918401841184218431844184518461847184818491850185118521853185418551856185718581859186018611862186318641865186618671868186918701871187218731874187518761877187818791880188118821883188418851886188718881889189018911892189318941895189618971898189919001901190219031904190519061907190819091910191119121913191419151916191719181919192019211922192319241925192619271928192919301931193219331934193519361937193819391940194119421943194419451946194719481949195019511952195319541955195619571958195919601961
  1. # coding=utf-8
  2. # Copyright 2023 The HuggingFace Team. All rights reserved.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. """PyTorch CLVP model."""
  16. import copy
  17. import math
  18. from dataclasses import dataclass
  19. from typing import Callable, Optional, Union
  20. import torch
  21. from torch import nn
  22. from torch.nn import CrossEntropyLoss
  23. from ...activations import ACT2FN, get_activation
  24. from ...cache_utils import Cache, DynamicCache
  25. from ...generation import GenerationConfig, GenerationMixin
  26. from ...modeling_attn_mask_utils import _prepare_4d_attention_mask, _prepare_4d_causal_attention_mask
  27. from ...modeling_outputs import (
  28. BaseModelOutput,
  29. BaseModelOutputWithPastAndCrossAttentions,
  30. BaseModelOutputWithPooling,
  31. CausalLMOutputWithCrossAttentions,
  32. )
  33. from ...modeling_utils import PreTrainedModel
  34. from ...pytorch_utils import Conv1D, isin_mps_friendly
  35. from ...utils import (
  36. ModelOutput,
  37. auto_docstring,
  38. logging,
  39. )
  40. from ...utils.deprecation import deprecate_kwarg
  41. from .configuration_clvp import (
  42. ClvpConfig,
  43. ClvpDecoderConfig,
  44. ClvpEncoderConfig,
  45. )
  46. logger = logging.get_logger(__name__)
  47. # Copied from transformers.models.clip.modeling_clip.contrastive_loss
  48. def contrastive_loss(logits: torch.Tensor) -> torch.Tensor:
  49. return nn.functional.cross_entropy(logits, torch.arange(len(logits), device=logits.device))
  50. # Copied from transformers.models.clip.modeling_clip.clip_loss with clip->clvp, image_loss->speech_loss
  51. def clvp_loss(similarity: torch.Tensor) -> torch.Tensor:
  52. caption_loss = contrastive_loss(similarity)
  53. speech_loss = contrastive_loss(similarity.t())
  54. return (caption_loss + speech_loss) / 2.0
  55. # Copied from transformers.models.llama.modeling_llama.rotate_half
  56. def rotate_half(x):
  57. """Rotates half the hidden dims of the input."""
  58. x1 = x[..., : x.shape[-1] // 2]
  59. x2 = x[..., x.shape[-1] // 2 :]
  60. return torch.cat((-x2, x1), dim=-1)
  61. def apply_rotary_pos_emb(q, k, v, cos, sin, position_ids, unsqueeze_dim=1):
  62. """Applies Rotary Position Embedding to the query and key tensors.
  63. Args:
  64. q (`torch.Tensor`): The query tensor.
  65. k (`torch.Tensor`): The key tensor.
  66. cos (`torch.Tensor`): The cosine part of the rotary embedding.
  67. sin (`torch.Tensor`): The sine part of the rotary embedding.
  68. position_ids (`torch.Tensor`):
  69. The position indices of the tokens corresponding to the query and key tensors. For example, this can be
  70. used to pass offsetted position ids when working with a KV-cache.
  71. unsqueeze_dim (`int`, *optional*, defaults to 1):
  72. The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
  73. sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
  74. that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
  75. k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
  76. cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
  77. the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
  78. Returns:
  79. `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
  80. """
  81. cos = cos[position_ids].unsqueeze(unsqueeze_dim)
  82. sin = sin[position_ids].unsqueeze(unsqueeze_dim)
  83. q_embed = (q * cos) + (rotate_half(q) * sin)
  84. k_embed = (k * cos) + (rotate_half(k) * sin)
  85. v_embed = (v * cos) + (rotate_half(v) * sin)
  86. return q_embed, k_embed, v_embed
  87. def _pad_extra_bos_eos_tokens(
  88. input_ids,
  89. attention_mask=None,
  90. pad_token_id=0,
  91. bos_token_id=255,
  92. eos_token_id=0,
  93. add_bos_token=True,
  94. add_eos_token=True,
  95. ):
  96. """
  97. This method adds extra bos and eos tokens to input_ids and accordingly modifies the attention_mask which is used in
  98. `ClvpConditioningEncoder` and the generation loop of the `ClvpModelForConditionalGeneration`.
  99. """
  100. # add the bos token at the beginning
  101. if add_bos_token:
  102. input_ids = torch.nn.functional.pad(input_ids, (1, 0), value=bos_token_id)
  103. attention_mask = (
  104. torch.nn.functional.pad(attention_mask, (1, 0), value=1) if attention_mask is not None else attention_mask
  105. )
  106. modified_input_ids = input_ids
  107. if add_eos_token:
  108. modified_input_ids = torch.zeros(
  109. (input_ids.shape[0], input_ids.shape[1] + 1), dtype=input_ids.dtype, device=input_ids.device
  110. )
  111. for i, each_input_id in enumerate(input_ids):
  112. # locate where the valid tokens end and then add the eos token
  113. if isin_mps_friendly(each_input_id, pad_token_id).sum():
  114. pos = torch.where(each_input_id == pad_token_id)[0].min()
  115. modified_input_ids[i] = torch.concatenate(
  116. [each_input_id[:pos], torch.tensor([eos_token_id], device=input_ids.device), each_input_id[pos:]]
  117. )
  118. else:
  119. # if there are no pad tokens present, then add eos to the end
  120. modified_input_ids[i] = torch.nn.functional.pad(each_input_id, (0, 1), value=eos_token_id)
  121. attention_mask = (
  122. torch.nn.functional.pad(attention_mask, (1, 0), value=1) if attention_mask is not None else attention_mask
  123. )
  124. return modified_input_ids, attention_mask
  125. @dataclass
  126. @auto_docstring(
  127. custom_intro="""
  128. Base class for CLVP encoder's outputs that contains a pooling of the last hidden states as well as a projection
  129. output (a linear layer on top of the pooled output).
  130. """
  131. )
  132. class ClvpEncoderOutput(ModelOutput):
  133. r"""
  134. embeds (`torch.FloatTensor` of shape `(batch_size, output_dim)`, *optional*, returned when model is initialized with `with_projection=True`):
  135. The embeddings obtained by applying the projection layer to the pooler_output.
  136. last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
  137. The hidden state of the last layer of the model.
  138. pooler_output (`torch.FloatTensor` of shape `(batch_size, hidden_size)`):
  139. Pooled output of the `last_hidden_state`.
  140. """
  141. embeds: Optional[torch.FloatTensor] = None
  142. last_hidden_state: Optional[torch.FloatTensor] = None
  143. pooler_output: Optional[torch.FloatTensor] = None
  144. hidden_states: Optional[tuple[torch.FloatTensor]] = None
  145. attentions: Optional[tuple[torch.FloatTensor]] = None
  146. @dataclass
  147. @auto_docstring
  148. class ClvpOutput(ModelOutput):
  149. r"""
  150. loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `return_loss` is `True`):
  151. Contrastive loss for speech-text similarity.
  152. speech_ids (`torch.LongTensor`, *optional*):
  153. speech_ids (or speech candidates) generated by the `ClvpForCausalLM` model.
  154. logits_per_speech (`torch.FloatTensor` of shape `(speech_batch_size, text_batch_size)`):
  155. The scaled dot product scores between `speech_embeds` and `text_embeds`. This represents the speech-text
  156. similarity scores.
  157. logits_per_text (`torch.FloatTensor` of shape `(text_batch_size, speech_batch_size)`):
  158. The scaled dot product scores between `text_embeds` and `speech_embeds`. This represents the text-speech
  159. similarity scores.
  160. text_embeds (`torch.FloatTensor` of shape `(batch_size, output_dim`):
  161. The text embeddings obtained by applying the projection layer to the pooled output of the text encoder
  162. model.
  163. speech_embeds (`torch.FloatTensor` of shape `(batch_size, output_dim`):
  164. The speech embeddings obtained by applying the projection layer to the pooled output of the speech encoder
  165. model.
  166. text_model_output (`BaseModelOutputWithPooling`):
  167. The pooled output of the `last_hidden_state` of the text encoder Model.
  168. speech_model_output (`BaseModelOutputWithPooling`):
  169. The pooled output of the `last_hidden_state` of the speech encoder Model.
  170. decoder_hidden_states (`torch.FloatTensor`, *optional*):
  171. The hidden states of the decoder model.
  172. text_encoder_hidden_states (`torch.FloatTensor`, *optional*):
  173. The hidden states of the text encoder model.
  174. speech_encoder_hidden_states (`torch.FloatTensor`, *optional*):
  175. The hidden states of the speech encoder model.
  176. """
  177. loss: Optional[torch.FloatTensor] = None
  178. speech_ids: Optional[torch.LongTensor] = None
  179. logits_per_speech: Optional[torch.FloatTensor] = None
  180. logits_per_text: Optional[torch.FloatTensor] = None
  181. text_embeds: Optional[torch.FloatTensor] = None
  182. speech_embeds: Optional[torch.FloatTensor] = None
  183. text_model_output: BaseModelOutputWithPooling = None
  184. speech_model_output: BaseModelOutputWithPooling = None
  185. decoder_hidden_states: Optional[torch.FloatTensor] = None
  186. text_encoder_hidden_states: Optional[torch.FloatTensor] = None
  187. speech_encoder_hidden_states: Optional[torch.FloatTensor] = None
  188. # Copied from transformers.models.llama.modeling_llama.LlamaRMSNorm with Llama->Clvp
  189. class ClvpRMSNorm(nn.Module):
  190. def __init__(self, hidden_size, eps=1e-6):
  191. """
  192. ClvpRMSNorm is equivalent to T5LayerNorm
  193. """
  194. super().__init__()
  195. self.weight = nn.Parameter(torch.ones(hidden_size))
  196. self.variance_epsilon = eps
  197. def forward(self, hidden_states):
  198. input_dtype = hidden_states.dtype
  199. hidden_states = hidden_states.to(torch.float32)
  200. variance = hidden_states.pow(2).mean(-1, keepdim=True)
  201. hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
  202. return self.weight * hidden_states.to(input_dtype)
  203. def extra_repr(self):
  204. return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}"
  205. class ClvpRotaryPositionalEmbedding(nn.Module):
  206. """
  207. Rotary Position Embedding Class for CLVP. It was proposed in the paper 'ROFORMER: ENHANCED TRANSFORMER WITH ROTARY
  208. POSITION EMBEDDING', Please see https://huggingface.co/papers/2104.09864v1.pdf .
  209. """
  210. def __init__(self, config):
  211. super().__init__()
  212. dim = max(config.projection_dim // (config.num_attention_heads * 2), 32)
  213. inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2, dtype=torch.int64).float() / dim))
  214. self.register_buffer("inv_freq", inv_freq)
  215. self.cached_sequence_length = None
  216. self.cached_rotary_positional_embedding = None
  217. def forward(self, hidden_states: torch.FloatTensor) -> torch.FloatTensor:
  218. sequence_length = hidden_states.shape[1]
  219. if sequence_length == self.cached_sequence_length and self.cached_rotary_positional_embedding is not None:
  220. return self.cached_rotary_positional_embedding
  221. self.cached_sequence_length = sequence_length
  222. time_stamps = torch.arange(sequence_length, device=hidden_states.device).type_as(self.inv_freq)
  223. freqs = torch.einsum("i,j->ij", time_stamps, self.inv_freq)
  224. embeddings = torch.cat((freqs, freqs), dim=-1)
  225. self.cached_rotary_positional_embedding = embeddings.unsqueeze(0)
  226. return self.cached_rotary_positional_embedding
  227. class ClvpSelfAttention(nn.Module):
  228. """
  229. Multi-headed attention to combine Absolute and Rotary Positional Embeddings into a single Attention module.
  230. """
  231. def __init__(self, config, layer_idx=None):
  232. super().__init__()
  233. self.config = config
  234. self.embed_dim = config.hidden_size
  235. self.num_heads = config.num_attention_heads
  236. self.head_dim = self.embed_dim // self.num_heads
  237. if self.head_dim * self.num_heads != self.embed_dim:
  238. raise ValueError(
  239. f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`:"
  240. f" {self.num_heads})."
  241. )
  242. self.scale = self.head_dim**-0.5
  243. self.dropout = config.attention_dropout
  244. self.layer_idx = layer_idx
  245. if hasattr(config, "max_position_embeddings"):
  246. max_positions = config.max_position_embeddings
  247. bias = torch.tril(torch.ones((max_positions, max_positions), dtype=torch.bool))
  248. bias = bias.view(1, 1, max_positions, max_positions)
  249. self.register_buffer("bias", bias, persistent=False)
  250. self.k_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=config.use_attention_bias)
  251. self.v_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=config.use_attention_bias)
  252. self.q_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=config.use_attention_bias)
  253. self.out_proj = nn.Linear(self.embed_dim, self.embed_dim)
  254. def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
  255. return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()
  256. @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58")
  257. def forward(
  258. self,
  259. hidden_states: torch.FloatTensor,
  260. rotary_pos_emb: Optional[torch.FloatTensor] = None,
  261. attention_mask: Optional[torch.LongTensor] = None,
  262. position_ids: Optional[torch.LongTensor] = None,
  263. past_key_values: Optional[Cache] = None,
  264. use_cache: Optional[bool] = False,
  265. head_mask: Optional[torch.FloatTensor] = None,
  266. output_attentions: Optional[bool] = False,
  267. cache_position: Optional[torch.Tensor] = None,
  268. ) -> tuple[torch.FloatTensor, Optional[torch.FloatTensor], Optional[tuple[torch.FloatTensor]]]:
  269. # Raise error when position_ids is None but rotary_pos_emb is provided, because we need that when applying
  270. # rotary_pos_emb to query and key states.
  271. if rotary_pos_emb is not None and position_ids is None:
  272. raise ValueError("`position_ids` must be provided when `rotary_pos_emb` is not None.")
  273. bsz, _, embed_dim = hidden_states.size()
  274. # get query proj
  275. query_states = self._shape(self.q_proj(hidden_states), -1, bsz) * self.scale
  276. key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
  277. value_states = self._shape(self.v_proj(hidden_states), -1, bsz)
  278. if past_key_values is not None:
  279. key_states, value_states = past_key_values.update(
  280. key_states, value_states, self.layer_idx, {"cache_position": cache_position}
  281. )
  282. if rotary_pos_emb is not None:
  283. rotary_emb_dim = rotary_pos_emb.shape[-1]
  284. # Partial rotary embedding
  285. query_rot, query_pass = (
  286. query_states[..., :rotary_emb_dim],
  287. query_states[..., rotary_emb_dim:],
  288. )
  289. key_rot, key_pass = (
  290. key_states[..., :rotary_emb_dim],
  291. key_states[..., rotary_emb_dim:],
  292. )
  293. value_rot, value_pass = (
  294. value_states[..., :rotary_emb_dim],
  295. value_states[..., rotary_emb_dim:],
  296. )
  297. cos, sin = rotary_pos_emb.cos().squeeze(0), rotary_pos_emb.sin().squeeze(0)
  298. query_rot, key_rot, value_rot = apply_rotary_pos_emb(query_rot, key_rot, value_rot, cos, sin, position_ids)
  299. # [batch_size, num_heads, seq_length, head_dim]
  300. query_states = torch.cat((query_rot, query_pass), dim=-1)
  301. key_states = torch.cat((key_rot, key_pass), dim=-1)
  302. value_states = torch.cat((value_rot, value_pass), dim=-1)
  303. tgt_len = query_states.shape[2]
  304. src_len = key_states.shape[2]
  305. attn_weights = torch.matmul(query_states, key_states.transpose(2, 3))
  306. if attention_mask is not None:
  307. if attention_mask.size() != (bsz, 1, tgt_len, src_len):
  308. raise ValueError(
  309. f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {attention_mask.size()}"
  310. )
  311. attn_weights = attn_weights + attention_mask
  312. attn_weights = nn.functional.softmax(attn_weights, dim=-1)
  313. # Mask heads if we want to
  314. if head_mask is not None:
  315. attn_weights = attn_weights * head_mask
  316. attn_probs = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training)
  317. attn_output = torch.matmul(attn_probs, value_states)
  318. if attn_output.size() != (bsz, self.num_heads, tgt_len, self.head_dim):
  319. raise ValueError(
  320. f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is"
  321. f" {attn_output.size()}"
  322. )
  323. attn_output = attn_output.transpose(1, 2).contiguous()
  324. attn_output = attn_output.reshape(bsz, tgt_len, self.embed_dim)
  325. attn_output = self.out_proj(attn_output)
  326. return attn_output, attn_weights
  327. class ClvpGatedLinearUnit(nn.Module):
  328. """
  329. `ClvpGatedLinearUnit` uses the second half of the `hidden_states` to act as a gate for the first half of the
  330. `hidden_states` which controls the flow of data from the first of the tensor.
  331. """
  332. def __init__(self, config):
  333. super().__init__()
  334. self.activation_fn = ACT2FN[config.hidden_act]
  335. self.proj = nn.Linear(config.hidden_size, config.intermediate_size * 2)
  336. def forward(self, hidden_states: torch.FloatTensor) -> torch.FloatTensor:
  337. hidden_states, gate = self.proj(hidden_states).chunk(2, dim=-1)
  338. return hidden_states * self.activation_fn(gate)
  339. class ClvpEncoderMLP(nn.Module):
  340. """
  341. This MLP is used in CLVP speech or text encoder models.
  342. """
  343. def __init__(self, config):
  344. super().__init__()
  345. self.config = config
  346. self.fc1 = ClvpGatedLinearUnit(config)
  347. self.fc2 = nn.Linear(config.intermediate_size, config.hidden_size)
  348. self.dropout_layer = nn.Dropout(config.dropout)
  349. def forward(self, hidden_states: torch.FloatTensor) -> torch.FloatTensor:
  350. hidden_states = self.fc1(hidden_states)
  351. hidden_states = self.dropout_layer(hidden_states)
  352. hidden_states = self.fc2(hidden_states)
  353. return hidden_states
  354. class ClvpEncoderLayer(nn.Module):
  355. def __init__(self, config: ClvpConfig):
  356. super().__init__()
  357. self.config = config
  358. self.embed_dim = config.hidden_size
  359. self.self_attn = ClvpSelfAttention(config)
  360. self.mlp = ClvpEncoderMLP(config)
  361. self.input_rmsnorm = ClvpRMSNorm(self.embed_dim, eps=config.layer_norm_eps)
  362. self.post_attention_rmsnorm = ClvpRMSNorm(self.embed_dim, eps=config.layer_norm_eps)
  363. def forward(
  364. self,
  365. hidden_states: torch.FloatTensor,
  366. rotary_pos_emb: torch.FloatTensor,
  367. attention_mask: torch.LongTensor,
  368. position_ids: torch.LongTensor,
  369. output_attentions: Optional[bool] = False,
  370. ) -> tuple[torch.FloatTensor]:
  371. """
  372. Args:
  373. hidden_states (`torch.FloatTensor` of shape `(batch, seq_len, embed_dim)`):
  374. input to the layer.
  375. rotary_pos_emb (`torch.FloatTensor`):
  376. rotary position embeddings generated by `ClvpRotaryPositionalEmbedding` module.
  377. attention_mask (`torch.FloatTensor` of shape `(batch, 1, tgt_len, src_len)`):
  378. attention mask where padding elements are indicated by very large negative values.
  379. position_ids (`torch.LongTensor`):
  380. Denotes position ids of the input tokens.
  381. output_attentions (`bool`, *optional*, defaults to `False`):
  382. Whether or not to return the attentions tensors of all attention layers. See `attentions` under
  383. returned tensors for more detail.
  384. """
  385. residual = hidden_states
  386. hidden_states = self.input_rmsnorm(hidden_states)
  387. hidden_states, attn_weights = self.self_attn(
  388. hidden_states=hidden_states,
  389. rotary_pos_emb=rotary_pos_emb,
  390. attention_mask=attention_mask,
  391. position_ids=position_ids,
  392. output_attentions=output_attentions,
  393. )
  394. hidden_states = residual + hidden_states
  395. residual = hidden_states
  396. hidden_states = self.post_attention_rmsnorm(hidden_states)
  397. hidden_states = self.mlp(hidden_states)
  398. hidden_states = residual + hidden_states
  399. return hidden_states, attn_weights
  400. # Copied from transformers.models.xlm.modeling_xlm.XLMSequenceSummary with XLM->Clvp
  401. class ClvpSequenceSummary(nn.Module):
  402. r"""
  403. Compute a single vector summary of a sequence hidden states.
  404. Args:
  405. config ([`ClvpConfig`]):
  406. The config used by the model. Relevant arguments in the config class of the model are (refer to the actual
  407. config class of your model for the default values it uses):
  408. - **summary_type** (`str`) -- The method to use to make this summary. Accepted values are:
  409. - `"last"` -- Take the last token hidden state (like XLNet)
  410. - `"first"` -- Take the first token hidden state (like Bert)
  411. - `"mean"` -- Take the mean of all tokens hidden states
  412. - `"cls_index"` -- Supply a Tensor of classification token position (GPT/GPT-2)
  413. - `"attn"` -- Not implemented now, use multi-head attention
  414. - **summary_use_proj** (`bool`) -- Add a projection after the vector extraction.
  415. - **summary_proj_to_labels** (`bool`) -- If `True`, the projection outputs to `config.num_labels` classes
  416. (otherwise to `config.hidden_size`).
  417. - **summary_activation** (`Optional[str]`) -- Set to `"tanh"` to add a tanh activation to the output,
  418. another string or `None` will add no activation.
  419. - **summary_first_dropout** (`float`) -- Optional dropout probability before the projection and activation.
  420. - **summary_last_dropout** (`float`)-- Optional dropout probability after the projection and activation.
  421. """
  422. def __init__(self, config: ClvpConfig):
  423. super().__init__()
  424. self.summary_type = getattr(config, "summary_type", "last")
  425. if self.summary_type == "attn":
  426. # We should use a standard multi-head attention module with absolute positional embedding for that.
  427. # Cf. https://github.com/zihangdai/xlnet/blob/master/modeling.py#L253-L276
  428. # We can probably just use the multi-head attention module of PyTorch >=1.1.0
  429. raise NotImplementedError
  430. self.summary = nn.Identity()
  431. if hasattr(config, "summary_use_proj") and config.summary_use_proj:
  432. if hasattr(config, "summary_proj_to_labels") and config.summary_proj_to_labels and config.num_labels > 0:
  433. num_classes = config.num_labels
  434. else:
  435. num_classes = config.hidden_size
  436. self.summary = nn.Linear(config.hidden_size, num_classes)
  437. activation_string = getattr(config, "summary_activation", None)
  438. self.activation: Callable = get_activation(activation_string) if activation_string else nn.Identity()
  439. self.first_dropout = nn.Identity()
  440. if hasattr(config, "summary_first_dropout") and config.summary_first_dropout > 0:
  441. self.first_dropout = nn.Dropout(config.summary_first_dropout)
  442. self.last_dropout = nn.Identity()
  443. if hasattr(config, "summary_last_dropout") and config.summary_last_dropout > 0:
  444. self.last_dropout = nn.Dropout(config.summary_last_dropout)
  445. def forward(
  446. self, hidden_states: torch.FloatTensor, cls_index: Optional[torch.LongTensor] = None
  447. ) -> torch.FloatTensor:
  448. """
  449. Compute a single vector summary of a sequence hidden states.
  450. Args:
  451. hidden_states (`torch.FloatTensor` of shape `[batch_size, seq_len, hidden_size]`):
  452. The hidden states of the last layer.
  453. cls_index (`torch.LongTensor` of shape `[batch_size]` or `[batch_size, ...]` where ... are optional leading dimensions of `hidden_states`, *optional*):
  454. Used if `summary_type == "cls_index"` and takes the last token of the sequence as classification token.
  455. Returns:
  456. `torch.FloatTensor`: The summary of the sequence hidden states.
  457. """
  458. if self.summary_type == "last":
  459. output = hidden_states[:, -1]
  460. elif self.summary_type == "first":
  461. output = hidden_states[:, 0]
  462. elif self.summary_type == "mean":
  463. output = hidden_states.mean(dim=1)
  464. elif self.summary_type == "cls_index":
  465. if cls_index is None:
  466. cls_index = torch.full_like(
  467. hidden_states[..., :1, :],
  468. hidden_states.shape[-2] - 1,
  469. dtype=torch.long,
  470. )
  471. else:
  472. cls_index = cls_index.unsqueeze(-1).unsqueeze(-1)
  473. cls_index = cls_index.expand((-1,) * (cls_index.dim() - 1) + (hidden_states.size(-1),))
  474. # shape of cls_index: (bsz, XX, 1, hidden_size) where XX are optional leading dim of hidden_states
  475. output = hidden_states.gather(-2, cls_index).squeeze(-2) # shape (bsz, XX, hidden_size)
  476. elif self.summary_type == "attn":
  477. raise NotImplementedError
  478. output = self.first_dropout(output)
  479. output = self.summary(output)
  480. output = self.activation(output)
  481. output = self.last_dropout(output)
  482. return output
  483. # Copied from transformers.models.gpt2.modeling_gpt2.GPT2MLP with GPT2->ClvpDecoderMLP
  484. class ClvpDecoderMLP(nn.Module):
  485. def __init__(self, intermediate_size, config):
  486. super().__init__()
  487. embed_dim = config.hidden_size
  488. self.c_fc = Conv1D(intermediate_size, embed_dim)
  489. self.c_proj = Conv1D(embed_dim, intermediate_size)
  490. self.act = ACT2FN[config.activation_function]
  491. self.dropout = nn.Dropout(config.resid_pdrop)
  492. def forward(self, hidden_states: Optional[tuple[torch.FloatTensor]]) -> torch.FloatTensor:
  493. hidden_states = self.c_fc(hidden_states)
  494. hidden_states = self.act(hidden_states)
  495. hidden_states = self.c_proj(hidden_states)
  496. hidden_states = self.dropout(hidden_states)
  497. return hidden_states
  498. class ClvpDecoderLayer(nn.Module):
  499. def __init__(self, config, layer_idx=None):
  500. super().__init__()
  501. hidden_size = config.hidden_size
  502. inner_dim = config.n_inner if config.n_inner is not None else 4 * hidden_size
  503. self.input_layernorm = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
  504. self.attn = ClvpSelfAttention(config, layer_idx=layer_idx)
  505. self.post_attention_layernorm = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
  506. self.mlp = ClvpDecoderMLP(inner_dim, config)
  507. @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58")
  508. def forward(
  509. self,
  510. hidden_states: Optional[tuple[torch.FloatTensor]],
  511. past_key_values: Optional[Cache] = None,
  512. attention_mask: Optional[torch.LongTensor] = None,
  513. position_ids: Optional[torch.LongTensor] = None,
  514. head_mask: Optional[torch.FloatTensor] = None,
  515. use_cache: Optional[bool] = False,
  516. output_attentions: Optional[bool] = False,
  517. cache_position: Optional[torch.Tensor] = None,
  518. ) -> Union[tuple[torch.Tensor], Optional[tuple[torch.Tensor, tuple[torch.FloatTensor, ...]]]]:
  519. residual = hidden_states
  520. hidden_states = self.input_layernorm(hidden_states)
  521. attn_outputs = self.attn(
  522. hidden_states,
  523. past_key_values=past_key_values,
  524. attention_mask=attention_mask,
  525. position_ids=position_ids,
  526. head_mask=head_mask,
  527. use_cache=use_cache,
  528. output_attentions=output_attentions,
  529. cache_position=cache_position,
  530. )
  531. attn_output = attn_outputs[0]
  532. # residual connection
  533. hidden_states = attn_output + residual
  534. residual = hidden_states
  535. hidden_states = self.post_attention_layernorm(hidden_states)
  536. feed_forward_hidden_states = self.mlp(hidden_states)
  537. # residual connection
  538. hidden_states = residual + feed_forward_hidden_states
  539. return (hidden_states,) + attn_outputs[1:]
  540. class ClvpConditioningEncoder(nn.Module):
  541. """
  542. This class processes the log-mel spectrograms(extracted by the Feature Extractor) and text tokens(produced by the
  543. tokenizer) as inputs for the decoder model.
  544. First each log-mel spectrogram is processed into a single vector which captures valuable characteristics from each
  545. of them, then the text tokens are converted into token embeddings and position embeddings are added afterwards.
  546. Both of these vectors are concatenated and then passed to the decoder model.
  547. The text tokens helps to incorporate the "text information" and the log-mel spectrogram is used to specify the
  548. "voice characteristics" into the generated mel tokens.
  549. """
  550. def __init__(self, config: ClvpConfig):
  551. super().__init__()
  552. self.text_config = config.text_config
  553. self.decoder_config = config.decoder_config
  554. self.text_token_embedding = nn.Embedding(self.text_config.vocab_size, self.decoder_config.hidden_size)
  555. self.text_position_embedding = nn.Embedding(
  556. self.decoder_config.max_text_tokens, self.decoder_config.hidden_size
  557. )
  558. self.mel_conv = nn.Conv1d(self.decoder_config.feature_size, self.decoder_config.hidden_size, kernel_size=1)
  559. # define group norms to be used before each attention layer
  560. num_groups = self.compute_groupnorm_groups(self.decoder_config.hidden_size)
  561. self.group_norms = nn.ModuleList(
  562. [
  563. nn.GroupNorm(num_groups, self.decoder_config.hidden_size, eps=1e-5, affine=True)
  564. for _ in range(self.decoder_config.num_mel_attn_blocks)
  565. ]
  566. )
  567. # define the attention layers
  568. self.mel_attn_blocks = nn.ModuleList(
  569. [ClvpSelfAttention(self.decoder_config) for _ in range(self.decoder_config.num_mel_attn_blocks)]
  570. )
  571. self.gradient_checkpointing = False
  572. def compute_groupnorm_groups(self, channels: int, groups: int = 32):
  573. """
  574. Calculates the value of `num_groups` for nn.GroupNorm. This logic is taken from the official tortoise
  575. repository. link :
  576. https://github.com/neonbjb/tortoise-tts/blob/4003544b6ff4b68c09856e04d3eff9da26d023c2/tortoise/models/arch_util.py#L26
  577. """
  578. if channels <= 16:
  579. groups = 8
  580. elif channels <= 64:
  581. groups = 16
  582. while channels % groups != 0:
  583. groups = int(groups / 2)
  584. if groups <= 2:
  585. raise ValueError(
  586. f"Number of groups for the GroupNorm must be greater than 2, but it is {groups}."
  587. f"Please consider using a different `hidden_size`"
  588. )
  589. return groups
  590. def forward(
  591. self,
  592. input_features: torch.FloatTensor,
  593. input_ids: Optional[torch.LongTensor] = None,
  594. inputs_embeds: Optional[torch.FloatTensor] = None,
  595. attention_mask: Optional[torch.LongTensor] = None,
  596. ):
  597. # process text
  598. if input_ids is not None and inputs_embeds is not None:
  599. raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
  600. elif input_ids is not None:
  601. batch_size, seq_length = input_ids.size()
  602. elif inputs_embeds is not None:
  603. batch_size, seq_length = inputs_embeds.size()[:-1]
  604. else:
  605. raise ValueError("You have to specify either input_ids or inputs_embeds")
  606. # construct attention mask if not given
  607. if attention_mask is None:
  608. attention_mask = torch.ones([batch_size, seq_length], dtype=torch.long, device=input_ids.device)
  609. # We add bos and eos input_ids in the modeling file instead of the tokenizer file to keep the logic simple
  610. # This logic is specific to ClvpConditioningEncoder and not used by other modules.
  611. input_ids, attention_mask = _pad_extra_bos_eos_tokens(
  612. input_ids,
  613. attention_mask,
  614. bos_token_id=self.text_config.bos_token_id,
  615. eos_token_id=self.text_config.eos_token_id,
  616. )
  617. inputs_embeds = self.text_token_embedding(input_ids)
  618. position_ids = attention_mask.cumsum(-1) - 1
  619. position_embeds = self.text_position_embedding(position_ids)
  620. text_embeds = inputs_embeds + position_embeds
  621. if self.gradient_checkpointing and self.training:
  622. # process each log-mel spectrogram into a single vector
  623. mel_spec = torch.utils.checkpoint.checkpoint(self.mel_conv, input_features)
  624. for i, mel_attn_block in enumerate(self.mel_attn_blocks):
  625. residual_mel_spec = mel_spec.transpose(1, 2)
  626. mel_spec = torch.utils.checkpoint.checkpoint(self.group_norms[i], mel_spec).transpose(1, 2)
  627. mel_spec = torch.utils.checkpoint.checkpoint(mel_attn_block, mel_spec)[0] + residual_mel_spec
  628. mel_spec = mel_spec.transpose(1, 2)
  629. else:
  630. # process each log-mel spectrogram into a single vector
  631. mel_spec = self.mel_conv(input_features)
  632. for i, mel_attn_block in enumerate(self.mel_attn_blocks):
  633. residual_mel_spec = mel_spec.transpose(1, 2)
  634. mel_spec = self.group_norms[i](mel_spec).transpose(1, 2)
  635. mel_spec = mel_attn_block(mel_spec)[0] + residual_mel_spec
  636. mel_spec = mel_spec.transpose(1, 2)
  637. mel_spec = mel_spec[:, :, 0]
  638. mel_spec = mel_spec.unsqueeze(1)
  639. # repeat if there is either (1 text vs N audios) or (N texts vs 1 audio)
  640. if text_embeds.shape[0] == 1 and mel_spec.shape[0] != 1:
  641. text_embeds = text_embeds.repeat(mel_spec.shape[0], 1, 1)
  642. elif text_embeds.shape[0] != 1 and mel_spec.shape[0] == 1:
  643. mel_spec = mel_spec.repeat(text_embeds.shape[0], 1, 1)
  644. # If there is N texts and M audios we will raise error since the number of text and audio must be same.
  645. elif text_embeds.shape[0] != mel_spec.shape[0]:
  646. raise ValueError(
  647. f"The number of texts and number of audios must be same. "
  648. f"Found {text_embeds.shape[0]} texts vs {mel_spec.shape[0]} audios"
  649. )
  650. return torch.concat([mel_spec, text_embeds], dim=1)
  651. @auto_docstring
  652. class ClvpPreTrainedModel(PreTrainedModel):
  653. config: ClvpConfig
  654. base_model_prefix = "clvp"
  655. supports_gradient_checkpointing = True
  656. _skip_keys_device_placement = "past_key_values"
  657. def _init_weights(self, module: nn.Module):
  658. """Initialize the weights"""
  659. factor = self.config.initializer_factor
  660. if isinstance(module, nn.Embedding):
  661. module.weight.data.normal_(mean=0.0, std=factor * 0.02)
  662. elif isinstance(module, (nn.Linear, Conv1D, nn.Conv1d)):
  663. module.weight.data.normal_(mean=0.0, std=factor * 0.02)
  664. if module.bias is not None:
  665. module.bias.data.zero_()
  666. elif isinstance(module, ClvpRMSNorm):
  667. module.weight.data.fill_(1.0)
  668. elif isinstance(module, ClvpEncoderMLP):
  669. in_proj_std = (module.config.hidden_size**-0.5) * ((2 * module.config.num_hidden_layers) ** -0.5) * factor
  670. fc_std = (2 * module.config.hidden_size) ** -0.5 * factor
  671. nn.init.normal_(module.fc1.proj.weight if getattr(module.fc1, "proj") else module.fc1.weight, std=fc_std)
  672. nn.init.normal_(module.fc2.weight, std=in_proj_std)
  673. elif isinstance(module, ClvpEncoder):
  674. config = self.config.get_text_config()
  675. factor = config.initializer_factor
  676. module.projection.weight.data.normal_(mean=0.0, std=factor * (config.hidden_size**-0.5))
  677. elif isinstance(module, ClvpConditioningEncoder):
  678. module.mel_conv.weight.data.normal_(mean=0.0, std=factor)
  679. module.mel_conv.bias.data.zero_()
  680. elif isinstance(module, ClvpForCausalLM):
  681. for name, p in module.named_parameters():
  682. if name == "c_proj.weight":
  683. p.data.normal_(
  684. mean=0.0, std=(self.config.initializer_range / math.sqrt(2 * self.config.num_hidden_layers))
  685. )
  686. elif isinstance(module, ClvpModelForConditionalGeneration):
  687. module.logit_scale.data.fill_(self.config.logit_scale_init_value)
  688. if isinstance(module, (nn.LayerNorm, nn.GroupNorm)):
  689. module.bias.data.zero_()
  690. module.weight.data.fill_(1.0)
  691. class ClvpEncoder(ClvpPreTrainedModel):
  692. """
  693. Transformer encoder consisting of `config.num_hidden_layers` self attention layers. Each layer is a
  694. [`ClvpEncoderLayer`].
  695. Args:
  696. config: ClvpConfig
  697. """
  698. def __init__(self, config: ClvpConfig):
  699. super().__init__(config)
  700. self.config = config
  701. self.token_embedding = nn.Embedding(config.vocab_size, config.hidden_size)
  702. self.rotary_pos_emb = ClvpRotaryPositionalEmbedding(config) if config.use_rotary_embedding else None
  703. self.layers = nn.ModuleList([ClvpEncoderLayer(config) for _ in range(config.num_hidden_layers)])
  704. self.sequence_summary = ClvpSequenceSummary(config)
  705. self.final_layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
  706. self.projection = nn.Linear(config.hidden_size, config.projection_dim, bias=False)
  707. self.gradient_checkpointing = False
  708. self.post_init()
  709. def get_input_embeddings(self):
  710. return self.token_embedding
  711. def set_input_embeddings(self, value):
  712. self.token_embedding = value
  713. def forward(
  714. self,
  715. input_ids: Optional[torch.LongTensor] = None,
  716. inputs_embeds: Optional[torch.LongTensor] = None,
  717. attention_mask: Optional[torch.LongTensor] = None,
  718. position_ids: Optional[torch.LongTensor] = None,
  719. output_attentions: Optional[bool] = None,
  720. output_hidden_states: Optional[bool] = None,
  721. return_dict: Optional[bool] = None,
  722. ) -> Union[tuple, BaseModelOutput]:
  723. r"""
  724. Args:
  725. input_ids (`torch.LongTensor` of shape `(batch_size, input_ids_length)`, *optional*):
  726. Indices of input sequence tokens in the vocabulary.
  727. Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
  728. [`PreTrainedTokenizer.__call__`] for details.
  729. [What are input IDs?](../glossary#input-ids)
  730. inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
  731. input embeddings for the model. This bypasses the model's internal embedding lookup matrix.
  732. attention_mask (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
  733. Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
  734. - 1 for tokens that are **not masked**,
  735. - 0 for tokens that are **masked**.
  736. [What are attention masks?](../glossary#attention-mask)
  737. position_ids (`torch.LongTensor`, *optional*):
  738. Denotes the position ids of `input_ids`.
  739. output_attentions (`bool`, *optional*):
  740. Whether or not to return the attentions tensors of all attention layers. See `attentions` under
  741. returned tensors for more detail.
  742. output_hidden_states (`bool`, *optional*):
  743. Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors
  744. for more detail.
  745. return_dict (`bool`, *optional*):
  746. Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
  747. """
  748. output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
  749. output_hidden_states = (
  750. output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
  751. )
  752. return_dict = return_dict if return_dict is not None else self.config.use_return_dict
  753. if input_ids is not None and inputs_embeds is not None:
  754. raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
  755. elif input_ids is not None:
  756. self.warn_if_padding_and_no_attention_mask(input_ids, attention_mask)
  757. input_shape = input_ids.size()
  758. input_ids = input_ids.view(-1, input_shape[-1])
  759. inputs_embeds = self.token_embedding(input_ids)
  760. elif inputs_embeds is not None:
  761. input_shape = inputs_embeds.size()[:-1]
  762. else:
  763. raise ValueError("You have to specify either input_ids or inputs_embeds")
  764. # expand attention_mask and create position_ids if needed
  765. if attention_mask is not None:
  766. # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
  767. attention_mask = _prepare_4d_attention_mask(attention_mask, inputs_embeds.dtype)
  768. if position_ids is None:
  769. device = input_ids.device if input_ids is not None else inputs_embeds.device
  770. position_ids = torch.arange(input_shape[1], dtype=torch.long, device=device)
  771. position_ids = position_ids.unsqueeze(0)
  772. encoder_states = () if output_hidden_states else None
  773. all_attentions = () if output_attentions else None
  774. rotary_pos_emb = self.rotary_pos_emb(inputs_embeds) if self.rotary_pos_emb is not None else None
  775. hidden_states = inputs_embeds
  776. for idx, encoder_layer in enumerate(self.layers):
  777. if output_hidden_states:
  778. encoder_states = encoder_states + (hidden_states,)
  779. if self.gradient_checkpointing and self.training:
  780. layer_outputs = torch.utils.checkpoint.checkpoint(
  781. encoder_layer.__call__,
  782. hidden_states,
  783. rotary_pos_emb,
  784. attention_mask,
  785. position_ids,
  786. )
  787. else:
  788. layer_outputs = encoder_layer(
  789. hidden_states,
  790. rotary_pos_emb,
  791. attention_mask,
  792. position_ids,
  793. output_attentions=output_attentions,
  794. )
  795. hidden_states = layer_outputs[0]
  796. if output_attentions:
  797. all_attentions = all_attentions + (layer_outputs[1],)
  798. if output_hidden_states:
  799. encoder_states = encoder_states + (hidden_states,)
  800. last_hidden_state = hidden_states
  801. last_hidden_state = self.final_layer_norm(last_hidden_state)
  802. # take the mean over axis 1 and get pooled output
  803. pooled_output = self.sequence_summary(last_hidden_state)
  804. # apply the projection layer
  805. embeds = self.projection(pooled_output)
  806. if not return_dict:
  807. return tuple(
  808. v for v in [embeds, last_hidden_state, pooled_output, encoder_states, all_attentions] if v is not None
  809. )
  810. return ClvpEncoderOutput(
  811. embeds=embeds,
  812. last_hidden_state=last_hidden_state,
  813. pooler_output=pooled_output,
  814. hidden_states=encoder_states,
  815. attentions=all_attentions,
  816. )
  817. class ClvpDecoder(ClvpPreTrainedModel):
  818. """
  819. Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`ClvpDecoderLayer`]
  820. """
  821. def __init__(self, config):
  822. super().__init__(config)
  823. self.config = config
  824. self.input_embeds_layer = nn.Embedding(self.config.vocab_size, self.config.hidden_size)
  825. self.position_embeds_layer = nn.Embedding(self.config.max_position_embeddings, self.config.hidden_size)
  826. self.drop = nn.Dropout(self.config.embd_pdrop)
  827. self.layers = nn.ModuleList(
  828. [ClvpDecoderLayer(self.config, layer_idx=i) for i in range(self.config.num_hidden_layers)]
  829. )
  830. self.layer_norm = nn.LayerNorm(self.config.hidden_size, eps=self.config.layer_norm_epsilon)
  831. self.gradient_checkpointing = False
  832. # Initialize weights and apply final processing
  833. self.post_init()
  834. def get_input_embeddings(self):
  835. return self.input_embeds_layer
  836. def set_input_embeddings(self, new_embeddings):
  837. self.input_embeds_layer = new_embeddings
  838. def _prune_heads(self, heads_to_prune):
  839. """
  840. Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer}
  841. """
  842. for layer, heads in heads_to_prune.items():
  843. self.layers[layer].attn.prune_heads(heads)
  844. @auto_docstring
  845. def forward(
  846. self,
  847. input_ids: Optional[torch.LongTensor] = None,
  848. attention_mask: Optional[torch.FloatTensor] = None,
  849. token_type_ids: Optional[torch.LongTensor] = None,
  850. position_ids: Optional[torch.LongTensor] = None,
  851. head_mask: Optional[torch.FloatTensor] = None,
  852. past_key_values: Optional[Cache] = None,
  853. inputs_embeds: Optional[torch.FloatTensor] = None,
  854. use_cache: Optional[bool] = None,
  855. output_attentions: Optional[bool] = None,
  856. output_hidden_states: Optional[bool] = None,
  857. return_dict: Optional[bool] = None,
  858. cache_position: Optional[torch.Tensor] = None,
  859. ) -> Union[tuple, BaseModelOutputWithPastAndCrossAttentions]:
  860. output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
  861. output_hidden_states = (
  862. output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
  863. )
  864. use_cache = use_cache if use_cache is not None else self.config.use_cache
  865. return_dict = return_dict if return_dict is not None else self.config.use_return_dict
  866. if input_ids is not None and inputs_embeds is not None:
  867. raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
  868. elif input_ids is not None:
  869. self.warn_if_padding_and_no_attention_mask(input_ids, attention_mask)
  870. input_shape = input_ids.size()
  871. input_ids = input_ids.view(-1, input_shape[-1])
  872. input_ids.shape[0]
  873. elif inputs_embeds is not None:
  874. input_shape = inputs_embeds.size()[:-1]
  875. inputs_embeds.shape[0]
  876. else:
  877. raise ValueError("You have to specify either input_ids or inputs_embeds")
  878. device = input_ids.device if input_ids is not None else inputs_embeds.device
  879. if token_type_ids is not None:
  880. token_type_ids = token_type_ids.view(-1, input_shape[-1])
  881. if self.gradient_checkpointing and self.training:
  882. if use_cache:
  883. logger.warning_once(
  884. "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
  885. )
  886. use_cache = False
  887. if use_cache and past_key_values is None:
  888. past_key_values = DynamicCache(config=self.config)
  889. if use_cache and isinstance(past_key_values, tuple):
  890. logger.warning_once(
  891. "Passing a tuple of `past_key_values` is deprecated and will be removed in Transformers v4.58.0. "
  892. "You should pass an instance of `DynamicCache` instead, e.g. "
  893. "`past_key_values=DynamicCache.from_legacy_cache(past_key_values)`."
  894. )
  895. past_key_values = DynamicCache.from_legacy_cache(past_key_values)
  896. past_key_values_length = past_key_values.get_seq_length() if past_key_values is not None else 0
  897. if position_ids is None:
  898. position_ids = torch.arange(
  899. past_key_values_length, input_shape[-1] + past_key_values_length, dtype=torch.long, device=device
  900. )
  901. position_ids = position_ids.unsqueeze(0).view(-1, input_shape[-1])
  902. if inputs_embeds is None:
  903. inputs_embeds = self.input_embeds_layer(input_ids)
  904. position_embeds = self.position_embeds_layer(position_ids)
  905. inputs_embeds = inputs_embeds + position_embeds
  906. attention_mask = _prepare_4d_causal_attention_mask(
  907. attention_mask, input_shape, inputs_embeds, past_key_values_length
  908. )
  909. # Prepare head mask if needed
  910. # 1.0 in head_mask indicate we keep the head
  911. # attention_probs has shape bsz x num_attention_heads x N x N
  912. # head_mask has shape num_hidden_layers x batch x num_attention_heads x N x N
  913. head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
  914. hidden_states = inputs_embeds
  915. if token_type_ids is not None:
  916. token_type_embeds = self.input_embeds_layer(token_type_ids)
  917. hidden_states = hidden_states + token_type_embeds
  918. hidden_states = self.drop(hidden_states)
  919. output_shape = (-1,) + input_shape[1:] + (hidden_states.size(-1),)
  920. all_self_attentions = () if output_attentions else None
  921. all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None
  922. all_hidden_states = () if output_hidden_states else None
  923. for i, block in enumerate(self.layers):
  924. if output_hidden_states:
  925. all_hidden_states = all_hidden_states + (hidden_states,)
  926. if self.gradient_checkpointing and self.training:
  927. outputs = torch.utils.checkpoint.checkpoint(
  928. block.__call__,
  929. hidden_states,
  930. None,
  931. attention_mask,
  932. position_ids,
  933. head_mask[i],
  934. cache_position,
  935. )
  936. else:
  937. outputs = block(
  938. hidden_states,
  939. past_key_values=past_key_values,
  940. attention_mask=attention_mask,
  941. position_ids=position_ids,
  942. head_mask=head_mask[i],
  943. use_cache=use_cache,
  944. output_attentions=output_attentions,
  945. cache_position=cache_position,
  946. )
  947. hidden_states = outputs[0]
  948. if output_attentions:
  949. all_self_attentions = all_self_attentions + (outputs[1],)
  950. if self.config.add_cross_attention:
  951. all_cross_attentions = all_cross_attentions + (outputs[2],)
  952. hidden_states = self.layer_norm(hidden_states)
  953. hidden_states = hidden_states.view(output_shape)
  954. # Add last hidden state
  955. if output_hidden_states:
  956. all_hidden_states = all_hidden_states + (hidden_states,)
  957. if not return_dict:
  958. return tuple(
  959. v
  960. for v in [hidden_states, past_key_values, all_hidden_states, all_self_attentions, all_cross_attentions]
  961. if v is not None
  962. )
  963. return BaseModelOutputWithPastAndCrossAttentions(
  964. last_hidden_state=hidden_states,
  965. past_key_values=past_key_values,
  966. hidden_states=all_hidden_states,
  967. attentions=all_self_attentions,
  968. cross_attentions=all_cross_attentions,
  969. )
  970. @auto_docstring
  971. class ClvpModel(ClvpPreTrainedModel):
  972. def __init__(self, config: ClvpDecoderConfig):
  973. super().__init__(config)
  974. self.config = config
  975. self.decoder = ClvpDecoder(self.config)
  976. # Initialize weights and apply final processing
  977. self.post_init()
  978. def get_input_embeddings(self):
  979. return self.decoder.input_embeds_layer
  980. def set_input_embeddings(self, value):
  981. self.decoder.input_embeds_layer = value
  982. @auto_docstring
  983. def forward(
  984. self,
  985. input_ids: Optional[torch.LongTensor] = None,
  986. attention_mask: Optional[torch.FloatTensor] = None,
  987. token_type_ids: Optional[torch.LongTensor] = None,
  988. position_ids: Optional[torch.LongTensor] = None,
  989. head_mask: Optional[torch.FloatTensor] = None,
  990. past_key_values: Optional[Cache] = None,
  991. inputs_embeds: Optional[torch.FloatTensor] = None,
  992. use_cache: Optional[bool] = None,
  993. output_attentions: Optional[bool] = None,
  994. output_hidden_states: Optional[bool] = None,
  995. return_dict: Optional[bool] = None,
  996. cache_position: Optional[torch.Tensor] = None,
  997. ) -> Union[tuple, BaseModelOutputWithPastAndCrossAttentions]:
  998. output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
  999. output_hidden_states = (
  1000. output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
  1001. )
  1002. use_cache = use_cache if use_cache is not None else self.config.use_cache
  1003. return_dict = return_dict if return_dict is not None else self.config.use_return_dict
  1004. # decoder outputs consists of (dec_features, past_key_values, dec_hidden, dec_attn)
  1005. decoder_outputs = self.decoder(
  1006. input_ids=input_ids,
  1007. attention_mask=attention_mask,
  1008. token_type_ids=token_type_ids,
  1009. position_ids=position_ids,
  1010. head_mask=head_mask,
  1011. past_key_values=past_key_values,
  1012. inputs_embeds=inputs_embeds,
  1013. use_cache=use_cache,
  1014. output_attentions=output_attentions,
  1015. output_hidden_states=output_hidden_states,
  1016. return_dict=return_dict,
  1017. cache_position=cache_position,
  1018. )
  1019. if not return_dict:
  1020. return decoder_outputs
  1021. return BaseModelOutputWithPastAndCrossAttentions(
  1022. last_hidden_state=decoder_outputs.last_hidden_state,
  1023. past_key_values=decoder_outputs.past_key_values,
  1024. hidden_states=decoder_outputs.hidden_states,
  1025. attentions=decoder_outputs.attentions,
  1026. cross_attentions=decoder_outputs.cross_attentions,
  1027. )
  1028. @auto_docstring(
  1029. custom_intro="""
  1030. The CLVP decoder model with a language modelling head on top.
  1031. """
  1032. )
  1033. class ClvpForCausalLM(ClvpPreTrainedModel, GenerationMixin):
  1034. def __init__(self, config):
  1035. super().__init__(config)
  1036. self.config = config
  1037. self.model = ClvpModel(self.config)
  1038. self.final_norm = nn.LayerNorm(self.config.hidden_size)
  1039. self.lm_head = nn.Linear(self.config.hidden_size, self.config.vocab_size, bias=True)
  1040. # Initialize weights and apply final processing
  1041. self.post_init()
  1042. def get_output_embeddings(self):
  1043. return None
  1044. def get_input_embeddings(self):
  1045. return self.model.decoder.input_embeds_layer
  1046. def set_input_embeddings(self, new_embeddings):
  1047. self.model.decoder.input_embeds_layer = new_embeddings
  1048. def _prepare_model_inputs(
  1049. self,
  1050. inputs: Optional[torch.Tensor] = None,
  1051. bos_token_id: Optional[int] = None,
  1052. model_kwargs: Optional[dict[str, torch.Tensor]] = None,
  1053. ) -> tuple[torch.Tensor, Optional[str], dict[str, torch.Tensor]]:
  1054. """
  1055. This function extracts the model-specific `inputs` for generation.
  1056. """
  1057. input_name = self.main_input_name
  1058. model_kwargs = {k: v for k, v in model_kwargs.items() if v is not None}
  1059. inputs_kwarg = model_kwargs.pop(input_name, None)
  1060. if inputs_kwarg is not None and inputs is not None:
  1061. raise ValueError(
  1062. f"`inputs`: {inputs}` were passed alongside {input_name} which is not allowed."
  1063. f"Make sure to either pass {inputs} or {input_name}=..."
  1064. )
  1065. elif inputs_kwarg is not None:
  1066. inputs = inputs_kwarg
  1067. if input_name == "input_ids" and "inputs_embeds" in model_kwargs:
  1068. model_kwargs["input_ids"] = self._maybe_initialize_input_ids_for_generation(
  1069. inputs, bos_token_id, model_kwargs=model_kwargs
  1070. )
  1071. inputs, input_name = model_kwargs["inputs_embeds"], "inputs_embeds"
  1072. # Check if conditioning_embeds are provided or not, if yes then concatenate the bos_token_id at the end of the conditioning_embeds.
  1073. # Then we must subtract the positional_ids because during the forward pass it will be added anyways, so we must cancel them out here.
  1074. conditioning_embeds = model_kwargs.get("conditioning_embeds")
  1075. if conditioning_embeds is not None:
  1076. mel_start_token_embedding = self.model.decoder.input_embeds_layer(
  1077. torch.full(
  1078. (conditioning_embeds.shape[0], 1),
  1079. fill_value=self.config.bos_token_id,
  1080. device=conditioning_embeds.device,
  1081. )
  1082. )
  1083. mel_start_token_embedding += self.model.decoder.position_embeds_layer(
  1084. torch.full((conditioning_embeds.shape[0], 1), fill_value=0, device=conditioning_embeds.device)
  1085. )
  1086. conditioning_embeds = torch.concat([conditioning_embeds, mel_start_token_embedding], dim=1)
  1087. # subtract the positional_ids here
  1088. if hasattr(model_kwargs, "attention_mask"):
  1089. position_ids = model_kwargs["attention_mask"].long().cumsum(-1) - 1
  1090. else:
  1091. position_ids = torch.arange(
  1092. 0, conditioning_embeds.shape[1], dtype=torch.long, device=conditioning_embeds.device
  1093. )
  1094. position_ids = position_ids.unsqueeze(0).repeat(conditioning_embeds.shape[0], 1)
  1095. model_kwargs["inputs_embeds"] = conditioning_embeds - self.model.decoder.position_embeds_layer(
  1096. position_ids
  1097. )
  1098. model_kwargs["input_ids"] = (
  1099. torch.ones((model_kwargs["inputs_embeds"].shape[0], 1), dtype=torch.long, device=self.device)
  1100. * self.config.bos_token_id
  1101. )
  1102. return model_kwargs["inputs_embeds"], "inputs_embeds", model_kwargs
  1103. inputs = self._maybe_initialize_input_ids_for_generation(inputs, bos_token_id, model_kwargs)
  1104. return inputs, input_name, model_kwargs
  1105. def prepare_inputs_for_generation(
  1106. self,
  1107. input_ids,
  1108. past_key_values=None,
  1109. inputs_embeds=None,
  1110. conditioning_embeds=None,
  1111. cache_position=None,
  1112. **kwargs,
  1113. ):
  1114. # Overwritten: has `conditioning_embeds`-related logic
  1115. input_ids_length = input_ids.shape[-1]
  1116. model_inputs = super().prepare_inputs_for_generation(
  1117. input_ids,
  1118. past_key_values=past_key_values,
  1119. inputs_embeds=inputs_embeds,
  1120. cache_position=cache_position,
  1121. **kwargs,
  1122. )
  1123. if conditioning_embeds is not None and cache_position[0] != 0:
  1124. model_inputs["position_ids"] = torch.tensor([input_ids_length], dtype=torch.long, device=input_ids.device)
  1125. return model_inputs
  1126. @auto_docstring
  1127. def forward(
  1128. self,
  1129. input_ids: Optional[torch.LongTensor] = None,
  1130. past_key_values: Optional[Cache] = None,
  1131. attention_mask: Optional[torch.FloatTensor] = None,
  1132. token_type_ids: Optional[torch.LongTensor] = None,
  1133. position_ids: Optional[torch.LongTensor] = None,
  1134. head_mask: Optional[torch.FloatTensor] = None,
  1135. inputs_embeds: Optional[torch.FloatTensor] = None,
  1136. labels: Optional[torch.LongTensor] = None,
  1137. use_cache: Optional[bool] = None,
  1138. output_attentions: Optional[bool] = None,
  1139. output_hidden_states: Optional[bool] = None,
  1140. return_dict: Optional[bool] = None,
  1141. cache_position: Optional[torch.Tensor] = None,
  1142. ) -> Union[tuple, CausalLMOutputWithCrossAttentions]:
  1143. r"""
  1144. labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
  1145. Labels for language modeling. Note that the labels **are shifted** inside the model, i.e. you can set
  1146. `labels = input_ids` Indices are selected in `[-100, 0, ..., config.vocab_size]` All labels set to `-100`
  1147. are ignored (masked), the loss is only computed for labels in `[0, ..., config.vocab_size]`
  1148. """
  1149. output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
  1150. output_hidden_states = (
  1151. output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
  1152. )
  1153. use_cache = use_cache if use_cache is not None else self.config.use_cache
  1154. return_dict = return_dict if return_dict is not None else self.config.use_return_dict
  1155. outputs = self.model(
  1156. input_ids=input_ids,
  1157. past_key_values=past_key_values,
  1158. attention_mask=attention_mask,
  1159. token_type_ids=token_type_ids,
  1160. position_ids=position_ids,
  1161. head_mask=head_mask,
  1162. inputs_embeds=inputs_embeds,
  1163. use_cache=use_cache,
  1164. output_attentions=output_attentions,
  1165. output_hidden_states=output_hidden_states,
  1166. return_dict=return_dict,
  1167. cache_position=cache_position,
  1168. )
  1169. hidden_states = outputs[0]
  1170. lm_logits = self.final_norm(hidden_states)
  1171. lm_logits = self.lm_head(lm_logits)
  1172. loss = None
  1173. if labels is not None:
  1174. labels = labels.to(lm_logits.device)
  1175. # Shift so that tokens < n predict n
  1176. shift_logits = lm_logits[..., :-1, :].contiguous()
  1177. shift_labels = labels[..., 1:].contiguous()
  1178. # Flatten the tokens
  1179. loss_fct = CrossEntropyLoss()
  1180. loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
  1181. if not return_dict:
  1182. output = (lm_logits,) + outputs[1:]
  1183. return ((loss,) + output) if loss is not None else output
  1184. return CausalLMOutputWithCrossAttentions(
  1185. loss=loss,
  1186. logits=lm_logits,
  1187. past_key_values=outputs.past_key_values,
  1188. hidden_states=outputs.hidden_states,
  1189. attentions=outputs.attentions,
  1190. cross_attentions=outputs.cross_attentions,
  1191. )
  1192. @auto_docstring(
  1193. custom_intro="""
  1194. The composite CLVP model with a text encoder, speech encoder and speech decoder model.
  1195. """
  1196. )
  1197. class ClvpModelForConditionalGeneration(ClvpPreTrainedModel, GenerationMixin):
  1198. config: ClvpConfig
  1199. def __init__(self, config: ClvpConfig):
  1200. super().__init__(config)
  1201. if not isinstance(config.text_config, ClvpEncoderConfig):
  1202. raise TypeError(
  1203. "config.text_config is expected to be of type `ClvpEncoderConfig` but is of type"
  1204. f" {type(config.text_config)}."
  1205. )
  1206. if not isinstance(config.speech_config, ClvpEncoderConfig):
  1207. raise TypeError(
  1208. "config.speech_config is expected to be of type `ClvpEncoderConfig` but is of type"
  1209. f" {type(config.speech_config)}."
  1210. )
  1211. if not isinstance(config.decoder_config, ClvpDecoderConfig):
  1212. raise TypeError(
  1213. "config.decoder_config is expected to be of type `ClvpDecoderConfig` but is of type"
  1214. f" {type(config.decoder_config)}."
  1215. )
  1216. self.conditioning_encoder = ClvpConditioningEncoder(config)
  1217. self.speech_decoder_model = ClvpForCausalLM(config.decoder_config)
  1218. self.text_encoder_model = ClvpEncoder(config.text_config)
  1219. self.speech_encoder_model = ClvpEncoder(config.speech_config)
  1220. self.logit_scale = nn.Parameter(torch.tensor(self.config.logit_scale_init_value))
  1221. # Initialize weights and apply final processing
  1222. self.post_init()
  1223. # taken from the original repo,
  1224. # link : https://github.com/neonbjb/tortoise-tts/blob/4003544b6ff4b68c09856e04d3eff9da26d023c2/tortoise/api.py#L117
  1225. def fix_speech_decoder_output(self, speech_ids: torch.LongTensor) -> torch.LongTensor:
  1226. """
  1227. This method modifies the output of the decoder model, such as replacing the `eos_token_id` and changing the
  1228. last few tokens of each sequence.
  1229. Args:
  1230. speech_ids (`torch.LongTensor`):
  1231. This refers to the output of the decoder model.
  1232. """
  1233. decoder_fixing_codes = self.config.decoder_config.decoder_fixing_codes
  1234. speech_ids = speech_ids[:, 1:]
  1235. stop_token_indices = torch.where(speech_ids == self.speech_decoder_model.config.eos_token_id, 1, 0)
  1236. speech_ids = torch.masked_fill(speech_ids, mask=stop_token_indices.bool(), value=decoder_fixing_codes[0])
  1237. for i, each_seq_stop_token_index in enumerate(stop_token_indices):
  1238. # This means that no stop tokens were found so the sentence was still being generated, in that case we don't need
  1239. # to apply any padding so just skip to the next sequence of tokens.
  1240. if each_seq_stop_token_index.sum() == 0:
  1241. continue
  1242. stm = each_seq_stop_token_index.argmax()
  1243. speech_ids[i, stm:] = decoder_fixing_codes[0]
  1244. if stm - 3 < speech_ids.shape[1]:
  1245. speech_ids[i, -3:] = torch.tensor(
  1246. [decoder_fixing_codes[1:]], device=speech_ids.device, dtype=torch.long
  1247. )
  1248. return speech_ids
  1249. def get_text_features(
  1250. self,
  1251. input_ids: Optional[torch.LongTensor] = None,
  1252. text_encoder_inputs_embeds: Optional[torch.FloatTensor] = None,
  1253. attention_mask: Optional[torch.LongTensor] = None,
  1254. ) -> torch.FloatTensor:
  1255. r"""
  1256. This method can be used to extract text_embeds from a text. The text embeddings obtained by applying the
  1257. projection layer to the pooled output of the CLVP text encoder model.
  1258. Args:
  1259. input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
  1260. Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you
  1261. provide it.
  1262. [What are input IDs?](../glossary#input-ids)
  1263. text_encoder_inputs_embeds (`torch.FloatTensor`, *optional*):
  1264. inputs_embeds for the text encoder model passed in place of `input_ids`.
  1265. attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
  1266. Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
  1267. - 1 for tokens that are **not masked**,
  1268. - 0 for tokens that are **masked**.
  1269. [What are attention masks?](../glossary#attention-mask)
  1270. Returns:
  1271. `torch.FloatTensor` of shape `(batch_size, output_dim)`:
  1272. The text embeddings obtained by applying the projection layer to the pooled output of the CLVP Text
  1273. Model.
  1274. Examples:
  1275. ```python
  1276. >>> from transformers import ClvpProcessor, ClvpModelForConditionalGeneration
  1277. >>> # Define the Text
  1278. >>> text = "This is an example text."
  1279. >>> # Define processor and model
  1280. >>> processor = ClvpProcessor.from_pretrained("susnato/clvp_dev")
  1281. >>> model = ClvpModelForConditionalGeneration.from_pretrained("susnato/clvp_dev")
  1282. >>> # Generate processor output and text embeds
  1283. >>> processor_output = processor(text=text, return_tensors="pt")
  1284. >>> text_embeds = model.get_text_features(input_ids=processor_output["input_ids"])
  1285. ```
  1286. """
  1287. outputs = self.text_encoder_model(
  1288. input_ids=input_ids,
  1289. inputs_embeds=text_encoder_inputs_embeds,
  1290. attention_mask=attention_mask,
  1291. )
  1292. return outputs[0]
  1293. def get_speech_features(
  1294. self,
  1295. speech_ids: Optional[torch.LongTensor] = None,
  1296. input_ids: Optional[torch.LongTensor] = None,
  1297. input_features: Optional[torch.FloatTensor] = None,
  1298. conditioning_encoder_inputs_embeds: Optional[torch.FloatTensor] = None,
  1299. attention_mask: Optional[torch.Tensor] = None,
  1300. generation_config: Optional[GenerationConfig] = None,
  1301. **kwargs,
  1302. ) -> torch.FloatTensor:
  1303. r"""
  1304. This method can be used to extract speech_embeds. The speech embeddings are obtained by applying the speech
  1305. model on speech_ids. If speech_ids is not present but both input_ids and input_features are given then the
  1306. decoder model will be used to first generate the speech_ids and then applying the speech model.
  1307. Args:
  1308. speech_ids (`torch.LongTensor` of shape `(batch_size, num_speech_ids)`, *optional*):
  1309. Speech Tokens. Padding will be ignored by default should you provide it. If speech_ids are provided
  1310. then input_ids and input_features will be automatically ignored.
  1311. input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
  1312. Input text Tokens. Processed from the [`ClvpTokenizer`]. If speech_ids is not provided, then input_ids
  1313. and input_features will be used.
  1314. conditioning_encoder_inputs_embeds (`torch.FloatTensor`, *optional*):
  1315. inputs_embeds for `ClvpConditioningEncoder`. Can be used in place of `input_ids`.
  1316. attention_mask (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
  1317. Mask to avoid performing attention on padding speech token indices. Mask values selected in `[0, 1]`:
  1318. - 1 for tokens that are **not masked**,
  1319. - 0 for tokens that are **masked**.
  1320. [What are attention masks?](../glossary#attention-mask)
  1321. generation_config (`GenerationConfig`, *optional*):
  1322. generation config to control the generation of speech_ids if they are not provided.
  1323. Returns:
  1324. `torch.FloatTensor` of shape `(batch_size, output_dim)`:
  1325. The speech embeddings obtained by applying the projection layer to the pooled output of the CLVP Speech
  1326. Model.
  1327. Examples:
  1328. ```python
  1329. >>> import datasets
  1330. >>> from transformers import ClvpProcessor, ClvpModelForConditionalGeneration
  1331. >>> # Define the Text and Load the Audio (We are taking an audio example from HuggingFace Hub using `datasets` library)
  1332. >>> text = "This is an example text."
  1333. >>> ds = datasets.load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
  1334. >>> ds = ds.cast_column("audio", datasets.Audio(sampling_rate=22050))
  1335. >>> audio = ds.sort("id")["audio"][0]
  1336. >>> audio_sample, sr = audio["array"], audio["sampling_rate"]
  1337. >>> # Define processor and model
  1338. >>> processor = ClvpProcessor.from_pretrained("susnato/clvp_dev")
  1339. >>> model = ClvpModelForConditionalGeneration.from_pretrained("susnato/clvp_dev")
  1340. >>> # Generate processor output and model output
  1341. >>> processor_output = processor(raw_speech=audio_sample, sampling_rate=sr, text=text, return_tensors="pt")
  1342. >>> speech_embeds = model.get_speech_features(
  1343. ... input_ids=processor_output["input_ids"], input_features=processor_output["input_features"]
  1344. ... )
  1345. ```
  1346. """
  1347. if speech_ids is None:
  1348. if (input_ids is None and conditioning_encoder_inputs_embeds is None) or input_features is None:
  1349. raise ValueError(
  1350. "Either speech_ids or input_ids/conditioning_encoder_inputs_embeds and input_features must be provided."
  1351. )
  1352. if generation_config is None:
  1353. generation_config = self.generation_config
  1354. generation_config.update(**kwargs)
  1355. conditioning_embeds = self.conditioning_encoder(
  1356. input_features=input_features,
  1357. input_ids=input_ids,
  1358. inputs_embeds=conditioning_encoder_inputs_embeds,
  1359. attention_mask=attention_mask,
  1360. )
  1361. speech_ids = self.speech_decoder_model.generate(
  1362. conditioning_embeds=conditioning_embeds,
  1363. generation_config=generation_config,
  1364. )
  1365. speech_ids = self.fix_speech_decoder_output(speech_ids[0])
  1366. outputs = self.speech_encoder_model(
  1367. input_ids=speech_ids,
  1368. attention_mask=attention_mask,
  1369. )
  1370. return outputs[0]
  1371. @auto_docstring
  1372. def forward(
  1373. self,
  1374. input_ids: Optional[torch.LongTensor] = None,
  1375. input_features: Optional[torch.FloatTensor] = None,
  1376. conditioning_encoder_inputs_embeds: Optional[torch.FloatTensor] = None,
  1377. text_encoder_inputs_embeds: Optional[torch.FloatTensor] = None,
  1378. attention_mask: Optional[torch.LongTensor] = None,
  1379. return_loss: Optional[bool] = None,
  1380. output_hidden_states: Optional[bool] = None,
  1381. output_attentions: Optional[bool] = False,
  1382. return_dict: Optional[bool] = None,
  1383. cache_position: Optional[torch.Tensor] = None,
  1384. ) -> Union[tuple, ClvpOutput]:
  1385. r"""
  1386. conditioning_encoder_inputs_embeds (`torch.FloatTensor`, *optional*):
  1387. inputs_embeds for `ClvpConditioningEncoder`. Can be used in place of `input_ids`.
  1388. text_encoder_inputs_embeds (`torch.FloatTensor`, *optional*):
  1389. inputs_embeds for the text encoder model passed in place of `input_ids`.
  1390. return_loss (`bool`, *optional*):
  1391. Whether or not to return the contrastive loss.
  1392. Examples:
  1393. ```python
  1394. >>> import datasets
  1395. >>> from transformers import ClvpProcessor, ClvpModelForConditionalGeneration
  1396. >>> # Define the Text and Load the Audio (We are taking an audio example from HuggingFace Hub using `datasets` library)
  1397. >>> text = "This is an example text."
  1398. >>> ds = datasets.load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
  1399. >>> ds = ds.cast_column("audio", datasets.Audio(sampling_rate=22050))
  1400. >>> audio = ds.sort("id")["audio"][0]
  1401. >>> audio_sample, sr = audio["array"], audio["sampling_rate"]
  1402. >>> # Define processor and model
  1403. >>> processor = ClvpProcessor.from_pretrained("susnato/clvp_dev")
  1404. >>> model = ClvpModelForConditionalGeneration.from_pretrained("susnato/clvp_dev")
  1405. >>> # processor outputs and model outputs
  1406. >>> processor_output = processor(raw_speech=audio_sample, sampling_rate=sr, text=text, return_tensors="pt")
  1407. >>> outputs = model(
  1408. ... input_ids=processor_output["input_ids"],
  1409. ... input_features=processor_output["input_features"],
  1410. ... return_dict=True,
  1411. ... )
  1412. ```
  1413. """
  1414. # Use CLVP model's config for some fields (if specified) instead of those of speech & text components.
  1415. output_hidden_states = (
  1416. output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
  1417. )
  1418. return_dict = return_dict if return_dict is not None else self.config.use_return_dict
  1419. conditioning_embeds = self.conditioning_encoder(
  1420. input_features=input_features,
  1421. input_ids=input_ids,
  1422. inputs_embeds=conditioning_encoder_inputs_embeds,
  1423. attention_mask=attention_mask,
  1424. )
  1425. decoder_outputs = self.speech_decoder_model(
  1426. inputs_embeds=conditioning_embeds,
  1427. output_hidden_states=output_hidden_states,
  1428. return_dict=return_dict,
  1429. cache_position=cache_position,
  1430. )
  1431. speech_ids = decoder_outputs[0]
  1432. # since we will get the embeds of shape `(batch_size, seq_len, embedding_dim)` during the forward pass
  1433. # we must convert it to tokens, to make it compaitable with speech_transformer
  1434. if speech_ids.ndim == 3:
  1435. speech_ids = speech_ids.argmax(2)
  1436. speech_ids = self.fix_speech_decoder_output(speech_ids)
  1437. speech_outputs = self.speech_encoder_model(
  1438. input_ids=speech_ids,
  1439. output_hidden_states=output_hidden_states,
  1440. return_dict=return_dict,
  1441. )
  1442. text_outputs = self.text_encoder_model(
  1443. input_ids=input_ids,
  1444. inputs_embeds=text_encoder_inputs_embeds,
  1445. attention_mask=attention_mask,
  1446. output_hidden_states=output_hidden_states,
  1447. return_dict=return_dict,
  1448. )
  1449. speech_embeds = speech_outputs[0]
  1450. text_embeds = text_outputs[0]
  1451. # normalized features
  1452. speech_embeds = speech_embeds / speech_embeds.norm(p=2, dim=-1, keepdim=True)
  1453. text_embeds = text_embeds / text_embeds.norm(p=2, dim=-1, keepdim=True)
  1454. # cosine similarity as logits
  1455. logit_scale = self.logit_scale.exp()
  1456. logits_per_text = torch.matmul(text_embeds, speech_embeds.t()) * logit_scale
  1457. logits_per_speech = logits_per_text.t()
  1458. loss = None
  1459. if return_loss:
  1460. loss = clvp_loss(logits_per_text)
  1461. if not return_dict:
  1462. output = (
  1463. logits_per_speech,
  1464. logits_per_text,
  1465. text_embeds,
  1466. speech_embeds,
  1467. text_outputs[2],
  1468. speech_outputs[2],
  1469. )
  1470. if output_hidden_states:
  1471. output += (
  1472. decoder_outputs[-1],
  1473. text_outputs[-1],
  1474. speech_outputs[-1],
  1475. )
  1476. return ((loss,) + output) if loss is not None else output
  1477. return ClvpOutput(
  1478. loss=loss,
  1479. logits_per_speech=logits_per_speech,
  1480. logits_per_text=logits_per_text,
  1481. text_embeds=text_embeds,
  1482. speech_embeds=speech_embeds,
  1483. text_model_output=text_outputs[2],
  1484. speech_model_output=speech_outputs[2],
  1485. decoder_hidden_states=decoder_outputs.hidden_states,
  1486. text_encoder_hidden_states=text_outputs.hidden_states,
  1487. speech_encoder_hidden_states=speech_outputs.hidden_states,
  1488. )
  1489. @torch.no_grad()
  1490. def generate(
  1491. self,
  1492. input_ids: Optional[torch.LongTensor] = None,
  1493. input_features: Optional[torch.FloatTensor] = None,
  1494. attention_mask: Optional[torch.LongTensor] = None,
  1495. generation_config: Optional[GenerationConfig] = None,
  1496. pad_to_max_mel_tokens: Optional[int] = None,
  1497. output_hidden_states: Optional[bool] = None,
  1498. **kwargs,
  1499. ):
  1500. """
  1501. Generate method for `ClvpModelForConditionalGeneration`, this method calls the `generate` method of
  1502. `ClvpForCausalLM` and then uses those generated `speech_ids` to process `text_embeds` and `speech_embeds` using
  1503. `ClvpEncoder`.
  1504. Args:
  1505. input_ids (`torch.FloatTensor` of shape `(batch_size, sequence_length)`, *optional*):
  1506. Input text Tokens. Processed from the [`ClvpTokenizer`].
  1507. attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
  1508. Mask to avoid performing attention on padding text token indices. Mask values selected in `[0, 1]`:
  1509. - 1 for tokens that are **not masked**,
  1510. - 0 for tokens that are **masked**.
  1511. [What are attention masks?](../glossary#attention-mask)
  1512. generation_config (`~generation.GenerationConfig`, *optional*):
  1513. The generation configuration to be used as base parametrization for the generation call. `**kwargs`
  1514. passed to generate matching the attributes of `generation_config` will override them. If
  1515. `generation_config` is not provided, the default will be used, which had the following loading
  1516. priority: 1) from the `generation_config.json` model file, if it exists; 2) from the model
  1517. configuration. Please note that unspecified parameters will inherit [`~generation.GenerationConfig`]'s
  1518. default values, whose documentation should be checked to parameterize generation.
  1519. pad_to_max_mel_tokens (`int`, *optional*):
  1520. Pads generated speech_ids to the specified value. This is to implement the same logic from the official
  1521. repo, link: https://github.com/neonbjb/tortoise-tts/blob/80f89987a5abda5e2b082618cd74f9c7411141dc/tortoise/api.py#L430
  1522. and to make sure the logits are same.
  1523. This does not affect generation quality so please don't consider using it since it is less efficient.
  1524. output_hidden_states (`bool`, *optional*):
  1525. Whether or not to return the hidden states of decoder model, text encoder and speech encoder models.
  1526. Returns:
  1527. `ClvpOutput` or tuple: A `ClvpOutput` (if `return_dict_in_generate=True` or when
  1528. `config.return_dict_in_generate=True`) or a tuple.
  1529. """
  1530. # If the input sequences are larger than (self.config.decoder_config.max_text_tokens - 3) then raise error,
  1531. # because we need to add 3 tokens ( 1 bos tokens and 2 eos tokens) to the input_ids in ClvpConditioningEncoder to
  1532. # properly sample
  1533. sequence_length = input_ids.shape[-1]
  1534. if sequence_length > (self.config.decoder_config.max_text_tokens - 3):
  1535. raise ValueError(
  1536. f"Maximum sequence length reached! Found input_ids of length {sequence_length}."
  1537. f"Please make sure that the maximum length of input_ids is {self.config.decoder_config.max_text_tokens - 3}"
  1538. )
  1539. if generation_config is None:
  1540. generation_config = self.generation_config
  1541. generation_config = copy.deepcopy(generation_config)
  1542. model_kwargs = generation_config.update(**kwargs) # All unused kwargs must be model kwargs
  1543. generation_config.validate()
  1544. self._validate_model_kwargs(model_kwargs.copy())
  1545. # pad input_ids as specified in the original repo
  1546. # link: https://github.com/neonbjb/tortoise-tts/blob/80f89987a5abda5e2b082618cd74f9c7411141dc/tortoise/api.py#L380
  1547. input_ids, attention_mask = _pad_extra_bos_eos_tokens(
  1548. input_ids,
  1549. attention_mask,
  1550. add_bos_token=False,
  1551. bos_token_id=self.config.text_config.bos_token_id,
  1552. eos_token_id=self.config.text_config.eos_token_id,
  1553. )
  1554. conditioning_embeds = self.conditioning_encoder(
  1555. input_features=input_features,
  1556. input_ids=input_ids,
  1557. attention_mask=attention_mask,
  1558. )
  1559. decoder_outputs = self.speech_decoder_model.generate(
  1560. conditioning_embeds=conditioning_embeds,
  1561. generation_config=generation_config,
  1562. output_hidden_states=output_hidden_states,
  1563. return_dict=generation_config.return_dict_in_generate,
  1564. )
  1565. if isinstance(decoder_outputs, ModelOutput):
  1566. speech_ids = decoder_outputs.sequences
  1567. # pad to pad_to_max_mel_tokens if given, to replicate the original repo logic
  1568. # link: https://github.com/neonbjb/tortoise-tts/blob/80f89987a5abda5e2b082618cd74f9c7411141dc/tortoise/api.py#L430
  1569. if pad_to_max_mel_tokens is not None:
  1570. padding_needed = pad_to_max_mel_tokens - speech_ids.shape[-1]
  1571. speech_ids = torch.nn.functional.pad(
  1572. speech_ids, (0, padding_needed), value=self.generation_config.eos_token_id
  1573. )
  1574. speech_ids = self.fix_speech_decoder_output(speech_ids)
  1575. speech_outputs = self.speech_encoder_model(
  1576. input_ids=speech_ids,
  1577. output_hidden_states=output_hidden_states,
  1578. return_dict=generation_config.return_dict_in_generate,
  1579. )
  1580. text_outputs = self.text_encoder_model(
  1581. input_ids=input_ids,
  1582. attention_mask=attention_mask,
  1583. output_hidden_states=output_hidden_states,
  1584. return_dict=generation_config.return_dict_in_generate,
  1585. )
  1586. speech_embeds = speech_outputs[0]
  1587. text_embeds = text_outputs[0]
  1588. # normalized features
  1589. speech_embeds = speech_embeds / speech_embeds.norm(p=2, dim=-1, keepdim=True)
  1590. text_embeds = text_embeds / text_embeds.norm(p=2, dim=-1, keepdim=True)
  1591. # cosine similarity as logits
  1592. logit_scale = self.logit_scale.exp()
  1593. logits_per_text = torch.matmul(text_embeds, speech_embeds.t()) * logit_scale
  1594. logits_per_speech = logits_per_text.t()
  1595. if not generation_config.return_dict_in_generate:
  1596. output = (
  1597. speech_ids,
  1598. logits_per_speech,
  1599. logits_per_text,
  1600. text_embeds,
  1601. speech_embeds,
  1602. text_outputs[2],
  1603. speech_outputs[2],
  1604. )
  1605. if output_hidden_states:
  1606. output += (
  1607. decoder_outputs[-1],
  1608. text_outputs[-1],
  1609. speech_outputs[-1],
  1610. )
  1611. return output
  1612. return ClvpOutput(
  1613. speech_ids=speech_ids,
  1614. logits_per_speech=logits_per_speech,
  1615. logits_per_text=logits_per_text,
  1616. text_embeds=text_embeds,
  1617. speech_embeds=speech_embeds,
  1618. text_model_output=text_outputs[2],
  1619. speech_model_output=speech_outputs[2],
  1620. decoder_hidden_states=decoder_outputs.hidden_states,
  1621. text_encoder_hidden_states=text_outputs.hidden_states,
  1622. speech_encoder_hidden_states=speech_outputs.hidden_states,
  1623. )
  1624. __all__ = [
  1625. "ClvpModelForConditionalGeneration",
  1626. "ClvpForCausalLM",
  1627. "ClvpModel",
  1628. "ClvpPreTrainedModel",
  1629. "ClvpEncoder",
  1630. "ClvpDecoder",
  1631. ]