modeling_bit.py 28 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821
  1. # coding=utf-8
  2. # Copyright 2022 Google AI and The HuggingFace Inc. team. All rights reserved.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. """PyTorch BiT model. Also supports backbone for ViT hybrid."""
  16. import collections
  17. import math
  18. from typing import Optional
  19. import numpy as np
  20. import torch
  21. from torch import Tensor, nn
  22. from ...activations import ACT2FN
  23. from ...modeling_outputs import (
  24. BackboneOutput,
  25. BaseModelOutputWithNoAttention,
  26. BaseModelOutputWithPoolingAndNoAttention,
  27. ImageClassifierOutputWithNoAttention,
  28. )
  29. from ...modeling_utils import PreTrainedModel
  30. from ...utils import auto_docstring, logging
  31. from ...utils.backbone_utils import BackboneMixin
  32. from .configuration_bit import BitConfig
  33. logger = logging.get_logger(__name__)
  34. def get_padding_value(padding=None, kernel_size=7, stride=1, dilation=1) -> tuple[tuple, bool]:
  35. r"""
  36. Utility function to get the tuple padding value given the kernel_size and padding.
  37. Args:
  38. padding (Union[`str`, `int`], *optional*):
  39. Padding value, can be either `"same"`, `"valid"`. If a different value is provided the default padding from
  40. PyTorch is used.
  41. kernel_size (`int`, *optional*, defaults to 7):
  42. Kernel size of the convolution layers.
  43. stride (`int`, *optional*, defaults to 1):
  44. Stride value of the convolution layers.
  45. dilation (`int`, *optional*, defaults to 1):
  46. Dilation value of the convolution layers.
  47. """
  48. dynamic = False
  49. if padding is None:
  50. padding = ((stride - 1) + dilation * (kernel_size - 1)) // 2
  51. return padding, dynamic
  52. if isinstance(padding, str):
  53. # for any string padding, the padding will be calculated for you, one of three ways
  54. padding = padding.lower()
  55. if padding == "same":
  56. # TF compatible 'SAME' padding, has a performance and GPU memory allocation impact
  57. if stride == 1 and (dilation * (kernel_size - 1)) % 2 == 0:
  58. # static case, no extra overhead
  59. padding = ((stride - 1) + dilation * (kernel_size - 1)) // 2
  60. else:
  61. # dynamic 'SAME' padding, has runtime/GPU memory overhead
  62. padding = 0
  63. dynamic = True
  64. elif padding == "valid":
  65. # 'VALID' padding, same as padding=0
  66. padding = 0
  67. else:
  68. # Default to PyTorch style 'same'-ish symmetric padding
  69. padding = ((stride - 1) + dilation * (kernel_size - 1)) // 2
  70. return padding, dynamic
  71. class WeightStandardizedConv2d(nn.Conv2d):
  72. """Conv2d with Weight Standardization. Includes TensorFlow compatible SAME padding. Used for ViT Hybrid model.
  73. Paper: [Micro-Batch Training with Batch-Channel Normalization and Weight
  74. Standardization](https://huggingface.co/papers/1903.10520v2)
  75. """
  76. def __init__(
  77. self,
  78. in_channel,
  79. out_channels,
  80. kernel_size,
  81. stride=1,
  82. padding="SAME",
  83. dilation=1,
  84. groups=1,
  85. bias=False,
  86. eps=1e-6,
  87. ):
  88. padding, is_dynamic = get_padding_value(padding, kernel_size, stride=stride, dilation=dilation)
  89. super().__init__(
  90. in_channel,
  91. out_channels,
  92. kernel_size,
  93. stride=stride,
  94. padding=padding,
  95. dilation=dilation,
  96. groups=groups,
  97. bias=bias,
  98. )
  99. if is_dynamic:
  100. self.pad = DynamicPad2d(kernel_size, stride, dilation)
  101. else:
  102. self.pad = None
  103. self.eps = eps
  104. def forward(self, hidden_state):
  105. if self.pad is not None:
  106. hidden_state = self.pad(hidden_state)
  107. weight = nn.functional.batch_norm(
  108. self.weight.reshape(1, self.out_channels, -1), None, None, training=True, momentum=0.0, eps=self.eps
  109. ).reshape_as(self.weight)
  110. hidden_state = nn.functional.conv2d(
  111. hidden_state, weight, self.bias, self.stride, self.padding, self.dilation, self.groups
  112. )
  113. return hidden_state
  114. class BitGroupNormActivation(nn.GroupNorm):
  115. r"""
  116. A module that combines group normalization with an activation function.
  117. """
  118. def __init__(self, config, num_channels, eps=1e-5, affine=True, apply_activation=True):
  119. super().__init__(config.num_groups, num_channels, eps=eps, affine=affine)
  120. if apply_activation:
  121. self.activation = ACT2FN[config.hidden_act]
  122. else:
  123. self.activation = nn.Identity()
  124. def forward(self, hidden_state):
  125. hidden_state = nn.functional.group_norm(hidden_state, self.num_groups, self.weight, self.bias, self.eps)
  126. hidden_state = self.activation(hidden_state)
  127. return hidden_state
  128. class DynamicPad2d(nn.Module):
  129. r"""
  130. A module that wraps dynamic padding of any input, given the parameters of the convolutional layer and the input
  131. hidden states.
  132. """
  133. def __init__(self, kernel_size, stride, dilation, value=0):
  134. super().__init__()
  135. # Safety checkers
  136. if isinstance(kernel_size, int):
  137. kernel_size = (kernel_size, kernel_size)
  138. if isinstance(stride, int):
  139. stride = (stride, stride)
  140. if isinstance(dilation, int):
  141. dilation = (dilation, dilation)
  142. self.kernel_size = kernel_size
  143. self.stride = stride
  144. self.dilation = dilation
  145. self.value = value
  146. def compute_padding(x, kernel_size, stride, dilation):
  147. return max((math.ceil(x / stride) - 1) * stride + (kernel_size - 1) * dilation + 1 - x, 0)
  148. self.compute_padding = compute_padding
  149. def forward(self, input):
  150. # Get width and height
  151. input_height, input_width = input.size()[-2:]
  152. # Compute the padding values
  153. padding_height = self.compute_padding(input_height, self.kernel_size[0], self.stride[0], self.dilation[0])
  154. padding_width = self.compute_padding(input_width, self.kernel_size[1], self.stride[1], self.dilation[1])
  155. # apply pad
  156. if padding_height > 0 or padding_width > 0:
  157. input = nn.functional.pad(
  158. input,
  159. [
  160. padding_width // 2,
  161. padding_width - padding_width // 2,
  162. padding_height // 2,
  163. padding_height - padding_height // 2,
  164. ],
  165. value=self.value,
  166. )
  167. return input
  168. class BitMaxPool2d(nn.MaxPool2d):
  169. """Tensorflow like 'SAME' wrapper for 2D max pooling"""
  170. def __init__(
  171. self,
  172. kernel_size: int,
  173. stride=None,
  174. dilation=1,
  175. ceil_mode=False,
  176. padding=(0, 0),
  177. padding_value=0,
  178. use_dynamic_padding=True,
  179. ):
  180. kernel_size = kernel_size if isinstance(kernel_size, collections.abc.Iterable) else (kernel_size, kernel_size)
  181. stride = stride if isinstance(stride, collections.abc.Iterable) else (stride, stride)
  182. dilation = dilation if isinstance(dilation, collections.abc.Iterable) else (dilation, dilation)
  183. super().__init__(kernel_size, stride, padding, dilation, ceil_mode)
  184. if use_dynamic_padding:
  185. self.pad = DynamicPad2d(kernel_size, stride, dilation, padding_value)
  186. else:
  187. self.pad = nn.Identity()
  188. def forward(self, hidden_states):
  189. hidden_states = self.pad(hidden_states)
  190. return nn.functional.max_pool2d(
  191. hidden_states, self.kernel_size, self.stride, self.padding, self.dilation, self.ceil_mode
  192. )
  193. class BitEmbeddings(nn.Module):
  194. """
  195. BiT Embeddings (stem) composed of a single aggressive convolution.
  196. """
  197. def __init__(self, config: BitConfig):
  198. super().__init__()
  199. self.convolution = WeightStandardizedConv2d(
  200. config.num_channels,
  201. config.embedding_size,
  202. kernel_size=7,
  203. stride=2,
  204. eps=1e-8,
  205. padding=config.global_padding,
  206. )
  207. self.pooler = BitMaxPool2d(kernel_size=3, stride=2, use_dynamic_padding=config.embedding_dynamic_padding)
  208. # Use the same padding strategy as convolutional layers
  209. if config.global_padding is not None and config.global_padding.upper() == "SAME":
  210. self.pad = nn.Identity()
  211. else:
  212. self.pad = nn.ConstantPad2d(padding=(1, 1, 1, 1), value=0.0)
  213. if config.layer_type != "preactivation":
  214. self.norm = BitGroupNormActivation(config, num_channels=config.embedding_size)
  215. else:
  216. self.norm = nn.Identity()
  217. self.num_channels = config.num_channels
  218. def forward(self, pixel_values: Tensor) -> Tensor:
  219. num_channels = pixel_values.shape[1]
  220. if num_channels != self.num_channels:
  221. raise ValueError(
  222. "Make sure that the channel dimension of the pixel values match with the one set in the configuration."
  223. )
  224. embedding = self.convolution(pixel_values)
  225. embedding = self.pad(embedding)
  226. embedding = self.norm(embedding)
  227. embedding = self.pooler(embedding)
  228. return embedding
  229. # Copied from transformers.models.convnext.modeling_convnext.drop_path
  230. def drop_path(input: torch.Tensor, drop_prob: float = 0.0, training: bool = False) -> torch.Tensor:
  231. """
  232. Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
  233. Comment by Ross Wightman: This is the same as the DropConnect impl I created for EfficientNet, etc networks,
  234. however, the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper...
  235. See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for changing the
  236. layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use 'survival rate' as the
  237. argument.
  238. """
  239. if drop_prob == 0.0 or not training:
  240. return input
  241. keep_prob = 1 - drop_prob
  242. shape = (input.shape[0],) + (1,) * (input.ndim - 1) # work with diff dim tensors, not just 2D ConvNets
  243. random_tensor = keep_prob + torch.rand(shape, dtype=input.dtype, device=input.device)
  244. random_tensor.floor_() # binarize
  245. output = input.div(keep_prob) * random_tensor
  246. return output
  247. # Copied from transformers.models.beit.modeling_beit.BeitDropPath with Beit->Bit
  248. class BitDropPath(nn.Module):
  249. """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks)."""
  250. def __init__(self, drop_prob: Optional[float] = None) -> None:
  251. super().__init__()
  252. self.drop_prob = drop_prob
  253. def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
  254. return drop_path(hidden_states, self.drop_prob, self.training)
  255. def extra_repr(self) -> str:
  256. return f"p={self.drop_prob}"
  257. def make_div(value, divisor=8):
  258. min_value = divisor
  259. new_value = max(min_value, int(value + divisor / 2) // divisor * divisor)
  260. if new_value < 0.9 * value:
  261. new_value += divisor
  262. return new_value
  263. class BitPreActivationBottleneckLayer(nn.Module):
  264. """Pre-activation (v2) bottleneck block.
  265. Follows the implementation of "Identity Mappings in Deep Residual Networks":
  266. https://github.com/KaimingHe/resnet-1k-layers/blob/master/resnet-pre-act.lua
  267. Except it puts the stride on 3x3 conv when available.
  268. """
  269. def __init__(
  270. self,
  271. config,
  272. in_channels,
  273. out_channels=None,
  274. bottle_ratio=0.25,
  275. stride=1,
  276. dilation=1,
  277. first_dilation=None,
  278. groups=1,
  279. drop_path_rate=0.0,
  280. is_first_layer=False,
  281. ):
  282. super().__init__()
  283. first_dilation = first_dilation or dilation
  284. out_channels = out_channels or in_channels
  285. mid_channels = make_div(out_channels * bottle_ratio)
  286. if is_first_layer:
  287. self.downsample = BitDownsampleConv(
  288. config,
  289. in_channels,
  290. out_channels,
  291. stride=stride,
  292. preact=True,
  293. )
  294. else:
  295. self.downsample = None
  296. self.norm1 = BitGroupNormActivation(config, in_channels)
  297. self.conv1 = WeightStandardizedConv2d(in_channels, mid_channels, 1, eps=1e-8, padding=config.global_padding)
  298. self.norm2 = BitGroupNormActivation(config, num_channels=mid_channels)
  299. self.conv2 = WeightStandardizedConv2d(
  300. mid_channels, mid_channels, 3, stride=stride, groups=groups, eps=1e-8, padding=config.global_padding
  301. )
  302. self.norm3 = BitGroupNormActivation(config, mid_channels)
  303. self.conv3 = WeightStandardizedConv2d(mid_channels, out_channels, 1, eps=1e-8, padding=config.global_padding)
  304. self.drop_path = BitDropPath(drop_path_rate) if drop_path_rate > 0 else nn.Identity()
  305. def forward(self, hidden_states):
  306. hidden_states_preact = self.norm1(hidden_states)
  307. # shortcut branch
  308. shortcut = hidden_states
  309. if self.downsample is not None:
  310. shortcut = self.downsample(hidden_states_preact)
  311. # residual branch
  312. hidden_states = self.conv1(hidden_states_preact)
  313. hidden_states = self.conv2(self.norm2(hidden_states))
  314. hidden_states = self.conv3(self.norm3(hidden_states))
  315. hidden_states = self.drop_path(hidden_states)
  316. return hidden_states + shortcut
  317. class BitBottleneckLayer(nn.Module):
  318. """Non Pre-activation bottleneck block, equivalent to V1.5/V1b bottleneck. Used for ViT Hybrid."""
  319. def __init__(
  320. self,
  321. config,
  322. in_channels,
  323. out_channels=None,
  324. bottle_ratio=0.25,
  325. stride=1,
  326. dilation=1,
  327. first_dilation=None,
  328. groups=1,
  329. drop_path_rate=0.0,
  330. is_first_layer=False,
  331. ):
  332. super().__init__()
  333. first_dilation = first_dilation or dilation
  334. out_channels = out_channels or in_channels
  335. mid_chs = make_div(out_channels * bottle_ratio)
  336. if is_first_layer:
  337. self.downsample = BitDownsampleConv(
  338. config,
  339. in_channels,
  340. out_channels,
  341. stride=stride,
  342. preact=False,
  343. )
  344. else:
  345. self.downsample = None
  346. self.conv1 = WeightStandardizedConv2d(in_channels, mid_chs, 1, eps=1e-8, padding=config.global_padding)
  347. self.norm1 = BitGroupNormActivation(config, num_channels=mid_chs)
  348. self.conv2 = WeightStandardizedConv2d(
  349. mid_chs,
  350. mid_chs,
  351. 3,
  352. stride=stride,
  353. dilation=first_dilation,
  354. groups=groups,
  355. eps=1e-8,
  356. padding=config.global_padding,
  357. )
  358. self.norm2 = BitGroupNormActivation(config, num_channels=mid_chs)
  359. self.conv3 = WeightStandardizedConv2d(mid_chs, out_channels, 1, eps=1e-8, padding=config.global_padding)
  360. self.norm3 = BitGroupNormActivation(config, num_channels=out_channels, apply_activation=False)
  361. self.drop_path = BitDropPath(drop_path_rate) if drop_path_rate > 0 else nn.Identity()
  362. self.activation = ACT2FN[config.hidden_act]
  363. def forward(self, hidden_states):
  364. # shortcut branch
  365. shortcut = hidden_states
  366. if self.downsample is not None:
  367. shortcut = self.downsample(hidden_states)
  368. # residual
  369. hidden_states = self.conv1(hidden_states)
  370. hidden_states = self.norm1(hidden_states)
  371. hidden_states = self.conv2(hidden_states)
  372. hidden_states = self.norm2(hidden_states)
  373. hidden_states = self.conv3(hidden_states)
  374. hidden_states = self.norm3(hidden_states)
  375. hidden_states = self.drop_path(hidden_states)
  376. hidden_states = self.activation(hidden_states + shortcut)
  377. return hidden_states
  378. class BitDownsampleConv(nn.Module):
  379. def __init__(
  380. self,
  381. config,
  382. in_channels,
  383. out_channels,
  384. stride=1,
  385. preact=True,
  386. ):
  387. super().__init__()
  388. self.conv = WeightStandardizedConv2d(
  389. in_channels, out_channels, 1, stride=stride, eps=1e-8, padding=config.global_padding
  390. )
  391. self.norm = (
  392. nn.Identity()
  393. if preact
  394. else BitGroupNormActivation(config, num_channels=out_channels, apply_activation=False)
  395. )
  396. def forward(self, x):
  397. return self.norm(self.conv(x))
  398. class BitStage(nn.Module):
  399. """
  400. A ResNet v2 stage composed by stacked layers.
  401. """
  402. def __init__(
  403. self,
  404. config,
  405. in_channels,
  406. out_channels,
  407. stride,
  408. dilation,
  409. depth,
  410. bottle_ratio=0.25,
  411. layer_dropout=None,
  412. ):
  413. super().__init__()
  414. first_dilation = 1 if dilation in (1, 2) else 2
  415. # Get the layer type
  416. if config.layer_type == "bottleneck":
  417. layer_cls = BitBottleneckLayer
  418. else:
  419. layer_cls = BitPreActivationBottleneckLayer
  420. prev_chs = in_channels
  421. self.layers = nn.Sequential()
  422. for layer_idx in range(depth):
  423. # Get the current hyper-parameters
  424. stride, drop_path_rate, is_first_layer = self._get_updated_hyperparameters(
  425. layer_idx, stride, layer_dropout
  426. )
  427. self.layers.add_module(
  428. str(layer_idx),
  429. layer_cls(
  430. config,
  431. prev_chs,
  432. out_channels,
  433. stride=stride,
  434. dilation=dilation,
  435. bottle_ratio=bottle_ratio,
  436. first_dilation=first_dilation,
  437. drop_path_rate=drop_path_rate,
  438. is_first_layer=is_first_layer,
  439. ),
  440. )
  441. prev_chs = out_channels
  442. first_dilation = dilation
  443. def _get_updated_hyperparameters(self, layer_idx, stride, layer_dropout):
  444. r"""
  445. Get the new hyper-parameters with respect to the previous ones and the index of the current layer.
  446. """
  447. if layer_dropout:
  448. drop_path_rate = layer_dropout[layer_idx]
  449. else:
  450. drop_path_rate = 0.0
  451. if layer_idx != 0:
  452. stride = 1
  453. is_first_layer = layer_idx == 0
  454. return stride, drop_path_rate, is_first_layer
  455. def forward(self, input: Tensor) -> Tensor:
  456. hidden_state = input
  457. for _, layer in enumerate(self.layers):
  458. hidden_state = layer(hidden_state)
  459. return hidden_state
  460. class BitEncoder(nn.Module):
  461. def __init__(self, config: BitConfig):
  462. super().__init__()
  463. self.stages = nn.ModuleList([])
  464. prev_chs = config.embedding_size
  465. # These needs to stay hardcoded
  466. current_stride = 4
  467. dilation = 1
  468. layer_dropouts = [
  469. x.tolist()
  470. for x in torch.Tensor(np.linspace(0, config.drop_path_rate, sum(config.depths))).split(config.depths)
  471. ]
  472. for stage_idx, (current_depth, current_hidden_size, layer_dropout) in enumerate(
  473. zip(config.depths, config.hidden_sizes, layer_dropouts)
  474. ):
  475. # Get the updated hyper params
  476. out_channels, stride, dilation = self._get_updated_hyperparameters(
  477. stage_idx, current_stride, current_hidden_size, dilation, config
  478. )
  479. stage = BitStage(
  480. config,
  481. prev_chs,
  482. out_channels,
  483. stride=stride,
  484. dilation=dilation,
  485. depth=current_depth,
  486. layer_dropout=layer_dropout,
  487. )
  488. prev_chs = out_channels
  489. current_stride *= stride
  490. self.stages.add_module(str(stage_idx), stage)
  491. def _get_updated_hyperparameters(self, stage_idx, current_stride, current_hidden_size, dilation, config):
  492. out_channels = make_div(current_hidden_size * config.width_factor)
  493. stride = 1 if stage_idx == 0 else 2
  494. if current_stride >= config.output_stride:
  495. dilation *= stride
  496. stride = 1
  497. return out_channels, stride, dilation
  498. def forward(
  499. self, hidden_state: Tensor, output_hidden_states: bool = False, return_dict: bool = True
  500. ) -> BaseModelOutputWithNoAttention:
  501. hidden_states = () if output_hidden_states else None
  502. for stage_module in self.stages:
  503. if output_hidden_states:
  504. hidden_states = hidden_states + (hidden_state,)
  505. hidden_state = stage_module(hidden_state)
  506. if output_hidden_states:
  507. hidden_states = hidden_states + (hidden_state,)
  508. if not return_dict:
  509. return tuple(v for v in [hidden_state, hidden_states] if v is not None)
  510. return BaseModelOutputWithNoAttention(
  511. last_hidden_state=hidden_state,
  512. hidden_states=hidden_states,
  513. )
  514. @auto_docstring
  515. class BitPreTrainedModel(PreTrainedModel):
  516. config: BitConfig
  517. base_model_prefix = "bit"
  518. main_input_name = "pixel_values"
  519. _no_split_modules = ["BitEmbeddings"]
  520. def _init_weights(self, module):
  521. if isinstance(module, nn.Conv2d):
  522. nn.init.kaiming_normal_(module.weight, mode="fan_out", nonlinearity="relu")
  523. # copied from the `reset_parameters` method of `class Linear(Module)` in `torch`.
  524. elif isinstance(module, nn.Linear):
  525. nn.init.kaiming_uniform_(module.weight, a=math.sqrt(5))
  526. if module.bias is not None:
  527. fan_in, _ = nn.init._calculate_fan_in_and_fan_out(module.weight)
  528. bound = 1 / math.sqrt(fan_in) if fan_in > 0 else 0
  529. nn.init.uniform_(module.bias, -bound, bound)
  530. elif isinstance(module, (nn.BatchNorm2d, nn.GroupNorm)):
  531. nn.init.constant_(module.weight, 1)
  532. nn.init.constant_(module.bias, 0)
  533. @auto_docstring
  534. class BitModel(BitPreTrainedModel):
  535. def __init__(self, config):
  536. super().__init__(config)
  537. self.config = config
  538. self.embedder = BitEmbeddings(config)
  539. self.encoder = BitEncoder(config)
  540. self.norm = (
  541. BitGroupNormActivation(config, num_channels=config.hidden_sizes[-1])
  542. if config.layer_type == "preactivation"
  543. else nn.Identity()
  544. )
  545. self.pooler = nn.AdaptiveAvgPool2d((1, 1))
  546. # Initialize weights and apply final processing
  547. self.post_init()
  548. @auto_docstring
  549. def forward(
  550. self, pixel_values: Tensor, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None
  551. ) -> BaseModelOutputWithPoolingAndNoAttention:
  552. output_hidden_states = (
  553. output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
  554. )
  555. return_dict = return_dict if return_dict is not None else self.config.use_return_dict
  556. embedding_output = self.embedder(pixel_values)
  557. encoder_outputs = self.encoder(
  558. embedding_output, output_hidden_states=output_hidden_states, return_dict=return_dict
  559. )
  560. last_hidden_state = encoder_outputs[0]
  561. last_hidden_state = self.norm(last_hidden_state)
  562. pooled_output = self.pooler(last_hidden_state)
  563. if not return_dict:
  564. return (last_hidden_state, pooled_output) + encoder_outputs[1:]
  565. return BaseModelOutputWithPoolingAndNoAttention(
  566. last_hidden_state=last_hidden_state,
  567. pooler_output=pooled_output,
  568. hidden_states=encoder_outputs.hidden_states,
  569. )
  570. @auto_docstring(
  571. custom_intro="""
  572. BiT Model with an image classification head on top (a linear layer on top of the pooled features), e.g. for
  573. ImageNet.
  574. """
  575. )
  576. class BitForImageClassification(BitPreTrainedModel):
  577. def __init__(self, config):
  578. super().__init__(config)
  579. self.num_labels = config.num_labels
  580. self.bit = BitModel(config)
  581. # classification head
  582. self.classifier = nn.Sequential(
  583. nn.Flatten(),
  584. nn.Linear(config.hidden_sizes[-1], config.num_labels) if config.num_labels > 0 else nn.Identity(),
  585. )
  586. # initialize weights and apply final processing
  587. self.post_init()
  588. @auto_docstring
  589. def forward(
  590. self,
  591. pixel_values: Optional[torch.FloatTensor] = None,
  592. labels: Optional[torch.LongTensor] = None,
  593. output_hidden_states: Optional[bool] = None,
  594. return_dict: Optional[bool] = None,
  595. ) -> ImageClassifierOutputWithNoAttention:
  596. r"""
  597. labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
  598. Labels for computing the image classification/regression loss. Indices should be in `[0, ...,
  599. config.num_labels - 1]`. If `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
  600. """
  601. return_dict = return_dict if return_dict is not None else self.config.use_return_dict
  602. outputs = self.bit(pixel_values, output_hidden_states=output_hidden_states, return_dict=return_dict)
  603. pooled_output = outputs.pooler_output if return_dict else outputs[1]
  604. logits = self.classifier(pooled_output)
  605. loss = None
  606. if labels is not None:
  607. loss = self.loss_function(labels, logits, self.config)
  608. if not return_dict:
  609. output = (logits,) + outputs[2:]
  610. return (loss,) + output if loss is not None else output
  611. return ImageClassifierOutputWithNoAttention(loss=loss, logits=logits, hidden_states=outputs.hidden_states)
  612. @auto_docstring(
  613. custom_intro="""
  614. BiT backbone, to be used with frameworks like DETR and MaskFormer.
  615. """
  616. )
  617. class BitBackbone(BitPreTrainedModel, BackboneMixin):
  618. has_attentions = False
  619. def __init__(self, config):
  620. super().__init__(config)
  621. super()._init_backbone(config)
  622. self.bit = BitModel(config)
  623. self.num_features = [config.embedding_size] + config.hidden_sizes
  624. # initialize weights and apply final processing
  625. self.post_init()
  626. @auto_docstring
  627. def forward(
  628. self, pixel_values: Tensor, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None
  629. ) -> BackboneOutput:
  630. r"""
  631. Examples:
  632. ```python
  633. >>> from transformers import AutoImageProcessor, AutoBackbone
  634. >>> import torch
  635. >>> from PIL import Image
  636. >>> import requests
  637. >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
  638. >>> image = Image.open(requests.get(url, stream=True).raw)
  639. >>> processor = AutoImageProcessor.from_pretrained("google/bit-50")
  640. >>> model = AutoBackbone.from_pretrained("google/bit-50")
  641. >>> inputs = processor(image, return_tensors="pt")
  642. >>> outputs = model(**inputs)
  643. ```"""
  644. return_dict = return_dict if return_dict is not None else self.config.use_return_dict
  645. output_hidden_states = (
  646. output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
  647. )
  648. outputs = self.bit(pixel_values, output_hidden_states=True, return_dict=True)
  649. hidden_states = outputs.hidden_states
  650. feature_maps = ()
  651. for idx, stage in enumerate(self.stage_names):
  652. if stage in self.out_features:
  653. feature_maps += (hidden_states[idx],)
  654. if not return_dict:
  655. output = (feature_maps,)
  656. if output_hidden_states:
  657. output += (outputs.hidden_states,)
  658. return output
  659. return BackboneOutput(
  660. feature_maps=feature_maps,
  661. hidden_states=outputs.hidden_states if output_hidden_states else None,
  662. attentions=None,
  663. )
  664. __all__ = ["BitForImageClassification", "BitModel", "BitPreTrainedModel", "BitBackbone"]