modeling_glpn.py 28 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724
  1. # coding=utf-8
  2. # Copyright 2022 KAIST and The HuggingFace Inc. team. All rights reserved.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. """PyTorch GLPN model."""
  16. import math
  17. from typing import Optional, Union
  18. import torch
  19. from torch import nn
  20. from ...activations import ACT2FN
  21. from ...modeling_outputs import BaseModelOutput, DepthEstimatorOutput
  22. from ...modeling_utils import PreTrainedModel
  23. from ...pytorch_utils import find_pruneable_heads_and_indices, prune_linear_layer
  24. from ...utils import auto_docstring, logging
  25. from .configuration_glpn import GLPNConfig
  26. logger = logging.get_logger(__name__)
  27. # Copied from transformers.models.beit.modeling_beit.drop_path
  28. def drop_path(input: torch.Tensor, drop_prob: float = 0.0, training: bool = False) -> torch.Tensor:
  29. """
  30. Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
  31. Comment by Ross Wightman: This is the same as the DropConnect impl I created for EfficientNet, etc networks,
  32. however, the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper...
  33. See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for changing the
  34. layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use 'survival rate' as the
  35. argument.
  36. """
  37. if drop_prob == 0.0 or not training:
  38. return input
  39. keep_prob = 1 - drop_prob
  40. shape = (input.shape[0],) + (1,) * (input.ndim - 1) # work with diff dim tensors, not just 2D ConvNets
  41. random_tensor = keep_prob + torch.rand(shape, dtype=input.dtype, device=input.device)
  42. random_tensor.floor_() # binarize
  43. output = input.div(keep_prob) * random_tensor
  44. return output
  45. # Copied from transformers.models.segformer.modeling_segformer.SegformerDropPath
  46. class GLPNDropPath(nn.Module):
  47. """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks)."""
  48. def __init__(self, drop_prob: Optional[float] = None) -> None:
  49. super().__init__()
  50. self.drop_prob = drop_prob
  51. def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
  52. return drop_path(hidden_states, self.drop_prob, self.training)
  53. def extra_repr(self) -> str:
  54. return f"p={self.drop_prob}"
  55. # Copied from transformers.models.segformer.modeling_segformer.SegformerOverlapPatchEmbeddings
  56. class GLPNOverlapPatchEmbeddings(nn.Module):
  57. """Construct the overlapping patch embeddings."""
  58. def __init__(self, patch_size, stride, num_channels, hidden_size):
  59. super().__init__()
  60. self.proj = nn.Conv2d(
  61. num_channels,
  62. hidden_size,
  63. kernel_size=patch_size,
  64. stride=stride,
  65. padding=patch_size // 2,
  66. )
  67. self.layer_norm = nn.LayerNorm(hidden_size)
  68. def forward(self, pixel_values):
  69. embeddings = self.proj(pixel_values)
  70. _, _, height, width = embeddings.shape
  71. # (batch_size, num_channels, height, width) -> (batch_size, num_channels, height*width) -> (batch_size, height*width, num_channels)
  72. # this can be fed to a Transformer layer
  73. embeddings = embeddings.flatten(2).transpose(1, 2)
  74. embeddings = self.layer_norm(embeddings)
  75. return embeddings, height, width
  76. # Copied from transformers.models.segformer.modeling_segformer.SegformerEfficientSelfAttention
  77. class GLPNEfficientSelfAttention(nn.Module):
  78. """SegFormer's efficient self-attention mechanism. Employs the sequence reduction process introduced in the [PvT
  79. paper](https://huggingface.co/papers/2102.12122)."""
  80. def __init__(self, config, hidden_size, num_attention_heads, sequence_reduction_ratio):
  81. super().__init__()
  82. self.hidden_size = hidden_size
  83. self.num_attention_heads = num_attention_heads
  84. if self.hidden_size % self.num_attention_heads != 0:
  85. raise ValueError(
  86. f"The hidden size ({self.hidden_size}) is not a multiple of the number of attention "
  87. f"heads ({self.num_attention_heads})"
  88. )
  89. self.attention_head_size = int(self.hidden_size / self.num_attention_heads)
  90. self.all_head_size = self.num_attention_heads * self.attention_head_size
  91. self.query = nn.Linear(self.hidden_size, self.all_head_size)
  92. self.key = nn.Linear(self.hidden_size, self.all_head_size)
  93. self.value = nn.Linear(self.hidden_size, self.all_head_size)
  94. self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
  95. self.sr_ratio = sequence_reduction_ratio
  96. if sequence_reduction_ratio > 1:
  97. self.sr = nn.Conv2d(
  98. hidden_size, hidden_size, kernel_size=sequence_reduction_ratio, stride=sequence_reduction_ratio
  99. )
  100. self.layer_norm = nn.LayerNorm(hidden_size)
  101. def forward(
  102. self,
  103. hidden_states,
  104. height,
  105. width,
  106. output_attentions=False,
  107. ):
  108. batch_size, seq_length, _ = hidden_states.shape
  109. query_layer = (
  110. self.query(hidden_states)
  111. .view(batch_size, -1, self.num_attention_heads, self.attention_head_size)
  112. .transpose(1, 2)
  113. )
  114. if self.sr_ratio > 1:
  115. batch_size, seq_len, num_channels = hidden_states.shape
  116. # Reshape to (batch_size, num_channels, height, width)
  117. hidden_states = hidden_states.permute(0, 2, 1).reshape(batch_size, num_channels, height, width)
  118. # Apply sequence reduction
  119. hidden_states = self.sr(hidden_states)
  120. # Reshape back to (batch_size, seq_len, num_channels)
  121. hidden_states = hidden_states.reshape(batch_size, num_channels, -1).permute(0, 2, 1)
  122. hidden_states = self.layer_norm(hidden_states)
  123. key_layer = (
  124. self.key(hidden_states)
  125. .view(batch_size, -1, self.num_attention_heads, self.attention_head_size)
  126. .transpose(1, 2)
  127. )
  128. value_layer = (
  129. self.value(hidden_states)
  130. .view(batch_size, -1, self.num_attention_heads, self.attention_head_size)
  131. .transpose(1, 2)
  132. )
  133. # Take the dot product between "query" and "key" to get the raw attention scores.
  134. attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
  135. attention_scores = attention_scores / math.sqrt(self.attention_head_size)
  136. # Normalize the attention scores to probabilities.
  137. attention_probs = nn.functional.softmax(attention_scores, dim=-1)
  138. # This is actually dropping out entire tokens to attend to, which might
  139. # seem a bit unusual, but is taken from the original Transformer paper.
  140. attention_probs = self.dropout(attention_probs)
  141. context_layer = torch.matmul(attention_probs, value_layer)
  142. context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
  143. new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
  144. context_layer = context_layer.view(new_context_layer_shape)
  145. outputs = (context_layer, attention_probs) if output_attentions else (context_layer,)
  146. return outputs
  147. # Copied from transformers.models.segformer.modeling_segformer.SegformerSelfOutput
  148. class GLPNSelfOutput(nn.Module):
  149. def __init__(self, config, hidden_size):
  150. super().__init__()
  151. self.dense = nn.Linear(hidden_size, hidden_size)
  152. self.dropout = nn.Dropout(config.hidden_dropout_prob)
  153. def forward(self, hidden_states, input_tensor):
  154. hidden_states = self.dense(hidden_states)
  155. hidden_states = self.dropout(hidden_states)
  156. return hidden_states
  157. # Copied from transformers.models.segformer.modeling_segformer.SegformerAttention with Segformer->GLPN
  158. class GLPNAttention(nn.Module):
  159. def __init__(self, config, hidden_size, num_attention_heads, sequence_reduction_ratio):
  160. super().__init__()
  161. self.self = GLPNEfficientSelfAttention(
  162. config=config,
  163. hidden_size=hidden_size,
  164. num_attention_heads=num_attention_heads,
  165. sequence_reduction_ratio=sequence_reduction_ratio,
  166. )
  167. self.output = GLPNSelfOutput(config, hidden_size=hidden_size)
  168. self.pruned_heads = set()
  169. def prune_heads(self, heads):
  170. if len(heads) == 0:
  171. return
  172. heads, index = find_pruneable_heads_and_indices(
  173. heads, self.self.num_attention_heads, self.self.attention_head_size, self.pruned_heads
  174. )
  175. # Prune linear layers
  176. self.self.query = prune_linear_layer(self.self.query, index)
  177. self.self.key = prune_linear_layer(self.self.key, index)
  178. self.self.value = prune_linear_layer(self.self.value, index)
  179. self.output.dense = prune_linear_layer(self.output.dense, index, dim=1)
  180. # Update hyper params and store pruned heads
  181. self.self.num_attention_heads = self.self.num_attention_heads - len(heads)
  182. self.self.all_head_size = self.self.attention_head_size * self.self.num_attention_heads
  183. self.pruned_heads = self.pruned_heads.union(heads)
  184. def forward(self, hidden_states, height, width, output_attentions=False):
  185. self_outputs = self.self(hidden_states, height, width, output_attentions)
  186. attention_output = self.output(self_outputs[0], hidden_states)
  187. outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them
  188. return outputs
  189. # Copied from transformers.models.segformer.modeling_segformer.SegformerDWConv
  190. class GLPNDWConv(nn.Module):
  191. def __init__(self, dim=768):
  192. super().__init__()
  193. self.dwconv = nn.Conv2d(dim, dim, 3, 1, 1, bias=True, groups=dim)
  194. def forward(self, hidden_states, height, width):
  195. batch_size, seq_len, num_channels = hidden_states.shape
  196. hidden_states = hidden_states.transpose(1, 2).view(batch_size, num_channels, height, width)
  197. hidden_states = self.dwconv(hidden_states)
  198. hidden_states = hidden_states.flatten(2).transpose(1, 2)
  199. return hidden_states
  200. # Copied from transformers.models.segformer.modeling_segformer.SegformerMixFFN with Segformer->GLPN
  201. class GLPNMixFFN(nn.Module):
  202. def __init__(self, config, in_features, hidden_features=None, out_features=None):
  203. super().__init__()
  204. out_features = out_features or in_features
  205. self.dense1 = nn.Linear(in_features, hidden_features)
  206. self.dwconv = GLPNDWConv(hidden_features)
  207. if isinstance(config.hidden_act, str):
  208. self.intermediate_act_fn = ACT2FN[config.hidden_act]
  209. else:
  210. self.intermediate_act_fn = config.hidden_act
  211. self.dense2 = nn.Linear(hidden_features, out_features)
  212. self.dropout = nn.Dropout(config.hidden_dropout_prob)
  213. def forward(self, hidden_states, height, width):
  214. hidden_states = self.dense1(hidden_states)
  215. hidden_states = self.dwconv(hidden_states, height, width)
  216. hidden_states = self.intermediate_act_fn(hidden_states)
  217. hidden_states = self.dropout(hidden_states)
  218. hidden_states = self.dense2(hidden_states)
  219. hidden_states = self.dropout(hidden_states)
  220. return hidden_states
  221. # Copied from transformers.models.segformer.modeling_segformer.SegformerLayer with Segformer->GLPN
  222. class GLPNLayer(nn.Module):
  223. """This corresponds to the Block class in the original implementation."""
  224. def __init__(self, config, hidden_size, num_attention_heads, drop_path, sequence_reduction_ratio, mlp_ratio):
  225. super().__init__()
  226. self.layer_norm_1 = nn.LayerNorm(hidden_size)
  227. self.attention = GLPNAttention(
  228. config,
  229. hidden_size=hidden_size,
  230. num_attention_heads=num_attention_heads,
  231. sequence_reduction_ratio=sequence_reduction_ratio,
  232. )
  233. self.drop_path = GLPNDropPath(drop_path) if drop_path > 0.0 else nn.Identity()
  234. self.layer_norm_2 = nn.LayerNorm(hidden_size)
  235. mlp_hidden_size = int(hidden_size * mlp_ratio)
  236. self.mlp = GLPNMixFFN(config, in_features=hidden_size, hidden_features=mlp_hidden_size)
  237. def forward(self, hidden_states, height, width, output_attentions=False):
  238. self_attention_outputs = self.attention(
  239. self.layer_norm_1(hidden_states), # in GLPN, layernorm is applied before self-attention
  240. height,
  241. width,
  242. output_attentions=output_attentions,
  243. )
  244. attention_output = self_attention_outputs[0]
  245. outputs = self_attention_outputs[1:] # add self attentions if we output attention weights
  246. # first residual connection (with stochastic depth)
  247. attention_output = self.drop_path(attention_output)
  248. hidden_states = attention_output + hidden_states
  249. mlp_output = self.mlp(self.layer_norm_2(hidden_states), height, width)
  250. # second residual connection (with stochastic depth)
  251. mlp_output = self.drop_path(mlp_output)
  252. layer_output = mlp_output + hidden_states
  253. outputs = (layer_output,) + outputs
  254. return outputs
  255. class GLPNEncoder(nn.Module):
  256. def __init__(self, config):
  257. super().__init__()
  258. self.config = config
  259. # stochastic depth decay rule
  260. dpr = [x.item() for x in torch.linspace(0, config.drop_path_rate, sum(config.depths), device="cpu")]
  261. # patch embeddings
  262. embeddings = []
  263. for i in range(config.num_encoder_blocks):
  264. embeddings.append(
  265. GLPNOverlapPatchEmbeddings(
  266. patch_size=config.patch_sizes[i],
  267. stride=config.strides[i],
  268. num_channels=config.num_channels if i == 0 else config.hidden_sizes[i - 1],
  269. hidden_size=config.hidden_sizes[i],
  270. )
  271. )
  272. self.patch_embeddings = nn.ModuleList(embeddings)
  273. # Transformer blocks
  274. blocks = []
  275. cur = 0
  276. for i in range(config.num_encoder_blocks):
  277. # each block consists of layers
  278. layers = []
  279. if i != 0:
  280. cur += config.depths[i - 1]
  281. for j in range(config.depths[i]):
  282. layers.append(
  283. GLPNLayer(
  284. config,
  285. hidden_size=config.hidden_sizes[i],
  286. num_attention_heads=config.num_attention_heads[i],
  287. drop_path=dpr[cur + j],
  288. sequence_reduction_ratio=config.sr_ratios[i],
  289. mlp_ratio=config.mlp_ratios[i],
  290. )
  291. )
  292. blocks.append(nn.ModuleList(layers))
  293. self.block = nn.ModuleList(blocks)
  294. # Layer norms
  295. self.layer_norm = nn.ModuleList(
  296. [nn.LayerNorm(config.hidden_sizes[i]) for i in range(config.num_encoder_blocks)]
  297. )
  298. def forward(
  299. self,
  300. pixel_values,
  301. output_attentions=False,
  302. output_hidden_states=False,
  303. return_dict=True,
  304. ):
  305. all_hidden_states = () if output_hidden_states else None
  306. all_self_attentions = () if output_attentions else None
  307. batch_size = pixel_values.shape[0]
  308. hidden_states = pixel_values
  309. for idx, x in enumerate(zip(self.patch_embeddings, self.block, self.layer_norm)):
  310. embedding_layer, block_layer, norm_layer = x
  311. # first, obtain patch embeddings
  312. hidden_states, height, width = embedding_layer(hidden_states)
  313. # second, send embeddings through blocks
  314. for i, blk in enumerate(block_layer):
  315. layer_outputs = blk(hidden_states, height, width, output_attentions)
  316. hidden_states = layer_outputs[0]
  317. if output_attentions:
  318. all_self_attentions = all_self_attentions + (layer_outputs[1],)
  319. # third, apply layer norm
  320. hidden_states = norm_layer(hidden_states)
  321. # fourth, optionally reshape back to (batch_size, num_channels, height, width)
  322. hidden_states = hidden_states.reshape(batch_size, height, width, -1).permute(0, 3, 1, 2).contiguous()
  323. if output_hidden_states:
  324. all_hidden_states = all_hidden_states + (hidden_states,)
  325. if not return_dict:
  326. return tuple(v for v in [hidden_states, all_hidden_states, all_self_attentions] if v is not None)
  327. return BaseModelOutput(
  328. last_hidden_state=hidden_states,
  329. hidden_states=all_hidden_states,
  330. attentions=all_self_attentions,
  331. )
  332. @auto_docstring
  333. class GLPNPreTrainedModel(PreTrainedModel):
  334. config: GLPNConfig
  335. base_model_prefix = "glpn"
  336. main_input_name = "pixel_values"
  337. _no_split_modules = []
  338. # Copied from transformers.models.segformer.modeling_segformer.SegformerPreTrainedModel._init_weights
  339. def _init_weights(self, module):
  340. """Initialize the weights"""
  341. if isinstance(module, (nn.Linear, nn.Conv2d)):
  342. # Slightly different from the TF version which uses truncated_normal for initialization
  343. # cf https://github.com/pytorch/pytorch/pull/5617
  344. module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
  345. if module.bias is not None:
  346. module.bias.data.zero_()
  347. elif isinstance(module, nn.Embedding):
  348. module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
  349. if module.padding_idx is not None:
  350. module.weight.data[module.padding_idx].zero_()
  351. elif isinstance(module, (nn.LayerNorm, nn.BatchNorm2d)):
  352. module.bias.data.zero_()
  353. module.weight.data.fill_(1.0)
  354. @auto_docstring
  355. class GLPNModel(GLPNPreTrainedModel):
  356. # Copied from transformers.models.segformer.modeling_segformer.SegformerModel.__init__ with Segformer->GLPN
  357. def __init__(self, config):
  358. super().__init__(config)
  359. self.config = config
  360. # hierarchical Transformer encoder
  361. self.encoder = GLPNEncoder(config)
  362. # Initialize weights and apply final processing
  363. self.post_init()
  364. def _prune_heads(self, heads_to_prune):
  365. """
  366. Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base
  367. class PreTrainedModel
  368. """
  369. for layer, heads in heads_to_prune.items():
  370. self.encoder.layer[layer].attention.prune_heads(heads)
  371. @auto_docstring
  372. # Copied from transformers.models.segformer.modeling_segformer.SegformerModel.forward
  373. def forward(
  374. self,
  375. pixel_values: torch.FloatTensor,
  376. output_attentions: Optional[bool] = None,
  377. output_hidden_states: Optional[bool] = None,
  378. return_dict: Optional[bool] = None,
  379. ) -> Union[tuple, BaseModelOutput]:
  380. output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
  381. output_hidden_states = (
  382. output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
  383. )
  384. return_dict = return_dict if return_dict is not None else self.config.use_return_dict
  385. encoder_outputs = self.encoder(
  386. pixel_values,
  387. output_attentions=output_attentions,
  388. output_hidden_states=output_hidden_states,
  389. return_dict=return_dict,
  390. )
  391. sequence_output = encoder_outputs[0]
  392. if not return_dict:
  393. return (sequence_output,) + encoder_outputs[1:]
  394. return BaseModelOutput(
  395. last_hidden_state=sequence_output,
  396. hidden_states=encoder_outputs.hidden_states,
  397. attentions=encoder_outputs.attentions,
  398. )
  399. class GLPNSelectiveFeatureFusion(nn.Module):
  400. """
  401. Selective Feature Fusion module, as explained in the [paper](https://huggingface.co/papers/2201.07436) (section 3.4). This
  402. module adaptively selects and integrates local and global features by attaining an attention map for each feature.
  403. """
  404. def __init__(self, in_channel=64):
  405. super().__init__()
  406. self.convolutional_layer1 = nn.Sequential(
  407. nn.Conv2d(in_channels=int(in_channel * 2), out_channels=in_channel, kernel_size=3, stride=1, padding=1),
  408. nn.BatchNorm2d(in_channel),
  409. nn.ReLU(),
  410. )
  411. self.convolutional_layer2 = nn.Sequential(
  412. nn.Conv2d(in_channels=in_channel, out_channels=int(in_channel / 2), kernel_size=3, stride=1, padding=1),
  413. nn.BatchNorm2d(int(in_channel / 2)),
  414. nn.ReLU(),
  415. )
  416. self.convolutional_layer3 = nn.Conv2d(
  417. in_channels=int(in_channel / 2), out_channels=2, kernel_size=3, stride=1, padding=1
  418. )
  419. self.sigmoid = nn.Sigmoid()
  420. def forward(self, local_features, global_features):
  421. # concatenate features along the channel dimension
  422. features = torch.cat((local_features, global_features), dim=1)
  423. # pass through convolutional layers
  424. features = self.convolutional_layer1(features)
  425. features = self.convolutional_layer2(features)
  426. features = self.convolutional_layer3(features)
  427. # apply sigmoid to get two-channel attention map
  428. attn = self.sigmoid(features)
  429. # construct hybrid features by adding element-wise
  430. hybrid_features = local_features * attn[:, 0, :, :].unsqueeze(1) + global_features * attn[
  431. :, 1, :, :
  432. ].unsqueeze(1)
  433. return hybrid_features
  434. class GLPNDecoderStage(nn.Module):
  435. def __init__(self, in_channels, out_channels):
  436. super().__init__()
  437. should_skip = in_channels == out_channels
  438. self.convolution = nn.Conv2d(in_channels, out_channels, kernel_size=1) if not should_skip else nn.Identity()
  439. self.fusion = GLPNSelectiveFeatureFusion(out_channels)
  440. self.upsample = nn.Upsample(scale_factor=2, mode="bilinear", align_corners=False)
  441. def forward(self, hidden_state, residual=None):
  442. hidden_state = self.convolution(hidden_state)
  443. if residual is not None:
  444. hidden_state = self.fusion(hidden_state, residual)
  445. hidden_state = self.upsample(hidden_state)
  446. return hidden_state
  447. hidden_state = self.upsample(hidden_state)
  448. return hidden_state
  449. class GLPNDecoder(nn.Module):
  450. def __init__(self, config):
  451. super().__init__()
  452. # we use features from end -> start
  453. reserved_hidden_sizes = config.hidden_sizes[::-1]
  454. out_channels = config.decoder_hidden_size
  455. self.stages = nn.ModuleList(
  456. [GLPNDecoderStage(hidden_size, out_channels) for hidden_size in reserved_hidden_sizes]
  457. )
  458. # don't fuse in first stage
  459. self.stages[0].fusion = None
  460. self.final_upsample = nn.Upsample(scale_factor=2, mode="bilinear", align_corners=False)
  461. def forward(self, hidden_states: list[torch.Tensor]) -> list[torch.Tensor]:
  462. stage_hidden_states = []
  463. stage_hidden_state = None
  464. for hidden_state, stage in zip(hidden_states[::-1], self.stages):
  465. stage_hidden_state = stage(hidden_state, stage_hidden_state)
  466. stage_hidden_states.append(stage_hidden_state)
  467. stage_hidden_states[-1] = self.final_upsample(stage_hidden_state)
  468. return stage_hidden_states
  469. class SiLogLoss(nn.Module):
  470. r"""
  471. Implements the Scale-invariant log scale loss [Eigen et al., 2014](https://huggingface.co/papers/1406.2283).
  472. $$L=\frac{1}{n} \sum_{i} d_{i}^{2}-\frac{1}{2 n^{2}}\left(\sum_{i} d_{i}^{2}\right)$$ where $d_{i}=\log y_{i}-\log
  473. y_{i}^{*}$.
  474. """
  475. def __init__(self, lambd=0.5):
  476. super().__init__()
  477. self.lambd = lambd
  478. def forward(self, pred, target):
  479. valid_mask = (target > 0).detach()
  480. diff_log = torch.log(target[valid_mask]) - torch.log(pred[valid_mask])
  481. loss = torch.sqrt(torch.pow(diff_log, 2).mean() - self.lambd * torch.pow(diff_log.mean(), 2))
  482. return loss
  483. class GLPNDepthEstimationHead(nn.Module):
  484. def __init__(self, config):
  485. super().__init__()
  486. self.config = config
  487. channels = config.decoder_hidden_size
  488. self.head = nn.Sequential(
  489. nn.Conv2d(channels, channels, kernel_size=3, stride=1, padding=1),
  490. nn.ReLU(inplace=False),
  491. nn.Conv2d(channels, 1, kernel_size=3, stride=1, padding=1),
  492. )
  493. def forward(self, hidden_states: list[torch.Tensor]) -> torch.Tensor:
  494. # use last features of the decoder
  495. hidden_states = hidden_states[self.config.head_in_index]
  496. hidden_states = self.head(hidden_states)
  497. predicted_depth = torch.sigmoid(hidden_states) * self.config.max_depth
  498. predicted_depth = predicted_depth.squeeze(dim=1)
  499. return predicted_depth
  500. @auto_docstring(
  501. custom_intro="""
  502. GLPN Model transformer with a lightweight depth estimation head on top e.g. for KITTI, NYUv2.
  503. """
  504. )
  505. class GLPNForDepthEstimation(GLPNPreTrainedModel):
  506. def __init__(self, config):
  507. super().__init__(config)
  508. self.glpn = GLPNModel(config)
  509. self.decoder = GLPNDecoder(config)
  510. self.head = GLPNDepthEstimationHead(config)
  511. # Initialize weights and apply final processing
  512. self.post_init()
  513. @auto_docstring
  514. def forward(
  515. self,
  516. pixel_values: torch.FloatTensor,
  517. labels: Optional[torch.FloatTensor] = None,
  518. output_attentions: Optional[bool] = None,
  519. output_hidden_states: Optional[bool] = None,
  520. return_dict: Optional[bool] = None,
  521. ) -> Union[tuple[torch.Tensor], DepthEstimatorOutput]:
  522. r"""
  523. labels (`torch.FloatTensor` of shape `(batch_size, height, width)`, *optional*):
  524. Ground truth depth estimation maps for computing the loss.
  525. Examples:
  526. ```python
  527. >>> from transformers import AutoImageProcessor, GLPNForDepthEstimation
  528. >>> import torch
  529. >>> import numpy as np
  530. >>> from PIL import Image
  531. >>> import requests
  532. >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
  533. >>> image = Image.open(requests.get(url, stream=True).raw)
  534. >>> image_processor = AutoImageProcessor.from_pretrained("vinvino02/glpn-kitti")
  535. >>> model = GLPNForDepthEstimation.from_pretrained("vinvino02/glpn-kitti")
  536. >>> # prepare image for the model
  537. >>> inputs = image_processor(images=image, return_tensors="pt")
  538. >>> with torch.no_grad():
  539. ... outputs = model(**inputs)
  540. >>> # interpolate to original size
  541. >>> post_processed_output = image_processor.post_process_depth_estimation(
  542. ... outputs,
  543. ... target_sizes=[(image.height, image.width)],
  544. ... )
  545. >>> # visualize the prediction
  546. >>> predicted_depth = post_processed_output[0]["predicted_depth"]
  547. >>> depth = predicted_depth * 255 / predicted_depth.max()
  548. >>> depth = depth.detach().cpu().numpy()
  549. >>> depth = Image.fromarray(depth.astype("uint8"))
  550. ```"""
  551. return_dict = return_dict if return_dict is not None else self.config.use_return_dict
  552. output_hidden_states = (
  553. output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
  554. )
  555. outputs = self.glpn(
  556. pixel_values,
  557. output_attentions=output_attentions,
  558. output_hidden_states=True, # we need the intermediate hidden states
  559. return_dict=return_dict,
  560. )
  561. hidden_states = outputs.hidden_states if return_dict else outputs[1]
  562. out = self.decoder(hidden_states)
  563. predicted_depth = self.head(out)
  564. loss = None
  565. if labels is not None:
  566. loss_fct = SiLogLoss()
  567. loss = loss_fct(predicted_depth, labels)
  568. if not return_dict:
  569. if output_hidden_states:
  570. output = (predicted_depth,) + outputs[1:]
  571. else:
  572. output = (predicted_depth,) + outputs[2:]
  573. return ((loss,) + output) if loss is not None else output
  574. return DepthEstimatorOutput(
  575. loss=loss,
  576. predicted_depth=predicted_depth,
  577. hidden_states=outputs.hidden_states if output_hidden_states else None,
  578. attentions=outputs.attentions,
  579. )
  580. __all__ = ["GLPNForDepthEstimation", "GLPNLayer", "GLPNModel", "GLPNPreTrainedModel"]