modeling_levit.py 25 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657
  1. # coding=utf-8
  2. # Copyright 2022 Meta Platforms, Inc. and The HuggingFace Inc. team. All rights reserved.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. """PyTorch LeViT model."""
  16. import itertools
  17. from dataclasses import dataclass
  18. from typing import Optional, Union
  19. import torch
  20. from torch import nn
  21. from ...modeling_outputs import (
  22. BaseModelOutputWithNoAttention,
  23. BaseModelOutputWithPoolingAndNoAttention,
  24. ImageClassifierOutputWithNoAttention,
  25. ModelOutput,
  26. )
  27. from ...modeling_utils import PreTrainedModel
  28. from ...utils import auto_docstring, logging
  29. from .configuration_levit import LevitConfig
  30. logger = logging.get_logger(__name__)
  31. @dataclass
  32. @auto_docstring(
  33. custom_intro="""
  34. Output type of [`LevitForImageClassificationWithTeacher`].
  35. """
  36. )
  37. class LevitForImageClassificationWithTeacherOutput(ModelOutput):
  38. r"""
  39. logits (`torch.FloatTensor` of shape `(batch_size, config.num_labels)`):
  40. Prediction scores as the average of the `cls_logits` and `distillation_logits`.
  41. cls_logits (`torch.FloatTensor` of shape `(batch_size, config.num_labels)`):
  42. Prediction scores of the classification head (i.e. the linear layer on top of the final hidden state of the
  43. class token).
  44. distillation_logits (`torch.FloatTensor` of shape `(batch_size, config.num_labels)`):
  45. Prediction scores of the distillation head (i.e. the linear layer on top of the final hidden state of the
  46. distillation token).
  47. """
  48. logits: Optional[torch.FloatTensor] = None
  49. cls_logits: Optional[torch.FloatTensor] = None
  50. distillation_logits: Optional[torch.FloatTensor] = None
  51. hidden_states: Optional[tuple[torch.FloatTensor]] = None
  52. class LevitConvEmbeddings(nn.Module):
  53. """
  54. LeViT Conv Embeddings with Batch Norm, used in the initial patch embedding layer.
  55. """
  56. def __init__(
  57. self, in_channels, out_channels, kernel_size, stride, padding, dilation=1, groups=1, bn_weight_init=1
  58. ):
  59. super().__init__()
  60. self.convolution = nn.Conv2d(
  61. in_channels, out_channels, kernel_size, stride, padding, dilation=dilation, groups=groups, bias=False
  62. )
  63. self.batch_norm = nn.BatchNorm2d(out_channels)
  64. def forward(self, embeddings):
  65. embeddings = self.convolution(embeddings)
  66. embeddings = self.batch_norm(embeddings)
  67. return embeddings
  68. class LevitPatchEmbeddings(nn.Module):
  69. """
  70. LeViT patch embeddings, for final embeddings to be passed to transformer blocks. It consists of multiple
  71. `LevitConvEmbeddings`.
  72. """
  73. def __init__(self, config):
  74. super().__init__()
  75. self.embedding_layer_1 = LevitConvEmbeddings(
  76. config.num_channels, config.hidden_sizes[0] // 8, config.kernel_size, config.stride, config.padding
  77. )
  78. self.activation_layer_1 = nn.Hardswish()
  79. self.embedding_layer_2 = LevitConvEmbeddings(
  80. config.hidden_sizes[0] // 8, config.hidden_sizes[0] // 4, config.kernel_size, config.stride, config.padding
  81. )
  82. self.activation_layer_2 = nn.Hardswish()
  83. self.embedding_layer_3 = LevitConvEmbeddings(
  84. config.hidden_sizes[0] // 4, config.hidden_sizes[0] // 2, config.kernel_size, config.stride, config.padding
  85. )
  86. self.activation_layer_3 = nn.Hardswish()
  87. self.embedding_layer_4 = LevitConvEmbeddings(
  88. config.hidden_sizes[0] // 2, config.hidden_sizes[0], config.kernel_size, config.stride, config.padding
  89. )
  90. self.num_channels = config.num_channels
  91. def forward(self, pixel_values):
  92. num_channels = pixel_values.shape[1]
  93. if num_channels != self.num_channels:
  94. raise ValueError(
  95. "Make sure that the channel dimension of the pixel values match with the one set in the configuration."
  96. )
  97. embeddings = self.embedding_layer_1(pixel_values)
  98. embeddings = self.activation_layer_1(embeddings)
  99. embeddings = self.embedding_layer_2(embeddings)
  100. embeddings = self.activation_layer_2(embeddings)
  101. embeddings = self.embedding_layer_3(embeddings)
  102. embeddings = self.activation_layer_3(embeddings)
  103. embeddings = self.embedding_layer_4(embeddings)
  104. return embeddings.flatten(2).transpose(1, 2)
  105. class MLPLayerWithBN(nn.Module):
  106. def __init__(self, input_dim, output_dim, bn_weight_init=1):
  107. super().__init__()
  108. self.linear = nn.Linear(in_features=input_dim, out_features=output_dim, bias=False)
  109. self.batch_norm = nn.BatchNorm1d(output_dim)
  110. def forward(self, hidden_state):
  111. hidden_state = self.linear(hidden_state)
  112. hidden_state = self.batch_norm(hidden_state.flatten(0, 1)).reshape_as(hidden_state)
  113. return hidden_state
  114. class LevitSubsample(nn.Module):
  115. def __init__(self, stride, resolution):
  116. super().__init__()
  117. self.stride = stride
  118. self.resolution = resolution
  119. def forward(self, hidden_state):
  120. batch_size, _, channels = hidden_state.shape
  121. hidden_state = hidden_state.view(batch_size, self.resolution, self.resolution, channels)[
  122. :, :: self.stride, :: self.stride
  123. ].reshape(batch_size, -1, channels)
  124. return hidden_state
  125. class LevitAttention(nn.Module):
  126. def __init__(self, hidden_sizes, key_dim, num_attention_heads, attention_ratio, resolution):
  127. super().__init__()
  128. self.num_attention_heads = num_attention_heads
  129. self.scale = key_dim**-0.5
  130. self.key_dim = key_dim
  131. self.attention_ratio = attention_ratio
  132. self.out_dim_keys_values = attention_ratio * key_dim * num_attention_heads + key_dim * num_attention_heads * 2
  133. self.out_dim_projection = attention_ratio * key_dim * num_attention_heads
  134. self.queries_keys_values = MLPLayerWithBN(hidden_sizes, self.out_dim_keys_values)
  135. self.activation = nn.Hardswish()
  136. self.projection = MLPLayerWithBN(self.out_dim_projection, hidden_sizes, bn_weight_init=0)
  137. points = list(itertools.product(range(resolution), range(resolution)))
  138. len_points = len(points)
  139. attention_offsets, indices = {}, []
  140. for p1 in points:
  141. for p2 in points:
  142. offset = (abs(p1[0] - p2[0]), abs(p1[1] - p2[1]))
  143. if offset not in attention_offsets:
  144. attention_offsets[offset] = len(attention_offsets)
  145. indices.append(attention_offsets[offset])
  146. self.attention_bias_cache = {}
  147. self.attention_biases = torch.nn.Parameter(torch.zeros(num_attention_heads, len(attention_offsets)))
  148. self.register_buffer(
  149. "attention_bias_idxs", torch.LongTensor(indices).view(len_points, len_points), persistent=False
  150. )
  151. @torch.no_grad()
  152. def train(self, mode=True):
  153. super().train(mode)
  154. if mode and self.attention_bias_cache:
  155. self.attention_bias_cache = {} # clear ab cache
  156. def get_attention_biases(self, device):
  157. if self.training:
  158. return self.attention_biases[:, self.attention_bias_idxs]
  159. else:
  160. device_key = str(device)
  161. if device_key not in self.attention_bias_cache:
  162. self.attention_bias_cache[device_key] = self.attention_biases[:, self.attention_bias_idxs]
  163. return self.attention_bias_cache[device_key]
  164. def forward(self, hidden_state):
  165. batch_size, seq_length, _ = hidden_state.shape
  166. queries_keys_values = self.queries_keys_values(hidden_state)
  167. query, key, value = queries_keys_values.view(batch_size, seq_length, self.num_attention_heads, -1).split(
  168. [self.key_dim, self.key_dim, self.attention_ratio * self.key_dim], dim=3
  169. )
  170. query = query.permute(0, 2, 1, 3)
  171. key = key.permute(0, 2, 1, 3)
  172. value = value.permute(0, 2, 1, 3)
  173. attention = query @ key.transpose(-2, -1) * self.scale + self.get_attention_biases(hidden_state.device)
  174. attention = attention.softmax(dim=-1)
  175. hidden_state = (attention @ value).transpose(1, 2).reshape(batch_size, seq_length, self.out_dim_projection)
  176. hidden_state = self.projection(self.activation(hidden_state))
  177. return hidden_state
  178. class LevitAttentionSubsample(nn.Module):
  179. def __init__(
  180. self,
  181. input_dim,
  182. output_dim,
  183. key_dim,
  184. num_attention_heads,
  185. attention_ratio,
  186. stride,
  187. resolution_in,
  188. resolution_out,
  189. ):
  190. super().__init__()
  191. self.num_attention_heads = num_attention_heads
  192. self.scale = key_dim**-0.5
  193. self.key_dim = key_dim
  194. self.attention_ratio = attention_ratio
  195. self.out_dim_keys_values = attention_ratio * key_dim * num_attention_heads + key_dim * num_attention_heads
  196. self.out_dim_projection = attention_ratio * key_dim * num_attention_heads
  197. self.resolution_out = resolution_out
  198. # resolution_in is the initial resolution, resolution_out is final resolution after downsampling
  199. self.keys_values = MLPLayerWithBN(input_dim, self.out_dim_keys_values)
  200. self.queries_subsample = LevitSubsample(stride, resolution_in)
  201. self.queries = MLPLayerWithBN(input_dim, key_dim * num_attention_heads)
  202. self.activation = nn.Hardswish()
  203. self.projection = MLPLayerWithBN(self.out_dim_projection, output_dim)
  204. self.attention_bias_cache = {}
  205. points = list(itertools.product(range(resolution_in), range(resolution_in)))
  206. points_ = list(itertools.product(range(resolution_out), range(resolution_out)))
  207. len_points, len_points_ = len(points), len(points_)
  208. attention_offsets, indices = {}, []
  209. for p1 in points_:
  210. for p2 in points:
  211. size = 1
  212. offset = (abs(p1[0] * stride - p2[0] + (size - 1) / 2), abs(p1[1] * stride - p2[1] + (size - 1) / 2))
  213. if offset not in attention_offsets:
  214. attention_offsets[offset] = len(attention_offsets)
  215. indices.append(attention_offsets[offset])
  216. self.attention_biases = torch.nn.Parameter(torch.zeros(num_attention_heads, len(attention_offsets)))
  217. self.register_buffer(
  218. "attention_bias_idxs", torch.LongTensor(indices).view(len_points_, len_points), persistent=False
  219. )
  220. @torch.no_grad()
  221. def train(self, mode=True):
  222. super().train(mode)
  223. if mode and self.attention_bias_cache:
  224. self.attention_bias_cache = {} # clear ab cache
  225. def get_attention_biases(self, device):
  226. if self.training:
  227. return self.attention_biases[:, self.attention_bias_idxs]
  228. else:
  229. device_key = str(device)
  230. if device_key not in self.attention_bias_cache:
  231. self.attention_bias_cache[device_key] = self.attention_biases[:, self.attention_bias_idxs]
  232. return self.attention_bias_cache[device_key]
  233. def forward(self, hidden_state):
  234. batch_size, seq_length, _ = hidden_state.shape
  235. key, value = (
  236. self.keys_values(hidden_state)
  237. .view(batch_size, seq_length, self.num_attention_heads, -1)
  238. .split([self.key_dim, self.attention_ratio * self.key_dim], dim=3)
  239. )
  240. key = key.permute(0, 2, 1, 3)
  241. value = value.permute(0, 2, 1, 3)
  242. query = self.queries(self.queries_subsample(hidden_state))
  243. query = query.view(batch_size, self.resolution_out**2, self.num_attention_heads, self.key_dim).permute(
  244. 0, 2, 1, 3
  245. )
  246. attention = query @ key.transpose(-2, -1) * self.scale + self.get_attention_biases(hidden_state.device)
  247. attention = attention.softmax(dim=-1)
  248. hidden_state = (attention @ value).transpose(1, 2).reshape(batch_size, -1, self.out_dim_projection)
  249. hidden_state = self.projection(self.activation(hidden_state))
  250. return hidden_state
  251. class LevitMLPLayer(nn.Module):
  252. """
  253. MLP Layer with `2X` expansion in contrast to ViT with `4X`.
  254. """
  255. def __init__(self, input_dim, hidden_dim):
  256. super().__init__()
  257. self.linear_up = MLPLayerWithBN(input_dim, hidden_dim)
  258. self.activation = nn.Hardswish()
  259. self.linear_down = MLPLayerWithBN(hidden_dim, input_dim)
  260. def forward(self, hidden_state):
  261. hidden_state = self.linear_up(hidden_state)
  262. hidden_state = self.activation(hidden_state)
  263. hidden_state = self.linear_down(hidden_state)
  264. return hidden_state
  265. class LevitResidualLayer(nn.Module):
  266. """
  267. Residual Block for LeViT
  268. """
  269. def __init__(self, module, drop_rate):
  270. super().__init__()
  271. self.module = module
  272. self.drop_rate = drop_rate
  273. def forward(self, hidden_state):
  274. if self.training and self.drop_rate > 0:
  275. rnd = torch.rand(hidden_state.size(0), 1, 1, device=hidden_state.device)
  276. rnd = rnd.ge_(self.drop_rate).div(1 - self.drop_rate).detach()
  277. hidden_state = hidden_state + self.module(hidden_state) * rnd
  278. return hidden_state
  279. else:
  280. hidden_state = hidden_state + self.module(hidden_state)
  281. return hidden_state
  282. class LevitStage(nn.Module):
  283. """
  284. LeViT Stage consisting of `LevitMLPLayer` and `LevitAttention` layers.
  285. """
  286. def __init__(
  287. self,
  288. config,
  289. idx,
  290. hidden_sizes,
  291. key_dim,
  292. depths,
  293. num_attention_heads,
  294. attention_ratio,
  295. mlp_ratio,
  296. down_ops,
  297. resolution_in,
  298. ):
  299. super().__init__()
  300. self.layers = []
  301. self.config = config
  302. self.resolution_in = resolution_in
  303. # resolution_in is the initial resolution, resolution_out is final resolution after downsampling
  304. for _ in range(depths):
  305. self.layers.append(
  306. LevitResidualLayer(
  307. LevitAttention(hidden_sizes, key_dim, num_attention_heads, attention_ratio, resolution_in),
  308. self.config.drop_path_rate,
  309. )
  310. )
  311. if mlp_ratio > 0:
  312. hidden_dim = hidden_sizes * mlp_ratio
  313. self.layers.append(
  314. LevitResidualLayer(LevitMLPLayer(hidden_sizes, hidden_dim), self.config.drop_path_rate)
  315. )
  316. if down_ops[0] == "Subsample":
  317. self.resolution_out = (self.resolution_in - 1) // down_ops[5] + 1
  318. self.layers.append(
  319. LevitAttentionSubsample(
  320. *self.config.hidden_sizes[idx : idx + 2],
  321. key_dim=down_ops[1],
  322. num_attention_heads=down_ops[2],
  323. attention_ratio=down_ops[3],
  324. stride=down_ops[5],
  325. resolution_in=resolution_in,
  326. resolution_out=self.resolution_out,
  327. )
  328. )
  329. self.resolution_in = self.resolution_out
  330. if down_ops[4] > 0:
  331. hidden_dim = self.config.hidden_sizes[idx + 1] * down_ops[4]
  332. self.layers.append(
  333. LevitResidualLayer(
  334. LevitMLPLayer(self.config.hidden_sizes[idx + 1], hidden_dim), self.config.drop_path_rate
  335. )
  336. )
  337. self.layers = nn.ModuleList(self.layers)
  338. def get_resolution(self):
  339. return self.resolution_in
  340. def forward(self, hidden_state):
  341. for layer in self.layers:
  342. hidden_state = layer(hidden_state)
  343. return hidden_state
  344. class LevitEncoder(nn.Module):
  345. """
  346. LeViT Encoder consisting of multiple `LevitStage` stages.
  347. """
  348. def __init__(self, config):
  349. super().__init__()
  350. self.config = config
  351. resolution = self.config.image_size // self.config.patch_size
  352. self.stages = []
  353. self.config.down_ops.append([""])
  354. for stage_idx in range(len(config.depths)):
  355. stage = LevitStage(
  356. config,
  357. stage_idx,
  358. config.hidden_sizes[stage_idx],
  359. config.key_dim[stage_idx],
  360. config.depths[stage_idx],
  361. config.num_attention_heads[stage_idx],
  362. config.attention_ratio[stage_idx],
  363. config.mlp_ratio[stage_idx],
  364. config.down_ops[stage_idx],
  365. resolution,
  366. )
  367. resolution = stage.get_resolution()
  368. self.stages.append(stage)
  369. self.stages = nn.ModuleList(self.stages)
  370. def forward(self, hidden_state, output_hidden_states=False, return_dict=True):
  371. all_hidden_states = () if output_hidden_states else None
  372. for stage in self.stages:
  373. if output_hidden_states:
  374. all_hidden_states = all_hidden_states + (hidden_state,)
  375. hidden_state = stage(hidden_state)
  376. if output_hidden_states:
  377. all_hidden_states = all_hidden_states + (hidden_state,)
  378. if not return_dict:
  379. return tuple(v for v in [hidden_state, all_hidden_states] if v is not None)
  380. return BaseModelOutputWithNoAttention(last_hidden_state=hidden_state, hidden_states=all_hidden_states)
  381. class LevitClassificationLayer(nn.Module):
  382. """
  383. LeViT Classification Layer
  384. """
  385. def __init__(self, input_dim, output_dim):
  386. super().__init__()
  387. self.batch_norm = nn.BatchNorm1d(input_dim)
  388. self.linear = nn.Linear(input_dim, output_dim)
  389. def forward(self, hidden_state):
  390. hidden_state = self.batch_norm(hidden_state)
  391. logits = self.linear(hidden_state)
  392. return logits
  393. @auto_docstring
  394. class LevitPreTrainedModel(PreTrainedModel):
  395. config: LevitConfig
  396. base_model_prefix = "levit"
  397. main_input_name = "pixel_values"
  398. _no_split_modules = ["LevitResidualLayer"]
  399. def _init_weights(self, module):
  400. """Initialize the weights"""
  401. if isinstance(module, (nn.Linear, nn.Conv2d)):
  402. # Slightly different from the TF version which uses truncated_normal for initialization
  403. # cf https://github.com/pytorch/pytorch/pull/5617
  404. module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
  405. if module.bias is not None:
  406. module.bias.data.zero_()
  407. elif isinstance(module, (nn.BatchNorm1d, nn.BatchNorm2d)):
  408. module.bias.data.zero_()
  409. module.weight.data.fill_(1.0)
  410. @auto_docstring
  411. class LevitModel(LevitPreTrainedModel):
  412. def __init__(self, config):
  413. super().__init__(config)
  414. self.config = config
  415. self.patch_embeddings = LevitPatchEmbeddings(config)
  416. self.encoder = LevitEncoder(config)
  417. # Initialize weights and apply final processing
  418. self.post_init()
  419. @auto_docstring
  420. def forward(
  421. self,
  422. pixel_values: Optional[torch.FloatTensor] = None,
  423. output_hidden_states: Optional[bool] = None,
  424. return_dict: Optional[bool] = None,
  425. ) -> Union[tuple, BaseModelOutputWithPoolingAndNoAttention]:
  426. output_hidden_states = (
  427. output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
  428. )
  429. return_dict = return_dict if return_dict is not None else self.config.use_return_dict
  430. if pixel_values is None:
  431. raise ValueError("You have to specify pixel_values")
  432. embeddings = self.patch_embeddings(pixel_values)
  433. encoder_outputs = self.encoder(
  434. embeddings,
  435. output_hidden_states=output_hidden_states,
  436. return_dict=return_dict,
  437. )
  438. last_hidden_state = encoder_outputs[0]
  439. # global average pooling, (batch_size, seq_length, hidden_sizes) -> (batch_size, hidden_sizes)
  440. pooled_output = last_hidden_state.mean(dim=1)
  441. if not return_dict:
  442. return (last_hidden_state, pooled_output) + encoder_outputs[1:]
  443. return BaseModelOutputWithPoolingAndNoAttention(
  444. last_hidden_state=last_hidden_state,
  445. pooler_output=pooled_output,
  446. hidden_states=encoder_outputs.hidden_states,
  447. )
  448. @auto_docstring(
  449. custom_intro="""
  450. Levit Model with an image classification head on top (a linear layer on top of the pooled features), e.g. for
  451. ImageNet.
  452. """
  453. )
  454. class LevitForImageClassification(LevitPreTrainedModel):
  455. def __init__(self, config):
  456. super().__init__(config)
  457. self.config = config
  458. self.num_labels = config.num_labels
  459. self.levit = LevitModel(config)
  460. # Classifier head
  461. self.classifier = (
  462. LevitClassificationLayer(config.hidden_sizes[-1], config.num_labels)
  463. if config.num_labels > 0
  464. else torch.nn.Identity()
  465. )
  466. # Initialize weights and apply final processing
  467. self.post_init()
  468. @auto_docstring
  469. def forward(
  470. self,
  471. pixel_values: Optional[torch.FloatTensor] = None,
  472. labels: Optional[torch.LongTensor] = None,
  473. output_hidden_states: Optional[bool] = None,
  474. return_dict: Optional[bool] = None,
  475. ) -> Union[tuple, ImageClassifierOutputWithNoAttention]:
  476. r"""
  477. labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
  478. Labels for computing the image classification/regression loss. Indices should be in `[0, ...,
  479. config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
  480. `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
  481. """
  482. return_dict = return_dict if return_dict is not None else self.config.use_return_dict
  483. outputs = self.levit(pixel_values, output_hidden_states=output_hidden_states, return_dict=return_dict)
  484. sequence_output = outputs[0]
  485. sequence_output = sequence_output.mean(1)
  486. logits = self.classifier(sequence_output)
  487. loss = None
  488. if labels is not None:
  489. loss = self.loss_function(labels, logits, self.config)
  490. if not return_dict:
  491. output = (logits,) + outputs[2:]
  492. return ((loss,) + output) if loss is not None else output
  493. return ImageClassifierOutputWithNoAttention(
  494. loss=loss,
  495. logits=logits,
  496. hidden_states=outputs.hidden_states,
  497. )
  498. @auto_docstring(
  499. custom_intro="""
  500. LeViT Model transformer with image classification heads on top (a linear layer on top of the final hidden state and
  501. a linear layer on top of the final hidden state of the distillation token) e.g. for ImageNet. .. warning::
  502. This model supports inference-only. Fine-tuning with distillation (i.e. with a teacher) is not yet
  503. supported.
  504. """
  505. )
  506. class LevitForImageClassificationWithTeacher(LevitPreTrainedModel):
  507. def __init__(self, config):
  508. super().__init__(config)
  509. self.config = config
  510. self.num_labels = config.num_labels
  511. self.levit = LevitModel(config)
  512. # Classifier head
  513. self.classifier = (
  514. LevitClassificationLayer(config.hidden_sizes[-1], config.num_labels)
  515. if config.num_labels > 0
  516. else torch.nn.Identity()
  517. )
  518. self.classifier_distill = (
  519. LevitClassificationLayer(config.hidden_sizes[-1], config.num_labels)
  520. if config.num_labels > 0
  521. else torch.nn.Identity()
  522. )
  523. # Initialize weights and apply final processing
  524. self.post_init()
  525. @auto_docstring
  526. def forward(
  527. self,
  528. pixel_values: Optional[torch.FloatTensor] = None,
  529. output_hidden_states: Optional[bool] = None,
  530. return_dict: Optional[bool] = None,
  531. ) -> Union[tuple, LevitForImageClassificationWithTeacherOutput]:
  532. return_dict = return_dict if return_dict is not None else self.config.use_return_dict
  533. outputs = self.levit(pixel_values, output_hidden_states=output_hidden_states, return_dict=return_dict)
  534. sequence_output = outputs[0]
  535. sequence_output = sequence_output.mean(1)
  536. cls_logits, distill_logits = self.classifier(sequence_output), self.classifier_distill(sequence_output)
  537. logits = (cls_logits + distill_logits) / 2
  538. if not return_dict:
  539. output = (logits, cls_logits, distill_logits) + outputs[2:]
  540. return output
  541. return LevitForImageClassificationWithTeacherOutput(
  542. logits=logits,
  543. cls_logits=cls_logits,
  544. distillation_logits=distill_logits,
  545. hidden_states=outputs.hidden_states,
  546. )
  547. __all__ = [
  548. "LevitForImageClassification",
  549. "LevitForImageClassificationWithTeacher",
  550. "LevitModel",
  551. "LevitPreTrainedModel",
  552. ]