modeling_beit.py 63 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160116111621163116411651166116711681169117011711172117311741175117611771178117911801181118211831184118511861187118811891190119111921193119411951196119711981199120012011202120312041205120612071208120912101211121212131214121512161217121812191220122112221223122412251226122712281229123012311232123312341235123612371238123912401241124212431244124512461247124812491250125112521253125412551256125712581259126012611262126312641265126612671268126912701271127212731274127512761277127812791280128112821283128412851286128712881289129012911292129312941295129612971298129913001301130213031304130513061307130813091310131113121313131413151316131713181319132013211322132313241325132613271328132913301331133213331334133513361337133813391340134113421343134413451346134713481349135013511352135313541355135613571358135913601361136213631364136513661367136813691370137113721373137413751376137713781379138013811382138313841385138613871388138913901391139213931394139513961397139813991400140114021403140414051406140714081409141014111412141314141415141614171418141914201421142214231424142514261427142814291430143114321433143414351436143714381439144014411442144314441445144614471448144914501451145214531454145514561457145814591460146114621463146414651466146714681469147014711472147314741475147614771478147914801481148214831484148514861487148814891490149114921493149414951496149714981499150015011502150315041505150615071508150915101511151215131514151515161517151815191520152115221523152415251526152715281529153015311532153315341535153615371538153915401541154215431544154515461547
  1. # coding=utf-8
  2. # Copyright 2021 Microsoft Research and The HuggingFace Inc. team. All rights reserved.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. """PyTorch BEiT model."""
  16. import collections.abc
  17. import math
  18. import warnings
  19. from dataclasses import dataclass
  20. from typing import Optional, Union
  21. import torch
  22. from torch import Tensor, nn
  23. from torch.nn import CrossEntropyLoss
  24. from ...activations import ACT2FN
  25. from ...modeling_layers import GradientCheckpointingLayer
  26. from ...modeling_outputs import (
  27. BackboneOutput,
  28. BaseModelOutput,
  29. BaseModelOutputWithPooling,
  30. ImageClassifierOutput,
  31. MaskedLMOutput,
  32. SemanticSegmenterOutput,
  33. )
  34. from ...modeling_utils import PreTrainedModel
  35. from ...pytorch_utils import compile_compatible_method_lru_cache, find_pruneable_heads_and_indices, prune_linear_layer
  36. from ...utils import auto_docstring, logging, torch_int
  37. from ...utils.backbone_utils import BackboneMixin
  38. from .configuration_beit import BeitConfig
  39. logger = logging.get_logger(__name__)
  40. @dataclass
  41. @auto_docstring(
  42. custom_intro="""
  43. Class for outputs of [`BeitModel`].
  44. """
  45. )
  46. class BeitModelOutputWithPooling(BaseModelOutputWithPooling):
  47. r"""
  48. pooler_output (`torch.FloatTensor` of shape `(batch_size, hidden_size)`):
  49. Average of the last layer hidden states of the patch tokens (excluding the *[CLS]* token) if
  50. *config.use_mean_pooling* is set to True. If set to False, then the final hidden state of the *[CLS]* token
  51. will be returned.
  52. """
  53. def drop_path(input: torch.Tensor, drop_prob: float = 0.0, training: bool = False) -> torch.Tensor:
  54. """
  55. Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
  56. Comment by Ross Wightman: This is the same as the DropConnect impl I created for EfficientNet, etc networks,
  57. however, the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper...
  58. See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for changing the
  59. layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use 'survival rate' as the
  60. argument.
  61. """
  62. if drop_prob == 0.0 or not training:
  63. return input
  64. keep_prob = 1 - drop_prob
  65. shape = (input.shape[0],) + (1,) * (input.ndim - 1) # work with diff dim tensors, not just 2D ConvNets
  66. random_tensor = keep_prob + torch.rand(shape, dtype=input.dtype, device=input.device)
  67. random_tensor.floor_() # binarize
  68. output = input.div(keep_prob) * random_tensor
  69. return output
  70. class BeitDropPath(nn.Module):
  71. """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks)."""
  72. def __init__(self, drop_prob: Optional[float] = None) -> None:
  73. super().__init__()
  74. self.drop_prob = drop_prob
  75. def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
  76. return drop_path(hidden_states, self.drop_prob, self.training)
  77. def extra_repr(self) -> str:
  78. return f"p={self.drop_prob}"
  79. # Based on timm implementation, which can be found here:
  80. # https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py
  81. class BeitEmbeddings(nn.Module):
  82. """
  83. Construct the CLS token, position and patch embeddings. Optionally, also the mask token.
  84. """
  85. def __init__(self, config: BeitConfig) -> None:
  86. super().__init__()
  87. self.cls_token = nn.Parameter(torch.zeros(1, 1, config.hidden_size))
  88. if config.use_mask_token:
  89. self.mask_token = nn.Parameter(torch.zeros(1, 1, config.hidden_size))
  90. else:
  91. self.mask_token = None
  92. self.patch_embeddings = BeitPatchEmbeddings(config)
  93. self.patch_size = config.patch_size
  94. self.image_size = (
  95. config.image_size
  96. if isinstance(config.image_size, collections.abc.Iterable)
  97. else (config.image_size, config.image_size)
  98. )
  99. num_patches = self.patch_embeddings.num_patches
  100. if config.use_absolute_position_embeddings:
  101. self.position_embeddings = nn.Parameter(torch.zeros(1, num_patches + 1, config.hidden_size))
  102. else:
  103. self.position_embeddings = None
  104. self.dropout = nn.Dropout(config.hidden_dropout_prob)
  105. # Copied from transformers.models.vit.modeling_vit.ViTEmbeddings.interpolate_pos_encoding
  106. def interpolate_pos_encoding(self, embeddings: torch.Tensor, height: int, width: int) -> torch.Tensor:
  107. """
  108. This method allows to interpolate the pre-trained position encodings, to be able to use the model on higher resolution
  109. images. This method is also adapted to support torch.jit tracing.
  110. Adapted from:
  111. - https://github.com/facebookresearch/dino/blob/de9ee3df6cf39fac952ab558447af1fa1365362a/vision_transformer.py#L174-L194, and
  112. - https://github.com/facebookresearch/dinov2/blob/e1277af2ba9496fbadf7aec6eba56e8d882d1e35/dinov2/models/vision_transformer.py#L179-L211
  113. """
  114. num_patches = embeddings.shape[1] - 1
  115. num_positions = self.position_embeddings.shape[1] - 1
  116. # always interpolate when tracing to ensure the exported model works for dynamic input shapes
  117. if not torch.jit.is_tracing() and num_patches == num_positions and height == width:
  118. return self.position_embeddings
  119. class_pos_embed = self.position_embeddings[:, :1]
  120. patch_pos_embed = self.position_embeddings[:, 1:]
  121. dim = embeddings.shape[-1]
  122. new_height = height // self.patch_size
  123. new_width = width // self.patch_size
  124. sqrt_num_positions = torch_int(num_positions**0.5)
  125. patch_pos_embed = patch_pos_embed.reshape(1, sqrt_num_positions, sqrt_num_positions, dim)
  126. patch_pos_embed = patch_pos_embed.permute(0, 3, 1, 2)
  127. patch_pos_embed = nn.functional.interpolate(
  128. patch_pos_embed,
  129. size=(new_height, new_width),
  130. mode="bicubic",
  131. align_corners=False,
  132. )
  133. patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
  134. return torch.cat((class_pos_embed, patch_pos_embed), dim=1)
  135. def forward(
  136. self,
  137. pixel_values: torch.Tensor,
  138. bool_masked_pos: Optional[torch.BoolTensor] = None,
  139. interpolate_pos_encoding: Optional[bool] = None,
  140. ) -> torch.Tensor:
  141. if self.position_embeddings is not None and interpolate_pos_encoding is not None:
  142. warnings.warn(
  143. "`interpolate_pos_encoding` argument has no effect for BEiTEmbeddings, embeddings are always "
  144. "interpolated to the input image size. The argument will be removed in transformers v4.51.0."
  145. )
  146. _, _, height, width = pixel_values.shape
  147. embeddings, (patch_height, patch_width) = self.patch_embeddings(pixel_values)
  148. batch_size, seq_len, _ = embeddings.size()
  149. if bool_masked_pos is not None:
  150. mask_tokens = self.mask_token.expand(batch_size, seq_len, -1)
  151. # replace the masked visual tokens by mask_tokens
  152. w = bool_masked_pos.unsqueeze(-1).type_as(mask_tokens)
  153. embeddings = embeddings * (1 - w) + mask_tokens * w
  154. cls_tokens = self.cls_token.expand(batch_size, -1, -1)
  155. embeddings = torch.cat((cls_tokens, embeddings), dim=1)
  156. if self.position_embeddings is not None:
  157. embeddings = embeddings + self.interpolate_pos_encoding(embeddings, height, width)
  158. embeddings = self.dropout(embeddings)
  159. return embeddings, (patch_height, patch_width)
  160. class BeitPatchEmbeddings(nn.Module):
  161. """
  162. This class turns `pixel_values` of shape `(batch_size, num_channels, height, width)` into the initial
  163. `hidden_states` (patch embeddings) of shape `(batch_size, seq_length, hidden_size)` to be consumed by a
  164. Transformer.
  165. """
  166. def __init__(self, config):
  167. super().__init__()
  168. image_size, patch_size = config.image_size, config.patch_size
  169. num_channels, hidden_size = config.num_channels, config.hidden_size
  170. image_size = image_size if isinstance(image_size, collections.abc.Iterable) else (image_size, image_size)
  171. patch_size = patch_size if isinstance(patch_size, collections.abc.Iterable) else (patch_size, patch_size)
  172. num_patches = (image_size[1] // patch_size[1]) * (image_size[0] // patch_size[0])
  173. patch_shape = (image_size[0] // patch_size[0], image_size[1] // patch_size[1])
  174. self.image_size = image_size
  175. self.patch_size = patch_size
  176. self.num_channels = num_channels
  177. self.num_patches = num_patches
  178. self.patch_shape = patch_shape
  179. self.projection = nn.Conv2d(num_channels, hidden_size, kernel_size=patch_size, stride=patch_size)
  180. def forward(self, pixel_values: torch.Tensor) -> torch.Tensor:
  181. batch_size, num_channels, height, width = pixel_values.shape
  182. if num_channels != self.num_channels:
  183. raise ValueError(
  184. "Make sure that the channel dimension of the pixel values match with the one set in the configuration."
  185. )
  186. embeddings = self.projection(pixel_values)
  187. patch_height, patch_width = embeddings.shape[2], embeddings.shape[3]
  188. embeddings = embeddings.flatten(2).transpose(1, 2)
  189. return embeddings, (patch_height, patch_width)
  190. class BeitSelfAttention(nn.Module):
  191. def __init__(self, config: BeitConfig, window_size: Optional[tuple] = None) -> None:
  192. super().__init__()
  193. self.config = config
  194. if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"):
  195. raise ValueError(
  196. f"The hidden size {config.hidden_size} is not a multiple of the number of attention "
  197. f"heads {config.num_attention_heads}."
  198. )
  199. self.num_attention_heads = config.num_attention_heads
  200. self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
  201. self.all_head_size = self.num_attention_heads * self.attention_head_size
  202. self.query = nn.Linear(config.hidden_size, self.all_head_size)
  203. self.key = nn.Linear(config.hidden_size, self.all_head_size, bias=False)
  204. self.value = nn.Linear(config.hidden_size, self.all_head_size)
  205. self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
  206. self.has_relative_position_bias = bool(window_size)
  207. if self.has_relative_position_bias:
  208. self.relative_position_bias = BeitRelativePositionBias(config, window_size=window_size)
  209. def forward(
  210. self,
  211. hidden_states: torch.Tensor,
  212. head_mask: Optional[torch.Tensor] = None,
  213. output_attentions: bool = False,
  214. relative_position_bias: Optional[torch.Tensor] = None,
  215. interpolate_pos_encoding: bool = False,
  216. resolution: Optional[tuple[int]] = None,
  217. ) -> Union[tuple[torch.Tensor], tuple[torch.Tensor, torch.Tensor]]:
  218. batch_size, seq_length, _ = hidden_states.shape
  219. query_layer = (
  220. self.query(hidden_states)
  221. .view(batch_size, -1, self.num_attention_heads, self.attention_head_size)
  222. .transpose(1, 2)
  223. )
  224. key_layer = (
  225. self.key(hidden_states)
  226. .view(batch_size, -1, self.num_attention_heads, self.attention_head_size)
  227. .transpose(1, 2)
  228. )
  229. value_layer = (
  230. self.value(hidden_states)
  231. .view(batch_size, -1, self.num_attention_heads, self.attention_head_size)
  232. .transpose(1, 2)
  233. )
  234. # Take the dot product between "query" and "key" to get the raw attention scores.
  235. attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
  236. attention_scores = attention_scores / math.sqrt(self.attention_head_size)
  237. # Add relative position bias if present.
  238. if self.has_relative_position_bias:
  239. height, width = resolution
  240. window_size = (height // self.config.patch_size, width // self.config.patch_size)
  241. attention_scores = attention_scores + self.relative_position_bias(
  242. window_size, interpolate_pos_encoding, dim_size=hidden_states.shape[1]
  243. )
  244. # Add shared relative position bias if provided.
  245. if relative_position_bias is not None:
  246. attention_scores = attention_scores + relative_position_bias
  247. # Normalize the attention scores to probabilities.
  248. attention_probs = nn.functional.softmax(attention_scores, dim=-1)
  249. # This is actually dropping out entire tokens to attend to, which might
  250. # seem a bit unusual, but is taken from the original Transformer paper.
  251. attention_probs = self.dropout(attention_probs)
  252. # Mask heads if we want to
  253. if head_mask is not None:
  254. attention_probs = attention_probs * head_mask
  255. context_layer = torch.matmul(attention_probs, value_layer)
  256. context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
  257. new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
  258. context_layer = context_layer.view(*new_context_layer_shape)
  259. outputs = (context_layer, attention_probs) if output_attentions else (context_layer,)
  260. return outputs
  261. class BeitSdpaSelfAttention(BeitSelfAttention):
  262. def forward(
  263. self,
  264. hidden_states: torch.Tensor,
  265. head_mask: Optional[torch.Tensor] = None,
  266. output_attentions: bool = False,
  267. relative_position_bias: Optional[torch.Tensor] = None,
  268. interpolate_pos_encoding: bool = False,
  269. resolution: Optional[tuple[int]] = None,
  270. ) -> Union[tuple[torch.Tensor], tuple[torch.Tensor, torch.Tensor]]:
  271. if output_attentions or head_mask is not None:
  272. logger.warning_once(
  273. "`BeitSdpaSelfAttention` is used but `torch.nn.functional.scaled_dot_product_attention` does not "
  274. "support `output_attentions=True` or `head_mask`. Falling back to the manual attention implementation, "
  275. "but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. "
  276. 'This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
  277. )
  278. return super().forward(
  279. hidden_states=hidden_states,
  280. head_mask=head_mask,
  281. output_attentions=output_attentions,
  282. relative_position_bias=relative_position_bias,
  283. interpolate_pos_encoding=interpolate_pos_encoding,
  284. resolution=resolution,
  285. )
  286. batch_size, seq_length, _ = hidden_states.shape
  287. query_layer = (
  288. self.query(hidden_states)
  289. .view(batch_size, -1, self.num_attention_heads, self.attention_head_size)
  290. .transpose(1, 2)
  291. )
  292. key_layer = (
  293. self.key(hidden_states)
  294. .view(batch_size, -1, self.num_attention_heads, self.attention_head_size)
  295. .transpose(1, 2)
  296. )
  297. value_layer = (
  298. self.value(hidden_states)
  299. .view(batch_size, -1, self.num_attention_heads, self.attention_head_size)
  300. .transpose(1, 2)
  301. )
  302. attn_bias = None
  303. if self.has_relative_position_bias:
  304. height, width = resolution
  305. window_size = (height // self.config.patch_size, width // self.config.patch_size)
  306. attn_bias = self.relative_position_bias(
  307. window_size, interpolate_pos_encoding, dim_size=hidden_states.shape[1]
  308. )
  309. # Add shared relative position bias if provided.
  310. if relative_position_bias is not None:
  311. if attn_bias is None:
  312. attn_bias = relative_position_bias
  313. else:
  314. attn_bias += relative_position_bias
  315. scaling = 1 / math.sqrt(self.attention_head_size)
  316. context_layer = torch.nn.functional.scaled_dot_product_attention(
  317. query_layer,
  318. key_layer,
  319. value_layer,
  320. attn_mask=attn_bias,
  321. dropout_p=self.config.attention_probs_dropout_prob if self.training else 0.0,
  322. is_causal=False,
  323. scale=scaling,
  324. )
  325. context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
  326. new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
  327. context_layer = context_layer.view(*new_context_layer_shape)
  328. return context_layer, None
  329. class BeitSelfOutput(nn.Module):
  330. """
  331. The residual connection is defined in BeitLayer instead of here (as is the case with other models), due to the
  332. layernorm applied before each block.
  333. """
  334. def __init__(self, config: BeitConfig) -> None:
  335. super().__init__()
  336. self.dense = nn.Linear(config.hidden_size, config.hidden_size)
  337. self.dropout = nn.Dropout(config.hidden_dropout_prob)
  338. def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor, gamma=None) -> torch.Tensor:
  339. hidden_states = self.dense(hidden_states)
  340. hidden_states = self.dropout(hidden_states)
  341. return hidden_states
  342. BEIT_SELF_ATTENTION_CLASSES = {
  343. "eager": BeitSelfAttention,
  344. "sdpa": BeitSdpaSelfAttention,
  345. }
  346. class BeitAttention(nn.Module):
  347. def __init__(self, config: BeitConfig, window_size: Optional[tuple] = None) -> None:
  348. super().__init__()
  349. self.attention = BEIT_SELF_ATTENTION_CLASSES[config._attn_implementation](config, window_size=window_size)
  350. self.output = BeitSelfOutput(config)
  351. self.pruned_heads = set()
  352. def prune_heads(self, heads):
  353. if len(heads) == 0:
  354. return
  355. heads, index = find_pruneable_heads_and_indices(
  356. heads, self.attention.num_attention_heads, self.attention.attention_head_size, self.pruned_heads
  357. )
  358. # Prune linear layers
  359. self.attention.query = prune_linear_layer(self.attention.query, index)
  360. self.attention.key = prune_linear_layer(self.attention.key, index)
  361. self.attention.value = prune_linear_layer(self.attention.value, index)
  362. self.output.dense = prune_linear_layer(self.output.dense, index, dim=1)
  363. # Update hyper params and store pruned heads
  364. self.attention.num_attention_heads = self.attention.num_attention_heads - len(heads)
  365. self.attention.all_head_size = self.attention.attention_head_size * self.attention.num_attention_heads
  366. self.pruned_heads = self.pruned_heads.union(heads)
  367. def forward(
  368. self,
  369. hidden_states: torch.Tensor,
  370. head_mask: Optional[torch.Tensor] = None,
  371. output_attentions: bool = False,
  372. relative_position_bias: Optional[torch.Tensor] = None,
  373. interpolate_pos_encoding: bool = False,
  374. resolution: Optional[tuple[int]] = None,
  375. ) -> Union[tuple[torch.Tensor], tuple[torch.Tensor, torch.Tensor]]:
  376. self_outputs = self.attention(
  377. hidden_states, head_mask, output_attentions, relative_position_bias, interpolate_pos_encoding, resolution
  378. )
  379. attention_output = self.output(self_outputs[0], hidden_states)
  380. outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them
  381. return outputs
  382. class BeitIntermediate(nn.Module):
  383. def __init__(self, config: BeitConfig) -> None:
  384. super().__init__()
  385. self.dense = nn.Linear(config.hidden_size, config.intermediate_size)
  386. if isinstance(config.hidden_act, str):
  387. self.intermediate_act_fn = ACT2FN[config.hidden_act]
  388. else:
  389. self.intermediate_act_fn = config.hidden_act
  390. def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
  391. hidden_states = self.dense(hidden_states)
  392. hidden_states = self.intermediate_act_fn(hidden_states)
  393. return hidden_states
  394. class BeitOutput(nn.Module):
  395. def __init__(self, config: BeitConfig) -> None:
  396. super().__init__()
  397. self.dense = nn.Linear(config.intermediate_size, config.hidden_size)
  398. self.dropout = nn.Dropout(config.hidden_dropout_prob)
  399. def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
  400. hidden_states = self.dense(hidden_states)
  401. hidden_states = self.dropout(hidden_states)
  402. return hidden_states
  403. class BeitLayer(GradientCheckpointingLayer):
  404. """This corresponds to the Block class in the timm implementation."""
  405. def __init__(self, config: BeitConfig, window_size: Optional[tuple] = None, drop_path_rate: float = 0.0) -> None:
  406. super().__init__()
  407. self.chunk_size_feed_forward = config.chunk_size_feed_forward
  408. self.seq_len_dim = 1
  409. self.attention = BeitAttention(config, window_size=window_size)
  410. self.intermediate = BeitIntermediate(config)
  411. self.output = BeitOutput(config)
  412. self.layernorm_before = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
  413. self.drop_path = BeitDropPath(drop_path_rate) if drop_path_rate > 0.0 else nn.Identity()
  414. self.layernorm_after = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
  415. init_values = config.layer_scale_init_value
  416. if init_values > 0:
  417. self.lambda_1 = nn.Parameter(init_values * torch.ones(config.hidden_size), requires_grad=True)
  418. self.lambda_2 = nn.Parameter(init_values * torch.ones(config.hidden_size), requires_grad=True)
  419. else:
  420. self.lambda_1, self.lambda_2 = None, None
  421. def forward(
  422. self,
  423. hidden_states: torch.Tensor,
  424. head_mask: Optional[torch.Tensor] = None,
  425. output_attentions: bool = False,
  426. relative_position_bias: Optional[torch.Tensor] = None,
  427. interpolate_pos_encoding: bool = False,
  428. resolution: Optional[tuple[int, int]] = None,
  429. ) -> Union[tuple[torch.Tensor], tuple[torch.Tensor, torch.Tensor]]:
  430. self_attention_outputs = self.attention(
  431. self.layernorm_before(hidden_states), # in BEiT, layernorm is applied before self-attention
  432. head_mask,
  433. output_attentions=output_attentions,
  434. relative_position_bias=relative_position_bias,
  435. interpolate_pos_encoding=interpolate_pos_encoding,
  436. resolution=resolution,
  437. )
  438. attention_output = self_attention_outputs[0]
  439. outputs = self_attention_outputs[1:] # add self attentions if we output attention weights
  440. # apply lambda_1 if present
  441. if self.lambda_1 is not None:
  442. attention_output = self.lambda_1 * attention_output
  443. # first residual connection
  444. hidden_states = self.drop_path(attention_output) + hidden_states
  445. # in BEiT, layernorm is also applied after self-attention
  446. layer_output = self.layernorm_after(hidden_states)
  447. layer_output = self.intermediate(layer_output)
  448. layer_output = self.output(layer_output)
  449. if self.lambda_2 is not None:
  450. layer_output = self.lambda_2 * layer_output
  451. # second residual connection
  452. layer_output = self.drop_path(layer_output) + hidden_states
  453. outputs = (layer_output,) + outputs
  454. return outputs
  455. class BeitRelativePositionBias(nn.Module):
  456. def __init__(self, config: BeitConfig, window_size: tuple) -> None:
  457. super().__init__()
  458. self.window_size = window_size
  459. self.num_relative_distance = (2 * window_size[0] - 1) * (2 * window_size[1] - 1) + 3
  460. self.relative_position_bias_table = nn.Parameter(
  461. torch.zeros(self.num_relative_distance, config.num_attention_heads)
  462. ) # 2*Wh-1 * 2*Ww-1, nH
  463. # cls to token & token 2 cls & cls to cls
  464. @compile_compatible_method_lru_cache(maxsize=10)
  465. def generate_relative_position_index(self, window_size: tuple[int, int]) -> torch.Tensor:
  466. """
  467. This method creates the relative position index, modified to support arbitrary window sizes,
  468. as introduced in [MiDaS v3.1](https://huggingface.co/papers/2307.14460).
  469. """
  470. num_relative_distance = (2 * window_size[0] - 1) * (2 * window_size[1] - 1) + 3
  471. # cls to token & token 2 cls & cls to cls
  472. # get pair-wise relative position index for each token inside the window
  473. window_area = window_size[0] * window_size[1]
  474. grid = torch.meshgrid(torch.arange(window_size[0]), torch.arange(window_size[1]), indexing="ij")
  475. coords = torch.stack(grid) # 2, Wh, Ww
  476. coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww
  477. relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww
  478. relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2
  479. relative_coords[:, :, 0] += window_size[0] - 1 # shift to start from 0
  480. relative_coords[:, :, 1] += window_size[1] - 1
  481. relative_coords[:, :, 0] *= 2 * window_size[1] - 1
  482. relative_position_index = torch.zeros(size=(window_area + 1,) * 2, dtype=relative_coords.dtype)
  483. relative_position_index[1:, 1:] = relative_coords.sum(-1) # Wh*Ww, Wh*Ww
  484. relative_position_index[0, 0:] = num_relative_distance - 3
  485. relative_position_index[0:, 0] = num_relative_distance - 2
  486. relative_position_index[0, 0] = num_relative_distance - 1
  487. return relative_position_index
  488. def forward(self, window_size, interpolate_pos_encoding: bool = False, dim_size=None) -> torch.Tensor:
  489. """
  490. Modification of timm.models.beit.py: Attention._get_rel_pos_bias to support arbitrary window sizes.
  491. """
  492. old_height = 2 * self.window_size[0] - 1
  493. old_width = 2 * self.window_size[1] - 1
  494. new_height = 2 * window_size[0] - 1
  495. new_width = 2 * window_size[1] - 1
  496. old_relative_position_bias_table = self.relative_position_bias_table
  497. old_num_relative_distance = self.num_relative_distance
  498. new_num_relative_distance = new_height * new_width + 3
  499. old_sub_table = old_relative_position_bias_table[: old_num_relative_distance - 3]
  500. old_sub_table = old_sub_table.reshape(1, old_width, old_height, -1).permute(0, 3, 1, 2)
  501. new_sub_table = nn.functional.interpolate(
  502. old_sub_table, size=(torch_int(new_height), torch_int(new_width)), mode="bilinear"
  503. )
  504. new_sub_table = new_sub_table.permute(0, 2, 3, 1).reshape(new_num_relative_distance - 3, -1)
  505. new_relative_position_bias_table = torch.cat(
  506. [new_sub_table, old_relative_position_bias_table[old_num_relative_distance - 3 :]]
  507. )
  508. relative_position_index = self.generate_relative_position_index(window_size)
  509. relative_position_bias = new_relative_position_bias_table[relative_position_index.view(-1)]
  510. # patch_size*num_patches_height, patch_size*num_patches_width, num_attention_heads
  511. relative_position_bias = relative_position_bias.view(
  512. window_size[0] * window_size[1] + 1, window_size[0] * window_size[1] + 1, -1
  513. )
  514. # num_attention_heads, patch_size*num_patches_width, patch_size*num_patches_height
  515. relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous()
  516. if interpolate_pos_encoding:
  517. relative_position_bias = nn.functional.interpolate(
  518. relative_position_bias.unsqueeze(1),
  519. size=(dim_size, dim_size),
  520. mode="bilinear",
  521. align_corners=False,
  522. ).squeeze(1)
  523. return relative_position_bias.unsqueeze(0)
  524. class BeitEncoder(nn.Module):
  525. def __init__(self, config: BeitConfig, window_size: Optional[tuple] = None) -> None:
  526. super().__init__()
  527. self.config = config
  528. self.has_relative_position_bias = config.use_shared_relative_position_bias
  529. if self.has_relative_position_bias:
  530. self.relative_position_bias = BeitRelativePositionBias(config, window_size=window_size)
  531. # stochastic depth decay rule
  532. dpr = [x.item() for x in torch.linspace(0, config.drop_path_rate, config.num_hidden_layers, device="cpu")]
  533. self.layer = nn.ModuleList(
  534. [
  535. BeitLayer(
  536. config,
  537. window_size=window_size if config.use_relative_position_bias else None,
  538. drop_path_rate=dpr[i],
  539. )
  540. for i in range(config.num_hidden_layers)
  541. ]
  542. )
  543. self.gradient_checkpointing = False
  544. def forward(
  545. self,
  546. hidden_states: torch.Tensor,
  547. head_mask: Optional[torch.Tensor] = None,
  548. output_attentions: bool = False,
  549. output_hidden_states: bool = False,
  550. interpolate_pos_encoding: bool = False,
  551. resolution: Optional[tuple[int, int]] = None,
  552. return_dict: bool = True,
  553. ) -> Union[tuple, BaseModelOutput]:
  554. all_hidden_states = () if output_hidden_states else None
  555. all_self_attentions = () if output_attentions else None
  556. for i, layer_module in enumerate(self.layer):
  557. if output_hidden_states:
  558. all_hidden_states = all_hidden_states + (hidden_states,)
  559. if self.has_relative_position_bias:
  560. height, width = resolution
  561. window_size = (height // self.config.patch_size, width // self.config.patch_size)
  562. relative_position_bias = self.relative_position_bias(
  563. window_size, interpolate_pos_encoding=interpolate_pos_encoding, dim_size=hidden_states.shape[1]
  564. )
  565. else:
  566. relative_position_bias = None
  567. layer_head_mask = head_mask[i] if head_mask is not None else None
  568. layer_outputs = layer_module(
  569. hidden_states,
  570. head_mask=layer_head_mask,
  571. output_attentions=output_attentions,
  572. relative_position_bias=relative_position_bias,
  573. interpolate_pos_encoding=interpolate_pos_encoding,
  574. resolution=resolution,
  575. )
  576. hidden_states = layer_outputs[0]
  577. if output_attentions:
  578. all_self_attentions = all_self_attentions + (layer_outputs[1],)
  579. if output_hidden_states:
  580. all_hidden_states = all_hidden_states + (hidden_states,)
  581. if not return_dict:
  582. return tuple(v for v in [hidden_states, all_hidden_states, all_self_attentions] if v is not None)
  583. return BaseModelOutput(
  584. last_hidden_state=hidden_states,
  585. hidden_states=all_hidden_states,
  586. attentions=all_self_attentions,
  587. )
  588. @auto_docstring
  589. class BeitPreTrainedModel(PreTrainedModel):
  590. config: BeitConfig
  591. base_model_prefix = "beit"
  592. main_input_name = "pixel_values"
  593. supports_gradient_checkpointing = True
  594. _no_split_modules = ["BeitLayer"]
  595. _keys_to_ignore_on_load_unexpected = [r".*relative_position_index.*"]
  596. _supports_sdpa = True
  597. def _init_weights(self, module):
  598. """Initialize the weights"""
  599. if isinstance(module, (nn.Linear, nn.Conv2d, nn.ConvTranspose2d)):
  600. # Slightly different from the TF version which uses truncated_normal for initialization
  601. # cf https://github.com/pytorch/pytorch/pull/5617
  602. module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
  603. if module.bias is not None:
  604. module.bias.data.zero_()
  605. elif isinstance(module, nn.Embedding):
  606. module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
  607. if module.padding_idx is not None:
  608. module.weight.data[module.padding_idx].zero_()
  609. elif isinstance(module, nn.LayerNorm):
  610. module.bias.data.zero_()
  611. module.weight.data.fill_(1.0)
  612. elif isinstance(module, BeitEmbeddings):
  613. module.cls_token.data.zero_()
  614. if module.mask_token is not None:
  615. module.mask_token.data.zero_()
  616. if module.position_embeddings is not None:
  617. module.position_embeddings.data.zero_()
  618. elif isinstance(module, BeitRelativePositionBias):
  619. module.relative_position_bias_table.data.zero_()
  620. elif isinstance(module, BeitLayer):
  621. if module.lambda_1 is not None:
  622. module.lambda_1.data.fill_(self.config.layer_scale_init_value)
  623. module.lambda_2.data.fill_(self.config.layer_scale_init_value)
  624. @auto_docstring
  625. class BeitModel(BeitPreTrainedModel):
  626. def __init__(self, config: BeitConfig, add_pooling_layer: bool = True) -> None:
  627. r"""
  628. add_pooling_layer (bool, *optional*, defaults to `True`):
  629. Whether to add a pooling layer
  630. """
  631. super().__init__(config)
  632. self.config = config
  633. self.embeddings = BeitEmbeddings(config)
  634. self.encoder = BeitEncoder(config, window_size=self.embeddings.patch_embeddings.patch_shape)
  635. self.layernorm = (
  636. nn.Identity() if config.use_mean_pooling else nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
  637. )
  638. self.pooler = BeitPooler(config) if add_pooling_layer else None
  639. # Initialize weights and apply final processing
  640. self.post_init()
  641. def get_input_embeddings(self):
  642. return self.embeddings.patch_embeddings
  643. def _prune_heads(self, heads_to_prune):
  644. """
  645. Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base
  646. class PreTrainedModel
  647. """
  648. for layer, heads in heads_to_prune.items():
  649. self.encoder.layer[layer].attention.prune_heads(heads)
  650. @auto_docstring
  651. def forward(
  652. self,
  653. pixel_values: torch.Tensor,
  654. bool_masked_pos: Optional[torch.BoolTensor] = None,
  655. head_mask: Optional[torch.Tensor] = None,
  656. output_attentions: Optional[bool] = None,
  657. output_hidden_states: Optional[bool] = None,
  658. interpolate_pos_encoding: bool = False,
  659. return_dict: Optional[bool] = None,
  660. ) -> Union[tuple, BeitModelOutputWithPooling]:
  661. r"""
  662. bool_masked_pos (`torch.BoolTensor` of shape `(batch_size, num_patches)`, *optional*):
  663. Boolean masked positions. Indicates which patches are masked (1) and which aren't (0).
  664. """
  665. output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
  666. output_hidden_states = (
  667. output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
  668. )
  669. return_dict = return_dict if return_dict is not None else self.config.use_return_dict
  670. # Prepare head mask if needed
  671. # 1.0 in head_mask indicate we keep the head
  672. # attention_probs has shape bsz x n_heads x N x N
  673. # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
  674. # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
  675. head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
  676. embedding_output, _ = self.embeddings(pixel_values, bool_masked_pos=bool_masked_pos)
  677. resolution = pixel_values.shape[2:]
  678. encoder_outputs = self.encoder(
  679. embedding_output,
  680. head_mask=head_mask,
  681. output_attentions=output_attentions,
  682. output_hidden_states=output_hidden_states,
  683. resolution=resolution,
  684. return_dict=return_dict,
  685. interpolate_pos_encoding=interpolate_pos_encoding,
  686. )
  687. sequence_output = encoder_outputs[0]
  688. sequence_output = self.layernorm(sequence_output)
  689. pooled_output = self.pooler(sequence_output) if self.pooler is not None else None
  690. if not return_dict:
  691. head_outputs = (sequence_output, pooled_output) if pooled_output is not None else (sequence_output,)
  692. return head_outputs + encoder_outputs[1:]
  693. return BeitModelOutputWithPooling(
  694. last_hidden_state=sequence_output,
  695. pooler_output=pooled_output,
  696. hidden_states=encoder_outputs.hidden_states,
  697. attentions=encoder_outputs.attentions,
  698. )
  699. class BeitPooler(nn.Module):
  700. def __init__(self, config: BeitConfig) -> None:
  701. super().__init__()
  702. self.layernorm = (
  703. nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) if config.use_mean_pooling else None
  704. )
  705. def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
  706. if self.layernorm is not None:
  707. # Mean pool the final hidden states of the patch tokens
  708. patch_tokens = hidden_states[:, 1:, :]
  709. pooled_output = self.layernorm(patch_tokens.mean(1))
  710. else:
  711. # Pool by simply taking the final hidden state of the [CLS] token
  712. pooled_output = hidden_states[:, 0]
  713. return pooled_output
  714. @auto_docstring(
  715. custom_intro="""
  716. Beit Model transformer with a 'language' modeling head on top. BEiT does masked image modeling by predicting
  717. visual tokens of a Vector-Quantize Variational Autoencoder (VQ-VAE), whereas other vision models like ViT and DeiT
  718. predict RGB pixel values. As a result, this class is incompatible with [`AutoModelForMaskedImageModeling`], so you
  719. will need to use [`BeitForMaskedImageModeling`] directly if you wish to do masked image modeling with BEiT.
  720. """
  721. )
  722. class BeitForMaskedImageModeling(BeitPreTrainedModel):
  723. def __init__(self, config: BeitConfig) -> None:
  724. super().__init__(config)
  725. self.num_labels = config.num_labels
  726. self.beit = BeitModel(config, add_pooling_layer=False)
  727. # Classifier head
  728. self.layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
  729. self.lm_head = nn.Linear(config.hidden_size, config.vocab_size)
  730. # Initialize weights and apply final processing
  731. self.post_init()
  732. def get_output_embeddings(self):
  733. return None
  734. @auto_docstring
  735. def forward(
  736. self,
  737. pixel_values: Optional[torch.Tensor] = None,
  738. bool_masked_pos: Optional[torch.BoolTensor] = None,
  739. head_mask: Optional[torch.Tensor] = None,
  740. labels: Optional[torch.Tensor] = None,
  741. output_attentions: Optional[bool] = None,
  742. output_hidden_states: Optional[bool] = None,
  743. interpolate_pos_encoding: bool = False,
  744. return_dict: Optional[bool] = None,
  745. ) -> Union[tuple, MaskedLMOutput]:
  746. r"""
  747. bool_masked_pos (`torch.BoolTensor` of shape `(batch_size, num_patches)`):
  748. Boolean masked positions. Indicates which patches are masked (1) and which aren't (0).
  749. labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
  750. Labels for computing the image classification/regression loss. Indices should be in `[0, ...,
  751. config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
  752. `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
  753. Examples:
  754. ```python
  755. >>> from transformers import AutoImageProcessor, BeitForMaskedImageModeling
  756. >>> import torch
  757. >>> from PIL import Image
  758. >>> import requests
  759. >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
  760. >>> image = Image.open(requests.get(url, stream=True).raw)
  761. >>> image_processor = AutoImageProcessor.from_pretrained("microsoft/beit-base-patch16-224-pt22k")
  762. >>> model = BeitForMaskedImageModeling.from_pretrained("microsoft/beit-base-patch16-224-pt22k")
  763. >>> num_patches = (model.config.image_size // model.config.patch_size) ** 2
  764. >>> pixel_values = image_processor(images=image, return_tensors="pt").pixel_values
  765. >>> # create random boolean mask of shape (batch_size, num_patches)
  766. >>> bool_masked_pos = torch.randint(low=0, high=2, size=(1, num_patches)).bool()
  767. >>> outputs = model(pixel_values, bool_masked_pos=bool_masked_pos)
  768. >>> loss, logits = outputs.loss, outputs.logits
  769. >>> list(logits.shape)
  770. [1, 196, 8192]
  771. ```"""
  772. return_dict = return_dict if return_dict is not None else self.config.use_return_dict
  773. outputs = self.beit(
  774. pixel_values,
  775. bool_masked_pos=bool_masked_pos,
  776. head_mask=head_mask,
  777. output_attentions=output_attentions,
  778. output_hidden_states=output_hidden_states,
  779. interpolate_pos_encoding=interpolate_pos_encoding,
  780. return_dict=return_dict,
  781. )
  782. sequence_output = outputs[0]
  783. sequence_output = self.layernorm(sequence_output)
  784. prediction_scores = self.lm_head(sequence_output[:, 1:])
  785. masked_lm_loss = None
  786. if labels is not None:
  787. loss_fct = CrossEntropyLoss() # -100 index = padding token
  788. masked_lm_loss = loss_fct(prediction_scores[bool_masked_pos], labels)
  789. if not return_dict:
  790. output = (prediction_scores,) + outputs[1:]
  791. return ((masked_lm_loss,) + output) if masked_lm_loss is not None else output
  792. return MaskedLMOutput(
  793. loss=masked_lm_loss,
  794. logits=prediction_scores,
  795. hidden_states=outputs.hidden_states,
  796. attentions=outputs.attentions,
  797. )
  798. @auto_docstring(
  799. custom_intro="""
  800. Beit Model transformer with an image classification head on top (a linear layer on top of the average of the final
  801. hidden states of the patch tokens) e.g. for ImageNet.
  802. """
  803. )
  804. class BeitForImageClassification(BeitPreTrainedModel):
  805. def __init__(self, config: BeitConfig) -> None:
  806. super().__init__(config)
  807. self.num_labels = config.num_labels
  808. self.beit = BeitModel(config, add_pooling_layer=True)
  809. # Classifier head
  810. self.classifier = nn.Linear(config.hidden_size, config.num_labels) if config.num_labels > 0 else nn.Identity()
  811. # Initialize weights and apply final processing
  812. self.post_init()
  813. @auto_docstring
  814. def forward(
  815. self,
  816. pixel_values: Optional[torch.Tensor] = None,
  817. head_mask: Optional[torch.Tensor] = None,
  818. labels: Optional[torch.Tensor] = None,
  819. output_attentions: Optional[bool] = None,
  820. output_hidden_states: Optional[bool] = None,
  821. interpolate_pos_encoding: bool = False,
  822. return_dict: Optional[bool] = None,
  823. ) -> Union[tuple, ImageClassifierOutput]:
  824. r"""
  825. labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
  826. Labels for computing the image classification/regression loss. Indices should be in `[0, ...,
  827. config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
  828. `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
  829. """
  830. return_dict = return_dict if return_dict is not None else self.config.use_return_dict
  831. outputs = self.beit(
  832. pixel_values,
  833. head_mask=head_mask,
  834. output_attentions=output_attentions,
  835. output_hidden_states=output_hidden_states,
  836. interpolate_pos_encoding=interpolate_pos_encoding,
  837. return_dict=return_dict,
  838. )
  839. pooled_output = outputs.pooler_output if return_dict else outputs[1]
  840. logits = self.classifier(pooled_output)
  841. loss = None
  842. if labels is not None:
  843. loss = self.loss_function(labels, logits, self.config)
  844. if not return_dict:
  845. output = (logits,) + outputs[2:]
  846. return ((loss,) + output) if loss is not None else output
  847. return ImageClassifierOutput(
  848. loss=loss,
  849. logits=logits,
  850. hidden_states=outputs.hidden_states,
  851. attentions=outputs.attentions,
  852. )
  853. class BeitConvModule(nn.Module):
  854. """
  855. A convolutional block that bundles conv/norm/activation layers. This block simplifies the usage of convolution
  856. layers, which are commonly used with a norm layer (e.g., BatchNorm) and activation layer (e.g., ReLU).
  857. Based on OpenMMLab's implementation, found in https://github.com/open-mmlab/mmsegmentation.
  858. """
  859. def __init__(
  860. self,
  861. in_channels: int,
  862. out_channels: int,
  863. kernel_size: Union[int, tuple[int, int]],
  864. padding: Union[int, tuple[int, int], str] = 0,
  865. bias: bool = False,
  866. dilation: Union[int, tuple[int, int]] = 1,
  867. ) -> None:
  868. super().__init__()
  869. self.conv = nn.Conv2d(
  870. in_channels=in_channels,
  871. out_channels=out_channels,
  872. kernel_size=kernel_size,
  873. padding=padding,
  874. bias=bias,
  875. dilation=dilation,
  876. )
  877. self.bn = nn.BatchNorm2d(out_channels)
  878. self.activation = nn.ReLU()
  879. def forward(self, input: torch.Tensor) -> torch.Tensor:
  880. output = self.conv(input)
  881. output = self.bn(output)
  882. output = self.activation(output)
  883. return output
  884. class BeitPyramidPoolingBlock(nn.Module):
  885. def __init__(self, pool_scale: int, in_channels: int, channels: int) -> None:
  886. super().__init__()
  887. self.layers = [
  888. nn.AdaptiveAvgPool2d(pool_scale),
  889. BeitConvModule(in_channels, channels, kernel_size=1),
  890. ]
  891. for i, layer in enumerate(self.layers):
  892. self.add_module(str(i), layer)
  893. def forward(self, input: torch.Tensor) -> torch.Tensor:
  894. hidden_state = input
  895. for layer in self.layers:
  896. hidden_state = layer(hidden_state)
  897. return hidden_state
  898. class BeitPyramidPoolingModule(nn.Module):
  899. """
  900. Pyramid Pooling Module (PPM) used in PSPNet.
  901. Args:
  902. pool_scales (tuple[int]): Pooling scales used in Pooling Pyramid
  903. Module.
  904. in_channels (int): Input channels.
  905. channels (int): Channels after modules, before conv_seg.
  906. align_corners (bool): align_corners argument of F.interpolate.
  907. Based on OpenMMLab's implementation, found in https://github.com/open-mmlab/mmsegmentation.
  908. """
  909. def __init__(self, pool_scales: tuple[int, ...], in_channels: int, channels: int, align_corners: bool) -> None:
  910. super().__init__()
  911. self.pool_scales = pool_scales
  912. self.align_corners = align_corners
  913. self.in_channels = in_channels
  914. self.channels = channels
  915. self.blocks = []
  916. for i, pool_scale in enumerate(pool_scales):
  917. block = BeitPyramidPoolingBlock(pool_scale=pool_scale, in_channels=in_channels, channels=channels)
  918. self.blocks.append(block)
  919. self.add_module(str(i), block)
  920. def forward(self, x: torch.Tensor) -> list[torch.Tensor]:
  921. ppm_outs = []
  922. for ppm in self.blocks:
  923. ppm_out = ppm(x)
  924. upsampled_ppm_out = nn.functional.interpolate(
  925. ppm_out, size=x.size()[2:], mode="bilinear", align_corners=self.align_corners
  926. )
  927. ppm_outs.append(upsampled_ppm_out)
  928. return ppm_outs
  929. class BeitUperHead(nn.Module):
  930. """
  931. Unified Perceptual Parsing for Scene Understanding. This head is the implementation of
  932. [UPerNet](https://huggingface.co/papers/1807.10221).
  933. Based on OpenMMLab's implementation, found in https://github.com/open-mmlab/mmsegmentation.
  934. """
  935. def __init__(self, config: BeitConfig) -> None:
  936. super().__init__()
  937. self.pool_scales = config.pool_scales # e.g. (1, 2, 3, 6)
  938. self.in_channels = [config.hidden_size] * 4 # e.g. [768, 768, 768, 768]
  939. self.channels = config.hidden_size
  940. self.align_corners = False
  941. self.classifier = nn.Conv2d(self.channels, config.num_labels, kernel_size=1)
  942. # PSP Module
  943. self.psp_modules = BeitPyramidPoolingModule(
  944. self.pool_scales,
  945. self.in_channels[-1],
  946. self.channels,
  947. align_corners=self.align_corners,
  948. )
  949. self.bottleneck = BeitConvModule(
  950. self.in_channels[-1] + len(self.pool_scales) * self.channels,
  951. self.channels,
  952. kernel_size=3,
  953. padding=1,
  954. )
  955. # FPN Module
  956. self.lateral_convs = nn.ModuleList()
  957. self.fpn_convs = nn.ModuleList()
  958. for in_channels in self.in_channels[:-1]: # skip the top layer
  959. l_conv = BeitConvModule(in_channels, self.channels, kernel_size=1)
  960. fpn_conv = BeitConvModule(self.channels, self.channels, kernel_size=3, padding=1)
  961. self.lateral_convs.append(l_conv)
  962. self.fpn_convs.append(fpn_conv)
  963. self.fpn_bottleneck = BeitConvModule(
  964. len(self.in_channels) * self.channels,
  965. self.channels,
  966. kernel_size=3,
  967. padding=1,
  968. )
  969. def psp_forward(self, inputs):
  970. x = inputs[-1]
  971. psp_outs = [x]
  972. psp_outs.extend(self.psp_modules(x))
  973. psp_outs = torch.cat(psp_outs, dim=1)
  974. output = self.bottleneck(psp_outs)
  975. return output
  976. def forward(self, encoder_hidden_states: torch.Tensor) -> torch.Tensor:
  977. # build laterals
  978. laterals = [lateral_conv(encoder_hidden_states[i]) for i, lateral_conv in enumerate(self.lateral_convs)]
  979. laterals.append(self.psp_forward(encoder_hidden_states))
  980. # build top-down path
  981. used_backbone_levels = len(laterals)
  982. for i in range(used_backbone_levels - 1, 0, -1):
  983. prev_shape = laterals[i - 1].shape[2:]
  984. laterals[i - 1] = laterals[i - 1] + nn.functional.interpolate(
  985. laterals[i], size=prev_shape, mode="bilinear", align_corners=self.align_corners
  986. )
  987. # build outputs
  988. fpn_outs = [self.fpn_convs[i](laterals[i]) for i in range(used_backbone_levels - 1)]
  989. # append psp feature
  990. fpn_outs.append(laterals[-1])
  991. for i in range(used_backbone_levels - 1, 0, -1):
  992. fpn_outs[i] = nn.functional.interpolate(
  993. fpn_outs[i], size=fpn_outs[0].shape[2:], mode="bilinear", align_corners=self.align_corners
  994. )
  995. fpn_outs = torch.cat(fpn_outs, dim=1)
  996. output = self.fpn_bottleneck(fpn_outs)
  997. output = self.classifier(output)
  998. return output
  999. class BeitFCNHead(nn.Module):
  1000. """
  1001. Fully Convolution Networks for Semantic Segmentation. This head is implemented of
  1002. [FCNNet](https://huggingface.co/papers/1411.4038>).
  1003. Args:
  1004. config (BeitConfig): Configuration.
  1005. in_channels
  1006. kernel_size (int): The kernel size for convs in the head. Default: 3.
  1007. dilation (int): The dilation rate for convs in the head. Default: 1.
  1008. Based on OpenMMLab's implementation, found in https://github.com/open-mmlab/mmsegmentation.
  1009. """
  1010. def __init__(
  1011. self, config: BeitConfig, in_index: int = 2, kernel_size: int = 3, dilation: Union[int, tuple[int, int]] = 1
  1012. ) -> None:
  1013. super().__init__()
  1014. self.in_channels = config.hidden_size
  1015. self.channels = config.auxiliary_channels
  1016. self.num_convs = config.auxiliary_num_convs
  1017. self.concat_input = config.auxiliary_concat_input
  1018. self.in_index = in_index
  1019. conv_padding = (kernel_size // 2) * dilation
  1020. convs = []
  1021. convs.append(
  1022. BeitConvModule(
  1023. self.in_channels, self.channels, kernel_size=kernel_size, padding=conv_padding, dilation=dilation
  1024. )
  1025. )
  1026. for i in range(self.num_convs - 1):
  1027. convs.append(
  1028. BeitConvModule(
  1029. self.channels, self.channels, kernel_size=kernel_size, padding=conv_padding, dilation=dilation
  1030. )
  1031. )
  1032. if self.num_convs == 0:
  1033. self.convs = nn.Identity()
  1034. else:
  1035. self.convs = nn.Sequential(*convs)
  1036. if self.concat_input:
  1037. self.conv_cat = BeitConvModule(
  1038. self.in_channels + self.channels, self.channels, kernel_size=kernel_size, padding=kernel_size // 2
  1039. )
  1040. self.classifier = nn.Conv2d(self.channels, config.num_labels, kernel_size=1)
  1041. def forward(self, encoder_hidden_states: torch.Tensor) -> torch.Tensor:
  1042. # just take the relevant feature maps
  1043. hidden_states = encoder_hidden_states[self.in_index]
  1044. output = self.convs(hidden_states)
  1045. if self.concat_input:
  1046. output = self.conv_cat(torch.cat([hidden_states, output], dim=1))
  1047. output = self.classifier(output)
  1048. return output
  1049. @auto_docstring
  1050. class BeitForSemanticSegmentation(BeitPreTrainedModel):
  1051. def __init__(self, config: BeitConfig) -> None:
  1052. super().__init__(config)
  1053. self.num_labels = config.num_labels
  1054. self.beit = BeitModel(config, add_pooling_layer=False)
  1055. # FPNs
  1056. if len(self.config.out_indices) != 4:
  1057. raise ValueError(
  1058. "BeitForSemanticSegmentation requires config.out_indices to be a list of 4 integers, "
  1059. "specifying which features to use from the backbone. One can use [3, 5, 7, 11] in case of "
  1060. "a base-sized architecture."
  1061. )
  1062. self.fpn1 = nn.Sequential(
  1063. nn.ConvTranspose2d(config.hidden_size, config.hidden_size, kernel_size=2, stride=2),
  1064. nn.BatchNorm2d(config.hidden_size),
  1065. nn.GELU(),
  1066. nn.ConvTranspose2d(config.hidden_size, config.hidden_size, kernel_size=2, stride=2),
  1067. )
  1068. self.fpn2 = nn.Sequential(
  1069. nn.ConvTranspose2d(config.hidden_size, config.hidden_size, kernel_size=2, stride=2),
  1070. )
  1071. self.fpn3 = nn.Identity()
  1072. self.fpn4 = nn.MaxPool2d(kernel_size=2, stride=2)
  1073. # Semantic segmentation head(s)
  1074. self.decode_head = BeitUperHead(config)
  1075. self.auxiliary_head = BeitFCNHead(config) if config.use_auxiliary_head else None
  1076. # Initialize weights and apply final processing
  1077. self.post_init()
  1078. def compute_loss(self, logits, auxiliary_logits, labels):
  1079. # upsample logits to the images' original size
  1080. upsampled_logits = nn.functional.interpolate(
  1081. logits, size=labels.shape[-2:], mode="bilinear", align_corners=False
  1082. )
  1083. if auxiliary_logits is not None:
  1084. upsampled_auxiliary_logits = nn.functional.interpolate(
  1085. auxiliary_logits, size=labels.shape[-2:], mode="bilinear", align_corners=False
  1086. )
  1087. # compute weighted loss
  1088. loss_fct = CrossEntropyLoss(ignore_index=self.config.semantic_loss_ignore_index)
  1089. main_loss = loss_fct(upsampled_logits, labels)
  1090. loss = main_loss
  1091. if auxiliary_logits is not None:
  1092. auxiliary_loss = loss_fct(upsampled_auxiliary_logits, labels)
  1093. loss += self.config.auxiliary_loss_weight * auxiliary_loss
  1094. return loss
  1095. @auto_docstring
  1096. def forward(
  1097. self,
  1098. pixel_values: Optional[torch.Tensor] = None,
  1099. head_mask: Optional[torch.Tensor] = None,
  1100. labels: Optional[torch.Tensor] = None,
  1101. output_attentions: Optional[bool] = None,
  1102. output_hidden_states: Optional[bool] = None,
  1103. interpolate_pos_encoding: bool = False,
  1104. return_dict: Optional[bool] = None,
  1105. ) -> Union[tuple, SemanticSegmenterOutput]:
  1106. r"""
  1107. labels (`torch.LongTensor` of shape `(batch_size, height, width)`, *optional*):
  1108. Ground truth semantic segmentation maps for computing the loss. Indices should be in `[0, ...,
  1109. config.num_labels - 1]`. If `config.num_labels > 1`, a classification loss is computed (Cross-Entropy).
  1110. Examples:
  1111. ```python
  1112. >>> from transformers import AutoImageProcessor, BeitForSemanticSegmentation
  1113. >>> from PIL import Image
  1114. >>> import requests
  1115. >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
  1116. >>> image = Image.open(requests.get(url, stream=True).raw)
  1117. >>> image_processor = AutoImageProcessor.from_pretrained("microsoft/beit-base-finetuned-ade-640-640")
  1118. >>> model = BeitForSemanticSegmentation.from_pretrained("microsoft/beit-base-finetuned-ade-640-640")
  1119. >>> inputs = image_processor(images=image, return_tensors="pt")
  1120. >>> outputs = model(**inputs)
  1121. >>> # logits are of shape (batch_size, num_labels, height, width)
  1122. >>> logits = outputs.logits
  1123. ```"""
  1124. return_dict = return_dict if return_dict is not None else self.config.use_return_dict
  1125. output_hidden_states = (
  1126. output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
  1127. )
  1128. if labels is not None and self.config.num_labels == 1:
  1129. raise ValueError("The number of labels should be greater than one")
  1130. outputs = self.beit(
  1131. pixel_values,
  1132. head_mask=head_mask,
  1133. output_attentions=output_attentions,
  1134. output_hidden_states=True, # we need the intermediate hidden states
  1135. interpolate_pos_encoding=interpolate_pos_encoding,
  1136. return_dict=return_dict,
  1137. )
  1138. encoder_hidden_states = outputs.hidden_states if return_dict else outputs[1]
  1139. # only keep certain features, and reshape
  1140. # note that we do +1 as the encoder_hidden_states also includes the initial embeddings
  1141. features = [feature for idx, feature in enumerate(encoder_hidden_states) if idx + 1 in self.config.out_indices]
  1142. batch_size = pixel_values.shape[0]
  1143. patch_resolution = self.config.image_size // self.config.patch_size
  1144. features = [
  1145. x[:, 1:, :].permute(0, 2, 1).reshape(batch_size, -1, patch_resolution, patch_resolution) for x in features
  1146. ]
  1147. # apply FPNs
  1148. ops = [self.fpn1, self.fpn2, self.fpn3, self.fpn4]
  1149. for i in range(len(features)):
  1150. features[i] = ops[i](features[i])
  1151. logits = self.decode_head(features)
  1152. auxiliary_logits = None
  1153. if self.auxiliary_head is not None:
  1154. auxiliary_logits = self.auxiliary_head(features)
  1155. loss = None
  1156. if labels is not None:
  1157. loss = self.compute_loss(logits, auxiliary_logits, labels)
  1158. if not return_dict:
  1159. if output_hidden_states:
  1160. output = (logits,) + outputs[1:]
  1161. else:
  1162. output = (logits,) + outputs[2:]
  1163. return ((loss,) + output) if loss is not None else output
  1164. return SemanticSegmenterOutput(
  1165. loss=loss,
  1166. logits=logits,
  1167. hidden_states=outputs.hidden_states if output_hidden_states else None,
  1168. attentions=outputs.attentions,
  1169. )
  1170. @auto_docstring(
  1171. custom_intro="""
  1172. BEiT backbone, to be used with frameworks like DETR and MaskFormer.
  1173. """
  1174. )
  1175. class BeitBackbone(BeitPreTrainedModel, BackboneMixin):
  1176. def __init__(self, config):
  1177. super().__init__(config)
  1178. super()._init_backbone(config)
  1179. self.num_features = [config.hidden_size for _ in range(config.num_hidden_layers + 1)]
  1180. self.embeddings = BeitEmbeddings(config)
  1181. self.encoder = BeitEncoder(config, window_size=self.embeddings.patch_embeddings.patch_shape)
  1182. if config.add_fpn:
  1183. if len(self.config.out_indices) != 4:
  1184. raise ValueError(
  1185. "BeitBackbone requires config.out_indices to be a list of 4 integers, "
  1186. "specifying which features to use from the backbone. One can use [3, 5, 7, 11] in case of "
  1187. "a base-sized architecture."
  1188. )
  1189. hidden_size = config.hidden_size
  1190. self.fpn1 = nn.Sequential(
  1191. nn.ConvTranspose2d(hidden_size, hidden_size, kernel_size=2, stride=2),
  1192. nn.BatchNorm2d(hidden_size, eps=config.batch_norm_eps),
  1193. nn.GELU(),
  1194. nn.ConvTranspose2d(hidden_size, hidden_size, kernel_size=2, stride=2),
  1195. )
  1196. self.fpn2 = nn.Sequential(nn.ConvTranspose2d(hidden_size, hidden_size, kernel_size=2, stride=2))
  1197. self.fpn3 = nn.Identity()
  1198. self.fpn4 = nn.MaxPool2d(kernel_size=2, stride=2)
  1199. # initialize weights and apply final processing
  1200. self.post_init()
  1201. def get_input_embeddings(self):
  1202. return self.embeddings.patch_embeddings
  1203. @auto_docstring
  1204. def forward(
  1205. self,
  1206. pixel_values: Tensor,
  1207. output_hidden_states: Optional[bool] = None,
  1208. output_attentions: Optional[bool] = None,
  1209. return_dict: Optional[bool] = None,
  1210. ) -> BackboneOutput:
  1211. r"""
  1212. Examples:
  1213. ```python
  1214. >>> from transformers import AutoImageProcessor, AutoBackbone
  1215. >>> import torch
  1216. >>> from PIL import Image
  1217. >>> import requests
  1218. >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
  1219. >>> image = Image.open(requests.get(url, stream=True).raw)
  1220. >>> processor = AutoImageProcessor.from_pretrained("microsoft/beit-base-patch16-224")
  1221. >>> model = AutoBackbone.from_pretrained(
  1222. ... "microsoft/beit-base-patch16-224", out_features=["stage1", "stage2", "stage3", "stage4"]
  1223. ... )
  1224. >>> inputs = processor(image, return_tensors="pt")
  1225. >>> outputs = model(**inputs)
  1226. >>> feature_maps = outputs.feature_maps
  1227. >>> list(feature_maps[-1].shape)
  1228. [1, 768, 14, 14]
  1229. ```"""
  1230. return_dict = return_dict if return_dict is not None else self.config.use_return_dict
  1231. output_hidden_states = (
  1232. output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
  1233. )
  1234. output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
  1235. batch_size = pixel_values.shape[0]
  1236. embedding_output, (patch_height, patch_width) = self.embeddings(pixel_values)
  1237. resolution = pixel_values.shape[2:]
  1238. outputs = self.encoder(
  1239. embedding_output,
  1240. output_hidden_states=True,
  1241. output_attentions=output_attentions,
  1242. resolution=resolution,
  1243. return_dict=return_dict,
  1244. )
  1245. hidden_states = outputs.hidden_states if return_dict else outputs[1]
  1246. feature_maps = ()
  1247. for stage, hidden_state in zip(self.stage_names, hidden_states):
  1248. if stage in self.out_features:
  1249. if self.config.reshape_hidden_states:
  1250. hidden_state = hidden_state[:, 1:, :]
  1251. hidden_state = hidden_state.permute(0, 2, 1)
  1252. hidden_state = hidden_state.reshape(batch_size, -1, patch_height, patch_width)
  1253. feature_maps += (hidden_state,)
  1254. if self.config.add_fpn:
  1255. feature_maps = [
  1256. self.fpn1(feature_maps[0]),
  1257. self.fpn2(feature_maps[1]),
  1258. self.fpn3(feature_maps[2]),
  1259. self.fpn4(feature_maps[3]),
  1260. ]
  1261. feature_maps = tuple(feature_maps)
  1262. if not return_dict:
  1263. if output_hidden_states:
  1264. output = (feature_maps,) + outputs[1:]
  1265. else:
  1266. output = (feature_maps,) + outputs[2:]
  1267. return output
  1268. return BackboneOutput(
  1269. feature_maps=feature_maps,
  1270. hidden_states=outputs.hidden_states if output_hidden_states else None,
  1271. attentions=outputs.attentions,
  1272. )
  1273. __all__ = [
  1274. "BeitForImageClassification",
  1275. "BeitForMaskedImageModeling",
  1276. "BeitForSemanticSegmentation",
  1277. "BeitModel",
  1278. "BeitPreTrainedModel",
  1279. "BeitBackbone",
  1280. ]