modeling_clap.py 81 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113111411151116111711181119112011211122112311241125112611271128112911301131113211331134113511361137113811391140114111421143114411451146114711481149115011511152115311541155115611571158115911601161116211631164116511661167116811691170117111721173117411751176117711781179118011811182118311841185118611871188118911901191119211931194119511961197119811991200120112021203120412051206120712081209121012111212121312141215121612171218121912201221122212231224122512261227122812291230123112321233123412351236123712381239124012411242124312441245124612471248124912501251125212531254125512561257125812591260126112621263126412651266126712681269127012711272127312741275127612771278127912801281128212831284128512861287128812891290129112921293129412951296129712981299130013011302130313041305130613071308130913101311131213131314131513161317131813191320132113221323132413251326132713281329133013311332133313341335133613371338133913401341134213431344134513461347134813491350135113521353135413551356135713581359136013611362136313641365136613671368136913701371137213731374137513761377137813791380138113821383138413851386138713881389139013911392139313941395139613971398139914001401140214031404140514061407140814091410141114121413141414151416141714181419142014211422142314241425142614271428142914301431143214331434143514361437143814391440144114421443144414451446144714481449145014511452145314541455145614571458145914601461146214631464146514661467146814691470147114721473147414751476147714781479148014811482148314841485148614871488148914901491149214931494149514961497149814991500150115021503150415051506150715081509151015111512151315141515151615171518151915201521152215231524152515261527152815291530153115321533153415351536153715381539154015411542154315441545154615471548154915501551155215531554155515561557155815591560156115621563156415651566156715681569157015711572157315741575157615771578157915801581158215831584158515861587158815891590159115921593159415951596159715981599160016011602160316041605160616071608160916101611161216131614161516161617161816191620162116221623162416251626162716281629163016311632163316341635163616371638163916401641164216431644164516461647164816491650165116521653165416551656165716581659166016611662166316641665166616671668166916701671167216731674167516761677167816791680168116821683168416851686168716881689169016911692169316941695169616971698169917001701170217031704170517061707170817091710171117121713171417151716171717181719172017211722172317241725172617271728172917301731173217331734173517361737173817391740174117421743174417451746174717481749175017511752175317541755175617571758175917601761176217631764176517661767176817691770177117721773177417751776177717781779178017811782178317841785178617871788178917901791179217931794179517961797179817991800180118021803180418051806180718081809181018111812181318141815181618171818181918201821182218231824182518261827182818291830183118321833183418351836183718381839184018411842184318441845184618471848184918501851185218531854185518561857185818591860186118621863186418651866186718681869187018711872187318741875187618771878187918801881188218831884188518861887188818891890189118921893189418951896189718981899190019011902190319041905190619071908190919101911191219131914191519161917191819191920192119221923192419251926192719281929
  1. # coding=utf-8
  2. # Copyright 2023 The LAION-AI Team and The HuggingFace Team. All rights reserved.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. """PyTorch CLAP model."""
  16. import collections
  17. import math
  18. from dataclasses import dataclass
  19. from typing import Any, Callable, Optional, Union
  20. import torch
  21. import torch.nn.functional as F
  22. from torch import nn
  23. from ...activations import ACT2FN
  24. from ...modeling_layers import GradientCheckpointingLayer
  25. from ...modeling_outputs import (
  26. BaseModelOutput,
  27. BaseModelOutputWithPooling,
  28. BaseModelOutputWithPoolingAndCrossAttentions,
  29. )
  30. from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
  31. from ...pytorch_utils import apply_chunking_to_forward, find_pruneable_heads_and_indices, meshgrid, prune_linear_layer
  32. from ...utils import ModelOutput, auto_docstring, can_return_tuple, filter_out_non_signature_kwargs, logging, torch_int
  33. from .configuration_clap import ClapAudioConfig, ClapConfig, ClapTextConfig
  34. logger = logging.get_logger(__name__)
  35. # Adapted from: https://github.com/LAION-AI/CLAP/blob/6ad05a971ba0622f6acee8c41993e0d02bbed639/src/open_clip/utils.py#L191
  36. def interpolate(hidden_states, ratio):
  37. """
  38. Interpolate data in time domain. This is used to compensate the resolution reduction in downsampling of a CNN.
  39. Args:
  40. hidden_states (`torch.FloatTensor` of shape (batch_size, time_length, classes_num)):
  41. Input hidden states
  42. ratio (`int`):
  43. The ratio of the length of the output to the length of the input.
  44. """
  45. (batch_size, time_length, classes_num) = hidden_states.shape
  46. upsampled = hidden_states[:, :, None, :].repeat(1, 1, ratio, 1)
  47. upsampled = upsampled.reshape(batch_size, time_length * ratio, classes_num)
  48. return upsampled
  49. # Adapted from https://github.com/LAION-AI/CLAP/blob/6ad05a971ba0622f6acee8c41993e0d02bbed639/src/open_clip/htsat.py#L249
  50. def window_partition(hidden_states, window_size):
  51. """
  52. Returns the resized hidden states. The output shape should be `(batch_size * num_windows, window_size, window_size,
  53. num_channels)`
  54. Args:
  55. hidden_states (`torch.FloatTensor` of shape `(batch_size, height, width, num_channels)`):
  56. Input hidden states
  57. window_size (`int`):
  58. Window size
  59. """
  60. batch_size, height, width, num_channels = hidden_states.shape
  61. hidden_states = hidden_states.view(
  62. batch_size, height // window_size, window_size, width // window_size, window_size, num_channels
  63. )
  64. windows = hidden_states.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, num_channels)
  65. return windows
  66. # Adapted from https://github.com/LAION-AI/CLAP/blob/6ad05a971ba0622f6acee8c41993e0d02bbed639/src/open_clip/htsat.py#L263
  67. def window_reverse(windows, window_size, height, width):
  68. """
  69. Merges windows to produce higher resolution features.
  70. Args:
  71. windows (`torch.FloatTensor` of shape `(num_windows * batch_size, window_size, window_size, num_channels)`):
  72. Input windows
  73. window_size (`int`):
  74. Window size
  75. height (`int`):
  76. Height of the resized audio
  77. width (`int`):
  78. Width of the resized audio
  79. """
  80. num_channels = windows.shape[-1]
  81. windows = windows.view(-1, height // window_size, width // window_size, window_size, window_size, num_channels)
  82. windows = windows.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, height, width, num_channels)
  83. return windows
  84. # Copied from transformers.models.roberta.modeling_roberta.create_position_ids_from_input_ids
  85. def create_position_ids_from_input_ids(input_ids, padding_idx, past_key_values_length=0):
  86. """
  87. Replace non-padding symbols with their position numbers. Position numbers begin at padding_idx+1. Padding symbols
  88. are ignored. This is modified from fairseq's `utils.make_positions`.
  89. Args:
  90. x: torch.Tensor x:
  91. Returns: torch.Tensor
  92. """
  93. # The series of casts and type-conversions here are carefully balanced to both work with ONNX export and XLA.
  94. mask = input_ids.ne(padding_idx).int()
  95. incremental_indices = (torch.cumsum(mask, dim=1).type_as(mask) + past_key_values_length) * mask
  96. return incremental_indices.long() + padding_idx
  97. # contrastive loss function, adapted from
  98. # https://sachinruk.github.io/blog/pytorch/pytorch%20lightning/loss%20function/gpu/2021/03/07/CLIP.html#CLIP-loss-function
  99. def contrastive_loss(logits: torch.Tensor) -> torch.Tensor:
  100. labels = torch.arange(len(logits), device=logits.device)
  101. return nn.functional.cross_entropy(logits, labels)
  102. @dataclass
  103. @auto_docstring(
  104. custom_intro="""
  105. Base class for text model's outputs that also contains a pooling of the last hidden states.
  106. """
  107. )
  108. # Copied from transformers.models.clip.modeling_clip.CLIPTextModelOutput with CLIP->Clap
  109. class ClapTextModelOutput(ModelOutput):
  110. r"""
  111. text_embeds (`torch.FloatTensor` of shape `(batch_size, output_dim)` *optional* returned when model is initialized with `with_projection=True`):
  112. The text embeddings obtained by applying the projection layer to the pooler_output.
  113. """
  114. text_embeds: Optional[torch.FloatTensor] = None
  115. last_hidden_state: Optional[torch.FloatTensor] = None
  116. hidden_states: Optional[tuple[torch.FloatTensor, ...]] = None
  117. attentions: Optional[tuple[torch.FloatTensor, ...]] = None
  118. @dataclass
  119. @auto_docstring(
  120. custom_intro="""
  121. ClapAudio model output to mimic the output of the original implementation.
  122. """
  123. )
  124. class ClapAudioModelOutput(ModelOutput):
  125. r"""
  126. audio_embeds (`torch.FloatTensor` of shape `(batch_size, hidden_size)`):
  127. The Audio embeddings obtained by applying the projection layer to the pooler_output.
  128. """
  129. audio_embeds: Optional[torch.FloatTensor] = None
  130. last_hidden_state: Optional[torch.FloatTensor] = None
  131. hidden_states: Optional[tuple[torch.FloatTensor, ...]] = None
  132. attentions: Optional[tuple[torch.FloatTensor, ...]] = None
  133. @dataclass
  134. @auto_docstring
  135. # Copied from transformers.models.clip.modeling_clip.CLIPOutput with CLIP->Clap, vision->audio, Vision->Audio, image->audio
  136. class ClapOutput(ModelOutput):
  137. r"""
  138. loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `return_loss` is `True`):
  139. Contrastive loss for audio-text similarity.
  140. logits_per_audio (`torch.FloatTensor` of shape `(audio_batch_size, text_batch_size)`):
  141. The scaled dot product scores between `audio_embeds` and `text_embeds`. This represents the audio-text
  142. similarity scores.
  143. logits_per_text (`torch.FloatTensor` of shape `(text_batch_size, audio_batch_size)`):
  144. The scaled dot product scores between `text_embeds` and `audio_embeds`. This represents the text-audio
  145. similarity scores.
  146. text_embeds (`torch.FloatTensor` of shape `(batch_size, output_dim`):
  147. The text embeddings obtained by applying the projection layer to the pooled output of [`ClapTextModel`].
  148. audio_embeds (`torch.FloatTensor` of shape `(batch_size, output_dim`):
  149. The audio embeddings obtained by applying the projection layer to the pooled output of [`ClapAudioModel`].
  150. text_model_output (`BaseModelOutputWithPooling`):
  151. The output of the [`ClapTextModel`].
  152. audio_model_output (`BaseModelOutputWithPooling`):
  153. The output of the [`ClapAudioModel`].
  154. """
  155. loss: Optional[torch.FloatTensor] = None
  156. logits_per_audio: Optional[torch.FloatTensor] = None
  157. logits_per_text: Optional[torch.FloatTensor] = None
  158. text_embeds: Optional[torch.FloatTensor] = None
  159. audio_embeds: Optional[torch.FloatTensor] = None
  160. text_model_output: BaseModelOutputWithPooling = None
  161. audio_model_output: BaseModelOutputWithPooling = None
  162. def to_tuple(self) -> tuple[Any]:
  163. return tuple(
  164. self[k] if k not in ["text_model_output", "audio_model_output"] else getattr(self, k).to_tuple()
  165. for k in self.keys()
  166. )
  167. # Adapted from transformers.models.swin.modeling_swin.SwinDropPath
  168. class ClapDropPath(nn.Module):
  169. """
  170. Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). This is a slightly
  171. refactored version of the `SwinDropPath` implementation.
  172. """
  173. def __init__(self, drop_prob=None):
  174. super().__init__()
  175. self.drop_prob = drop_prob
  176. def forward(self, hidden_states):
  177. if self.drop_prob == 0.0 or not self.training:
  178. return hidden_states
  179. keep_prob = 1 - self.drop_prob
  180. # work with diff dim tensors, not just 2D ConvNets
  181. shape = (hidden_states.shape[0],) + (1,) * (hidden_states.ndim - 1)
  182. random_tensor = keep_prob + torch.rand(shape, dtype=hidden_states.dtype, device=hidden_states.device)
  183. random_tensor.floor_() # binarize
  184. output = hidden_states.div(keep_prob) * random_tensor
  185. return output
  186. # Adapted from https://github.com/LAION-AI/CLAP/blob/6ad05a971ba0622f6acee8c41993e0d02bbed639/src/open_clip/feature_fusion.py#L133
  187. class ClapAudioAFFBlock(nn.Module):
  188. r"""
  189. ATTENTIONAL FEATURE FUSION Block from CLAP, since in CLAP we are always in 2D mode, it is not needed to implement
  190. the 1D version.
  191. """
  192. def __init__(self, config: ClapAudioConfig):
  193. super().__init__()
  194. channels = config.patch_embeds_hidden_size
  195. downsize_ratio = config.aff_block_r
  196. inter_channels = int(channels // downsize_ratio)
  197. self.local_att = nn.Sequential(
  198. nn.Conv2d(channels, inter_channels, kernel_size=1, stride=1, padding=0),
  199. nn.BatchNorm2d(inter_channels),
  200. nn.ReLU(inplace=True),
  201. nn.Conv2d(inter_channels, channels, kernel_size=1, stride=1, padding=0),
  202. nn.BatchNorm2d(channels),
  203. )
  204. self.global_att = nn.Sequential(
  205. nn.AdaptiveAvgPool2d(1),
  206. nn.Conv2d(channels, inter_channels, kernel_size=1, stride=1, padding=0),
  207. nn.BatchNorm2d(inter_channels),
  208. nn.ReLU(inplace=True),
  209. nn.Conv2d(inter_channels, channels, kernel_size=1, stride=1, padding=0),
  210. nn.BatchNorm2d(channels),
  211. )
  212. self.sigmoid = nn.Sigmoid()
  213. def forward(self, hidden_states, residual):
  214. attention_input = hidden_states + residual
  215. fused_layer_output = self.local_att(attention_input) + self.global_att(attention_input)
  216. fused_layer_output = self.sigmoid(fused_layer_output)
  217. output = 2 * hidden_states * fused_layer_output + 2 * residual * (1 - fused_layer_output)
  218. return output
  219. class ClapAudioPatchEmbed(nn.Module):
  220. """
  221. This module converts the hidden states reshaped as an image to patch embeddings ready to be passed to the
  222. Transformer block.
  223. """
  224. def __init__(self, config: ClapAudioConfig):
  225. super().__init__()
  226. img_size = (config.spec_size, config.spec_size) if isinstance(config.spec_size, int) else config.spec_size
  227. patch_size = (
  228. (config.patch_size, config.patch_size) if isinstance(config.patch_size, int) else config.patch_size
  229. )
  230. patch_stride = (
  231. (config.patch_stride, config.patch_stride) if isinstance(config.patch_stride, int) else config.patch_stride
  232. )
  233. self.img_size = img_size
  234. self.patch_stride = patch_stride
  235. self.grid_size = (img_size[0] // patch_stride[0], img_size[1] // patch_stride[1])
  236. self.num_patches = self.grid_size[0] * self.grid_size[1]
  237. self.flatten = config.flatten_patch_embeds
  238. self.enable_fusion = config.enable_fusion
  239. padding = ((patch_size[0] - patch_stride[0]) // 2, (patch_size[1] - patch_stride[1]) // 2)
  240. scale_factor = 4 if (self.enable_fusion) and (config.fusion_type == "channel_map") else 1
  241. self.proj = nn.Conv2d(
  242. config.patch_embed_input_channels * scale_factor,
  243. config.patch_embeds_hidden_size,
  244. kernel_size=patch_size,
  245. stride=patch_stride,
  246. padding=padding,
  247. )
  248. self.norm = nn.LayerNorm(config.patch_embeds_hidden_size) if config.enable_patch_layer_norm else nn.Identity()
  249. if self.enable_fusion:
  250. self.fusion_model = ClapAudioAFFBlock(config)
  251. self.mel_conv2d = nn.Conv2d(
  252. config.patch_embed_input_channels,
  253. config.patch_embeds_hidden_size,
  254. kernel_size=(patch_size[0], patch_size[1] * 3),
  255. stride=(patch_stride[0], patch_stride[1] * 3),
  256. padding=padding,
  257. )
  258. def forward(self, hidden_states, is_longer_idx=None):
  259. if self.enable_fusion:
  260. # retrieve the last mel as we have transposed the input
  261. global_hidden_states = hidden_states[:, 0:1, :, :]
  262. # global processing
  263. batch_size, num_channels, height, width = global_hidden_states.shape
  264. if height != self.img_size[0] or width != self.img_size[1]:
  265. raise ValueError(
  266. f"Input audio size ({height}*{width}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."
  267. )
  268. global_hidden_states = self.proj(global_hidden_states)
  269. output_width = global_hidden_states.size(-1)
  270. if len(is_longer_idx) > 0:
  271. # local processing
  272. local_hidden_states = hidden_states[is_longer_idx, 1:, :, :].contiguous()
  273. batch_size, num_channels, height, width = local_hidden_states.shape
  274. local_hidden_states = local_hidden_states.view(batch_size * num_channels, 1, height, width)
  275. local_hidden_states = self.mel_conv2d(local_hidden_states)
  276. _, features, height, width = local_hidden_states.shape
  277. local_hidden_states = local_hidden_states.view(batch_size, num_channels, features, height, width)
  278. local_hidden_states = local_hidden_states.permute((0, 2, 3, 1, 4)).contiguous().flatten(3)
  279. local_width = local_hidden_states.size(-1)
  280. local_hidden_states = torch.nn.functional.pad(
  281. local_hidden_states, (0, output_width - local_width), "constant", 0
  282. )
  283. global_hidden_states[is_longer_idx] = self.fusion_model(
  284. global_hidden_states[is_longer_idx], local_hidden_states
  285. )
  286. hidden_states = global_hidden_states
  287. else:
  288. _, _, height, width = hidden_states.shape
  289. if height != self.img_size[0] or width != self.img_size[1]:
  290. raise ValueError(
  291. f"Input audio size ({height}*{width}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."
  292. )
  293. hidden_states = self.proj(hidden_states)
  294. if self.flatten:
  295. hidden_states = hidden_states.flatten(2).transpose(1, 2)
  296. hidden_states = self.norm(hidden_states)
  297. return hidden_states
  298. # Copied from transformers.models.swin.modeling_swin.SwinSelfAttention with Swin->ClapAudio
  299. class ClapAudioSelfAttention(nn.Module):
  300. def __init__(self, config, dim, num_heads, window_size):
  301. super().__init__()
  302. if dim % num_heads != 0:
  303. raise ValueError(
  304. f"The hidden size ({dim}) is not a multiple of the number of attention heads ({num_heads})"
  305. )
  306. self.num_attention_heads = num_heads
  307. self.attention_head_size = int(dim / num_heads)
  308. self.all_head_size = self.num_attention_heads * self.attention_head_size
  309. self.window_size = (
  310. window_size if isinstance(window_size, collections.abc.Iterable) else (window_size, window_size)
  311. )
  312. self.relative_position_bias_table = nn.Parameter(
  313. torch.zeros((2 * self.window_size[0] - 1) * (2 * self.window_size[1] - 1), num_heads)
  314. )
  315. # get pair-wise relative position index for each token inside the window
  316. coords_h = torch.arange(self.window_size[0])
  317. coords_w = torch.arange(self.window_size[1])
  318. coords = torch.stack(meshgrid([coords_h, coords_w], indexing="ij"))
  319. coords_flatten = torch.flatten(coords, 1)
  320. relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :]
  321. relative_coords = relative_coords.permute(1, 2, 0).contiguous()
  322. relative_coords[:, :, 0] += self.window_size[0] - 1
  323. relative_coords[:, :, 1] += self.window_size[1] - 1
  324. relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1
  325. relative_position_index = relative_coords.sum(-1)
  326. self.register_buffer("relative_position_index", relative_position_index)
  327. self.query = nn.Linear(self.all_head_size, self.all_head_size, bias=config.qkv_bias)
  328. self.key = nn.Linear(self.all_head_size, self.all_head_size, bias=config.qkv_bias)
  329. self.value = nn.Linear(self.all_head_size, self.all_head_size, bias=config.qkv_bias)
  330. self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
  331. def forward(
  332. self,
  333. hidden_states: torch.Tensor,
  334. attention_mask: Optional[torch.FloatTensor] = None,
  335. head_mask: Optional[torch.FloatTensor] = None,
  336. output_attentions: Optional[bool] = False,
  337. ) -> tuple[torch.Tensor]:
  338. batch_size, dim, num_channels = hidden_states.shape
  339. hidden_shape = (batch_size, dim, -1, self.attention_head_size)
  340. query_layer = self.query(hidden_states).view(hidden_shape).transpose(1, 2)
  341. key_layer = self.key(hidden_states).view(hidden_shape).transpose(1, 2)
  342. value_layer = self.value(hidden_states).view(hidden_shape).transpose(1, 2)
  343. # Take the dot product between "query" and "key" to get the raw attention scores.
  344. attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
  345. attention_scores = attention_scores / math.sqrt(self.attention_head_size)
  346. relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)]
  347. relative_position_bias = relative_position_bias.view(
  348. self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1
  349. )
  350. relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous()
  351. attention_scores = attention_scores + relative_position_bias.unsqueeze(0)
  352. if attention_mask is not None:
  353. # Apply the attention mask is (precomputed for all layers in ClapAudioModel forward() function)
  354. mask_shape = attention_mask.shape[0]
  355. attention_scores = attention_scores.view(
  356. batch_size // mask_shape, mask_shape, self.num_attention_heads, dim, dim
  357. )
  358. attention_scores = attention_scores + attention_mask.unsqueeze(1).unsqueeze(0)
  359. attention_scores = attention_scores.view(-1, self.num_attention_heads, dim, dim)
  360. # Normalize the attention scores to probabilities.
  361. attention_probs = nn.functional.softmax(attention_scores, dim=-1)
  362. # This is actually dropping out entire tokens to attend to, which might
  363. # seem a bit unusual, but is taken from the original Transformer paper.
  364. attention_probs = self.dropout(attention_probs)
  365. # Mask heads if we want to
  366. if head_mask is not None:
  367. attention_probs = attention_probs * head_mask
  368. context_layer = torch.matmul(attention_probs, value_layer)
  369. context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
  370. new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
  371. context_layer = context_layer.view(new_context_layer_shape)
  372. outputs = (context_layer, attention_probs) if output_attentions else (context_layer,)
  373. return outputs
  374. # Copied from transformers.models.swin.modeling_swin.SwinSelfOutput with Swin->ClapAudio
  375. class ClapAudioSelfOutput(nn.Module):
  376. def __init__(self, config, dim):
  377. super().__init__()
  378. self.dense = nn.Linear(dim, dim)
  379. self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
  380. def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor:
  381. hidden_states = self.dense(hidden_states)
  382. hidden_states = self.dropout(hidden_states)
  383. return hidden_states
  384. # Copied from transformers.models.swin.modeling_swin.SwinAttention with Swin->ClapAudio
  385. class ClapAudioAttention(nn.Module):
  386. def __init__(self, config, dim, num_heads, window_size):
  387. super().__init__()
  388. self.self = ClapAudioSelfAttention(config, dim, num_heads, window_size)
  389. self.output = ClapAudioSelfOutput(config, dim)
  390. self.pruned_heads = set()
  391. def prune_heads(self, heads):
  392. if len(heads) == 0:
  393. return
  394. heads, index = find_pruneable_heads_and_indices(
  395. heads, self.self.num_attention_heads, self.self.attention_head_size, self.pruned_heads
  396. )
  397. # Prune linear layers
  398. self.self.query = prune_linear_layer(self.self.query, index)
  399. self.self.key = prune_linear_layer(self.self.key, index)
  400. self.self.value = prune_linear_layer(self.self.value, index)
  401. self.output.dense = prune_linear_layer(self.output.dense, index, dim=1)
  402. # Update hyper params and store pruned heads
  403. self.self.num_attention_heads = self.self.num_attention_heads - len(heads)
  404. self.self.all_head_size = self.self.attention_head_size * self.self.num_attention_heads
  405. self.pruned_heads = self.pruned_heads.union(heads)
  406. def forward(
  407. self,
  408. hidden_states: torch.Tensor,
  409. attention_mask: Optional[torch.FloatTensor] = None,
  410. head_mask: Optional[torch.FloatTensor] = None,
  411. output_attentions: Optional[bool] = False,
  412. ) -> tuple[torch.Tensor]:
  413. self_outputs = self.self(hidden_states, attention_mask, head_mask, output_attentions)
  414. attention_output = self.output(self_outputs[0], hidden_states)
  415. outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them
  416. return outputs
  417. # Copied from transformers.models.swin.modeling_swin.SwinIntermediate with Swin->ClapAudio
  418. class ClapAudioIntermediate(nn.Module):
  419. def __init__(self, config, dim):
  420. super().__init__()
  421. self.dense = nn.Linear(dim, int(config.mlp_ratio * dim))
  422. if isinstance(config.hidden_act, str):
  423. self.intermediate_act_fn = ACT2FN[config.hidden_act]
  424. else:
  425. self.intermediate_act_fn = config.hidden_act
  426. def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
  427. hidden_states = self.dense(hidden_states)
  428. hidden_states = self.intermediate_act_fn(hidden_states)
  429. return hidden_states
  430. # Copied from transformers.models.swin.modeling_swin.SwinOutput with Swin->ClapAudio
  431. class ClapAudioOutput(nn.Module):
  432. def __init__(self, config, dim):
  433. super().__init__()
  434. self.dense = nn.Linear(int(config.mlp_ratio * dim), dim)
  435. self.dropout = nn.Dropout(config.hidden_dropout_prob)
  436. def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
  437. hidden_states = self.dense(hidden_states)
  438. hidden_states = self.dropout(hidden_states)
  439. return hidden_states
  440. # Copied from transformers.models.swin.modeling_swin.SwinLayer with SwinDropPath->ClapDropPath, Swin->ClapAudio
  441. class ClapAudioLayer(nn.Module):
  442. def __init__(self, config, dim, input_resolution, num_heads, drop_path_rate=0.0, shift_size=0):
  443. super().__init__()
  444. self.chunk_size_feed_forward = config.chunk_size_feed_forward
  445. self.shift_size = shift_size
  446. self.window_size = config.window_size
  447. self.input_resolution = input_resolution
  448. self.layernorm_before = nn.LayerNorm(dim, eps=config.layer_norm_eps)
  449. self.attention = ClapAudioAttention(config, dim, num_heads, window_size=self.window_size)
  450. self.drop_path = ClapDropPath(drop_path_rate) if drop_path_rate > 0.0 else nn.Identity()
  451. self.layernorm_after = nn.LayerNorm(dim, eps=config.layer_norm_eps)
  452. self.intermediate = ClapAudioIntermediate(config, dim)
  453. self.output = ClapAudioOutput(config, dim)
  454. def set_shift_and_window_size(self, input_resolution):
  455. if min(input_resolution) <= self.window_size:
  456. # if window size is larger than input resolution, we don't partition windows
  457. self.shift_size = torch_int(0)
  458. self.window_size = (
  459. torch.min(torch.tensor(input_resolution)) if torch.jit.is_tracing() else min(input_resolution)
  460. )
  461. def get_attn_mask(self, height, width, dtype, device):
  462. if self.shift_size > 0:
  463. # calculate attention mask for SW-MSA
  464. img_mask = torch.zeros((1, height, width, 1), dtype=dtype, device=device)
  465. height_slices = (
  466. slice(0, -self.window_size),
  467. slice(-self.window_size, -self.shift_size),
  468. slice(-self.shift_size, None),
  469. )
  470. width_slices = (
  471. slice(0, -self.window_size),
  472. slice(-self.window_size, -self.shift_size),
  473. slice(-self.shift_size, None),
  474. )
  475. count = 0
  476. for height_slice in height_slices:
  477. for width_slice in width_slices:
  478. img_mask[:, height_slice, width_slice, :] = count
  479. count += 1
  480. mask_windows = window_partition(img_mask, self.window_size)
  481. mask_windows = mask_windows.view(-1, self.window_size * self.window_size)
  482. attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
  483. attn_mask = attn_mask.masked_fill(attn_mask != 0, -100.0).masked_fill(attn_mask == 0, 0.0)
  484. else:
  485. attn_mask = None
  486. return attn_mask
  487. def maybe_pad(self, hidden_states, height, width):
  488. pad_right = (self.window_size - width % self.window_size) % self.window_size
  489. pad_bottom = (self.window_size - height % self.window_size) % self.window_size
  490. pad_values = (0, 0, 0, pad_right, 0, pad_bottom)
  491. hidden_states = nn.functional.pad(hidden_states, pad_values)
  492. return hidden_states, pad_values
  493. def forward(
  494. self,
  495. hidden_states: torch.Tensor,
  496. input_dimensions: tuple[int, int],
  497. head_mask: Optional[torch.FloatTensor] = None,
  498. output_attentions: Optional[bool] = False,
  499. always_partition: Optional[bool] = False,
  500. ) -> tuple[torch.Tensor, torch.Tensor]:
  501. if not always_partition:
  502. self.set_shift_and_window_size(input_dimensions)
  503. else:
  504. pass
  505. height, width = input_dimensions
  506. batch_size, _, channels = hidden_states.size()
  507. shortcut = hidden_states
  508. hidden_states = self.layernorm_before(hidden_states)
  509. hidden_states = hidden_states.view(batch_size, height, width, channels)
  510. # pad hidden_states to multiples of window size
  511. hidden_states, pad_values = self.maybe_pad(hidden_states, height, width)
  512. _, height_pad, width_pad, _ = hidden_states.shape
  513. # cyclic shift
  514. if self.shift_size > 0:
  515. shifted_hidden_states = torch.roll(hidden_states, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2))
  516. else:
  517. shifted_hidden_states = hidden_states
  518. # partition windows
  519. hidden_states_windows = window_partition(shifted_hidden_states, self.window_size)
  520. hidden_states_windows = hidden_states_windows.view(-1, self.window_size * self.window_size, channels)
  521. attn_mask = self.get_attn_mask(
  522. height_pad, width_pad, dtype=hidden_states.dtype, device=hidden_states_windows.device
  523. )
  524. attention_outputs = self.attention(
  525. hidden_states_windows, attn_mask, head_mask, output_attentions=output_attentions
  526. )
  527. attention_output = attention_outputs[0]
  528. attention_windows = attention_output.view(-1, self.window_size, self.window_size, channels)
  529. shifted_windows = window_reverse(attention_windows, self.window_size, height_pad, width_pad)
  530. # reverse cyclic shift
  531. if self.shift_size > 0:
  532. attention_windows = torch.roll(shifted_windows, shifts=(self.shift_size, self.shift_size), dims=(1, 2))
  533. else:
  534. attention_windows = shifted_windows
  535. was_padded = pad_values[3] > 0 or pad_values[5] > 0
  536. if was_padded:
  537. attention_windows = attention_windows[:, :height, :width, :].contiguous()
  538. attention_windows = attention_windows.view(batch_size, height * width, channels)
  539. hidden_states = shortcut + self.drop_path(attention_windows)
  540. layer_output = self.layernorm_after(hidden_states)
  541. layer_output = self.intermediate(layer_output)
  542. layer_output = hidden_states + self.output(layer_output)
  543. layer_outputs = (layer_output, attention_outputs[1]) if output_attentions else (layer_output,)
  544. return layer_outputs
  545. # Copied from transformers.models.swin.modeling_swin.SwinStage with Swin->ClapAudio
  546. class ClapAudioStage(GradientCheckpointingLayer):
  547. def __init__(self, config, dim, input_resolution, depth, num_heads, drop_path, downsample):
  548. super().__init__()
  549. self.config = config
  550. self.dim = dim
  551. self.blocks = nn.ModuleList(
  552. [
  553. ClapAudioLayer(
  554. config=config,
  555. dim=dim,
  556. input_resolution=input_resolution,
  557. num_heads=num_heads,
  558. drop_path_rate=drop_path[i],
  559. shift_size=0 if (i % 2 == 0) else config.window_size // 2,
  560. )
  561. for i in range(depth)
  562. ]
  563. )
  564. # patch merging layer
  565. if downsample is not None:
  566. self.downsample = downsample(input_resolution, dim=dim, norm_layer=nn.LayerNorm)
  567. else:
  568. self.downsample = None
  569. self.pointing = False
  570. def forward(
  571. self,
  572. hidden_states: torch.Tensor,
  573. input_dimensions: tuple[int, int],
  574. head_mask: Optional[torch.FloatTensor] = None,
  575. output_attentions: Optional[bool] = False,
  576. always_partition: Optional[bool] = False,
  577. ) -> tuple[torch.Tensor]:
  578. height, width = input_dimensions
  579. for i, layer_module in enumerate(self.blocks):
  580. layer_head_mask = head_mask[i] if head_mask is not None else None
  581. layer_outputs = layer_module(
  582. hidden_states, input_dimensions, layer_head_mask, output_attentions, always_partition
  583. )
  584. hidden_states = layer_outputs[0]
  585. hidden_states_before_downsampling = hidden_states
  586. if self.downsample is not None:
  587. height_downsampled, width_downsampled = (height + 1) // 2, (width + 1) // 2
  588. output_dimensions = (height, width, height_downsampled, width_downsampled)
  589. hidden_states = self.downsample(hidden_states_before_downsampling, input_dimensions)
  590. else:
  591. output_dimensions = (height, width, height, width)
  592. stage_outputs = (hidden_states, hidden_states_before_downsampling, output_dimensions)
  593. if output_attentions:
  594. stage_outputs += layer_outputs[1:]
  595. return stage_outputs
  596. # Copied from transformers.models.swin.modeling_swin.SwinPatchMerging with Swin->ClapAudio
  597. class ClapAudioPatchMerging(nn.Module):
  598. """
  599. Patch Merging Layer.
  600. Args:
  601. input_resolution (`tuple[int]`):
  602. Resolution of input feature.
  603. dim (`int`):
  604. Number of input channels.
  605. norm_layer (`nn.Module`, *optional*, defaults to `nn.LayerNorm`):
  606. Normalization layer class.
  607. """
  608. def __init__(self, input_resolution: tuple[int], dim: int, norm_layer: nn.Module = nn.LayerNorm) -> None:
  609. super().__init__()
  610. self.input_resolution = input_resolution
  611. self.dim = dim
  612. self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False)
  613. self.norm = norm_layer(4 * dim)
  614. def maybe_pad(self, input_feature, height, width):
  615. should_pad = (height % 2 == 1) or (width % 2 == 1)
  616. if should_pad:
  617. pad_values = (0, 0, 0, width % 2, 0, height % 2)
  618. input_feature = nn.functional.pad(input_feature, pad_values)
  619. return input_feature
  620. def forward(self, input_feature: torch.Tensor, input_dimensions: tuple[int, int]) -> torch.Tensor:
  621. height, width = input_dimensions
  622. # `dim` is height * width
  623. batch_size, dim, num_channels = input_feature.shape
  624. input_feature = input_feature.view(batch_size, height, width, num_channels)
  625. # pad input to be divisible by width and height, if needed
  626. input_feature = self.maybe_pad(input_feature, height, width)
  627. # [batch_size, height/2, width/2, num_channels]
  628. input_feature_0 = input_feature[:, 0::2, 0::2, :]
  629. # [batch_size, height/2, width/2, num_channels]
  630. input_feature_1 = input_feature[:, 1::2, 0::2, :]
  631. # [batch_size, height/2, width/2, num_channels]
  632. input_feature_2 = input_feature[:, 0::2, 1::2, :]
  633. # [batch_size, height/2, width/2, num_channels]
  634. input_feature_3 = input_feature[:, 1::2, 1::2, :]
  635. # batch_size height/2 width/2 4*num_channels
  636. input_feature = torch.cat([input_feature_0, input_feature_1, input_feature_2, input_feature_3], -1)
  637. input_feature = input_feature.view(batch_size, -1, 4 * num_channels) # batch_size height/2*width/2 4*C
  638. input_feature = self.norm(input_feature)
  639. input_feature = self.reduction(input_feature)
  640. return input_feature
  641. class ClapAudioEncoder(nn.Module):
  642. def __init__(self, config):
  643. super().__init__()
  644. self.num_layers = len(config.depths)
  645. self.config = config
  646. self.patch_embed = ClapAudioPatchEmbed(config)
  647. self.enable_fusion = config.enable_fusion
  648. self.patch_stride = self.patch_embed.patch_stride
  649. self.spec_size = config.spec_size
  650. self.freq_ratio = config.spec_size // config.num_mel_bins
  651. self.num_features = int(config.patch_embeds_hidden_size * 2 ** (self.num_layers - 1))
  652. drop_path_rate = [x.item() for x in torch.linspace(0, config.drop_path_rate, sum(config.depths), device="cpu")]
  653. grid_size = self.patch_embed.grid_size
  654. self.input_resolutions = [(grid_size[0] // (2**i), grid_size[1] // (2**i)) for i in range(self.num_layers)]
  655. self.layers = nn.ModuleList(
  656. [
  657. ClapAudioStage(
  658. config=config,
  659. dim=int(config.patch_embeds_hidden_size * 2**i_layer),
  660. input_resolution=self.input_resolutions[i_layer],
  661. depth=config.depths[i_layer],
  662. num_heads=config.num_attention_heads[i_layer],
  663. drop_path=drop_path_rate[sum(config.depths[:i_layer]) : sum(config.depths[: i_layer + 1])],
  664. downsample=ClapAudioPatchMerging if (i_layer < self.num_layers - 1) else None,
  665. )
  666. for i_layer in range(self.num_layers)
  667. ]
  668. )
  669. self.gradient_checkpointing = False
  670. self.batch_norm = nn.BatchNorm2d(config.num_mel_bins)
  671. self.norm = nn.LayerNorm(self.num_features)
  672. self.depths = config.depths
  673. self.avgpool = nn.AdaptiveAvgPool1d(1)
  674. def reshape_mel2img(self, normalized_input_features):
  675. """
  676. The input is 4 normalized log mel spectrograms. It is reshape to the common shape of images. Each channel
  677. should represent 1 of the 4 crops of the spectrogram. For more details, refer to the [`ClapFeatureExtractor`].
  678. """
  679. _, _, time_length, freq_length = normalized_input_features.shape
  680. spec_width = int(self.spec_size * self.freq_ratio)
  681. spec_height = self.spec_size // self.freq_ratio
  682. if time_length > spec_width or freq_length > spec_height:
  683. raise ValueError("the wav size should be less than or equal to the swin input size")
  684. # to avoid bicubic zero error
  685. if time_length < spec_width:
  686. normalized_input_features = nn.functional.interpolate(
  687. normalized_input_features, (spec_width, freq_length), mode="bicubic", align_corners=True
  688. )
  689. if freq_length < spec_height:
  690. normalized_input_features = nn.functional.interpolate(
  691. normalized_input_features, (time_length, spec_height), mode="bicubic", align_corners=True
  692. )
  693. batch, channels, time, freq = normalized_input_features.shape
  694. # batch_size, channels, spec_width, spec_height --> batch_size, channels, spec_height * freq_ratio, spec_width // freq_ratio
  695. normalized_input_features = normalized_input_features.reshape(
  696. batch, channels * self.freq_ratio, time // self.freq_ratio, freq
  697. )
  698. normalized_input_features = normalized_input_features.permute(0, 1, 3, 2).contiguous()
  699. normalized_input_features = normalized_input_features.reshape(
  700. batch, channels, freq * self.freq_ratio, time // self.freq_ratio
  701. )
  702. return normalized_input_features
  703. def forward(
  704. self,
  705. input_features,
  706. is_longer: Optional[torch.FloatTensor] = None,
  707. head_mask: Optional[torch.FloatTensor] = None,
  708. output_attentions: Optional[bool] = False,
  709. output_hidden_states: Optional[bool] = False,
  710. output_hidden_states_before_downsampling: Optional[bool] = False,
  711. always_partition: Optional[bool] = False,
  712. return_dict: Optional[bool] = True,
  713. ) -> Union[tuple, ClapAudioModelOutput]:
  714. input_features = input_features.transpose(1, 3)
  715. normalized_input_features = self.batch_norm(input_features)
  716. normalized_input_features = normalized_input_features.transpose(1, 3)
  717. is_longer_list_idx = None
  718. if self.enable_fusion:
  719. is_longer_list = is_longer.to(input_features.device)
  720. is_longer_list_idx = torch.where(is_longer_list == 1)[0]
  721. hidden_states = self.reshape_mel2img(normalized_input_features)
  722. frames_num = hidden_states.shape[2]
  723. hidden_states = self.patch_embed(hidden_states, is_longer_list_idx)
  724. all_hidden_states = () if output_hidden_states else None
  725. all_reshaped_hidden_states = () if output_hidden_states else None
  726. all_self_attentions = () if output_attentions else None
  727. input_dimensions = self.input_resolutions[0]
  728. if output_hidden_states:
  729. batch_size, _, hidden_size = hidden_states.shape
  730. # rearrange batch_size (height width) channels -> batch_size channel height width
  731. reshaped_hidden_state = hidden_states.view(batch_size, *input_dimensions, hidden_size)
  732. reshaped_hidden_state = reshaped_hidden_state.permute(0, 3, 1, 2)
  733. all_hidden_states += (hidden_states,)
  734. all_reshaped_hidden_states += (reshaped_hidden_state,)
  735. for i, layer_module in enumerate(self.layers):
  736. layer_head_mask = head_mask[i] if head_mask is not None else None
  737. input_dimensions = self.input_resolutions[i]
  738. layer_outputs = layer_module(
  739. hidden_states, input_dimensions, layer_head_mask, output_attentions, always_partition
  740. )
  741. hidden_states = layer_outputs[0]
  742. hidden_states_before_downsampling = layer_outputs[1]
  743. output_dimensions = layer_outputs[2]
  744. input_dimensions = (output_dimensions[-2], output_dimensions[-1])
  745. if output_hidden_states and output_hidden_states_before_downsampling:
  746. batch_size, _, hidden_size = hidden_states_before_downsampling.shape
  747. # rearrange batch_size (height width) channels -> batch_size channel height width
  748. # here we use the original (not downsampled) height and width
  749. reshaped_hidden_state = hidden_states_before_downsampling.view(
  750. batch_size, *(output_dimensions[0], output_dimensions[1]), hidden_size
  751. )
  752. reshaped_hidden_state = reshaped_hidden_state.permute(0, 3, 1, 2)
  753. all_hidden_states += (hidden_states_before_downsampling,)
  754. all_reshaped_hidden_states += (reshaped_hidden_state,)
  755. elif output_hidden_states and not output_hidden_states_before_downsampling:
  756. batch_size, _, hidden_size = hidden_states.shape
  757. # rearrange batch_size (height width) channels -> batch_size channel height width
  758. reshaped_hidden_state = hidden_states.view(batch_size, *input_dimensions, hidden_size)
  759. reshaped_hidden_state = reshaped_hidden_state.permute(0, 3, 1, 2)
  760. all_hidden_states += (hidden_states,)
  761. all_reshaped_hidden_states += (reshaped_hidden_state,)
  762. if output_attentions:
  763. all_self_attentions += layer_outputs[3:]
  764. last_hidden_state = self.norm(hidden_states)
  765. batch_size, _, n_channels = last_hidden_state.shape
  766. freq_shape = frames_num // (2 ** (len(self.depths) - 1)) // self.patch_stride[0]
  767. temporal_shape = frames_num // (2 ** (len(self.depths) - 1)) // self.patch_stride[1]
  768. last_hidden_state = (
  769. last_hidden_state.permute(0, 2, 1).contiguous().reshape(batch_size, n_channels, freq_shape, temporal_shape)
  770. )
  771. batch_size, n_channels, n_frequencies, n_temp = last_hidden_state.shape
  772. # group 2D CNN
  773. c_freq_bin = n_frequencies // self.freq_ratio
  774. last_hidden_state = last_hidden_state.reshape(
  775. batch_size, n_channels, n_frequencies // c_freq_bin, c_freq_bin, n_temp
  776. )
  777. last_hidden_state = (
  778. last_hidden_state.permute(0, 1, 3, 2, 4).contiguous().reshape(batch_size, n_channels, c_freq_bin, -1)
  779. )
  780. latent_output = self.avgpool(torch.flatten(last_hidden_state, 2))
  781. latent_output = torch.flatten(latent_output, 1)
  782. if not return_dict:
  783. return tuple(
  784. v
  785. for v in [
  786. last_hidden_state,
  787. latent_output,
  788. all_reshaped_hidden_states,
  789. all_self_attentions,
  790. ]
  791. if v is not None
  792. )
  793. return BaseModelOutputWithPooling(
  794. last_hidden_state=last_hidden_state,
  795. pooler_output=latent_output,
  796. hidden_states=all_reshaped_hidden_states,
  797. attentions=all_self_attentions,
  798. )
  799. class ClapProjectionLayer(nn.Module):
  800. def __init__(self, config: Union[ClapAudioConfig, ClapTextConfig]):
  801. super().__init__()
  802. self.config = config
  803. hidden_size = config.hidden_size
  804. projection_dim = config.projection_dim
  805. self.linear1 = nn.Linear(hidden_size, projection_dim)
  806. self.activation = ACT2FN[config.projection_hidden_act]
  807. self.linear2 = nn.Linear(projection_dim, projection_dim)
  808. def forward(self, hidden_states):
  809. hidden_states = self.linear1(hidden_states)
  810. hidden_states = self.activation(hidden_states)
  811. hidden_states = self.linear2(hidden_states)
  812. return hidden_states
  813. # Copied from transformers.models.roberta.modeling_roberta.RobertaEmbeddings with Roberta->ClapText, persistent=False->persistent=True
  814. class ClapTextEmbeddings(nn.Module):
  815. """
  816. Same as BertEmbeddings with a tiny tweak for positional embeddings indexing.
  817. """
  818. # Copied from transformers.models.bert.modeling_bert.BertEmbeddings.__init__
  819. def __init__(self, config):
  820. super().__init__()
  821. self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id)
  822. self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size)
  823. self.token_type_embeddings = nn.Embedding(config.type_vocab_size, config.hidden_size)
  824. # self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load
  825. # any TensorFlow checkpoint file
  826. self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
  827. self.dropout = nn.Dropout(config.hidden_dropout_prob)
  828. # position_ids (1, len position emb) is contiguous in memory and exported when serialized
  829. self.position_embedding_type = getattr(config, "position_embedding_type", "absolute")
  830. self.register_buffer(
  831. "position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)), persistent=True
  832. )
  833. self.register_buffer(
  834. "token_type_ids", torch.zeros(self.position_ids.size(), dtype=torch.long), persistent=True
  835. )
  836. # End copy
  837. self.padding_idx = config.pad_token_id
  838. self.position_embeddings = nn.Embedding(
  839. config.max_position_embeddings, config.hidden_size, padding_idx=self.padding_idx
  840. )
  841. def forward(
  842. self, input_ids=None, token_type_ids=None, position_ids=None, inputs_embeds=None, past_key_values_length=0
  843. ):
  844. if position_ids is None:
  845. if input_ids is not None:
  846. # Create the position ids from the input token ids. Any padded tokens remain padded.
  847. position_ids = create_position_ids_from_input_ids(input_ids, self.padding_idx, past_key_values_length)
  848. else:
  849. position_ids = self.create_position_ids_from_inputs_embeds(inputs_embeds)
  850. if input_ids is not None:
  851. input_shape = input_ids.size()
  852. else:
  853. input_shape = inputs_embeds.size()[:-1]
  854. seq_length = input_shape[1]
  855. # Setting the token_type_ids to the registered buffer in constructor where it is all zeros, which usually occurs
  856. # when its auto-generated, registered buffer helps users when tracing the model without passing token_type_ids, solves
  857. # issue #5664
  858. if token_type_ids is None:
  859. if hasattr(self, "token_type_ids"):
  860. buffered_token_type_ids = self.token_type_ids[:, :seq_length]
  861. buffered_token_type_ids_expanded = buffered_token_type_ids.expand(input_shape[0], seq_length)
  862. token_type_ids = buffered_token_type_ids_expanded
  863. else:
  864. token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=self.position_ids.device)
  865. if inputs_embeds is None:
  866. inputs_embeds = self.word_embeddings(input_ids)
  867. token_type_embeddings = self.token_type_embeddings(token_type_ids)
  868. embeddings = inputs_embeds + token_type_embeddings
  869. if self.position_embedding_type == "absolute":
  870. position_embeddings = self.position_embeddings(position_ids)
  871. embeddings += position_embeddings
  872. embeddings = self.LayerNorm(embeddings)
  873. embeddings = self.dropout(embeddings)
  874. return embeddings
  875. def create_position_ids_from_inputs_embeds(self, inputs_embeds):
  876. """
  877. We are provided embeddings directly. We cannot infer which are padded so just generate sequential position ids.
  878. Args:
  879. inputs_embeds: torch.Tensor
  880. Returns: torch.Tensor
  881. """
  882. input_shape = inputs_embeds.size()[:-1]
  883. sequence_length = input_shape[1]
  884. position_ids = torch.arange(
  885. self.padding_idx + 1, sequence_length + self.padding_idx + 1, dtype=torch.long, device=inputs_embeds.device
  886. )
  887. return position_ids.unsqueeze(0).expand(input_shape)
  888. # Copied from transformers.models.align.modeling_align.eager_attention_forward
  889. def eager_attention_forward(
  890. module: nn.Module,
  891. query: torch.Tensor,
  892. key: torch.Tensor,
  893. value: torch.Tensor,
  894. attention_mask: Optional[torch.Tensor],
  895. scaling: float,
  896. dropout: float = 0.0,
  897. head_mask: Optional[torch.Tensor] = None,
  898. **kwargs,
  899. ):
  900. attn_weights = torch.matmul(query, key.transpose(2, 3)) * scaling
  901. if attention_mask is not None:
  902. causal_mask = attention_mask[:, :, :, : key.shape[-2]]
  903. attn_weights = attn_weights + causal_mask
  904. attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
  905. attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
  906. if head_mask is not None:
  907. attn_weights = attn_weights * head_mask.view(1, -1, 1, 1)
  908. attn_output = torch.matmul(attn_weights, value)
  909. attn_output = attn_output.transpose(1, 2).contiguous()
  910. return attn_output, attn_weights
  911. # Copied from transformers.models.align.modeling_align.AlignTextSelfAttention with Align->Clap
  912. class ClapTextSelfAttention(nn.Module):
  913. def __init__(self, config):
  914. super().__init__()
  915. if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"):
  916. raise ValueError(
  917. f"The hidden size ({config.hidden_size}) is not a multiple of the number of attention "
  918. f"heads ({config.num_attention_heads})"
  919. )
  920. self.config = config
  921. self.num_attention_heads = config.num_attention_heads
  922. self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
  923. self.all_head_size = self.num_attention_heads * self.attention_head_size
  924. self.query = nn.Linear(config.hidden_size, self.all_head_size)
  925. self.key = nn.Linear(config.hidden_size, self.all_head_size)
  926. self.value = nn.Linear(config.hidden_size, self.all_head_size)
  927. self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
  928. self.attention_dropout = config.attention_probs_dropout_prob
  929. self.scaling = self.attention_head_size**-0.5
  930. def forward(
  931. self,
  932. hidden_states: torch.Tensor,
  933. attention_mask: Optional[torch.FloatTensor] = None,
  934. head_mask: Optional[torch.FloatTensor] = None,
  935. output_attentions: Optional[bool] = False,
  936. **kwargs,
  937. ) -> tuple[torch.Tensor]:
  938. input_shape = hidden_states.shape[:-1]
  939. hidden_shape = (*input_shape, -1, self.attention_head_size)
  940. query_states = self.query(hidden_states).view(hidden_shape).transpose(1, 2)
  941. key_states = self.key(hidden_states).view(hidden_shape).transpose(1, 2)
  942. value_states = self.value(hidden_states).view(hidden_shape).transpose(1, 2)
  943. attention_interface: Callable = eager_attention_forward
  944. if self.config._attn_implementation != "eager":
  945. attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
  946. attn_output, attn_weights = attention_interface(
  947. self,
  948. query_states,
  949. key_states,
  950. value_states,
  951. attention_mask,
  952. dropout=0.0 if not self.training else self.attention_dropout,
  953. scaling=self.scaling,
  954. head_mask=head_mask,
  955. **kwargs,
  956. )
  957. attn_output = attn_output.reshape(*input_shape, -1).contiguous()
  958. outputs = (attn_output, attn_weights) if output_attentions else (attn_output,)
  959. return outputs
  960. # Copied from transformers.models.bert.modeling_bert.BertSelfOutput
  961. class ClapTextSelfOutput(nn.Module):
  962. def __init__(self, config):
  963. super().__init__()
  964. self.dense = nn.Linear(config.hidden_size, config.hidden_size)
  965. self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
  966. self.dropout = nn.Dropout(config.hidden_dropout_prob)
  967. def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor:
  968. hidden_states = self.dense(hidden_states)
  969. hidden_states = self.dropout(hidden_states)
  970. hidden_states = self.LayerNorm(hidden_states + input_tensor)
  971. return hidden_states
  972. # Copied from transformers.models.align.modeling_align.AlignTextAttention with Align->Clap
  973. class ClapTextAttention(nn.Module):
  974. def __init__(self, config):
  975. super().__init__()
  976. self.self = ClapTextSelfAttention(config)
  977. self.output = ClapTextSelfOutput(config)
  978. self.pruned_heads = set()
  979. def prune_heads(self, heads):
  980. if len(heads) == 0:
  981. return
  982. heads, index = find_pruneable_heads_and_indices(
  983. heads, self.self.num_attention_heads, self.self.attention_head_size, self.pruned_heads
  984. )
  985. # Prune linear layers
  986. self.self.query = prune_linear_layer(self.self.query, index)
  987. self.self.key = prune_linear_layer(self.self.key, index)
  988. self.self.value = prune_linear_layer(self.self.value, index)
  989. self.output.dense = prune_linear_layer(self.output.dense, index, dim=1)
  990. # Update hyper params and store pruned heads
  991. self.self.num_attention_heads = self.self.num_attention_heads - len(heads)
  992. self.self.all_head_size = self.self.attention_head_size * self.self.num_attention_heads
  993. self.pruned_heads = self.pruned_heads.union(heads)
  994. def forward(
  995. self,
  996. hidden_states: torch.Tensor,
  997. attention_mask: Optional[torch.FloatTensor] = None,
  998. head_mask: Optional[torch.FloatTensor] = None,
  999. output_attentions: Optional[bool] = False,
  1000. **kwargs,
  1001. ) -> tuple[torch.Tensor]:
  1002. self_outputs = self.self(
  1003. hidden_states,
  1004. attention_mask=attention_mask,
  1005. head_mask=head_mask,
  1006. output_attentions=output_attentions,
  1007. **kwargs,
  1008. )
  1009. attention_output = self.output(self_outputs[0], hidden_states)
  1010. outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them
  1011. return outputs
  1012. # Copied from transformers.models.bert.modeling_bert.BertIntermediate
  1013. class ClapTextIntermediate(nn.Module):
  1014. def __init__(self, config):
  1015. super().__init__()
  1016. self.dense = nn.Linear(config.hidden_size, config.intermediate_size)
  1017. if isinstance(config.hidden_act, str):
  1018. self.intermediate_act_fn = ACT2FN[config.hidden_act]
  1019. else:
  1020. self.intermediate_act_fn = config.hidden_act
  1021. def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
  1022. hidden_states = self.dense(hidden_states)
  1023. hidden_states = self.intermediate_act_fn(hidden_states)
  1024. return hidden_states
  1025. # Copied from transformers.models.bert.modeling_bert.BertOutput
  1026. class ClapTextOutput(nn.Module):
  1027. def __init__(self, config):
  1028. super().__init__()
  1029. self.dense = nn.Linear(config.intermediate_size, config.hidden_size)
  1030. self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
  1031. self.dropout = nn.Dropout(config.hidden_dropout_prob)
  1032. def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor:
  1033. hidden_states = self.dense(hidden_states)
  1034. hidden_states = self.dropout(hidden_states)
  1035. hidden_states = self.LayerNorm(hidden_states + input_tensor)
  1036. return hidden_states
  1037. # Copied from transformers.models.align.modeling_align.AlignTextLayer with Align->Clap
  1038. class ClapTextLayer(GradientCheckpointingLayer):
  1039. def __init__(self, config):
  1040. super().__init__()
  1041. self.chunk_size_feed_forward = config.chunk_size_feed_forward
  1042. self.seq_len_dim = 1
  1043. self.attention = ClapTextAttention(config)
  1044. self.intermediate = ClapTextIntermediate(config)
  1045. self.output = ClapTextOutput(config)
  1046. def forward(
  1047. self,
  1048. hidden_states: torch.Tensor,
  1049. attention_mask: Optional[torch.FloatTensor] = None,
  1050. head_mask: Optional[torch.FloatTensor] = None,
  1051. output_attentions: Optional[bool] = False,
  1052. **kwargs,
  1053. ) -> tuple[torch.Tensor]:
  1054. self_attention_outputs = self.attention(
  1055. hidden_states,
  1056. attention_mask=attention_mask,
  1057. head_mask=head_mask,
  1058. output_attentions=output_attentions,
  1059. **kwargs,
  1060. )
  1061. attention_output = self_attention_outputs[0]
  1062. outputs = self_attention_outputs[1:] # add self attentions if we output attention weights
  1063. layer_output = apply_chunking_to_forward(
  1064. self.feed_forward_chunk, self.chunk_size_feed_forward, self.seq_len_dim, attention_output
  1065. )
  1066. outputs = (layer_output,) + outputs
  1067. return outputs
  1068. def feed_forward_chunk(self, attention_output):
  1069. intermediate_output = self.intermediate(attention_output)
  1070. layer_output = self.output(intermediate_output, attention_output)
  1071. return layer_output
  1072. # Copied from transformers.models.align.modeling_align.AlignTextEncoder with Align->Clap
  1073. class ClapTextEncoder(nn.Module):
  1074. def __init__(self, config):
  1075. super().__init__()
  1076. self.config = config
  1077. self.layer = nn.ModuleList([ClapTextLayer(config) for i in range(config.num_hidden_layers)])
  1078. self.gradient_checkpointing = False
  1079. @can_return_tuple
  1080. def forward(
  1081. self,
  1082. hidden_states: torch.Tensor,
  1083. attention_mask: Optional[torch.FloatTensor] = None,
  1084. head_mask: Optional[torch.FloatTensor] = None,
  1085. output_attentions: Optional[bool] = False,
  1086. output_hidden_states: Optional[bool] = False,
  1087. return_dict: Optional[bool] = True,
  1088. **kwargs,
  1089. ) -> Union[tuple[torch.Tensor], BaseModelOutput]:
  1090. all_hidden_states = () if output_hidden_states else None
  1091. all_self_attentions = () if output_attentions else None
  1092. for i, layer_module in enumerate(self.layer):
  1093. if output_hidden_states:
  1094. all_hidden_states = all_hidden_states + (hidden_states,)
  1095. layer_head_mask = head_mask[i] if head_mask is not None else None
  1096. layer_outputs = layer_module(
  1097. hidden_states=hidden_states,
  1098. attention_mask=attention_mask,
  1099. head_mask=layer_head_mask,
  1100. output_attentions=output_attentions,
  1101. **kwargs,
  1102. )
  1103. hidden_states = layer_outputs[0]
  1104. if output_attentions:
  1105. all_self_attentions = all_self_attentions + (layer_outputs[1],)
  1106. if output_hidden_states:
  1107. all_hidden_states = all_hidden_states + (hidden_states,)
  1108. return BaseModelOutput(
  1109. last_hidden_state=hidden_states,
  1110. hidden_states=all_hidden_states,
  1111. attentions=all_self_attentions,
  1112. )
  1113. # Copied from transformers.models.bert.modeling_bert.BertPooler
  1114. class ClapTextPooler(nn.Module):
  1115. def __init__(self, config):
  1116. super().__init__()
  1117. self.dense = nn.Linear(config.hidden_size, config.hidden_size)
  1118. self.activation = nn.Tanh()
  1119. def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
  1120. # We "pool" the model by simply taking the hidden state corresponding
  1121. # to the first token.
  1122. first_token_tensor = hidden_states[:, 0]
  1123. pooled_output = self.dense(first_token_tensor)
  1124. pooled_output = self.activation(pooled_output)
  1125. return pooled_output
  1126. @auto_docstring
  1127. class ClapPreTrainedModel(PreTrainedModel):
  1128. config: ClapConfig
  1129. base_model_prefix = "clap"
  1130. supports_gradient_checkpointing = False
  1131. def _init_weights(self, module: nn.Module):
  1132. """Initialize the weights"""
  1133. factor = self.config.initializer_factor
  1134. if isinstance(module, ClapTextEmbeddings):
  1135. module.position_embeddings.weight.data.normal_(mean=0.0, std=factor * 0.02)
  1136. module.token_type_embeddings.weight.data.normal_(mean=0.0, std=factor * 0.02)
  1137. elif isinstance(module, ClapModel):
  1138. module.logit_scale_a.data.fill_(math.log(self.config.logit_scale_init_value))
  1139. module.logit_scale_t.data.fill_(math.log(self.config.logit_scale_init_value))
  1140. elif isinstance(module, nn.Embedding):
  1141. module.weight.data.normal_(mean=0.0, std=factor * 0.02)
  1142. elif isinstance(module, (nn.LayerNorm, nn.BatchNorm2d)):
  1143. module.bias.data.zero_()
  1144. module.weight.data.fill_(1.0)
  1145. elif isinstance(module, (nn.Conv2d, nn.Linear)):
  1146. in_proj_std = (self.config.hidden_size**-0.5) * ((2 * self.config.num_hidden_layers) ** -0.5) * factor
  1147. nn.init.normal_(module.weight, std=in_proj_std)
  1148. if module.bias is not None:
  1149. module.bias.data.zero_()
  1150. elif isinstance(module, ClapAudioSelfAttention):
  1151. module.relative_position_bias_table.data.zero_()
  1152. class ClapAudioModel(ClapPreTrainedModel):
  1153. config: ClapAudioConfig
  1154. main_input_name = "input_features"
  1155. def __init__(self, config: ClapAudioConfig):
  1156. super().__init__(config)
  1157. self.audio_encoder = ClapAudioEncoder(config)
  1158. # Initialize weights and apply final processing
  1159. self.post_init()
  1160. def get_input_embeddings(self) -> nn.Module:
  1161. return self.audio_encoder.patch_embed.proj
  1162. @auto_docstring
  1163. def forward(
  1164. self,
  1165. input_features: Optional[torch.FloatTensor] = None,
  1166. is_longer: Optional[torch.BoolTensor] = None,
  1167. output_attentions: Optional[bool] = None,
  1168. output_hidden_states: Optional[bool] = None,
  1169. return_dict: Optional[bool] = None,
  1170. ) -> Union[tuple, BaseModelOutputWithPooling]:
  1171. r"""
  1172. is_longer (`torch.FloatTensor`, of shape `(batch_size, 1)`, *optional*):
  1173. Whether the audio clip is longer than `max_length`. If `True`, a feature fusion will be enabled to enhance
  1174. the features.
  1175. Examples:
  1176. ```python
  1177. >>> from datasets import load_dataset
  1178. >>> from transformers import AutoProcessor, ClapAudioModel
  1179. >>> dataset = load_dataset("hf-internal-testing/ashraq-esc50-1-dog-example")
  1180. >>> audio_sample = dataset["train"]["audio"][0]["array"]
  1181. >>> model = ClapAudioModel.from_pretrained("laion/clap-htsat-fused")
  1182. >>> processor = AutoProcessor.from_pretrained("laion/clap-htsat-fused")
  1183. >>> inputs = processor(audios=audio_sample, return_tensors="pt")
  1184. >>> outputs = model(**inputs)
  1185. >>> last_hidden_state = outputs.last_hidden_state
  1186. ```"""
  1187. return_dict = return_dict if return_dict is not None else self.config.use_return_dict
  1188. output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
  1189. output_hidden_states = (
  1190. output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
  1191. )
  1192. return self.audio_encoder(
  1193. input_features=input_features,
  1194. is_longer=is_longer,
  1195. output_attentions=output_attentions,
  1196. output_hidden_states=output_hidden_states,
  1197. return_dict=return_dict,
  1198. )
  1199. @auto_docstring(
  1200. custom_intro="""
  1201. The model can behave as an encoder (with only self-attention) as well as a decoder, in which case a layer of
  1202. cross-attention is added between the self-attention layers, following the architecture described in *Attention is
  1203. all you need*_ by Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit, Llion Jones, Aidan N. Gomez, Lukasz
  1204. Kaiser and Illia Polosukhin.
  1205. To behave as an decoder the model needs to be initialized with the `is_decoder` argument of the configuration set
  1206. to `True`. To be used in a Seq2Seq model, the model needs to initialized with both `is_decoder` argument and
  1207. `add_cross_attention` set to `True`; an `encoder_hidden_states` is then expected as an input to the forward pass.
  1208. .. _*Attention is all you need*: https://huggingface.co/papers/1706.03762
  1209. """
  1210. )
  1211. class ClapTextModel(ClapPreTrainedModel):
  1212. config: ClapTextConfig
  1213. def __init__(self, config, add_pooling_layer=True):
  1214. r"""
  1215. add_pooling_layer (bool, *optional*, defaults to `True`):
  1216. Whether to add a pooling layer
  1217. """
  1218. super().__init__(config)
  1219. self.config = config
  1220. self.embeddings = ClapTextEmbeddings(config)
  1221. self.encoder = ClapTextEncoder(config)
  1222. self.pooler = ClapTextPooler(config) if add_pooling_layer else None
  1223. # Initialize weights and apply final processing
  1224. self.post_init()
  1225. def get_input_embeddings(self):
  1226. return self.embeddings.word_embeddings
  1227. def set_input_embeddings(self, value):
  1228. self.embeddings.word_embeddings = value
  1229. @can_return_tuple
  1230. @auto_docstring
  1231. def forward(
  1232. self,
  1233. input_ids: Optional[torch.Tensor] = None,
  1234. attention_mask: Optional[torch.Tensor] = None,
  1235. token_type_ids: Optional[torch.Tensor] = None,
  1236. position_ids: Optional[torch.Tensor] = None,
  1237. head_mask: Optional[torch.Tensor] = None,
  1238. inputs_embeds: Optional[torch.Tensor] = None,
  1239. output_attentions: Optional[bool] = None,
  1240. output_hidden_states: Optional[bool] = None,
  1241. return_dict: Optional[bool] = None,
  1242. ) -> Union[tuple[torch.Tensor], BaseModelOutputWithPoolingAndCrossAttentions]:
  1243. output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
  1244. output_hidden_states = (
  1245. output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
  1246. )
  1247. return_dict = return_dict if return_dict is not None else self.config.use_return_dict
  1248. if input_ids is not None and inputs_embeds is not None:
  1249. raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
  1250. elif input_ids is not None:
  1251. self.warn_if_padding_and_no_attention_mask(input_ids, attention_mask)
  1252. input_shape = input_ids.size()
  1253. elif inputs_embeds is not None:
  1254. input_shape = inputs_embeds.size()[:-1]
  1255. else:
  1256. raise ValueError("You have to specify either input_ids or inputs_embeds")
  1257. batch_size, seq_length = input_shape
  1258. device = input_ids.device if input_ids is not None else inputs_embeds.device
  1259. if attention_mask is None:
  1260. attention_mask = torch.ones(((batch_size, seq_length)), device=device)
  1261. if token_type_ids is None:
  1262. if hasattr(self.embeddings, "token_type_ids"):
  1263. buffered_token_type_ids = self.embeddings.token_type_ids[:, :seq_length]
  1264. buffered_token_type_ids_expanded = buffered_token_type_ids.expand(batch_size, seq_length)
  1265. token_type_ids = buffered_token_type_ids_expanded
  1266. else:
  1267. token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=device)
  1268. # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
  1269. # ourselves in which case we just need to make it broadcastable to all heads.
  1270. extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(attention_mask, input_shape)
  1271. # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
  1272. head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
  1273. embedding_output = self.embeddings(
  1274. input_ids=input_ids,
  1275. position_ids=position_ids,
  1276. token_type_ids=token_type_ids,
  1277. inputs_embeds=inputs_embeds,
  1278. )
  1279. encoder_outputs = self.encoder(
  1280. embedding_output,
  1281. attention_mask=extended_attention_mask,
  1282. head_mask=head_mask,
  1283. output_attentions=output_attentions,
  1284. output_hidden_states=output_hidden_states,
  1285. return_dict=True,
  1286. )
  1287. sequence_output = encoder_outputs[0]
  1288. pooled_output = self.pooler(sequence_output) if self.pooler is not None else None
  1289. return BaseModelOutputWithPooling(
  1290. last_hidden_state=sequence_output,
  1291. pooler_output=pooled_output,
  1292. hidden_states=encoder_outputs.hidden_states,
  1293. attentions=encoder_outputs.attentions,
  1294. )
  1295. @auto_docstring
  1296. class ClapModel(ClapPreTrainedModel):
  1297. config: ClapConfig
  1298. def __init__(self, config: ClapConfig):
  1299. super().__init__(config)
  1300. if not isinstance(config.text_config, ClapTextConfig):
  1301. raise TypeError(
  1302. "config.text_config is expected to be of type ClapTextConfig but is of type"
  1303. f" {type(config.text_config)}."
  1304. )
  1305. if not isinstance(config.audio_config, ClapAudioConfig):
  1306. raise TypeError(
  1307. "config.audio_config is expected to be of type ClapAudioConfig but is of type"
  1308. f" {type(config.audio_config)}."
  1309. )
  1310. text_config = config.text_config
  1311. audio_config = config.audio_config
  1312. self.logit_scale_a = nn.Parameter(torch.tensor(math.log(config.logit_scale_init_value)))
  1313. self.logit_scale_t = nn.Parameter(torch.tensor(math.log(config.logit_scale_init_value)))
  1314. self.projection_dim = config.projection_dim
  1315. self.text_model = ClapTextModel(text_config)
  1316. self.text_projection = ClapProjectionLayer(text_config)
  1317. self.audio_model = ClapAudioModel(audio_config)
  1318. self.audio_projection = ClapProjectionLayer(audio_config)
  1319. # Initialize weights and apply final processing
  1320. self.post_init()
  1321. @filter_out_non_signature_kwargs()
  1322. @auto_docstring
  1323. def get_text_features(
  1324. self,
  1325. input_ids: torch.Tensor,
  1326. attention_mask: Optional[torch.Tensor] = None,
  1327. position_ids: Optional[torch.Tensor] = None,
  1328. ) -> torch.FloatTensor:
  1329. r"""
  1330. Returns:
  1331. text_features (`torch.FloatTensor` of shape `(batch_size, output_dim`): The text embeddings obtained by
  1332. applying the projection layer to the pooled output of [`ClapTextModel`].
  1333. Examples:
  1334. ```python
  1335. >>> import torch
  1336. >>> from transformers import AutoTokenizer, ClapModel
  1337. >>> model = ClapModel.from_pretrained("laion/clap-htsat-unfused")
  1338. >>> tokenizer = AutoTokenizer.from_pretrained("laion/clap-htsat-unfused")
  1339. >>> inputs = tokenizer(["the sound of a cat", "the sound of a dog"], padding=True, return_tensors="pt")
  1340. >>> with torch.inference_mode():
  1341. ... text_features = model.get_text_features(**inputs)
  1342. ```"""
  1343. text_outputs: BaseModelOutputWithPooling = self.text_model(
  1344. input_ids=input_ids, attention_mask=attention_mask, position_ids=position_ids
  1345. )
  1346. text_features = self.text_projection(text_outputs.pooler_output)
  1347. text_features = F.normalize(text_features, dim=-1)
  1348. return text_features
  1349. @filter_out_non_signature_kwargs()
  1350. @auto_docstring
  1351. def get_audio_features(
  1352. self,
  1353. input_features: torch.Tensor,
  1354. is_longer: Optional[torch.Tensor] = None,
  1355. attention_mask: Optional[torch.Tensor] = None,
  1356. ) -> torch.FloatTensor:
  1357. r"""
  1358. is_longer (`torch.FloatTensor`, of shape `(batch_size, 1)`, *optional*):
  1359. Whether the audio clip is longer than `max_length`. If `True`, a feature fusion will be enabled to enhance
  1360. the features.
  1361. Returns:
  1362. audio_features (`torch.FloatTensor` of shape `(batch_size, output_dim`): The audio embeddings obtained by
  1363. applying the projection layer to the pooled output of [`ClapAudioModel`].
  1364. Examples:
  1365. ```python
  1366. >>> import torch
  1367. >>> from transformers import AutoFeatureExtractor, ClapModel
  1368. >>> model = ClapModel.from_pretrained("laion/clap-htsat-unfused")
  1369. >>> feature_extractor = AutoFeatureExtractor.from_pretrained("laion/clap-htsat-unfused")
  1370. >>> random_audio = torch.rand((16_000))
  1371. >>> inputs = feature_extractor(random_audio, return_tensors="pt")
  1372. >>> with torch.inference_mode():
  1373. ... audio_features = model.get_audio_features(**inputs)
  1374. ```"""
  1375. audio_outputs: BaseModelOutputWithPooling = self.audio_model(
  1376. input_features=input_features, is_longer=is_longer
  1377. )
  1378. audio_features = self.audio_projection(audio_outputs.pooler_output)
  1379. audio_features = F.normalize(audio_features, dim=-1)
  1380. return audio_features
  1381. @can_return_tuple
  1382. @auto_docstring
  1383. def forward(
  1384. self,
  1385. input_ids: Optional[torch.LongTensor] = None,
  1386. input_features: Optional[torch.FloatTensor] = None,
  1387. is_longer: Optional[torch.BoolTensor] = None,
  1388. attention_mask: Optional[torch.Tensor] = None,
  1389. position_ids: Optional[torch.LongTensor] = None,
  1390. return_loss: Optional[bool] = None,
  1391. output_attentions: Optional[bool] = None,
  1392. output_hidden_states: Optional[bool] = None,
  1393. return_dict: Optional[bool] = None,
  1394. ) -> Union[tuple, ClapOutput]:
  1395. r"""
  1396. is_longer (`torch.FloatTensor`, of shape `(batch_size, 1)`, *optional*):
  1397. Whether the audio clip is longer than `max_length`. If `True`, a feature fusion will be enabled to enhance
  1398. the features.
  1399. return_loss (`bool`, *optional*):
  1400. Whether or not to return the contrastive loss.
  1401. Examples:
  1402. ```python
  1403. >>> from datasets import load_dataset
  1404. >>> from transformers import AutoProcessor, ClapModel
  1405. >>> dataset = load_dataset("hf-internal-testing/ashraq-esc50-1-dog-example")
  1406. >>> audio_sample = dataset["train"]["audio"][0]["array"]
  1407. >>> model = ClapModel.from_pretrained("laion/clap-htsat-unfused")
  1408. >>> processor = AutoProcessor.from_pretrained("laion/clap-htsat-unfused")
  1409. >>> input_text = ["Sound of a dog", "Sound of vacuum cleaner"]
  1410. >>> inputs = processor(text=input_text, audios=audio_sample, return_tensors="pt", padding=True)
  1411. >>> outputs = model(**inputs)
  1412. >>> logits_per_audio = outputs.logits_per_audio # this is the audio-text similarity score
  1413. >>> probs = logits_per_audio.softmax(dim=-1) # we can take the softmax to get the label probabilities
  1414. ```"""
  1415. # Use CLAP model's config for some fields (if specified) instead of those of audio & text components.
  1416. output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
  1417. output_hidden_states = (
  1418. output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
  1419. )
  1420. return_dict = return_dict if return_dict is not None else self.config.use_return_dict
  1421. audio_outputs = self.audio_model(
  1422. input_features=input_features,
  1423. is_longer=is_longer,
  1424. output_attentions=output_attentions,
  1425. output_hidden_states=output_hidden_states,
  1426. return_dict=True,
  1427. )
  1428. text_outputs = self.text_model(
  1429. input_ids=input_ids,
  1430. attention_mask=attention_mask,
  1431. position_ids=position_ids,
  1432. output_attentions=output_attentions,
  1433. output_hidden_states=output_hidden_states,
  1434. return_dict=True,
  1435. )
  1436. audio_embeds = audio_outputs[1] if not return_dict else audio_outputs.pooler_output
  1437. audio_embeds = self.audio_projection(audio_embeds)
  1438. text_embeds = text_outputs[1] if not return_dict else text_outputs.pooler_output
  1439. text_embeds = self.text_projection(text_embeds)
  1440. # normalized features
  1441. audio_embeds = audio_embeds / audio_embeds.norm(p=2, dim=-1, keepdim=True)
  1442. text_embeds = text_embeds / text_embeds.norm(p=2, dim=-1, keepdim=True)
  1443. # cosine similarity as logits
  1444. logit_scale_text = self.logit_scale_t.exp()
  1445. logit_scale_audio = self.logit_scale_a.exp()
  1446. logits_per_text = torch.matmul(text_embeds, audio_embeds.t()) * logit_scale_text
  1447. logits_per_audio = torch.matmul(audio_embeds, text_embeds.t()) * logit_scale_audio
  1448. loss = None
  1449. if return_loss:
  1450. caption_loss = contrastive_loss(logits_per_text)
  1451. audio_loss = contrastive_loss(logits_per_audio.t())
  1452. loss = (caption_loss + audio_loss) / 2.0
  1453. return ClapOutput(
  1454. loss=loss,
  1455. logits_per_audio=logits_per_audio,
  1456. logits_per_text=logits_per_text,
  1457. text_embeds=text_embeds,
  1458. audio_embeds=audio_embeds,
  1459. text_model_output=text_outputs,
  1460. audio_model_output=audio_outputs,
  1461. )
  1462. @auto_docstring
  1463. class ClapTextModelWithProjection(ClapPreTrainedModel):
  1464. config: ClapTextConfig
  1465. def __init__(self, config: ClapTextConfig):
  1466. super().__init__(config)
  1467. self.text_model = ClapTextModel(config)
  1468. self.text_projection = ClapProjectionLayer(config)
  1469. # Initialize weights and apply final processing
  1470. self.post_init()
  1471. def get_input_embeddings(self) -> nn.Module:
  1472. return self.text_model.embeddings.word_embeddings
  1473. def set_input_embeddings(self, value):
  1474. self.text_model.embeddings.word_embeddings = value
  1475. @can_return_tuple
  1476. @auto_docstring
  1477. def forward(
  1478. self,
  1479. input_ids: Optional[torch.Tensor] = None,
  1480. attention_mask: Optional[torch.Tensor] = None,
  1481. position_ids: Optional[torch.Tensor] = None,
  1482. output_attentions: Optional[bool] = None,
  1483. output_hidden_states: Optional[bool] = None,
  1484. return_dict: Optional[bool] = None,
  1485. ) -> Union[tuple, ClapTextModelOutput]:
  1486. r"""
  1487. Examples:
  1488. ```python
  1489. >>> from transformers import AutoTokenizer, ClapTextModelWithProjection
  1490. >>> model = ClapTextModelWithProjection.from_pretrained("laion/clap-htsat-unfused")
  1491. >>> tokenizer = AutoTokenizer.from_pretrained("laion/clap-htsat-unfused")
  1492. >>> inputs = tokenizer(["a sound of a cat", "a sound of a dog"], padding=True, return_tensors="pt")
  1493. >>> outputs = model(**inputs)
  1494. >>> text_embeds = outputs.text_embeds
  1495. ```"""
  1496. return_dict = return_dict if return_dict is not None else self.config.use_return_dict
  1497. text_outputs = self.text_model(
  1498. input_ids=input_ids,
  1499. attention_mask=attention_mask,
  1500. position_ids=position_ids,
  1501. output_attentions=output_attentions,
  1502. output_hidden_states=output_hidden_states,
  1503. return_dict=True,
  1504. )
  1505. pooled_output = text_outputs[1] if not return_dict else text_outputs.pooler_output
  1506. text_embeds = self.text_projection(pooled_output)
  1507. return ClapTextModelOutput(
  1508. text_embeds=text_embeds,
  1509. last_hidden_state=text_outputs.last_hidden_state,
  1510. hidden_states=text_outputs.hidden_states,
  1511. attentions=text_outputs.attentions,
  1512. )
  1513. @auto_docstring
  1514. class ClapAudioModelWithProjection(ClapPreTrainedModel):
  1515. config: ClapAudioConfig
  1516. main_input_name = "input_features"
  1517. def __init__(self, config: ClapAudioConfig):
  1518. super().__init__(config)
  1519. self.audio_model = ClapAudioModel(config)
  1520. self.audio_projection = ClapProjectionLayer(config)
  1521. # Initialize weights and apply final processing
  1522. self.post_init()
  1523. def get_input_embeddings(self) -> nn.Module:
  1524. return self.audio_model.audio_encoder.patch_embed.proj
  1525. @can_return_tuple
  1526. @auto_docstring
  1527. def forward(
  1528. self,
  1529. input_features: Optional[torch.FloatTensor] = None,
  1530. is_longer: Optional[torch.BoolTensor] = None,
  1531. output_attentions: Optional[bool] = None,
  1532. output_hidden_states: Optional[bool] = None,
  1533. return_dict: Optional[bool] = None,
  1534. ) -> Union[tuple, ClapAudioModelOutput]:
  1535. r"""
  1536. is_longer (`torch.FloatTensor`, of shape `(batch_size, 1)`, *optional*):
  1537. Whether the audio clip is longer than `max_length`. If `True`, a feature fusion will be enabled to enhance
  1538. the features.
  1539. Examples:
  1540. ```python
  1541. >>> from datasets import load_dataset
  1542. >>> from transformers import ClapAudioModelWithProjection, ClapProcessor
  1543. >>> model = ClapAudioModelWithProjection.from_pretrained("laion/clap-htsat-fused")
  1544. >>> processor = ClapProcessor.from_pretrained("laion/clap-htsat-fused")
  1545. >>> dataset = load_dataset("hf-internal-testing/ashraq-esc50-1-dog-example")
  1546. >>> audio_sample = dataset["train"]["audio"][0]["array"]
  1547. >>> inputs = processor(audios=audio_sample, return_tensors="pt")
  1548. >>> outputs = model(**inputs)
  1549. >>> audio_embeds = outputs.audio_embeds
  1550. ```"""
  1551. return_dict = return_dict if return_dict is not None else self.config.use_return_dict
  1552. output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
  1553. output_hidden_states = (
  1554. output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
  1555. )
  1556. audio_outputs = self.audio_model(
  1557. input_features=input_features,
  1558. is_longer=is_longer,
  1559. output_attentions=output_attentions,
  1560. output_hidden_states=output_hidden_states,
  1561. return_dict=True,
  1562. )
  1563. pooled_output = audio_outputs[1] if not return_dict else audio_outputs.pooler_output
  1564. audio_embeds = self.audio_projection(pooled_output)
  1565. return ClapAudioModelOutput(
  1566. audio_embeds=audio_embeds,
  1567. last_hidden_state=audio_outputs.last_hidden_state,
  1568. attentions=audio_outputs.attentions,
  1569. hidden_states=audio_outputs.hidden_states,
  1570. )
  1571. __all__ = [
  1572. "ClapModel",
  1573. "ClapPreTrainedModel",
  1574. "ClapTextModel",
  1575. "ClapTextModelWithProjection",
  1576. "ClapAudioModel",
  1577. "ClapAudioModelWithProjection",
  1578. ]