modeling_mobilevitv2.py 34 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947
  1. # coding=utf-8
  2. # Copyright 2023 Apple Inc. and The HuggingFace Inc. team. All rights reserved.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. #
  16. # Original license: https://github.com/apple/ml-cvnets/blob/main/LICENSE
  17. """PyTorch MobileViTV2 model."""
  18. from typing import Optional, Union
  19. import torch
  20. from torch import nn
  21. from torch.nn import CrossEntropyLoss
  22. from ...activations import ACT2FN
  23. from ...modeling_layers import GradientCheckpointingLayer
  24. from ...modeling_outputs import (
  25. BaseModelOutputWithNoAttention,
  26. BaseModelOutputWithPoolingAndNoAttention,
  27. ImageClassifierOutputWithNoAttention,
  28. SemanticSegmenterOutput,
  29. )
  30. from ...modeling_utils import PreTrainedModel
  31. from ...utils import auto_docstring, logging
  32. from .configuration_mobilevitv2 import MobileViTV2Config
  33. logger = logging.get_logger(__name__)
  34. # Copied from transformers.models.mobilevit.modeling_mobilevit.make_divisible
  35. def make_divisible(value: int, divisor: int = 8, min_value: Optional[int] = None) -> int:
  36. """
  37. Ensure that all layers have a channel count that is divisible by `divisor`. This function is taken from the
  38. original TensorFlow repo. It can be seen here:
  39. https://github.com/tensorflow/models/blob/master/research/slim/nets/mobilenet/mobilenet.py
  40. """
  41. if min_value is None:
  42. min_value = divisor
  43. new_value = max(min_value, int(value + divisor / 2) // divisor * divisor)
  44. # Make sure that round down does not go down by more than 10%.
  45. if new_value < 0.9 * value:
  46. new_value += divisor
  47. return int(new_value)
  48. def clip(value: float, min_val: float = float("-inf"), max_val: float = float("inf")) -> float:
  49. return max(min_val, min(max_val, value))
  50. # Copied from transformers.models.mobilevit.modeling_mobilevit.MobileViTConvLayer with MobileViT->MobileViTV2
  51. class MobileViTV2ConvLayer(nn.Module):
  52. def __init__(
  53. self,
  54. config: MobileViTV2Config,
  55. in_channels: int,
  56. out_channels: int,
  57. kernel_size: int,
  58. stride: int = 1,
  59. groups: int = 1,
  60. bias: bool = False,
  61. dilation: int = 1,
  62. use_normalization: bool = True,
  63. use_activation: Union[bool, str] = True,
  64. ) -> None:
  65. super().__init__()
  66. padding = int((kernel_size - 1) / 2) * dilation
  67. if in_channels % groups != 0:
  68. raise ValueError(f"Input channels ({in_channels}) are not divisible by {groups} groups.")
  69. if out_channels % groups != 0:
  70. raise ValueError(f"Output channels ({out_channels}) are not divisible by {groups} groups.")
  71. self.convolution = nn.Conv2d(
  72. in_channels=in_channels,
  73. out_channels=out_channels,
  74. kernel_size=kernel_size,
  75. stride=stride,
  76. padding=padding,
  77. dilation=dilation,
  78. groups=groups,
  79. bias=bias,
  80. padding_mode="zeros",
  81. )
  82. if use_normalization:
  83. self.normalization = nn.BatchNorm2d(
  84. num_features=out_channels,
  85. eps=1e-5,
  86. momentum=0.1,
  87. affine=True,
  88. track_running_stats=True,
  89. )
  90. else:
  91. self.normalization = None
  92. if use_activation:
  93. if isinstance(use_activation, str):
  94. self.activation = ACT2FN[use_activation]
  95. elif isinstance(config.hidden_act, str):
  96. self.activation = ACT2FN[config.hidden_act]
  97. else:
  98. self.activation = config.hidden_act
  99. else:
  100. self.activation = None
  101. def forward(self, features: torch.Tensor) -> torch.Tensor:
  102. features = self.convolution(features)
  103. if self.normalization is not None:
  104. features = self.normalization(features)
  105. if self.activation is not None:
  106. features = self.activation(features)
  107. return features
  108. # Copied from transformers.models.mobilevit.modeling_mobilevit.MobileViTInvertedResidual with MobileViT->MobileViTV2
  109. class MobileViTV2InvertedResidual(nn.Module):
  110. """
  111. Inverted residual block (MobileNetv2): https://huggingface.co/papers/1801.04381
  112. """
  113. def __init__(
  114. self, config: MobileViTV2Config, in_channels: int, out_channels: int, stride: int, dilation: int = 1
  115. ) -> None:
  116. super().__init__()
  117. expanded_channels = make_divisible(int(round(in_channels * config.expand_ratio)), 8)
  118. if stride not in [1, 2]:
  119. raise ValueError(f"Invalid stride {stride}.")
  120. self.use_residual = (stride == 1) and (in_channels == out_channels)
  121. self.expand_1x1 = MobileViTV2ConvLayer(
  122. config, in_channels=in_channels, out_channels=expanded_channels, kernel_size=1
  123. )
  124. self.conv_3x3 = MobileViTV2ConvLayer(
  125. config,
  126. in_channels=expanded_channels,
  127. out_channels=expanded_channels,
  128. kernel_size=3,
  129. stride=stride,
  130. groups=expanded_channels,
  131. dilation=dilation,
  132. )
  133. self.reduce_1x1 = MobileViTV2ConvLayer(
  134. config,
  135. in_channels=expanded_channels,
  136. out_channels=out_channels,
  137. kernel_size=1,
  138. use_activation=False,
  139. )
  140. def forward(self, features: torch.Tensor) -> torch.Tensor:
  141. residual = features
  142. features = self.expand_1x1(features)
  143. features = self.conv_3x3(features)
  144. features = self.reduce_1x1(features)
  145. return residual + features if self.use_residual else features
  146. # Copied from transformers.models.mobilevit.modeling_mobilevit.MobileViTMobileNetLayer with MobileViT->MobileViTV2
  147. class MobileViTV2MobileNetLayer(nn.Module):
  148. def __init__(
  149. self, config: MobileViTV2Config, in_channels: int, out_channels: int, stride: int = 1, num_stages: int = 1
  150. ) -> None:
  151. super().__init__()
  152. self.layer = nn.ModuleList()
  153. for i in range(num_stages):
  154. layer = MobileViTV2InvertedResidual(
  155. config,
  156. in_channels=in_channels,
  157. out_channels=out_channels,
  158. stride=stride if i == 0 else 1,
  159. )
  160. self.layer.append(layer)
  161. in_channels = out_channels
  162. def forward(self, features: torch.Tensor) -> torch.Tensor:
  163. for layer_module in self.layer:
  164. features = layer_module(features)
  165. return features
  166. class MobileViTV2LinearSelfAttention(nn.Module):
  167. """
  168. This layer applies a self-attention with linear complexity, as described in MobileViTV2 paper:
  169. https://huggingface.co/papers/2206.02680
  170. Args:
  171. config (`MobileVitv2Config`):
  172. Model configuration object
  173. embed_dim (`int`):
  174. `input_channels` from an expected input of size :math:`(batch_size, input_channels, height, width)`
  175. """
  176. def __init__(self, config: MobileViTV2Config, embed_dim: int) -> None:
  177. super().__init__()
  178. self.qkv_proj = MobileViTV2ConvLayer(
  179. config=config,
  180. in_channels=embed_dim,
  181. out_channels=1 + (2 * embed_dim),
  182. bias=True,
  183. kernel_size=1,
  184. use_normalization=False,
  185. use_activation=False,
  186. )
  187. self.attn_dropout = nn.Dropout(p=config.attn_dropout)
  188. self.out_proj = MobileViTV2ConvLayer(
  189. config=config,
  190. in_channels=embed_dim,
  191. out_channels=embed_dim,
  192. bias=True,
  193. kernel_size=1,
  194. use_normalization=False,
  195. use_activation=False,
  196. )
  197. self.embed_dim = embed_dim
  198. def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
  199. # (batch_size, embed_dim, num_pixels_in_patch, num_patches) --> (batch_size, 1+2*embed_dim, num_pixels_in_patch, num_patches)
  200. qkv = self.qkv_proj(hidden_states)
  201. # Project hidden_states into query, key and value
  202. # Query --> [batch_size, 1, num_pixels_in_patch, num_patches]
  203. # value, key --> [batch_size, embed_dim, num_pixels_in_patch, num_patches]
  204. query, key, value = torch.split(qkv, split_size_or_sections=[1, self.embed_dim, self.embed_dim], dim=1)
  205. # apply softmax along num_patches dimension
  206. context_scores = torch.nn.functional.softmax(query, dim=-1)
  207. context_scores = self.attn_dropout(context_scores)
  208. # Compute context vector
  209. # [batch_size, embed_dim, num_pixels_in_patch, num_patches] x [batch_size, 1, num_pixels_in_patch, num_patches] -> [batch_size, embed_dim, num_pixels_in_patch, num_patches]
  210. context_vector = key * context_scores
  211. # [batch_size, embed_dim, num_pixels_in_patch, num_patches] --> [batch_size, embed_dim, num_pixels_in_patch, 1]
  212. context_vector = torch.sum(context_vector, dim=-1, keepdim=True)
  213. # combine context vector with values
  214. # [batch_size, embed_dim, num_pixels_in_patch, num_patches] * [batch_size, embed_dim, num_pixels_in_patch, 1] --> [batch_size, embed_dim, num_pixels_in_patch, num_patches]
  215. out = torch.nn.functional.relu(value) * context_vector.expand_as(value)
  216. out = self.out_proj(out)
  217. return out
  218. class MobileViTV2FFN(nn.Module):
  219. def __init__(
  220. self,
  221. config: MobileViTV2Config,
  222. embed_dim: int,
  223. ffn_latent_dim: int,
  224. ffn_dropout: float = 0.0,
  225. ) -> None:
  226. super().__init__()
  227. self.conv1 = MobileViTV2ConvLayer(
  228. config=config,
  229. in_channels=embed_dim,
  230. out_channels=ffn_latent_dim,
  231. kernel_size=1,
  232. stride=1,
  233. bias=True,
  234. use_normalization=False,
  235. use_activation=True,
  236. )
  237. self.dropout1 = nn.Dropout(ffn_dropout)
  238. self.conv2 = MobileViTV2ConvLayer(
  239. config=config,
  240. in_channels=ffn_latent_dim,
  241. out_channels=embed_dim,
  242. kernel_size=1,
  243. stride=1,
  244. bias=True,
  245. use_normalization=False,
  246. use_activation=False,
  247. )
  248. self.dropout2 = nn.Dropout(ffn_dropout)
  249. def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
  250. hidden_states = self.conv1(hidden_states)
  251. hidden_states = self.dropout1(hidden_states)
  252. hidden_states = self.conv2(hidden_states)
  253. hidden_states = self.dropout2(hidden_states)
  254. return hidden_states
  255. class MobileViTV2TransformerLayer(nn.Module):
  256. def __init__(
  257. self,
  258. config: MobileViTV2Config,
  259. embed_dim: int,
  260. ffn_latent_dim: int,
  261. dropout: float = 0.0,
  262. ) -> None:
  263. super().__init__()
  264. self.layernorm_before = nn.GroupNorm(num_groups=1, num_channels=embed_dim, eps=config.layer_norm_eps)
  265. self.attention = MobileViTV2LinearSelfAttention(config, embed_dim)
  266. self.dropout1 = nn.Dropout(p=dropout)
  267. self.layernorm_after = nn.GroupNorm(num_groups=1, num_channels=embed_dim, eps=config.layer_norm_eps)
  268. self.ffn = MobileViTV2FFN(config, embed_dim, ffn_latent_dim, config.ffn_dropout)
  269. def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
  270. layernorm_1_out = self.layernorm_before(hidden_states)
  271. attention_output = self.attention(layernorm_1_out)
  272. hidden_states = attention_output + hidden_states
  273. layer_output = self.layernorm_after(hidden_states)
  274. layer_output = self.ffn(layer_output)
  275. layer_output = layer_output + hidden_states
  276. return layer_output
  277. class MobileViTV2Transformer(nn.Module):
  278. def __init__(self, config: MobileViTV2Config, n_layers: int, d_model: int) -> None:
  279. super().__init__()
  280. ffn_multiplier = config.ffn_multiplier
  281. ffn_dims = [ffn_multiplier * d_model] * n_layers
  282. # ensure that dims are multiple of 16
  283. ffn_dims = [int((d // 16) * 16) for d in ffn_dims]
  284. self.layer = nn.ModuleList()
  285. for block_idx in range(n_layers):
  286. transformer_layer = MobileViTV2TransformerLayer(
  287. config, embed_dim=d_model, ffn_latent_dim=ffn_dims[block_idx]
  288. )
  289. self.layer.append(transformer_layer)
  290. def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
  291. for layer_module in self.layer:
  292. hidden_states = layer_module(hidden_states)
  293. return hidden_states
  294. class MobileViTV2Layer(GradientCheckpointingLayer):
  295. """
  296. MobileViTV2 layer: https://huggingface.co/papers/2206.02680
  297. """
  298. def __init__(
  299. self,
  300. config: MobileViTV2Config,
  301. in_channels: int,
  302. out_channels: int,
  303. attn_unit_dim: int,
  304. n_attn_blocks: int = 2,
  305. dilation: int = 1,
  306. stride: int = 2,
  307. ) -> None:
  308. super().__init__()
  309. self.patch_width = config.patch_size
  310. self.patch_height = config.patch_size
  311. cnn_out_dim = attn_unit_dim
  312. if stride == 2:
  313. self.downsampling_layer = MobileViTV2InvertedResidual(
  314. config,
  315. in_channels=in_channels,
  316. out_channels=out_channels,
  317. stride=stride if dilation == 1 else 1,
  318. dilation=dilation // 2 if dilation > 1 else 1,
  319. )
  320. in_channels = out_channels
  321. else:
  322. self.downsampling_layer = None
  323. # Local representations
  324. self.conv_kxk = MobileViTV2ConvLayer(
  325. config,
  326. in_channels=in_channels,
  327. out_channels=in_channels,
  328. kernel_size=config.conv_kernel_size,
  329. groups=in_channels,
  330. )
  331. self.conv_1x1 = MobileViTV2ConvLayer(
  332. config,
  333. in_channels=in_channels,
  334. out_channels=cnn_out_dim,
  335. kernel_size=1,
  336. use_normalization=False,
  337. use_activation=False,
  338. )
  339. # Global representations
  340. self.transformer = MobileViTV2Transformer(config, d_model=attn_unit_dim, n_layers=n_attn_blocks)
  341. # self.layernorm = MobileViTV2LayerNorm2D(attn_unit_dim, eps=config.layer_norm_eps)
  342. self.layernorm = nn.GroupNorm(num_groups=1, num_channels=attn_unit_dim, eps=config.layer_norm_eps)
  343. # Fusion
  344. self.conv_projection = MobileViTV2ConvLayer(
  345. config,
  346. in_channels=cnn_out_dim,
  347. out_channels=in_channels,
  348. kernel_size=1,
  349. use_normalization=True,
  350. use_activation=False,
  351. )
  352. def unfolding(self, feature_map: torch.Tensor) -> tuple[torch.Tensor, tuple[int, int]]:
  353. batch_size, in_channels, img_height, img_width = feature_map.shape
  354. patches = nn.functional.unfold(
  355. feature_map,
  356. kernel_size=(self.patch_height, self.patch_width),
  357. stride=(self.patch_height, self.patch_width),
  358. )
  359. patches = patches.reshape(batch_size, in_channels, self.patch_height * self.patch_width, -1)
  360. return patches, (img_height, img_width)
  361. def folding(self, patches: torch.Tensor, output_size: tuple[int, int]) -> torch.Tensor:
  362. batch_size, in_dim, patch_size, n_patches = patches.shape
  363. patches = patches.reshape(batch_size, in_dim * patch_size, n_patches)
  364. feature_map = nn.functional.fold(
  365. patches,
  366. output_size=output_size,
  367. kernel_size=(self.patch_height, self.patch_width),
  368. stride=(self.patch_height, self.patch_width),
  369. )
  370. return feature_map
  371. def forward(self, features: torch.Tensor) -> torch.Tensor:
  372. # reduce spatial dimensions if needed
  373. if self.downsampling_layer:
  374. features = self.downsampling_layer(features)
  375. # local representation
  376. features = self.conv_kxk(features)
  377. features = self.conv_1x1(features)
  378. # convert feature map to patches
  379. patches, output_size = self.unfolding(features)
  380. # learn global representations
  381. patches = self.transformer(patches)
  382. patches = self.layernorm(patches)
  383. # convert patches back to feature maps
  384. # [batch_size, patch_height, patch_width, input_dim] --> [batch_size, input_dim, patch_height, patch_width]
  385. features = self.folding(patches, output_size)
  386. features = self.conv_projection(features)
  387. return features
  388. class MobileViTV2Encoder(nn.Module):
  389. def __init__(self, config: MobileViTV2Config) -> None:
  390. super().__init__()
  391. self.config = config
  392. self.layer = nn.ModuleList()
  393. self.gradient_checkpointing = False
  394. # segmentation architectures like DeepLab and PSPNet modify the strides
  395. # of the classification backbones
  396. dilate_layer_4 = dilate_layer_5 = False
  397. if config.output_stride == 8:
  398. dilate_layer_4 = True
  399. dilate_layer_5 = True
  400. elif config.output_stride == 16:
  401. dilate_layer_5 = True
  402. dilation = 1
  403. layer_0_dim = make_divisible(
  404. clip(value=32 * config.width_multiplier, min_val=16, max_val=64), divisor=8, min_value=16
  405. )
  406. layer_1_dim = make_divisible(64 * config.width_multiplier, divisor=16)
  407. layer_2_dim = make_divisible(128 * config.width_multiplier, divisor=8)
  408. layer_3_dim = make_divisible(256 * config.width_multiplier, divisor=8)
  409. layer_4_dim = make_divisible(384 * config.width_multiplier, divisor=8)
  410. layer_5_dim = make_divisible(512 * config.width_multiplier, divisor=8)
  411. layer_1 = MobileViTV2MobileNetLayer(
  412. config,
  413. in_channels=layer_0_dim,
  414. out_channels=layer_1_dim,
  415. stride=1,
  416. num_stages=1,
  417. )
  418. self.layer.append(layer_1)
  419. layer_2 = MobileViTV2MobileNetLayer(
  420. config,
  421. in_channels=layer_1_dim,
  422. out_channels=layer_2_dim,
  423. stride=2,
  424. num_stages=2,
  425. )
  426. self.layer.append(layer_2)
  427. layer_3 = MobileViTV2Layer(
  428. config,
  429. in_channels=layer_2_dim,
  430. out_channels=layer_3_dim,
  431. attn_unit_dim=make_divisible(config.base_attn_unit_dims[0] * config.width_multiplier, divisor=8),
  432. n_attn_blocks=config.n_attn_blocks[0],
  433. )
  434. self.layer.append(layer_3)
  435. if dilate_layer_4:
  436. dilation *= 2
  437. layer_4 = MobileViTV2Layer(
  438. config,
  439. in_channels=layer_3_dim,
  440. out_channels=layer_4_dim,
  441. attn_unit_dim=make_divisible(config.base_attn_unit_dims[1] * config.width_multiplier, divisor=8),
  442. n_attn_blocks=config.n_attn_blocks[1],
  443. dilation=dilation,
  444. )
  445. self.layer.append(layer_4)
  446. if dilate_layer_5:
  447. dilation *= 2
  448. layer_5 = MobileViTV2Layer(
  449. config,
  450. in_channels=layer_4_dim,
  451. out_channels=layer_5_dim,
  452. attn_unit_dim=make_divisible(config.base_attn_unit_dims[2] * config.width_multiplier, divisor=8),
  453. n_attn_blocks=config.n_attn_blocks[2],
  454. dilation=dilation,
  455. )
  456. self.layer.append(layer_5)
  457. def forward(
  458. self,
  459. hidden_states: torch.Tensor,
  460. output_hidden_states: bool = False,
  461. return_dict: bool = True,
  462. ) -> Union[tuple, BaseModelOutputWithNoAttention]:
  463. all_hidden_states = () if output_hidden_states else None
  464. for i, layer_module in enumerate(self.layer):
  465. hidden_states = layer_module(hidden_states)
  466. if output_hidden_states:
  467. all_hidden_states = all_hidden_states + (hidden_states,)
  468. if not return_dict:
  469. return tuple(v for v in [hidden_states, all_hidden_states] if v is not None)
  470. return BaseModelOutputWithNoAttention(last_hidden_state=hidden_states, hidden_states=all_hidden_states)
  471. @auto_docstring
  472. class MobileViTV2PreTrainedModel(PreTrainedModel):
  473. config: MobileViTV2Config
  474. base_model_prefix = "mobilevitv2"
  475. main_input_name = "pixel_values"
  476. supports_gradient_checkpointing = True
  477. _no_split_modules = ["MobileViTV2Layer"]
  478. def _init_weights(self, module: nn.Module) -> None:
  479. """Initialize the weights"""
  480. if isinstance(module, (nn.Linear, nn.Conv2d, nn.BatchNorm2d)):
  481. # Slightly different from the TF version which uses truncated_normal for initialization
  482. # cf https://github.com/pytorch/pytorch/pull/5617
  483. module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
  484. if module.bias is not None:
  485. module.bias.data.zero_()
  486. elif isinstance(module, nn.GroupNorm):
  487. module.bias.data.zero_()
  488. module.weight.data.fill_(1.0)
  489. @auto_docstring
  490. class MobileViTV2Model(MobileViTV2PreTrainedModel):
  491. def __init__(self, config: MobileViTV2Config, expand_output: bool = True):
  492. r"""
  493. expand_output (`bool`, *optional*, defaults to `True`):
  494. Whether to expand the output of the model. If `True`, the model will output pooled features in addition to
  495. hidden states. If `False`, only the hidden states will be returned.
  496. """
  497. super().__init__(config)
  498. self.config = config
  499. self.expand_output = expand_output
  500. layer_0_dim = make_divisible(
  501. clip(value=32 * config.width_multiplier, min_val=16, max_val=64), divisor=8, min_value=16
  502. )
  503. self.conv_stem = MobileViTV2ConvLayer(
  504. config,
  505. in_channels=config.num_channels,
  506. out_channels=layer_0_dim,
  507. kernel_size=3,
  508. stride=2,
  509. use_normalization=True,
  510. use_activation=True,
  511. )
  512. self.encoder = MobileViTV2Encoder(config)
  513. # Initialize weights and apply final processing
  514. self.post_init()
  515. def _prune_heads(self, heads_to_prune):
  516. """Prunes heads of the model.
  517. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base class PreTrainedModel
  518. """
  519. for layer_index, heads in heads_to_prune.items():
  520. mobilevitv2_layer = self.encoder.layer[layer_index]
  521. if isinstance(mobilevitv2_layer, MobileViTV2Layer):
  522. for transformer_layer in mobilevitv2_layer.transformer.layer:
  523. transformer_layer.attention.prune_heads(heads)
  524. @auto_docstring
  525. def forward(
  526. self,
  527. pixel_values: Optional[torch.Tensor] = None,
  528. output_hidden_states: Optional[bool] = None,
  529. return_dict: Optional[bool] = None,
  530. ) -> Union[tuple, BaseModelOutputWithPoolingAndNoAttention]:
  531. output_hidden_states = (
  532. output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
  533. )
  534. return_dict = return_dict if return_dict is not None else self.config.use_return_dict
  535. if pixel_values is None:
  536. raise ValueError("You have to specify pixel_values")
  537. embedding_output = self.conv_stem(pixel_values)
  538. encoder_outputs = self.encoder(
  539. embedding_output,
  540. output_hidden_states=output_hidden_states,
  541. return_dict=return_dict,
  542. )
  543. if self.expand_output:
  544. last_hidden_state = encoder_outputs[0]
  545. # global average pooling: (batch_size, channels, height, width) -> (batch_size, channels)
  546. pooled_output = torch.mean(last_hidden_state, dim=[-2, -1], keepdim=False)
  547. else:
  548. last_hidden_state = encoder_outputs[0]
  549. pooled_output = None
  550. if not return_dict:
  551. output = (last_hidden_state, pooled_output) if pooled_output is not None else (last_hidden_state,)
  552. return output + encoder_outputs[1:]
  553. return BaseModelOutputWithPoolingAndNoAttention(
  554. last_hidden_state=last_hidden_state,
  555. pooler_output=pooled_output,
  556. hidden_states=encoder_outputs.hidden_states,
  557. )
  558. @auto_docstring(
  559. custom_intro="""
  560. MobileViTV2 model with an image classification head on top (a linear layer on top of the pooled features), e.g. for
  561. ImageNet.
  562. """
  563. )
  564. class MobileViTV2ForImageClassification(MobileViTV2PreTrainedModel):
  565. def __init__(self, config: MobileViTV2Config) -> None:
  566. super().__init__(config)
  567. self.num_labels = config.num_labels
  568. self.mobilevitv2 = MobileViTV2Model(config)
  569. out_channels = make_divisible(512 * config.width_multiplier, divisor=8) # layer 5 output dimension
  570. # Classifier head
  571. self.classifier = (
  572. nn.Linear(in_features=out_channels, out_features=config.num_labels)
  573. if config.num_labels > 0
  574. else nn.Identity()
  575. )
  576. # Initialize weights and apply final processing
  577. self.post_init()
  578. @auto_docstring
  579. def forward(
  580. self,
  581. pixel_values: Optional[torch.Tensor] = None,
  582. output_hidden_states: Optional[bool] = None,
  583. labels: Optional[torch.Tensor] = None,
  584. return_dict: Optional[bool] = None,
  585. ) -> Union[tuple, ImageClassifierOutputWithNoAttention]:
  586. r"""
  587. labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
  588. Labels for computing the image classification/regression loss. Indices should be in `[0, ...,
  589. config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss). If
  590. `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
  591. """
  592. return_dict = return_dict if return_dict is not None else self.config.use_return_dict
  593. outputs = self.mobilevitv2(pixel_values, output_hidden_states=output_hidden_states, return_dict=return_dict)
  594. pooled_output = outputs.pooler_output if return_dict else outputs[1]
  595. logits = self.classifier(pooled_output)
  596. loss = None
  597. if labels is not None:
  598. loss = self.loss_function(labels, logits, self.config)
  599. if not return_dict:
  600. output = (logits,) + outputs[2:]
  601. return ((loss,) + output) if loss is not None else output
  602. return ImageClassifierOutputWithNoAttention(
  603. loss=loss,
  604. logits=logits,
  605. hidden_states=outputs.hidden_states,
  606. )
  607. # Copied from transformers.models.mobilevit.modeling_mobilevit.MobileViTASPPPooling with MobileViT->MobileViTV2
  608. class MobileViTV2ASPPPooling(nn.Module):
  609. def __init__(self, config: MobileViTV2Config, in_channels: int, out_channels: int) -> None:
  610. super().__init__()
  611. self.global_pool = nn.AdaptiveAvgPool2d(output_size=1)
  612. self.conv_1x1 = MobileViTV2ConvLayer(
  613. config,
  614. in_channels=in_channels,
  615. out_channels=out_channels,
  616. kernel_size=1,
  617. stride=1,
  618. use_normalization=True,
  619. use_activation="relu",
  620. )
  621. def forward(self, features: torch.Tensor) -> torch.Tensor:
  622. spatial_size = features.shape[-2:]
  623. features = self.global_pool(features)
  624. features = self.conv_1x1(features)
  625. features = nn.functional.interpolate(features, size=spatial_size, mode="bilinear", align_corners=False)
  626. return features
  627. class MobileViTV2ASPP(nn.Module):
  628. """
  629. ASPP module defined in DeepLab papers: https://huggingface.co/papers/1606.00915, https://huggingface.co/papers/1706.05587
  630. """
  631. def __init__(self, config: MobileViTV2Config) -> None:
  632. super().__init__()
  633. encoder_out_channels = make_divisible(512 * config.width_multiplier, divisor=8) # layer 5 output dimension
  634. in_channels = encoder_out_channels
  635. out_channels = config.aspp_out_channels
  636. if len(config.atrous_rates) != 3:
  637. raise ValueError("Expected 3 values for atrous_rates")
  638. self.convs = nn.ModuleList()
  639. in_projection = MobileViTV2ConvLayer(
  640. config,
  641. in_channels=in_channels,
  642. out_channels=out_channels,
  643. kernel_size=1,
  644. use_activation="relu",
  645. )
  646. self.convs.append(in_projection)
  647. self.convs.extend(
  648. [
  649. MobileViTV2ConvLayer(
  650. config,
  651. in_channels=in_channels,
  652. out_channels=out_channels,
  653. kernel_size=3,
  654. dilation=rate,
  655. use_activation="relu",
  656. )
  657. for rate in config.atrous_rates
  658. ]
  659. )
  660. pool_layer = MobileViTV2ASPPPooling(config, in_channels, out_channels)
  661. self.convs.append(pool_layer)
  662. self.project = MobileViTV2ConvLayer(
  663. config, in_channels=5 * out_channels, out_channels=out_channels, kernel_size=1, use_activation="relu"
  664. )
  665. self.dropout = nn.Dropout(p=config.aspp_dropout_prob)
  666. def forward(self, features: torch.Tensor) -> torch.Tensor:
  667. pyramid = []
  668. for conv in self.convs:
  669. pyramid.append(conv(features))
  670. pyramid = torch.cat(pyramid, dim=1)
  671. pooled_features = self.project(pyramid)
  672. pooled_features = self.dropout(pooled_features)
  673. return pooled_features
  674. # Copied from transformers.models.mobilevit.modeling_mobilevit.MobileViTDeepLabV3 with MobileViT->MobileViTV2
  675. class MobileViTV2DeepLabV3(nn.Module):
  676. """
  677. DeepLabv3 architecture: https://huggingface.co/papers/1706.05587
  678. """
  679. def __init__(self, config: MobileViTV2Config) -> None:
  680. super().__init__()
  681. self.aspp = MobileViTV2ASPP(config)
  682. self.dropout = nn.Dropout2d(config.classifier_dropout_prob)
  683. self.classifier = MobileViTV2ConvLayer(
  684. config,
  685. in_channels=config.aspp_out_channels,
  686. out_channels=config.num_labels,
  687. kernel_size=1,
  688. use_normalization=False,
  689. use_activation=False,
  690. bias=True,
  691. )
  692. def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
  693. features = self.aspp(hidden_states[-1])
  694. features = self.dropout(features)
  695. features = self.classifier(features)
  696. return features
  697. @auto_docstring(
  698. custom_intro="""
  699. MobileViTV2 model with a semantic segmentation head on top, e.g. for Pascal VOC.
  700. """
  701. )
  702. class MobileViTV2ForSemanticSegmentation(MobileViTV2PreTrainedModel):
  703. def __init__(self, config: MobileViTV2Config) -> None:
  704. super().__init__(config)
  705. self.num_labels = config.num_labels
  706. self.mobilevitv2 = MobileViTV2Model(config, expand_output=False)
  707. self.segmentation_head = MobileViTV2DeepLabV3(config)
  708. # Initialize weights and apply final processing
  709. self.post_init()
  710. @auto_docstring
  711. def forward(
  712. self,
  713. pixel_values: Optional[torch.Tensor] = None,
  714. labels: Optional[torch.Tensor] = None,
  715. output_hidden_states: Optional[bool] = None,
  716. return_dict: Optional[bool] = None,
  717. ) -> Union[tuple, SemanticSegmenterOutput]:
  718. r"""
  719. labels (`torch.LongTensor` of shape `(batch_size, height, width)`, *optional*):
  720. Ground truth semantic segmentation maps for computing the loss. Indices should be in `[0, ...,
  721. config.num_labels - 1]`. If `config.num_labels > 1`, a classification loss is computed (Cross-Entropy).
  722. Examples:
  723. ```python
  724. >>> import requests
  725. >>> import torch
  726. >>> from PIL import Image
  727. >>> from transformers import AutoImageProcessor, MobileViTV2ForSemanticSegmentation
  728. >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
  729. >>> image = Image.open(requests.get(url, stream=True).raw)
  730. >>> image_processor = AutoImageProcessor.from_pretrained("apple/mobilevitv2-1.0-imagenet1k-256")
  731. >>> model = MobileViTV2ForSemanticSegmentation.from_pretrained("apple/mobilevitv2-1.0-imagenet1k-256")
  732. >>> inputs = image_processor(images=image, return_tensors="pt")
  733. >>> with torch.no_grad():
  734. ... outputs = model(**inputs)
  735. >>> # logits are of shape (batch_size, num_labels, height, width)
  736. >>> logits = outputs.logits
  737. ```"""
  738. output_hidden_states = (
  739. output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
  740. )
  741. return_dict = return_dict if return_dict is not None else self.config.use_return_dict
  742. if labels is not None and self.config.num_labels == 1:
  743. raise ValueError("The number of labels should be greater than one")
  744. outputs = self.mobilevitv2(
  745. pixel_values,
  746. output_hidden_states=True, # we need the intermediate hidden states
  747. return_dict=return_dict,
  748. )
  749. encoder_hidden_states = outputs.hidden_states if return_dict else outputs[1]
  750. logits = self.segmentation_head(encoder_hidden_states)
  751. loss = None
  752. if labels is not None:
  753. # upsample logits to the images' original size
  754. upsampled_logits = nn.functional.interpolate(
  755. logits, size=labels.shape[-2:], mode="bilinear", align_corners=False
  756. )
  757. loss_fct = CrossEntropyLoss(ignore_index=self.config.semantic_loss_ignore_index)
  758. loss = loss_fct(upsampled_logits, labels)
  759. if not return_dict:
  760. if output_hidden_states:
  761. output = (logits,) + outputs[1:]
  762. else:
  763. output = (logits,) + outputs[2:]
  764. return ((loss,) + output) if loss is not None else output
  765. return SemanticSegmenterOutput(
  766. loss=loss,
  767. logits=logits,
  768. hidden_states=outputs.hidden_states if output_hidden_states else None,
  769. attentions=None,
  770. )
  771. __all__ = [
  772. "MobileViTV2ForImageClassification",
  773. "MobileViTV2ForSemanticSegmentation",
  774. "MobileViTV2Model",
  775. "MobileViTV2PreTrainedModel",
  776. ]