modeling_efficientnet.py 20 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561
  1. # coding=utf-8
  2. # Copyright 2023 Google Research, Inc. and The HuggingFace Inc. team. All rights reserved.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. """PyTorch EfficientNet model."""
  16. import math
  17. from typing import Optional, Union
  18. import torch
  19. from torch import nn
  20. from ...activations import ACT2FN
  21. from ...modeling_outputs import (
  22. BaseModelOutputWithNoAttention,
  23. BaseModelOutputWithPoolingAndNoAttention,
  24. ImageClassifierOutputWithNoAttention,
  25. )
  26. from ...modeling_utils import PreTrainedModel
  27. from ...utils import auto_docstring, logging
  28. from .configuration_efficientnet import EfficientNetConfig
  29. logger = logging.get_logger(__name__)
  30. def round_filters(config: EfficientNetConfig, num_channels: int):
  31. r"""
  32. Round number of filters based on depth multiplier.
  33. """
  34. divisor = config.depth_divisor
  35. num_channels *= config.width_coefficient
  36. new_dim = max(divisor, int(num_channels + divisor / 2) // divisor * divisor)
  37. # Make sure that round down does not go down by more than 10%.
  38. if new_dim < 0.9 * num_channels:
  39. new_dim += divisor
  40. return int(new_dim)
  41. def correct_pad(kernel_size: Union[int, tuple], adjust: bool = True):
  42. r"""
  43. Utility function to get the tuple padding value for the depthwise convolution.
  44. Args:
  45. kernel_size (`int` or `tuple`):
  46. Kernel size of the convolution layers.
  47. adjust (`bool`, *optional*, defaults to `True`):
  48. Adjusts padding value to apply to right and bottom sides of the input.
  49. """
  50. if isinstance(kernel_size, int):
  51. kernel_size = (kernel_size, kernel_size)
  52. correct = (kernel_size[0] // 2, kernel_size[1] // 2)
  53. if adjust:
  54. return (correct[1] - 1, correct[1], correct[0] - 1, correct[0])
  55. else:
  56. return (correct[1], correct[1], correct[0], correct[0])
  57. class EfficientNetEmbeddings(nn.Module):
  58. r"""
  59. A module that corresponds to the stem module of the original work.
  60. """
  61. def __init__(self, config: EfficientNetConfig):
  62. super().__init__()
  63. self.out_dim = round_filters(config, 32)
  64. self.padding = nn.ZeroPad2d(padding=(0, 1, 0, 1))
  65. self.convolution = nn.Conv2d(
  66. config.num_channels, self.out_dim, kernel_size=3, stride=2, padding="valid", bias=False
  67. )
  68. self.batchnorm = nn.BatchNorm2d(self.out_dim, eps=config.batch_norm_eps, momentum=config.batch_norm_momentum)
  69. self.activation = ACT2FN[config.hidden_act]
  70. def forward(self, pixel_values: torch.Tensor) -> torch.Tensor:
  71. features = self.padding(pixel_values)
  72. features = self.convolution(features)
  73. features = self.batchnorm(features)
  74. features = self.activation(features)
  75. return features
  76. class EfficientNetDepthwiseConv2d(nn.Conv2d):
  77. def __init__(
  78. self,
  79. in_channels,
  80. depth_multiplier=1,
  81. kernel_size=3,
  82. stride=1,
  83. padding=0,
  84. dilation=1,
  85. bias=True,
  86. padding_mode="zeros",
  87. ):
  88. out_channels = in_channels * depth_multiplier
  89. super().__init__(
  90. in_channels=in_channels,
  91. out_channels=out_channels,
  92. kernel_size=kernel_size,
  93. stride=stride,
  94. padding=padding,
  95. dilation=dilation,
  96. groups=in_channels,
  97. bias=bias,
  98. padding_mode=padding_mode,
  99. )
  100. class EfficientNetExpansionLayer(nn.Module):
  101. r"""
  102. This corresponds to the expansion phase of each block in the original implementation.
  103. """
  104. def __init__(self, config: EfficientNetConfig, in_dim: int, out_dim: int, stride: int):
  105. super().__init__()
  106. self.expand_conv = nn.Conv2d(
  107. in_channels=in_dim,
  108. out_channels=out_dim,
  109. kernel_size=1,
  110. padding="same",
  111. bias=False,
  112. )
  113. self.expand_bn = nn.BatchNorm2d(num_features=out_dim, eps=config.batch_norm_eps)
  114. self.expand_act = ACT2FN[config.hidden_act]
  115. def forward(self, hidden_states: torch.FloatTensor) -> torch.Tensor:
  116. # Expand phase
  117. hidden_states = self.expand_conv(hidden_states)
  118. hidden_states = self.expand_bn(hidden_states)
  119. hidden_states = self.expand_act(hidden_states)
  120. return hidden_states
  121. class EfficientNetDepthwiseLayer(nn.Module):
  122. r"""
  123. This corresponds to the depthwise convolution phase of each block in the original implementation.
  124. """
  125. def __init__(
  126. self,
  127. config: EfficientNetConfig,
  128. in_dim: int,
  129. stride: int,
  130. kernel_size: int,
  131. adjust_padding: bool,
  132. ):
  133. super().__init__()
  134. self.stride = stride
  135. conv_pad = "valid" if self.stride == 2 else "same"
  136. padding = correct_pad(kernel_size, adjust=adjust_padding)
  137. self.depthwise_conv_pad = nn.ZeroPad2d(padding=padding)
  138. self.depthwise_conv = EfficientNetDepthwiseConv2d(
  139. in_dim, kernel_size=kernel_size, stride=stride, padding=conv_pad, bias=False
  140. )
  141. self.depthwise_norm = nn.BatchNorm2d(
  142. num_features=in_dim, eps=config.batch_norm_eps, momentum=config.batch_norm_momentum
  143. )
  144. self.depthwise_act = ACT2FN[config.hidden_act]
  145. def forward(self, hidden_states: torch.FloatTensor) -> torch.Tensor:
  146. # Depthwise convolution
  147. if self.stride == 2:
  148. hidden_states = self.depthwise_conv_pad(hidden_states)
  149. hidden_states = self.depthwise_conv(hidden_states)
  150. hidden_states = self.depthwise_norm(hidden_states)
  151. hidden_states = self.depthwise_act(hidden_states)
  152. return hidden_states
  153. class EfficientNetSqueezeExciteLayer(nn.Module):
  154. r"""
  155. This corresponds to the Squeeze and Excitement phase of each block in the original implementation.
  156. """
  157. def __init__(self, config: EfficientNetConfig, in_dim: int, expand_dim: int, expand: bool = False):
  158. super().__init__()
  159. self.dim = expand_dim if expand else in_dim
  160. self.dim_se = max(1, int(in_dim * config.squeeze_expansion_ratio))
  161. self.squeeze = nn.AdaptiveAvgPool2d(output_size=1)
  162. self.reduce = nn.Conv2d(
  163. in_channels=self.dim,
  164. out_channels=self.dim_se,
  165. kernel_size=1,
  166. padding="same",
  167. )
  168. self.expand = nn.Conv2d(
  169. in_channels=self.dim_se,
  170. out_channels=self.dim,
  171. kernel_size=1,
  172. padding="same",
  173. )
  174. self.act_reduce = ACT2FN[config.hidden_act]
  175. self.act_expand = nn.Sigmoid()
  176. def forward(self, hidden_states: torch.FloatTensor) -> torch.Tensor:
  177. inputs = hidden_states
  178. hidden_states = self.squeeze(hidden_states)
  179. hidden_states = self.reduce(hidden_states)
  180. hidden_states = self.act_reduce(hidden_states)
  181. hidden_states = self.expand(hidden_states)
  182. hidden_states = self.act_expand(hidden_states)
  183. hidden_states = torch.mul(inputs, hidden_states)
  184. return hidden_states
  185. class EfficientNetFinalBlockLayer(nn.Module):
  186. r"""
  187. This corresponds to the final phase of each block in the original implementation.
  188. """
  189. def __init__(
  190. self, config: EfficientNetConfig, in_dim: int, out_dim: int, stride: int, drop_rate: float, id_skip: bool
  191. ):
  192. super().__init__()
  193. self.apply_dropout = stride == 1 and not id_skip
  194. self.project_conv = nn.Conv2d(
  195. in_channels=in_dim,
  196. out_channels=out_dim,
  197. kernel_size=1,
  198. padding="same",
  199. bias=False,
  200. )
  201. self.project_bn = nn.BatchNorm2d(
  202. num_features=out_dim, eps=config.batch_norm_eps, momentum=config.batch_norm_momentum
  203. )
  204. self.dropout = nn.Dropout(p=drop_rate)
  205. def forward(self, embeddings: torch.FloatTensor, hidden_states: torch.FloatTensor) -> torch.Tensor:
  206. hidden_states = self.project_conv(hidden_states)
  207. hidden_states = self.project_bn(hidden_states)
  208. if self.apply_dropout:
  209. hidden_states = self.dropout(hidden_states)
  210. hidden_states = hidden_states + embeddings
  211. return hidden_states
  212. class EfficientNetBlock(nn.Module):
  213. r"""
  214. This corresponds to the expansion and depthwise convolution phase of each block in the original implementation.
  215. Args:
  216. config ([`EfficientNetConfig`]):
  217. Model configuration class.
  218. in_dim (`int`):
  219. Number of input channels.
  220. out_dim (`int`):
  221. Number of output channels.
  222. stride (`int`):
  223. Stride size to be used in convolution layers.
  224. expand_ratio (`int`):
  225. Expand ratio to set the output dimensions for the expansion and squeeze-excite layers.
  226. kernel_size (`int`):
  227. Kernel size for the depthwise convolution layer.
  228. drop_rate (`float`):
  229. Dropout rate to be used in the final phase of each block.
  230. id_skip (`bool`):
  231. Whether to apply dropout and sum the final hidden states with the input embeddings during the final phase
  232. of each block. Set to `True` for the first block of each stage.
  233. adjust_padding (`bool`):
  234. Whether to apply padding to only right and bottom side of the input kernel before the depthwise convolution
  235. operation, set to `True` for inputs with odd input sizes.
  236. """
  237. def __init__(
  238. self,
  239. config: EfficientNetConfig,
  240. in_dim: int,
  241. out_dim: int,
  242. stride: int,
  243. expand_ratio: int,
  244. kernel_size: int,
  245. drop_rate: float,
  246. id_skip: bool,
  247. adjust_padding: bool,
  248. ):
  249. super().__init__()
  250. self.expand_ratio = expand_ratio
  251. self.expand = self.expand_ratio != 1
  252. expand_in_dim = in_dim * expand_ratio
  253. if self.expand:
  254. self.expansion = EfficientNetExpansionLayer(
  255. config=config, in_dim=in_dim, out_dim=expand_in_dim, stride=stride
  256. )
  257. self.depthwise_conv = EfficientNetDepthwiseLayer(
  258. config=config,
  259. in_dim=expand_in_dim if self.expand else in_dim,
  260. stride=stride,
  261. kernel_size=kernel_size,
  262. adjust_padding=adjust_padding,
  263. )
  264. self.squeeze_excite = EfficientNetSqueezeExciteLayer(
  265. config=config, in_dim=in_dim, expand_dim=expand_in_dim, expand=self.expand
  266. )
  267. self.projection = EfficientNetFinalBlockLayer(
  268. config=config,
  269. in_dim=expand_in_dim if self.expand else in_dim,
  270. out_dim=out_dim,
  271. stride=stride,
  272. drop_rate=drop_rate,
  273. id_skip=id_skip,
  274. )
  275. def forward(self, hidden_states: torch.FloatTensor) -> torch.Tensor:
  276. embeddings = hidden_states
  277. # Expansion and depthwise convolution phase
  278. if self.expand_ratio != 1:
  279. hidden_states = self.expansion(hidden_states)
  280. hidden_states = self.depthwise_conv(hidden_states)
  281. # Squeeze and excite phase
  282. hidden_states = self.squeeze_excite(hidden_states)
  283. hidden_states = self.projection(embeddings, hidden_states)
  284. return hidden_states
  285. class EfficientNetEncoder(nn.Module):
  286. r"""
  287. Forward propagates the embeddings through each EfficientNet block.
  288. Args:
  289. config ([`EfficientNetConfig`]):
  290. Model configuration class.
  291. """
  292. def __init__(self, config: EfficientNetConfig):
  293. super().__init__()
  294. self.config = config
  295. self.depth_coefficient = config.depth_coefficient
  296. def round_repeats(repeats):
  297. # Round number of block repeats based on depth multiplier.
  298. return int(math.ceil(self.depth_coefficient * repeats))
  299. num_base_blocks = len(config.in_channels)
  300. num_blocks = sum(round_repeats(n) for n in config.num_block_repeats)
  301. curr_block_num = 0
  302. blocks = []
  303. for i in range(num_base_blocks):
  304. in_dim = round_filters(config, config.in_channels[i])
  305. out_dim = round_filters(config, config.out_channels[i])
  306. stride = config.strides[i]
  307. kernel_size = config.kernel_sizes[i]
  308. expand_ratio = config.expand_ratios[i]
  309. for j in range(round_repeats(config.num_block_repeats[i])):
  310. id_skip = j == 0
  311. stride = 1 if j > 0 else stride
  312. in_dim = out_dim if j > 0 else in_dim
  313. adjust_padding = curr_block_num not in config.depthwise_padding
  314. drop_rate = config.drop_connect_rate * curr_block_num / num_blocks
  315. block = EfficientNetBlock(
  316. config=config,
  317. in_dim=in_dim,
  318. out_dim=out_dim,
  319. stride=stride,
  320. kernel_size=kernel_size,
  321. expand_ratio=expand_ratio,
  322. drop_rate=drop_rate,
  323. id_skip=id_skip,
  324. adjust_padding=adjust_padding,
  325. )
  326. blocks.append(block)
  327. curr_block_num += 1
  328. self.blocks = nn.ModuleList(blocks)
  329. self.top_conv = nn.Conv2d(
  330. in_channels=out_dim,
  331. out_channels=round_filters(config, 1280),
  332. kernel_size=1,
  333. padding="same",
  334. bias=False,
  335. )
  336. self.top_bn = nn.BatchNorm2d(
  337. num_features=config.hidden_dim, eps=config.batch_norm_eps, momentum=config.batch_norm_momentum
  338. )
  339. self.top_activation = ACT2FN[config.hidden_act]
  340. def forward(
  341. self,
  342. hidden_states: torch.FloatTensor,
  343. output_hidden_states: Optional[bool] = False,
  344. return_dict: Optional[bool] = True,
  345. ) -> BaseModelOutputWithNoAttention:
  346. all_hidden_states = (hidden_states,) if output_hidden_states else None
  347. for block in self.blocks:
  348. hidden_states = block(hidden_states)
  349. if output_hidden_states:
  350. all_hidden_states += (hidden_states,)
  351. hidden_states = self.top_conv(hidden_states)
  352. hidden_states = self.top_bn(hidden_states)
  353. hidden_states = self.top_activation(hidden_states)
  354. if not return_dict:
  355. return tuple(v for v in [hidden_states, all_hidden_states] if v is not None)
  356. return BaseModelOutputWithNoAttention(
  357. last_hidden_state=hidden_states,
  358. hidden_states=all_hidden_states,
  359. )
  360. @auto_docstring
  361. class EfficientNetPreTrainedModel(PreTrainedModel):
  362. config: EfficientNetConfig
  363. base_model_prefix = "efficientnet"
  364. main_input_name = "pixel_values"
  365. _no_split_modules = []
  366. def _init_weights(self, module: nn.Module):
  367. """Initialize the weights"""
  368. if isinstance(module, (nn.Linear, nn.Conv2d, nn.BatchNorm2d)):
  369. # Slightly different from the TF version which uses truncated_normal for initialization
  370. # cf https://github.com/pytorch/pytorch/pull/5617
  371. module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
  372. if module.bias is not None:
  373. module.bias.data.zero_()
  374. @auto_docstring
  375. class EfficientNetModel(EfficientNetPreTrainedModel):
  376. def __init__(self, config: EfficientNetConfig):
  377. super().__init__(config)
  378. self.config = config
  379. self.embeddings = EfficientNetEmbeddings(config)
  380. self.encoder = EfficientNetEncoder(config)
  381. # Final pooling layer
  382. if config.pooling_type == "mean":
  383. self.pooler = nn.AvgPool2d(config.hidden_dim, ceil_mode=True)
  384. elif config.pooling_type == "max":
  385. self.pooler = nn.MaxPool2d(config.hidden_dim, ceil_mode=True)
  386. else:
  387. raise ValueError(f"config.pooling must be one of ['mean', 'max'] got {config.pooling}")
  388. # Initialize weights and apply final processing
  389. self.post_init()
  390. @auto_docstring
  391. def forward(
  392. self,
  393. pixel_values: Optional[torch.FloatTensor] = None,
  394. output_hidden_states: Optional[bool] = None,
  395. return_dict: Optional[bool] = None,
  396. ) -> Union[tuple, BaseModelOutputWithPoolingAndNoAttention]:
  397. output_hidden_states = (
  398. output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
  399. )
  400. return_dict = return_dict if return_dict is not None else self.config.use_return_dict
  401. if pixel_values is None:
  402. raise ValueError("You have to specify pixel_values")
  403. embedding_output = self.embeddings(pixel_values)
  404. encoder_outputs = self.encoder(
  405. embedding_output,
  406. output_hidden_states=output_hidden_states,
  407. return_dict=return_dict,
  408. )
  409. # Apply pooling
  410. last_hidden_state = encoder_outputs[0]
  411. pooled_output = self.pooler(last_hidden_state)
  412. # Reshape (batch_size, 1280, 1 , 1) -> (batch_size, 1280)
  413. pooled_output = pooled_output.reshape(pooled_output.shape[:2])
  414. if not return_dict:
  415. return (last_hidden_state, pooled_output) + encoder_outputs[1:]
  416. return BaseModelOutputWithPoolingAndNoAttention(
  417. last_hidden_state=last_hidden_state,
  418. pooler_output=pooled_output,
  419. hidden_states=encoder_outputs.hidden_states,
  420. )
  421. @auto_docstring(
  422. custom_intro="""
  423. EfficientNet Model with an image classification head on top (a linear layer on top of the pooled features), e.g.
  424. for ImageNet.
  425. """
  426. )
  427. class EfficientNetForImageClassification(EfficientNetPreTrainedModel):
  428. def __init__(self, config):
  429. super().__init__(config)
  430. self.num_labels = config.num_labels
  431. self.config = config
  432. self.efficientnet = EfficientNetModel(config)
  433. # Classifier head
  434. self.dropout = nn.Dropout(p=config.dropout_rate)
  435. self.classifier = nn.Linear(config.hidden_dim, self.num_labels) if self.num_labels > 0 else nn.Identity()
  436. # Initialize weights and apply final processing
  437. self.post_init()
  438. @auto_docstring
  439. def forward(
  440. self,
  441. pixel_values: Optional[torch.FloatTensor] = None,
  442. labels: Optional[torch.LongTensor] = None,
  443. output_hidden_states: Optional[bool] = None,
  444. return_dict: Optional[bool] = None,
  445. ) -> Union[tuple, ImageClassifierOutputWithNoAttention]:
  446. r"""
  447. labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
  448. Labels for computing the image classification/regression loss. Indices should be in `[0, ...,
  449. config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
  450. `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
  451. """
  452. return_dict = return_dict if return_dict is not None else self.config.use_return_dict
  453. outputs = self.efficientnet(pixel_values, output_hidden_states=output_hidden_states, return_dict=return_dict)
  454. pooled_output = outputs.pooler_output if return_dict else outputs[1]
  455. pooled_output = self.dropout(pooled_output)
  456. logits = self.classifier(pooled_output)
  457. loss = None
  458. if labels is not None:
  459. loss = self.loss_function(labels, logits, self.config)
  460. if not return_dict:
  461. output = (logits,) + outputs[2:]
  462. return ((loss,) + output) if loss is not None else output
  463. return ImageClassifierOutputWithNoAttention(
  464. loss=loss,
  465. logits=logits,
  466. hidden_states=outputs.hidden_states,
  467. )
  468. __all__ = ["EfficientNetForImageClassification", "EfficientNetModel", "EfficientNetPreTrainedModel"]