modeling_trocr.py 38 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865
  1. # coding=utf-8
  2. # Copyright 2021 The Fairseq Authors and The HuggingFace Inc. team. All rights reserved.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. """PyTorch TrOCR decoder model (based on RoBERTa)."""
  16. import math
  17. from typing import Optional, Union
  18. import torch
  19. from torch import nn
  20. from torch.nn import CrossEntropyLoss
  21. from ...activations import ACT2FN
  22. from ...cache_utils import Cache, DynamicCache, EncoderDecoderCache
  23. from ...generation import GenerationMixin
  24. from ...modeling_attn_mask_utils import (
  25. _prepare_4d_attention_mask,
  26. _prepare_4d_causal_attention_mask,
  27. )
  28. from ...modeling_layers import GradientCheckpointingLayer
  29. from ...modeling_outputs import BaseModelOutputWithPastAndCrossAttentions, CausalLMOutputWithCrossAttentions
  30. from ...modeling_utils import PreTrainedModel
  31. from ...utils import auto_docstring, logging
  32. from ...utils.deprecation import deprecate_kwarg
  33. from .configuration_trocr import TrOCRConfig
  34. logger = logging.get_logger(__name__)
  35. # Copied from transformers.models.bart.modeling_bart.BartLearnedPositionalEmbedding with Bart->TrOCR
  36. class TrOCRLearnedPositionalEmbedding(nn.Embedding):
  37. """
  38. This module learns positional embeddings up to a fixed maximum size.
  39. """
  40. def __init__(self, num_embeddings: int, embedding_dim: int):
  41. # TrOCR is set up so that if padding_idx is specified then offset the embedding ids by 2
  42. # and adjust num_embeddings appropriately. Other models don't have this hack
  43. self.offset = 2
  44. super().__init__(num_embeddings + self.offset, embedding_dim)
  45. def forward(
  46. self, input_ids: torch.Tensor, past_key_values_length: int = 0, position_ids: Optional[torch.Tensor] = None
  47. ):
  48. """`input_ids' shape is expected to be [bsz x seqlen]."""
  49. if position_ids is None:
  50. bsz, seq_len = input_ids.shape[:2]
  51. position_ids = torch.arange(
  52. past_key_values_length, past_key_values_length + seq_len, dtype=torch.long, device=self.weight.device
  53. ).expand(bsz, -1)
  54. else:
  55. position_ids = position_ids.unsqueeze(0)
  56. return super().forward(position_ids + self.offset)
  57. # Copied from transformers.models.bart.modeling_bart.BartScaledWordEmbedding with Bart->TrOCR
  58. class TrOCRScaledWordEmbedding(nn.Embedding):
  59. """
  60. This module overrides nn.Embeddings' forward by multiplying with embeddings scale.
  61. """
  62. def __init__(self, num_embeddings: int, embedding_dim: int, padding_idx: int, embed_scale: Optional[float] = 1.0):
  63. super().__init__(num_embeddings, embedding_dim, padding_idx)
  64. self.embed_scale = embed_scale
  65. def forward(self, input_ids: torch.Tensor):
  66. return super().forward(input_ids) * self.embed_scale
  67. class TrOCRSinusoidalPositionalEmbedding(nn.Module):
  68. """This module produces sinusoidal positional embeddings of any length."""
  69. def __init__(self, num_positions: int, embedding_dim: int, padding_idx: Optional[int] = None):
  70. super().__init__()
  71. self.offset = 2
  72. self.embedding_dim = embedding_dim
  73. self.padding_idx = padding_idx
  74. self.weights = self.get_embedding(num_positions, embedding_dim, padding_idx)
  75. self.register_buffer("_float_tensor", torch.FloatTensor(1))
  76. @staticmethod
  77. def get_embedding(num_embeddings: int, embedding_dim: int, padding_idx: Optional[int] = None):
  78. """
  79. Build sinusoidal embeddings. This matches the implementation in tensor2tensor, but differs slightly from the
  80. description in Section 3.5 of "Attention Is All You Need".
  81. """
  82. half_dim = embedding_dim // 2
  83. emb = math.log(10000) / (half_dim - 1)
  84. emb = torch.exp(torch.arange(half_dim, dtype=torch.int64).float() * -emb)
  85. emb = torch.arange(num_embeddings, dtype=torch.int64).float().unsqueeze(1) * emb.unsqueeze(0)
  86. emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1).view(num_embeddings, -1)
  87. if embedding_dim % 2 == 1:
  88. # zero pad
  89. emb = torch.cat([emb, torch.zeros(num_embeddings, 1)], dim=1)
  90. if padding_idx is not None:
  91. emb[padding_idx, :] = 0
  92. return emb.to(torch.get_default_dtype())
  93. @torch.no_grad()
  94. def forward(self, input_ids: torch.Tensor, past_key_values_length: int = 0):
  95. bsz, seq_len = input_ids.size()
  96. # Create the position ids from the input token ids. Any padded tokens remain padded.
  97. position_ids = self.create_position_ids_from_input_ids(input_ids, self.padding_idx, past_key_values_length).to(
  98. input_ids.device
  99. )
  100. # expand embeddings if needed
  101. max_pos = self.padding_idx + 1 + seq_len
  102. if self.weights is None or max_pos > self.weights.size(0):
  103. # recompute/expand embeddings if needed
  104. self.weights = self.get_embedding(max_pos, self.embedding_dim, self.padding_idx)
  105. self.weights = self.weights.to(self._float_tensor)
  106. x = self.weights.index_select(0, position_ids.view(-1)).view(bsz, seq_len, -1).detach()
  107. return x
  108. def create_position_ids_from_input_ids(
  109. self, input_ids: torch.Tensor, padding_idx: int, past_key_values_length: Optional[int] = 0
  110. ):
  111. """
  112. Replace non-padding symbols with their position numbers. Position numbers begin at padding_idx+1. Padding
  113. symbols are ignored. This is modified from fairseq's `utils.make_positions`.
  114. """
  115. # The series of casts and type-conversions here are carefully balanced to both work with ONNX export and XLA.
  116. mask = input_ids.ne(padding_idx).int()
  117. incremental_indices = (torch.cumsum(mask, dim=1).type_as(mask) + past_key_values_length) * mask
  118. return incremental_indices.long() + padding_idx
  119. class TrOCRAttention(nn.Module):
  120. """Multi-headed attention from 'Attention Is All You Need' paper."""
  121. def __init__(
  122. self,
  123. config,
  124. embed_dim: int,
  125. num_heads: int,
  126. kdim: Optional[int] = None,
  127. vdim: Optional[int] = None,
  128. dropout: Optional[float] = 0.0,
  129. is_decoder: Optional[bool] = False,
  130. bias: Optional[bool] = True,
  131. is_cross_attention: Optional[bool] = False,
  132. layer_idx: Optional[bool] = None,
  133. ):
  134. super().__init__()
  135. self.embed_dim = embed_dim
  136. self.kdim = kdim if kdim is not None else embed_dim
  137. self.vdim = vdim if vdim is not None else embed_dim
  138. self.num_heads = num_heads
  139. self.dropout = dropout
  140. self.head_dim = embed_dim // num_heads
  141. if not (self.head_dim * num_heads == self.embed_dim):
  142. raise ValueError(
  143. f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`:"
  144. f" {num_heads})."
  145. )
  146. self.scaling = self.head_dim**-0.5
  147. self.is_decoder = is_decoder
  148. self.layer_idx = layer_idx
  149. self.k_proj = nn.Linear(self.kdim, embed_dim, bias=bias)
  150. self.v_proj = nn.Linear(self.vdim, embed_dim, bias=bias)
  151. self.q_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
  152. self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
  153. @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58")
  154. def forward(
  155. self,
  156. hidden_states: torch.Tensor,
  157. key_value_states: Optional[torch.Tensor] = None,
  158. past_key_values: Optional[Cache] = None,
  159. attention_mask: Optional[torch.Tensor] = None,
  160. layer_head_mask: Optional[torch.Tensor] = None,
  161. output_attentions: Optional[bool] = False,
  162. cache_position: Optional[torch.Tensor] = None,
  163. ) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]:
  164. """Input shape: Batch x Time x Channel"""
  165. # if key_value_states are provided this layer is used as a cross-attention layer
  166. # for the decoder
  167. is_cross_attention = key_value_states is not None
  168. bsz, tgt_len, embed_dim = hidden_states.size()
  169. # get query proj
  170. query_states = self.q_proj(hidden_states) * self.scaling
  171. is_updated = False
  172. if past_key_values is not None:
  173. if isinstance(past_key_values, EncoderDecoderCache):
  174. is_updated = past_key_values.is_updated.get(self.layer_idx)
  175. if is_cross_attention:
  176. # after the first generated id, we can subsequently re-use all key/value_states from cache
  177. curr_past_key_value = past_key_values.cross_attention_cache
  178. else:
  179. curr_past_key_value = past_key_values.self_attention_cache
  180. else:
  181. curr_past_key_value = past_key_values
  182. current_states = key_value_states if is_cross_attention else hidden_states
  183. if is_cross_attention and past_key_values is not None and is_updated:
  184. # reuse k,v, cross_attentions
  185. key_states = curr_past_key_value.layers[self.layer_idx].keys
  186. value_states = curr_past_key_value.layers[self.layer_idx].values
  187. else:
  188. key_states = self.k_proj(current_states)
  189. value_states = self.v_proj(current_states)
  190. key_states = key_states.view(bsz, -1, self.num_heads, self.head_dim).transpose(1, 2)
  191. value_states = value_states.view(bsz, -1, self.num_heads, self.head_dim).transpose(1, 2)
  192. if past_key_values is not None:
  193. # save all key/value_states to cache to be re-used for fast auto-regressive generation
  194. cache_position = cache_position if not is_cross_attention else None
  195. key_states, value_states = curr_past_key_value.update(
  196. key_states, value_states, self.layer_idx, {"cache_position": cache_position}
  197. )
  198. # set flag that curr layer for cross-attn is already updated so we can re-use in subsequent calls
  199. if is_cross_attention and isinstance(past_key_values, EncoderDecoderCache):
  200. past_key_values.is_updated[self.layer_idx] = True
  201. proj_shape = (bsz * self.num_heads, -1, self.head_dim)
  202. query_states = query_states.view(bsz, tgt_len, self.num_heads, self.head_dim).transpose(1, 2)
  203. query_states = query_states.reshape(*proj_shape)
  204. key_states = key_states.reshape(*proj_shape)
  205. value_states = value_states.reshape(*proj_shape)
  206. src_len = key_states.size(1)
  207. attn_weights = torch.bmm(query_states, key_states.transpose(1, 2))
  208. if attn_weights.size() != (bsz * self.num_heads, tgt_len, src_len):
  209. raise ValueError(
  210. f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is"
  211. f" {attn_weights.size()}"
  212. )
  213. if attention_mask is not None:
  214. if attention_mask.size() != (bsz, 1, tgt_len, src_len):
  215. raise ValueError(
  216. f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {attention_mask.size()}"
  217. )
  218. attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attention_mask
  219. attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
  220. attn_weights = nn.functional.softmax(attn_weights, dim=-1)
  221. if layer_head_mask is not None:
  222. if layer_head_mask.size() != (self.num_heads,):
  223. raise ValueError(
  224. f"Head mask for a single layer should be of size {(self.num_heads,)}, but is"
  225. f" {layer_head_mask.size()}"
  226. )
  227. attn_weights = layer_head_mask.view(1, -1, 1, 1) * attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
  228. attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
  229. if output_attentions:
  230. # this operation is a bit awkward, but it's required to
  231. # make sure that attn_weights keeps its gradient.
  232. # In order to do so, attn_weights have to be reshaped
  233. # twice and have to be reused in the following
  234. attn_weights_reshaped = attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
  235. attn_weights = attn_weights_reshaped.view(bsz * self.num_heads, tgt_len, src_len)
  236. else:
  237. attn_weights_reshaped = None
  238. attn_probs = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training)
  239. attn_output = torch.bmm(attn_probs, value_states)
  240. if attn_output.size() != (bsz * self.num_heads, tgt_len, self.head_dim):
  241. raise ValueError(
  242. f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is"
  243. f" {attn_output.size()}"
  244. )
  245. attn_output = attn_output.view(bsz, self.num_heads, tgt_len, self.head_dim)
  246. attn_output = attn_output.transpose(1, 2)
  247. attn_output = attn_output.reshape(bsz, tgt_len, embed_dim)
  248. attn_output = self.out_proj(attn_output)
  249. return attn_output, attn_weights_reshaped
  250. class TrOCRDecoderLayer(GradientCheckpointingLayer):
  251. def __init__(self, config: TrOCRConfig, layer_idx=None):
  252. super().__init__()
  253. self.embed_dim = config.hidden_size
  254. self.self_attn = TrOCRAttention(
  255. config,
  256. embed_dim=self.embed_dim,
  257. num_heads=config.decoder_attention_heads,
  258. dropout=config.attention_dropout,
  259. is_decoder=True,
  260. layer_idx=layer_idx,
  261. )
  262. self.dropout = config.dropout
  263. self.activation_fn = ACT2FN[config.activation_function]
  264. self.activation_dropout = config.activation_dropout
  265. self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim)
  266. if config.is_decoder:
  267. self.encoder_attn = TrOCRAttention(
  268. config,
  269. embed_dim=self.embed_dim,
  270. num_heads=config.decoder_attention_heads,
  271. kdim=config.cross_attention_hidden_size,
  272. vdim=config.cross_attention_hidden_size,
  273. dropout=config.attention_dropout,
  274. is_decoder=True,
  275. is_cross_attention=True,
  276. layer_idx=layer_idx,
  277. )
  278. self.encoder_attn_layer_norm = nn.LayerNorm(self.embed_dim)
  279. self.fc1 = nn.Linear(self.embed_dim, config.decoder_ffn_dim)
  280. self.fc2 = nn.Linear(config.decoder_ffn_dim, self.embed_dim)
  281. self.final_layer_norm = nn.LayerNorm(self.embed_dim)
  282. @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58")
  283. def forward(
  284. self,
  285. hidden_states: torch.Tensor,
  286. attention_mask: Optional[torch.Tensor] = None,
  287. encoder_hidden_states: Optional[torch.Tensor] = None,
  288. encoder_attention_mask: Optional[torch.Tensor] = None,
  289. layer_head_mask: Optional[torch.Tensor] = None,
  290. cross_attn_layer_head_mask: Optional[torch.Tensor] = None,
  291. past_key_values: Optional[Cache] = None,
  292. output_attentions: Optional[bool] = False,
  293. use_cache: Optional[bool] = True,
  294. cache_position: Optional[torch.Tensor] = None,
  295. ):
  296. """
  297. Args:
  298. hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
  299. attention_mask (`torch.FloatTensor`): attention mask of size
  300. `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
  301. encoder_hidden_states (`torch.FloatTensor`):
  302. cross attention input to the layer of shape `(batch, seq_len, embed_dim)`
  303. encoder_attention_mask (`torch.FloatTensor`): encoder attention mask of size
  304. `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
  305. layer_head_mask (`torch.FloatTensor`): mask for attention heads in a given layer of size
  306. `(encoder_attention_heads,)`.
  307. cross_attn_layer_head_mask (`torch.FloatTensor`): mask for cross-attention heads in a given layer of
  308. size *(decoder_attention_heads,)*.
  309. past_key_values (`Cache`): cached past key and value projection states
  310. output_attentions (`bool`, *optional*):
  311. Whether or not to return the attentions tensors of all attention layers. See `attentions` under
  312. returned tensors for more detail.
  313. """
  314. residual = hidden_states
  315. # Self Attention
  316. hidden_states, self_attn_weights = self.self_attn(
  317. hidden_states=hidden_states,
  318. past_key_values=past_key_values,
  319. attention_mask=attention_mask,
  320. layer_head_mask=layer_head_mask,
  321. output_attentions=output_attentions,
  322. cache_position=cache_position,
  323. )
  324. hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
  325. hidden_states = residual + hidden_states
  326. hidden_states = self.self_attn_layer_norm(hidden_states)
  327. # Cross-Attention Block
  328. cross_attn_weights = None
  329. if encoder_hidden_states is not None:
  330. residual = hidden_states
  331. hidden_states, cross_attn_weights = self.encoder_attn(
  332. hidden_states=hidden_states,
  333. key_value_states=encoder_hidden_states,
  334. attention_mask=encoder_attention_mask,
  335. layer_head_mask=cross_attn_layer_head_mask,
  336. past_key_values=past_key_values,
  337. output_attentions=output_attentions,
  338. cache_position=cache_position,
  339. )
  340. hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
  341. hidden_states = residual + hidden_states
  342. hidden_states = self.encoder_attn_layer_norm(hidden_states)
  343. # Fully Connected
  344. residual = hidden_states
  345. hidden_states = self.activation_fn(self.fc1(hidden_states))
  346. hidden_states = nn.functional.dropout(hidden_states, p=self.activation_dropout, training=self.training)
  347. hidden_states = self.fc2(hidden_states)
  348. hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
  349. hidden_states = residual + hidden_states
  350. hidden_states = self.final_layer_norm(hidden_states)
  351. outputs = (hidden_states,)
  352. if output_attentions:
  353. outputs += (self_attn_weights, cross_attn_weights)
  354. return outputs
  355. @auto_docstring
  356. class TrOCRPreTrainedModel(PreTrainedModel):
  357. config: TrOCRConfig
  358. base_model_prefix = "model"
  359. supports_gradient_checkpointing = True
  360. _no_split_modules = ["TrOCRDecoderLayer"]
  361. def _init_weights(self, module):
  362. std = self.config.init_std
  363. if isinstance(module, (nn.Linear, nn.Conv1d)):
  364. module.weight.data.normal_(mean=0.0, std=std)
  365. if module.bias is not None:
  366. module.bias.data.zero_()
  367. elif isinstance(module, nn.Embedding):
  368. module.weight.data.normal_(mean=0.0, std=std)
  369. if module.padding_idx is not None:
  370. module.weight.data[module.padding_idx].zero_()
  371. class TrOCRDecoder(TrOCRPreTrainedModel):
  372. """
  373. Transformer decoder consisting of *config.decoder_layers* layers. Each layer is a [`TrOCRDecoderLayer`]
  374. Args:
  375. config: TrOCRConfig
  376. """
  377. def __init__(self, config: TrOCRConfig):
  378. super().__init__(config)
  379. self.dropout = config.dropout
  380. self.layerdrop = config.decoder_layerdrop
  381. self.padding_idx = config.pad_token_id
  382. embed_scale = math.sqrt(config.hidden_size) if config.scale_embedding else 1.0
  383. self.embed_tokens = TrOCRScaledWordEmbedding(
  384. config.vocab_size, config.hidden_size, self.padding_idx, embed_scale=embed_scale
  385. )
  386. if config.use_learned_position_embeddings:
  387. self.embed_positions = TrOCRLearnedPositionalEmbedding(config.max_position_embeddings, config.hidden_size)
  388. else:
  389. self.embed_positions = TrOCRSinusoidalPositionalEmbedding(
  390. config.max_position_embeddings + self.padding_idx + 1,
  391. config.hidden_size,
  392. self.padding_idx,
  393. )
  394. if config.layernorm_embedding:
  395. self.layernorm_embedding = nn.LayerNorm(config.hidden_size)
  396. else:
  397. self.layernorm_embedding = None
  398. self.layers = nn.ModuleList([TrOCRDecoderLayer(config, layer_idx=i) for i in range(config.decoder_layers)])
  399. self.gradient_checkpointing = False
  400. # Initialize weights and apply final processing
  401. self.post_init()
  402. def forward(
  403. self,
  404. input_ids=None,
  405. attention_mask=None,
  406. encoder_hidden_states=None,
  407. encoder_attention_mask=None,
  408. head_mask=None,
  409. cross_attn_head_mask=None,
  410. past_key_values=None,
  411. inputs_embeds=None,
  412. use_cache=None,
  413. output_attentions=None,
  414. output_hidden_states=None,
  415. return_dict=None,
  416. cache_position=None,
  417. ):
  418. r"""
  419. Args:
  420. input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
  421. Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you
  422. provide it.
  423. Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
  424. [`PreTrainedTokenizer.__call__`] for details.
  425. [What are input IDs?](../glossary#input-ids)
  426. attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
  427. Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
  428. - 1 for tokens that are **not masked**,
  429. - 0 for tokens that are **masked**.
  430. [What are attention masks?](../glossary#attention-mask)
  431. encoder_hidden_states (`torch.FloatTensor` of shape `(batch_size, encoder_sequence_length, hidden_size)`, *optional*):
  432. Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention
  433. of the decoder.
  434. encoder_attention_mask (`torch.LongTensor` of shape `(batch_size, encoder_sequence_length)`, *optional*):
  435. Mask to avoid performing cross-attention on padding tokens indices of encoder input_ids. Mask values
  436. selected in `[0, 1]`:
  437. - 1 for tokens that are **not masked**,
  438. - 0 for tokens that are **masked**.
  439. [What are attention masks?](../glossary#attention-mask)
  440. head_mask (`torch.Tensor` of shape `(decoder_layers, decoder_attention_heads)`, *optional*):
  441. Mask to nullify selected heads of the attention modules. Mask values selected in `[0, 1]`:
  442. - 1 indicates the head is **not masked**,
  443. - 0 indicates the head is **masked**.
  444. cross_attn_head_mask (`torch.Tensor` of shape `(decoder_layers, decoder_attention_heads)`, *optional*):
  445. Mask to nullify selected heads of the attention modules in encoder to avoid performing cross-attention
  446. on hidden heads. Mask values selected in `[0, 1]`:
  447. - 1 indicates the head is **not masked**,
  448. - 0 indicates the head is **masked**.
  449. past_key_values (`Cache`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
  450. It is a [`~cache_utils.Cache`] instance. For more details, see our [kv cache guide](https://huggingface.co/docs/transformers/en/kv_cache).
  451. Contains pre-computed hidden-states (key and values in the self-attention blocks and in the
  452. cross-attention blocks) that can be used (see `past_key_values` input) to speed up sequential decoding.
  453. If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those
  454. that don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of
  455. all `decoder_input_ids` of shape `(batch_size, sequence_length)`.
  456. inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
  457. Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation.
  458. This is useful if you want more control over how to convert `input_ids` indices into associated vectors
  459. than the model's internal embedding lookup matrix.
  460. output_attentions (`bool`, *optional*):
  461. Whether or not to return the attentions tensors of all attention layers. See `attentions` under
  462. returned tensors for more detail.
  463. output_hidden_states (`bool`, *optional*):
  464. Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors
  465. for more detail.
  466. return_dict (`bool`, *optional*):
  467. Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
  468. """
  469. output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
  470. output_hidden_states = (
  471. output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
  472. )
  473. use_cache = use_cache if use_cache is not None else self.config.use_cache
  474. return_dict = return_dict if return_dict is not None else self.config.use_return_dict
  475. # retrieve input_ids and inputs_embeds
  476. if input_ids is not None and inputs_embeds is not None:
  477. raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time")
  478. elif input_ids is not None:
  479. input = input_ids
  480. input_ids = input_ids.view(-1, input.shape[-1])
  481. elif inputs_embeds is not None:
  482. input_shape = inputs_embeds.size()[:-1]
  483. input = inputs_embeds[:, :, -1]
  484. else:
  485. raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds")
  486. if self.gradient_checkpointing and self.training:
  487. if use_cache:
  488. logger.warning_once(
  489. "`use_cache = True` is incompatible with gradient checkpointing. Setting `use_cache = False`..."
  490. )
  491. use_cache = False
  492. if use_cache and past_key_values is None:
  493. past_key_values = (
  494. EncoderDecoderCache(DynamicCache(config=self.config), DynamicCache(config=self.config))
  495. if encoder_hidden_states is not None
  496. else DynamicCache(config=self.config)
  497. )
  498. if use_cache and isinstance(past_key_values, tuple):
  499. logger.warning_once(
  500. "Passing a tuple of `past_key_values` is deprecated and will be removed in Transformers v4.58.0. "
  501. "You should pass an instance of `EncoderDecoderCache` instead, e.g. "
  502. "`past_key_values=EncoderDecoderCache.from_legacy_cache(past_key_values)`."
  503. )
  504. past_key_values = EncoderDecoderCache.from_legacy_cache(past_key_values)
  505. past_key_values_length = past_key_values.get_seq_length() if past_key_values is not None else 0
  506. if inputs_embeds is None:
  507. inputs_embeds = self.embed_tokens(input_ids)
  508. if self.config.use_learned_position_embeddings:
  509. embed_pos = self.embed_positions(input, past_key_values_length=past_key_values_length)
  510. else:
  511. embed_pos = self.embed_positions(input_ids, past_key_values_length=past_key_values_length)
  512. hidden_states = inputs_embeds + embed_pos
  513. if self.layernorm_embedding is not None:
  514. hidden_states = self.layernorm_embedding(hidden_states)
  515. hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
  516. input_shape = input.shape
  517. attention_mask = _prepare_4d_causal_attention_mask(
  518. attention_mask, input_shape, inputs_embeds, past_key_values_length
  519. )
  520. # expand encoder attention mask
  521. if encoder_hidden_states is not None and encoder_attention_mask is not None:
  522. # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
  523. encoder_attention_mask = _prepare_4d_attention_mask(
  524. encoder_attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]
  525. )
  526. # decoder layers
  527. all_hidden_states = () if output_hidden_states else None
  528. all_self_attns = () if output_attentions else None
  529. all_cross_attentions = () if (output_attentions and encoder_hidden_states is not None) else None
  530. # check if head_mask/cross_attn_head_mask has a correct number of layers specified if desired
  531. for attn_mask, mask_name in zip([head_mask, cross_attn_head_mask], ["head_mask", "cross_attn_head_mask"]):
  532. if attn_mask is not None:
  533. if attn_mask.size()[0] != (len(self.layers)):
  534. raise ValueError(
  535. f"The `{mask_name}` should be specified for {len(self.layers)} layers, but it is for"
  536. f" {head_mask.size()[0]}."
  537. )
  538. for idx, decoder_layer in enumerate(self.layers):
  539. # add LayerDrop (see https://huggingface.co/papers/1909.11556 for description)
  540. if output_hidden_states:
  541. all_hidden_states += (hidden_states,)
  542. if self.training:
  543. dropout_probability = torch.rand([])
  544. if dropout_probability < self.layerdrop:
  545. continue
  546. layer_outputs = decoder_layer(
  547. hidden_states,
  548. attention_mask,
  549. encoder_hidden_states, # as a positional argument for gradient checkpointing
  550. encoder_attention_mask=encoder_attention_mask,
  551. layer_head_mask=(head_mask[idx] if head_mask is not None else None),
  552. cross_attn_layer_head_mask=(cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None),
  553. past_key_values=past_key_values,
  554. output_attentions=output_attentions,
  555. use_cache=use_cache,
  556. cache_position=cache_position,
  557. )
  558. hidden_states = layer_outputs[0]
  559. if output_attentions:
  560. all_self_attns += (layer_outputs[1],)
  561. if encoder_hidden_states is not None:
  562. all_cross_attentions += (layer_outputs[2],)
  563. # add hidden states from the last decoder layer
  564. if output_hidden_states:
  565. all_hidden_states += (hidden_states,)
  566. if not return_dict:
  567. return tuple(
  568. v
  569. for v in [hidden_states, past_key_values, all_hidden_states, all_self_attns, all_cross_attentions]
  570. if v is not None
  571. )
  572. return BaseModelOutputWithPastAndCrossAttentions(
  573. last_hidden_state=hidden_states,
  574. past_key_values=past_key_values,
  575. hidden_states=all_hidden_states,
  576. attentions=all_self_attns,
  577. cross_attentions=all_cross_attentions,
  578. )
  579. @auto_docstring(
  580. custom_intro="""
  581. The TrOCR Model with a language modeling head. Can be used for summarization.
  582. This wrapper class is a helper class to correctly load pretrained checkpoints when the causal language model is
  583. used in combination with the [`EncoderDecoderModel`] framework.
  584. """
  585. )
  586. class TrOCRDecoderWrapper(TrOCRPreTrainedModel):
  587. def __init__(self, config):
  588. super().__init__(config)
  589. self.decoder = TrOCRDecoder(config)
  590. def forward(self, *args, **kwargs):
  591. return self.decoder(*args, **kwargs)
  592. @auto_docstring(
  593. custom_intro="""
  594. The TrOCR Decoder with a language modeling head. Can be used as the decoder part of [`EncoderDecoderModel`] and
  595. """
  596. )
  597. class TrOCRForCausalLM(TrOCRPreTrainedModel, GenerationMixin):
  598. _tied_weights_keys = ["output_projection.weight"]
  599. def __init__(self, config):
  600. config.is_decoder = True
  601. config.is_encoder_decoder = False
  602. super().__init__(config)
  603. self.model = TrOCRDecoderWrapper(config)
  604. self.output_projection = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
  605. # Initialize weights and apply final processing
  606. self.post_init()
  607. def get_input_embeddings(self):
  608. return self.model.decoder.embed_tokens
  609. def set_input_embeddings(self, value):
  610. self.model.decoder.embed_tokens = value
  611. def get_output_embeddings(self):
  612. return self.output_projection
  613. def set_output_embeddings(self, new_embeddings):
  614. self.output_projection = new_embeddings
  615. def set_decoder(self, decoder):
  616. self.model.decoder = decoder
  617. def get_decoder(self):
  618. return self.model.decoder
  619. @auto_docstring
  620. def forward(
  621. self,
  622. input_ids: Optional[torch.LongTensor] = None,
  623. attention_mask: Optional[torch.Tensor] = None,
  624. encoder_hidden_states: Optional[torch.FloatTensor] = None,
  625. encoder_attention_mask: Optional[torch.LongTensor] = None,
  626. head_mask: Optional[torch.Tensor] = None,
  627. cross_attn_head_mask: Optional[torch.Tensor] = None,
  628. past_key_values: Optional[Cache] = None,
  629. inputs_embeds: Optional[torch.FloatTensor] = None,
  630. labels: Optional[torch.LongTensor] = None,
  631. use_cache: Optional[bool] = None,
  632. output_attentions: Optional[bool] = None,
  633. output_hidden_states: Optional[bool] = None,
  634. return_dict: Optional[bool] = None,
  635. cache_position: Optional[torch.Tensor] = None,
  636. ) -> Union[tuple, CausalLMOutputWithCrossAttentions]:
  637. r"""
  638. cross_attn_head_mask (`torch.Tensor` of shape `(decoder_layers, decoder_attention_heads)`, *optional*):
  639. Mask to nullify selected heads of the cross-attention modules. Mask values selected in `[0, 1]`:
  640. - 1 indicates the head is **not masked**,
  641. - 0 indicates the head is **masked**.
  642. labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
  643. Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
  644. config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
  645. (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
  646. Example:
  647. ```python
  648. >>> from transformers import (
  649. ... TrOCRConfig,
  650. ... TrOCRProcessor,
  651. ... TrOCRForCausalLM,
  652. ... ViTConfig,
  653. ... ViTModel,
  654. ... VisionEncoderDecoderModel,
  655. ... )
  656. >>> import requests
  657. >>> from PIL import Image
  658. >>> # TrOCR is a decoder model and should be used within a VisionEncoderDecoderModel
  659. >>> # init vision2text model with random weights
  660. >>> encoder = ViTModel(ViTConfig())
  661. >>> decoder = TrOCRForCausalLM(TrOCRConfig())
  662. >>> model = VisionEncoderDecoderModel(encoder=encoder, decoder=decoder)
  663. >>> # If you want to start from the pretrained model, load the checkpoint with `VisionEncoderDecoderModel`
  664. >>> processor = TrOCRProcessor.from_pretrained("microsoft/trocr-base-handwritten")
  665. >>> model = VisionEncoderDecoderModel.from_pretrained("microsoft/trocr-base-handwritten")
  666. >>> # load image from the IAM dataset
  667. >>> url = "https://fki.tic.heia-fr.ch/static/img/a01-122-02.jpg"
  668. >>> image = Image.open(requests.get(url, stream=True).raw).convert("RGB")
  669. >>> pixel_values = processor(image, return_tensors="pt").pixel_values
  670. >>> text = "industry, ' Mr. Brown commented icily. ' Let us have a"
  671. >>> # training
  672. >>> model.config.decoder_start_token_id = processor.tokenizer.eos_token_id
  673. >>> model.config.pad_token_id = processor.tokenizer.pad_token_id
  674. >>> model.config.vocab_size = model.config.decoder.vocab_size
  675. >>> labels = processor.tokenizer(text, return_tensors="pt").input_ids
  676. >>> outputs = model(pixel_values, labels=labels)
  677. >>> loss = outputs.loss
  678. >>> round(loss.item(), 2)
  679. 5.30
  680. >>> # inference
  681. >>> generated_ids = model.generate(pixel_values)
  682. >>> generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
  683. >>> generated_text
  684. 'industry, " Mr. Brown commented icily. " Let us have a'
  685. ```"""
  686. output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
  687. output_hidden_states = (
  688. output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
  689. )
  690. return_dict = return_dict if return_dict is not None else self.config.use_return_dict
  691. # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
  692. outputs = self.model.decoder(
  693. input_ids=input_ids,
  694. attention_mask=attention_mask,
  695. encoder_hidden_states=encoder_hidden_states,
  696. encoder_attention_mask=encoder_attention_mask,
  697. head_mask=head_mask,
  698. cross_attn_head_mask=cross_attn_head_mask,
  699. past_key_values=past_key_values,
  700. inputs_embeds=inputs_embeds,
  701. use_cache=use_cache,
  702. output_attentions=output_attentions,
  703. output_hidden_states=output_hidden_states,
  704. return_dict=return_dict,
  705. cache_position=cache_position,
  706. )
  707. logits = self.output_projection(outputs[0])
  708. loss = None
  709. if labels is not None:
  710. loss_fct = CrossEntropyLoss()
  711. loss = loss_fct(logits.view(-1, self.config.vocab_size), labels.view(-1))
  712. if not return_dict:
  713. output = (logits,) + outputs[1:]
  714. return (loss,) + output if loss is not None else output
  715. return CausalLMOutputWithCrossAttentions(
  716. loss=loss,
  717. logits=logits,
  718. past_key_values=outputs.past_key_values,
  719. hidden_states=outputs.hidden_states,
  720. attentions=outputs.attentions,
  721. cross_attentions=outputs.cross_attentions,
  722. )
  723. __all__ = ["TrOCRForCausalLM", "TrOCRPreTrainedModel"]