modeling_bark.py 70 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160116111621163116411651166116711681169117011711172117311741175117611771178117911801181118211831184118511861187118811891190119111921193119411951196119711981199120012011202120312041205120612071208120912101211121212131214121512161217121812191220122112221223122412251226122712281229123012311232123312341235123612371238123912401241124212431244124512461247124812491250125112521253125412551256125712581259126012611262126312641265126612671268126912701271127212731274127512761277127812791280128112821283128412851286128712881289129012911292129312941295129612971298129913001301130213031304130513061307130813091310131113121313131413151316131713181319132013211322132313241325132613271328132913301331133213331334133513361337133813391340134113421343134413451346134713481349135013511352135313541355135613571358135913601361136213631364136513661367136813691370137113721373137413751376137713781379138013811382138313841385138613871388138913901391139213931394139513961397139813991400140114021403140414051406140714081409141014111412141314141415141614171418141914201421142214231424142514261427142814291430143114321433143414351436143714381439144014411442144314441445144614471448144914501451145214531454145514561457145814591460146114621463146414651466146714681469147014711472147314741475147614771478147914801481148214831484148514861487148814891490149114921493149414951496149714981499150015011502150315041505150615071508150915101511151215131514151515161517151815191520152115221523152415251526152715281529153015311532153315341535153615371538153915401541154215431544154515461547154815491550155115521553155415551556155715581559156015611562156315641565156615671568156915701571157215731574157515761577157815791580158115821583158415851586158715881589159015911592159315941595159615971598159916001601160216031604160516061607160816091610161116121613161416151616161716181619162016211622162316241625162616271628
  1. # coding=utf-8
  2. # Copyright 2023 The Suno AI Authors and The HuggingFace Inc. team. All rights reserved.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. """PyTorch BARK model."""
  16. import math
  17. import warnings
  18. from typing import Optional, Union
  19. import numpy as np
  20. import torch
  21. from torch import nn
  22. from torch.nn import functional as F
  23. from ...cache_utils import Cache, DynamicCache
  24. from ...generation import GenerationMixin
  25. from ...generation.logits_process import (
  26. AlternatingCodebooksLogitsProcessor,
  27. BarkEosPrioritizerLogitsProcessor,
  28. SuppressTokensLogitsProcessor,
  29. )
  30. from ...modeling_attn_mask_utils import _prepare_4d_attention_mask
  31. from ...modeling_flash_attention_utils import flash_attn_supports_top_left_mask, is_flash_attn_available
  32. from ...modeling_layers import GradientCheckpointingLayer
  33. from ...modeling_outputs import CausalLMOutputWithPast, MaskedLMOutput
  34. from ...modeling_utils import PreTrainedModel, get_parameter_device
  35. from ...utils import (
  36. auto_docstring,
  37. is_accelerate_available,
  38. is_torch_accelerator_available,
  39. logging,
  40. )
  41. from ..auto import AutoModel
  42. from .configuration_bark import (
  43. BarkCoarseConfig,
  44. BarkConfig,
  45. BarkFineConfig,
  46. BarkSemanticConfig,
  47. BarkSubModelConfig,
  48. )
  49. from .generation_configuration_bark import (
  50. BarkCoarseGenerationConfig,
  51. BarkFineGenerationConfig,
  52. BarkSemanticGenerationConfig,
  53. )
  54. if is_flash_attn_available():
  55. from ...modeling_flash_attention_utils import _flash_attention_forward
  56. logger = logging.get_logger(__name__)
  57. class BarkSelfAttention(nn.Module):
  58. # adapted from GPTNeoSelfAttention and Bark code
  59. # BarkSelfAttention can have two attention type, i.e full attention or causal attention
  60. def __init__(self, config, is_causal=False, layer_idx=None):
  61. super().__init__()
  62. # regularization
  63. self.dropout = config.dropout
  64. self.attn_dropout = nn.Dropout(config.dropout)
  65. self.resid_dropout = nn.Dropout(config.dropout)
  66. self.embed_dim = config.hidden_size
  67. self.num_heads = config.num_heads
  68. self.head_dim = self.embed_dim // self.num_heads
  69. if config.hidden_size % config.num_heads != 0:
  70. raise ValueError(
  71. f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`:"
  72. f" {self.num_heads})."
  73. )
  74. # key, query, value projections for all heads, but in a batch
  75. self.att_proj = nn.Linear(config.hidden_size, 3 * config.hidden_size, bias=config.bias)
  76. # output projection
  77. self.out_proj = nn.Linear(config.hidden_size, config.hidden_size, bias=config.bias)
  78. self.is_causal = is_causal
  79. self.layer_idx = layer_idx
  80. if is_causal:
  81. block_size = config.block_size
  82. bias = torch.tril(torch.ones((block_size, block_size), dtype=bool)).view(1, 1, block_size, block_size)
  83. self.register_buffer("bias", bias)
  84. # Copied from transformers.models.gpt_neo.modeling_gpt_neo.GPTNeoSelfAttention._split_heads
  85. def _split_heads(self, tensor, num_heads, attn_head_size):
  86. """
  87. Splits hidden_size dim into attn_head_size and num_heads
  88. """
  89. new_shape = tensor.size()[:-1] + (num_heads, attn_head_size)
  90. tensor = tensor.view(new_shape)
  91. return tensor.permute(0, 2, 1, 3) # (batch, head, seq_length, head_features)
  92. def _merge_heads(self, tensor, num_heads, attn_head_size):
  93. """
  94. Merges attn_head_size dim and num_attn_heads dim into hidden_size
  95. """
  96. # re-assemble all head outputs side by side
  97. # (batch, num_heads, seq_len, attn_head_size) -> (batch, seq_len, num_heads*attn_head_size)
  98. tensor = tensor.transpose(1, 2).contiguous()
  99. tensor = tensor.view(tensor.size()[:-2] + (num_heads * attn_head_size,))
  100. return tensor
  101. def _attn(self, query, key, value, attention_mask=None, head_mask=None):
  102. # unlike GPTNeo's SelfAttention, divide by the square root of the dimension of the query and the key
  103. attn_weights = torch.matmul(query, key.transpose(-1, -2)) * (1.0 / math.sqrt(self.head_dim))
  104. if self.is_causal:
  105. query_length, key_length = query.size(-2), key.size(-2)
  106. # fill the upper left part of the attention weights with inf
  107. attn_weights = attn_weights.masked_fill(
  108. self.bias[:, :, key_length - query_length : key_length, :key_length] == 0,
  109. torch.finfo(attn_weights.dtype).min,
  110. )
  111. if attention_mask is not None:
  112. # Apply the attention mask
  113. attn_weights = attn_weights + attention_mask
  114. attn_weights = nn.functional.softmax(attn_weights, dim=-1)
  115. attn_weights = attn_weights.to(value.dtype)
  116. attn_weights = self.attn_dropout(attn_weights)
  117. # Mask heads if we want to
  118. if head_mask is not None:
  119. attn_weights = attn_weights * head_mask
  120. # (batch, num_heads, seq_len, seq_len) x (batch, num_heads, seq_len, attn_head_size)
  121. # -> (batch, num_heads, seq_len, attn_head_size)
  122. attn_output = torch.matmul(attn_weights, value)
  123. return attn_output, attn_weights
  124. def forward(
  125. self,
  126. hidden_states,
  127. attention_mask=None,
  128. past_key_values=None,
  129. head_mask=None,
  130. use_cache=False,
  131. output_attentions=False,
  132. cache_position=None,
  133. ):
  134. # calculate query, key, values for all heads in batch and move head forward to be the batch dim
  135. query, key, value = self.att_proj(hidden_states).split(self.embed_dim, dim=2)
  136. query = self._split_heads(query, self.num_heads, self.head_dim)
  137. key = self._split_heads(key, self.num_heads, self.head_dim)
  138. value = self._split_heads(value, self.num_heads, self.head_dim)
  139. if past_key_values is not None:
  140. key, value = past_key_values.update(key, value, self.layer_idx, {"cache_position": cache_position})
  141. attn_output, attn_weights = self._attn(query, key, value, attention_mask, head_mask)
  142. attn_output = self._merge_heads(attn_output, self.num_heads, self.head_dim)
  143. attn_output = self.out_proj(attn_output)
  144. attn_output = self.resid_dropout(attn_output)
  145. return attn_output, attn_weights
  146. class BarkSelfFlashAttention2(BarkSelfAttention):
  147. """
  148. Bark flash attention module. This module inherits from `BarkSelfAttention` as the weights of the module stays
  149. untouched. The only required change would be on the forward pass where it needs to correctly call the public API of
  150. flash attention and deal with padding tokens in case the input contains any of them.
  151. """
  152. def __init__(self, *args, **kwargs):
  153. super().__init__(*args, **kwargs)
  154. # TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1.
  155. # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignment, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0.
  156. # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left).
  157. self._flash_attn_uses_top_left_mask = flash_attn_supports_top_left_mask()
  158. def _split_heads(self, tensor, num_heads, attn_head_size):
  159. """
  160. Splits hidden_size dim into attn_head_size and num_heads
  161. """
  162. new_shape = tensor.size()[:-1] + (num_heads, attn_head_size)
  163. tensor = tensor.view(new_shape)
  164. # Flash attention requires the input to have the shape
  165. # batch_size x seq_length x head_dim x hidden_dim - (batch, seq_length, head, head_features)
  166. return tensor
  167. def _merge_heads(self, tensor, num_heads, attn_head_size):
  168. """
  169. Merges attn_head_size dim and num_attn_heads dim into hidden_size
  170. """
  171. # re-assemble all head outputs side by side
  172. # (batch, seq_len, num_heads, attn_head_size) -> (batch, seq_len, num_heads*attn_head_size)
  173. tensor = tensor.view(tensor.size()[:-2] + (num_heads * attn_head_size,))
  174. return tensor
  175. def forward(
  176. self,
  177. hidden_states,
  178. attention_mask=None,
  179. past_key_values=None,
  180. head_mask=None,
  181. use_cache=False,
  182. output_attentions=False,
  183. cache_position=None,
  184. ):
  185. batch_size, query_len, _ = hidden_states.size()
  186. # calculate query, key, values for all heads in batch and move head forward to be the batch dim
  187. query, key, value = self.att_proj(hidden_states).split(self.embed_dim, dim=2)
  188. query = self._split_heads(query, self.num_heads, self.head_dim)
  189. key = self._split_heads(key, self.num_heads, self.head_dim)
  190. value = self._split_heads(value, self.num_heads, self.head_dim)
  191. if past_key_values is not None:
  192. key, value = past_key_values.update(key, value, self.layer_idx, {"cache_position": cache_position})
  193. attn_output = _flash_attention_forward(
  194. query,
  195. key,
  196. value,
  197. attention_mask,
  198. query_len,
  199. dropout=self.dropout if self.training else 0.0,
  200. use_top_left_mask=self._flash_attn_uses_top_left_mask,
  201. is_causal=self.is_causal,
  202. )
  203. attn_output = self._merge_heads(attn_output, self.num_heads, self.head_dim)
  204. attn_output = self.out_proj(attn_output)
  205. attn_output = self.resid_dropout(attn_output)
  206. return attn_output, None
  207. BARK_ATTENTION_CLASSES = {
  208. "eager": BarkSelfAttention,
  209. "flash_attention_2": BarkSelfFlashAttention2,
  210. }
  211. class BarkMLP(nn.Module):
  212. def __init__(self, config):
  213. super().__init__()
  214. self.in_proj = nn.Linear(config.hidden_size, 4 * config.hidden_size, bias=config.bias)
  215. self.out_proj = nn.Linear(4 * config.hidden_size, config.hidden_size, bias=config.bias)
  216. self.dropout = nn.Dropout(config.dropout)
  217. self.gelu = nn.GELU()
  218. def forward(self, hidden_states):
  219. hidden_states = self.in_proj(hidden_states)
  220. hidden_states = self.gelu(hidden_states)
  221. hidden_states = self.out_proj(hidden_states)
  222. hidden_states = self.dropout(hidden_states)
  223. return hidden_states
  224. class BarkBlock(GradientCheckpointingLayer):
  225. def __init__(self, config, is_causal=False, layer_idx=None):
  226. super().__init__()
  227. if is_causal:
  228. # if causal, the layerNorm bias is optional to stick with Bark choice of leaving optional bias
  229. # in AutoRegressive models (corresponding to the "Text" and the "Coarse" modules)
  230. self.layernorm_1 = nn.LayerNorm(config.hidden_size, bias=config.bias)
  231. self.layernorm_2 = nn.LayerNorm(config.hidden_size, bias=config.bias)
  232. else:
  233. self.layernorm_1 = nn.LayerNorm(config.hidden_size)
  234. self.layernorm_2 = nn.LayerNorm(config.hidden_size)
  235. self.attn = BARK_ATTENTION_CLASSES[config._attn_implementation](
  236. config, is_causal=is_causal, layer_idx=layer_idx
  237. )
  238. self.mlp = BarkMLP(config)
  239. def forward(
  240. self,
  241. hidden_states,
  242. past_key_values=None,
  243. attention_mask=None,
  244. head_mask=None,
  245. use_cache=False,
  246. output_attentions=False,
  247. cache_position=None,
  248. ):
  249. intermediary_hidden_states = self.layernorm_1(hidden_states)
  250. attn_outputs = self.attn(
  251. intermediary_hidden_states,
  252. past_key_values=past_key_values,
  253. attention_mask=attention_mask,
  254. head_mask=head_mask,
  255. use_cache=use_cache,
  256. output_attentions=output_attentions,
  257. cache_position=cache_position,
  258. )
  259. attn_output = attn_outputs[0] # output_attn: output, present_key_values, (attn_weights)
  260. outputs = attn_outputs[1:]
  261. intermediary_hidden_states = hidden_states + attn_output
  262. intermediary_hidden_states = intermediary_hidden_states + self.mlp(
  263. self.layernorm_2(intermediary_hidden_states)
  264. )
  265. return (intermediary_hidden_states,) + outputs
  266. @auto_docstring
  267. class BarkPreTrainedModel(PreTrainedModel):
  268. config: BarkConfig
  269. supports_gradient_checkpointing = False
  270. _supports_flash_attn = True
  271. def _init_weights(self, module):
  272. """Initialize the weights."""
  273. if isinstance(module, (nn.Linear,)):
  274. # Slightly different from the TF version which uses truncated_normal for initialization
  275. # cf https://github.com/pytorch/pytorch/pull/5617
  276. module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
  277. if module.bias is not None:
  278. module.bias.data.zero_()
  279. elif isinstance(module, nn.Embedding):
  280. module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
  281. if module.padding_idx is not None:
  282. module.weight.data[module.padding_idx].zero_()
  283. elif isinstance(module, nn.LayerNorm):
  284. module.bias.data.zero_()
  285. module.weight.data.fill_(1.0)
  286. def __init__(self, *inputs, **kwargs):
  287. super().__init__(*inputs, **kwargs)
  288. @property
  289. def device(self) -> torch.device:
  290. """
  291. `torch.device`: The device on which the module is (assuming that all the module parameters are on the same
  292. device).
  293. """
  294. # if has _hf_hook, has been offloaded so the device has to be found in the hook
  295. if not hasattr(self, "_hf_hook"):
  296. return get_parameter_device(self)
  297. for module in self.modules():
  298. if (
  299. hasattr(module, "_hf_hook")
  300. and hasattr(module._hf_hook, "execution_device")
  301. and module._hf_hook.execution_device is not None
  302. ):
  303. return torch.device(module._hf_hook.execution_device)
  304. return get_parameter_device(self)
  305. # GPT2-like autoregressive model
  306. class BarkCausalModel(BarkPreTrainedModel, GenerationMixin):
  307. config: BarkSubModelConfig
  308. def __init__(self, config):
  309. super().__init__(config)
  310. self.config = config
  311. # initialize as an autoregressive GPT-like model
  312. self.input_embeds_layer = nn.Embedding(config.input_vocab_size, config.hidden_size)
  313. self.position_embeds_layer = nn.Embedding(config.block_size, config.hidden_size)
  314. self.drop = nn.Dropout(config.dropout)
  315. self.layers = nn.ModuleList([BarkBlock(config, is_causal=True, layer_idx=i) for i in range(config.num_layers)])
  316. self.layernorm_final = nn.LayerNorm(config.hidden_size, bias=config.bias)
  317. self.lm_head = nn.Linear(config.hidden_size, config.output_vocab_size, bias=False)
  318. self.gradient_checkpointing = False
  319. # Initialize weights and apply final processing
  320. self.post_init()
  321. def get_output_embeddings(self):
  322. # NOTE: get_output_embeddings() must return None to prevent accidental weight tying.
  323. # See e.g. https://github.com/huggingface/transformers/pull/39339#discussion_r2219126400
  324. return None
  325. def get_input_embeddings(self):
  326. return self.input_embeds_layer
  327. def set_input_embeddings(self, new_embeddings):
  328. self.input_embeds_layer = new_embeddings
  329. def prepare_inputs_for_generation(
  330. self,
  331. input_ids,
  332. attention_mask=None,
  333. input_embeds=None,
  334. past_key_values=None,
  335. position_ids=None,
  336. use_cache=None,
  337. cache_position=None,
  338. **kwargs,
  339. ):
  340. # Overwritten -- bark uses `input_embeds` not `inputS_embeds`
  341. model_inputs = super().prepare_inputs_for_generation(
  342. input_ids,
  343. attention_mask=attention_mask,
  344. inputs_embeds=input_embeds,
  345. past_key_values=past_key_values,
  346. position_ids=position_ids,
  347. use_cache=use_cache,
  348. cache_position=cache_position,
  349. **kwargs,
  350. )
  351. model_inputs["input_embeds"] = model_inputs.pop("inputs_embeds", None)
  352. return model_inputs
  353. @auto_docstring
  354. def forward(
  355. self,
  356. input_ids: Optional[torch.Tensor] = None,
  357. past_key_values: Optional[Cache] = None,
  358. attention_mask: Optional[torch.Tensor] = None,
  359. position_ids: Optional[torch.Tensor] = None,
  360. head_mask: Optional[torch.Tensor] = None,
  361. labels: Optional[torch.LongTensor] = None,
  362. input_embeds: Optional[torch.Tensor] = None,
  363. use_cache: Optional[bool] = None,
  364. output_attentions: Optional[bool] = None,
  365. output_hidden_states: Optional[bool] = None,
  366. return_dict: Optional[bool] = None,
  367. cache_position: Optional[torch.Tensor] = None,
  368. ) -> Union[tuple[torch.Tensor], CausalLMOutputWithPast]:
  369. r"""
  370. input_embeds (`torch.FloatTensor` of shape `(batch_size, input_sequence_length, hidden_size)`, *optional*):
  371. Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation.
  372. Here, due to `Bark` particularities, if `past_key_values` is used, `input_embeds` will be ignored and you
  373. have to use `input_ids`. If `past_key_values` is not used and `use_cache` is set to `True`, `input_embeds`
  374. is used in priority instead of `input_ids`.
  375. """
  376. output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
  377. output_hidden_states = (
  378. output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
  379. )
  380. use_cache = use_cache if use_cache is not None else self.config.use_cache
  381. return_dict = return_dict if return_dict is not None else self.config.use_return_dict
  382. loss = None
  383. if labels is not None:
  384. raise NotImplementedError(
  385. "Training is not implemented yet for Bark - ensure you do not pass `labels` to the model."
  386. )
  387. # Verify if input_embeds already exists
  388. # then compute embeddings.
  389. if input_ids is not None and input_embeds is not None:
  390. raise ValueError("You cannot specify both input_ids and input_embeds at the same time")
  391. elif input_embeds is not None and past_key_values is None:
  392. # we want to return the input_embeds in priority so that it is in line with a weird hack
  393. # of Bark which concatenate two bits of the input_embeds on the first forward pass of the semantic model
  394. pass
  395. elif input_ids is not None:
  396. input_embeds = self.input_embeds_layer(input_ids) # token embeddings of shape (b, t, n_embd)
  397. elif input_embeds is not None:
  398. pass
  399. else:
  400. raise ValueError("You have to specify either input_ids or input_embeds")
  401. input_shape = input_embeds.size()[:-1]
  402. batch_size = input_embeds.shape[0]
  403. seq_length = input_shape[-1]
  404. device = input_ids.device if input_ids is not None else input_embeds.device
  405. if self.gradient_checkpointing and self.training:
  406. if use_cache:
  407. logger.warning_once(
  408. "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
  409. )
  410. use_cache = False
  411. if use_cache and past_key_values is None:
  412. past_key_values = DynamicCache(config=self.config)
  413. if use_cache and isinstance(past_key_values, tuple):
  414. logger.warning_once(
  415. "Passing a tuple of `past_key_values` is deprecated and will be removed in Transformers v4.58.0. "
  416. "You should pass an instance of `DynamicCache` instead, e.g. "
  417. "`past_key_values=DynamicCache.from_legacy_cache(past_key_values)`."
  418. )
  419. past_key_values = DynamicCache.from_legacy_cache(past_key_values)
  420. past_length = past_key_values.get_seq_length() if past_key_values is not None else 0
  421. if position_ids is None:
  422. position_ids = torch.arange(past_length, seq_length + past_length, dtype=torch.long, device=device)
  423. position_ids = position_ids.unsqueeze(0) # shape (1, seq_length)
  424. position_embeds = self.position_embeds_layer(position_ids) # position embeddings of shape (1, t, n_embd)
  425. # Attention mask.
  426. if attention_mask is not None:
  427. if batch_size <= 0:
  428. raise ValueError("batch_size has to be defined and > 0")
  429. if self.config._attn_implementation == "flash_attention_2":
  430. attention_mask = attention_mask if 0 in attention_mask else None
  431. else:
  432. attention_mask = attention_mask.view(batch_size, -1)
  433. # [bsz, to_seq_length] -> [bsz, 1, 1, to_seq_length]
  434. # from_seq_length is 1 to easily broadcast
  435. attention_mask = _prepare_4d_attention_mask(attention_mask, input_embeds.dtype, tgt_len=1)
  436. # Prepare head mask if needed
  437. # 1.0 in head_mask indicate we keep the head
  438. # attention_probs has shape bsz x num_heads x N x N
  439. # head_mask has shape num_layers x batch x num_heads x N x N
  440. head_mask = self.get_head_mask(head_mask, self.config.num_layers)
  441. hidden_states = self.drop(input_embeds + position_embeds)
  442. output_shape = input_shape + (hidden_states.size(-1),)
  443. all_self_attentions = () if output_attentions else None
  444. all_hidden_states = () if output_hidden_states else None
  445. for i, block in enumerate(self.layers):
  446. if output_hidden_states:
  447. all_hidden_states = all_hidden_states + (hidden_states,)
  448. outputs = block(
  449. hidden_states,
  450. past_key_values=past_key_values,
  451. attention_mask=attention_mask,
  452. head_mask=head_mask[i],
  453. use_cache=use_cache,
  454. output_attentions=output_attentions,
  455. cache_position=cache_position,
  456. )
  457. hidden_states = outputs[0]
  458. if output_attentions:
  459. all_self_attentions = all_self_attentions + (outputs[1],)
  460. hidden_states = self.layernorm_final(hidden_states)
  461. hidden_states = hidden_states.view(output_shape)
  462. # Add last hidden state
  463. if output_hidden_states:
  464. all_hidden_states = all_hidden_states + (hidden_states,)
  465. logits = self.lm_head(hidden_states)
  466. if not return_dict:
  467. return tuple(
  468. v for v in [None, logits, past_key_values, all_hidden_states, all_self_attentions] if v is not None
  469. )
  470. return CausalLMOutputWithPast(
  471. loss=loss,
  472. logits=logits,
  473. past_key_values=past_key_values,
  474. hidden_states=all_hidden_states,
  475. attentions=all_self_attentions,
  476. )
  477. @auto_docstring(
  478. custom_intro="""
  479. Bark semantic (or text) model. It shares the same architecture as the coarse model.
  480. It is a GPT-2 like autoregressive model with a language modeling head on top.
  481. """
  482. )
  483. class BarkSemanticModel(BarkCausalModel):
  484. base_model_prefix = "semantic"
  485. config: BarkSemanticConfig
  486. def generate(
  487. self,
  488. input_ids: torch.Tensor,
  489. semantic_generation_config: Optional[BarkSemanticGenerationConfig] = None,
  490. history_prompt: Optional[dict[str, torch.Tensor]] = None,
  491. attention_mask: Optional[torch.Tensor] = None,
  492. **kwargs,
  493. ) -> torch.LongTensor:
  494. """
  495. Generates text semantic tokens from an input prompt and an additional optional `Bark` speaker prompt.
  496. Args:
  497. input_ids (`Optional[torch.Tensor]` of shape (batch_size, seq_len), *optional*):
  498. Input ids, i.e tokenized input sentences. Will be truncated up to
  499. semantic_generation_config.max_input_semantic_length tokens. Note that the output audios will be as
  500. long as the longest generation among the batch.
  501. semantic_generation_config (`BarkSemanticGenerationConfig`):
  502. Generation config indicating how to generate the semantic tokens.
  503. history_prompt (`Optional[dict[str,torch.Tensor]]`, *optional*):
  504. Optional `Bark` speaker prompt.
  505. attention_mask (`Optional[torch.Tensor]`, *optional*):
  506. Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
  507. - 1 for tokens that are **not masked**,
  508. - 0 for tokens that are **masked**.
  509. [What are attention masks?](../glossary#attention-mask)
  510. Returns:
  511. torch.LongTensor: Output semantic tokens.
  512. """
  513. if semantic_generation_config is None:
  514. raise ValueError("`semantic_generation_config` has to be provided")
  515. batch_size = input_ids.shape[0]
  516. max_input_semantic_length = semantic_generation_config.max_input_semantic_length
  517. input_ids = input_ids + semantic_generation_config.text_encoding_offset
  518. if attention_mask is not None:
  519. input_ids = input_ids.masked_fill((1 - attention_mask).bool(), semantic_generation_config.text_pad_token)
  520. if history_prompt is not None:
  521. semantic_history = history_prompt["semantic_prompt"][-max_input_semantic_length:]
  522. semantic_history = nn.functional.pad(
  523. semantic_history,
  524. (0, max_input_semantic_length - len(semantic_history)),
  525. value=semantic_generation_config.semantic_pad_token,
  526. mode="constant",
  527. )
  528. else:
  529. semantic_history = torch.tensor(
  530. [semantic_generation_config.semantic_pad_token] * max_input_semantic_length, dtype=torch.int
  531. ).to(self.device)
  532. semantic_history = torch.repeat_interleave(semantic_history[None], batch_size, dim=0)
  533. infer_array = torch.tensor(
  534. [[semantic_generation_config.semantic_infer_token]] * batch_size, dtype=torch.int
  535. ).to(self.device)
  536. input_embeds = torch.cat(
  537. [
  538. self.input_embeds_layer(input_ids[:, :max_input_semantic_length])
  539. + self.input_embeds_layer(semantic_history[:, : max_input_semantic_length + 1]),
  540. self.input_embeds_layer(infer_array),
  541. ],
  542. dim=1,
  543. )
  544. tokens_to_suppress = list(
  545. range(semantic_generation_config.semantic_vocab_size, semantic_generation_config.semantic_pad_token)
  546. )
  547. tokens_to_suppress.extend(
  548. list(range(semantic_generation_config.semantic_pad_token + 1, self.config.output_vocab_size))
  549. )
  550. suppress_tokens_logits_processor = SuppressTokensLogitsProcessor(tokens_to_suppress, device=input_ids.device)
  551. min_eos_p = kwargs.get("min_eos_p", semantic_generation_config.min_eos_p)
  552. early_stopping_logits_processor = BarkEosPrioritizerLogitsProcessor(
  553. eos_token_id=semantic_generation_config.eos_token_id, min_eos_p=min_eos_p, device=input_ids.device
  554. )
  555. # pass input_ids in order to stay consistent with the transformers generate method even though it is not used
  556. # (except to get the input seq_len - that's why we keep the first 257 tokens)
  557. semantic_output = super().generate(
  558. torch.ones((batch_size, max_input_semantic_length + 1), dtype=torch.int, device=self.device),
  559. input_embeds=input_embeds,
  560. logits_processor=[suppress_tokens_logits_processor, early_stopping_logits_processor],
  561. generation_config=semantic_generation_config,
  562. **kwargs,
  563. ) # size: 10048
  564. # take the generated semantic tokens
  565. semantic_output = semantic_output[:, max_input_semantic_length + 1 :]
  566. return semantic_output
  567. @auto_docstring(
  568. custom_intro="""
  569. Bark coarse acoustics model.
  570. It shares the same architecture as the semantic (or text) model. It is a GPT-2 like autoregressive model with a
  571. language modeling head on top.
  572. """
  573. )
  574. class BarkCoarseModel(BarkCausalModel):
  575. base_model_prefix = "coarse_acoustics"
  576. config: BarkCoarseConfig
  577. def preprocess_histories(
  578. self,
  579. max_coarse_history: int,
  580. semantic_to_coarse_ratio: int,
  581. batch_size: int,
  582. semantic_generation_config: int,
  583. codebook_size: int,
  584. history_prompt: Optional[dict[str, torch.Tensor]] = None,
  585. ):
  586. """
  587. Preprocess the optional `Bark` speaker prompts before `self.generate`.
  588. Args:
  589. max_coarse_history (`int`):
  590. Maximum size of coarse tokens used.
  591. semantic_to_coarse_ratio (`int`):
  592. Ratio of semantic to coarse frequency
  593. batch_size (`int`):
  594. Batch size, i.e the number of samples.
  595. semantic_generation_config (`BarkSemanticGenerationConfig`):
  596. Generation config indicating how to generate the semantic tokens.
  597. codebook_size (`int`):
  598. Codebook channel size, i.e. the size of the output vocabulary per codebook channel.
  599. history_prompt (`Optional[dict[str,torch.Tensor]]`):
  600. Optional `Bark` speaker prompt.
  601. Returns: Returns:
  602. `tuple(torch.FloatTensor)`:
  603. - **x_semantic_history** (`torch.FloatTensor` -- Processed semantic speaker prompt.
  604. - **x_coarse_history** (`torch.FloatTensor`) -- Processed coarse speaker prompt.
  605. """
  606. if history_prompt is not None:
  607. x_semantic_history = torch.repeat_interleave(history_prompt["semantic_prompt"][None], batch_size, dim=0)
  608. # clone to avoid modifying history_prompt.coarse_prompt
  609. x_coarse_history = history_prompt["coarse_prompt"].clone()
  610. # offset x_coarse_history
  611. if codebook_size is not None:
  612. for n in range(1, x_coarse_history.shape[0]):
  613. # offset
  614. x_coarse_history[n, :] += codebook_size * n
  615. # flatten x_coarse_history
  616. x_coarse_history = torch.transpose(x_coarse_history, 0, 1).reshape(-1)
  617. x_coarse_history = x_coarse_history + semantic_generation_config.semantic_vocab_size
  618. x_coarse_history = torch.repeat_interleave(x_coarse_history[None], batch_size, dim=0)
  619. # e.g: after SEMANTIC_VOCAB_SIZE (10000), 1024 tokens dedicated to first codebook, 1024 next tokens
  620. # dedicated to second codebook.
  621. max_semantic_history = int(np.floor(max_coarse_history / semantic_to_coarse_ratio))
  622. # trim histories correctly
  623. n_semantic_hist_provided = min(
  624. [
  625. max_semantic_history,
  626. x_semantic_history.shape[1] - x_semantic_history.shape[1] % 2,
  627. int(np.floor(x_coarse_history.shape[1] / semantic_to_coarse_ratio)),
  628. ]
  629. )
  630. n_coarse_hist_provided = int(round(n_semantic_hist_provided * semantic_to_coarse_ratio))
  631. x_semantic_history = x_semantic_history[:, -n_semantic_hist_provided:].int()
  632. x_coarse_history = x_coarse_history[:, -n_coarse_hist_provided:].int()
  633. # bit of a hack for time alignment (sounds better) - from Bark original implementation
  634. x_coarse_history = x_coarse_history[:, :-2]
  635. else:
  636. # shape: (batch_size, 0)
  637. x_semantic_history = torch.tensor([[]] * batch_size, dtype=torch.int, device=self.device)
  638. x_coarse_history = torch.tensor([[]] * batch_size, dtype=torch.int, device=self.device)
  639. return x_semantic_history, x_coarse_history
  640. def generate(
  641. self,
  642. semantic_output: torch.Tensor,
  643. semantic_generation_config: Optional[BarkSemanticGenerationConfig] = None,
  644. coarse_generation_config: Optional[BarkCoarseGenerationConfig] = None,
  645. codebook_size: int = 1024,
  646. history_prompt: Optional[dict[str, torch.Tensor]] = None,
  647. return_output_lengths: Optional[bool] = None,
  648. **kwargs,
  649. ) -> Union[torch.LongTensor, tuple[torch.LongTensor, torch.LongTensor]]:
  650. """
  651. Generates coarse acoustics tokens from input text semantic tokens and an additional optional `Bark` speaker
  652. prompt.
  653. Args:
  654. semantic_output (`torch.Tensor` of shape (batch_size, seq_len), *optional*):
  655. Input text semantic ids, i.e the output of `BarkSemanticModel.generate`.
  656. semantic_generation_config (`BarkSemanticGenerationConfig`):
  657. Generation config indicating how to generate the semantic tokens.
  658. coarse_generation_config (`BarkCoarseGenerationConfig`):
  659. Generation config indicating how to generate the coarse tokens.
  660. codebook_size (`int`, *optional*, defaults to 1024):
  661. Codebook channel size, i.e. the size of the output vocabulary per codebook channel.
  662. history_prompt (`Optional[dict[str,torch.Tensor]]`, *optional*):
  663. Optional `Bark` speaker prompt.
  664. return_output_lengths (`bool`, *optional*):
  665. Whether or not to return the output lengths. Useful when batching.
  666. Returns:
  667. By default:
  668. torch.LongTensor: Output coarse acoustics tokens.
  669. If `return_output_lengths=True`:
  670. `Tuple(torch.Tensor, torch.Tensor): The output coarse acoustics tokens, and the length of each sample
  671. of the batch.
  672. """
  673. if semantic_generation_config is None:
  674. raise ValueError("`semantic_generation_config` has to be provided")
  675. if coarse_generation_config is None:
  676. raise ValueError("`coarse_generation_config` has to be provided")
  677. max_coarse_input_length = coarse_generation_config.max_coarse_input_length
  678. max_coarse_history = coarse_generation_config.max_coarse_history
  679. sliding_window_len = coarse_generation_config.sliding_window_len
  680. # replace semantic_pad_token (eos_tok and pad_tok here) with coarse_semantic_pad_token i.e the pad_token
  681. # used in the next model
  682. semantic_output.masked_fill_(
  683. semantic_output == semantic_generation_config.semantic_pad_token,
  684. coarse_generation_config.coarse_semantic_pad_token,
  685. )
  686. semantic_to_coarse_ratio = (
  687. coarse_generation_config.coarse_rate_hz
  688. / semantic_generation_config.semantic_rate_hz
  689. * coarse_generation_config.n_coarse_codebooks
  690. )
  691. max_semantic_history = int(np.floor(max_coarse_history / semantic_to_coarse_ratio))
  692. output_lengths = (semantic_output != coarse_generation_config.coarse_semantic_pad_token).sum(1)
  693. output_lengths = torch.floor(
  694. output_lengths * semantic_to_coarse_ratio / coarse_generation_config.n_coarse_codebooks
  695. )
  696. output_lengths = torch.round(output_lengths * coarse_generation_config.n_coarse_codebooks).int()
  697. max_generated_len = torch.max(output_lengths).item()
  698. batch_size = semantic_output.shape[0]
  699. x_semantic_history, x_coarse = self.preprocess_histories(
  700. history_prompt=history_prompt,
  701. max_coarse_history=max_coarse_history,
  702. semantic_to_coarse_ratio=semantic_to_coarse_ratio,
  703. batch_size=batch_size,
  704. semantic_generation_config=semantic_generation_config,
  705. codebook_size=codebook_size,
  706. )
  707. base_semantic_idx = x_semantic_history.shape[1]
  708. semantic_output = torch.hstack([x_semantic_history, semantic_output])
  709. n_window_steps = int(np.ceil(max_generated_len / sliding_window_len))
  710. total_generated_len = 0
  711. len_coarse_history = x_coarse.shape[1]
  712. for _ in range(n_window_steps):
  713. semantic_idx = base_semantic_idx + int(round(total_generated_len / semantic_to_coarse_ratio))
  714. # pad from right side
  715. input_coarse = semantic_output[:, np.max([0, semantic_idx - max_semantic_history]) :]
  716. input_coarse = input_coarse[:, :max_coarse_input_length]
  717. input_coarse = F.pad(
  718. input_coarse,
  719. (0, max_coarse_input_length - input_coarse.shape[-1]),
  720. "constant",
  721. coarse_generation_config.coarse_semantic_pad_token,
  722. )
  723. input_coarse = torch.hstack(
  724. [
  725. input_coarse,
  726. torch.tensor([[coarse_generation_config.coarse_infer_token]] * batch_size, device=self.device),
  727. x_coarse[:, -max_coarse_history:],
  728. ]
  729. )
  730. alternatingLogitsProcessor = AlternatingCodebooksLogitsProcessor(
  731. input_coarse.shape[1],
  732. semantic_generation_config.semantic_vocab_size,
  733. codebook_size,
  734. )
  735. output_coarse = super().generate(
  736. input_coarse,
  737. logits_processor=[alternatingLogitsProcessor],
  738. max_new_tokens=min(sliding_window_len, max_generated_len - total_generated_len),
  739. generation_config=coarse_generation_config,
  740. **kwargs,
  741. )
  742. input_coarse_len = input_coarse.shape[1]
  743. x_coarse = torch.hstack([x_coarse, output_coarse[:, input_coarse_len:]])
  744. total_generated_len = x_coarse.shape[1] - len_coarse_history
  745. del output_coarse
  746. coarse_output = x_coarse[:, len_coarse_history:]
  747. if return_output_lengths:
  748. return coarse_output, output_lengths
  749. return coarse_output
  750. @auto_docstring(
  751. custom_intro="""
  752. Bark fine acoustics model. It is a non-causal GPT-like model with `config.n_codes_total` embedding layers and
  753. language modeling heads, one for each codebook.
  754. """
  755. )
  756. class BarkFineModel(BarkPreTrainedModel):
  757. base_model_prefix = "fine_acoustics"
  758. config: BarkFineConfig
  759. main_input_name = "codebook_idx"
  760. def __init__(self, config):
  761. # non-causal gpt-like model with one embedding layer and one lm_head for each codebook of Encodec
  762. super().__init__(config)
  763. self.config = config
  764. # initialize a modified non causal GPT-like model
  765. # note that for there is one embedding layer and one lm_head for each codebook of Encodec
  766. self.input_embeds_layers = nn.ModuleList(
  767. [nn.Embedding(config.input_vocab_size, config.hidden_size) for _ in range(config.n_codes_total)]
  768. )
  769. self.position_embeds_layer = nn.Embedding(config.block_size, config.hidden_size)
  770. self.drop = nn.Dropout(config.dropout)
  771. self.layers = nn.ModuleList(
  772. [BarkBlock(config, is_causal=False, layer_idx=i) for i in range(config.num_layers)]
  773. )
  774. self.layernorm_final = nn.LayerNorm(config.hidden_size)
  775. self.lm_heads = nn.ModuleList(
  776. [
  777. nn.Linear(config.hidden_size, config.output_vocab_size, bias=False)
  778. for _ in range(config.n_codes_given, config.n_codes_total)
  779. ]
  780. )
  781. self.gradient_checkpointing = False
  782. self.n_codes_total = config.n_codes_total
  783. # Initialize weights and apply final processing
  784. self.post_init()
  785. def get_input_embeddings(self):
  786. # one embedding layers for each codebook
  787. return self.input_embeds_layers
  788. def set_input_embeddings(self, new_embeddings):
  789. # one embedding layers for each codebook
  790. self.input_embeds_layers = new_embeddings
  791. def get_output_embeddings(self):
  792. # one lm_head for each codebook
  793. return self.lm_heads
  794. def set_output_embeddings(self, new_output_embeddings):
  795. # one lm_head for each codebook
  796. self.lm_heads = new_output_embeddings
  797. def _resize_token_embeddings(self, new_num_tokens, pad_to_multiple_of=None, mean_resizing=True):
  798. old_embeddings_list = self.get_input_embeddings()
  799. new_embeddings_list = nn.ModuleList(
  800. [
  801. self._get_resized_embeddings(old_embeddings, new_num_tokens, pad_to_multiple_of, mean_resizing)
  802. for old_embeddings in old_embeddings_list
  803. ]
  804. )
  805. self.set_input_embeddings(new_embeddings_list)
  806. new_num_tokens = new_embeddings_list[0].weight.shape[0]
  807. # if word embeddings are not tied, make sure that lm head is resized as well
  808. if self.get_output_embeddings() is not None and not self.config.tie_word_embeddings:
  809. old_lm_head_list = self.get_output_embeddings()
  810. new_lm_head_list = nn.ModuleList(
  811. [self._get_resized_lm_head(old_lm_head, new_num_tokens) for old_lm_head in old_lm_head_list]
  812. )
  813. self.set_output_embeddings(new_lm_head_list)
  814. return self.get_input_embeddings()
  815. def resize_token_embeddings(
  816. self,
  817. new_num_tokens: Optional[int] = None,
  818. pad_to_multiple_of: Optional[int] = None,
  819. mean_resizing: bool = True,
  820. ) -> nn.Embedding:
  821. """
  822. Resizes input token embeddings matrix of the model if `new_num_tokens != config.vocab_size`.
  823. Takes care of tying weights embeddings afterwards if the model class has a `tie_weights()` method.
  824. Arguments:
  825. new_num_tokens (`int`, *optional*):
  826. The number of new tokens in the embedding matrix. Increasing the size will add newly initialized
  827. vectors at the end. Reducing the size will remove vectors from the end. If not provided or `None`, just
  828. returns a pointer to the input tokens `torch.nn.Embedding` module of the model without doing anything.
  829. pad_to_multiple_of (`int`, *optional*):
  830. If set will pad the embedding matrix to a multiple of the provided value.
  831. This is especially useful to enable the use of Tensor Cores on NVIDIA hardware with compute capability
  832. `>= 7.5` (Volta), or on TPUs which benefit from having sequence lengths be a multiple of 128. For more
  833. details about this, or help on choosing the correct value for resizing, refer to this guide:
  834. https://docs.nvidia.com/deeplearning/performance/dl-performance-matrix-multiplication/index.html#requirements-tc
  835. mean_resizing (`bool`):
  836. Whether to initialize the added embeddings from a multivariate normal distribution that has old embeddings' mean and
  837. covariance or to initialize them with a normal distribution that has a mean of zero and std equals `config.initializer_range`.
  838. Setting `mean_resizing` to `True` is useful when increasing the size of the embeddings of causal language models,
  839. where the generated tokens' probabilities won't be affected by the added embeddings because initializing the new embeddings with the
  840. old embeddings' mean will reduce the kl-divergence between the next token probability before and after adding the new embeddings.
  841. Refer to this article for more information: https://nlp.stanford.edu/~johnhew/vocab-expansion.html
  842. Return:
  843. `torch.nn.Embedding`: Pointer to the input tokens Embeddings Module of the model.
  844. """
  845. model_embeds = self._resize_token_embeddings(new_num_tokens, pad_to_multiple_of, mean_resizing)
  846. if new_num_tokens is None and pad_to_multiple_of is None:
  847. return model_embeds
  848. # Update base model and current model config
  849. self.config.output_vocab_size = model_embeds[0].weight.shape[0]
  850. self.config.vocab_size = model_embeds[0].weight.shape[0]
  851. self.output_vocab_size = model_embeds[0].weight.shape[0]
  852. self.vocab_size = model_embeds[0].weight.shape[0]
  853. # Tie weights again if needed
  854. self.tie_weights()
  855. return model_embeds
  856. def _tie_weights(self):
  857. if getattr(self.config, "tie_word_embeddings", True):
  858. self._tied_weights_keys = []
  859. output_embeddings = self.get_output_embeddings()
  860. input_embeddings = self.get_input_embeddings()
  861. for i in range(self.config.n_codes_total - self.config.n_codes_given):
  862. # self.input_embeds_layers[i + 1].weight = self.lm_heads[i].weight
  863. self._tie_or_clone_weights(output_embeddings[i], input_embeddings[i + 1])
  864. self._tied_weights_keys.append(f"lm_heads.{i}.weight")
  865. def tie_weights(self):
  866. """
  867. Tie the weights between the input embeddings list and the output embeddings list.
  868. If the `torchscript` flag is set in the configuration, can't handle parameter sharing so we are cloning the
  869. weights instead.
  870. """
  871. for module in self.modules():
  872. if hasattr(module, "_tie_weights"):
  873. module._tie_weights()
  874. @auto_docstring
  875. def forward(
  876. self,
  877. codebook_idx: int, # an additional idx corresponding to the id of the codebook that will be predicted
  878. input_ids: Optional[torch.Tensor] = None,
  879. attention_mask: Optional[torch.Tensor] = None,
  880. position_ids: Optional[torch.Tensor] = None,
  881. head_mask: Optional[torch.Tensor] = None,
  882. labels: Optional[torch.LongTensor] = None,
  883. input_embeds: Optional[torch.Tensor] = None,
  884. output_attentions: Optional[bool] = None,
  885. output_hidden_states: Optional[bool] = None,
  886. return_dict: Optional[bool] = None,
  887. ) -> Union[tuple[torch.Tensor], MaskedLMOutput]:
  888. r"""
  889. codebook_idx (`int`):
  890. Index of the codebook that will be predicted.
  891. labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
  892. NOT IMPLEMENTED YET.
  893. input_embeds (`torch.FloatTensor` of shape `(batch_size, input_sequence_length, hidden_size)`, *optional*):
  894. Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. If
  895. `past_key_values` is used, optionally only the last `input_embeds` have to be input (see
  896. `past_key_values`). This is useful if you want more control over how to convert `input_ids` indices into
  897. associated vectors than the model's internal embedding lookup matrix.
  898. """
  899. output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
  900. output_hidden_states = (
  901. output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
  902. )
  903. return_dict = return_dict if return_dict is not None else self.config.use_return_dict
  904. loss = None
  905. if labels is not None:
  906. raise NotImplementedError("Training is not implemented yet")
  907. if codebook_idx == 0:
  908. raise ValueError("Cannot predict 0th codebook - 0th codebook should be predicted by the coarse model")
  909. if input_ids is not None and input_embeds is not None:
  910. raise ValueError("You cannot specify both input_ids and input_embeds at the same time")
  911. if input_ids is None and input_embeds is None:
  912. raise ValueError("You have to specify either input_ids or input_embeds")
  913. if input_ids is not None:
  914. # the input_embeddings are the sum of the j previous codebooks embeddings before
  915. # the current codebook_idx codebook
  916. # forward the GPT model itself
  917. input_embeds = [
  918. input_embeds_layer(input_ids[:, :, i]).unsqueeze(-1)
  919. for i, input_embeds_layer in enumerate(self.input_embeds_layers)
  920. ] # token embeddings of shape (b, t, n_embd)
  921. input_embeds = torch.cat(input_embeds, dim=-1)
  922. input_embeds = input_embeds[:, :, :, : codebook_idx + 1].sum(dim=-1)
  923. input_shape = input_embeds.size()[:-1]
  924. batch_size = input_embeds.shape[0]
  925. seq_length = input_shape[1]
  926. device = input_ids.device if input_ids is not None else input_embeds.device
  927. if position_ids is None:
  928. position_ids = torch.arange(0, seq_length, dtype=torch.long, device=device)
  929. position_ids = position_ids.unsqueeze(0) # shape (1, seq_length)
  930. position_embeds = self.position_embeds_layer(position_ids) # position embeddings of shape (1, t, n_embd)
  931. # Attention mask.
  932. if attention_mask is not None:
  933. if batch_size <= 0:
  934. raise ValueError("batch_size has to be defined and > 0")
  935. if self.config._attn_implementation == "flash_attention_2":
  936. attention_mask = attention_mask if 0 in attention_mask else None
  937. else:
  938. # [bsz, to_seq_length] -> [bsz, 1, 1, to_seq_length]
  939. # from_seq_length is 1 to easily broadcast
  940. attention_mask = _prepare_4d_attention_mask(attention_mask, input_embeds.dtype, tgt_len=1)
  941. head_mask = self.get_head_mask(head_mask, self.config.num_layers)
  942. hidden_states = self.drop(input_embeds + position_embeds)
  943. output_shape = input_shape + (hidden_states.size(-1),)
  944. all_self_attentions = () if output_attentions else None
  945. all_hidden_states = () if output_hidden_states else None
  946. for i, block in enumerate(self.layers):
  947. if output_hidden_states:
  948. all_hidden_states = all_hidden_states + (hidden_states,)
  949. outputs = block(
  950. hidden_states,
  951. attention_mask=attention_mask,
  952. head_mask=head_mask[i],
  953. output_attentions=output_attentions,
  954. )
  955. hidden_states = outputs[0]
  956. if output_attentions:
  957. all_self_attentions = all_self_attentions + (outputs[1],)
  958. hidden_states = self.layernorm_final(hidden_states)
  959. hidden_states = hidden_states.view(output_shape)
  960. # Add last hidden state
  961. if output_hidden_states:
  962. all_hidden_states = all_hidden_states + (hidden_states,)
  963. logits = self.lm_heads[codebook_idx - self.config.n_codes_given](hidden_states)
  964. if not return_dict:
  965. return tuple(v for v in [None, logits, all_hidden_states, all_self_attentions] if v is not None)
  966. return MaskedLMOutput(
  967. loss=loss,
  968. logits=logits,
  969. hidden_states=all_hidden_states,
  970. attentions=all_self_attentions,
  971. )
  972. @torch.no_grad()
  973. def generate(
  974. self,
  975. coarse_output: torch.Tensor,
  976. semantic_generation_config: Optional[BarkSemanticGenerationConfig] = None,
  977. coarse_generation_config: Optional[BarkCoarseGenerationConfig] = None,
  978. fine_generation_config: BarkFineGenerationConfig = None,
  979. codebook_size: int = 1024,
  980. history_prompt: Optional[dict[str, torch.Tensor]] = None,
  981. **kwargs,
  982. ) -> torch.LongTensor:
  983. """
  984. Generates fine acoustics tokens from input coarse acoustics tokens and an additional optional `Bark` speaker
  985. prompt.
  986. Args:
  987. coarse_output (`torch.Tensor` of shape (batch_size, seq_len)):
  988. Input coarse acoustics ids, i.e the output of `BarkCoarseModel.generate`.
  989. semantic_generation_config (`BarkSemanticGenerationConfig`):
  990. Generation config indicating how to generate the semantic tokens.
  991. coarse_generation_config (`BarkCoarseGenerationConfig`):
  992. Generation config indicating how to generate the coarse tokens.
  993. fine_generation_config (`BarkFineGenerationConfig`):
  994. Generation config indicating how to generate the fine tokens.
  995. codebook_size (`int`, *optional*, defaults to 1024):
  996. Codebook channel size, i.e. the size of the output vocabulary per codebook channel.
  997. history_prompt (`Optional[dict[str,torch.Tensor]]`, *optional*):
  998. Optional `Bark` speaker prompt.
  999. Returns:
  1000. torch.LongTensor: Output fine acoustics tokens.
  1001. """
  1002. if semantic_generation_config is None:
  1003. raise ValueError("`semantic_generation_config` has to be provided")
  1004. if coarse_generation_config is None:
  1005. raise ValueError("`coarse_generation_config` has to be provided")
  1006. if fine_generation_config is None:
  1007. raise ValueError("`fine_generation_config` has to be provided")
  1008. # since we don't really use GenerationConfig through the fine model (autoencoder)
  1009. # and since only temperature is used from the classic GenerationConfig parameters
  1010. # manually impose the kwargs priority over the generation config
  1011. temperature = kwargs.get("temperature", fine_generation_config.temperature)
  1012. max_fine_history_length = fine_generation_config.max_fine_history_length
  1013. max_fine_input_length = fine_generation_config.max_fine_input_length
  1014. # shape: (batch, n_coarse_codebooks * seq_len)
  1015. # new_shape: (batch, seq_len, n_coarse_codebooks)
  1016. coarse_output = coarse_output.view(coarse_output.shape[0], -1, coarse_generation_config.n_coarse_codebooks)
  1017. # brings ids into the range [0, codebook_size -1]
  1018. coarse_output = torch.remainder(coarse_output - semantic_generation_config.semantic_vocab_size, codebook_size)
  1019. batch_size = coarse_output.shape[0]
  1020. if history_prompt is not None:
  1021. x_fine_history = torch.repeat_interleave(history_prompt["fine_prompt"].T[None], batch_size, dim=0)
  1022. # transpose to get to shape (seq_len, n_fine_codebooks)
  1023. else:
  1024. x_fine_history = None
  1025. n_coarse = coarse_generation_config.n_coarse_codebooks
  1026. # pad the last 6th codebooks
  1027. fine_input = F.pad(
  1028. coarse_output,
  1029. (0, fine_generation_config.n_fine_codebooks - n_coarse),
  1030. "constant",
  1031. codebook_size,
  1032. )
  1033. # prepend history if available (max max_fine_history_length)
  1034. if x_fine_history is not None:
  1035. fine_input = torch.cat([x_fine_history[:, -max_fine_history_length:, :], fine_input], dim=1)
  1036. # len of the fine_history that has been added to fine_input
  1037. n_history = x_fine_history[:, -max_fine_history_length:, :].shape[1]
  1038. else:
  1039. n_history = 0
  1040. n_remove_from_end = 0
  1041. # need to pad if too short (since non-causal model)
  1042. if fine_input.shape[1] < max_fine_input_length:
  1043. n_remove_from_end = max_fine_input_length - fine_input.shape[1]
  1044. fine_input = F.pad(fine_input, (0, 0, 0, n_remove_from_end), mode="constant", value=codebook_size)
  1045. # we can be lazy about fractional loop and just keep overwriting codebooks.
  1046. # seems that coarse_output.shape[1] - (max_fine_input_length - n_history) is equal to minus n_remove_from_end
  1047. # So if we needed to pad because too short, n_loops is always 1 (because n_remove_from_end > 0)
  1048. # If not, we loop over at least twice.
  1049. n_loops = (coarse_output.shape[1] - (max_fine_input_length - n_history)) / max_fine_history_length
  1050. n_loops = int(np.ceil(n_loops))
  1051. n_loops = max(0, n_loops) + 1
  1052. for n_outer in range(n_loops):
  1053. start_idx = min([n_outer * max_fine_history_length, fine_input.shape[1] - max_fine_input_length])
  1054. start_fill_idx = min(
  1055. [n_history + n_outer * max_fine_history_length, fine_input.shape[1] - max_fine_history_length]
  1056. )
  1057. rel_start_fill_idx = start_fill_idx - start_idx
  1058. input_buffer = fine_input[:, start_idx : start_idx + max_fine_input_length, :]
  1059. for n_inner in range(n_coarse, fine_generation_config.n_fine_codebooks):
  1060. logits = self.forward(n_inner, input_buffer).logits
  1061. if temperature is None or temperature == 1.0:
  1062. relevant_logits = logits[:, rel_start_fill_idx:, :codebook_size]
  1063. codebook_preds = torch.argmax(relevant_logits, -1)
  1064. else:
  1065. relevant_logits = logits[:, :, :codebook_size] / temperature
  1066. # apply softmax
  1067. probs = F.softmax(relevant_logits, dim=-1)[:, rel_start_fill_idx:max_fine_input_length]
  1068. # reshape to 2D: (batch_size, seq_len, codebook_size) -> (batch_size*seq_len, codebook_size)
  1069. probs = probs.reshape((-1, codebook_size))
  1070. # multinomial then reshape : (batch_size*seq_len)-> (batch_size,seq_len)
  1071. codebook_preds = torch.multinomial(probs, num_samples=1).view(batch_size, -1)
  1072. codebook_preds = codebook_preds.to(torch.int32)
  1073. input_buffer[:, rel_start_fill_idx:, n_inner] = codebook_preds
  1074. del logits, codebook_preds
  1075. # transfer into fine_input
  1076. for n_inner in range(n_coarse, fine_generation_config.n_fine_codebooks):
  1077. fine_input[
  1078. :, start_fill_idx : start_fill_idx + (max_fine_input_length - rel_start_fill_idx), n_inner
  1079. ] = input_buffer[:, rel_start_fill_idx:, n_inner]
  1080. del input_buffer
  1081. fine_input = fine_input.transpose(1, 2)[:, :, n_history:]
  1082. if n_remove_from_end > 0:
  1083. fine_input = fine_input[:, :, :-n_remove_from_end]
  1084. if fine_input.shape[-1] != coarse_output.shape[-2]:
  1085. raise ValueError("input and output should have the same seq_len")
  1086. return fine_input
  1087. @auto_docstring(
  1088. custom_intro="""
  1089. The full Bark model, a text-to-speech model composed of 4 sub-models:
  1090. - [`BarkSemanticModel`] (also referred to as the 'text' model): a causal auto-regressive transformer model that
  1091. takes
  1092. as input tokenized text, and predicts semantic text tokens that capture the meaning of the text.
  1093. - [`BarkCoarseModel`] (also referred to as the 'coarse acoustics' model), also a causal autoregressive transformer,
  1094. that takes into input the results of the last model. It aims at regressing the first two audio codebooks necessary
  1095. to `encodec`.
  1096. - [`BarkFineModel`] (the 'fine acoustics' model), this time a non-causal autoencoder transformer, which iteratively
  1097. predicts the last codebooks based on the sum of the previous codebooks embeddings.
  1098. - having predicted all the codebook channels from the [`EncodecModel`], Bark uses it to decode the output audio
  1099. array.
  1100. It should be noted that each of the first three modules can support conditional speaker embeddings to condition the
  1101. output sound according to specific predefined voice.
  1102. """
  1103. )
  1104. class BarkModel(BarkPreTrainedModel):
  1105. config: BarkConfig
  1106. def __init__(self, config):
  1107. super().__init__(config)
  1108. self.semantic = BarkSemanticModel(config.semantic_config)
  1109. self.coarse_acoustics = BarkCoarseModel(config.coarse_acoustics_config)
  1110. self.fine_acoustics = BarkFineModel(config.fine_acoustics_config)
  1111. self.codec_model = AutoModel.from_config(config.codec_config)
  1112. self.config = config
  1113. @classmethod
  1114. def can_generate(cls) -> bool:
  1115. # Bark has a unique model structure, where the external class (`BarkModel`) doesn't need to inherit from
  1116. # `GenerationMixin` (it has a non-standard generation method), but one of the internal models do
  1117. # (`BarkSemanticModel`). This means that the base `can_generate()` will return `False`, but we need to
  1118. # override it so as to do `GenerationConfig` handling in multiple parts of the codebase.
  1119. return True
  1120. @property
  1121. def device(self) -> torch.device:
  1122. """
  1123. `torch.device`: The device on which the module is (assuming that all the module parameters are on the same
  1124. device).
  1125. """
  1126. # for bark_model, device must be verified on its sub-models
  1127. # if has _hf_hook, has been offloaded so the device has to be found in the hook
  1128. if not hasattr(self.semantic, "_hf_hook"):
  1129. return get_parameter_device(self)
  1130. for module in self.semantic.modules():
  1131. if (
  1132. hasattr(module, "_hf_hook")
  1133. and hasattr(module._hf_hook, "execution_device")
  1134. and module._hf_hook.execution_device is not None
  1135. ):
  1136. return torch.device(module._hf_hook.execution_device)
  1137. def enable_cpu_offload(
  1138. self,
  1139. accelerator_id: Optional[int] = 0,
  1140. **kwargs,
  1141. ):
  1142. r"""
  1143. Offloads all sub-models to CPU using accelerate, reducing memory usage with a low impact on performance. This
  1144. method moves one whole sub-model at a time to the accelerator when it is used, and the sub-model remains in accelerator until the next sub-model runs.
  1145. Args:
  1146. accelerator_id (`int`, *optional*, defaults to 0):
  1147. accelerator id on which the sub-models will be loaded and offloaded. This argument is deprecated.
  1148. kwargs (`dict`, *optional*):
  1149. additional keyword arguments:
  1150. `gpu_id`: accelerator id on which the sub-models will be loaded and offloaded.
  1151. """
  1152. if is_accelerate_available():
  1153. from accelerate import cpu_offload_with_hook
  1154. else:
  1155. raise ImportError("`enable_model_cpu_offload` requires `accelerate`.")
  1156. gpu_id = kwargs.get("gpu_id", 0)
  1157. if gpu_id != 0:
  1158. warnings.warn(
  1159. "The argument `gpu_id` is deprecated and will be removed in version 4.54.0 of Transformers. Please use `accelerator_id` instead.",
  1160. FutureWarning,
  1161. )
  1162. accelerator_id = gpu_id
  1163. device_type = "cuda"
  1164. if is_torch_accelerator_available():
  1165. device_type = torch.accelerator.current_accelerator().type
  1166. device = torch.device(f"{device_type}:{accelerator_id}")
  1167. torch_accelerator_module = getattr(torch, device_type)
  1168. if self.device.type != "cpu":
  1169. self.to("cpu")
  1170. torch_accelerator_module.empty_cache() # otherwise we don't see the memory savings (but they probably exist)
  1171. # this layer is used outside the first forward pass of semantic so need to be loaded before semantic
  1172. self.semantic.input_embeds_layer, _ = cpu_offload_with_hook(self.semantic.input_embeds_layer, device)
  1173. hook = None
  1174. for cpu_offloaded_model in [
  1175. self.semantic,
  1176. self.coarse_acoustics,
  1177. self.fine_acoustics,
  1178. ]:
  1179. _, hook = cpu_offload_with_hook(cpu_offloaded_model, device, prev_module_hook=hook)
  1180. self.fine_acoustics_hook = hook
  1181. _, hook = cpu_offload_with_hook(self.codec_model, device, prev_module_hook=hook)
  1182. # We'll offload the last model manually.
  1183. self.codec_model_hook = hook
  1184. def codec_decode(self, fine_output, output_lengths=None):
  1185. """Turn quantized audio codes into audio array using encodec."""
  1186. fine_output = fine_output.transpose(0, 1)
  1187. emb = self.codec_model.quantizer.decode(fine_output)
  1188. if output_lengths is not None:
  1189. # encodec uses LSTMs which behaves differently with appended padding
  1190. # decoding with encodec takes around 0.1% of the total generation time
  1191. # to keep generation quality, we break batching
  1192. out = [sample[:, :l].unsqueeze(0) for (sample, l) in zip(emb, output_lengths)]
  1193. audio_arr = [self.codec_model.decoder(sample).squeeze() for sample in out]
  1194. else:
  1195. out = self.codec_model.decoder(emb)
  1196. audio_arr = out.squeeze(1) # squeeze the codebook dimension
  1197. return audio_arr
  1198. @torch.no_grad()
  1199. def generate(
  1200. self,
  1201. input_ids: Optional[torch.Tensor] = None,
  1202. history_prompt: Optional[dict[str, torch.Tensor]] = None,
  1203. return_output_lengths: Optional[bool] = None,
  1204. **kwargs,
  1205. ) -> torch.LongTensor:
  1206. """
  1207. Generates audio from an input prompt and an additional optional `Bark` speaker prompt.
  1208. Args:
  1209. input_ids (`Optional[torch.Tensor]` of shape (batch_size, seq_len), *optional*):
  1210. Input ids. Will be truncated up to 256 tokens. Note that the output audios will be as long as the
  1211. longest generation among the batch.
  1212. history_prompt (`Optional[dict[str,torch.Tensor]]`, *optional*):
  1213. Optional `Bark` speaker prompt. Note that for now, this model takes only one speaker prompt per batch.
  1214. kwargs (*optional*): Remaining dictionary of keyword arguments. Keyword arguments are of two types:
  1215. - Without a prefix, they will be entered as `**kwargs` for the `generate` method of each sub-model.
  1216. - With a *semantic_*, *coarse_*, *fine_* prefix, they will be input for the `generate` method of the
  1217. semantic, coarse and fine respectively. It has the priority over the keywords without a prefix.
  1218. This means you can, for example, specify a generation strategy for all sub-models except one.
  1219. return_output_lengths (`bool`, *optional*):
  1220. Whether or not to return the waveform lengths. Useful when batching.
  1221. Returns:
  1222. By default:
  1223. - **audio_waveform** (`torch.Tensor` of shape (batch_size, seq_len)): Generated audio waveform.
  1224. When `return_output_lengths=True`:
  1225. Returns a tuple made of:
  1226. - **audio_waveform** (`torch.Tensor` of shape (batch_size, seq_len)): Generated audio waveform.
  1227. - **output_lengths** (`torch.Tensor` of shape (batch_size)): The length of each waveform in the batch
  1228. Example:
  1229. ```python
  1230. >>> from transformers import AutoProcessor, BarkModel
  1231. >>> processor = AutoProcessor.from_pretrained("suno/bark-small")
  1232. >>> model = BarkModel.from_pretrained("suno/bark-small")
  1233. >>> # To add a voice preset, you can pass `voice_preset` to `BarkProcessor.__call__(...)`
  1234. >>> voice_preset = "v2/en_speaker_6"
  1235. >>> inputs = processor("Hello, my dog is cute, I need him in my life", voice_preset=voice_preset)
  1236. >>> audio_array = model.generate(**inputs, semantic_max_new_tokens=100)
  1237. >>> audio_array = audio_array.cpu().numpy().squeeze()
  1238. ```
  1239. """
  1240. # TODO (joao):workaround until nested generation config is compatible with PreTrained Model
  1241. # todo: dict
  1242. semantic_generation_config = BarkSemanticGenerationConfig(**self.generation_config.semantic_config)
  1243. coarse_generation_config = BarkCoarseGenerationConfig(**self.generation_config.coarse_acoustics_config)
  1244. fine_generation_config = BarkFineGenerationConfig(**self.generation_config.fine_acoustics_config)
  1245. kwargs_semantic = {
  1246. # if "attention_mask" is set, it should not be passed to CoarseModel and FineModel
  1247. "attention_mask": kwargs.pop("attention_mask", None),
  1248. "min_eos_p": kwargs.pop("min_eos_p", None),
  1249. }
  1250. kwargs_coarse = {}
  1251. kwargs_fine = {}
  1252. for key, value in kwargs.items():
  1253. if key.startswith("semantic_"):
  1254. key = key[len("semantic_") :]
  1255. kwargs_semantic[key] = value
  1256. elif key.startswith("coarse_"):
  1257. key = key[len("coarse_") :]
  1258. kwargs_coarse[key] = value
  1259. elif key.startswith("fine_"):
  1260. key = key[len("fine_") :]
  1261. kwargs_fine[key] = value
  1262. else:
  1263. # If the key is already in a specific config, then it's been set with a
  1264. # submodules specific value and we don't override
  1265. if key not in kwargs_semantic:
  1266. kwargs_semantic[key] = value
  1267. if key not in kwargs_coarse:
  1268. kwargs_coarse[key] = value
  1269. if key not in kwargs_fine:
  1270. kwargs_fine[key] = value
  1271. # 1. Generate from the semantic model
  1272. if "generation_config" in kwargs_semantic:
  1273. kwargs_semantic.pop("generation_config")
  1274. semantic_output = self.semantic.generate(
  1275. input_ids,
  1276. history_prompt=history_prompt,
  1277. semantic_generation_config=semantic_generation_config,
  1278. **kwargs_semantic,
  1279. )
  1280. # 2. Generate from the coarse model
  1281. if "generation_config" in kwargs_coarse:
  1282. kwargs_coarse.pop("generation_config")
  1283. coarse_output = self.coarse_acoustics.generate(
  1284. semantic_output,
  1285. history_prompt=history_prompt,
  1286. semantic_generation_config=semantic_generation_config,
  1287. coarse_generation_config=coarse_generation_config,
  1288. codebook_size=self.generation_config.codebook_size,
  1289. return_output_lengths=return_output_lengths,
  1290. **kwargs_coarse,
  1291. )
  1292. output_lengths = None
  1293. if return_output_lengths:
  1294. coarse_output, output_lengths = coarse_output
  1295. # (batch_size, seq_len*coarse_codebooks) -> (batch_size, seq_len)
  1296. output_lengths = output_lengths // coarse_generation_config.n_coarse_codebooks
  1297. # 3. "generate" from the fine model
  1298. if "generation_config" in kwargs_fine:
  1299. kwargs_fine.pop("generation_config")
  1300. output = self.fine_acoustics.generate(
  1301. coarse_output,
  1302. history_prompt=history_prompt,
  1303. semantic_generation_config=semantic_generation_config,
  1304. coarse_generation_config=coarse_generation_config,
  1305. fine_generation_config=fine_generation_config,
  1306. codebook_size=self.generation_config.codebook_size,
  1307. **kwargs_fine,
  1308. )
  1309. if getattr(self, "fine_acoustics_hook", None) is not None:
  1310. # Manually offload fine_acoustics to CPU
  1311. # and load codec_model to GPU
  1312. # since bark doesn't use codec_model forward pass
  1313. self.fine_acoustics_hook.offload()
  1314. self.codec_model = self.codec_model.to(self.device)
  1315. # 4. Decode the output and generate audio array
  1316. audio = self.codec_decode(output, output_lengths)
  1317. if getattr(self, "codec_model_hook", None) is not None:
  1318. # Offload codec_model to CPU
  1319. self.codec_model_hook.offload()
  1320. if return_output_lengths:
  1321. output_lengths = [len(sample) for sample in audio]
  1322. audio = nn.utils.rnn.pad_sequence(audio, batch_first=True, padding_value=0)
  1323. return audio, output_lengths
  1324. return audio
  1325. def tie_weights(self):
  1326. """
  1327. Tie the weights between the input embeddings list and the output embeddings list.
  1328. If the `torchscript` flag is set in the configuration, can't handle parameter sharing so we are cloning the
  1329. weights instead.
  1330. """
  1331. for module in self.modules():
  1332. if hasattr(module, "_tie_weights"):
  1333. module._tie_weights()
  1334. __all__ = [
  1335. "BarkFineModel",
  1336. "BarkSemanticModel",
  1337. "BarkCoarseModel",
  1338. "BarkModel",
  1339. "BarkPreTrainedModel",
  1340. "BarkCausalModel",
  1341. ]