modeling_mobilevit.py 36 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998
  1. # coding=utf-8
  2. # Copyright 2022 Apple Inc. and The HuggingFace Inc. team. All rights reserved.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. #
  16. # Original license: https://github.com/apple/ml-cvnets/blob/main/LICENSE
  17. """PyTorch MobileViT model."""
  18. import math
  19. from typing import Optional, Union
  20. import torch
  21. from torch import nn
  22. from torch.nn import CrossEntropyLoss
  23. from ...activations import ACT2FN
  24. from ...modeling_layers import GradientCheckpointingLayer
  25. from ...modeling_outputs import (
  26. BaseModelOutputWithNoAttention,
  27. BaseModelOutputWithPoolingAndNoAttention,
  28. ImageClassifierOutputWithNoAttention,
  29. SemanticSegmenterOutput,
  30. )
  31. from ...modeling_utils import PreTrainedModel
  32. from ...pytorch_utils import find_pruneable_heads_and_indices, prune_linear_layer
  33. from ...utils import auto_docstring, logging, torch_int
  34. from .configuration_mobilevit import MobileViTConfig
  35. logger = logging.get_logger(__name__)
  36. def make_divisible(value: int, divisor: int = 8, min_value: Optional[int] = None) -> int:
  37. """
  38. Ensure that all layers have a channel count that is divisible by `divisor`. This function is taken from the
  39. original TensorFlow repo. It can be seen here:
  40. https://github.com/tensorflow/models/blob/master/research/slim/nets/mobilenet/mobilenet.py
  41. """
  42. if min_value is None:
  43. min_value = divisor
  44. new_value = max(min_value, int(value + divisor / 2) // divisor * divisor)
  45. # Make sure that round down does not go down by more than 10%.
  46. if new_value < 0.9 * value:
  47. new_value += divisor
  48. return int(new_value)
  49. class MobileViTConvLayer(nn.Module):
  50. def __init__(
  51. self,
  52. config: MobileViTConfig,
  53. in_channels: int,
  54. out_channels: int,
  55. kernel_size: int,
  56. stride: int = 1,
  57. groups: int = 1,
  58. bias: bool = False,
  59. dilation: int = 1,
  60. use_normalization: bool = True,
  61. use_activation: Union[bool, str] = True,
  62. ) -> None:
  63. super().__init__()
  64. padding = int((kernel_size - 1) / 2) * dilation
  65. if in_channels % groups != 0:
  66. raise ValueError(f"Input channels ({in_channels}) are not divisible by {groups} groups.")
  67. if out_channels % groups != 0:
  68. raise ValueError(f"Output channels ({out_channels}) are not divisible by {groups} groups.")
  69. self.convolution = nn.Conv2d(
  70. in_channels=in_channels,
  71. out_channels=out_channels,
  72. kernel_size=kernel_size,
  73. stride=stride,
  74. padding=padding,
  75. dilation=dilation,
  76. groups=groups,
  77. bias=bias,
  78. padding_mode="zeros",
  79. )
  80. if use_normalization:
  81. self.normalization = nn.BatchNorm2d(
  82. num_features=out_channels,
  83. eps=1e-5,
  84. momentum=0.1,
  85. affine=True,
  86. track_running_stats=True,
  87. )
  88. else:
  89. self.normalization = None
  90. if use_activation:
  91. if isinstance(use_activation, str):
  92. self.activation = ACT2FN[use_activation]
  93. elif isinstance(config.hidden_act, str):
  94. self.activation = ACT2FN[config.hidden_act]
  95. else:
  96. self.activation = config.hidden_act
  97. else:
  98. self.activation = None
  99. def forward(self, features: torch.Tensor) -> torch.Tensor:
  100. features = self.convolution(features)
  101. if self.normalization is not None:
  102. features = self.normalization(features)
  103. if self.activation is not None:
  104. features = self.activation(features)
  105. return features
  106. class MobileViTInvertedResidual(nn.Module):
  107. """
  108. Inverted residual block (MobileNetv2): https://huggingface.co/papers/1801.04381
  109. """
  110. def __init__(
  111. self, config: MobileViTConfig, in_channels: int, out_channels: int, stride: int, dilation: int = 1
  112. ) -> None:
  113. super().__init__()
  114. expanded_channels = make_divisible(int(round(in_channels * config.expand_ratio)), 8)
  115. if stride not in [1, 2]:
  116. raise ValueError(f"Invalid stride {stride}.")
  117. self.use_residual = (stride == 1) and (in_channels == out_channels)
  118. self.expand_1x1 = MobileViTConvLayer(
  119. config, in_channels=in_channels, out_channels=expanded_channels, kernel_size=1
  120. )
  121. self.conv_3x3 = MobileViTConvLayer(
  122. config,
  123. in_channels=expanded_channels,
  124. out_channels=expanded_channels,
  125. kernel_size=3,
  126. stride=stride,
  127. groups=expanded_channels,
  128. dilation=dilation,
  129. )
  130. self.reduce_1x1 = MobileViTConvLayer(
  131. config,
  132. in_channels=expanded_channels,
  133. out_channels=out_channels,
  134. kernel_size=1,
  135. use_activation=False,
  136. )
  137. def forward(self, features: torch.Tensor) -> torch.Tensor:
  138. residual = features
  139. features = self.expand_1x1(features)
  140. features = self.conv_3x3(features)
  141. features = self.reduce_1x1(features)
  142. return residual + features if self.use_residual else features
  143. class MobileViTMobileNetLayer(nn.Module):
  144. def __init__(
  145. self, config: MobileViTConfig, in_channels: int, out_channels: int, stride: int = 1, num_stages: int = 1
  146. ) -> None:
  147. super().__init__()
  148. self.layer = nn.ModuleList()
  149. for i in range(num_stages):
  150. layer = MobileViTInvertedResidual(
  151. config,
  152. in_channels=in_channels,
  153. out_channels=out_channels,
  154. stride=stride if i == 0 else 1,
  155. )
  156. self.layer.append(layer)
  157. in_channels = out_channels
  158. def forward(self, features: torch.Tensor) -> torch.Tensor:
  159. for layer_module in self.layer:
  160. features = layer_module(features)
  161. return features
  162. class MobileViTSelfAttention(nn.Module):
  163. def __init__(self, config: MobileViTConfig, hidden_size: int) -> None:
  164. super().__init__()
  165. if hidden_size % config.num_attention_heads != 0:
  166. raise ValueError(
  167. f"The hidden size {hidden_size} is not a multiple of the number of attention "
  168. f"heads {config.num_attention_heads}."
  169. )
  170. self.num_attention_heads = config.num_attention_heads
  171. self.attention_head_size = int(hidden_size / config.num_attention_heads)
  172. self.all_head_size = self.num_attention_heads * self.attention_head_size
  173. self.query = nn.Linear(hidden_size, self.all_head_size, bias=config.qkv_bias)
  174. self.key = nn.Linear(hidden_size, self.all_head_size, bias=config.qkv_bias)
  175. self.value = nn.Linear(hidden_size, self.all_head_size, bias=config.qkv_bias)
  176. self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
  177. def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
  178. batch_size, seq_length, _ = hidden_states.shape
  179. query_layer = (
  180. self.query(hidden_states)
  181. .view(batch_size, -1, self.num_attention_heads, self.attention_head_size)
  182. .transpose(1, 2)
  183. )
  184. key_layer = (
  185. self.key(hidden_states)
  186. .view(batch_size, -1, self.num_attention_heads, self.attention_head_size)
  187. .transpose(1, 2)
  188. )
  189. value_layer = (
  190. self.value(hidden_states)
  191. .view(batch_size, -1, self.num_attention_heads, self.attention_head_size)
  192. .transpose(1, 2)
  193. )
  194. # Take the dot product between "query" and "key" to get the raw attention scores.
  195. attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
  196. attention_scores = attention_scores / math.sqrt(self.attention_head_size)
  197. # Normalize the attention scores to probabilities.
  198. attention_probs = nn.functional.softmax(attention_scores, dim=-1)
  199. # This is actually dropping out entire tokens to attend to, which might
  200. # seem a bit unusual, but is taken from the original Transformer paper.
  201. attention_probs = self.dropout(attention_probs)
  202. context_layer = torch.matmul(attention_probs, value_layer)
  203. context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
  204. new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
  205. context_layer = context_layer.view(*new_context_layer_shape)
  206. return context_layer
  207. class MobileViTSelfOutput(nn.Module):
  208. def __init__(self, config: MobileViTConfig, hidden_size: int) -> None:
  209. super().__init__()
  210. self.dense = nn.Linear(hidden_size, hidden_size)
  211. self.dropout = nn.Dropout(config.hidden_dropout_prob)
  212. def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
  213. hidden_states = self.dense(hidden_states)
  214. hidden_states = self.dropout(hidden_states)
  215. return hidden_states
  216. class MobileViTAttention(nn.Module):
  217. def __init__(self, config: MobileViTConfig, hidden_size: int) -> None:
  218. super().__init__()
  219. self.attention = MobileViTSelfAttention(config, hidden_size)
  220. self.output = MobileViTSelfOutput(config, hidden_size)
  221. self.pruned_heads = set()
  222. def prune_heads(self, heads: set[int]) -> None:
  223. if len(heads) == 0:
  224. return
  225. heads, index = find_pruneable_heads_and_indices(
  226. heads, self.attention.num_attention_heads, self.attention.attention_head_size, self.pruned_heads
  227. )
  228. # Prune linear layers
  229. self.attention.query = prune_linear_layer(self.attention.query, index)
  230. self.attention.key = prune_linear_layer(self.attention.key, index)
  231. self.attention.value = prune_linear_layer(self.attention.value, index)
  232. self.output.dense = prune_linear_layer(self.output.dense, index, dim=1)
  233. # Update hyper params and store pruned heads
  234. self.attention.num_attention_heads = self.attention.num_attention_heads - len(heads)
  235. self.attention.all_head_size = self.attention.attention_head_size * self.attention.num_attention_heads
  236. self.pruned_heads = self.pruned_heads.union(heads)
  237. def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
  238. self_outputs = self.attention(hidden_states)
  239. attention_output = self.output(self_outputs)
  240. return attention_output
  241. class MobileViTIntermediate(nn.Module):
  242. def __init__(self, config: MobileViTConfig, hidden_size: int, intermediate_size: int) -> None:
  243. super().__init__()
  244. self.dense = nn.Linear(hidden_size, intermediate_size)
  245. if isinstance(config.hidden_act, str):
  246. self.intermediate_act_fn = ACT2FN[config.hidden_act]
  247. else:
  248. self.intermediate_act_fn = config.hidden_act
  249. def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
  250. hidden_states = self.dense(hidden_states)
  251. hidden_states = self.intermediate_act_fn(hidden_states)
  252. return hidden_states
  253. class MobileViTOutput(nn.Module):
  254. def __init__(self, config: MobileViTConfig, hidden_size: int, intermediate_size: int) -> None:
  255. super().__init__()
  256. self.dense = nn.Linear(intermediate_size, hidden_size)
  257. self.dropout = nn.Dropout(config.hidden_dropout_prob)
  258. def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor:
  259. hidden_states = self.dense(hidden_states)
  260. hidden_states = self.dropout(hidden_states)
  261. hidden_states = hidden_states + input_tensor
  262. return hidden_states
  263. class MobileViTTransformerLayer(nn.Module):
  264. def __init__(self, config: MobileViTConfig, hidden_size: int, intermediate_size: int) -> None:
  265. super().__init__()
  266. self.attention = MobileViTAttention(config, hidden_size)
  267. self.intermediate = MobileViTIntermediate(config, hidden_size, intermediate_size)
  268. self.output = MobileViTOutput(config, hidden_size, intermediate_size)
  269. self.layernorm_before = nn.LayerNorm(hidden_size, eps=config.layer_norm_eps)
  270. self.layernorm_after = nn.LayerNorm(hidden_size, eps=config.layer_norm_eps)
  271. def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
  272. attention_output = self.attention(self.layernorm_before(hidden_states))
  273. hidden_states = attention_output + hidden_states
  274. layer_output = self.layernorm_after(hidden_states)
  275. layer_output = self.intermediate(layer_output)
  276. layer_output = self.output(layer_output, hidden_states)
  277. return layer_output
  278. class MobileViTTransformer(nn.Module):
  279. def __init__(self, config: MobileViTConfig, hidden_size: int, num_stages: int) -> None:
  280. super().__init__()
  281. self.layer = nn.ModuleList()
  282. for _ in range(num_stages):
  283. transformer_layer = MobileViTTransformerLayer(
  284. config,
  285. hidden_size=hidden_size,
  286. intermediate_size=int(hidden_size * config.mlp_ratio),
  287. )
  288. self.layer.append(transformer_layer)
  289. def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
  290. for layer_module in self.layer:
  291. hidden_states = layer_module(hidden_states)
  292. return hidden_states
  293. class MobileViTLayer(GradientCheckpointingLayer):
  294. """
  295. MobileViT block: https://huggingface.co/papers/2110.02178
  296. """
  297. def __init__(
  298. self,
  299. config: MobileViTConfig,
  300. in_channels: int,
  301. out_channels: int,
  302. stride: int,
  303. hidden_size: int,
  304. num_stages: int,
  305. dilation: int = 1,
  306. ) -> None:
  307. super().__init__()
  308. self.patch_width = config.patch_size
  309. self.patch_height = config.patch_size
  310. if stride == 2:
  311. self.downsampling_layer = MobileViTInvertedResidual(
  312. config,
  313. in_channels=in_channels,
  314. out_channels=out_channels,
  315. stride=stride if dilation == 1 else 1,
  316. dilation=dilation // 2 if dilation > 1 else 1,
  317. )
  318. in_channels = out_channels
  319. else:
  320. self.downsampling_layer = None
  321. self.conv_kxk = MobileViTConvLayer(
  322. config,
  323. in_channels=in_channels,
  324. out_channels=in_channels,
  325. kernel_size=config.conv_kernel_size,
  326. )
  327. self.conv_1x1 = MobileViTConvLayer(
  328. config,
  329. in_channels=in_channels,
  330. out_channels=hidden_size,
  331. kernel_size=1,
  332. use_normalization=False,
  333. use_activation=False,
  334. )
  335. self.transformer = MobileViTTransformer(
  336. config,
  337. hidden_size=hidden_size,
  338. num_stages=num_stages,
  339. )
  340. self.layernorm = nn.LayerNorm(hidden_size, eps=config.layer_norm_eps)
  341. self.conv_projection = MobileViTConvLayer(
  342. config, in_channels=hidden_size, out_channels=in_channels, kernel_size=1
  343. )
  344. self.fusion = MobileViTConvLayer(
  345. config, in_channels=2 * in_channels, out_channels=in_channels, kernel_size=config.conv_kernel_size
  346. )
  347. def unfolding(self, features: torch.Tensor) -> tuple[torch.Tensor, dict]:
  348. patch_width, patch_height = self.patch_width, self.patch_height
  349. patch_area = int(patch_width * patch_height)
  350. batch_size, channels, orig_height, orig_width = features.shape
  351. new_height = (
  352. torch_int(torch.ceil(orig_height / patch_height) * patch_height)
  353. if torch.jit.is_tracing()
  354. else int(math.ceil(orig_height / patch_height) * patch_height)
  355. )
  356. new_width = (
  357. torch_int(torch.ceil(orig_width / patch_width) * patch_width)
  358. if torch.jit.is_tracing()
  359. else int(math.ceil(orig_width / patch_width) * patch_width)
  360. )
  361. interpolate = False
  362. if new_width != orig_width or new_height != orig_height:
  363. # Note: Padding can be done, but then it needs to be handled in attention function.
  364. features = nn.functional.interpolate(
  365. features, size=(new_height, new_width), mode="bilinear", align_corners=False
  366. )
  367. interpolate = True
  368. # number of patches along width and height
  369. num_patch_width = new_width // patch_width
  370. num_patch_height = new_height // patch_height
  371. num_patches = num_patch_height * num_patch_width
  372. # convert from shape (batch_size, channels, orig_height, orig_width)
  373. # to the shape (batch_size * patch_area, num_patches, channels)
  374. patches = features.reshape(
  375. batch_size * channels * num_patch_height, patch_height, num_patch_width, patch_width
  376. )
  377. patches = patches.transpose(1, 2)
  378. patches = patches.reshape(batch_size, channels, num_patches, patch_area)
  379. patches = patches.transpose(1, 3)
  380. patches = patches.reshape(batch_size * patch_area, num_patches, -1)
  381. info_dict = {
  382. "orig_size": (orig_height, orig_width),
  383. "batch_size": batch_size,
  384. "channels": channels,
  385. "interpolate": interpolate,
  386. "num_patches": num_patches,
  387. "num_patches_width": num_patch_width,
  388. "num_patches_height": num_patch_height,
  389. }
  390. return patches, info_dict
  391. def folding(self, patches: torch.Tensor, info_dict: dict) -> torch.Tensor:
  392. patch_width, patch_height = self.patch_width, self.patch_height
  393. patch_area = int(patch_width * patch_height)
  394. batch_size = info_dict["batch_size"]
  395. channels = info_dict["channels"]
  396. num_patches = info_dict["num_patches"]
  397. num_patch_height = info_dict["num_patches_height"]
  398. num_patch_width = info_dict["num_patches_width"]
  399. # convert from shape (batch_size * patch_area, num_patches, channels)
  400. # back to shape (batch_size, channels, orig_height, orig_width)
  401. features = patches.contiguous().view(batch_size, patch_area, num_patches, -1)
  402. features = features.transpose(1, 3)
  403. features = features.reshape(
  404. batch_size * channels * num_patch_height, num_patch_width, patch_height, patch_width
  405. )
  406. features = features.transpose(1, 2)
  407. features = features.reshape(
  408. batch_size, channels, num_patch_height * patch_height, num_patch_width * patch_width
  409. )
  410. if info_dict["interpolate"]:
  411. features = nn.functional.interpolate(
  412. features, size=info_dict["orig_size"], mode="bilinear", align_corners=False
  413. )
  414. return features
  415. def forward(self, features: torch.Tensor) -> torch.Tensor:
  416. # reduce spatial dimensions if needed
  417. if self.downsampling_layer:
  418. features = self.downsampling_layer(features)
  419. residual = features
  420. # local representation
  421. features = self.conv_kxk(features)
  422. features = self.conv_1x1(features)
  423. # convert feature map to patches
  424. patches, info_dict = self.unfolding(features)
  425. # learn global representations
  426. patches = self.transformer(patches)
  427. patches = self.layernorm(patches)
  428. # convert patches back to feature maps
  429. features = self.folding(patches, info_dict)
  430. features = self.conv_projection(features)
  431. features = self.fusion(torch.cat((residual, features), dim=1))
  432. return features
  433. class MobileViTEncoder(nn.Module):
  434. def __init__(self, config: MobileViTConfig) -> None:
  435. super().__init__()
  436. self.config = config
  437. self.layer = nn.ModuleList()
  438. self.gradient_checkpointing = False
  439. # segmentation architectures like DeepLab and PSPNet modify the strides
  440. # of the classification backbones
  441. dilate_layer_4 = dilate_layer_5 = False
  442. if config.output_stride == 8:
  443. dilate_layer_4 = True
  444. dilate_layer_5 = True
  445. elif config.output_stride == 16:
  446. dilate_layer_5 = True
  447. dilation = 1
  448. layer_1 = MobileViTMobileNetLayer(
  449. config,
  450. in_channels=config.neck_hidden_sizes[0],
  451. out_channels=config.neck_hidden_sizes[1],
  452. stride=1,
  453. num_stages=1,
  454. )
  455. self.layer.append(layer_1)
  456. layer_2 = MobileViTMobileNetLayer(
  457. config,
  458. in_channels=config.neck_hidden_sizes[1],
  459. out_channels=config.neck_hidden_sizes[2],
  460. stride=2,
  461. num_stages=3,
  462. )
  463. self.layer.append(layer_2)
  464. layer_3 = MobileViTLayer(
  465. config,
  466. in_channels=config.neck_hidden_sizes[2],
  467. out_channels=config.neck_hidden_sizes[3],
  468. stride=2,
  469. hidden_size=config.hidden_sizes[0],
  470. num_stages=2,
  471. )
  472. self.layer.append(layer_3)
  473. if dilate_layer_4:
  474. dilation *= 2
  475. layer_4 = MobileViTLayer(
  476. config,
  477. in_channels=config.neck_hidden_sizes[3],
  478. out_channels=config.neck_hidden_sizes[4],
  479. stride=2,
  480. hidden_size=config.hidden_sizes[1],
  481. num_stages=4,
  482. dilation=dilation,
  483. )
  484. self.layer.append(layer_4)
  485. if dilate_layer_5:
  486. dilation *= 2
  487. layer_5 = MobileViTLayer(
  488. config,
  489. in_channels=config.neck_hidden_sizes[4],
  490. out_channels=config.neck_hidden_sizes[5],
  491. stride=2,
  492. hidden_size=config.hidden_sizes[2],
  493. num_stages=3,
  494. dilation=dilation,
  495. )
  496. self.layer.append(layer_5)
  497. def forward(
  498. self,
  499. hidden_states: torch.Tensor,
  500. output_hidden_states: bool = False,
  501. return_dict: bool = True,
  502. ) -> Union[tuple, BaseModelOutputWithNoAttention]:
  503. all_hidden_states = () if output_hidden_states else None
  504. for i, layer_module in enumerate(self.layer):
  505. hidden_states = layer_module(hidden_states)
  506. if output_hidden_states:
  507. all_hidden_states = all_hidden_states + (hidden_states,)
  508. if not return_dict:
  509. return tuple(v for v in [hidden_states, all_hidden_states] if v is not None)
  510. return BaseModelOutputWithNoAttention(last_hidden_state=hidden_states, hidden_states=all_hidden_states)
  511. @auto_docstring
  512. class MobileViTPreTrainedModel(PreTrainedModel):
  513. config: MobileViTConfig
  514. base_model_prefix = "mobilevit"
  515. main_input_name = "pixel_values"
  516. supports_gradient_checkpointing = True
  517. _no_split_modules = ["MobileViTLayer"]
  518. def _init_weights(self, module: nn.Module) -> None:
  519. """Initialize the weights"""
  520. if isinstance(module, (nn.Linear, nn.Conv2d, nn.BatchNorm2d)):
  521. # Slightly different from the TF version which uses truncated_normal for initialization
  522. # cf https://github.com/pytorch/pytorch/pull/5617
  523. module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
  524. if module.bias is not None:
  525. module.bias.data.zero_()
  526. elif isinstance(module, nn.LayerNorm):
  527. module.bias.data.zero_()
  528. module.weight.data.fill_(1.0)
  529. @auto_docstring
  530. class MobileViTModel(MobileViTPreTrainedModel):
  531. def __init__(self, config: MobileViTConfig, expand_output: bool = True):
  532. r"""
  533. expand_output (`bool`, *optional*, defaults to `True`):
  534. Whether to expand the output of the model using a 1x1 convolution. If `True`, the model will apply an additional
  535. 1x1 convolution to expand the output channels from `config.neck_hidden_sizes[5]` to `config.neck_hidden_sizes[6]`.
  536. """
  537. super().__init__(config)
  538. self.config = config
  539. self.expand_output = expand_output
  540. self.conv_stem = MobileViTConvLayer(
  541. config,
  542. in_channels=config.num_channels,
  543. out_channels=config.neck_hidden_sizes[0],
  544. kernel_size=3,
  545. stride=2,
  546. )
  547. self.encoder = MobileViTEncoder(config)
  548. if self.expand_output:
  549. self.conv_1x1_exp = MobileViTConvLayer(
  550. config,
  551. in_channels=config.neck_hidden_sizes[5],
  552. out_channels=config.neck_hidden_sizes[6],
  553. kernel_size=1,
  554. )
  555. # Initialize weights and apply final processing
  556. self.post_init()
  557. def _prune_heads(self, heads_to_prune):
  558. """Prunes heads of the model.
  559. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base class PreTrainedModel
  560. """
  561. for layer_index, heads in heads_to_prune.items():
  562. mobilevit_layer = self.encoder.layer[layer_index]
  563. if isinstance(mobilevit_layer, MobileViTLayer):
  564. for transformer_layer in mobilevit_layer.transformer.layer:
  565. transformer_layer.attention.prune_heads(heads)
  566. @auto_docstring
  567. def forward(
  568. self,
  569. pixel_values: Optional[torch.Tensor] = None,
  570. output_hidden_states: Optional[bool] = None,
  571. return_dict: Optional[bool] = None,
  572. ) -> Union[tuple, BaseModelOutputWithPoolingAndNoAttention]:
  573. output_hidden_states = (
  574. output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
  575. )
  576. return_dict = return_dict if return_dict is not None else self.config.use_return_dict
  577. if pixel_values is None:
  578. raise ValueError("You have to specify pixel_values")
  579. embedding_output = self.conv_stem(pixel_values)
  580. encoder_outputs = self.encoder(
  581. embedding_output,
  582. output_hidden_states=output_hidden_states,
  583. return_dict=return_dict,
  584. )
  585. if self.expand_output:
  586. last_hidden_state = self.conv_1x1_exp(encoder_outputs[0])
  587. # global average pooling: (batch_size, channels, height, width) -> (batch_size, channels)
  588. pooled_output = torch.mean(last_hidden_state, dim=[-2, -1], keepdim=False)
  589. else:
  590. last_hidden_state = encoder_outputs[0]
  591. pooled_output = None
  592. if not return_dict:
  593. output = (last_hidden_state, pooled_output) if pooled_output is not None else (last_hidden_state,)
  594. return output + encoder_outputs[1:]
  595. return BaseModelOutputWithPoolingAndNoAttention(
  596. last_hidden_state=last_hidden_state,
  597. pooler_output=pooled_output,
  598. hidden_states=encoder_outputs.hidden_states,
  599. )
  600. @auto_docstring(
  601. custom_intro="""
  602. MobileViT model with an image classification head on top (a linear layer on top of the pooled features), e.g. for
  603. ImageNet.
  604. """
  605. )
  606. class MobileViTForImageClassification(MobileViTPreTrainedModel):
  607. def __init__(self, config: MobileViTConfig) -> None:
  608. super().__init__(config)
  609. self.num_labels = config.num_labels
  610. self.mobilevit = MobileViTModel(config)
  611. # Classifier head
  612. self.dropout = nn.Dropout(config.classifier_dropout_prob, inplace=True)
  613. self.classifier = (
  614. nn.Linear(config.neck_hidden_sizes[-1], config.num_labels) if config.num_labels > 0 else nn.Identity()
  615. )
  616. # Initialize weights and apply final processing
  617. self.post_init()
  618. @auto_docstring
  619. def forward(
  620. self,
  621. pixel_values: Optional[torch.Tensor] = None,
  622. output_hidden_states: Optional[bool] = None,
  623. labels: Optional[torch.Tensor] = None,
  624. return_dict: Optional[bool] = None,
  625. ) -> Union[tuple, ImageClassifierOutputWithNoAttention]:
  626. r"""
  627. labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
  628. Labels for computing the image classification/regression loss. Indices should be in `[0, ...,
  629. config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss). If
  630. `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
  631. """
  632. return_dict = return_dict if return_dict is not None else self.config.use_return_dict
  633. outputs = self.mobilevit(pixel_values, output_hidden_states=output_hidden_states, return_dict=return_dict)
  634. pooled_output = outputs.pooler_output if return_dict else outputs[1]
  635. logits = self.classifier(self.dropout(pooled_output))
  636. loss = None
  637. if labels is not None:
  638. loss = self.loss_function(labels, logits, self.config)
  639. if not return_dict:
  640. output = (logits,) + outputs[2:]
  641. return ((loss,) + output) if loss is not None else output
  642. return ImageClassifierOutputWithNoAttention(
  643. loss=loss,
  644. logits=logits,
  645. hidden_states=outputs.hidden_states,
  646. )
  647. class MobileViTASPPPooling(nn.Module):
  648. def __init__(self, config: MobileViTConfig, in_channels: int, out_channels: int) -> None:
  649. super().__init__()
  650. self.global_pool = nn.AdaptiveAvgPool2d(output_size=1)
  651. self.conv_1x1 = MobileViTConvLayer(
  652. config,
  653. in_channels=in_channels,
  654. out_channels=out_channels,
  655. kernel_size=1,
  656. stride=1,
  657. use_normalization=True,
  658. use_activation="relu",
  659. )
  660. def forward(self, features: torch.Tensor) -> torch.Tensor:
  661. spatial_size = features.shape[-2:]
  662. features = self.global_pool(features)
  663. features = self.conv_1x1(features)
  664. features = nn.functional.interpolate(features, size=spatial_size, mode="bilinear", align_corners=False)
  665. return features
  666. class MobileViTASPP(nn.Module):
  667. """
  668. ASPP module defined in DeepLab papers: https://huggingface.co/papers/1606.00915, https://huggingface.co/papers/1706.05587
  669. """
  670. def __init__(self, config: MobileViTConfig) -> None:
  671. super().__init__()
  672. in_channels = config.neck_hidden_sizes[-2]
  673. out_channels = config.aspp_out_channels
  674. if len(config.atrous_rates) != 3:
  675. raise ValueError("Expected 3 values for atrous_rates")
  676. self.convs = nn.ModuleList()
  677. in_projection = MobileViTConvLayer(
  678. config,
  679. in_channels=in_channels,
  680. out_channels=out_channels,
  681. kernel_size=1,
  682. use_activation="relu",
  683. )
  684. self.convs.append(in_projection)
  685. self.convs.extend(
  686. [
  687. MobileViTConvLayer(
  688. config,
  689. in_channels=in_channels,
  690. out_channels=out_channels,
  691. kernel_size=3,
  692. dilation=rate,
  693. use_activation="relu",
  694. )
  695. for rate in config.atrous_rates
  696. ]
  697. )
  698. pool_layer = MobileViTASPPPooling(config, in_channels, out_channels)
  699. self.convs.append(pool_layer)
  700. self.project = MobileViTConvLayer(
  701. config, in_channels=5 * out_channels, out_channels=out_channels, kernel_size=1, use_activation="relu"
  702. )
  703. self.dropout = nn.Dropout(p=config.aspp_dropout_prob)
  704. def forward(self, features: torch.Tensor) -> torch.Tensor:
  705. pyramid = []
  706. for conv in self.convs:
  707. pyramid.append(conv(features))
  708. pyramid = torch.cat(pyramid, dim=1)
  709. pooled_features = self.project(pyramid)
  710. pooled_features = self.dropout(pooled_features)
  711. return pooled_features
  712. class MobileViTDeepLabV3(nn.Module):
  713. """
  714. DeepLabv3 architecture: https://huggingface.co/papers/1706.05587
  715. """
  716. def __init__(self, config: MobileViTConfig) -> None:
  717. super().__init__()
  718. self.aspp = MobileViTASPP(config)
  719. self.dropout = nn.Dropout2d(config.classifier_dropout_prob)
  720. self.classifier = MobileViTConvLayer(
  721. config,
  722. in_channels=config.aspp_out_channels,
  723. out_channels=config.num_labels,
  724. kernel_size=1,
  725. use_normalization=False,
  726. use_activation=False,
  727. bias=True,
  728. )
  729. def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
  730. features = self.aspp(hidden_states[-1])
  731. features = self.dropout(features)
  732. features = self.classifier(features)
  733. return features
  734. @auto_docstring(
  735. custom_intro="""
  736. MobileViT model with a semantic segmentation head on top, e.g. for Pascal VOC.
  737. """
  738. )
  739. class MobileViTForSemanticSegmentation(MobileViTPreTrainedModel):
  740. def __init__(self, config: MobileViTConfig) -> None:
  741. super().__init__(config)
  742. self.num_labels = config.num_labels
  743. self.mobilevit = MobileViTModel(config, expand_output=False)
  744. self.segmentation_head = MobileViTDeepLabV3(config)
  745. # Initialize weights and apply final processing
  746. self.post_init()
  747. @auto_docstring
  748. def forward(
  749. self,
  750. pixel_values: Optional[torch.Tensor] = None,
  751. labels: Optional[torch.Tensor] = None,
  752. output_hidden_states: Optional[bool] = None,
  753. return_dict: Optional[bool] = None,
  754. ) -> Union[tuple, SemanticSegmenterOutput]:
  755. r"""
  756. labels (`torch.LongTensor` of shape `(batch_size, height, width)`, *optional*):
  757. Ground truth semantic segmentation maps for computing the loss. Indices should be in `[0, ...,
  758. config.num_labels - 1]`. If `config.num_labels > 1`, a classification loss is computed (Cross-Entropy).
  759. Examples:
  760. ```python
  761. >>> import requests
  762. >>> import torch
  763. >>> from PIL import Image
  764. >>> from transformers import AutoImageProcessor, MobileViTForSemanticSegmentation
  765. >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
  766. >>> image = Image.open(requests.get(url, stream=True).raw)
  767. >>> image_processor = AutoImageProcessor.from_pretrained("apple/deeplabv3-mobilevit-small")
  768. >>> model = MobileViTForSemanticSegmentation.from_pretrained("apple/deeplabv3-mobilevit-small")
  769. >>> inputs = image_processor(images=image, return_tensors="pt")
  770. >>> with torch.no_grad():
  771. ... outputs = model(**inputs)
  772. >>> # logits are of shape (batch_size, num_labels, height, width)
  773. >>> logits = outputs.logits
  774. ```"""
  775. output_hidden_states = (
  776. output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
  777. )
  778. return_dict = return_dict if return_dict is not None else self.config.use_return_dict
  779. if labels is not None and self.config.num_labels == 1:
  780. raise ValueError("The number of labels should be greater than one")
  781. outputs = self.mobilevit(
  782. pixel_values,
  783. output_hidden_states=True, # we need the intermediate hidden states
  784. return_dict=return_dict,
  785. )
  786. encoder_hidden_states = outputs.hidden_states if return_dict else outputs[1]
  787. logits = self.segmentation_head(encoder_hidden_states)
  788. loss = None
  789. if labels is not None:
  790. # upsample logits to the images' original size
  791. upsampled_logits = nn.functional.interpolate(
  792. logits, size=labels.shape[-2:], mode="bilinear", align_corners=False
  793. )
  794. loss_fct = CrossEntropyLoss(ignore_index=self.config.semantic_loss_ignore_index)
  795. loss = loss_fct(upsampled_logits, labels)
  796. if not return_dict:
  797. if output_hidden_states:
  798. output = (logits,) + outputs[1:]
  799. else:
  800. output = (logits,) + outputs[2:]
  801. return ((loss,) + output) if loss is not None else output
  802. return SemanticSegmenterOutput(
  803. loss=loss,
  804. logits=logits,
  805. hidden_states=outputs.hidden_states if output_hidden_states else None,
  806. attentions=None,
  807. )
  808. __all__ = [
  809. "MobileViTForImageClassification",
  810. "MobileViTForSemanticSegmentation",
  811. "MobileViTModel",
  812. "MobileViTPreTrainedModel",
  813. ]