modeling_align.py 50 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160116111621163116411651166116711681169117011711172117311741175117611771178117911801181118211831184118511861187118811891190119111921193119411951196119711981199120012011202120312041205120612071208120912101211121212131214121512161217121812191220122112221223122412251226122712281229123012311232123312341235123612371238123912401241124212431244124512461247124812491250125112521253125412551256125712581259126012611262126312641265126612671268126912701271127212731274127512761277127812791280128112821283128412851286128712881289129012911292
  1. # coding=utf-8
  2. # Copyright 2023 The Google Research Team Authors and The HuggingFace Team. All rights reserved.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. """PyTorch ALIGN model."""
  16. import math
  17. from dataclasses import dataclass
  18. from typing import Any, Callable, Optional, Union
  19. import torch
  20. from torch import nn
  21. from ...activations import ACT2FN
  22. from ...modeling_layers import GradientCheckpointingLayer
  23. from ...modeling_outputs import (
  24. BaseModelOutput,
  25. BaseModelOutputWithNoAttention,
  26. BaseModelOutputWithPooling,
  27. BaseModelOutputWithPoolingAndNoAttention,
  28. )
  29. from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
  30. from ...pytorch_utils import apply_chunking_to_forward, find_pruneable_heads_and_indices, prune_linear_layer
  31. from ...utils import ModelOutput, auto_docstring, can_return_tuple, filter_out_non_signature_kwargs, logging
  32. from .configuration_align import AlignConfig, AlignTextConfig, AlignVisionConfig
  33. logger = logging.get_logger(__name__)
  34. @dataclass
  35. @auto_docstring(
  36. custom_intro="""
  37. Base class for vision model's outputs that also contains image embeddings of the pooling of the last hidden states.
  38. """
  39. )
  40. class AlignVisionModelOutput(ModelOutput):
  41. r"""
  42. image_embeds (`torch.FloatTensor` of shape `(batch_size, output_dim)` *optional* returned when model is initialized with `with_projection=True`):
  43. The image embeddings obtained by applying the projection layer to the pooler_output.
  44. """
  45. image_embeds: Optional[torch.FloatTensor] = None
  46. last_hidden_state: Optional[torch.FloatTensor] = None
  47. hidden_states: Optional[tuple[torch.FloatTensor]] = None
  48. @dataclass
  49. @auto_docstring(
  50. custom_intro="""
  51. Base class for text model's outputs that also contains a pooling of the last hidden states.
  52. """
  53. )
  54. class AlignTextModelOutput(ModelOutput):
  55. r"""
  56. text_embeds (`torch.FloatTensor` of shape `(batch_size, output_dim)` *optional* returned when model is initialized with `with_projection=True`):
  57. The text embeddings obtained by applying the projection layer to the pooler_output.
  58. """
  59. text_embeds: Optional[torch.FloatTensor] = None
  60. last_hidden_state: Optional[torch.FloatTensor] = None
  61. hidden_states: Optional[tuple[torch.FloatTensor]] = None
  62. attentions: Optional[tuple[torch.FloatTensor]] = None
  63. @dataclass
  64. @auto_docstring
  65. class AlignOutput(ModelOutput):
  66. r"""
  67. loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `return_loss` is `True`):
  68. Contrastive loss for image-text similarity.
  69. logits_per_image (`torch.FloatTensor` of shape `(image_batch_size, text_batch_size)`):
  70. The scaled dot product scores between `image_embeds` and `text_embeds`. This represents the image-text
  71. similarity scores.
  72. logits_per_text (`torch.FloatTensor` of shape `(text_batch_size, image_batch_size)`):
  73. The scaled dot product scores between `text_embeds` and `image_embeds`. This represents the text-image
  74. similarity scores.
  75. text_embeds (`torch.FloatTensor` of shape `(batch_size, output_dim`):
  76. The text embeddings obtained by applying the projection layer to the pooled output of [`AlignTextModel`].
  77. image_embeds (`torch.FloatTensor` of shape `(batch_size, output_dim`):
  78. The output of [`AlignVisionModel`].
  79. text_model_output (`BaseModelOutputWithPooling`):
  80. The output of the [`AlignTextModel`].
  81. vision_model_output (`BaseModelOutputWithPoolingAndNoAttention`):
  82. The output of the [`AlignVisionModel`].
  83. """
  84. loss: Optional[torch.FloatTensor] = None
  85. logits_per_image: Optional[torch.FloatTensor] = None
  86. logits_per_text: Optional[torch.FloatTensor] = None
  87. text_embeds: Optional[torch.FloatTensor] = None
  88. image_embeds: Optional[torch.FloatTensor] = None
  89. text_model_output: BaseModelOutputWithPooling = None
  90. vision_model_output: BaseModelOutputWithPoolingAndNoAttention = None
  91. def to_tuple(self) -> tuple[Any]:
  92. return tuple(
  93. self[k] if k not in ["text_model_output", "vision_model_output"] else getattr(self, k).to_tuple()
  94. for k in self.keys()
  95. )
  96. # contrastive loss function, adapted from
  97. # https://sachinruk.github.io/blog/pytorch/pytorch%20lightning/loss%20function/gpu/2021/03/07/CLIP.html
  98. def contrastive_loss(logits: torch.Tensor) -> torch.Tensor:
  99. return nn.functional.cross_entropy(logits, torch.arange(len(logits), device=logits.device), label_smoothing=0.1)
  100. def align_loss(similarity: torch.Tensor) -> torch.Tensor:
  101. caption_loss = contrastive_loss(similarity)
  102. image_loss = contrastive_loss(similarity.t())
  103. return (caption_loss + image_loss) / 2.0
  104. # Copied from transformers.models.efficientnet.modeling_efficientnet.round_filters with EfficientNet->AlignVision
  105. def round_filters(config: AlignVisionConfig, num_channels: int):
  106. r"""
  107. Round number of filters based on depth multiplier.
  108. """
  109. divisor = config.depth_divisor
  110. num_channels *= config.width_coefficient
  111. new_dim = max(divisor, int(num_channels + divisor / 2) // divisor * divisor)
  112. # Make sure that round down does not go down by more than 10%.
  113. if new_dim < 0.9 * num_channels:
  114. new_dim += divisor
  115. return int(new_dim)
  116. # Copied from transformers.models.efficientnet.modeling_efficientnet.correct_pad
  117. def correct_pad(kernel_size: Union[int, tuple], adjust: bool = True):
  118. r"""
  119. Utility function to get the tuple padding value for the depthwise convolution.
  120. Args:
  121. kernel_size (`int` or `tuple`):
  122. Kernel size of the convolution layers.
  123. adjust (`bool`, *optional*, defaults to `True`):
  124. Adjusts padding value to apply to right and bottom sides of the input.
  125. """
  126. if isinstance(kernel_size, int):
  127. kernel_size = (kernel_size, kernel_size)
  128. correct = (kernel_size[0] // 2, kernel_size[1] // 2)
  129. if adjust:
  130. return (correct[1] - 1, correct[1], correct[0] - 1, correct[0])
  131. else:
  132. return (correct[1], correct[1], correct[0], correct[0])
  133. # Copied from transformers.models.efficientnet.modeling_efficientnet.EfficientNetEmbeddings with EfficientNet->AlignVision
  134. class AlignVisionEmbeddings(nn.Module):
  135. r"""
  136. A module that corresponds to the stem module of the original work.
  137. """
  138. def __init__(self, config: AlignVisionConfig):
  139. super().__init__()
  140. self.out_dim = round_filters(config, 32)
  141. self.padding = nn.ZeroPad2d(padding=(0, 1, 0, 1))
  142. self.convolution = nn.Conv2d(
  143. config.num_channels, self.out_dim, kernel_size=3, stride=2, padding="valid", bias=False
  144. )
  145. self.batchnorm = nn.BatchNorm2d(self.out_dim, eps=config.batch_norm_eps, momentum=config.batch_norm_momentum)
  146. self.activation = ACT2FN[config.hidden_act]
  147. def forward(self, pixel_values: torch.Tensor) -> torch.Tensor:
  148. features = self.padding(pixel_values)
  149. features = self.convolution(features)
  150. features = self.batchnorm(features)
  151. features = self.activation(features)
  152. return features
  153. # Copied from transformers.models.efficientnet.modeling_efficientnet.EfficientNetDepthwiseConv2d with EfficientNet->AlignVision
  154. class AlignVisionDepthwiseConv2d(nn.Conv2d):
  155. def __init__(
  156. self,
  157. in_channels,
  158. depth_multiplier=1,
  159. kernel_size=3,
  160. stride=1,
  161. padding=0,
  162. dilation=1,
  163. bias=True,
  164. padding_mode="zeros",
  165. ):
  166. out_channels = in_channels * depth_multiplier
  167. super().__init__(
  168. in_channels=in_channels,
  169. out_channels=out_channels,
  170. kernel_size=kernel_size,
  171. stride=stride,
  172. padding=padding,
  173. dilation=dilation,
  174. groups=in_channels,
  175. bias=bias,
  176. padding_mode=padding_mode,
  177. )
  178. # Copied from transformers.models.efficientnet.modeling_efficientnet.EfficientNetExpansionLayer with EfficientNet->AlignVision
  179. class AlignVisionExpansionLayer(nn.Module):
  180. r"""
  181. This corresponds to the expansion phase of each block in the original implementation.
  182. """
  183. def __init__(self, config: AlignVisionConfig, in_dim: int, out_dim: int, stride: int):
  184. super().__init__()
  185. self.expand_conv = nn.Conv2d(
  186. in_channels=in_dim,
  187. out_channels=out_dim,
  188. kernel_size=1,
  189. padding="same",
  190. bias=False,
  191. )
  192. self.expand_bn = nn.BatchNorm2d(num_features=out_dim, eps=config.batch_norm_eps)
  193. self.expand_act = ACT2FN[config.hidden_act]
  194. def forward(self, hidden_states: torch.FloatTensor) -> torch.Tensor:
  195. # Expand phase
  196. hidden_states = self.expand_conv(hidden_states)
  197. hidden_states = self.expand_bn(hidden_states)
  198. hidden_states = self.expand_act(hidden_states)
  199. return hidden_states
  200. # Copied from transformers.models.efficientnet.modeling_efficientnet.EfficientNetDepthwiseLayer with EfficientNet->AlignVision
  201. class AlignVisionDepthwiseLayer(nn.Module):
  202. r"""
  203. This corresponds to the depthwise convolution phase of each block in the original implementation.
  204. """
  205. def __init__(
  206. self,
  207. config: AlignVisionConfig,
  208. in_dim: int,
  209. stride: int,
  210. kernel_size: int,
  211. adjust_padding: bool,
  212. ):
  213. super().__init__()
  214. self.stride = stride
  215. conv_pad = "valid" if self.stride == 2 else "same"
  216. padding = correct_pad(kernel_size, adjust=adjust_padding)
  217. self.depthwise_conv_pad = nn.ZeroPad2d(padding=padding)
  218. self.depthwise_conv = AlignVisionDepthwiseConv2d(
  219. in_dim, kernel_size=kernel_size, stride=stride, padding=conv_pad, bias=False
  220. )
  221. self.depthwise_norm = nn.BatchNorm2d(
  222. num_features=in_dim, eps=config.batch_norm_eps, momentum=config.batch_norm_momentum
  223. )
  224. self.depthwise_act = ACT2FN[config.hidden_act]
  225. def forward(self, hidden_states: torch.FloatTensor) -> torch.Tensor:
  226. # Depthwise convolution
  227. if self.stride == 2:
  228. hidden_states = self.depthwise_conv_pad(hidden_states)
  229. hidden_states = self.depthwise_conv(hidden_states)
  230. hidden_states = self.depthwise_norm(hidden_states)
  231. hidden_states = self.depthwise_act(hidden_states)
  232. return hidden_states
  233. # Copied from transformers.models.efficientnet.modeling_efficientnet.EfficientNetSqueezeExciteLayer with EfficientNet->AlignVision
  234. class AlignVisionSqueezeExciteLayer(nn.Module):
  235. r"""
  236. This corresponds to the Squeeze and Excitement phase of each block in the original implementation.
  237. """
  238. def __init__(self, config: AlignVisionConfig, in_dim: int, expand_dim: int, expand: bool = False):
  239. super().__init__()
  240. self.dim = expand_dim if expand else in_dim
  241. self.dim_se = max(1, int(in_dim * config.squeeze_expansion_ratio))
  242. self.squeeze = nn.AdaptiveAvgPool2d(output_size=1)
  243. self.reduce = nn.Conv2d(
  244. in_channels=self.dim,
  245. out_channels=self.dim_se,
  246. kernel_size=1,
  247. padding="same",
  248. )
  249. self.expand = nn.Conv2d(
  250. in_channels=self.dim_se,
  251. out_channels=self.dim,
  252. kernel_size=1,
  253. padding="same",
  254. )
  255. self.act_reduce = ACT2FN[config.hidden_act]
  256. self.act_expand = nn.Sigmoid()
  257. def forward(self, hidden_states: torch.FloatTensor) -> torch.Tensor:
  258. inputs = hidden_states
  259. hidden_states = self.squeeze(hidden_states)
  260. hidden_states = self.reduce(hidden_states)
  261. hidden_states = self.act_reduce(hidden_states)
  262. hidden_states = self.expand(hidden_states)
  263. hidden_states = self.act_expand(hidden_states)
  264. hidden_states = torch.mul(inputs, hidden_states)
  265. return hidden_states
  266. class AlignVisionFinalBlockLayer(nn.Module):
  267. r"""
  268. This corresponds to the final phase of each block in the original implementation.
  269. """
  270. def __init__(
  271. self, config: AlignVisionConfig, in_dim: int, out_dim: int, stride: int, drop_rate: float, id_skip: bool
  272. ):
  273. super().__init__()
  274. self.apply_dropout = stride == 1 and not id_skip
  275. self.project_conv = nn.Conv2d(
  276. in_channels=in_dim,
  277. out_channels=out_dim,
  278. kernel_size=1,
  279. padding="same",
  280. bias=False,
  281. )
  282. self.project_bn = nn.BatchNorm2d(
  283. num_features=out_dim, eps=config.batch_norm_eps, momentum=config.batch_norm_momentum
  284. )
  285. self.dropout = nn.Dropout(p=drop_rate)
  286. def forward(self, embeddings: torch.FloatTensor, hidden_states: torch.FloatTensor) -> torch.Tensor:
  287. hidden_states = self.project_conv(hidden_states)
  288. hidden_states = self.project_bn(hidden_states)
  289. if self.apply_dropout:
  290. hidden_states = self.dropout(hidden_states)
  291. hidden_states = hidden_states + embeddings
  292. return hidden_states
  293. class AlignVisionBlock(nn.Module):
  294. r"""
  295. This corresponds to the block module of original the EfficientNet vision encoder implementation.
  296. Args:
  297. config ([`AlignVisionConfig`]):
  298. Model configuration class.
  299. in_dim (`int`):
  300. Number of input channels.
  301. out_dim (`int`):
  302. Number of output channels.
  303. stride (`int`):
  304. Stride size to be used in convolution layers.
  305. expand_ratio (`int`):
  306. Expand ratio to set the output dimensions for the expansion and squeeze-excite layers.
  307. kernel_size (`int`):
  308. Kernel size for the depthwise convolution layer.
  309. drop_rate (`float`):
  310. Dropout rate to be used in the final phase of each block.
  311. id_skip (`bool`):
  312. Whether to apply dropout and sum the final hidden states with the input embeddings during the final phase
  313. of each block. Set to `True` for the first block of each stage.
  314. adjust_padding (`bool`):
  315. Whether to apply padding to only right and bottom side of the input kernel before the depthwise convolution
  316. operation, set to `True` for inputs with odd input sizes.
  317. """
  318. def __init__(
  319. self,
  320. config: AlignVisionConfig,
  321. in_dim: int,
  322. out_dim: int,
  323. stride: int,
  324. expand_ratio: int,
  325. kernel_size: int,
  326. drop_rate: float,
  327. id_skip: bool,
  328. adjust_padding: bool,
  329. ):
  330. super().__init__()
  331. self.expand_ratio = expand_ratio
  332. self.expand = self.expand_ratio != 1
  333. expand_in_dim = in_dim * expand_ratio
  334. if self.expand:
  335. self.expansion = AlignVisionExpansionLayer(
  336. config=config, in_dim=in_dim, out_dim=expand_in_dim, stride=stride
  337. )
  338. self.depthwise_conv = AlignVisionDepthwiseLayer(
  339. config=config,
  340. in_dim=expand_in_dim if self.expand else in_dim,
  341. stride=stride,
  342. kernel_size=kernel_size,
  343. adjust_padding=adjust_padding,
  344. )
  345. self.squeeze_excite = AlignVisionSqueezeExciteLayer(
  346. config=config, in_dim=in_dim, expand_dim=expand_in_dim, expand=self.expand
  347. )
  348. self.projection = AlignVisionFinalBlockLayer(
  349. config=config,
  350. in_dim=expand_in_dim if self.expand else in_dim,
  351. out_dim=out_dim,
  352. stride=stride,
  353. drop_rate=drop_rate,
  354. id_skip=id_skip,
  355. )
  356. def forward(self, hidden_states: torch.FloatTensor) -> torch.Tensor:
  357. embeddings = hidden_states
  358. # Expansion and depthwise convolution phase
  359. if self.expand_ratio != 1:
  360. hidden_states = self.expansion(hidden_states)
  361. hidden_states = self.depthwise_conv(hidden_states)
  362. # Squeeze and excite phase
  363. hidden_states = self.squeeze_excite(hidden_states)
  364. hidden_states = self.projection(embeddings, hidden_states)
  365. return hidden_states
  366. class AlignVisionEncoder(nn.Module):
  367. r"""
  368. Forward propagates the embeddings through each vision encoder (EfficientNet) block.
  369. Args:
  370. config ([`AlignVisionConfig`]):
  371. Model configuration class.
  372. """
  373. def __init__(self, config: AlignVisionConfig):
  374. super().__init__()
  375. self.depth_coefficient = config.depth_coefficient
  376. def round_repeats(repeats):
  377. # Round number of block repeats based on depth multiplier.
  378. return int(math.ceil(self.depth_coefficient * repeats))
  379. num_base_blocks = len(config.in_channels)
  380. num_blocks = sum(round_repeats(n) for n in config.num_block_repeats)
  381. curr_block_num = 0
  382. blocks = []
  383. for i in range(num_base_blocks):
  384. in_dim = round_filters(config, config.in_channels[i])
  385. out_dim = round_filters(config, config.out_channels[i])
  386. stride = config.strides[i]
  387. kernel_size = config.kernel_sizes[i]
  388. expand_ratio = config.expand_ratios[i]
  389. for j in range(round_repeats(config.num_block_repeats[i])):
  390. id_skip = j == 0
  391. stride = 1 if j > 0 else stride
  392. in_dim = out_dim if j > 0 else in_dim
  393. adjust_padding = curr_block_num not in config.depthwise_padding
  394. drop_rate = config.drop_connect_rate * curr_block_num / num_blocks
  395. block = AlignVisionBlock(
  396. config=config,
  397. in_dim=in_dim,
  398. out_dim=out_dim,
  399. stride=stride,
  400. kernel_size=kernel_size,
  401. expand_ratio=expand_ratio,
  402. drop_rate=drop_rate,
  403. id_skip=id_skip,
  404. adjust_padding=adjust_padding,
  405. )
  406. blocks.append(block)
  407. curr_block_num += 1
  408. self.blocks = nn.ModuleList(blocks)
  409. def forward(
  410. self,
  411. hidden_states: torch.FloatTensor,
  412. output_hidden_states: Optional[bool] = False,
  413. return_dict: Optional[bool] = True,
  414. ) -> BaseModelOutputWithPoolingAndNoAttention:
  415. all_hidden_states = (hidden_states,) if output_hidden_states else None
  416. for block in self.blocks:
  417. hidden_states = block(hidden_states)
  418. if output_hidden_states:
  419. all_hidden_states += (hidden_states,)
  420. if not return_dict:
  421. return tuple(v for v in [hidden_states, all_hidden_states] if v is not None)
  422. return BaseModelOutputWithNoAttention(
  423. last_hidden_state=hidden_states,
  424. hidden_states=all_hidden_states,
  425. )
  426. class AlignTextEmbeddings(nn.Module):
  427. """Construct the embeddings from word, position and token_type embeddings."""
  428. def __init__(self, config):
  429. super().__init__()
  430. self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id)
  431. self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size)
  432. self.token_type_embeddings = nn.Embedding(config.type_vocab_size, config.hidden_size)
  433. # self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load
  434. # any TensorFlow checkpoint file
  435. self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
  436. self.dropout = nn.Dropout(config.hidden_dropout_prob)
  437. # position_ids (1, len position emb) is contiguous in memory and exported when serialized
  438. self.position_embedding_type = getattr(config, "position_embedding_type", "absolute")
  439. self.register_buffer(
  440. "position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)), persistent=False
  441. )
  442. self.register_buffer(
  443. "token_type_ids", torch.zeros(self.position_ids.size(), dtype=torch.long), persistent=False
  444. )
  445. def forward(
  446. self,
  447. input_ids: Optional[torch.LongTensor] = None,
  448. token_type_ids: Optional[torch.LongTensor] = None,
  449. position_ids: Optional[torch.LongTensor] = None,
  450. inputs_embeds: Optional[torch.FloatTensor] = None,
  451. ) -> torch.Tensor:
  452. if input_ids is not None:
  453. input_shape = input_ids.size()
  454. else:
  455. input_shape = inputs_embeds.size()[:-1]
  456. seq_length = input_shape[1]
  457. if position_ids is None:
  458. position_ids = self.position_ids[:, :seq_length]
  459. # Setting the token_type_ids to the registered buffer in constructor where it is all zeros, which usually occurs
  460. # when its auto-generated, registered buffer helps users when tracing the model without passing token_type_ids, solves
  461. # issue #5664
  462. if token_type_ids is None:
  463. if hasattr(self, "token_type_ids"):
  464. buffered_token_type_ids = self.token_type_ids[:, :seq_length]
  465. buffered_token_type_ids_expanded = buffered_token_type_ids.expand(input_shape[0], seq_length)
  466. token_type_ids = buffered_token_type_ids_expanded
  467. else:
  468. token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=self.position_ids.device)
  469. if inputs_embeds is None:
  470. inputs_embeds = self.word_embeddings(input_ids)
  471. token_type_embeddings = self.token_type_embeddings(token_type_ids)
  472. embeddings = inputs_embeds + token_type_embeddings
  473. if self.position_embedding_type == "absolute":
  474. position_embeddings = self.position_embeddings(position_ids)
  475. embeddings += position_embeddings
  476. embeddings = self.LayerNorm(embeddings)
  477. embeddings = self.dropout(embeddings)
  478. return embeddings
  479. def eager_attention_forward(
  480. module: nn.Module,
  481. query: torch.Tensor,
  482. key: torch.Tensor,
  483. value: torch.Tensor,
  484. attention_mask: Optional[torch.Tensor],
  485. scaling: float,
  486. dropout: float = 0.0,
  487. head_mask: Optional[torch.Tensor] = None,
  488. **kwargs,
  489. ):
  490. attn_weights = torch.matmul(query, key.transpose(2, 3)) * scaling
  491. if attention_mask is not None:
  492. causal_mask = attention_mask[:, :, :, : key.shape[-2]]
  493. attn_weights = attn_weights + causal_mask
  494. attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
  495. attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
  496. if head_mask is not None:
  497. attn_weights = attn_weights * head_mask.view(1, -1, 1, 1)
  498. attn_output = torch.matmul(attn_weights, value)
  499. attn_output = attn_output.transpose(1, 2).contiguous()
  500. return attn_output, attn_weights
  501. class AlignTextSelfAttention(nn.Module):
  502. def __init__(self, config):
  503. super().__init__()
  504. if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"):
  505. raise ValueError(
  506. f"The hidden size ({config.hidden_size}) is not a multiple of the number of attention "
  507. f"heads ({config.num_attention_heads})"
  508. )
  509. self.config = config
  510. self.num_attention_heads = config.num_attention_heads
  511. self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
  512. self.all_head_size = self.num_attention_heads * self.attention_head_size
  513. self.query = nn.Linear(config.hidden_size, self.all_head_size)
  514. self.key = nn.Linear(config.hidden_size, self.all_head_size)
  515. self.value = nn.Linear(config.hidden_size, self.all_head_size)
  516. self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
  517. self.attention_dropout = config.attention_probs_dropout_prob
  518. self.scaling = self.attention_head_size**-0.5
  519. def forward(
  520. self,
  521. hidden_states: torch.Tensor,
  522. attention_mask: Optional[torch.FloatTensor] = None,
  523. head_mask: Optional[torch.FloatTensor] = None,
  524. output_attentions: Optional[bool] = False,
  525. **kwargs,
  526. ) -> tuple[torch.Tensor]:
  527. input_shape = hidden_states.shape[:-1]
  528. hidden_shape = (*input_shape, -1, self.attention_head_size)
  529. query_states = self.query(hidden_states).view(hidden_shape).transpose(1, 2)
  530. key_states = self.key(hidden_states).view(hidden_shape).transpose(1, 2)
  531. value_states = self.value(hidden_states).view(hidden_shape).transpose(1, 2)
  532. attention_interface: Callable = eager_attention_forward
  533. if self.config._attn_implementation != "eager":
  534. attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
  535. attn_output, attn_weights = attention_interface(
  536. self,
  537. query_states,
  538. key_states,
  539. value_states,
  540. attention_mask,
  541. dropout=0.0 if not self.training else self.attention_dropout,
  542. scaling=self.scaling,
  543. head_mask=head_mask,
  544. **kwargs,
  545. )
  546. attn_output = attn_output.reshape(*input_shape, -1).contiguous()
  547. outputs = (attn_output, attn_weights) if output_attentions else (attn_output,)
  548. return outputs
  549. # Copied from transformers.models.bert.modeling_bert.BertSelfOutput with Bert->AlignText
  550. class AlignTextSelfOutput(nn.Module):
  551. def __init__(self, config):
  552. super().__init__()
  553. self.dense = nn.Linear(config.hidden_size, config.hidden_size)
  554. self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
  555. self.dropout = nn.Dropout(config.hidden_dropout_prob)
  556. def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor:
  557. hidden_states = self.dense(hidden_states)
  558. hidden_states = self.dropout(hidden_states)
  559. hidden_states = self.LayerNorm(hidden_states + input_tensor)
  560. return hidden_states
  561. class AlignTextAttention(nn.Module):
  562. def __init__(self, config):
  563. super().__init__()
  564. self.self = AlignTextSelfAttention(config)
  565. self.output = AlignTextSelfOutput(config)
  566. self.pruned_heads = set()
  567. def prune_heads(self, heads):
  568. if len(heads) == 0:
  569. return
  570. heads, index = find_pruneable_heads_and_indices(
  571. heads, self.self.num_attention_heads, self.self.attention_head_size, self.pruned_heads
  572. )
  573. # Prune linear layers
  574. self.self.query = prune_linear_layer(self.self.query, index)
  575. self.self.key = prune_linear_layer(self.self.key, index)
  576. self.self.value = prune_linear_layer(self.self.value, index)
  577. self.output.dense = prune_linear_layer(self.output.dense, index, dim=1)
  578. # Update hyper params and store pruned heads
  579. self.self.num_attention_heads = self.self.num_attention_heads - len(heads)
  580. self.self.all_head_size = self.self.attention_head_size * self.self.num_attention_heads
  581. self.pruned_heads = self.pruned_heads.union(heads)
  582. def forward(
  583. self,
  584. hidden_states: torch.Tensor,
  585. attention_mask: Optional[torch.FloatTensor] = None,
  586. head_mask: Optional[torch.FloatTensor] = None,
  587. output_attentions: Optional[bool] = False,
  588. **kwargs,
  589. ) -> tuple[torch.Tensor]:
  590. self_outputs = self.self(
  591. hidden_states,
  592. attention_mask=attention_mask,
  593. head_mask=head_mask,
  594. output_attentions=output_attentions,
  595. **kwargs,
  596. )
  597. attention_output = self.output(self_outputs[0], hidden_states)
  598. outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them
  599. return outputs
  600. # Copied from transformers.models.bert.modeling_bert.BertIntermediate with Bert->AlignText
  601. class AlignTextIntermediate(nn.Module):
  602. def __init__(self, config):
  603. super().__init__()
  604. self.dense = nn.Linear(config.hidden_size, config.intermediate_size)
  605. if isinstance(config.hidden_act, str):
  606. self.intermediate_act_fn = ACT2FN[config.hidden_act]
  607. else:
  608. self.intermediate_act_fn = config.hidden_act
  609. def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
  610. hidden_states = self.dense(hidden_states)
  611. hidden_states = self.intermediate_act_fn(hidden_states)
  612. return hidden_states
  613. # Copied from transformers.models.bert.modeling_bert.BertOutput with Bert->AlignText
  614. class AlignTextOutput(nn.Module):
  615. def __init__(self, config):
  616. super().__init__()
  617. self.dense = nn.Linear(config.intermediate_size, config.hidden_size)
  618. self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
  619. self.dropout = nn.Dropout(config.hidden_dropout_prob)
  620. def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor:
  621. hidden_states = self.dense(hidden_states)
  622. hidden_states = self.dropout(hidden_states)
  623. hidden_states = self.LayerNorm(hidden_states + input_tensor)
  624. return hidden_states
  625. class AlignTextLayer(GradientCheckpointingLayer):
  626. def __init__(self, config):
  627. super().__init__()
  628. self.chunk_size_feed_forward = config.chunk_size_feed_forward
  629. self.seq_len_dim = 1
  630. self.attention = AlignTextAttention(config)
  631. self.intermediate = AlignTextIntermediate(config)
  632. self.output = AlignTextOutput(config)
  633. def forward(
  634. self,
  635. hidden_states: torch.Tensor,
  636. attention_mask: Optional[torch.FloatTensor] = None,
  637. head_mask: Optional[torch.FloatTensor] = None,
  638. output_attentions: Optional[bool] = False,
  639. **kwargs,
  640. ) -> tuple[torch.Tensor]:
  641. self_attention_outputs = self.attention(
  642. hidden_states,
  643. attention_mask=attention_mask,
  644. head_mask=head_mask,
  645. output_attentions=output_attentions,
  646. **kwargs,
  647. )
  648. attention_output = self_attention_outputs[0]
  649. outputs = self_attention_outputs[1:] # add self attentions if we output attention weights
  650. layer_output = apply_chunking_to_forward(
  651. self.feed_forward_chunk, self.chunk_size_feed_forward, self.seq_len_dim, attention_output
  652. )
  653. outputs = (layer_output,) + outputs
  654. return outputs
  655. def feed_forward_chunk(self, attention_output):
  656. intermediate_output = self.intermediate(attention_output)
  657. layer_output = self.output(intermediate_output, attention_output)
  658. return layer_output
  659. class AlignTextEncoder(nn.Module):
  660. def __init__(self, config):
  661. super().__init__()
  662. self.config = config
  663. self.layer = nn.ModuleList([AlignTextLayer(config) for i in range(config.num_hidden_layers)])
  664. self.gradient_checkpointing = False
  665. @can_return_tuple
  666. def forward(
  667. self,
  668. hidden_states: torch.Tensor,
  669. attention_mask: Optional[torch.FloatTensor] = None,
  670. head_mask: Optional[torch.FloatTensor] = None,
  671. output_attentions: Optional[bool] = False,
  672. output_hidden_states: Optional[bool] = False,
  673. return_dict: Optional[bool] = True,
  674. **kwargs,
  675. ) -> Union[tuple[torch.Tensor], BaseModelOutput]:
  676. all_hidden_states = () if output_hidden_states else None
  677. all_self_attentions = () if output_attentions else None
  678. for i, layer_module in enumerate(self.layer):
  679. if output_hidden_states:
  680. all_hidden_states = all_hidden_states + (hidden_states,)
  681. layer_head_mask = head_mask[i] if head_mask is not None else None
  682. layer_outputs = layer_module(
  683. hidden_states=hidden_states,
  684. attention_mask=attention_mask,
  685. head_mask=layer_head_mask,
  686. output_attentions=output_attentions,
  687. **kwargs,
  688. )
  689. hidden_states = layer_outputs[0]
  690. if output_attentions:
  691. all_self_attentions = all_self_attentions + (layer_outputs[1],)
  692. if output_hidden_states:
  693. all_hidden_states = all_hidden_states + (hidden_states,)
  694. return BaseModelOutput(
  695. last_hidden_state=hidden_states,
  696. hidden_states=all_hidden_states,
  697. attentions=all_self_attentions,
  698. )
  699. # Copied from transformers.models.bert.modeling_bert.BertPooler with Bert -> AlignText
  700. class AlignTextPooler(nn.Module):
  701. def __init__(self, config):
  702. super().__init__()
  703. self.dense = nn.Linear(config.hidden_size, config.hidden_size)
  704. self.activation = nn.Tanh()
  705. def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
  706. # We "pool" the model by simply taking the hidden state corresponding
  707. # to the first token.
  708. first_token_tensor = hidden_states[:, 0]
  709. pooled_output = self.dense(first_token_tensor)
  710. pooled_output = self.activation(pooled_output)
  711. return pooled_output
  712. @auto_docstring
  713. class AlignPreTrainedModel(PreTrainedModel):
  714. config: AlignConfig
  715. base_model_prefix = "align"
  716. supports_gradient_checkpointing = True
  717. def _init_weights(self, module: nn.Module):
  718. """Initialize the weights"""
  719. std = self.config.initializer_range
  720. if isinstance(module, (nn.Linear, nn.Conv2d)):
  721. module.weight.data.normal_(mean=0.0, std=std)
  722. if module.bias is not None:
  723. module.bias.data.zero_()
  724. elif isinstance(module, AlignModel):
  725. nn.init.xavier_uniform_(module.text_projection.weight)
  726. module.text_projection.bias.data.zero_()
  727. module.temperature.data.fill_(self.config.temperature_init_value)
  728. elif isinstance(module, nn.Embedding):
  729. module.weight.data.normal_(mean=0.0, std=std)
  730. if module.padding_idx is not None:
  731. module.weight.data[module.padding_idx].zero_()
  732. if isinstance(module, (nn.LayerNorm, nn.BatchNorm2d)):
  733. module.bias.data.zero_()
  734. module.weight.data.fill_(1.0)
  735. @auto_docstring(
  736. custom_intro="""
  737. The text model from ALIGN without any head or projection on top.
  738. """
  739. )
  740. class AlignTextModel(AlignPreTrainedModel):
  741. config: AlignTextConfig
  742. _no_split_modules = ["AlignTextEmbeddings"]
  743. def __init__(self, config: AlignTextConfig, add_pooling_layer: bool = True):
  744. r"""
  745. add_pooling_layer (bool, *optional*, defaults to `True`):
  746. Whether to add a pooling layer
  747. """
  748. super().__init__(config)
  749. self.config = config
  750. self.embeddings = AlignTextEmbeddings(config)
  751. self.encoder = AlignTextEncoder(config)
  752. self.pooler = AlignTextPooler(config) if add_pooling_layer else None
  753. # Initialize weights and apply final processing
  754. self.post_init()
  755. def get_input_embeddings(self):
  756. return self.embeddings.word_embeddings
  757. def set_input_embeddings(self, value):
  758. self.embeddings.word_embeddings = value
  759. @can_return_tuple
  760. @auto_docstring
  761. def forward(
  762. self,
  763. input_ids: Optional[torch.Tensor] = None,
  764. attention_mask: Optional[torch.Tensor] = None,
  765. token_type_ids: Optional[torch.Tensor] = None,
  766. position_ids: Optional[torch.Tensor] = None,
  767. head_mask: Optional[torch.FloatTensor] = None,
  768. inputs_embeds: Optional[torch.Tensor] = None,
  769. output_attentions: Optional[bool] = None,
  770. output_hidden_states: Optional[bool] = None,
  771. return_dict: Optional[bool] = None,
  772. **kwargs,
  773. ) -> Union[tuple, BaseModelOutputWithPooling]:
  774. r"""
  775. Examples:
  776. ```python
  777. >>> from transformers import AutoTokenizer, AlignTextModel
  778. >>> model = AlignTextModel.from_pretrained("kakaobrain/align-base")
  779. >>> tokenizer = AutoTokenizer.from_pretrained("kakaobrain/align-base")
  780. >>> inputs = tokenizer(["a photo of a cat", "a photo of a dog"], padding=True, return_tensors="pt")
  781. >>> outputs = model(**inputs)
  782. >>> last_hidden_state = outputs.last_hidden_state
  783. >>> pooled_output = outputs.pooler_output # pooled (EOS token) states
  784. ```"""
  785. output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
  786. output_hidden_states = (
  787. output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
  788. )
  789. return_dict = return_dict if return_dict is not None else self.config.use_return_dict
  790. if input_ids is not None and inputs_embeds is not None:
  791. raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
  792. elif input_ids is not None:
  793. self.warn_if_padding_and_no_attention_mask(input_ids, attention_mask)
  794. input_shape = input_ids.size()
  795. elif inputs_embeds is not None:
  796. input_shape = inputs_embeds.size()[:-1]
  797. else:
  798. raise ValueError("You have to specify either input_ids or inputs_embeds")
  799. batch_size, seq_length = input_shape
  800. device = input_ids.device if input_ids is not None else inputs_embeds.device
  801. if attention_mask is None:
  802. attention_mask = torch.ones(((batch_size, seq_length)), device=device)
  803. if token_type_ids is None:
  804. if hasattr(self.embeddings, "token_type_ids"):
  805. buffered_token_type_ids = self.embeddings.token_type_ids[:, :seq_length]
  806. buffered_token_type_ids_expanded = buffered_token_type_ids.expand(batch_size, seq_length)
  807. token_type_ids = buffered_token_type_ids_expanded
  808. else:
  809. token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=device)
  810. # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
  811. # ourselves in which case we just need to make it broadcastable to all heads.
  812. extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(attention_mask, input_shape)
  813. # Prepare head mask if needed
  814. # 1.0 in head_mask indicate we keep the head
  815. # attention_probs has shape bsz x n_heads x N x N
  816. # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
  817. # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
  818. head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
  819. embedding_output = self.embeddings(
  820. input_ids=input_ids,
  821. position_ids=position_ids,
  822. token_type_ids=token_type_ids,
  823. inputs_embeds=inputs_embeds,
  824. )
  825. encoder_outputs = self.encoder(
  826. embedding_output,
  827. attention_mask=extended_attention_mask,
  828. head_mask=head_mask,
  829. output_attentions=output_attentions,
  830. output_hidden_states=output_hidden_states,
  831. return_dict=True,
  832. **kwargs,
  833. )
  834. sequence_output = encoder_outputs[0]
  835. pooled_output = self.pooler(sequence_output) if self.pooler is not None else None
  836. return BaseModelOutputWithPooling(
  837. last_hidden_state=sequence_output,
  838. pooler_output=pooled_output,
  839. hidden_states=encoder_outputs.hidden_states,
  840. attentions=encoder_outputs.attentions,
  841. )
  842. @auto_docstring(
  843. custom_intro="""
  844. The vision model from ALIGN without any head or projection on top.
  845. """
  846. )
  847. class AlignVisionModel(AlignPreTrainedModel):
  848. config: AlignVisionConfig
  849. main_input_name = "pixel_values"
  850. supports_gradient_checkpointing = False
  851. def __init__(self, config: AlignVisionConfig):
  852. super().__init__(config)
  853. self.config = config
  854. self.embeddings = AlignVisionEmbeddings(config)
  855. self.encoder = AlignVisionEncoder(config)
  856. # Final pooling layer
  857. if config.pooling_type == "mean":
  858. self.pooler = nn.AvgPool2d(config.hidden_dim, ceil_mode=True)
  859. elif config.pooling_type == "max":
  860. self.pooler = nn.MaxPool2d(config.hidden_dim, ceil_mode=True)
  861. else:
  862. raise ValueError(f"config.pooling must be one of ['mean', 'max'] got {config.pooling}")
  863. # Initialize weights and apply final processing
  864. self.post_init()
  865. def get_input_embeddings(self) -> nn.Module:
  866. return self.vision_model.embeddings.convolution
  867. @can_return_tuple
  868. @auto_docstring
  869. def forward(
  870. self,
  871. pixel_values: Optional[torch.FloatTensor] = None,
  872. output_hidden_states: Optional[bool] = None,
  873. return_dict: Optional[bool] = None,
  874. ) -> Union[tuple, BaseModelOutputWithPoolingAndNoAttention]:
  875. r"""
  876. Examples:
  877. ```python
  878. >>> from PIL import Image
  879. >>> import requests
  880. >>> from transformers import AutoProcessor, AlignVisionModel
  881. >>> model = AlignVisionModel.from_pretrained("kakaobrain/align-base")
  882. >>> processor = AutoProcessor.from_pretrained("kakaobrain/align-base")
  883. >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
  884. >>> image = Image.open(requests.get(url, stream=True).raw)
  885. >>> inputs = processor(images=image, return_tensors="pt")
  886. >>> outputs = model(**inputs)
  887. >>> last_hidden_state = outputs.last_hidden_state
  888. >>> pooled_output = outputs.pooler_output # pooled CLS states
  889. ```"""
  890. output_hidden_states = (
  891. output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
  892. )
  893. return_dict = return_dict if return_dict is not None else self.config.use_return_dict
  894. if pixel_values is None:
  895. raise ValueError("You have to specify pixel_values")
  896. embedding_output = self.embeddings(pixel_values)
  897. encoder_outputs = self.encoder(
  898. embedding_output,
  899. output_hidden_states=output_hidden_states,
  900. return_dict=True,
  901. )
  902. # Apply pooling
  903. last_hidden_state = encoder_outputs[0]
  904. pooled_output = self.pooler(last_hidden_state)
  905. # Reshape (batch_size, projection_dim, 1 , 1) -> (batch_size, projection_dim)
  906. pooled_output = pooled_output.reshape(pooled_output.shape[:2])
  907. return BaseModelOutputWithPoolingAndNoAttention(
  908. last_hidden_state=last_hidden_state,
  909. pooler_output=pooled_output,
  910. hidden_states=encoder_outputs.hidden_states,
  911. )
  912. @auto_docstring
  913. class AlignModel(AlignPreTrainedModel):
  914. config: AlignConfig
  915. def __init__(self, config: AlignConfig):
  916. super().__init__(config)
  917. if not isinstance(config.text_config, AlignTextConfig):
  918. raise TypeError(
  919. "config.text_config is expected to be of type AlignTextConfig but is of type"
  920. f" {type(config.text_config)}."
  921. )
  922. if not isinstance(config.vision_config, AlignVisionConfig):
  923. raise TypeError(
  924. "config.vision_config is expected to be of type AlignVisionConfig but is of type"
  925. f" {type(config.vision_config)}."
  926. )
  927. text_config = config.text_config
  928. vision_config = config.vision_config
  929. self.projection_dim = config.projection_dim
  930. self.text_embed_dim = text_config.hidden_size
  931. self.text_model = AlignTextModel(text_config)
  932. self.vision_model = AlignVisionModel(vision_config)
  933. self.text_projection = nn.Linear(self.text_embed_dim, self.projection_dim)
  934. self.temperature = nn.Parameter(torch.tensor(self.config.temperature_init_value))
  935. # Initialize weights and apply final processing
  936. self.post_init()
  937. @filter_out_non_signature_kwargs()
  938. @auto_docstring
  939. def get_text_features(
  940. self,
  941. input_ids: Optional[torch.Tensor] = None,
  942. attention_mask: Optional[torch.Tensor] = None,
  943. token_type_ids: Optional[torch.Tensor] = None,
  944. position_ids: Optional[torch.Tensor] = None,
  945. head_mask: Optional[torch.Tensor] = None,
  946. inputs_embeds: Optional[torch.Tensor] = None,
  947. ) -> torch.FloatTensor:
  948. r"""
  949. Returns:
  950. text_features (`torch.FloatTensor` of shape `(batch_size, output_dim`): The text embeddings obtained by
  951. applying the projection layer to the pooled output of [`AlignTextModel`].
  952. Examples:
  953. ```python
  954. >>> import torch
  955. >>> from transformers import AutoTokenizer, AlignModel
  956. >>> model = AlignModel.from_pretrained("kakaobrain/align-base")
  957. >>> tokenizer = AutoTokenizer.from_pretrained("kakaobrain/align-base")
  958. >>> inputs = tokenizer(["a photo of a cat", "a photo of a dog"], padding=True, return_tensors="pt")
  959. >>> with torch.inference_mode():
  960. ... text_features = model.get_text_features(**inputs)
  961. ```"""
  962. text_outputs = self.text_model(
  963. input_ids=input_ids,
  964. attention_mask=attention_mask,
  965. token_type_ids=token_type_ids,
  966. position_ids=position_ids,
  967. head_mask=head_mask,
  968. inputs_embeds=inputs_embeds,
  969. )
  970. last_hidden_state = text_outputs[0][:, 0, :]
  971. text_features = self.text_projection(last_hidden_state)
  972. return text_features
  973. @filter_out_non_signature_kwargs()
  974. @auto_docstring
  975. def get_image_features(self, pixel_values: torch.FloatTensor) -> torch.FloatTensor:
  976. r"""
  977. Returns:
  978. image_features (`torch.FloatTensor` of shape `(batch_size, output_dim`): The image embeddings obtained by
  979. applying the projection layer to the pooled output of [`AlignVisionModel`].
  980. Examples:
  981. ```python
  982. >>> import torch
  983. >>> from transformers import AutoProcessor, AlignModel
  984. >>> from transformers.image_utils import load_image
  985. >>> model = AlignModel.from_pretrained("kakaobrain/align-base")
  986. >>> processor = AutoProcessor.from_pretrained("kakaobrain/align-base")
  987. >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
  988. >>> image = load_image(url)
  989. >>> inputs = processor(images=image, return_tensors="pt")
  990. >>> with torch.inference_mode():
  991. ... image_features = model.get_image_features(**inputs)
  992. ```"""
  993. vision_outputs = self.vision_model(pixel_values=pixel_values)
  994. image_features = vision_outputs.pooler_output
  995. return image_features
  996. @can_return_tuple
  997. @auto_docstring
  998. def forward(
  999. self,
  1000. input_ids: Optional[torch.LongTensor] = None,
  1001. pixel_values: Optional[torch.FloatTensor] = None,
  1002. attention_mask: Optional[torch.Tensor] = None,
  1003. token_type_ids: Optional[torch.Tensor] = None,
  1004. position_ids: Optional[torch.Tensor] = None,
  1005. head_mask: Optional[torch.Tensor] = None,
  1006. inputs_embeds: Optional[torch.Tensor] = None,
  1007. return_loss: Optional[bool] = None,
  1008. output_attentions: Optional[bool] = None,
  1009. output_hidden_states: Optional[bool] = None,
  1010. return_dict: Optional[bool] = None,
  1011. ) -> Union[tuple, AlignOutput]:
  1012. r"""
  1013. return_loss (`bool`, *optional*):
  1014. Whether or not to return the contrastive loss.
  1015. Examples:
  1016. ```python
  1017. >>> import torch
  1018. >>> from transformers import AutoProcessor, AlignModel
  1019. >>> from transformers.image_utils import load_image
  1020. >>> model = AlignModel.from_pretrained("kakaobrain/align-base")
  1021. >>> processor = AutoProcessor.from_pretrained("kakaobrain/align-base")
  1022. >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
  1023. >>> image = load_image(url)
  1024. >>> inputs = processor(
  1025. ... images=image, text=["a photo of a cat", "a photo of a dog"], return_tensors="pt", padding=True
  1026. ... )
  1027. >>> with torch.inference_mode():
  1028. ... outputs = model(**inputs)
  1029. >>> logits_per_image = outputs.logits_per_image # this is the image-text similarity score
  1030. >>> probs = logits_per_image.softmax(dim=1) # we can take the softmax to get the label probabilities
  1031. ```"""
  1032. # Use ALIGN model's config for some fields (if specified) instead of those of vision & text components.
  1033. output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
  1034. output_hidden_states = (
  1035. output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
  1036. )
  1037. return_dict = return_dict if return_dict is not None else self.config.use_return_dict
  1038. vision_outputs = self.vision_model(
  1039. pixel_values=pixel_values,
  1040. output_hidden_states=output_hidden_states,
  1041. return_dict=True,
  1042. )
  1043. text_outputs = self.text_model(
  1044. input_ids=input_ids,
  1045. attention_mask=attention_mask,
  1046. token_type_ids=token_type_ids,
  1047. position_ids=position_ids,
  1048. head_mask=head_mask,
  1049. inputs_embeds=inputs_embeds,
  1050. output_attentions=output_attentions,
  1051. output_hidden_states=output_hidden_states,
  1052. return_dict=True,
  1053. )
  1054. image_embeds = vision_outputs[1]
  1055. text_embeds = text_outputs[0][:, 0, :]
  1056. text_embeds = self.text_projection(text_embeds)
  1057. # normalized features
  1058. image_embeds = image_embeds / image_embeds.norm(p=2, dim=-1, keepdim=True)
  1059. text_embeds = text_embeds / text_embeds.norm(p=2, dim=-1, keepdim=True)
  1060. # cosine similarity as logits
  1061. logits_per_text = torch.matmul(text_embeds, image_embeds.t()) / self.temperature
  1062. logits_per_image = logits_per_text.t()
  1063. loss = None
  1064. if return_loss:
  1065. loss = align_loss(logits_per_text)
  1066. return AlignOutput(
  1067. loss=loss,
  1068. logits_per_image=logits_per_image,
  1069. logits_per_text=logits_per_text,
  1070. text_embeds=text_embeds,
  1071. image_embeds=image_embeds,
  1072. text_model_output=text_outputs,
  1073. vision_model_output=vision_outputs,
  1074. )
  1075. __all__ = ["AlignPreTrainedModel", "AlignTextModel", "AlignVisionModel", "AlignModel"]