modeling_vitpose.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279
  1. # coding=utf-8
  2. # Copyright 2024 University of Sydney and The HuggingFace Inc. team. All rights reserved.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. """PyTorch VitPose model."""
  16. from dataclasses import dataclass
  17. from typing import Optional, Union
  18. import torch
  19. from torch import nn
  20. from ...modeling_outputs import BackboneOutput
  21. from ...modeling_utils import PreTrainedModel
  22. from ...processing_utils import Unpack
  23. from ...utils import ModelOutput, TransformersKwargs, auto_docstring, logging
  24. from ...utils.backbone_utils import load_backbone
  25. from ...utils.generic import can_return_tuple
  26. from .configuration_vitpose import VitPoseConfig
  27. logger = logging.get_logger(__name__)
  28. # General docstring
  29. @dataclass
  30. @auto_docstring(
  31. custom_intro="""
  32. Class for outputs of pose estimation models.
  33. """
  34. )
  35. class VitPoseEstimatorOutput(ModelOutput):
  36. r"""
  37. loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
  38. Loss is not supported at this moment. See https://github.com/ViTAE-Transformer/ViTPose/tree/main/mmpose/models/losses for further detail.
  39. heatmaps (`torch.FloatTensor` of shape `(batch_size, num_keypoints, height, width)`):
  40. Heatmaps as predicted by the model.
  41. hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
  42. Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
  43. one for the output of each stage) of shape `(batch_size, sequence_length, hidden_size)`. Hidden-states
  44. (also called feature maps) of the model at the output of each stage.
  45. """
  46. loss: Optional[torch.FloatTensor] = None
  47. heatmaps: Optional[torch.FloatTensor] = None
  48. hidden_states: Optional[tuple[torch.FloatTensor, ...]] = None
  49. attentions: Optional[tuple[torch.FloatTensor, ...]] = None
  50. @auto_docstring
  51. class VitPosePreTrainedModel(PreTrainedModel):
  52. config: VitPoseConfig
  53. base_model_prefix = "vit"
  54. main_input_name = "pixel_values"
  55. supports_gradient_checkpointing = True
  56. def _init_weights(self, module: Union[nn.Linear, nn.Conv2d, nn.LayerNorm]):
  57. """Initialize the weights"""
  58. if isinstance(module, (nn.Linear, nn.Conv2d)):
  59. # Upcast the input in `fp32` and cast it back to desired `dtype` to avoid
  60. # `trunc_normal_cpu` not implemented in `half` issues
  61. module.weight.data = nn.init.trunc_normal_(
  62. module.weight.data.to(torch.float32), mean=0.0, std=self.config.initializer_range
  63. ).to(module.weight.dtype)
  64. if module.bias is not None:
  65. module.bias.data.zero_()
  66. elif isinstance(module, nn.LayerNorm):
  67. module.bias.data.zero_()
  68. module.weight.data.fill_(1.0)
  69. def flip_back(output_flipped, flip_pairs, target_type="gaussian-heatmap"):
  70. """Flip the flipped heatmaps back to the original form.
  71. Args:
  72. output_flipped (`torch.tensor` of shape `(batch_size, num_keypoints, height, width)`):
  73. The output heatmaps obtained from the flipped images.
  74. flip_pairs (`torch.Tensor` of shape `(num_keypoints, 2)`):
  75. Pairs of keypoints which are mirrored (for example, left ear -- right ear).
  76. target_type (`str`, *optional*, defaults to `"gaussian-heatmap"`):
  77. Target type to use. Can be gaussian-heatmap or combined-target.
  78. gaussian-heatmap: Classification target with gaussian distribution.
  79. combined-target: The combination of classification target (response map) and regression target (offset map).
  80. Paper ref: Huang et al. The Devil is in the Details: Delving into Unbiased Data Processing for Human Pose Estimation (CVPR 2020).
  81. Returns:
  82. torch.Tensor: heatmaps that flipped back to the original image
  83. """
  84. if target_type not in ["gaussian-heatmap", "combined-target"]:
  85. raise ValueError("target_type should be gaussian-heatmap or combined-target")
  86. if output_flipped.ndim != 4:
  87. raise ValueError("output_flipped should be [batch_size, num_keypoints, height, width]")
  88. batch_size, num_keypoints, height, width = output_flipped.shape
  89. channels = 1
  90. if target_type == "combined-target":
  91. channels = 3
  92. output_flipped[:, 1::3, ...] = -output_flipped[:, 1::3, ...]
  93. output_flipped = output_flipped.reshape(batch_size, -1, channels, height, width)
  94. output_flipped_back = output_flipped.clone()
  95. # Swap left-right parts
  96. for left, right in flip_pairs.tolist():
  97. output_flipped_back[:, left, ...] = output_flipped[:, right, ...]
  98. output_flipped_back[:, right, ...] = output_flipped[:, left, ...]
  99. output_flipped_back = output_flipped_back.reshape((batch_size, num_keypoints, height, width))
  100. # Flip horizontally
  101. output_flipped_back = output_flipped_back.flip(-1)
  102. return output_flipped_back
  103. class VitPoseSimpleDecoder(nn.Module):
  104. """
  105. Simple decoding head consisting of a ReLU activation, 4x upsampling and a 3x3 convolution, turning the
  106. feature maps into heatmaps.
  107. """
  108. def __init__(self, config: VitPoseConfig):
  109. super().__init__()
  110. self.activation = nn.ReLU()
  111. self.upsampling = nn.Upsample(scale_factor=config.scale_factor, mode="bilinear", align_corners=False)
  112. self.conv = nn.Conv2d(
  113. config.backbone_config.hidden_size, config.num_labels, kernel_size=3, stride=1, padding=1
  114. )
  115. def forward(self, hidden_state: torch.Tensor, flip_pairs: Optional[torch.Tensor] = None) -> torch.Tensor:
  116. # Transform input: ReLU + upsample
  117. hidden_state = self.activation(hidden_state)
  118. hidden_state = self.upsampling(hidden_state)
  119. heatmaps = self.conv(hidden_state)
  120. if flip_pairs is not None:
  121. heatmaps = flip_back(heatmaps, flip_pairs)
  122. return heatmaps
  123. class VitPoseClassicDecoder(nn.Module):
  124. """
  125. Classic decoding head consisting of a 2 deconvolutional blocks, followed by a 1x1 convolution layer,
  126. turning the feature maps into heatmaps.
  127. """
  128. def __init__(self, config: VitPoseConfig):
  129. super().__init__()
  130. self.deconv1 = nn.ConvTranspose2d(
  131. config.backbone_config.hidden_size, 256, kernel_size=4, stride=2, padding=1, bias=False
  132. )
  133. self.batchnorm1 = nn.BatchNorm2d(256)
  134. self.relu1 = nn.ReLU()
  135. self.deconv2 = nn.ConvTranspose2d(256, 256, kernel_size=4, stride=2, padding=1, bias=False)
  136. self.batchnorm2 = nn.BatchNorm2d(256)
  137. self.relu2 = nn.ReLU()
  138. self.conv = nn.Conv2d(256, config.num_labels, kernel_size=1, stride=1, padding=0)
  139. def forward(self, hidden_state: torch.Tensor, flip_pairs: Optional[torch.Tensor] = None):
  140. hidden_state = self.deconv1(hidden_state)
  141. hidden_state = self.batchnorm1(hidden_state)
  142. hidden_state = self.relu1(hidden_state)
  143. hidden_state = self.deconv2(hidden_state)
  144. hidden_state = self.batchnorm2(hidden_state)
  145. hidden_state = self.relu2(hidden_state)
  146. heatmaps = self.conv(hidden_state)
  147. if flip_pairs is not None:
  148. heatmaps = flip_back(heatmaps, flip_pairs)
  149. return heatmaps
  150. @auto_docstring(
  151. custom_intro="""
  152. The VitPose model with a pose estimation head on top.
  153. """
  154. )
  155. class VitPoseForPoseEstimation(VitPosePreTrainedModel):
  156. def __init__(self, config: VitPoseConfig):
  157. super().__init__(config)
  158. self.backbone = load_backbone(config)
  159. # add backbone attributes
  160. if not hasattr(self.backbone.config, "hidden_size"):
  161. raise ValueError("The backbone should have a hidden_size attribute")
  162. if not hasattr(self.backbone.config, "image_size"):
  163. raise ValueError("The backbone should have an image_size attribute")
  164. if not hasattr(self.backbone.config, "patch_size"):
  165. raise ValueError("The backbone should have a patch_size attribute")
  166. self.head = VitPoseSimpleDecoder(config) if config.use_simple_decoder else VitPoseClassicDecoder(config)
  167. # Initialize weights and apply final processing
  168. self.post_init()
  169. @can_return_tuple
  170. @auto_docstring
  171. def forward(
  172. self,
  173. pixel_values: torch.Tensor,
  174. dataset_index: Optional[torch.Tensor] = None,
  175. flip_pairs: Optional[torch.Tensor] = None,
  176. labels: Optional[torch.Tensor] = None,
  177. **kwargs: Unpack[TransformersKwargs],
  178. ) -> VitPoseEstimatorOutput:
  179. r"""
  180. dataset_index (`torch.Tensor` of shape `(batch_size,)`):
  181. Index to use in the Mixture-of-Experts (MoE) blocks of the backbone.
  182. This corresponds to the dataset index used during training, e.g. For the single dataset index 0 refers to the corresponding dataset. For the multiple datasets index 0 refers to dataset A (e.g. MPII) and index 1 refers to dataset B (e.g. CrowdPose).
  183. flip_pairs (`torch.tensor`, *optional*):
  184. Whether to mirror pairs of keypoints (for example, left ear -- right ear).
  185. Examples:
  186. ```python
  187. >>> from transformers import AutoImageProcessor, VitPoseForPoseEstimation
  188. >>> import torch
  189. >>> from PIL import Image
  190. >>> import requests
  191. >>> processor = AutoImageProcessor.from_pretrained("usyd-community/vitpose-base-simple")
  192. >>> model = VitPoseForPoseEstimation.from_pretrained("usyd-community/vitpose-base-simple")
  193. >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
  194. >>> image = Image.open(requests.get(url, stream=True).raw)
  195. >>> boxes = [[[412.8, 157.61, 53.05, 138.01], [384.43, 172.21, 15.12, 35.74]]]
  196. >>> inputs = processor(image, boxes=boxes, return_tensors="pt")
  197. >>> with torch.no_grad():
  198. ... outputs = model(**inputs)
  199. >>> heatmaps = outputs.heatmaps
  200. ```"""
  201. loss = None
  202. if labels is not None:
  203. raise NotImplementedError("Training is not yet supported")
  204. outputs: BackboneOutput = self.backbone.forward_with_filtered_kwargs(
  205. pixel_values,
  206. dataset_index=dataset_index,
  207. **kwargs,
  208. )
  209. # Turn output hidden states in tensor of shape (batch_size, num_channels, height, width)
  210. sequence_output = outputs.feature_maps[-1]
  211. batch_size = sequence_output.shape[0]
  212. patch_height = self.config.backbone_config.image_size[0] // self.config.backbone_config.patch_size[0]
  213. patch_width = self.config.backbone_config.image_size[1] // self.config.backbone_config.patch_size[1]
  214. sequence_output = sequence_output.permute(0, 2, 1)
  215. sequence_output = sequence_output.reshape(batch_size, -1, patch_height, patch_width).contiguous()
  216. heatmaps = self.head(sequence_output, flip_pairs=flip_pairs)
  217. return VitPoseEstimatorOutput(
  218. loss=loss,
  219. heatmaps=heatmaps,
  220. hidden_states=outputs.hidden_states,
  221. attentions=outputs.attentions,
  222. )
  223. __all__ = ["VitPosePreTrainedModel", "VitPoseForPoseEstimation"]