| 12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160116111621163116411651166116711681169117011711172117311741175117611771178117911801181118211831184118511861187118811891190119111921193119411951196119711981199120012011202120312041205120612071208120912101211121212131214121512161217121812191220122112221223122412251226122712281229123012311232123312341235123612371238123912401241124212431244124512461247124812491250125112521253125412551256125712581259126012611262126312641265126612671268126912701271127212731274127512761277127812791280128112821283128412851286128712881289129012911292129312941295129612971298129913001301130213031304130513061307130813091310131113121313131413151316131713181319132013211322132313241325132613271328132913301331133213331334133513361337133813391340134113421343134413451346134713481349135013511352135313541355135613571358135913601361136213631364136513661367136813691370137113721373137413751376137713781379138013811382138313841385138613871388138913901391139213931394139513961397139813991400140114021403140414051406140714081409141014111412141314141415141614171418141914201421142214231424142514261427142814291430143114321433143414351436143714381439144014411442144314441445144614471448144914501451145214531454145514561457145814591460146114621463146414651466146714681469147014711472147314741475147614771478147914801481148214831484148514861487148814891490149114921493149414951496149714981499150015011502150315041505150615071508150915101511151215131514151515161517151815191520152115221523152415251526152715281529153015311532153315341535153615371538153915401541154215431544154515461547154815491550155115521553155415551556155715581559156015611562156315641565156615671568156915701571157215731574157515761577157815791580158115821583158415851586158715881589159015911592159315941595159615971598159916001601160216031604160516061607160816091610161116121613161416151616161716181619162016211622162316241625162616271628 |
- # coding=utf-8
- # Copyright 2023 The Suno AI Authors and The HuggingFace Inc. team. All rights reserved.
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- """PyTorch BARK model."""
- import math
- import warnings
- from typing import Optional, Union
- import numpy as np
- import torch
- from torch import nn
- from torch.nn import functional as F
- from ...cache_utils import Cache, DynamicCache
- from ...generation import GenerationMixin
- from ...generation.logits_process import (
- AlternatingCodebooksLogitsProcessor,
- BarkEosPrioritizerLogitsProcessor,
- SuppressTokensLogitsProcessor,
- )
- from ...modeling_attn_mask_utils import _prepare_4d_attention_mask
- from ...modeling_flash_attention_utils import flash_attn_supports_top_left_mask, is_flash_attn_available
- from ...modeling_layers import GradientCheckpointingLayer
- from ...modeling_outputs import CausalLMOutputWithPast, MaskedLMOutput
- from ...modeling_utils import PreTrainedModel, get_parameter_device
- from ...utils import (
- auto_docstring,
- is_accelerate_available,
- is_torch_accelerator_available,
- logging,
- )
- from ..auto import AutoModel
- from .configuration_bark import (
- BarkCoarseConfig,
- BarkConfig,
- BarkFineConfig,
- BarkSemanticConfig,
- BarkSubModelConfig,
- )
- from .generation_configuration_bark import (
- BarkCoarseGenerationConfig,
- BarkFineGenerationConfig,
- BarkSemanticGenerationConfig,
- )
- if is_flash_attn_available():
- from ...modeling_flash_attention_utils import _flash_attention_forward
- logger = logging.get_logger(__name__)
- class BarkSelfAttention(nn.Module):
- # adapted from GPTNeoSelfAttention and Bark code
- # BarkSelfAttention can have two attention type, i.e full attention or causal attention
- def __init__(self, config, is_causal=False, layer_idx=None):
- super().__init__()
- # regularization
- self.dropout = config.dropout
- self.attn_dropout = nn.Dropout(config.dropout)
- self.resid_dropout = nn.Dropout(config.dropout)
- self.embed_dim = config.hidden_size
- self.num_heads = config.num_heads
- self.head_dim = self.embed_dim // self.num_heads
- if config.hidden_size % config.num_heads != 0:
- raise ValueError(
- f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`:"
- f" {self.num_heads})."
- )
- # key, query, value projections for all heads, but in a batch
- self.att_proj = nn.Linear(config.hidden_size, 3 * config.hidden_size, bias=config.bias)
- # output projection
- self.out_proj = nn.Linear(config.hidden_size, config.hidden_size, bias=config.bias)
- self.is_causal = is_causal
- self.layer_idx = layer_idx
- if is_causal:
- block_size = config.block_size
- bias = torch.tril(torch.ones((block_size, block_size), dtype=bool)).view(1, 1, block_size, block_size)
- self.register_buffer("bias", bias)
- # Copied from transformers.models.gpt_neo.modeling_gpt_neo.GPTNeoSelfAttention._split_heads
- def _split_heads(self, tensor, num_heads, attn_head_size):
- """
- Splits hidden_size dim into attn_head_size and num_heads
- """
- new_shape = tensor.size()[:-1] + (num_heads, attn_head_size)
- tensor = tensor.view(new_shape)
- return tensor.permute(0, 2, 1, 3) # (batch, head, seq_length, head_features)
- def _merge_heads(self, tensor, num_heads, attn_head_size):
- """
- Merges attn_head_size dim and num_attn_heads dim into hidden_size
- """
- # re-assemble all head outputs side by side
- # (batch, num_heads, seq_len, attn_head_size) -> (batch, seq_len, num_heads*attn_head_size)
- tensor = tensor.transpose(1, 2).contiguous()
- tensor = tensor.view(tensor.size()[:-2] + (num_heads * attn_head_size,))
- return tensor
- def _attn(self, query, key, value, attention_mask=None, head_mask=None):
- # unlike GPTNeo's SelfAttention, divide by the square root of the dimension of the query and the key
- attn_weights = torch.matmul(query, key.transpose(-1, -2)) * (1.0 / math.sqrt(self.head_dim))
- if self.is_causal:
- query_length, key_length = query.size(-2), key.size(-2)
- # fill the upper left part of the attention weights with inf
- attn_weights = attn_weights.masked_fill(
- self.bias[:, :, key_length - query_length : key_length, :key_length] == 0,
- torch.finfo(attn_weights.dtype).min,
- )
- if attention_mask is not None:
- # Apply the attention mask
- attn_weights = attn_weights + attention_mask
- attn_weights = nn.functional.softmax(attn_weights, dim=-1)
- attn_weights = attn_weights.to(value.dtype)
- attn_weights = self.attn_dropout(attn_weights)
- # Mask heads if we want to
- if head_mask is not None:
- attn_weights = attn_weights * head_mask
- # (batch, num_heads, seq_len, seq_len) x (batch, num_heads, seq_len, attn_head_size)
- # -> (batch, num_heads, seq_len, attn_head_size)
- attn_output = torch.matmul(attn_weights, value)
- return attn_output, attn_weights
- def forward(
- self,
- hidden_states,
- attention_mask=None,
- past_key_values=None,
- head_mask=None,
- use_cache=False,
- output_attentions=False,
- cache_position=None,
- ):
- # calculate query, key, values for all heads in batch and move head forward to be the batch dim
- query, key, value = self.att_proj(hidden_states).split(self.embed_dim, dim=2)
- query = self._split_heads(query, self.num_heads, self.head_dim)
- key = self._split_heads(key, self.num_heads, self.head_dim)
- value = self._split_heads(value, self.num_heads, self.head_dim)
- if past_key_values is not None:
- key, value = past_key_values.update(key, value, self.layer_idx, {"cache_position": cache_position})
- attn_output, attn_weights = self._attn(query, key, value, attention_mask, head_mask)
- attn_output = self._merge_heads(attn_output, self.num_heads, self.head_dim)
- attn_output = self.out_proj(attn_output)
- attn_output = self.resid_dropout(attn_output)
- return attn_output, attn_weights
- class BarkSelfFlashAttention2(BarkSelfAttention):
- """
- Bark flash attention module. This module inherits from `BarkSelfAttention` as the weights of the module stays
- untouched. The only required change would be on the forward pass where it needs to correctly call the public API of
- flash attention and deal with padding tokens in case the input contains any of them.
- """
- def __init__(self, *args, **kwargs):
- super().__init__(*args, **kwargs)
- # TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1.
- # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignment, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0.
- # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left).
- self._flash_attn_uses_top_left_mask = flash_attn_supports_top_left_mask()
- def _split_heads(self, tensor, num_heads, attn_head_size):
- """
- Splits hidden_size dim into attn_head_size and num_heads
- """
- new_shape = tensor.size()[:-1] + (num_heads, attn_head_size)
- tensor = tensor.view(new_shape)
- # Flash attention requires the input to have the shape
- # batch_size x seq_length x head_dim x hidden_dim - (batch, seq_length, head, head_features)
- return tensor
- def _merge_heads(self, tensor, num_heads, attn_head_size):
- """
- Merges attn_head_size dim and num_attn_heads dim into hidden_size
- """
- # re-assemble all head outputs side by side
- # (batch, seq_len, num_heads, attn_head_size) -> (batch, seq_len, num_heads*attn_head_size)
- tensor = tensor.view(tensor.size()[:-2] + (num_heads * attn_head_size,))
- return tensor
- def forward(
- self,
- hidden_states,
- attention_mask=None,
- past_key_values=None,
- head_mask=None,
- use_cache=False,
- output_attentions=False,
- cache_position=None,
- ):
- batch_size, query_len, _ = hidden_states.size()
- # calculate query, key, values for all heads in batch and move head forward to be the batch dim
- query, key, value = self.att_proj(hidden_states).split(self.embed_dim, dim=2)
- query = self._split_heads(query, self.num_heads, self.head_dim)
- key = self._split_heads(key, self.num_heads, self.head_dim)
- value = self._split_heads(value, self.num_heads, self.head_dim)
- if past_key_values is not None:
- key, value = past_key_values.update(key, value, self.layer_idx, {"cache_position": cache_position})
- attn_output = _flash_attention_forward(
- query,
- key,
- value,
- attention_mask,
- query_len,
- dropout=self.dropout if self.training else 0.0,
- use_top_left_mask=self._flash_attn_uses_top_left_mask,
- is_causal=self.is_causal,
- )
- attn_output = self._merge_heads(attn_output, self.num_heads, self.head_dim)
- attn_output = self.out_proj(attn_output)
- attn_output = self.resid_dropout(attn_output)
- return attn_output, None
- BARK_ATTENTION_CLASSES = {
- "eager": BarkSelfAttention,
- "flash_attention_2": BarkSelfFlashAttention2,
- }
- class BarkMLP(nn.Module):
- def __init__(self, config):
- super().__init__()
- self.in_proj = nn.Linear(config.hidden_size, 4 * config.hidden_size, bias=config.bias)
- self.out_proj = nn.Linear(4 * config.hidden_size, config.hidden_size, bias=config.bias)
- self.dropout = nn.Dropout(config.dropout)
- self.gelu = nn.GELU()
- def forward(self, hidden_states):
- hidden_states = self.in_proj(hidden_states)
- hidden_states = self.gelu(hidden_states)
- hidden_states = self.out_proj(hidden_states)
- hidden_states = self.dropout(hidden_states)
- return hidden_states
- class BarkBlock(GradientCheckpointingLayer):
- def __init__(self, config, is_causal=False, layer_idx=None):
- super().__init__()
- if is_causal:
- # if causal, the layerNorm bias is optional to stick with Bark choice of leaving optional bias
- # in AutoRegressive models (corresponding to the "Text" and the "Coarse" modules)
- self.layernorm_1 = nn.LayerNorm(config.hidden_size, bias=config.bias)
- self.layernorm_2 = nn.LayerNorm(config.hidden_size, bias=config.bias)
- else:
- self.layernorm_1 = nn.LayerNorm(config.hidden_size)
- self.layernorm_2 = nn.LayerNorm(config.hidden_size)
- self.attn = BARK_ATTENTION_CLASSES[config._attn_implementation](
- config, is_causal=is_causal, layer_idx=layer_idx
- )
- self.mlp = BarkMLP(config)
- def forward(
- self,
- hidden_states,
- past_key_values=None,
- attention_mask=None,
- head_mask=None,
- use_cache=False,
- output_attentions=False,
- cache_position=None,
- ):
- intermediary_hidden_states = self.layernorm_1(hidden_states)
- attn_outputs = self.attn(
- intermediary_hidden_states,
- past_key_values=past_key_values,
- attention_mask=attention_mask,
- head_mask=head_mask,
- use_cache=use_cache,
- output_attentions=output_attentions,
- cache_position=cache_position,
- )
- attn_output = attn_outputs[0] # output_attn: output, present_key_values, (attn_weights)
- outputs = attn_outputs[1:]
- intermediary_hidden_states = hidden_states + attn_output
- intermediary_hidden_states = intermediary_hidden_states + self.mlp(
- self.layernorm_2(intermediary_hidden_states)
- )
- return (intermediary_hidden_states,) + outputs
- @auto_docstring
- class BarkPreTrainedModel(PreTrainedModel):
- config: BarkConfig
- supports_gradient_checkpointing = False
- _supports_flash_attn = True
- def _init_weights(self, module):
- """Initialize the weights."""
- if isinstance(module, (nn.Linear,)):
- # Slightly different from the TF version which uses truncated_normal for initialization
- # cf https://github.com/pytorch/pytorch/pull/5617
- module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
- if module.bias is not None:
- module.bias.data.zero_()
- elif isinstance(module, nn.Embedding):
- module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
- if module.padding_idx is not None:
- module.weight.data[module.padding_idx].zero_()
- elif isinstance(module, nn.LayerNorm):
- module.bias.data.zero_()
- module.weight.data.fill_(1.0)
- def __init__(self, *inputs, **kwargs):
- super().__init__(*inputs, **kwargs)
- @property
- def device(self) -> torch.device:
- """
- `torch.device`: The device on which the module is (assuming that all the module parameters are on the same
- device).
- """
- # if has _hf_hook, has been offloaded so the device has to be found in the hook
- if not hasattr(self, "_hf_hook"):
- return get_parameter_device(self)
- for module in self.modules():
- if (
- hasattr(module, "_hf_hook")
- and hasattr(module._hf_hook, "execution_device")
- and module._hf_hook.execution_device is not None
- ):
- return torch.device(module._hf_hook.execution_device)
- return get_parameter_device(self)
- # GPT2-like autoregressive model
- class BarkCausalModel(BarkPreTrainedModel, GenerationMixin):
- config: BarkSubModelConfig
- def __init__(self, config):
- super().__init__(config)
- self.config = config
- # initialize as an autoregressive GPT-like model
- self.input_embeds_layer = nn.Embedding(config.input_vocab_size, config.hidden_size)
- self.position_embeds_layer = nn.Embedding(config.block_size, config.hidden_size)
- self.drop = nn.Dropout(config.dropout)
- self.layers = nn.ModuleList([BarkBlock(config, is_causal=True, layer_idx=i) for i in range(config.num_layers)])
- self.layernorm_final = nn.LayerNorm(config.hidden_size, bias=config.bias)
- self.lm_head = nn.Linear(config.hidden_size, config.output_vocab_size, bias=False)
- self.gradient_checkpointing = False
- # Initialize weights and apply final processing
- self.post_init()
- def get_output_embeddings(self):
- # NOTE: get_output_embeddings() must return None to prevent accidental weight tying.
- # See e.g. https://github.com/huggingface/transformers/pull/39339#discussion_r2219126400
- return None
- def get_input_embeddings(self):
- return self.input_embeds_layer
- def set_input_embeddings(self, new_embeddings):
- self.input_embeds_layer = new_embeddings
- def prepare_inputs_for_generation(
- self,
- input_ids,
- attention_mask=None,
- input_embeds=None,
- past_key_values=None,
- position_ids=None,
- use_cache=None,
- cache_position=None,
- **kwargs,
- ):
- # Overwritten -- bark uses `input_embeds` not `inputS_embeds`
- model_inputs = super().prepare_inputs_for_generation(
- input_ids,
- attention_mask=attention_mask,
- inputs_embeds=input_embeds,
- past_key_values=past_key_values,
- position_ids=position_ids,
- use_cache=use_cache,
- cache_position=cache_position,
- **kwargs,
- )
- model_inputs["input_embeds"] = model_inputs.pop("inputs_embeds", None)
- return model_inputs
- @auto_docstring
- def forward(
- self,
- input_ids: Optional[torch.Tensor] = None,
- past_key_values: Optional[Cache] = None,
- attention_mask: Optional[torch.Tensor] = None,
- position_ids: Optional[torch.Tensor] = None,
- head_mask: Optional[torch.Tensor] = None,
- labels: Optional[torch.LongTensor] = None,
- input_embeds: Optional[torch.Tensor] = None,
- use_cache: Optional[bool] = None,
- output_attentions: Optional[bool] = None,
- output_hidden_states: Optional[bool] = None,
- return_dict: Optional[bool] = None,
- cache_position: Optional[torch.Tensor] = None,
- ) -> Union[tuple[torch.Tensor], CausalLMOutputWithPast]:
- r"""
- input_embeds (`torch.FloatTensor` of shape `(batch_size, input_sequence_length, hidden_size)`, *optional*):
- Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation.
- Here, due to `Bark` particularities, if `past_key_values` is used, `input_embeds` will be ignored and you
- have to use `input_ids`. If `past_key_values` is not used and `use_cache` is set to `True`, `input_embeds`
- is used in priority instead of `input_ids`.
- """
- output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
- output_hidden_states = (
- output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
- )
- use_cache = use_cache if use_cache is not None else self.config.use_cache
- return_dict = return_dict if return_dict is not None else self.config.use_return_dict
- loss = None
- if labels is not None:
- raise NotImplementedError(
- "Training is not implemented yet for Bark - ensure you do not pass `labels` to the model."
- )
- # Verify if input_embeds already exists
- # then compute embeddings.
- if input_ids is not None and input_embeds is not None:
- raise ValueError("You cannot specify both input_ids and input_embeds at the same time")
- elif input_embeds is not None and past_key_values is None:
- # we want to return the input_embeds in priority so that it is in line with a weird hack
- # of Bark which concatenate two bits of the input_embeds on the first forward pass of the semantic model
- pass
- elif input_ids is not None:
- input_embeds = self.input_embeds_layer(input_ids) # token embeddings of shape (b, t, n_embd)
- elif input_embeds is not None:
- pass
- else:
- raise ValueError("You have to specify either input_ids or input_embeds")
- input_shape = input_embeds.size()[:-1]
- batch_size = input_embeds.shape[0]
- seq_length = input_shape[-1]
- device = input_ids.device if input_ids is not None else input_embeds.device
- if self.gradient_checkpointing and self.training:
- if use_cache:
- logger.warning_once(
- "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
- )
- use_cache = False
- if use_cache and past_key_values is None:
- past_key_values = DynamicCache(config=self.config)
- if use_cache and isinstance(past_key_values, tuple):
- logger.warning_once(
- "Passing a tuple of `past_key_values` is deprecated and will be removed in Transformers v4.58.0. "
- "You should pass an instance of `DynamicCache` instead, e.g. "
- "`past_key_values=DynamicCache.from_legacy_cache(past_key_values)`."
- )
- past_key_values = DynamicCache.from_legacy_cache(past_key_values)
- past_length = past_key_values.get_seq_length() if past_key_values is not None else 0
- if position_ids is None:
- position_ids = torch.arange(past_length, seq_length + past_length, dtype=torch.long, device=device)
- position_ids = position_ids.unsqueeze(0) # shape (1, seq_length)
- position_embeds = self.position_embeds_layer(position_ids) # position embeddings of shape (1, t, n_embd)
- # Attention mask.
- if attention_mask is not None:
- if batch_size <= 0:
- raise ValueError("batch_size has to be defined and > 0")
- if self.config._attn_implementation == "flash_attention_2":
- attention_mask = attention_mask if 0 in attention_mask else None
- else:
- attention_mask = attention_mask.view(batch_size, -1)
- # [bsz, to_seq_length] -> [bsz, 1, 1, to_seq_length]
- # from_seq_length is 1 to easily broadcast
- attention_mask = _prepare_4d_attention_mask(attention_mask, input_embeds.dtype, tgt_len=1)
- # Prepare head mask if needed
- # 1.0 in head_mask indicate we keep the head
- # attention_probs has shape bsz x num_heads x N x N
- # head_mask has shape num_layers x batch x num_heads x N x N
- head_mask = self.get_head_mask(head_mask, self.config.num_layers)
- hidden_states = self.drop(input_embeds + position_embeds)
- output_shape = input_shape + (hidden_states.size(-1),)
- all_self_attentions = () if output_attentions else None
- all_hidden_states = () if output_hidden_states else None
- for i, block in enumerate(self.layers):
- if output_hidden_states:
- all_hidden_states = all_hidden_states + (hidden_states,)
- outputs = block(
- hidden_states,
- past_key_values=past_key_values,
- attention_mask=attention_mask,
- head_mask=head_mask[i],
- use_cache=use_cache,
- output_attentions=output_attentions,
- cache_position=cache_position,
- )
- hidden_states = outputs[0]
- if output_attentions:
- all_self_attentions = all_self_attentions + (outputs[1],)
- hidden_states = self.layernorm_final(hidden_states)
- hidden_states = hidden_states.view(output_shape)
- # Add last hidden state
- if output_hidden_states:
- all_hidden_states = all_hidden_states + (hidden_states,)
- logits = self.lm_head(hidden_states)
- if not return_dict:
- return tuple(
- v for v in [None, logits, past_key_values, all_hidden_states, all_self_attentions] if v is not None
- )
- return CausalLMOutputWithPast(
- loss=loss,
- logits=logits,
- past_key_values=past_key_values,
- hidden_states=all_hidden_states,
- attentions=all_self_attentions,
- )
- @auto_docstring(
- custom_intro="""
- Bark semantic (or text) model. It shares the same architecture as the coarse model.
- It is a GPT-2 like autoregressive model with a language modeling head on top.
- """
- )
- class BarkSemanticModel(BarkCausalModel):
- base_model_prefix = "semantic"
- config: BarkSemanticConfig
- def generate(
- self,
- input_ids: torch.Tensor,
- semantic_generation_config: Optional[BarkSemanticGenerationConfig] = None,
- history_prompt: Optional[dict[str, torch.Tensor]] = None,
- attention_mask: Optional[torch.Tensor] = None,
- **kwargs,
- ) -> torch.LongTensor:
- """
- Generates text semantic tokens from an input prompt and an additional optional `Bark` speaker prompt.
- Args:
- input_ids (`Optional[torch.Tensor]` of shape (batch_size, seq_len), *optional*):
- Input ids, i.e tokenized input sentences. Will be truncated up to
- semantic_generation_config.max_input_semantic_length tokens. Note that the output audios will be as
- long as the longest generation among the batch.
- semantic_generation_config (`BarkSemanticGenerationConfig`):
- Generation config indicating how to generate the semantic tokens.
- history_prompt (`Optional[dict[str,torch.Tensor]]`, *optional*):
- Optional `Bark` speaker prompt.
- attention_mask (`Optional[torch.Tensor]`, *optional*):
- Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
- - 1 for tokens that are **not masked**,
- - 0 for tokens that are **masked**.
- [What are attention masks?](../glossary#attention-mask)
- Returns:
- torch.LongTensor: Output semantic tokens.
- """
- if semantic_generation_config is None:
- raise ValueError("`semantic_generation_config` has to be provided")
- batch_size = input_ids.shape[0]
- max_input_semantic_length = semantic_generation_config.max_input_semantic_length
- input_ids = input_ids + semantic_generation_config.text_encoding_offset
- if attention_mask is not None:
- input_ids = input_ids.masked_fill((1 - attention_mask).bool(), semantic_generation_config.text_pad_token)
- if history_prompt is not None:
- semantic_history = history_prompt["semantic_prompt"][-max_input_semantic_length:]
- semantic_history = nn.functional.pad(
- semantic_history,
- (0, max_input_semantic_length - len(semantic_history)),
- value=semantic_generation_config.semantic_pad_token,
- mode="constant",
- )
- else:
- semantic_history = torch.tensor(
- [semantic_generation_config.semantic_pad_token] * max_input_semantic_length, dtype=torch.int
- ).to(self.device)
- semantic_history = torch.repeat_interleave(semantic_history[None], batch_size, dim=0)
- infer_array = torch.tensor(
- [[semantic_generation_config.semantic_infer_token]] * batch_size, dtype=torch.int
- ).to(self.device)
- input_embeds = torch.cat(
- [
- self.input_embeds_layer(input_ids[:, :max_input_semantic_length])
- + self.input_embeds_layer(semantic_history[:, : max_input_semantic_length + 1]),
- self.input_embeds_layer(infer_array),
- ],
- dim=1,
- )
- tokens_to_suppress = list(
- range(semantic_generation_config.semantic_vocab_size, semantic_generation_config.semantic_pad_token)
- )
- tokens_to_suppress.extend(
- list(range(semantic_generation_config.semantic_pad_token + 1, self.config.output_vocab_size))
- )
- suppress_tokens_logits_processor = SuppressTokensLogitsProcessor(tokens_to_suppress, device=input_ids.device)
- min_eos_p = kwargs.get("min_eos_p", semantic_generation_config.min_eos_p)
- early_stopping_logits_processor = BarkEosPrioritizerLogitsProcessor(
- eos_token_id=semantic_generation_config.eos_token_id, min_eos_p=min_eos_p, device=input_ids.device
- )
- # pass input_ids in order to stay consistent with the transformers generate method even though it is not used
- # (except to get the input seq_len - that's why we keep the first 257 tokens)
- semantic_output = super().generate(
- torch.ones((batch_size, max_input_semantic_length + 1), dtype=torch.int, device=self.device),
- input_embeds=input_embeds,
- logits_processor=[suppress_tokens_logits_processor, early_stopping_logits_processor],
- generation_config=semantic_generation_config,
- **kwargs,
- ) # size: 10048
- # take the generated semantic tokens
- semantic_output = semantic_output[:, max_input_semantic_length + 1 :]
- return semantic_output
- @auto_docstring(
- custom_intro="""
- Bark coarse acoustics model.
- It shares the same architecture as the semantic (or text) model. It is a GPT-2 like autoregressive model with a
- language modeling head on top.
- """
- )
- class BarkCoarseModel(BarkCausalModel):
- base_model_prefix = "coarse_acoustics"
- config: BarkCoarseConfig
- def preprocess_histories(
- self,
- max_coarse_history: int,
- semantic_to_coarse_ratio: int,
- batch_size: int,
- semantic_generation_config: int,
- codebook_size: int,
- history_prompt: Optional[dict[str, torch.Tensor]] = None,
- ):
- """
- Preprocess the optional `Bark` speaker prompts before `self.generate`.
- Args:
- max_coarse_history (`int`):
- Maximum size of coarse tokens used.
- semantic_to_coarse_ratio (`int`):
- Ratio of semantic to coarse frequency
- batch_size (`int`):
- Batch size, i.e the number of samples.
- semantic_generation_config (`BarkSemanticGenerationConfig`):
- Generation config indicating how to generate the semantic tokens.
- codebook_size (`int`):
- Codebook channel size, i.e. the size of the output vocabulary per codebook channel.
- history_prompt (`Optional[dict[str,torch.Tensor]]`):
- Optional `Bark` speaker prompt.
- Returns: Returns:
- `tuple(torch.FloatTensor)`:
- - **x_semantic_history** (`torch.FloatTensor` -- Processed semantic speaker prompt.
- - **x_coarse_history** (`torch.FloatTensor`) -- Processed coarse speaker prompt.
- """
- if history_prompt is not None:
- x_semantic_history = torch.repeat_interleave(history_prompt["semantic_prompt"][None], batch_size, dim=0)
- # clone to avoid modifying history_prompt.coarse_prompt
- x_coarse_history = history_prompt["coarse_prompt"].clone()
- # offset x_coarse_history
- if codebook_size is not None:
- for n in range(1, x_coarse_history.shape[0]):
- # offset
- x_coarse_history[n, :] += codebook_size * n
- # flatten x_coarse_history
- x_coarse_history = torch.transpose(x_coarse_history, 0, 1).reshape(-1)
- x_coarse_history = x_coarse_history + semantic_generation_config.semantic_vocab_size
- x_coarse_history = torch.repeat_interleave(x_coarse_history[None], batch_size, dim=0)
- # e.g: after SEMANTIC_VOCAB_SIZE (10000), 1024 tokens dedicated to first codebook, 1024 next tokens
- # dedicated to second codebook.
- max_semantic_history = int(np.floor(max_coarse_history / semantic_to_coarse_ratio))
- # trim histories correctly
- n_semantic_hist_provided = min(
- [
- max_semantic_history,
- x_semantic_history.shape[1] - x_semantic_history.shape[1] % 2,
- int(np.floor(x_coarse_history.shape[1] / semantic_to_coarse_ratio)),
- ]
- )
- n_coarse_hist_provided = int(round(n_semantic_hist_provided * semantic_to_coarse_ratio))
- x_semantic_history = x_semantic_history[:, -n_semantic_hist_provided:].int()
- x_coarse_history = x_coarse_history[:, -n_coarse_hist_provided:].int()
- # bit of a hack for time alignment (sounds better) - from Bark original implementation
- x_coarse_history = x_coarse_history[:, :-2]
- else:
- # shape: (batch_size, 0)
- x_semantic_history = torch.tensor([[]] * batch_size, dtype=torch.int, device=self.device)
- x_coarse_history = torch.tensor([[]] * batch_size, dtype=torch.int, device=self.device)
- return x_semantic_history, x_coarse_history
- def generate(
- self,
- semantic_output: torch.Tensor,
- semantic_generation_config: Optional[BarkSemanticGenerationConfig] = None,
- coarse_generation_config: Optional[BarkCoarseGenerationConfig] = None,
- codebook_size: int = 1024,
- history_prompt: Optional[dict[str, torch.Tensor]] = None,
- return_output_lengths: Optional[bool] = None,
- **kwargs,
- ) -> Union[torch.LongTensor, tuple[torch.LongTensor, torch.LongTensor]]:
- """
- Generates coarse acoustics tokens from input text semantic tokens and an additional optional `Bark` speaker
- prompt.
- Args:
- semantic_output (`torch.Tensor` of shape (batch_size, seq_len), *optional*):
- Input text semantic ids, i.e the output of `BarkSemanticModel.generate`.
- semantic_generation_config (`BarkSemanticGenerationConfig`):
- Generation config indicating how to generate the semantic tokens.
- coarse_generation_config (`BarkCoarseGenerationConfig`):
- Generation config indicating how to generate the coarse tokens.
- codebook_size (`int`, *optional*, defaults to 1024):
- Codebook channel size, i.e. the size of the output vocabulary per codebook channel.
- history_prompt (`Optional[dict[str,torch.Tensor]]`, *optional*):
- Optional `Bark` speaker prompt.
- return_output_lengths (`bool`, *optional*):
- Whether or not to return the output lengths. Useful when batching.
- Returns:
- By default:
- torch.LongTensor: Output coarse acoustics tokens.
- If `return_output_lengths=True`:
- `Tuple(torch.Tensor, torch.Tensor): The output coarse acoustics tokens, and the length of each sample
- of the batch.
- """
- if semantic_generation_config is None:
- raise ValueError("`semantic_generation_config` has to be provided")
- if coarse_generation_config is None:
- raise ValueError("`coarse_generation_config` has to be provided")
- max_coarse_input_length = coarse_generation_config.max_coarse_input_length
- max_coarse_history = coarse_generation_config.max_coarse_history
- sliding_window_len = coarse_generation_config.sliding_window_len
- # replace semantic_pad_token (eos_tok and pad_tok here) with coarse_semantic_pad_token i.e the pad_token
- # used in the next model
- semantic_output.masked_fill_(
- semantic_output == semantic_generation_config.semantic_pad_token,
- coarse_generation_config.coarse_semantic_pad_token,
- )
- semantic_to_coarse_ratio = (
- coarse_generation_config.coarse_rate_hz
- / semantic_generation_config.semantic_rate_hz
- * coarse_generation_config.n_coarse_codebooks
- )
- max_semantic_history = int(np.floor(max_coarse_history / semantic_to_coarse_ratio))
- output_lengths = (semantic_output != coarse_generation_config.coarse_semantic_pad_token).sum(1)
- output_lengths = torch.floor(
- output_lengths * semantic_to_coarse_ratio / coarse_generation_config.n_coarse_codebooks
- )
- output_lengths = torch.round(output_lengths * coarse_generation_config.n_coarse_codebooks).int()
- max_generated_len = torch.max(output_lengths).item()
- batch_size = semantic_output.shape[0]
- x_semantic_history, x_coarse = self.preprocess_histories(
- history_prompt=history_prompt,
- max_coarse_history=max_coarse_history,
- semantic_to_coarse_ratio=semantic_to_coarse_ratio,
- batch_size=batch_size,
- semantic_generation_config=semantic_generation_config,
- codebook_size=codebook_size,
- )
- base_semantic_idx = x_semantic_history.shape[1]
- semantic_output = torch.hstack([x_semantic_history, semantic_output])
- n_window_steps = int(np.ceil(max_generated_len / sliding_window_len))
- total_generated_len = 0
- len_coarse_history = x_coarse.shape[1]
- for _ in range(n_window_steps):
- semantic_idx = base_semantic_idx + int(round(total_generated_len / semantic_to_coarse_ratio))
- # pad from right side
- input_coarse = semantic_output[:, np.max([0, semantic_idx - max_semantic_history]) :]
- input_coarse = input_coarse[:, :max_coarse_input_length]
- input_coarse = F.pad(
- input_coarse,
- (0, max_coarse_input_length - input_coarse.shape[-1]),
- "constant",
- coarse_generation_config.coarse_semantic_pad_token,
- )
- input_coarse = torch.hstack(
- [
- input_coarse,
- torch.tensor([[coarse_generation_config.coarse_infer_token]] * batch_size, device=self.device),
- x_coarse[:, -max_coarse_history:],
- ]
- )
- alternatingLogitsProcessor = AlternatingCodebooksLogitsProcessor(
- input_coarse.shape[1],
- semantic_generation_config.semantic_vocab_size,
- codebook_size,
- )
- output_coarse = super().generate(
- input_coarse,
- logits_processor=[alternatingLogitsProcessor],
- max_new_tokens=min(sliding_window_len, max_generated_len - total_generated_len),
- generation_config=coarse_generation_config,
- **kwargs,
- )
- input_coarse_len = input_coarse.shape[1]
- x_coarse = torch.hstack([x_coarse, output_coarse[:, input_coarse_len:]])
- total_generated_len = x_coarse.shape[1] - len_coarse_history
- del output_coarse
- coarse_output = x_coarse[:, len_coarse_history:]
- if return_output_lengths:
- return coarse_output, output_lengths
- return coarse_output
- @auto_docstring(
- custom_intro="""
- Bark fine acoustics model. It is a non-causal GPT-like model with `config.n_codes_total` embedding layers and
- language modeling heads, one for each codebook.
- """
- )
- class BarkFineModel(BarkPreTrainedModel):
- base_model_prefix = "fine_acoustics"
- config: BarkFineConfig
- main_input_name = "codebook_idx"
- def __init__(self, config):
- # non-causal gpt-like model with one embedding layer and one lm_head for each codebook of Encodec
- super().__init__(config)
- self.config = config
- # initialize a modified non causal GPT-like model
- # note that for there is one embedding layer and one lm_head for each codebook of Encodec
- self.input_embeds_layers = nn.ModuleList(
- [nn.Embedding(config.input_vocab_size, config.hidden_size) for _ in range(config.n_codes_total)]
- )
- self.position_embeds_layer = nn.Embedding(config.block_size, config.hidden_size)
- self.drop = nn.Dropout(config.dropout)
- self.layers = nn.ModuleList(
- [BarkBlock(config, is_causal=False, layer_idx=i) for i in range(config.num_layers)]
- )
- self.layernorm_final = nn.LayerNorm(config.hidden_size)
- self.lm_heads = nn.ModuleList(
- [
- nn.Linear(config.hidden_size, config.output_vocab_size, bias=False)
- for _ in range(config.n_codes_given, config.n_codes_total)
- ]
- )
- self.gradient_checkpointing = False
- self.n_codes_total = config.n_codes_total
- # Initialize weights and apply final processing
- self.post_init()
- def get_input_embeddings(self):
- # one embedding layers for each codebook
- return self.input_embeds_layers
- def set_input_embeddings(self, new_embeddings):
- # one embedding layers for each codebook
- self.input_embeds_layers = new_embeddings
- def get_output_embeddings(self):
- # one lm_head for each codebook
- return self.lm_heads
- def set_output_embeddings(self, new_output_embeddings):
- # one lm_head for each codebook
- self.lm_heads = new_output_embeddings
- def _resize_token_embeddings(self, new_num_tokens, pad_to_multiple_of=None, mean_resizing=True):
- old_embeddings_list = self.get_input_embeddings()
- new_embeddings_list = nn.ModuleList(
- [
- self._get_resized_embeddings(old_embeddings, new_num_tokens, pad_to_multiple_of, mean_resizing)
- for old_embeddings in old_embeddings_list
- ]
- )
- self.set_input_embeddings(new_embeddings_list)
- new_num_tokens = new_embeddings_list[0].weight.shape[0]
- # if word embeddings are not tied, make sure that lm head is resized as well
- if self.get_output_embeddings() is not None and not self.config.tie_word_embeddings:
- old_lm_head_list = self.get_output_embeddings()
- new_lm_head_list = nn.ModuleList(
- [self._get_resized_lm_head(old_lm_head, new_num_tokens) for old_lm_head in old_lm_head_list]
- )
- self.set_output_embeddings(new_lm_head_list)
- return self.get_input_embeddings()
- def resize_token_embeddings(
- self,
- new_num_tokens: Optional[int] = None,
- pad_to_multiple_of: Optional[int] = None,
- mean_resizing: bool = True,
- ) -> nn.Embedding:
- """
- Resizes input token embeddings matrix of the model if `new_num_tokens != config.vocab_size`.
- Takes care of tying weights embeddings afterwards if the model class has a `tie_weights()` method.
- Arguments:
- new_num_tokens (`int`, *optional*):
- The number of new tokens in the embedding matrix. Increasing the size will add newly initialized
- vectors at the end. Reducing the size will remove vectors from the end. If not provided or `None`, just
- returns a pointer to the input tokens `torch.nn.Embedding` module of the model without doing anything.
- pad_to_multiple_of (`int`, *optional*):
- If set will pad the embedding matrix to a multiple of the provided value.
- This is especially useful to enable the use of Tensor Cores on NVIDIA hardware with compute capability
- `>= 7.5` (Volta), or on TPUs which benefit from having sequence lengths be a multiple of 128. For more
- details about this, or help on choosing the correct value for resizing, refer to this guide:
- https://docs.nvidia.com/deeplearning/performance/dl-performance-matrix-multiplication/index.html#requirements-tc
- mean_resizing (`bool`):
- Whether to initialize the added embeddings from a multivariate normal distribution that has old embeddings' mean and
- covariance or to initialize them with a normal distribution that has a mean of zero and std equals `config.initializer_range`.
- Setting `mean_resizing` to `True` is useful when increasing the size of the embeddings of causal language models,
- where the generated tokens' probabilities won't be affected by the added embeddings because initializing the new embeddings with the
- old embeddings' mean will reduce the kl-divergence between the next token probability before and after adding the new embeddings.
- Refer to this article for more information: https://nlp.stanford.edu/~johnhew/vocab-expansion.html
- Return:
- `torch.nn.Embedding`: Pointer to the input tokens Embeddings Module of the model.
- """
- model_embeds = self._resize_token_embeddings(new_num_tokens, pad_to_multiple_of, mean_resizing)
- if new_num_tokens is None and pad_to_multiple_of is None:
- return model_embeds
- # Update base model and current model config
- self.config.output_vocab_size = model_embeds[0].weight.shape[0]
- self.config.vocab_size = model_embeds[0].weight.shape[0]
- self.output_vocab_size = model_embeds[0].weight.shape[0]
- self.vocab_size = model_embeds[0].weight.shape[0]
- # Tie weights again if needed
- self.tie_weights()
- return model_embeds
- def _tie_weights(self):
- if getattr(self.config, "tie_word_embeddings", True):
- self._tied_weights_keys = []
- output_embeddings = self.get_output_embeddings()
- input_embeddings = self.get_input_embeddings()
- for i in range(self.config.n_codes_total - self.config.n_codes_given):
- # self.input_embeds_layers[i + 1].weight = self.lm_heads[i].weight
- self._tie_or_clone_weights(output_embeddings[i], input_embeddings[i + 1])
- self._tied_weights_keys.append(f"lm_heads.{i}.weight")
- def tie_weights(self):
- """
- Tie the weights between the input embeddings list and the output embeddings list.
- If the `torchscript` flag is set in the configuration, can't handle parameter sharing so we are cloning the
- weights instead.
- """
- for module in self.modules():
- if hasattr(module, "_tie_weights"):
- module._tie_weights()
- @auto_docstring
- def forward(
- self,
- codebook_idx: int, # an additional idx corresponding to the id of the codebook that will be predicted
- input_ids: Optional[torch.Tensor] = None,
- attention_mask: Optional[torch.Tensor] = None,
- position_ids: Optional[torch.Tensor] = None,
- head_mask: Optional[torch.Tensor] = None,
- labels: Optional[torch.LongTensor] = None,
- input_embeds: Optional[torch.Tensor] = None,
- output_attentions: Optional[bool] = None,
- output_hidden_states: Optional[bool] = None,
- return_dict: Optional[bool] = None,
- ) -> Union[tuple[torch.Tensor], MaskedLMOutput]:
- r"""
- codebook_idx (`int`):
- Index of the codebook that will be predicted.
- labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
- NOT IMPLEMENTED YET.
- input_embeds (`torch.FloatTensor` of shape `(batch_size, input_sequence_length, hidden_size)`, *optional*):
- Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. If
- `past_key_values` is used, optionally only the last `input_embeds` have to be input (see
- `past_key_values`). This is useful if you want more control over how to convert `input_ids` indices into
- associated vectors than the model's internal embedding lookup matrix.
- """
- output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
- output_hidden_states = (
- output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
- )
- return_dict = return_dict if return_dict is not None else self.config.use_return_dict
- loss = None
- if labels is not None:
- raise NotImplementedError("Training is not implemented yet")
- if codebook_idx == 0:
- raise ValueError("Cannot predict 0th codebook - 0th codebook should be predicted by the coarse model")
- if input_ids is not None and input_embeds is not None:
- raise ValueError("You cannot specify both input_ids and input_embeds at the same time")
- if input_ids is None and input_embeds is None:
- raise ValueError("You have to specify either input_ids or input_embeds")
- if input_ids is not None:
- # the input_embeddings are the sum of the j previous codebooks embeddings before
- # the current codebook_idx codebook
- # forward the GPT model itself
- input_embeds = [
- input_embeds_layer(input_ids[:, :, i]).unsqueeze(-1)
- for i, input_embeds_layer in enumerate(self.input_embeds_layers)
- ] # token embeddings of shape (b, t, n_embd)
- input_embeds = torch.cat(input_embeds, dim=-1)
- input_embeds = input_embeds[:, :, :, : codebook_idx + 1].sum(dim=-1)
- input_shape = input_embeds.size()[:-1]
- batch_size = input_embeds.shape[0]
- seq_length = input_shape[1]
- device = input_ids.device if input_ids is not None else input_embeds.device
- if position_ids is None:
- position_ids = torch.arange(0, seq_length, dtype=torch.long, device=device)
- position_ids = position_ids.unsqueeze(0) # shape (1, seq_length)
- position_embeds = self.position_embeds_layer(position_ids) # position embeddings of shape (1, t, n_embd)
- # Attention mask.
- if attention_mask is not None:
- if batch_size <= 0:
- raise ValueError("batch_size has to be defined and > 0")
- if self.config._attn_implementation == "flash_attention_2":
- attention_mask = attention_mask if 0 in attention_mask else None
- else:
- # [bsz, to_seq_length] -> [bsz, 1, 1, to_seq_length]
- # from_seq_length is 1 to easily broadcast
- attention_mask = _prepare_4d_attention_mask(attention_mask, input_embeds.dtype, tgt_len=1)
- head_mask = self.get_head_mask(head_mask, self.config.num_layers)
- hidden_states = self.drop(input_embeds + position_embeds)
- output_shape = input_shape + (hidden_states.size(-1),)
- all_self_attentions = () if output_attentions else None
- all_hidden_states = () if output_hidden_states else None
- for i, block in enumerate(self.layers):
- if output_hidden_states:
- all_hidden_states = all_hidden_states + (hidden_states,)
- outputs = block(
- hidden_states,
- attention_mask=attention_mask,
- head_mask=head_mask[i],
- output_attentions=output_attentions,
- )
- hidden_states = outputs[0]
- if output_attentions:
- all_self_attentions = all_self_attentions + (outputs[1],)
- hidden_states = self.layernorm_final(hidden_states)
- hidden_states = hidden_states.view(output_shape)
- # Add last hidden state
- if output_hidden_states:
- all_hidden_states = all_hidden_states + (hidden_states,)
- logits = self.lm_heads[codebook_idx - self.config.n_codes_given](hidden_states)
- if not return_dict:
- return tuple(v for v in [None, logits, all_hidden_states, all_self_attentions] if v is not None)
- return MaskedLMOutput(
- loss=loss,
- logits=logits,
- hidden_states=all_hidden_states,
- attentions=all_self_attentions,
- )
- @torch.no_grad()
- def generate(
- self,
- coarse_output: torch.Tensor,
- semantic_generation_config: Optional[BarkSemanticGenerationConfig] = None,
- coarse_generation_config: Optional[BarkCoarseGenerationConfig] = None,
- fine_generation_config: BarkFineGenerationConfig = None,
- codebook_size: int = 1024,
- history_prompt: Optional[dict[str, torch.Tensor]] = None,
- **kwargs,
- ) -> torch.LongTensor:
- """
- Generates fine acoustics tokens from input coarse acoustics tokens and an additional optional `Bark` speaker
- prompt.
- Args:
- coarse_output (`torch.Tensor` of shape (batch_size, seq_len)):
- Input coarse acoustics ids, i.e the output of `BarkCoarseModel.generate`.
- semantic_generation_config (`BarkSemanticGenerationConfig`):
- Generation config indicating how to generate the semantic tokens.
- coarse_generation_config (`BarkCoarseGenerationConfig`):
- Generation config indicating how to generate the coarse tokens.
- fine_generation_config (`BarkFineGenerationConfig`):
- Generation config indicating how to generate the fine tokens.
- codebook_size (`int`, *optional*, defaults to 1024):
- Codebook channel size, i.e. the size of the output vocabulary per codebook channel.
- history_prompt (`Optional[dict[str,torch.Tensor]]`, *optional*):
- Optional `Bark` speaker prompt.
- Returns:
- torch.LongTensor: Output fine acoustics tokens.
- """
- if semantic_generation_config is None:
- raise ValueError("`semantic_generation_config` has to be provided")
- if coarse_generation_config is None:
- raise ValueError("`coarse_generation_config` has to be provided")
- if fine_generation_config is None:
- raise ValueError("`fine_generation_config` has to be provided")
- # since we don't really use GenerationConfig through the fine model (autoencoder)
- # and since only temperature is used from the classic GenerationConfig parameters
- # manually impose the kwargs priority over the generation config
- temperature = kwargs.get("temperature", fine_generation_config.temperature)
- max_fine_history_length = fine_generation_config.max_fine_history_length
- max_fine_input_length = fine_generation_config.max_fine_input_length
- # shape: (batch, n_coarse_codebooks * seq_len)
- # new_shape: (batch, seq_len, n_coarse_codebooks)
- coarse_output = coarse_output.view(coarse_output.shape[0], -1, coarse_generation_config.n_coarse_codebooks)
- # brings ids into the range [0, codebook_size -1]
- coarse_output = torch.remainder(coarse_output - semantic_generation_config.semantic_vocab_size, codebook_size)
- batch_size = coarse_output.shape[0]
- if history_prompt is not None:
- x_fine_history = torch.repeat_interleave(history_prompt["fine_prompt"].T[None], batch_size, dim=0)
- # transpose to get to shape (seq_len, n_fine_codebooks)
- else:
- x_fine_history = None
- n_coarse = coarse_generation_config.n_coarse_codebooks
- # pad the last 6th codebooks
- fine_input = F.pad(
- coarse_output,
- (0, fine_generation_config.n_fine_codebooks - n_coarse),
- "constant",
- codebook_size,
- )
- # prepend history if available (max max_fine_history_length)
- if x_fine_history is not None:
- fine_input = torch.cat([x_fine_history[:, -max_fine_history_length:, :], fine_input], dim=1)
- # len of the fine_history that has been added to fine_input
- n_history = x_fine_history[:, -max_fine_history_length:, :].shape[1]
- else:
- n_history = 0
- n_remove_from_end = 0
- # need to pad if too short (since non-causal model)
- if fine_input.shape[1] < max_fine_input_length:
- n_remove_from_end = max_fine_input_length - fine_input.shape[1]
- fine_input = F.pad(fine_input, (0, 0, 0, n_remove_from_end), mode="constant", value=codebook_size)
- # we can be lazy about fractional loop and just keep overwriting codebooks.
- # seems that coarse_output.shape[1] - (max_fine_input_length - n_history) is equal to minus n_remove_from_end
- # So if we needed to pad because too short, n_loops is always 1 (because n_remove_from_end > 0)
- # If not, we loop over at least twice.
- n_loops = (coarse_output.shape[1] - (max_fine_input_length - n_history)) / max_fine_history_length
- n_loops = int(np.ceil(n_loops))
- n_loops = max(0, n_loops) + 1
- for n_outer in range(n_loops):
- start_idx = min([n_outer * max_fine_history_length, fine_input.shape[1] - max_fine_input_length])
- start_fill_idx = min(
- [n_history + n_outer * max_fine_history_length, fine_input.shape[1] - max_fine_history_length]
- )
- rel_start_fill_idx = start_fill_idx - start_idx
- input_buffer = fine_input[:, start_idx : start_idx + max_fine_input_length, :]
- for n_inner in range(n_coarse, fine_generation_config.n_fine_codebooks):
- logits = self.forward(n_inner, input_buffer).logits
- if temperature is None or temperature == 1.0:
- relevant_logits = logits[:, rel_start_fill_idx:, :codebook_size]
- codebook_preds = torch.argmax(relevant_logits, -1)
- else:
- relevant_logits = logits[:, :, :codebook_size] / temperature
- # apply softmax
- probs = F.softmax(relevant_logits, dim=-1)[:, rel_start_fill_idx:max_fine_input_length]
- # reshape to 2D: (batch_size, seq_len, codebook_size) -> (batch_size*seq_len, codebook_size)
- probs = probs.reshape((-1, codebook_size))
- # multinomial then reshape : (batch_size*seq_len)-> (batch_size,seq_len)
- codebook_preds = torch.multinomial(probs, num_samples=1).view(batch_size, -1)
- codebook_preds = codebook_preds.to(torch.int32)
- input_buffer[:, rel_start_fill_idx:, n_inner] = codebook_preds
- del logits, codebook_preds
- # transfer into fine_input
- for n_inner in range(n_coarse, fine_generation_config.n_fine_codebooks):
- fine_input[
- :, start_fill_idx : start_fill_idx + (max_fine_input_length - rel_start_fill_idx), n_inner
- ] = input_buffer[:, rel_start_fill_idx:, n_inner]
- del input_buffer
- fine_input = fine_input.transpose(1, 2)[:, :, n_history:]
- if n_remove_from_end > 0:
- fine_input = fine_input[:, :, :-n_remove_from_end]
- if fine_input.shape[-1] != coarse_output.shape[-2]:
- raise ValueError("input and output should have the same seq_len")
- return fine_input
- @auto_docstring(
- custom_intro="""
- The full Bark model, a text-to-speech model composed of 4 sub-models:
- - [`BarkSemanticModel`] (also referred to as the 'text' model): a causal auto-regressive transformer model that
- takes
- as input tokenized text, and predicts semantic text tokens that capture the meaning of the text.
- - [`BarkCoarseModel`] (also referred to as the 'coarse acoustics' model), also a causal autoregressive transformer,
- that takes into input the results of the last model. It aims at regressing the first two audio codebooks necessary
- to `encodec`.
- - [`BarkFineModel`] (the 'fine acoustics' model), this time a non-causal autoencoder transformer, which iteratively
- predicts the last codebooks based on the sum of the previous codebooks embeddings.
- - having predicted all the codebook channels from the [`EncodecModel`], Bark uses it to decode the output audio
- array.
- It should be noted that each of the first three modules can support conditional speaker embeddings to condition the
- output sound according to specific predefined voice.
- """
- )
- class BarkModel(BarkPreTrainedModel):
- config: BarkConfig
- def __init__(self, config):
- super().__init__(config)
- self.semantic = BarkSemanticModel(config.semantic_config)
- self.coarse_acoustics = BarkCoarseModel(config.coarse_acoustics_config)
- self.fine_acoustics = BarkFineModel(config.fine_acoustics_config)
- self.codec_model = AutoModel.from_config(config.codec_config)
- self.config = config
- @classmethod
- def can_generate(cls) -> bool:
- # Bark has a unique model structure, where the external class (`BarkModel`) doesn't need to inherit from
- # `GenerationMixin` (it has a non-standard generation method), but one of the internal models do
- # (`BarkSemanticModel`). This means that the base `can_generate()` will return `False`, but we need to
- # override it so as to do `GenerationConfig` handling in multiple parts of the codebase.
- return True
- @property
- def device(self) -> torch.device:
- """
- `torch.device`: The device on which the module is (assuming that all the module parameters are on the same
- device).
- """
- # for bark_model, device must be verified on its sub-models
- # if has _hf_hook, has been offloaded so the device has to be found in the hook
- if not hasattr(self.semantic, "_hf_hook"):
- return get_parameter_device(self)
- for module in self.semantic.modules():
- if (
- hasattr(module, "_hf_hook")
- and hasattr(module._hf_hook, "execution_device")
- and module._hf_hook.execution_device is not None
- ):
- return torch.device(module._hf_hook.execution_device)
- def enable_cpu_offload(
- self,
- accelerator_id: Optional[int] = 0,
- **kwargs,
- ):
- r"""
- Offloads all sub-models to CPU using accelerate, reducing memory usage with a low impact on performance. This
- method moves one whole sub-model at a time to the accelerator when it is used, and the sub-model remains in accelerator until the next sub-model runs.
- Args:
- accelerator_id (`int`, *optional*, defaults to 0):
- accelerator id on which the sub-models will be loaded and offloaded. This argument is deprecated.
- kwargs (`dict`, *optional*):
- additional keyword arguments:
- `gpu_id`: accelerator id on which the sub-models will be loaded and offloaded.
- """
- if is_accelerate_available():
- from accelerate import cpu_offload_with_hook
- else:
- raise ImportError("`enable_model_cpu_offload` requires `accelerate`.")
- gpu_id = kwargs.get("gpu_id", 0)
- if gpu_id != 0:
- warnings.warn(
- "The argument `gpu_id` is deprecated and will be removed in version 4.54.0 of Transformers. Please use `accelerator_id` instead.",
- FutureWarning,
- )
- accelerator_id = gpu_id
- device_type = "cuda"
- if is_torch_accelerator_available():
- device_type = torch.accelerator.current_accelerator().type
- device = torch.device(f"{device_type}:{accelerator_id}")
- torch_accelerator_module = getattr(torch, device_type)
- if self.device.type != "cpu":
- self.to("cpu")
- torch_accelerator_module.empty_cache() # otherwise we don't see the memory savings (but they probably exist)
- # this layer is used outside the first forward pass of semantic so need to be loaded before semantic
- self.semantic.input_embeds_layer, _ = cpu_offload_with_hook(self.semantic.input_embeds_layer, device)
- hook = None
- for cpu_offloaded_model in [
- self.semantic,
- self.coarse_acoustics,
- self.fine_acoustics,
- ]:
- _, hook = cpu_offload_with_hook(cpu_offloaded_model, device, prev_module_hook=hook)
- self.fine_acoustics_hook = hook
- _, hook = cpu_offload_with_hook(self.codec_model, device, prev_module_hook=hook)
- # We'll offload the last model manually.
- self.codec_model_hook = hook
- def codec_decode(self, fine_output, output_lengths=None):
- """Turn quantized audio codes into audio array using encodec."""
- fine_output = fine_output.transpose(0, 1)
- emb = self.codec_model.quantizer.decode(fine_output)
- if output_lengths is not None:
- # encodec uses LSTMs which behaves differently with appended padding
- # decoding with encodec takes around 0.1% of the total generation time
- # to keep generation quality, we break batching
- out = [sample[:, :l].unsqueeze(0) for (sample, l) in zip(emb, output_lengths)]
- audio_arr = [self.codec_model.decoder(sample).squeeze() for sample in out]
- else:
- out = self.codec_model.decoder(emb)
- audio_arr = out.squeeze(1) # squeeze the codebook dimension
- return audio_arr
- @torch.no_grad()
- def generate(
- self,
- input_ids: Optional[torch.Tensor] = None,
- history_prompt: Optional[dict[str, torch.Tensor]] = None,
- return_output_lengths: Optional[bool] = None,
- **kwargs,
- ) -> torch.LongTensor:
- """
- Generates audio from an input prompt and an additional optional `Bark` speaker prompt.
- Args:
- input_ids (`Optional[torch.Tensor]` of shape (batch_size, seq_len), *optional*):
- Input ids. Will be truncated up to 256 tokens. Note that the output audios will be as long as the
- longest generation among the batch.
- history_prompt (`Optional[dict[str,torch.Tensor]]`, *optional*):
- Optional `Bark` speaker prompt. Note that for now, this model takes only one speaker prompt per batch.
- kwargs (*optional*): Remaining dictionary of keyword arguments. Keyword arguments are of two types:
- - Without a prefix, they will be entered as `**kwargs` for the `generate` method of each sub-model.
- - With a *semantic_*, *coarse_*, *fine_* prefix, they will be input for the `generate` method of the
- semantic, coarse and fine respectively. It has the priority over the keywords without a prefix.
- This means you can, for example, specify a generation strategy for all sub-models except one.
- return_output_lengths (`bool`, *optional*):
- Whether or not to return the waveform lengths. Useful when batching.
- Returns:
- By default:
- - **audio_waveform** (`torch.Tensor` of shape (batch_size, seq_len)): Generated audio waveform.
- When `return_output_lengths=True`:
- Returns a tuple made of:
- - **audio_waveform** (`torch.Tensor` of shape (batch_size, seq_len)): Generated audio waveform.
- - **output_lengths** (`torch.Tensor` of shape (batch_size)): The length of each waveform in the batch
- Example:
- ```python
- >>> from transformers import AutoProcessor, BarkModel
- >>> processor = AutoProcessor.from_pretrained("suno/bark-small")
- >>> model = BarkModel.from_pretrained("suno/bark-small")
- >>> # To add a voice preset, you can pass `voice_preset` to `BarkProcessor.__call__(...)`
- >>> voice_preset = "v2/en_speaker_6"
- >>> inputs = processor("Hello, my dog is cute, I need him in my life", voice_preset=voice_preset)
- >>> audio_array = model.generate(**inputs, semantic_max_new_tokens=100)
- >>> audio_array = audio_array.cpu().numpy().squeeze()
- ```
- """
- # TODO (joao):workaround until nested generation config is compatible with PreTrained Model
- # todo: dict
- semantic_generation_config = BarkSemanticGenerationConfig(**self.generation_config.semantic_config)
- coarse_generation_config = BarkCoarseGenerationConfig(**self.generation_config.coarse_acoustics_config)
- fine_generation_config = BarkFineGenerationConfig(**self.generation_config.fine_acoustics_config)
- kwargs_semantic = {
- # if "attention_mask" is set, it should not be passed to CoarseModel and FineModel
- "attention_mask": kwargs.pop("attention_mask", None),
- "min_eos_p": kwargs.pop("min_eos_p", None),
- }
- kwargs_coarse = {}
- kwargs_fine = {}
- for key, value in kwargs.items():
- if key.startswith("semantic_"):
- key = key[len("semantic_") :]
- kwargs_semantic[key] = value
- elif key.startswith("coarse_"):
- key = key[len("coarse_") :]
- kwargs_coarse[key] = value
- elif key.startswith("fine_"):
- key = key[len("fine_") :]
- kwargs_fine[key] = value
- else:
- # If the key is already in a specific config, then it's been set with a
- # submodules specific value and we don't override
- if key not in kwargs_semantic:
- kwargs_semantic[key] = value
- if key not in kwargs_coarse:
- kwargs_coarse[key] = value
- if key not in kwargs_fine:
- kwargs_fine[key] = value
- # 1. Generate from the semantic model
- if "generation_config" in kwargs_semantic:
- kwargs_semantic.pop("generation_config")
- semantic_output = self.semantic.generate(
- input_ids,
- history_prompt=history_prompt,
- semantic_generation_config=semantic_generation_config,
- **kwargs_semantic,
- )
- # 2. Generate from the coarse model
- if "generation_config" in kwargs_coarse:
- kwargs_coarse.pop("generation_config")
- coarse_output = self.coarse_acoustics.generate(
- semantic_output,
- history_prompt=history_prompt,
- semantic_generation_config=semantic_generation_config,
- coarse_generation_config=coarse_generation_config,
- codebook_size=self.generation_config.codebook_size,
- return_output_lengths=return_output_lengths,
- **kwargs_coarse,
- )
- output_lengths = None
- if return_output_lengths:
- coarse_output, output_lengths = coarse_output
- # (batch_size, seq_len*coarse_codebooks) -> (batch_size, seq_len)
- output_lengths = output_lengths // coarse_generation_config.n_coarse_codebooks
- # 3. "generate" from the fine model
- if "generation_config" in kwargs_fine:
- kwargs_fine.pop("generation_config")
- output = self.fine_acoustics.generate(
- coarse_output,
- history_prompt=history_prompt,
- semantic_generation_config=semantic_generation_config,
- coarse_generation_config=coarse_generation_config,
- fine_generation_config=fine_generation_config,
- codebook_size=self.generation_config.codebook_size,
- **kwargs_fine,
- )
- if getattr(self, "fine_acoustics_hook", None) is not None:
- # Manually offload fine_acoustics to CPU
- # and load codec_model to GPU
- # since bark doesn't use codec_model forward pass
- self.fine_acoustics_hook.offload()
- self.codec_model = self.codec_model.to(self.device)
- # 4. Decode the output and generate audio array
- audio = self.codec_decode(output, output_lengths)
- if getattr(self, "codec_model_hook", None) is not None:
- # Offload codec_model to CPU
- self.codec_model_hook.offload()
- if return_output_lengths:
- output_lengths = [len(sample) for sample in audio]
- audio = nn.utils.rnn.pad_sequence(audio, batch_first=True, padding_value=0)
- return audio, output_lengths
- return audio
- def tie_weights(self):
- """
- Tie the weights between the input embeddings list and the output embeddings list.
- If the `torchscript` flag is set in the configuration, can't handle parameter sharing so we are cloning the
- weights instead.
- """
- for module in self.modules():
- if hasattr(module, "_tie_weights"):
- module._tie_weights()
- __all__ = [
- "BarkFineModel",
- "BarkSemanticModel",
- "BarkCoarseModel",
- "BarkModel",
- "BarkPreTrainedModel",
- "BarkCausalModel",
- ]
|