modeling_dpt.py 50 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126112711281129113011311132113311341135113611371138113911401141114211431144114511461147114811491150115111521153115411551156115711581159116011611162116311641165116611671168116911701171117211731174117511761177117811791180118111821183118411851186118711881189119011911192119311941195119611971198119912001201120212031204120512061207120812091210121112121213121412151216121712181219122012211222122312241225
  1. # coding=utf-8
  2. # Copyright 2022 Intel Labs, OpenMMLab and The HuggingFace Inc. team. All rights reserved.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. """PyTorch DPT (Dense Prediction Transformers) model.
  16. This implementation is heavily inspired by OpenMMLab's implementation, found here:
  17. https://github.com/open-mmlab/mmsegmentation/blob/master/mmseg/models/decode_heads/dpt_head.py.
  18. """
  19. import collections.abc
  20. from dataclasses import dataclass
  21. from typing import Callable, Optional
  22. import torch
  23. from torch import nn
  24. from torch.nn import CrossEntropyLoss
  25. from ...activations import ACT2FN
  26. from ...modeling_layers import GradientCheckpointingLayer
  27. from ...modeling_outputs import BaseModelOutput, DepthEstimatorOutput, SemanticSegmenterOutput
  28. from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
  29. from ...pytorch_utils import find_pruneable_heads_and_indices, prune_linear_layer
  30. from ...utils import ModelOutput, auto_docstring, logging, torch_int
  31. from ...utils.backbone_utils import load_backbone
  32. from ...utils.generic import can_return_tuple, check_model_inputs
  33. from .configuration_dpt import DPTConfig
  34. logger = logging.get_logger(__name__)
  35. @dataclass
  36. @auto_docstring(
  37. custom_intro="""
  38. Base class for model's outputs that also contains intermediate activations that can be used at later stages. Useful
  39. in the context of Vision models.:
  40. """
  41. )
  42. class BaseModelOutputWithIntermediateActivations(ModelOutput):
  43. r"""
  44. last_hidden_states (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
  45. Sequence of hidden-states at the output of the last layer of the model.
  46. intermediate_activations (`tuple(torch.FloatTensor)`, *optional*):
  47. Intermediate activations that can be used to compute hidden states of the model at various layers.
  48. """
  49. last_hidden_states: Optional[torch.FloatTensor] = None
  50. intermediate_activations: Optional[tuple[torch.FloatTensor, ...]] = None
  51. @dataclass
  52. @auto_docstring(
  53. custom_intro="""
  54. Base class for model's outputs that also contains a pooling of the last hidden states as well as intermediate
  55. activations that can be used by the model at later stages.
  56. """
  57. )
  58. class BaseModelOutputWithPoolingAndIntermediateActivations(ModelOutput):
  59. r"""
  60. pooler_output (`torch.FloatTensor` of shape `(batch_size, hidden_size)`):
  61. Last layer hidden-state of the first token of the sequence (classification token) after further processing
  62. through the layers used for the auxiliary pretraining task. E.g. for BERT-family of models, this returns
  63. the classification token after processing through a linear layer and a tanh activation function. The linear
  64. layer weights are trained from the next sentence prediction (classification) objective during pretraining.
  65. intermediate_activations (`tuple(torch.FloatTensor)`, *optional*):
  66. Intermediate activations that can be used to compute hidden states of the model at various layers.
  67. """
  68. last_hidden_state: Optional[torch.FloatTensor] = None
  69. pooler_output: Optional[torch.FloatTensor] = None
  70. hidden_states: Optional[tuple[torch.FloatTensor, ...]] = None
  71. attentions: Optional[tuple[torch.FloatTensor, ...]] = None
  72. intermediate_activations: Optional[tuple[torch.FloatTensor, ...]] = None
  73. class DPTViTHybridEmbeddings(nn.Module):
  74. """
  75. This class turns `pixel_values` of shape `(batch_size, num_channels, height, width)` into the initial
  76. `hidden_states` (patch embeddings) of shape `(batch_size, seq_length, hidden_size)` to be consumed by a
  77. Transformer.
  78. """
  79. def __init__(self, config: DPTConfig, feature_size: Optional[tuple[int, int]] = None):
  80. super().__init__()
  81. image_size, patch_size = config.image_size, config.patch_size
  82. num_channels, hidden_size = config.num_channels, config.hidden_size
  83. image_size = image_size if isinstance(image_size, collections.abc.Iterable) else (image_size, image_size)
  84. patch_size = patch_size if isinstance(patch_size, collections.abc.Iterable) else (patch_size, patch_size)
  85. num_patches = (image_size[1] // patch_size[1]) * (image_size[0] // patch_size[0])
  86. self.backbone = load_backbone(config)
  87. feature_dim = self.backbone.channels[-1]
  88. if len(self.backbone.channels) != 3:
  89. raise ValueError(f"Expected backbone to have 3 output features, got {len(self.backbone.channels)}")
  90. self.residual_feature_map_index = [0, 1] # Always take the output of the first and second backbone stage
  91. if feature_size is None:
  92. feat_map_shape = config.backbone_featmap_shape
  93. feature_size = feat_map_shape[-2:]
  94. feature_dim = feat_map_shape[1]
  95. else:
  96. feature_size = (
  97. feature_size if isinstance(feature_size, collections.abc.Iterable) else (feature_size, feature_size)
  98. )
  99. feature_dim = self.backbone.channels[-1]
  100. self.image_size = image_size
  101. self.patch_size = patch_size[0]
  102. self.num_channels = num_channels
  103. self.projection = nn.Conv2d(feature_dim, hidden_size, kernel_size=1)
  104. self.cls_token = nn.Parameter(torch.zeros(1, 1, config.hidden_size))
  105. self.position_embeddings = nn.Parameter(torch.zeros(1, num_patches + 1, config.hidden_size))
  106. def _resize_pos_embed(self, posemb, grid_size_height, grid_size_width, start_index=1):
  107. posemb_tok = posemb[:, :start_index]
  108. posemb_grid = posemb[0, start_index:]
  109. old_grid_size = torch_int(len(posemb_grid) ** 0.5)
  110. posemb_grid = posemb_grid.reshape(1, old_grid_size, old_grid_size, -1).permute(0, 3, 1, 2)
  111. posemb_grid = nn.functional.interpolate(posemb_grid, size=(grid_size_height, grid_size_width), mode="bilinear")
  112. posemb_grid = posemb_grid.permute(0, 2, 3, 1).reshape(1, grid_size_height * grid_size_width, -1)
  113. posemb = torch.cat([posemb_tok, posemb_grid], dim=1)
  114. return posemb
  115. def forward(
  116. self, pixel_values: torch.Tensor, interpolate_pos_encoding: bool = False
  117. ) -> BaseModelOutputWithIntermediateActivations:
  118. batch_size, num_channels, height, width = pixel_values.shape
  119. if num_channels != self.num_channels:
  120. raise ValueError(
  121. "Make sure that the channel dimension of the pixel values match with the one set in the configuration."
  122. )
  123. if not interpolate_pos_encoding:
  124. if height != self.image_size[0] or width != self.image_size[1]:
  125. raise ValueError(
  126. f"Input image size ({height}*{width}) doesn't match model"
  127. f" ({self.image_size[0]}*{self.image_size[1]})."
  128. )
  129. position_embeddings = self._resize_pos_embed(
  130. self.position_embeddings, height // self.patch_size, width // self.patch_size
  131. )
  132. backbone_output = self.backbone(pixel_values)
  133. features = backbone_output.feature_maps[-1]
  134. # Retrieve also the intermediate activations to use them at later stages
  135. output_hidden_states = [backbone_output.feature_maps[index] for index in self.residual_feature_map_index]
  136. embeddings = self.projection(features).flatten(2).transpose(1, 2)
  137. cls_tokens = self.cls_token.expand(batch_size, -1, -1)
  138. embeddings = torch.cat((cls_tokens, embeddings), dim=1)
  139. # add positional encoding to each token
  140. embeddings = embeddings + position_embeddings
  141. # Return hidden states and intermediate activations
  142. return BaseModelOutputWithIntermediateActivations(
  143. last_hidden_states=embeddings,
  144. intermediate_activations=output_hidden_states,
  145. )
  146. class DPTViTEmbeddings(nn.Module):
  147. """
  148. Construct the CLS token, position and patch embeddings.
  149. """
  150. def __init__(self, config):
  151. super().__init__()
  152. self.cls_token = nn.Parameter(torch.zeros(1, 1, config.hidden_size))
  153. self.patch_embeddings = DPTViTPatchEmbeddings(config)
  154. num_patches = self.patch_embeddings.num_patches
  155. self.position_embeddings = nn.Parameter(torch.zeros(1, num_patches + 1, config.hidden_size))
  156. self.dropout = nn.Dropout(config.hidden_dropout_prob)
  157. self.config = config
  158. def _resize_pos_embed(self, posemb, grid_size_height, grid_size_width, start_index=1):
  159. posemb_tok = posemb[:, :start_index]
  160. posemb_grid = posemb[0, start_index:]
  161. old_grid_size = torch_int(posemb_grid.size(0) ** 0.5)
  162. posemb_grid = posemb_grid.reshape(1, old_grid_size, old_grid_size, -1).permute(0, 3, 1, 2)
  163. posemb_grid = nn.functional.interpolate(posemb_grid, size=(grid_size_height, grid_size_width), mode="bilinear")
  164. posemb_grid = posemb_grid.permute(0, 2, 3, 1).reshape(1, grid_size_height * grid_size_width, -1)
  165. posemb = torch.cat([posemb_tok, posemb_grid], dim=1)
  166. return posemb
  167. def forward(self, pixel_values: torch.Tensor) -> BaseModelOutputWithIntermediateActivations:
  168. batch_size, num_channels, height, width = pixel_values.shape
  169. # possibly interpolate position encodings to handle varying image sizes
  170. patch_size = self.config.patch_size
  171. position_embeddings = self._resize_pos_embed(
  172. self.position_embeddings, height // patch_size, width // patch_size
  173. )
  174. embeddings = self.patch_embeddings(pixel_values)
  175. batch_size, seq_len, _ = embeddings.size()
  176. # add the [CLS] token to the embedded patch tokens
  177. cls_tokens = self.cls_token.expand(batch_size, -1, -1)
  178. embeddings = torch.cat((cls_tokens, embeddings), dim=1)
  179. # add positional encoding to each token
  180. embeddings = embeddings + position_embeddings
  181. embeddings = self.dropout(embeddings)
  182. return BaseModelOutputWithIntermediateActivations(last_hidden_states=embeddings)
  183. class DPTViTPatchEmbeddings(nn.Module):
  184. """
  185. Image to Patch Embedding.
  186. """
  187. def __init__(self, config: DPTConfig):
  188. super().__init__()
  189. image_size, patch_size = config.image_size, config.patch_size
  190. num_channels, hidden_size = config.num_channels, config.hidden_size
  191. image_size = image_size if isinstance(image_size, collections.abc.Iterable) else (image_size, image_size)
  192. patch_size = patch_size if isinstance(patch_size, collections.abc.Iterable) else (patch_size, patch_size)
  193. num_patches = (image_size[1] // patch_size[1]) * (image_size[0] // patch_size[0])
  194. self.image_size = image_size
  195. self.patch_size = patch_size
  196. self.num_channels = num_channels
  197. self.num_patches = num_patches
  198. self.projection = nn.Conv2d(num_channels, hidden_size, kernel_size=patch_size, stride=patch_size)
  199. def forward(self, pixel_values: torch.Tensor) -> torch.Tensor:
  200. batch_size, num_channels, height, width = pixel_values.shape
  201. if num_channels != self.num_channels:
  202. raise ValueError(
  203. "Make sure that the channel dimension of the pixel values match with the one set in the configuration."
  204. )
  205. embeddings = self.projection(pixel_values).flatten(2).transpose(1, 2)
  206. return embeddings
  207. # Copied from transformers.models.vit.modeling_vit.eager_attention_forward
  208. def eager_attention_forward(
  209. module: nn.Module,
  210. query: torch.Tensor,
  211. key: torch.Tensor,
  212. value: torch.Tensor,
  213. attention_mask: Optional[torch.Tensor],
  214. scaling: float,
  215. dropout: float = 0.0,
  216. **kwargs,
  217. ):
  218. # Take the dot product between "query" and "key" to get the raw attention scores.
  219. attn_weights = torch.matmul(query, key.transpose(-1, -2)) * scaling
  220. # Normalize the attention scores to probabilities.
  221. attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
  222. # This is actually dropping out entire tokens to attend to, which might
  223. # seem a bit unusual, but is taken from the original Transformer paper.
  224. attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
  225. # Mask heads if we want to
  226. if attention_mask is not None:
  227. attn_weights = attn_weights * attention_mask
  228. attn_output = torch.matmul(attn_weights, value)
  229. attn_output = attn_output.transpose(1, 2).contiguous()
  230. return attn_output, attn_weights
  231. # Copied from transformers.models.vit.modeling_vit.ViTSelfAttention with ViT->DPT
  232. class DPTSelfAttention(nn.Module):
  233. def __init__(self, config: DPTConfig):
  234. super().__init__()
  235. if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"):
  236. raise ValueError(
  237. f"The hidden size {config.hidden_size} is not a multiple of the number of attention "
  238. f"heads {config.num_attention_heads}."
  239. )
  240. self.config = config
  241. self.num_attention_heads = config.num_attention_heads
  242. self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
  243. self.all_head_size = self.num_attention_heads * self.attention_head_size
  244. self.dropout_prob = config.attention_probs_dropout_prob
  245. self.scaling = self.attention_head_size**-0.5
  246. self.is_causal = False
  247. self.query = nn.Linear(config.hidden_size, self.all_head_size, bias=config.qkv_bias)
  248. self.key = nn.Linear(config.hidden_size, self.all_head_size, bias=config.qkv_bias)
  249. self.value = nn.Linear(config.hidden_size, self.all_head_size, bias=config.qkv_bias)
  250. def forward(
  251. self, hidden_states: torch.Tensor, head_mask: Optional[torch.Tensor] = None
  252. ) -> tuple[torch.Tensor, torch.Tensor]:
  253. batch_size = hidden_states.shape[0]
  254. new_shape = batch_size, -1, self.num_attention_heads, self.attention_head_size
  255. key_layer = self.key(hidden_states).view(*new_shape).transpose(1, 2)
  256. value_layer = self.value(hidden_states).view(*new_shape).transpose(1, 2)
  257. query_layer = self.query(hidden_states).view(*new_shape).transpose(1, 2)
  258. attention_interface: Callable = eager_attention_forward
  259. if self.config._attn_implementation != "eager":
  260. attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
  261. context_layer, attention_probs = attention_interface(
  262. self,
  263. query_layer,
  264. key_layer,
  265. value_layer,
  266. head_mask,
  267. is_causal=self.is_causal,
  268. scaling=self.scaling,
  269. dropout=0.0 if not self.training else self.dropout_prob,
  270. )
  271. new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
  272. context_layer = context_layer.reshape(new_context_layer_shape)
  273. return context_layer, attention_probs
  274. # Copied from transformers.models.vit.modeling_vit.ViTSelfOutput with ViTConfig->DPTConfig, ViTSelfOutput->DPTViTSelfOutput
  275. class DPTViTSelfOutput(nn.Module):
  276. """
  277. The residual connection is defined in ViTLayer instead of here (as is the case with other models), due to the
  278. layernorm applied before each block.
  279. """
  280. def __init__(self, config: DPTConfig):
  281. super().__init__()
  282. self.dense = nn.Linear(config.hidden_size, config.hidden_size)
  283. self.dropout = nn.Dropout(config.hidden_dropout_prob)
  284. def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor:
  285. hidden_states = self.dense(hidden_states)
  286. hidden_states = self.dropout(hidden_states)
  287. return hidden_states
  288. # Copied from transformers.models.vit.modeling_vit.ViTAttention with ViTConfig->DPTConfig, ViTSelfAttention->DPTSelfAttention, ViTSelfOutput->DPTViTSelfOutput
  289. class DPTViTAttention(nn.Module):
  290. def __init__(self, config: DPTConfig):
  291. super().__init__()
  292. self.attention = DPTSelfAttention(config)
  293. self.output = DPTViTSelfOutput(config)
  294. self.pruned_heads = set()
  295. def prune_heads(self, heads: set[int]):
  296. if len(heads) == 0:
  297. return
  298. heads, index = find_pruneable_heads_and_indices(
  299. heads, self.attention.num_attention_heads, self.attention.attention_head_size, self.pruned_heads
  300. )
  301. # Prune linear layers
  302. self.attention.query = prune_linear_layer(self.attention.query, index)
  303. self.attention.key = prune_linear_layer(self.attention.key, index)
  304. self.attention.value = prune_linear_layer(self.attention.value, index)
  305. self.output.dense = prune_linear_layer(self.output.dense, index, dim=1)
  306. # Update hyper params and store pruned heads
  307. self.attention.num_attention_heads = self.attention.num_attention_heads - len(heads)
  308. self.attention.all_head_size = self.attention.attention_head_size * self.attention.num_attention_heads
  309. self.pruned_heads = self.pruned_heads.union(heads)
  310. def forward(self, hidden_states: torch.Tensor, head_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
  311. self_attn_output, _ = self.attention(hidden_states, head_mask)
  312. output = self.output(self_attn_output, hidden_states)
  313. return output
  314. # Copied from transformers.models.vit.modeling_vit.ViTIntermediate with ViTConfig->DPTConfig, ViTIntermediate->DPTViTIntermediate
  315. class DPTViTIntermediate(nn.Module):
  316. def __init__(self, config: DPTConfig):
  317. super().__init__()
  318. self.dense = nn.Linear(config.hidden_size, config.intermediate_size)
  319. if isinstance(config.hidden_act, str):
  320. self.intermediate_act_fn = ACT2FN[config.hidden_act]
  321. else:
  322. self.intermediate_act_fn = config.hidden_act
  323. def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
  324. hidden_states = self.dense(hidden_states)
  325. hidden_states = self.intermediate_act_fn(hidden_states)
  326. return hidden_states
  327. # Copied from transformers.models.vit.modeling_vit.ViTOutput with ViTConfig->DPTConfig, ViTOutput->DPTViTOutput
  328. class DPTViTOutput(nn.Module):
  329. def __init__(self, config: DPTConfig):
  330. super().__init__()
  331. self.dense = nn.Linear(config.intermediate_size, config.hidden_size)
  332. self.dropout = nn.Dropout(config.hidden_dropout_prob)
  333. def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor:
  334. hidden_states = self.dense(hidden_states)
  335. hidden_states = self.dropout(hidden_states)
  336. hidden_states = hidden_states + input_tensor
  337. return hidden_states
  338. # Copied from transformers.models.vit.modeling_vit.ViTLayer with ViTConfig->DPTConfig, ViTAttention->DPTViTAttention, ViTIntermediate->DPTViTIntermediate, ViTOutput->DPTViTOutput, ViTLayer->DPTViTLayer
  339. class DPTViTLayer(GradientCheckpointingLayer):
  340. """This corresponds to the Block class in the timm implementation."""
  341. def __init__(self, config: DPTConfig):
  342. super().__init__()
  343. self.chunk_size_feed_forward = config.chunk_size_feed_forward
  344. self.seq_len_dim = 1
  345. self.attention = DPTViTAttention(config)
  346. self.intermediate = DPTViTIntermediate(config)
  347. self.output = DPTViTOutput(config)
  348. self.layernorm_before = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
  349. self.layernorm_after = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
  350. def forward(self, hidden_states: torch.Tensor, head_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
  351. hidden_states_norm = self.layernorm_before(hidden_states)
  352. attention_output = self.attention(hidden_states_norm, head_mask)
  353. # first residual connection
  354. hidden_states = attention_output + hidden_states
  355. # in ViT, layernorm is also applied after self-attention
  356. layer_output = self.layernorm_after(hidden_states)
  357. layer_output = self.intermediate(layer_output)
  358. # second residual connection is done here
  359. layer_output = self.output(layer_output, hidden_states)
  360. return layer_output
  361. # Copied from transformers.models.dinov2.modeling_dinov2.Dinov2Encoder with Dinov2Config->DPTConfig, Dinov2->DPTViT
  362. class DPTViTEncoder(nn.Module):
  363. def __init__(self, config: DPTConfig):
  364. super().__init__()
  365. self.config = config
  366. self.layer = nn.ModuleList([DPTViTLayer(config) for _ in range(config.num_hidden_layers)])
  367. self.gradient_checkpointing = False
  368. def forward(
  369. self, hidden_states: torch.Tensor, head_mask: Optional[torch.Tensor] = None, output_hidden_states: bool = False
  370. ) -> BaseModelOutput:
  371. all_hidden_states = [hidden_states] if output_hidden_states else None
  372. for i, layer_module in enumerate(self.layer):
  373. layer_head_mask = head_mask[i] if head_mask is not None else None
  374. hidden_states = layer_module(hidden_states, layer_head_mask)
  375. if all_hidden_states:
  376. all_hidden_states.append(hidden_states)
  377. return BaseModelOutput(
  378. last_hidden_state=hidden_states,
  379. hidden_states=tuple(all_hidden_states) if all_hidden_states else None,
  380. )
  381. class DPTReassembleStage(nn.Module):
  382. """
  383. This class reassembles the hidden states of the backbone into image-like feature representations at various
  384. resolutions.
  385. This happens in 3 stages:
  386. 1. Map the N + 1 tokens to a set of N tokens, by taking into account the readout ([CLS]) token according to
  387. `config.readout_type`.
  388. 2. Project the channel dimension of the hidden states according to `config.neck_hidden_sizes`.
  389. 3. Resizing the spatial dimensions (height, width).
  390. Args:
  391. config (`[DPTConfig]`):
  392. Model configuration class defining the model architecture.
  393. """
  394. def __init__(self, config):
  395. super().__init__()
  396. self.config = config
  397. self.layers = nn.ModuleList()
  398. if config.is_hybrid:
  399. self._init_reassemble_dpt_hybrid(config)
  400. else:
  401. self._init_reassemble_dpt(config)
  402. self.neck_ignore_stages = config.neck_ignore_stages
  403. def _init_reassemble_dpt_hybrid(self, config):
  404. r""" "
  405. For DPT-Hybrid the first 2 reassemble layers are set to `nn.Identity()`, please check the official
  406. implementation: https://github.com/isl-org/DPT/blob/f43ef9e08d70a752195028a51be5e1aff227b913/dpt/vit.py#L438
  407. for more details.
  408. """
  409. for i, factor in zip(range(len(config.neck_hidden_sizes)), config.reassemble_factors):
  410. if i <= 1:
  411. self.layers.append(nn.Identity())
  412. elif i > 1:
  413. self.layers.append(DPTReassembleLayer(config, channels=config.neck_hidden_sizes[i], factor=factor))
  414. if config.readout_type != "project":
  415. raise ValueError(f"Readout type {config.readout_type} is not supported for DPT-Hybrid.")
  416. # When using DPT-Hybrid the readout type is set to "project". The sanity check is done on the config file
  417. self.readout_projects = nn.ModuleList()
  418. hidden_size = _get_backbone_hidden_size(config)
  419. for i in range(len(config.neck_hidden_sizes)):
  420. if i <= 1:
  421. self.readout_projects.append(nn.Sequential(nn.Identity()))
  422. elif i > 1:
  423. self.readout_projects.append(
  424. nn.Sequential(nn.Linear(2 * hidden_size, hidden_size), ACT2FN[config.hidden_act])
  425. )
  426. def _init_reassemble_dpt(self, config):
  427. for i, factor in zip(range(len(config.neck_hidden_sizes)), config.reassemble_factors):
  428. self.layers.append(DPTReassembleLayer(config, channels=config.neck_hidden_sizes[i], factor=factor))
  429. if config.readout_type == "project":
  430. self.readout_projects = nn.ModuleList()
  431. hidden_size = _get_backbone_hidden_size(config)
  432. for _ in range(len(config.neck_hidden_sizes)):
  433. self.readout_projects.append(
  434. nn.Sequential(nn.Linear(2 * hidden_size, hidden_size), ACT2FN[config.hidden_act])
  435. )
  436. def forward(self, hidden_states: list[torch.Tensor], patch_height=None, patch_width=None) -> list[torch.Tensor]:
  437. """
  438. Args:
  439. hidden_states (`list[torch.FloatTensor]`, each of shape `(batch_size, sequence_length + 1, hidden_size)`):
  440. List of hidden states from the backbone.
  441. """
  442. out = []
  443. for i, hidden_state in enumerate(hidden_states):
  444. if i not in self.neck_ignore_stages:
  445. # reshape to (batch_size, num_channels, height, width)
  446. cls_token, hidden_state = hidden_state[:, 0], hidden_state[:, 1:]
  447. batch_size, sequence_length, num_channels = hidden_state.shape
  448. if patch_height is not None and patch_width is not None:
  449. hidden_state = hidden_state.reshape(batch_size, patch_height, patch_width, num_channels)
  450. else:
  451. size = torch_int(sequence_length**0.5)
  452. hidden_state = hidden_state.reshape(batch_size, size, size, num_channels)
  453. hidden_state = hidden_state.permute(0, 3, 1, 2).contiguous()
  454. feature_shape = hidden_state.shape
  455. if self.config.readout_type == "project":
  456. # reshape to (batch_size, height*width, num_channels)
  457. hidden_state = hidden_state.flatten(2).permute((0, 2, 1))
  458. readout = cls_token.unsqueeze(1).expand_as(hidden_state)
  459. # concatenate the readout token to the hidden states and project
  460. hidden_state = self.readout_projects[i](torch.cat((hidden_state, readout), -1))
  461. # reshape back to (batch_size, num_channels, height, width)
  462. hidden_state = hidden_state.permute(0, 2, 1).reshape(feature_shape)
  463. elif self.config.readout_type == "add":
  464. hidden_state = hidden_state.flatten(2) + cls_token.unsqueeze(-1)
  465. hidden_state = hidden_state.reshape(feature_shape)
  466. hidden_state = self.layers[i](hidden_state)
  467. out.append(hidden_state)
  468. return out
  469. def _get_backbone_hidden_size(config):
  470. if config.backbone_config is not None and config.is_hybrid is False:
  471. return config.backbone_config.hidden_size
  472. else:
  473. return config.hidden_size
  474. class DPTReassembleLayer(nn.Module):
  475. def __init__(self, config: DPTConfig, channels: int, factor: int):
  476. super().__init__()
  477. # projection
  478. hidden_size = _get_backbone_hidden_size(config)
  479. self.projection = nn.Conv2d(in_channels=hidden_size, out_channels=channels, kernel_size=1)
  480. # up/down sampling depending on factor
  481. if factor > 1:
  482. self.resize = nn.ConvTranspose2d(channels, channels, kernel_size=factor, stride=factor, padding=0)
  483. elif factor == 1:
  484. self.resize = nn.Identity()
  485. elif factor < 1:
  486. # so should downsample
  487. self.resize = nn.Conv2d(channels, channels, kernel_size=3, stride=int(1 / factor), padding=1)
  488. def forward(self, hidden_state):
  489. hidden_state = self.projection(hidden_state)
  490. hidden_state = self.resize(hidden_state)
  491. return hidden_state
  492. class DPTFeatureFusionStage(nn.Module):
  493. def __init__(self, config: DPTConfig):
  494. super().__init__()
  495. self.layers = nn.ModuleList()
  496. for _ in range(len(config.neck_hidden_sizes)):
  497. self.layers.append(DPTFeatureFusionLayer(config))
  498. def forward(self, hidden_states):
  499. # reversing the hidden_states, we start from the last
  500. hidden_states = hidden_states[::-1]
  501. fused_hidden_states = []
  502. fused_hidden_state = None
  503. for hidden_state, layer in zip(hidden_states, self.layers):
  504. if fused_hidden_state is None:
  505. # first layer only uses the last hidden_state
  506. fused_hidden_state = layer(hidden_state)
  507. else:
  508. fused_hidden_state = layer(fused_hidden_state, hidden_state)
  509. fused_hidden_states.append(fused_hidden_state)
  510. return fused_hidden_states
  511. class DPTPreActResidualLayer(nn.Module):
  512. """
  513. ResidualConvUnit, pre-activate residual unit.
  514. Args:
  515. config (`[DPTConfig]`):
  516. Model configuration class defining the model architecture.
  517. """
  518. def __init__(self, config: DPTConfig):
  519. super().__init__()
  520. self.use_batch_norm = config.use_batch_norm_in_fusion_residual
  521. use_bias_in_fusion_residual = (
  522. config.use_bias_in_fusion_residual
  523. if config.use_bias_in_fusion_residual is not None
  524. else not self.use_batch_norm
  525. )
  526. self.activation1 = nn.ReLU()
  527. self.convolution1 = nn.Conv2d(
  528. config.fusion_hidden_size,
  529. config.fusion_hidden_size,
  530. kernel_size=3,
  531. stride=1,
  532. padding=1,
  533. bias=use_bias_in_fusion_residual,
  534. )
  535. self.activation2 = nn.ReLU()
  536. self.convolution2 = nn.Conv2d(
  537. config.fusion_hidden_size,
  538. config.fusion_hidden_size,
  539. kernel_size=3,
  540. stride=1,
  541. padding=1,
  542. bias=use_bias_in_fusion_residual,
  543. )
  544. if self.use_batch_norm:
  545. self.batch_norm1 = nn.BatchNorm2d(config.fusion_hidden_size)
  546. self.batch_norm2 = nn.BatchNorm2d(config.fusion_hidden_size)
  547. def forward(self, hidden_state: torch.Tensor) -> torch.Tensor:
  548. residual = hidden_state
  549. hidden_state = self.activation1(hidden_state)
  550. hidden_state = self.convolution1(hidden_state)
  551. if self.use_batch_norm:
  552. hidden_state = self.batch_norm1(hidden_state)
  553. hidden_state = self.activation2(hidden_state)
  554. hidden_state = self.convolution2(hidden_state)
  555. if self.use_batch_norm:
  556. hidden_state = self.batch_norm2(hidden_state)
  557. return hidden_state + residual
  558. class DPTFeatureFusionLayer(nn.Module):
  559. """Feature fusion layer, merges feature maps from different stages.
  560. Args:
  561. config (`[DPTConfig]`):
  562. Model configuration class defining the model architecture.
  563. align_corners (`bool`, *optional*, defaults to `True`):
  564. The align_corner setting for bilinear upsample.
  565. """
  566. def __init__(self, config: DPTConfig, align_corners: bool = True):
  567. super().__init__()
  568. self.align_corners = align_corners
  569. self.projection = nn.Conv2d(config.fusion_hidden_size, config.fusion_hidden_size, kernel_size=1, bias=True)
  570. self.residual_layer1 = DPTPreActResidualLayer(config)
  571. self.residual_layer2 = DPTPreActResidualLayer(config)
  572. def forward(self, hidden_state: torch.Tensor, residual: Optional[torch.Tensor] = None) -> torch.Tensor:
  573. if residual is not None:
  574. if hidden_state.shape != residual.shape:
  575. residual = nn.functional.interpolate(
  576. residual, size=(hidden_state.shape[2], hidden_state.shape[3]), mode="bilinear", align_corners=False
  577. )
  578. hidden_state = hidden_state + self.residual_layer1(residual)
  579. hidden_state = self.residual_layer2(hidden_state)
  580. hidden_state = nn.functional.interpolate(
  581. hidden_state, scale_factor=2, mode="bilinear", align_corners=self.align_corners
  582. )
  583. hidden_state = self.projection(hidden_state)
  584. return hidden_state
  585. @auto_docstring
  586. class DPTPreTrainedModel(PreTrainedModel):
  587. config: DPTConfig
  588. base_model_prefix = "dpt"
  589. main_input_name = "pixel_values"
  590. supports_gradient_checkpointing = True
  591. _supports_sdpa = True
  592. _supports_flash_attn = True
  593. _supports_flex_attn = True
  594. _supports_attention_backend = True
  595. _can_record_outputs = {
  596. "attentions": DPTSelfAttention,
  597. }
  598. def _init_weights(self, module):
  599. """Initialize the weights"""
  600. if isinstance(module, (nn.Linear, nn.Conv2d, nn.ConvTranspose2d)):
  601. # Slightly different from the TF version which uses truncated_normal for initialization
  602. # cf https://github.com/pytorch/pytorch/pull/5617
  603. module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
  604. if module.bias is not None:
  605. module.bias.data.zero_()
  606. elif isinstance(module, (nn.LayerNorm, nn.BatchNorm2d)):
  607. module.bias.data.zero_()
  608. module.weight.data.fill_(1.0)
  609. if isinstance(module, (DPTViTEmbeddings, DPTViTHybridEmbeddings)):
  610. module.cls_token.data.zero_()
  611. module.position_embeddings.data.zero_()
  612. @auto_docstring
  613. class DPTModel(DPTPreTrainedModel):
  614. def __init__(self, config: DPTConfig, add_pooling_layer: bool = True):
  615. r"""
  616. add_pooling_layer (bool, *optional*, defaults to `True`):
  617. Whether to add a pooling layer
  618. """
  619. super().__init__(config)
  620. self.config = config
  621. # vit encoder
  622. if config.is_hybrid:
  623. self.embeddings = DPTViTHybridEmbeddings(config)
  624. else:
  625. self.embeddings = DPTViTEmbeddings(config)
  626. self.encoder = DPTViTEncoder(config)
  627. self.layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
  628. self.pooler = DPTViTPooler(config) if add_pooling_layer else None
  629. # Initialize weights and apply final processing
  630. self.post_init()
  631. def get_input_embeddings(self):
  632. if self.config.is_hybrid:
  633. return self.embeddings
  634. else:
  635. return self.embeddings.patch_embeddings
  636. def _prune_heads(self, heads_to_prune):
  637. """
  638. Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base
  639. class PreTrainedModel
  640. """
  641. for layer, heads in heads_to_prune.items():
  642. self.encoder.layer[layer].attention.prune_heads(heads)
  643. @check_model_inputs(tie_last_hidden_states=False)
  644. @auto_docstring
  645. def forward(
  646. self,
  647. pixel_values: torch.FloatTensor,
  648. head_mask: Optional[torch.FloatTensor] = None,
  649. output_hidden_states: Optional[bool] = None,
  650. **kwargs,
  651. ) -> BaseModelOutputWithPoolingAndIntermediateActivations:
  652. if output_hidden_states is None:
  653. output_hidden_states = self.config.output_hidden_states
  654. # Prepare head mask if needed
  655. # 1.0 in head_mask indicate we keep the head
  656. # attention_probs has shape bsz x n_heads x N x N
  657. # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
  658. # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
  659. head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
  660. embedding_output: BaseModelOutputWithIntermediateActivations = self.embeddings(pixel_values)
  661. embedding_last_hidden_states = embedding_output.last_hidden_states
  662. encoder_outputs: BaseModelOutput = self.encoder(
  663. embedding_last_hidden_states, head_mask=head_mask, output_hidden_states=output_hidden_states
  664. )
  665. sequence_output = encoder_outputs.last_hidden_state
  666. sequence_output = self.layernorm(sequence_output)
  667. pooled_output = self.pooler(sequence_output) if self.pooler is not None else None
  668. return BaseModelOutputWithPoolingAndIntermediateActivations(
  669. last_hidden_state=sequence_output,
  670. pooler_output=pooled_output,
  671. intermediate_activations=embedding_output.intermediate_activations,
  672. hidden_states=encoder_outputs.hidden_states,
  673. )
  674. # Copied from transformers.models.vit.modeling_vit.ViTPooler with ViTConfig->DPTConfig, ViTPooler->DPTViTPooler
  675. class DPTViTPooler(nn.Module):
  676. def __init__(self, config: DPTConfig):
  677. super().__init__()
  678. self.dense = nn.Linear(config.hidden_size, config.pooler_output_size)
  679. self.activation = ACT2FN[config.pooler_act]
  680. def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
  681. # We "pool" the model by simply taking the hidden state corresponding
  682. # to the first token.
  683. first_token_tensor = hidden_states[:, 0]
  684. pooled_output = self.dense(first_token_tensor)
  685. pooled_output = self.activation(pooled_output)
  686. return pooled_output
  687. class DPTNeck(nn.Module):
  688. """
  689. DPTNeck. A neck is a module that is normally used between the backbone and the head. It takes a list of tensors as
  690. input and produces another list of tensors as output. For DPT, it includes 2 stages:
  691. * DPTReassembleStage
  692. * DPTFeatureFusionStage.
  693. Args:
  694. config (dict): config dict.
  695. """
  696. def __init__(self, config: DPTConfig):
  697. super().__init__()
  698. self.config = config
  699. # postprocessing: only required in case of a non-hierarchical backbone (e.g. ViT, BEiT)
  700. if config.backbone_config is not None and config.backbone_config.model_type == "swinv2":
  701. self.reassemble_stage = None
  702. else:
  703. self.reassemble_stage = DPTReassembleStage(config)
  704. self.convs = nn.ModuleList()
  705. for channel in config.neck_hidden_sizes:
  706. self.convs.append(nn.Conv2d(channel, config.fusion_hidden_size, kernel_size=3, padding=1, bias=False))
  707. # fusion
  708. self.fusion_stage = DPTFeatureFusionStage(config)
  709. def forward(
  710. self,
  711. hidden_states: list[torch.Tensor],
  712. patch_height: Optional[int] = None,
  713. patch_width: Optional[int] = None,
  714. ) -> list[torch.Tensor]:
  715. """
  716. Args:
  717. hidden_states (`list[torch.FloatTensor]`, each of shape `(batch_size, sequence_length, hidden_size)` or `(batch_size, hidden_size, height, width)`):
  718. List of hidden states from the backbone.
  719. """
  720. if not isinstance(hidden_states, (tuple, list)):
  721. raise TypeError("hidden_states should be a tuple or list of tensors")
  722. if len(hidden_states) != len(self.config.neck_hidden_sizes):
  723. raise ValueError("The number of hidden states should be equal to the number of neck hidden sizes.")
  724. # postprocess hidden states
  725. if self.reassemble_stage is not None:
  726. hidden_states = self.reassemble_stage(hidden_states, patch_height, patch_width)
  727. features = [self.convs[i](feature) for i, feature in enumerate(hidden_states)]
  728. # fusion blocks
  729. output = self.fusion_stage(features)
  730. return output
  731. class DPTDepthEstimationHead(nn.Module):
  732. """
  733. Output head consisting of 3 convolutional layers. It progressively halves the feature dimension and upsamples
  734. the predictions to the input resolution after the first convolutional layer (details can be found in the paper's
  735. supplementary material).
  736. """
  737. def __init__(self, config: DPTConfig):
  738. super().__init__()
  739. self.config = config
  740. self.projection = None
  741. if config.add_projection:
  742. self.projection = nn.Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  743. features = config.fusion_hidden_size
  744. self.head = nn.Sequential(
  745. nn.Conv2d(features, features // 2, kernel_size=3, stride=1, padding=1),
  746. nn.Upsample(scale_factor=2, mode="bilinear", align_corners=True),
  747. nn.Conv2d(features // 2, 32, kernel_size=3, stride=1, padding=1),
  748. nn.ReLU(),
  749. nn.Conv2d(32, 1, kernel_size=1, stride=1, padding=0),
  750. nn.ReLU(),
  751. )
  752. def forward(self, hidden_states: list[torch.Tensor]) -> torch.Tensor:
  753. # use last features
  754. hidden_states = hidden_states[self.config.head_in_index]
  755. if self.projection is not None:
  756. hidden_states = self.projection(hidden_states)
  757. hidden_states = nn.ReLU()(hidden_states)
  758. predicted_depth = self.head(hidden_states)
  759. predicted_depth = predicted_depth.squeeze(dim=1)
  760. return predicted_depth
  761. @auto_docstring(
  762. custom_intro="""
  763. DPT Model with a depth estimation head on top (consisting of 3 convolutional layers) e.g. for KITTI, NYUv2.
  764. """
  765. )
  766. class DPTForDepthEstimation(DPTPreTrainedModel):
  767. def __init__(self, config):
  768. super().__init__(config)
  769. self.backbone = None
  770. if config.is_hybrid is False and (config.backbone_config is not None or config.backbone is not None):
  771. self.backbone = load_backbone(config)
  772. else:
  773. self.dpt = DPTModel(config, add_pooling_layer=False)
  774. # Neck
  775. self.neck = DPTNeck(config)
  776. # Depth estimation head
  777. self.head = DPTDepthEstimationHead(config)
  778. # Initialize weights and apply final processing
  779. self.post_init()
  780. @can_return_tuple
  781. @auto_docstring
  782. def forward(
  783. self,
  784. pixel_values: torch.FloatTensor,
  785. head_mask: Optional[torch.FloatTensor] = None,
  786. labels: Optional[torch.LongTensor] = None,
  787. output_hidden_states: Optional[bool] = None,
  788. **kwargs,
  789. ) -> DepthEstimatorOutput:
  790. r"""
  791. labels (`torch.LongTensor` of shape `(batch_size, height, width)`, *optional*):
  792. Ground truth depth estimation maps for computing the loss.
  793. Examples:
  794. ```python
  795. >>> from transformers import AutoImageProcessor, DPTForDepthEstimation
  796. >>> import torch
  797. >>> import numpy as np
  798. >>> from PIL import Image
  799. >>> import requests
  800. >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
  801. >>> image = Image.open(requests.get(url, stream=True).raw)
  802. >>> image_processor = AutoImageProcessor.from_pretrained("Intel/dpt-large")
  803. >>> model = DPTForDepthEstimation.from_pretrained("Intel/dpt-large")
  804. >>> # prepare image for the model
  805. >>> inputs = image_processor(images=image, return_tensors="pt")
  806. >>> with torch.no_grad():
  807. ... outputs = model(**inputs)
  808. >>> # interpolate to original size
  809. >>> post_processed_output = image_processor.post_process_depth_estimation(
  810. ... outputs,
  811. ... target_sizes=[(image.height, image.width)],
  812. ... )
  813. >>> # visualize the prediction
  814. >>> predicted_depth = post_processed_output[0]["predicted_depth"]
  815. >>> depth = predicted_depth * 255 / predicted_depth.max()
  816. >>> depth = depth.detach().cpu().numpy()
  817. >>> depth = Image.fromarray(depth.astype("uint8"))
  818. ```"""
  819. if output_hidden_states is None:
  820. output_hidden_states = self.config.output_hidden_states
  821. loss = None
  822. if labels is not None:
  823. raise NotImplementedError("Training is not implemented yet")
  824. if self.backbone is not None:
  825. outputs = self.backbone.forward_with_filtered_kwargs(pixel_values, output_hidden_states=True, **kwargs)
  826. hidden_states = outputs.feature_maps
  827. else:
  828. outputs = self.dpt(pixel_values, head_mask=head_mask, output_hidden_states=True, **kwargs)
  829. hidden_states = outputs.hidden_states
  830. # only keep certain features based on config.backbone_out_indices
  831. # note that the hidden_states also include the initial embeddings
  832. if not self.config.is_hybrid:
  833. hidden_states = [
  834. feature for idx, feature in enumerate(hidden_states[1:]) if idx in self.config.backbone_out_indices
  835. ]
  836. else:
  837. backbone_hidden_states = outputs.intermediate_activations
  838. backbone_hidden_states.extend(
  839. feature
  840. for idx, feature in enumerate(hidden_states[1:])
  841. if idx in self.config.backbone_out_indices[2:]
  842. )
  843. hidden_states = backbone_hidden_states
  844. patch_height, patch_width = None, None
  845. if self.config.backbone_config is not None and self.config.is_hybrid is False:
  846. _, _, height, width = pixel_values.shape
  847. patch_size = self.config.backbone_config.patch_size
  848. patch_height = height // patch_size
  849. patch_width = width // patch_size
  850. hidden_states = self.neck(hidden_states, patch_height, patch_width)
  851. predicted_depth = self.head(hidden_states)
  852. return DepthEstimatorOutput(
  853. loss=loss,
  854. predicted_depth=predicted_depth,
  855. hidden_states=outputs.hidden_states if output_hidden_states else None,
  856. attentions=outputs.attentions,
  857. )
  858. class DPTSemanticSegmentationHead(nn.Module):
  859. def __init__(self, config: DPTConfig):
  860. super().__init__()
  861. self.config = config
  862. features = config.fusion_hidden_size
  863. self.head = nn.Sequential(
  864. nn.Conv2d(features, features, kernel_size=3, padding=1, bias=False),
  865. nn.BatchNorm2d(features),
  866. nn.ReLU(),
  867. nn.Dropout(config.semantic_classifier_dropout),
  868. nn.Conv2d(features, config.num_labels, kernel_size=1),
  869. nn.Upsample(scale_factor=2, mode="bilinear", align_corners=True),
  870. )
  871. def forward(self, hidden_states: list[torch.Tensor]) -> torch.Tensor:
  872. # use last features
  873. hidden_states = hidden_states[self.config.head_in_index]
  874. logits = self.head(hidden_states)
  875. return logits
  876. class DPTAuxiliaryHead(nn.Module):
  877. def __init__(self, config: DPTConfig):
  878. super().__init__()
  879. features = config.fusion_hidden_size
  880. self.head = nn.Sequential(
  881. nn.Conv2d(features, features, kernel_size=3, padding=1, bias=False),
  882. nn.BatchNorm2d(features),
  883. nn.ReLU(),
  884. nn.Dropout(0.1, False),
  885. nn.Conv2d(features, config.num_labels, kernel_size=1),
  886. )
  887. def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
  888. logits = self.head(hidden_states)
  889. return logits
  890. @auto_docstring
  891. class DPTForSemanticSegmentation(DPTPreTrainedModel):
  892. def __init__(self, config: DPTConfig):
  893. super().__init__(config)
  894. self.dpt = DPTModel(config, add_pooling_layer=False)
  895. # Neck
  896. self.neck = DPTNeck(config)
  897. # Segmentation head(s)
  898. self.head = DPTSemanticSegmentationHead(config)
  899. self.auxiliary_head = DPTAuxiliaryHead(config) if config.use_auxiliary_head else None
  900. # Initialize weights and apply final processing
  901. self.post_init()
  902. @can_return_tuple
  903. @auto_docstring
  904. def forward(
  905. self,
  906. pixel_values: Optional[torch.FloatTensor] = None,
  907. head_mask: Optional[torch.FloatTensor] = None,
  908. labels: Optional[torch.LongTensor] = None,
  909. output_hidden_states: Optional[bool] = None,
  910. **kwargs,
  911. ) -> SemanticSegmenterOutput:
  912. r"""
  913. labels (`torch.LongTensor` of shape `(batch_size, height, width)`, *optional*):
  914. Ground truth semantic segmentation maps for computing the loss. Indices should be in `[0, ...,
  915. config.num_labels - 1]`. If `config.num_labels > 1`, a classification loss is computed (Cross-Entropy).
  916. Examples:
  917. ```python
  918. >>> from transformers import AutoImageProcessor, DPTForSemanticSegmentation
  919. >>> from PIL import Image
  920. >>> import requests
  921. >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
  922. >>> image = Image.open(requests.get(url, stream=True).raw)
  923. >>> image_processor = AutoImageProcessor.from_pretrained("Intel/dpt-large-ade")
  924. >>> model = DPTForSemanticSegmentation.from_pretrained("Intel/dpt-large-ade")
  925. >>> inputs = image_processor(images=image, return_tensors="pt")
  926. >>> outputs = model(**inputs)
  927. >>> logits = outputs.logits
  928. ```"""
  929. if output_hidden_states is None:
  930. output_hidden_states = self.config.output_hidden_states
  931. if labels is not None and self.config.num_labels == 1:
  932. raise ValueError("The number of labels should be greater than one")
  933. outputs: BaseModelOutputWithPoolingAndIntermediateActivations = self.dpt(
  934. pixel_values, head_mask=head_mask, output_hidden_states=True, **kwargs
  935. )
  936. hidden_states = outputs.hidden_states
  937. # only keep certain features based on config.backbone_out_indices
  938. # note that the hidden_states also include the initial embeddings
  939. if not self.config.is_hybrid:
  940. hidden_states = [
  941. feature for idx, feature in enumerate(hidden_states[1:]) if idx in self.config.backbone_out_indices
  942. ]
  943. else:
  944. backbone_hidden_states = outputs.intermediate_activations
  945. backbone_hidden_states.extend(
  946. feature for idx, feature in enumerate(hidden_states[1:]) if idx in self.config.backbone_out_indices[2:]
  947. )
  948. hidden_states = backbone_hidden_states
  949. hidden_states = self.neck(hidden_states=hidden_states)
  950. logits = self.head(hidden_states)
  951. auxiliary_logits = None
  952. if self.auxiliary_head is not None:
  953. auxiliary_logits = self.auxiliary_head(hidden_states[-1])
  954. loss = None
  955. if labels is not None:
  956. # upsample logits to the images' original size
  957. upsampled_logits = nn.functional.interpolate(
  958. logits, size=labels.shape[-2:], mode="bilinear", align_corners=False
  959. )
  960. if auxiliary_logits is not None:
  961. upsampled_auxiliary_logits = nn.functional.interpolate(
  962. auxiliary_logits, size=labels.shape[-2:], mode="bilinear", align_corners=False
  963. )
  964. # compute weighted loss
  965. loss_fct = CrossEntropyLoss(ignore_index=self.config.semantic_loss_ignore_index)
  966. main_loss = loss_fct(upsampled_logits, labels)
  967. auxiliary_loss = loss_fct(upsampled_auxiliary_logits, labels)
  968. loss = main_loss + self.config.auxiliary_loss_weight * auxiliary_loss
  969. return SemanticSegmenterOutput(
  970. loss=loss,
  971. logits=logits,
  972. hidden_states=outputs.hidden_states if output_hidden_states else None,
  973. attentions=outputs.attentions,
  974. )
  975. __all__ = ["DPTForDepthEstimation", "DPTForSemanticSegmentation", "DPTModel", "DPTPreTrainedModel"]