modeling_textnet.py 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415
  1. # coding=utf-8
  2. # Copyright 2024 the Fast authors and The HuggingFace Inc. team. All rights reserved.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. """PyTorch TextNet model."""
  16. from typing import Any, Optional, Union
  17. import torch
  18. import torch.nn as nn
  19. from torch import Tensor
  20. from transformers import PreTrainedModel
  21. from transformers.activations import ACT2CLS
  22. from transformers.modeling_outputs import (
  23. BackboneOutput,
  24. BaseModelOutputWithNoAttention,
  25. BaseModelOutputWithPoolingAndNoAttention,
  26. ImageClassifierOutputWithNoAttention,
  27. )
  28. from transformers.models.textnet.configuration_textnet import TextNetConfig
  29. from transformers.utils import logging
  30. from transformers.utils.backbone_utils import BackboneMixin
  31. from ...utils import auto_docstring
  32. logger = logging.get_logger(__name__)
  33. class TextNetConvLayer(nn.Module):
  34. def __init__(self, config: TextNetConfig):
  35. super().__init__()
  36. self.kernel_size = config.stem_kernel_size
  37. self.stride = config.stem_stride
  38. self.activation_function = config.stem_act_func
  39. padding = (
  40. (config.kernel_size[0] // 2, config.kernel_size[1] // 2)
  41. if isinstance(config.stem_kernel_size, tuple)
  42. else config.stem_kernel_size // 2
  43. )
  44. self.conv = nn.Conv2d(
  45. config.stem_num_channels,
  46. config.stem_out_channels,
  47. kernel_size=config.stem_kernel_size,
  48. stride=config.stem_stride,
  49. padding=padding,
  50. bias=False,
  51. )
  52. self.batch_norm = nn.BatchNorm2d(config.stem_out_channels, config.batch_norm_eps)
  53. self.activation = nn.Identity()
  54. if self.activation_function is not None:
  55. self.activation = ACT2CLS[self.activation_function]()
  56. def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
  57. hidden_states = self.conv(hidden_states)
  58. hidden_states = self.batch_norm(hidden_states)
  59. return self.activation(hidden_states)
  60. class TextNetRepConvLayer(nn.Module):
  61. r"""
  62. This layer supports re-parameterization by combining multiple convolutional branches
  63. (e.g., main convolution, vertical, horizontal, and identity branches) during training.
  64. At inference time, these branches can be collapsed into a single convolution for
  65. efficiency, as per the re-parameterization paradigm.
  66. The "Rep" in the name stands for "re-parameterization" (introduced by RepVGG).
  67. """
  68. def __init__(self, config: TextNetConfig, in_channels: int, out_channels: int, kernel_size: int, stride: int):
  69. super().__init__()
  70. self.num_channels = in_channels
  71. self.out_channels = out_channels
  72. self.kernel_size = kernel_size
  73. self.stride = stride
  74. padding = ((kernel_size[0] - 1) // 2, (kernel_size[1] - 1) // 2)
  75. self.activation_function = nn.ReLU()
  76. self.main_conv = nn.Conv2d(
  77. in_channels=in_channels,
  78. out_channels=out_channels,
  79. kernel_size=kernel_size,
  80. stride=stride,
  81. padding=padding,
  82. bias=False,
  83. )
  84. self.main_batch_norm = nn.BatchNorm2d(num_features=out_channels, eps=config.batch_norm_eps)
  85. vertical_padding = ((kernel_size[0] - 1) // 2, 0)
  86. horizontal_padding = (0, (kernel_size[1] - 1) // 2)
  87. if kernel_size[1] != 1:
  88. self.vertical_conv = nn.Conv2d(
  89. in_channels=in_channels,
  90. out_channels=out_channels,
  91. kernel_size=(kernel_size[0], 1),
  92. stride=stride,
  93. padding=vertical_padding,
  94. bias=False,
  95. )
  96. self.vertical_batch_norm = nn.BatchNorm2d(num_features=out_channels, eps=config.batch_norm_eps)
  97. else:
  98. self.vertical_conv, self.vertical_batch_norm = None, None
  99. if kernel_size[0] != 1:
  100. self.horizontal_conv = nn.Conv2d(
  101. in_channels=in_channels,
  102. out_channels=out_channels,
  103. kernel_size=(1, kernel_size[1]),
  104. stride=stride,
  105. padding=horizontal_padding,
  106. bias=False,
  107. )
  108. self.horizontal_batch_norm = nn.BatchNorm2d(num_features=out_channels, eps=config.batch_norm_eps)
  109. else:
  110. self.horizontal_conv, self.horizontal_batch_norm = None, None
  111. self.rbr_identity = (
  112. nn.BatchNorm2d(num_features=in_channels, eps=config.batch_norm_eps)
  113. if out_channels == in_channels and stride == 1
  114. else None
  115. )
  116. def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
  117. main_outputs = self.main_conv(hidden_states)
  118. main_outputs = self.main_batch_norm(main_outputs)
  119. # applies a convolution with a vertical kernel
  120. if self.vertical_conv is not None:
  121. vertical_outputs = self.vertical_conv(hidden_states)
  122. vertical_outputs = self.vertical_batch_norm(vertical_outputs)
  123. main_outputs = main_outputs + vertical_outputs
  124. # applies a convolution with a horizontal kernel
  125. if self.horizontal_conv is not None:
  126. horizontal_outputs = self.horizontal_conv(hidden_states)
  127. horizontal_outputs = self.horizontal_batch_norm(horizontal_outputs)
  128. main_outputs = main_outputs + horizontal_outputs
  129. if self.rbr_identity is not None:
  130. id_out = self.rbr_identity(hidden_states)
  131. main_outputs = main_outputs + id_out
  132. return self.activation_function(main_outputs)
  133. class TextNetStage(nn.Module):
  134. def __init__(self, config: TextNetConfig, depth: int):
  135. super().__init__()
  136. kernel_size = config.conv_layer_kernel_sizes[depth]
  137. stride = config.conv_layer_strides[depth]
  138. num_layers = len(kernel_size)
  139. stage_in_channel_size = config.hidden_sizes[depth]
  140. stage_out_channel_size = config.hidden_sizes[depth + 1]
  141. in_channels = [stage_in_channel_size] + [stage_out_channel_size] * (num_layers - 1)
  142. out_channels = [stage_out_channel_size] * num_layers
  143. stage = []
  144. for stage_config in zip(in_channels, out_channels, kernel_size, stride):
  145. stage.append(TextNetRepConvLayer(config, *stage_config))
  146. self.stage = nn.ModuleList(stage)
  147. def forward(self, hidden_state):
  148. for block in self.stage:
  149. hidden_state = block(hidden_state)
  150. return hidden_state
  151. class TextNetEncoder(nn.Module):
  152. def __init__(self, config: TextNetConfig):
  153. super().__init__()
  154. stages = []
  155. num_stages = len(config.conv_layer_kernel_sizes)
  156. for stage_ix in range(num_stages):
  157. stages.append(TextNetStage(config, stage_ix))
  158. self.stages = nn.ModuleList(stages)
  159. def forward(
  160. self,
  161. hidden_state: torch.Tensor,
  162. output_hidden_states: Optional[bool] = None,
  163. return_dict: Optional[bool] = None,
  164. ) -> BaseModelOutputWithNoAttention:
  165. hidden_states = [hidden_state]
  166. for stage in self.stages:
  167. hidden_state = stage(hidden_state)
  168. hidden_states.append(hidden_state)
  169. if not return_dict:
  170. output = (hidden_state,)
  171. return output + (hidden_states,) if output_hidden_states else output
  172. return BaseModelOutputWithNoAttention(last_hidden_state=hidden_state, hidden_states=hidden_states)
  173. @auto_docstring
  174. class TextNetPreTrainedModel(PreTrainedModel):
  175. config: TextNetConfig
  176. base_model_prefix = "textnet"
  177. main_input_name = "pixel_values"
  178. def _init_weights(self, module):
  179. if isinstance(module, (nn.Linear, nn.Conv2d)):
  180. module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
  181. if module.bias is not None:
  182. module.bias.data.zero_()
  183. elif isinstance(module, nn.BatchNorm2d):
  184. module.weight.data.fill_(1.0)
  185. if module.bias is not None:
  186. module.bias.data.zero_()
  187. @auto_docstring
  188. class TextNetModel(TextNetPreTrainedModel):
  189. def __init__(self, config):
  190. super().__init__(config)
  191. self.stem = TextNetConvLayer(config)
  192. self.encoder = TextNetEncoder(config)
  193. self.pooler = nn.AdaptiveAvgPool2d((2, 2))
  194. self.post_init()
  195. @auto_docstring
  196. def forward(
  197. self, pixel_values: Tensor, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None
  198. ) -> Union[tuple[Any, list[Any]], tuple[Any], BaseModelOutputWithPoolingAndNoAttention]:
  199. return_dict = return_dict if return_dict is not None else self.config.use_return_dict
  200. output_hidden_states = (
  201. output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
  202. )
  203. hidden_state = self.stem(pixel_values)
  204. encoder_outputs = self.encoder(
  205. hidden_state, output_hidden_states=output_hidden_states, return_dict=return_dict
  206. )
  207. last_hidden_state = encoder_outputs[0]
  208. pooled_output = self.pooler(last_hidden_state)
  209. if not return_dict:
  210. output = (last_hidden_state, pooled_output)
  211. return output + (encoder_outputs[1],) if output_hidden_states else output
  212. return BaseModelOutputWithPoolingAndNoAttention(
  213. last_hidden_state=last_hidden_state,
  214. pooler_output=pooled_output,
  215. hidden_states=encoder_outputs[1] if output_hidden_states else None,
  216. )
  217. @auto_docstring(
  218. custom_intro="""
  219. TextNet Model with an image classification head on top (a linear layer on top of the pooled features), e.g. for
  220. ImageNet.
  221. """
  222. )
  223. class TextNetForImageClassification(TextNetPreTrainedModel):
  224. def __init__(self, config):
  225. super().__init__(config)
  226. self.num_labels = config.num_labels
  227. self.textnet = TextNetModel(config)
  228. self.avg_pool = nn.AdaptiveAvgPool2d((1, 1))
  229. self.flatten = nn.Flatten()
  230. self.fc = nn.Linear(config.hidden_sizes[-1], config.num_labels) if config.num_labels > 0 else nn.Identity()
  231. # classification head
  232. self.classifier = nn.ModuleList([self.avg_pool, self.flatten])
  233. # initialize weights and apply final processing
  234. self.post_init()
  235. @auto_docstring
  236. def forward(
  237. self,
  238. pixel_values: Optional[torch.FloatTensor] = None,
  239. labels: Optional[torch.LongTensor] = None,
  240. output_hidden_states: Optional[bool] = None,
  241. return_dict: Optional[bool] = None,
  242. ) -> ImageClassifierOutputWithNoAttention:
  243. r"""
  244. labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
  245. Labels for computing the image classification/regression loss. Indices should be in `[0, ...,
  246. config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
  247. `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
  248. Examples:
  249. ```python
  250. >>> import torch
  251. >>> import requests
  252. >>> from transformers import TextNetForImageClassification, TextNetImageProcessor
  253. >>> from PIL import Image
  254. >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
  255. >>> image = Image.open(requests.get(url, stream=True).raw)
  256. >>> processor = TextNetImageProcessor.from_pretrained("czczup/textnet-base")
  257. >>> model = TextNetForImageClassification.from_pretrained("czczup/textnet-base")
  258. >>> inputs = processor(images=image, return_tensors="pt")
  259. >>> with torch.no_grad():
  260. ... outputs = model(**inputs)
  261. >>> outputs.logits.shape
  262. torch.Size([1, 2])
  263. ```"""
  264. return_dict = return_dict if return_dict is not None else self.config.use_return_dict
  265. outputs = self.textnet(pixel_values, output_hidden_states=output_hidden_states, return_dict=return_dict)
  266. last_hidden_state = outputs[0]
  267. for layer in self.classifier:
  268. last_hidden_state = layer(last_hidden_state)
  269. logits = self.fc(last_hidden_state)
  270. loss = None
  271. if labels is not None:
  272. loss = self.loss_function(labels, logits, self.config)
  273. if not return_dict:
  274. output = (logits,) + outputs[2:]
  275. return (loss,) + output if loss is not None else output
  276. return ImageClassifierOutputWithNoAttention(loss=loss, logits=logits, hidden_states=outputs.hidden_states)
  277. @auto_docstring(
  278. custom_intro="""
  279. TextNet backbone, to be used with frameworks like DETR and MaskFormer.
  280. """
  281. )
  282. class TextNetBackbone(TextNetPreTrainedModel, BackboneMixin):
  283. has_attentions = False
  284. def __init__(self, config):
  285. super().__init__(config)
  286. super()._init_backbone(config)
  287. self.textnet = TextNetModel(config)
  288. self.num_features = config.hidden_sizes
  289. # initialize weights and apply final processing
  290. self.post_init()
  291. @auto_docstring
  292. def forward(
  293. self, pixel_values: Tensor, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None
  294. ) -> Union[tuple[tuple], BackboneOutput]:
  295. r"""
  296. Examples:
  297. ```python
  298. >>> import torch
  299. >>> import requests
  300. >>> from PIL import Image
  301. >>> from transformers import AutoImageProcessor, AutoBackbone
  302. >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
  303. >>> image = Image.open(requests.get(url, stream=True).raw)
  304. >>> processor = AutoImageProcessor.from_pretrained("czczup/textnet-base")
  305. >>> model = AutoBackbone.from_pretrained("czczup/textnet-base")
  306. >>> inputs = processor(image, return_tensors="pt")
  307. >>> with torch.no_grad():
  308. >>> outputs = model(**inputs)
  309. ```"""
  310. return_dict = return_dict if return_dict is not None else self.config.use_return_dict
  311. output_hidden_states = (
  312. output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
  313. )
  314. outputs = self.textnet(pixel_values, output_hidden_states=True, return_dict=return_dict)
  315. hidden_states = outputs.hidden_states if return_dict else outputs[2]
  316. feature_maps = ()
  317. for idx, stage in enumerate(self.stage_names):
  318. if stage in self.out_features:
  319. feature_maps += (hidden_states[idx],)
  320. if not return_dict:
  321. output = (feature_maps,)
  322. if output_hidden_states:
  323. hidden_states = outputs.hidden_states if return_dict else outputs[2]
  324. output += (hidden_states,)
  325. return output
  326. return BackboneOutput(
  327. feature_maps=feature_maps,
  328. hidden_states=outputs.hidden_states if output_hidden_states else None,
  329. attentions=None,
  330. )
  331. __all__ = ["TextNetBackbone", "TextNetModel", "TextNetPreTrainedModel", "TextNetForImageClassification"]