modeling_focalnet.py 37 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937
  1. # coding=utf-8
  2. # Copyright 2023 Microsoft Research and The HuggingFace Inc. team. All rights reserved.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. """PyTorch FocalNet model."""
  16. import collections.abc
  17. import math
  18. from dataclasses import dataclass
  19. from typing import Optional, Union
  20. import torch
  21. from torch import nn
  22. from ...activations import ACT2FN
  23. from ...modeling_layers import GradientCheckpointingLayer
  24. from ...modeling_outputs import BackboneOutput
  25. from ...modeling_utils import PreTrainedModel
  26. from ...utils import ModelOutput, auto_docstring, logging
  27. from ...utils.backbone_utils import BackboneMixin
  28. from .configuration_focalnet import FocalNetConfig
  29. logger = logging.get_logger(__name__)
  30. @dataclass
  31. @auto_docstring(
  32. custom_intro="""
  33. FocalNet encoder's outputs, with potential hidden states.
  34. """
  35. )
  36. class FocalNetEncoderOutput(ModelOutput):
  37. r"""
  38. reshaped_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
  39. Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each stage) of
  40. shape `(batch_size, hidden_size, height, width)`.
  41. Hidden-states of the model at the output of each layer plus the initial embedding outputs reshaped to
  42. include the spatial dimensions.
  43. """
  44. last_hidden_state: Optional[torch.FloatTensor] = None
  45. hidden_states: Optional[tuple[torch.FloatTensor]] = None
  46. reshaped_hidden_states: Optional[tuple[torch.FloatTensor]] = None
  47. @dataclass
  48. @auto_docstring(
  49. custom_intro="""
  50. FocalNet model's outputs that also contains a pooling of the last hidden states.
  51. """
  52. )
  53. class FocalNetModelOutput(ModelOutput):
  54. r"""
  55. pooler_output (`torch.FloatTensor` of shape `(batch_size, hidden_size)`, *optional*, returned when `add_pooling_layer=True` is passed):
  56. Average pooling of the last layer hidden-state.
  57. reshaped_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
  58. Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each stage) of
  59. shape `(batch_size, hidden_size, height, width)`.
  60. Hidden-states of the model at the output of each layer plus the initial embedding outputs reshaped to
  61. include the spatial dimensions.
  62. """
  63. last_hidden_state: Optional[torch.FloatTensor] = None
  64. pooler_output: Optional[torch.FloatTensor] = None
  65. hidden_states: Optional[tuple[torch.FloatTensor]] = None
  66. reshaped_hidden_states: Optional[tuple[torch.FloatTensor]] = None
  67. @dataclass
  68. @auto_docstring(
  69. custom_intro="""
  70. FocalNet masked image model outputs.
  71. """
  72. )
  73. class FocalNetMaskedImageModelingOutput(ModelOutput):
  74. r"""
  75. loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `bool_masked_pos` is provided):
  76. Masked image modeling (MLM) loss.
  77. reconstruction (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
  78. Reconstructed pixel values.
  79. reshaped_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
  80. Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each stage) of
  81. shape `(batch_size, hidden_size, height, width)`.
  82. Hidden-states of the model at the output of each layer plus the initial embedding outputs reshaped to
  83. include the spatial dimensions.
  84. """
  85. loss: Optional[torch.FloatTensor] = None
  86. reconstruction: Optional[torch.FloatTensor] = None
  87. hidden_states: Optional[tuple[torch.FloatTensor]] = None
  88. reshaped_hidden_states: Optional[tuple[torch.FloatTensor]] = None
  89. @dataclass
  90. @auto_docstring(
  91. custom_intro="""
  92. FocalNet outputs for image classification.
  93. """
  94. )
  95. class FocalNetImageClassifierOutput(ModelOutput):
  96. r"""
  97. loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
  98. Classification (or regression if config.num_labels==1) loss.
  99. logits (`torch.FloatTensor` of shape `(batch_size, config.num_labels)`):
  100. Classification (or regression if config.num_labels==1) scores (before SoftMax).
  101. reshaped_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
  102. Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each stage) of
  103. shape `(batch_size, hidden_size, height, width)`.
  104. Hidden-states of the model at the output of each layer plus the initial embedding outputs reshaped to
  105. include the spatial dimensions.
  106. """
  107. loss: Optional[torch.FloatTensor] = None
  108. logits: Optional[torch.FloatTensor] = None
  109. hidden_states: Optional[tuple[torch.FloatTensor]] = None
  110. reshaped_hidden_states: Optional[tuple[torch.FloatTensor]] = None
  111. class FocalNetEmbeddings(nn.Module):
  112. """
  113. Construct the patch embeddings and layernorm. Optionally, also the mask token.
  114. """
  115. def __init__(self, config, use_mask_token=False):
  116. super().__init__()
  117. self.patch_embeddings = FocalNetPatchEmbeddings(
  118. config=config,
  119. image_size=config.image_size,
  120. patch_size=config.patch_size,
  121. num_channels=config.num_channels,
  122. embed_dim=config.embed_dim,
  123. use_conv_embed=config.use_conv_embed,
  124. is_stem=True,
  125. )
  126. self.patch_grid = self.patch_embeddings.grid_size
  127. self.mask_token = nn.Parameter(torch.zeros(1, 1, config.embed_dim)) if use_mask_token else None
  128. self.norm = nn.LayerNorm(config.embed_dim, eps=config.layer_norm_eps)
  129. self.dropout = nn.Dropout(config.hidden_dropout_prob)
  130. def forward(
  131. self, pixel_values: Optional[torch.FloatTensor], bool_masked_pos: Optional[torch.BoolTensor] = None
  132. ) -> tuple[torch.Tensor]:
  133. embeddings, output_dimensions = self.patch_embeddings(pixel_values)
  134. embeddings = self.norm(embeddings)
  135. batch_size, seq_len, _ = embeddings.size()
  136. if bool_masked_pos is not None:
  137. mask_tokens = self.mask_token.expand(batch_size, seq_len, -1)
  138. # replace the masked visual tokens by mask_tokens
  139. mask = bool_masked_pos.unsqueeze(-1).type_as(mask_tokens)
  140. embeddings = embeddings * (1.0 - mask) + mask_tokens * mask
  141. embeddings = self.dropout(embeddings)
  142. return embeddings, output_dimensions
  143. class FocalNetPatchEmbeddings(nn.Module):
  144. def __init__(
  145. self,
  146. config,
  147. image_size,
  148. patch_size,
  149. num_channels,
  150. embed_dim,
  151. add_norm=False,
  152. use_conv_embed=False,
  153. is_stem=False,
  154. ):
  155. super().__init__()
  156. image_size = image_size if isinstance(image_size, collections.abc.Iterable) else (image_size, image_size)
  157. patch_size = patch_size if isinstance(patch_size, collections.abc.Iterable) else (patch_size, patch_size)
  158. num_patches = (image_size[1] // patch_size[1]) * (image_size[0] // patch_size[0])
  159. self.image_size = image_size
  160. self.patch_size = patch_size
  161. self.num_channels = num_channels
  162. self.num_patches = num_patches
  163. self.grid_size = (image_size[0] // patch_size[0], image_size[1] // patch_size[1])
  164. if use_conv_embed:
  165. # if we choose to use conv embedding, then we treat the stem and non-stem differently
  166. if is_stem:
  167. kernel_size = 7
  168. padding = 2
  169. stride = 4
  170. else:
  171. kernel_size = 3
  172. padding = 1
  173. stride = 2
  174. self.projection = nn.Conv2d(
  175. num_channels, embed_dim, kernel_size=kernel_size, stride=stride, padding=padding
  176. )
  177. else:
  178. self.projection = nn.Conv2d(num_channels, embed_dim, kernel_size=patch_size, stride=patch_size)
  179. if add_norm:
  180. self.norm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps)
  181. else:
  182. self.norm = None
  183. def maybe_pad(self, pixel_values, height, width):
  184. if width % self.patch_size[1] != 0:
  185. pad_values = (0, self.patch_size[1] - width % self.patch_size[1])
  186. pixel_values = nn.functional.pad(pixel_values, pad_values)
  187. if height % self.patch_size[0] != 0:
  188. pad_values = (0, 0, 0, self.patch_size[0] - height % self.patch_size[0])
  189. pixel_values = nn.functional.pad(pixel_values, pad_values)
  190. return pixel_values
  191. def forward(self, pixel_values: Optional[torch.FloatTensor]) -> tuple[torch.Tensor, tuple[int]]:
  192. _, num_channels, height, width = pixel_values.shape
  193. if num_channels != self.num_channels:
  194. raise ValueError(
  195. "Make sure that the channel dimension of the pixel values match with the one set in the configuration."
  196. )
  197. # pad the input to be divisible by self.patch_size, if needed
  198. pixel_values = self.maybe_pad(pixel_values, height, width)
  199. embeddings = self.projection(pixel_values)
  200. _, _, height, width = embeddings.shape
  201. output_dimensions = (height, width)
  202. embeddings = embeddings.flatten(2).transpose(1, 2)
  203. if self.norm is not None:
  204. embeddings = self.norm(embeddings)
  205. return embeddings, output_dimensions
  206. # Copied from transformers.models.beit.modeling_beit.drop_path
  207. def drop_path(input: torch.Tensor, drop_prob: float = 0.0, training: bool = False) -> torch.Tensor:
  208. """
  209. Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
  210. Comment by Ross Wightman: This is the same as the DropConnect impl I created for EfficientNet, etc networks,
  211. however, the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper...
  212. See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for changing the
  213. layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use 'survival rate' as the
  214. argument.
  215. """
  216. if drop_prob == 0.0 or not training:
  217. return input
  218. keep_prob = 1 - drop_prob
  219. shape = (input.shape[0],) + (1,) * (input.ndim - 1) # work with diff dim tensors, not just 2D ConvNets
  220. random_tensor = keep_prob + torch.rand(shape, dtype=input.dtype, device=input.device)
  221. random_tensor.floor_() # binarize
  222. output = input.div(keep_prob) * random_tensor
  223. return output
  224. # Copied from transformers.models.beit.modeling_beit.BeitDropPath with Beit->FocalNet
  225. class FocalNetDropPath(nn.Module):
  226. """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks)."""
  227. def __init__(self, drop_prob: Optional[float] = None) -> None:
  228. super().__init__()
  229. self.drop_prob = drop_prob
  230. def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
  231. return drop_path(hidden_states, self.drop_prob, self.training)
  232. def extra_repr(self) -> str:
  233. return f"p={self.drop_prob}"
  234. class FocalNetModulation(nn.Module):
  235. def __init__(self, config, index, dim, focal_factor=2, bias=True, projection_dropout=0.0):
  236. super().__init__()
  237. self.dim = dim
  238. self.focal_window = config.focal_windows[index]
  239. self.focal_level = config.focal_levels[index]
  240. self.focal_factor = focal_factor
  241. self.use_post_layernorm_in_modulation = config.use_post_layernorm_in_modulation
  242. self.normalize_modulator = config.normalize_modulator
  243. self.projection_in = nn.Linear(dim, 2 * dim + (self.focal_level + 1), bias=bias)
  244. self.projection_context = nn.Conv2d(dim, dim, kernel_size=1, stride=1, bias=bias)
  245. self.activation = nn.GELU()
  246. self.projection_out = nn.Linear(dim, dim)
  247. self.projection_dropout = nn.Dropout(projection_dropout)
  248. self.focal_layers = nn.ModuleList()
  249. self.kernel_sizes = []
  250. for k in range(self.focal_level):
  251. kernel_size = self.focal_factor * k + self.focal_window
  252. self.focal_layers.append(
  253. nn.Sequential(
  254. nn.Conv2d(
  255. dim, dim, kernel_size=kernel_size, stride=1, groups=dim, padding=kernel_size // 2, bias=False
  256. ),
  257. nn.GELU(),
  258. )
  259. )
  260. self.kernel_sizes.append(kernel_size)
  261. if self.use_post_layernorm_in_modulation:
  262. self.layernorm = nn.LayerNorm(dim, eps=config.layer_norm_eps)
  263. def forward(self, hidden_state):
  264. """
  265. Args:
  266. hidden_state:
  267. Input features with shape of (batch_size, height, width, num_channels)
  268. """
  269. num_channels = hidden_state.shape[-1]
  270. # pre linear projection
  271. x = self.projection_in(hidden_state).permute(0, 3, 1, 2).contiguous()
  272. q, ctx, gates = torch.split(x, (num_channels, num_channels, self.focal_level + 1), 1)
  273. # context aggregation
  274. ctx_all = 0
  275. for level in range(self.focal_level):
  276. ctx = self.focal_layers[level](ctx)
  277. ctx_all = ctx_all + ctx * gates[:, level : level + 1]
  278. ctx_global = self.activation(ctx.mean(2, keepdim=True).mean(3, keepdim=True))
  279. ctx_all = ctx_all + ctx_global * gates[:, self.focal_level :]
  280. # normalize context
  281. if self.normalize_modulator:
  282. ctx_all = ctx_all / (self.focal_level + 1)
  283. # focal modulation
  284. modulator = self.projection_context(ctx_all)
  285. x_out = q * modulator
  286. x_out = x_out.permute(0, 2, 3, 1).contiguous()
  287. if self.use_post_layernorm_in_modulation:
  288. x_out = self.layernorm(x_out)
  289. # post linear projection
  290. x_out = self.projection_out(x_out)
  291. x_out = self.projection_dropout(x_out)
  292. return x_out
  293. class FocalNetMlp(nn.Module):
  294. def __init__(self, config, in_features, hidden_features=None, out_features=None, drop=0.0):
  295. super().__init__()
  296. out_features = out_features or in_features
  297. hidden_features = hidden_features or in_features
  298. self.fc1 = nn.Linear(in_features, hidden_features)
  299. self.activation = ACT2FN[config.hidden_act]
  300. self.fc2 = nn.Linear(hidden_features, out_features)
  301. self.drop = nn.Dropout(drop)
  302. def forward(self, hidden_state):
  303. hidden_state = self.fc1(hidden_state)
  304. hidden_state = self.activation(hidden_state)
  305. hidden_state = self.drop(hidden_state)
  306. hidden_state = self.fc2(hidden_state)
  307. hidden_state = self.drop(hidden_state)
  308. return hidden_state
  309. class FocalNetLayer(nn.Module):
  310. r"""Focal Modulation Network layer (block).
  311. Args:
  312. config (`FocalNetConfig`):
  313. Model config.
  314. index (`int`):
  315. Layer index.
  316. dim (`int`):
  317. Number of input channels.
  318. input_resolution (`tuple[int]`):
  319. Input resolution.
  320. drop_path (`float`, *optional*, defaults to 0.0):
  321. Stochastic depth rate.
  322. """
  323. def __init__(self, config, index, dim, input_resolution, drop_path=0.0):
  324. super().__init__()
  325. self.config = config
  326. # layer-specific attributes
  327. self.dim = dim
  328. self.input_resolution = input_resolution
  329. # general attributes
  330. self.drop = config.hidden_dropout_prob
  331. self.use_post_layernorm = config.use_post_layernorm
  332. self.norm1 = nn.LayerNorm(dim, eps=config.layer_norm_eps)
  333. self.modulation = FocalNetModulation(
  334. config=config,
  335. index=index,
  336. dim=dim,
  337. projection_dropout=self.drop,
  338. )
  339. self.drop_path = FocalNetDropPath(drop_path) if drop_path > 0.0 else nn.Identity()
  340. self.norm2 = nn.LayerNorm(dim, eps=config.layer_norm_eps)
  341. mlp_hidden_dim = int(dim * config.mlp_ratio)
  342. self.mlp = FocalNetMlp(config=config, in_features=dim, hidden_features=mlp_hidden_dim, drop=self.drop)
  343. self.gamma_1 = 1.0
  344. self.gamma_2 = 1.0
  345. if config.use_layerscale:
  346. self.gamma_1 = nn.Parameter(config.layerscale_value * torch.ones(dim), requires_grad=True)
  347. self.gamma_2 = nn.Parameter(config.layerscale_value * torch.ones(dim), requires_grad=True)
  348. def forward(self, hidden_state, input_dimensions):
  349. height, width = input_dimensions
  350. batch_size, _, num_channels = hidden_state.shape
  351. shortcut = hidden_state
  352. # Focal Modulation
  353. hidden_state = hidden_state if self.use_post_layernorm else self.norm1(hidden_state)
  354. hidden_state = hidden_state.view(batch_size, height, width, num_channels)
  355. hidden_state = self.modulation(hidden_state).view(batch_size, height * width, num_channels)
  356. hidden_state = hidden_state if not self.use_post_layernorm else self.norm1(hidden_state)
  357. # FFN
  358. hidden_state = shortcut + self.drop_path(self.gamma_1 * hidden_state)
  359. hidden_state = hidden_state + self.drop_path(
  360. self.gamma_2
  361. * (self.norm2(self.mlp(hidden_state)) if self.use_post_layernorm else self.mlp(self.norm2(hidden_state)))
  362. )
  363. return hidden_state
  364. class FocalNetStage(GradientCheckpointingLayer):
  365. def __init__(self, config, index, input_resolution):
  366. super().__init__()
  367. self.config = config
  368. self.num_stages = len(config.depths)
  369. embed_dim = [config.embed_dim * (2**i) for i in range(self.num_stages)]
  370. dim = embed_dim[index]
  371. out_dim = embed_dim[index + 1] if (index < self.num_stages - 1) else None
  372. downsample = FocalNetPatchEmbeddings if (index < self.num_stages - 1) else None
  373. # stochastic depth decay rule
  374. dpr = [x.item() for x in torch.linspace(0, config.drop_path_rate, sum(config.depths), device="cpu")]
  375. drop_path = dpr[sum(config.depths[:index]) : sum(config.depths[: index + 1])]
  376. self.layers = nn.ModuleList(
  377. [
  378. FocalNetLayer(
  379. config=config,
  380. index=index,
  381. dim=dim,
  382. input_resolution=input_resolution,
  383. drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path,
  384. )
  385. for i in range(config.depths[index])
  386. ]
  387. )
  388. if downsample is not None:
  389. self.downsample = downsample(
  390. config=config,
  391. image_size=input_resolution,
  392. patch_size=2,
  393. num_channels=dim,
  394. embed_dim=out_dim,
  395. add_norm=True,
  396. use_conv_embed=config.use_conv_embed,
  397. is_stem=False,
  398. )
  399. else:
  400. self.downsample = None
  401. self.pointing = False
  402. def forward(self, hidden_states: torch.Tensor, input_dimensions: tuple[int, int]) -> tuple[torch.Tensor]:
  403. height, width = input_dimensions
  404. for layer_module in self.layers:
  405. hidden_states = layer_module(hidden_states, input_dimensions)
  406. hidden_states_before_downsampling = hidden_states
  407. if self.downsample is not None:
  408. height, width = input_dimensions
  409. hidden_states = hidden_states.transpose(1, 2).reshape(
  410. hidden_states_before_downsampling.shape[0], -1, height, width
  411. )
  412. hidden_states, output_dimensions = self.downsample(hidden_states)
  413. else:
  414. output_dimensions = (height, width, height, width)
  415. stage_outputs = (hidden_states, hidden_states_before_downsampling, output_dimensions)
  416. return stage_outputs
  417. class FocalNetEncoder(nn.Module):
  418. def __init__(self, config, grid_size):
  419. super().__init__()
  420. self.num_stages = len(config.depths)
  421. self.config = config
  422. self.stages = nn.ModuleList(
  423. [
  424. FocalNetStage(
  425. config=config,
  426. index=i_layer,
  427. input_resolution=(grid_size[0] // (2**i_layer), grid_size[1] // (2**i_layer)),
  428. )
  429. for i_layer in range(self.num_stages)
  430. ]
  431. )
  432. self.gradient_checkpointing = False
  433. def forward(
  434. self,
  435. hidden_states: torch.Tensor,
  436. input_dimensions: tuple[int, int],
  437. output_hidden_states: Optional[bool] = False,
  438. output_hidden_states_before_downsampling: Optional[bool] = False,
  439. return_dict: Optional[bool] = True,
  440. ) -> Union[tuple, FocalNetEncoderOutput]:
  441. all_hidden_states = () if output_hidden_states else None
  442. all_reshaped_hidden_states = () if output_hidden_states else None
  443. if output_hidden_states:
  444. batch_size, _, hidden_size = hidden_states.shape
  445. # rearrange b (h w) c -> b c h w
  446. reshaped_hidden_state = hidden_states.view(batch_size, *input_dimensions, hidden_size)
  447. reshaped_hidden_state = reshaped_hidden_state.permute(0, 3, 1, 2)
  448. all_hidden_states += (hidden_states,)
  449. all_reshaped_hidden_states += (reshaped_hidden_state,)
  450. for i, stage_module in enumerate(self.stages):
  451. stage_outputs = stage_module(hidden_states, input_dimensions)
  452. hidden_states = stage_outputs[0]
  453. hidden_states_before_downsampling = stage_outputs[1]
  454. output_dimensions = stage_outputs[2]
  455. input_dimensions = (output_dimensions[-2], output_dimensions[-1])
  456. if output_hidden_states and output_hidden_states_before_downsampling:
  457. batch_size, _, hidden_size = hidden_states_before_downsampling.shape
  458. # rearrange b (h w) c -> b c h w
  459. # here we use the original (not downsampled) height and width
  460. reshaped_hidden_state = hidden_states_before_downsampling.view(
  461. batch_size, *(output_dimensions[0], output_dimensions[1]), hidden_size
  462. )
  463. reshaped_hidden_state = reshaped_hidden_state.permute(0, 3, 1, 2)
  464. all_hidden_states += (hidden_states_before_downsampling,)
  465. all_reshaped_hidden_states += (reshaped_hidden_state,)
  466. elif output_hidden_states and not output_hidden_states_before_downsampling:
  467. batch_size, _, hidden_size = hidden_states.shape
  468. # rearrange b (h w) c -> b c h w
  469. reshaped_hidden_state = hidden_states.view(batch_size, *input_dimensions, hidden_size)
  470. reshaped_hidden_state = reshaped_hidden_state.permute(0, 3, 1, 2)
  471. all_hidden_states += (hidden_states,)
  472. all_reshaped_hidden_states += (reshaped_hidden_state,)
  473. if not return_dict:
  474. return tuple(v for v in [hidden_states, all_hidden_states] if v is not None)
  475. return FocalNetEncoderOutput(
  476. last_hidden_state=hidden_states,
  477. hidden_states=all_hidden_states,
  478. reshaped_hidden_states=all_reshaped_hidden_states,
  479. )
  480. @auto_docstring
  481. class FocalNetPreTrainedModel(PreTrainedModel):
  482. config: FocalNetConfig
  483. base_model_prefix = "focalnet"
  484. main_input_name = "pixel_values"
  485. supports_gradient_checkpointing = True
  486. _no_split_modules = ["FocalNetStage"]
  487. def _init_weights(self, module):
  488. """Initialize the weights"""
  489. if isinstance(module, (nn.Linear, nn.Conv2d)):
  490. # Slightly different from the TF version which uses truncated_normal for initialization
  491. # cf https://github.com/pytorch/pytorch/pull/5617
  492. module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
  493. if module.bias is not None:
  494. module.bias.data.zero_()
  495. elif isinstance(module, nn.LayerNorm):
  496. module.bias.data.zero_()
  497. module.weight.data.fill_(1.0)
  498. elif isinstance(module, FocalNetEmbeddings):
  499. if module.mask_token is not None:
  500. module.mask_token.data.zero_()
  501. elif isinstance(module, FocalNetLayer):
  502. if self.config.use_layerscale:
  503. module.gamma_1.data.fill_(self.config.layerscale_value)
  504. module.gamma_2.data.fill_(self.config.layerscale_value)
  505. @auto_docstring
  506. class FocalNetModel(FocalNetPreTrainedModel):
  507. def __init__(self, config, add_pooling_layer=True, use_mask_token=False):
  508. r"""
  509. add_pooling_layer (bool, *optional*, defaults to `True`):
  510. Whether to add a pooling layer
  511. use_mask_token (`bool`, *optional*, defaults to `False`):
  512. Whether to use a mask token for masked image modeling.
  513. """
  514. super().__init__(config)
  515. self.config = config
  516. self.num_stages = len(config.depths)
  517. self.num_features = int(config.embed_dim * 2 ** (self.num_stages - 1))
  518. self.embeddings = FocalNetEmbeddings(config, use_mask_token=use_mask_token)
  519. self.encoder = FocalNetEncoder(config, self.embeddings.patch_grid)
  520. self.layernorm = nn.LayerNorm(self.num_features, eps=config.layer_norm_eps)
  521. self.pooler = nn.AdaptiveAvgPool1d(1) if add_pooling_layer else None
  522. # Initialize weights and apply final processing
  523. self.post_init()
  524. def get_input_embeddings(self):
  525. return self.embeddings.patch_embeddings
  526. @auto_docstring
  527. def forward(
  528. self,
  529. pixel_values: Optional[torch.FloatTensor] = None,
  530. bool_masked_pos: Optional[torch.BoolTensor] = None,
  531. output_hidden_states: Optional[bool] = None,
  532. return_dict: Optional[bool] = None,
  533. ) -> Union[tuple, FocalNetModelOutput]:
  534. r"""
  535. bool_masked_pos (`torch.BoolTensor` of shape `(batch_size, num_patches)`):
  536. Boolean masked positions. Indicates which patches are masked (1) and which aren't (0).
  537. """
  538. output_hidden_states = (
  539. output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
  540. )
  541. return_dict = return_dict if return_dict is not None else self.config.use_return_dict
  542. if pixel_values is None:
  543. raise ValueError("You have to specify pixel_values")
  544. embedding_output, input_dimensions = self.embeddings(pixel_values, bool_masked_pos=bool_masked_pos)
  545. encoder_outputs = self.encoder(
  546. embedding_output,
  547. input_dimensions,
  548. output_hidden_states=output_hidden_states,
  549. return_dict=return_dict,
  550. )
  551. sequence_output = encoder_outputs[0]
  552. sequence_output = self.layernorm(sequence_output)
  553. pooled_output = None
  554. if self.pooler is not None:
  555. pooled_output = self.pooler(sequence_output.transpose(1, 2))
  556. pooled_output = torch.flatten(pooled_output, 1)
  557. if not return_dict:
  558. output = (sequence_output, pooled_output) + encoder_outputs[1:]
  559. return output
  560. return FocalNetModelOutput(
  561. last_hidden_state=sequence_output,
  562. pooler_output=pooled_output,
  563. hidden_states=encoder_outputs.hidden_states,
  564. reshaped_hidden_states=encoder_outputs.reshaped_hidden_states,
  565. )
  566. @auto_docstring(
  567. custom_intro="""
  568. FocalNet Model with a decoder on top for masked image modeling.
  569. This follows the same implementation as in [SimMIM](https://huggingface.co/papers/2111.09886).
  570. <Tip>
  571. Note that we provide a script to pre-train this model on custom data in our [examples
  572. directory](https://github.com/huggingface/transformers/tree/main/examples/pytorch/image-pretraining).
  573. </Tip>
  574. """
  575. )
  576. class FocalNetForMaskedImageModeling(FocalNetPreTrainedModel):
  577. def __init__(self, config):
  578. super().__init__(config)
  579. self.focalnet = FocalNetModel(config, add_pooling_layer=False, use_mask_token=True)
  580. self.num_stages = len(config.depths)
  581. num_features = int(config.embed_dim * 2 ** (self.num_stages - 1))
  582. self.decoder = nn.Sequential(
  583. nn.Conv2d(
  584. in_channels=num_features, out_channels=config.encoder_stride**2 * config.num_channels, kernel_size=1
  585. ),
  586. nn.PixelShuffle(config.encoder_stride),
  587. )
  588. # Initialize weights and apply final processing
  589. self.post_init()
  590. @auto_docstring
  591. def forward(
  592. self,
  593. pixel_values: Optional[torch.FloatTensor] = None,
  594. bool_masked_pos: Optional[torch.BoolTensor] = None,
  595. output_hidden_states: Optional[bool] = None,
  596. return_dict: Optional[bool] = None,
  597. ) -> Union[tuple, FocalNetMaskedImageModelingOutput]:
  598. r"""
  599. bool_masked_pos (`torch.BoolTensor` of shape `(batch_size, num_patches)`):
  600. Boolean masked positions. Indicates which patches are masked (1) and which aren't (0).
  601. Examples:
  602. ```python
  603. >>> from transformers import AutoImageProcessor, FocalNetConfig, FocalNetForMaskedImageModeling
  604. >>> import torch
  605. >>> from PIL import Image
  606. >>> import requests
  607. >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
  608. >>> image = Image.open(requests.get(url, stream=True).raw)
  609. >>> image_processor = AutoImageProcessor.from_pretrained("microsoft/focalnet-base-simmim-window6-192")
  610. >>> config = FocalNetConfig()
  611. >>> model = FocalNetForMaskedImageModeling(config)
  612. >>> num_patches = (model.config.image_size // model.config.patch_size) ** 2
  613. >>> pixel_values = image_processor(images=image, return_tensors="pt").pixel_values
  614. >>> # create random boolean mask of shape (batch_size, num_patches)
  615. >>> bool_masked_pos = torch.randint(low=0, high=2, size=(1, num_patches)).bool()
  616. >>> outputs = model(pixel_values, bool_masked_pos=bool_masked_pos)
  617. >>> loss, reconstructed_pixel_values = outputs.loss, outputs.logits
  618. >>> list(reconstructed_pixel_values.shape)
  619. [1, 3, 192, 192]
  620. ```"""
  621. return_dict = return_dict if return_dict is not None else self.config.use_return_dict
  622. outputs = self.focalnet(
  623. pixel_values,
  624. bool_masked_pos=bool_masked_pos,
  625. output_hidden_states=output_hidden_states,
  626. return_dict=return_dict,
  627. )
  628. sequence_output = outputs[0]
  629. # Reshape to (batch_size, num_channels, height, width)
  630. sequence_output = sequence_output.transpose(1, 2)
  631. batch_size, num_channels, sequence_length = sequence_output.shape
  632. height = width = math.floor(sequence_length**0.5)
  633. sequence_output = sequence_output.reshape(batch_size, num_channels, height, width)
  634. # Reconstruct pixel values
  635. reconstructed_pixel_values = self.decoder(sequence_output)
  636. masked_im_loss = None
  637. if bool_masked_pos is not None:
  638. size = self.config.image_size // self.config.patch_size
  639. bool_masked_pos = bool_masked_pos.reshape(-1, size, size)
  640. mask = (
  641. bool_masked_pos.repeat_interleave(self.config.patch_size, 1)
  642. .repeat_interleave(self.config.patch_size, 2)
  643. .unsqueeze(1)
  644. .contiguous()
  645. )
  646. reconstruction_loss = nn.functional.l1_loss(pixel_values, reconstructed_pixel_values, reduction="none")
  647. masked_im_loss = (reconstruction_loss * mask).sum() / (mask.sum() + 1e-5) / self.config.num_channels
  648. if not return_dict:
  649. output = (reconstructed_pixel_values,) + outputs[2:]
  650. return ((masked_im_loss,) + output) if masked_im_loss is not None else output
  651. return FocalNetMaskedImageModelingOutput(
  652. loss=masked_im_loss,
  653. reconstruction=reconstructed_pixel_values,
  654. hidden_states=outputs.hidden_states,
  655. reshaped_hidden_states=outputs.reshaped_hidden_states,
  656. )
  657. @auto_docstring(
  658. custom_intro="""
  659. FocalNet Model with an image classification head on top (a linear layer on top of the pooled output) e.g. for
  660. ImageNet.
  661. """
  662. )
  663. class FocalNetForImageClassification(FocalNetPreTrainedModel):
  664. # Copied from transformers.models.swin.modeling_swin.SwinForImageClassification.__init__ with Swin->FocalNet, swin->focalnet
  665. def __init__(self, config):
  666. super().__init__(config)
  667. self.num_labels = config.num_labels
  668. self.focalnet = FocalNetModel(config)
  669. # Classifier head
  670. self.classifier = (
  671. nn.Linear(self.focalnet.num_features, config.num_labels) if config.num_labels > 0 else nn.Identity()
  672. )
  673. # Initialize weights and apply final processing
  674. self.post_init()
  675. @auto_docstring
  676. def forward(
  677. self,
  678. pixel_values: Optional[torch.FloatTensor] = None,
  679. labels: Optional[torch.LongTensor] = None,
  680. output_hidden_states: Optional[bool] = None,
  681. return_dict: Optional[bool] = None,
  682. ) -> Union[tuple, FocalNetImageClassifierOutput]:
  683. r"""
  684. labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
  685. Labels for computing the image classification/regression loss. Indices should be in `[0, ...,
  686. config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
  687. `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
  688. """
  689. return_dict = return_dict if return_dict is not None else self.config.use_return_dict
  690. outputs = self.focalnet(
  691. pixel_values,
  692. output_hidden_states=output_hidden_states,
  693. return_dict=return_dict,
  694. )
  695. pooled_output = outputs[1]
  696. logits = self.classifier(pooled_output)
  697. loss = None
  698. if labels is not None:
  699. loss = self.loss_function(labels, logits, self.config)
  700. if not return_dict:
  701. output = (logits,) + outputs[2:]
  702. return ((loss,) + output) if loss is not None else output
  703. return FocalNetImageClassifierOutput(
  704. loss=loss,
  705. logits=logits,
  706. hidden_states=outputs.hidden_states,
  707. reshaped_hidden_states=outputs.reshaped_hidden_states,
  708. )
  709. @auto_docstring(
  710. custom_intro="""
  711. FocalNet backbone, to be used with frameworks like X-Decoder.
  712. """
  713. )
  714. class FocalNetBackbone(FocalNetPreTrainedModel, BackboneMixin):
  715. has_attentions = False
  716. def __init__(self, config: FocalNetConfig):
  717. super().__init__(config)
  718. super()._init_backbone(config)
  719. self.num_features = [config.embed_dim] + config.hidden_sizes
  720. self.focalnet = FocalNetModel(config)
  721. # initialize weights and apply final processing
  722. self.post_init()
  723. @auto_docstring
  724. def forward(
  725. self,
  726. pixel_values: torch.Tensor,
  727. output_hidden_states: Optional[bool] = None,
  728. return_dict: Optional[bool] = None,
  729. ) -> BackboneOutput:
  730. r"""
  731. Examples:
  732. ```python
  733. >>> from transformers import AutoImageProcessor, AutoBackbone
  734. >>> import torch
  735. >>> from PIL import Image
  736. >>> import requests
  737. >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
  738. >>> image = Image.open(requests.get(url, stream=True).raw)
  739. >>> processor = AutoImageProcessor.from_pretrained("microsoft/focalnet-tiny-lrf")
  740. >>> model = AutoBackbone.from_pretrained("microsoft/focalnet-tiny-lrf")
  741. >>> inputs = processor(image, return_tensors="pt")
  742. >>> outputs = model(**inputs)
  743. ```"""
  744. return_dict = return_dict if return_dict is not None else self.config.use_return_dict
  745. output_hidden_states = (
  746. output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
  747. )
  748. outputs = self.focalnet(pixel_values, output_hidden_states=True, return_dict=True)
  749. hidden_states = outputs.reshaped_hidden_states
  750. feature_maps = ()
  751. for idx, stage in enumerate(self.stage_names):
  752. if stage in self.out_features:
  753. feature_maps += (hidden_states[idx],)
  754. if not return_dict:
  755. output = (feature_maps,)
  756. if output_hidden_states:
  757. output += (outputs.hidden_states,)
  758. return output
  759. return BackboneOutput(
  760. feature_maps=feature_maps,
  761. hidden_states=outputs.hidden_states if output_hidden_states else None,
  762. attentions=None,
  763. )
  764. __all__ = [
  765. "FocalNetForImageClassification",
  766. "FocalNetForMaskedImageModeling",
  767. "FocalNetBackbone",
  768. "FocalNetModel",
  769. "FocalNetPreTrainedModel",
  770. ]