modeling_pvt.py 24 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591
  1. # coding=utf-8
  2. # Copyright 2023 Authors: Wenhai Wang, Enze Xie, Xiang Li, Deng-Ping Fan,
  3. # Kaitao Song, Ding Liang, Tong Lu, Ping Luo, Ling Shao and The HuggingFace Inc. team.
  4. # All rights reserved.
  5. #
  6. # Licensed under the Apache License, Version 2.0 (the "License");
  7. # you may not use this file except in compliance with the License.
  8. # You may obtain a copy of the License at
  9. #
  10. # http://www.apache.org/licenses/LICENSE-2.0
  11. #
  12. # Unless required by applicable law or agreed to in writing, software
  13. # distributed under the License is distributed on an "AS IS" BASIS,
  14. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  15. # See the License for the specific language governing permissions and
  16. # limitations under the License.
  17. """PyTorch PVT model."""
  18. import collections
  19. import math
  20. from collections.abc import Iterable
  21. from typing import Optional, Union
  22. import torch
  23. import torch.nn.functional as F
  24. from torch import nn
  25. from ...activations import ACT2FN
  26. from ...modeling_outputs import BaseModelOutput, ImageClassifierOutput
  27. from ...modeling_utils import PreTrainedModel
  28. from ...pytorch_utils import find_pruneable_heads_and_indices, prune_linear_layer
  29. from ...utils import auto_docstring, logging
  30. from .configuration_pvt import PvtConfig
  31. logger = logging.get_logger(__name__)
  32. # Copied from transformers.models.beit.modeling_beit.drop_path
  33. def drop_path(input: torch.Tensor, drop_prob: float = 0.0, training: bool = False) -> torch.Tensor:
  34. """
  35. Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
  36. Comment by Ross Wightman: This is the same as the DropConnect impl I created for EfficientNet, etc networks,
  37. however, the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper...
  38. See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for changing the
  39. layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use 'survival rate' as the
  40. argument.
  41. """
  42. if drop_prob == 0.0 or not training:
  43. return input
  44. keep_prob = 1 - drop_prob
  45. shape = (input.shape[0],) + (1,) * (input.ndim - 1) # work with diff dim tensors, not just 2D ConvNets
  46. random_tensor = keep_prob + torch.rand(shape, dtype=input.dtype, device=input.device)
  47. random_tensor.floor_() # binarize
  48. output = input.div(keep_prob) * random_tensor
  49. return output
  50. # Copied from transformers.models.convnext.modeling_convnext.ConvNextDropPath with ConvNext->Pvt
  51. class PvtDropPath(nn.Module):
  52. """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks)."""
  53. def __init__(self, drop_prob: Optional[float] = None) -> None:
  54. super().__init__()
  55. self.drop_prob = drop_prob
  56. def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
  57. return drop_path(hidden_states, self.drop_prob, self.training)
  58. def extra_repr(self) -> str:
  59. return f"p={self.drop_prob}"
  60. class PvtPatchEmbeddings(nn.Module):
  61. """
  62. This class turns `pixel_values` of shape `(batch_size, num_channels, height, width)` into the initial
  63. `hidden_states` (patch embeddings) of shape `(batch_size, seq_length, hidden_size)` to be consumed by a
  64. Transformer.
  65. """
  66. def __init__(
  67. self,
  68. config: PvtConfig,
  69. image_size: Union[int, Iterable[int]],
  70. patch_size: Union[int, Iterable[int]],
  71. stride: int,
  72. num_channels: int,
  73. hidden_size: int,
  74. cls_token: bool = False,
  75. ):
  76. super().__init__()
  77. self.config = config
  78. image_size = image_size if isinstance(image_size, collections.abc.Iterable) else (image_size, image_size)
  79. patch_size = patch_size if isinstance(patch_size, collections.abc.Iterable) else (patch_size, patch_size)
  80. num_patches = (image_size[1] // patch_size[1]) * (image_size[0] // patch_size[0])
  81. self.image_size = image_size
  82. self.patch_size = patch_size
  83. self.num_channels = num_channels
  84. self.num_patches = num_patches
  85. self.position_embeddings = nn.Parameter(
  86. torch.randn(1, num_patches + 1 if cls_token else num_patches, hidden_size)
  87. )
  88. self.cls_token = nn.Parameter(torch.zeros(1, 1, hidden_size)) if cls_token else None
  89. self.projection = nn.Conv2d(num_channels, hidden_size, kernel_size=stride, stride=patch_size)
  90. self.layer_norm = nn.LayerNorm(hidden_size, eps=config.layer_norm_eps)
  91. self.dropout = nn.Dropout(p=config.hidden_dropout_prob)
  92. def interpolate_pos_encoding(self, embeddings: torch.Tensor, height: int, width: int) -> torch.Tensor:
  93. num_patches = height * width
  94. # always interpolate when tracing to ensure the exported model works for dynamic input shapes
  95. if not torch.jit.is_tracing() and num_patches == self.config.image_size * self.config.image_size:
  96. return self.position_embeddings
  97. embeddings = embeddings.reshape(1, height, width, -1).permute(0, 3, 1, 2)
  98. interpolated_embeddings = F.interpolate(embeddings, size=(height, width), mode="bilinear")
  99. interpolated_embeddings = interpolated_embeddings.reshape(1, -1, height * width).permute(0, 2, 1)
  100. return interpolated_embeddings
  101. def forward(self, pixel_values: torch.Tensor) -> tuple[torch.Tensor, int, int]:
  102. batch_size, num_channels, height, width = pixel_values.shape
  103. if num_channels != self.num_channels:
  104. raise ValueError(
  105. "Make sure that the channel dimension of the pixel values match with the one set in the configuration."
  106. )
  107. patch_embed = self.projection(pixel_values)
  108. *_, height, width = patch_embed.shape
  109. patch_embed = patch_embed.flatten(2).transpose(1, 2)
  110. embeddings = self.layer_norm(patch_embed)
  111. if self.cls_token is not None:
  112. cls_token = self.cls_token.expand(batch_size, -1, -1)
  113. embeddings = torch.cat((cls_token, embeddings), dim=1)
  114. position_embeddings = self.interpolate_pos_encoding(self.position_embeddings[:, 1:], height, width)
  115. position_embeddings = torch.cat((self.position_embeddings[:, :1], position_embeddings), dim=1)
  116. else:
  117. position_embeddings = self.interpolate_pos_encoding(self.position_embeddings, height, width)
  118. embeddings = self.dropout(embeddings + position_embeddings)
  119. return embeddings, height, width
  120. class PvtSelfOutput(nn.Module):
  121. def __init__(self, config: PvtConfig, hidden_size: int):
  122. super().__init__()
  123. self.dense = nn.Linear(hidden_size, hidden_size)
  124. self.dropout = nn.Dropout(config.hidden_dropout_prob)
  125. def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
  126. hidden_states = self.dense(hidden_states)
  127. hidden_states = self.dropout(hidden_states)
  128. return hidden_states
  129. class PvtEfficientSelfAttention(nn.Module):
  130. """Efficient self-attention mechanism with reduction of the sequence [PvT paper](https://huggingface.co/papers/2102.12122)."""
  131. def __init__(
  132. self, config: PvtConfig, hidden_size: int, num_attention_heads: int, sequences_reduction_ratio: float
  133. ):
  134. super().__init__()
  135. self.hidden_size = hidden_size
  136. self.num_attention_heads = num_attention_heads
  137. if self.hidden_size % self.num_attention_heads != 0:
  138. raise ValueError(
  139. f"The hidden size ({self.hidden_size}) is not a multiple of the number of attention "
  140. f"heads ({self.num_attention_heads})"
  141. )
  142. self.attention_head_size = int(self.hidden_size / self.num_attention_heads)
  143. self.all_head_size = self.num_attention_heads * self.attention_head_size
  144. self.query = nn.Linear(self.hidden_size, self.all_head_size, bias=config.qkv_bias)
  145. self.key = nn.Linear(self.hidden_size, self.all_head_size, bias=config.qkv_bias)
  146. self.value = nn.Linear(self.hidden_size, self.all_head_size, bias=config.qkv_bias)
  147. self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
  148. self.sequences_reduction_ratio = sequences_reduction_ratio
  149. if sequences_reduction_ratio > 1:
  150. self.sequence_reduction = nn.Conv2d(
  151. hidden_size, hidden_size, kernel_size=sequences_reduction_ratio, stride=sequences_reduction_ratio
  152. )
  153. self.layer_norm = nn.LayerNorm(hidden_size, eps=config.layer_norm_eps)
  154. def transpose_for_scores(self, hidden_states: int) -> torch.Tensor:
  155. new_shape = hidden_states.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
  156. hidden_states = hidden_states.view(new_shape)
  157. return hidden_states.permute(0, 2, 1, 3)
  158. def forward(
  159. self,
  160. hidden_states: torch.Tensor,
  161. height: int,
  162. width: int,
  163. output_attentions: bool = False,
  164. ) -> tuple[torch.Tensor]:
  165. query_layer = self.transpose_for_scores(self.query(hidden_states))
  166. if self.sequences_reduction_ratio > 1:
  167. batch_size, seq_len, num_channels = hidden_states.shape
  168. # Reshape to (batch_size, num_channels, height, width)
  169. hidden_states = hidden_states.permute(0, 2, 1).reshape(batch_size, num_channels, height, width)
  170. # Apply sequence reduction
  171. hidden_states = self.sequence_reduction(hidden_states)
  172. # Reshape back to (batch_size, seq_len, num_channels)
  173. hidden_states = hidden_states.reshape(batch_size, num_channels, -1).permute(0, 2, 1)
  174. hidden_states = self.layer_norm(hidden_states)
  175. key_layer = self.transpose_for_scores(self.key(hidden_states))
  176. value_layer = self.transpose_for_scores(self.value(hidden_states))
  177. # Take the dot product between "query" and "key" to get the raw attention scores.
  178. attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
  179. attention_scores = attention_scores / math.sqrt(self.attention_head_size)
  180. # Normalize the attention scores to probabilities.
  181. attention_probs = nn.functional.softmax(attention_scores, dim=-1)
  182. # This is actually dropping out entire tokens to attend to, which might
  183. # seem a bit unusual, but is taken from the original Transformer paper.
  184. attention_probs = self.dropout(attention_probs)
  185. context_layer = torch.matmul(attention_probs, value_layer)
  186. context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
  187. new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
  188. context_layer = context_layer.view(new_context_layer_shape)
  189. outputs = (context_layer, attention_probs) if output_attentions else (context_layer,)
  190. return outputs
  191. class PvtAttention(nn.Module):
  192. def __init__(
  193. self, config: PvtConfig, hidden_size: int, num_attention_heads: int, sequences_reduction_ratio: float
  194. ):
  195. super().__init__()
  196. self.self = PvtEfficientSelfAttention(
  197. config,
  198. hidden_size=hidden_size,
  199. num_attention_heads=num_attention_heads,
  200. sequences_reduction_ratio=sequences_reduction_ratio,
  201. )
  202. self.output = PvtSelfOutput(config, hidden_size=hidden_size)
  203. self.pruned_heads = set()
  204. def prune_heads(self, heads):
  205. if len(heads) == 0:
  206. return
  207. heads, index = find_pruneable_heads_and_indices(
  208. heads, self.self.num_attention_heads, self.self.attention_head_size, self.pruned_heads
  209. )
  210. # Prune linear layers
  211. self.self.query = prune_linear_layer(self.self.query, index)
  212. self.self.key = prune_linear_layer(self.self.key, index)
  213. self.self.value = prune_linear_layer(self.self.value, index)
  214. self.output.dense = prune_linear_layer(self.output.dense, index, dim=1)
  215. # Update hyper params and store pruned heads
  216. self.self.num_attention_heads = self.self.num_attention_heads - len(heads)
  217. self.self.all_head_size = self.self.attention_head_size * self.self.num_attention_heads
  218. self.pruned_heads = self.pruned_heads.union(heads)
  219. def forward(
  220. self, hidden_states: torch.Tensor, height: int, width: int, output_attentions: bool = False
  221. ) -> tuple[torch.Tensor]:
  222. self_outputs = self.self(hidden_states, height, width, output_attentions)
  223. attention_output = self.output(self_outputs[0])
  224. outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them
  225. return outputs
  226. class PvtFFN(nn.Module):
  227. def __init__(
  228. self,
  229. config: PvtConfig,
  230. in_features: int,
  231. hidden_features: Optional[int] = None,
  232. out_features: Optional[int] = None,
  233. ):
  234. super().__init__()
  235. out_features = out_features if out_features is not None else in_features
  236. self.dense1 = nn.Linear(in_features, hidden_features)
  237. if isinstance(config.hidden_act, str):
  238. self.intermediate_act_fn = ACT2FN[config.hidden_act]
  239. else:
  240. self.intermediate_act_fn = config.hidden_act
  241. self.dense2 = nn.Linear(hidden_features, out_features)
  242. self.dropout = nn.Dropout(config.hidden_dropout_prob)
  243. def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
  244. hidden_states = self.dense1(hidden_states)
  245. hidden_states = self.intermediate_act_fn(hidden_states)
  246. hidden_states = self.dropout(hidden_states)
  247. hidden_states = self.dense2(hidden_states)
  248. hidden_states = self.dropout(hidden_states)
  249. return hidden_states
  250. class PvtLayer(nn.Module):
  251. def __init__(
  252. self,
  253. config: PvtConfig,
  254. hidden_size: int,
  255. num_attention_heads: int,
  256. drop_path: float,
  257. sequences_reduction_ratio: float,
  258. mlp_ratio: float,
  259. ):
  260. super().__init__()
  261. self.layer_norm_1 = nn.LayerNorm(hidden_size, eps=config.layer_norm_eps)
  262. self.attention = PvtAttention(
  263. config=config,
  264. hidden_size=hidden_size,
  265. num_attention_heads=num_attention_heads,
  266. sequences_reduction_ratio=sequences_reduction_ratio,
  267. )
  268. self.drop_path = PvtDropPath(drop_path) if drop_path > 0.0 else nn.Identity()
  269. self.layer_norm_2 = nn.LayerNorm(hidden_size, eps=config.layer_norm_eps)
  270. mlp_hidden_size = int(hidden_size * mlp_ratio)
  271. self.mlp = PvtFFN(config=config, in_features=hidden_size, hidden_features=mlp_hidden_size)
  272. def forward(self, hidden_states: torch.Tensor, height: int, width: int, output_attentions: bool = False):
  273. self_attention_outputs = self.attention(
  274. hidden_states=self.layer_norm_1(hidden_states),
  275. height=height,
  276. width=width,
  277. output_attentions=output_attentions,
  278. )
  279. attention_output = self_attention_outputs[0]
  280. outputs = self_attention_outputs[1:]
  281. attention_output = self.drop_path(attention_output)
  282. hidden_states = attention_output + hidden_states
  283. mlp_output = self.mlp(self.layer_norm_2(hidden_states))
  284. mlp_output = self.drop_path(mlp_output)
  285. layer_output = hidden_states + mlp_output
  286. outputs = (layer_output,) + outputs
  287. return outputs
  288. class PvtEncoder(nn.Module):
  289. def __init__(self, config: PvtConfig):
  290. super().__init__()
  291. self.config = config
  292. # stochastic depth decay rule
  293. drop_path_decays = torch.linspace(0, config.drop_path_rate, sum(config.depths), device="cpu").tolist()
  294. # patch embeddings
  295. embeddings = []
  296. for i in range(config.num_encoder_blocks):
  297. embeddings.append(
  298. PvtPatchEmbeddings(
  299. config=config,
  300. image_size=config.image_size if i == 0 else self.config.image_size // (2 ** (i + 1)),
  301. patch_size=config.patch_sizes[i],
  302. stride=config.strides[i],
  303. num_channels=config.num_channels if i == 0 else config.hidden_sizes[i - 1],
  304. hidden_size=config.hidden_sizes[i],
  305. cls_token=i == config.num_encoder_blocks - 1,
  306. )
  307. )
  308. self.patch_embeddings = nn.ModuleList(embeddings)
  309. # Transformer blocks
  310. blocks = []
  311. cur = 0
  312. for i in range(config.num_encoder_blocks):
  313. # each block consists of layers
  314. layers = []
  315. if i != 0:
  316. cur += config.depths[i - 1]
  317. for j in range(config.depths[i]):
  318. layers.append(
  319. PvtLayer(
  320. config=config,
  321. hidden_size=config.hidden_sizes[i],
  322. num_attention_heads=config.num_attention_heads[i],
  323. drop_path=drop_path_decays[cur + j],
  324. sequences_reduction_ratio=config.sequence_reduction_ratios[i],
  325. mlp_ratio=config.mlp_ratios[i],
  326. )
  327. )
  328. blocks.append(nn.ModuleList(layers))
  329. self.block = nn.ModuleList(blocks)
  330. # Layer norms
  331. self.layer_norm = nn.LayerNorm(config.hidden_sizes[-1], eps=config.layer_norm_eps)
  332. def forward(
  333. self,
  334. pixel_values: torch.FloatTensor,
  335. output_attentions: Optional[bool] = False,
  336. output_hidden_states: Optional[bool] = False,
  337. return_dict: Optional[bool] = True,
  338. ) -> Union[tuple, BaseModelOutput]:
  339. all_hidden_states = () if output_hidden_states else None
  340. all_self_attentions = () if output_attentions else None
  341. batch_size = pixel_values.shape[0]
  342. num_blocks = len(self.block)
  343. hidden_states = pixel_values
  344. for idx, (embedding_layer, block_layer) in enumerate(zip(self.patch_embeddings, self.block)):
  345. # first, obtain patch embeddings
  346. hidden_states, height, width = embedding_layer(hidden_states)
  347. # second, send embeddings through blocks
  348. for block in block_layer:
  349. layer_outputs = block(hidden_states, height, width, output_attentions)
  350. hidden_states = layer_outputs[0]
  351. if output_attentions:
  352. all_self_attentions = all_self_attentions + (layer_outputs[1],)
  353. if output_hidden_states:
  354. all_hidden_states = all_hidden_states + (hidden_states,)
  355. if idx != num_blocks - 1:
  356. hidden_states = hidden_states.reshape(batch_size, height, width, -1).permute(0, 3, 1, 2).contiguous()
  357. hidden_states = self.layer_norm(hidden_states)
  358. if output_hidden_states:
  359. all_hidden_states = all_hidden_states + (hidden_states,)
  360. if not return_dict:
  361. return tuple(v for v in [hidden_states, all_hidden_states, all_self_attentions] if v is not None)
  362. return BaseModelOutput(
  363. last_hidden_state=hidden_states,
  364. hidden_states=all_hidden_states,
  365. attentions=all_self_attentions,
  366. )
  367. @auto_docstring
  368. class PvtPreTrainedModel(PreTrainedModel):
  369. config: PvtConfig
  370. base_model_prefix = "pvt"
  371. main_input_name = "pixel_values"
  372. _no_split_modules = []
  373. def _init_weights(self, module: nn.Module) -> None:
  374. """Initialize the weights"""
  375. std = self.config.initializer_range
  376. if isinstance(module, (nn.Linear, nn.Conv2d)):
  377. # Upcast the input in `fp32` and cast it back to desired `dtype` to avoid
  378. # `trunc_normal_cpu` not implemented in `half` issues
  379. nn.init.trunc_normal_(module.weight.data, mean=0.0, std=std)
  380. if module.bias is not None:
  381. module.bias.data.zero_()
  382. elif isinstance(module, nn.LayerNorm):
  383. module.bias.data.zero_()
  384. module.weight.data.fill_(1.0)
  385. elif isinstance(module, PvtPatchEmbeddings):
  386. module.position_embeddings.data = nn.init.trunc_normal_(
  387. module.position_embeddings.data,
  388. mean=0.0,
  389. std=std,
  390. )
  391. if module.cls_token is not None:
  392. module.cls_token.data = nn.init.trunc_normal_(
  393. module.cls_token.data,
  394. mean=0.0,
  395. std=std,
  396. )
  397. @auto_docstring
  398. class PvtModel(PvtPreTrainedModel):
  399. def __init__(self, config: PvtConfig):
  400. super().__init__(config)
  401. self.config = config
  402. # hierarchical Transformer encoder
  403. self.encoder = PvtEncoder(config)
  404. # Initialize weights and apply final processing
  405. self.post_init()
  406. def _prune_heads(self, heads_to_prune):
  407. """
  408. Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base
  409. class PreTrainedModel
  410. """
  411. for layer, heads in heads_to_prune.items():
  412. self.encoder.layer[layer].attention.prune_heads(heads)
  413. @auto_docstring
  414. def forward(
  415. self,
  416. pixel_values: torch.FloatTensor,
  417. output_attentions: Optional[bool] = None,
  418. output_hidden_states: Optional[bool] = None,
  419. return_dict: Optional[bool] = None,
  420. ) -> Union[tuple, BaseModelOutput]:
  421. output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
  422. output_hidden_states = (
  423. output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
  424. )
  425. return_dict = return_dict if return_dict is not None else self.config.use_return_dict
  426. encoder_outputs = self.encoder(
  427. pixel_values=pixel_values,
  428. output_attentions=output_attentions,
  429. output_hidden_states=output_hidden_states,
  430. return_dict=return_dict,
  431. )
  432. sequence_output = encoder_outputs[0]
  433. if not return_dict:
  434. return (sequence_output,) + encoder_outputs[1:]
  435. return BaseModelOutput(
  436. last_hidden_state=sequence_output,
  437. hidden_states=encoder_outputs.hidden_states,
  438. attentions=encoder_outputs.attentions,
  439. )
  440. @auto_docstring(
  441. custom_intro="""
  442. Pvt Model transformer with an image classification head on top (a linear layer on top of the final hidden state of
  443. the [CLS] token) e.g. for ImageNet.
  444. """
  445. )
  446. class PvtForImageClassification(PvtPreTrainedModel):
  447. def __init__(self, config: PvtConfig) -> None:
  448. super().__init__(config)
  449. self.num_labels = config.num_labels
  450. self.pvt = PvtModel(config)
  451. # Classifier head
  452. self.classifier = (
  453. nn.Linear(config.hidden_sizes[-1], config.num_labels) if config.num_labels > 0 else nn.Identity()
  454. )
  455. # Initialize weights and apply final processing
  456. self.post_init()
  457. @auto_docstring
  458. def forward(
  459. self,
  460. pixel_values: Optional[torch.Tensor],
  461. labels: Optional[torch.Tensor] = None,
  462. output_attentions: Optional[bool] = None,
  463. output_hidden_states: Optional[bool] = None,
  464. return_dict: Optional[bool] = None,
  465. ) -> Union[tuple, ImageClassifierOutput]:
  466. r"""
  467. labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
  468. Labels for computing the image classification/regression loss. Indices should be in `[0, ...,
  469. config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
  470. `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
  471. """
  472. return_dict = return_dict if return_dict is not None else self.config.use_return_dict
  473. outputs = self.pvt(
  474. pixel_values=pixel_values,
  475. output_attentions=output_attentions,
  476. output_hidden_states=output_hidden_states,
  477. return_dict=return_dict,
  478. )
  479. sequence_output = outputs[0]
  480. logits = self.classifier(sequence_output[:, 0, :])
  481. loss = None
  482. if labels is not None:
  483. loss = self.loss_function(labels, logits, self.config)
  484. if not return_dict:
  485. output = (logits,) + outputs[1:]
  486. return ((loss,) + output) if loss is not None else output
  487. return ImageClassifierOutput(
  488. loss=loss,
  489. logits=logits,
  490. hidden_states=outputs.hidden_states,
  491. attentions=outputs.attentions,
  492. )
  493. __all__ = ["PvtForImageClassification", "PvtModel", "PvtPreTrainedModel"]