modeling_zoedepth.py 53 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160116111621163116411651166116711681169117011711172117311741175117611771178117911801181118211831184118511861187118811891190119111921193119411951196119711981199120012011202120312041205120612071208120912101211121212131214121512161217121812191220122112221223122412251226122712281229123012311232123312341235123612371238123912401241124212431244124512461247124812491250125112521253125412551256125712581259126012611262126312641265126612671268126912701271127212731274127512761277127812791280128112821283128412851286128712881289129012911292129312941295129612971298129913001301130213031304130513061307130813091310131113121313131413151316131713181319132013211322132313241325132613271328132913301331133213331334133513361337133813391340134113421343134413451346134713481349135013511352
  1. # coding=utf-8
  2. # Copyright 2024 Intel Labs and The HuggingFace Inc. team. All rights reserved.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. """PyTorch ZoeDepth model."""
  16. import math
  17. from dataclasses import dataclass
  18. from typing import Optional, Union
  19. import torch
  20. from torch import nn
  21. from ...activations import ACT2FN
  22. from ...modeling_outputs import DepthEstimatorOutput
  23. from ...modeling_utils import PreTrainedModel
  24. from ...utils import ModelOutput, auto_docstring, logging
  25. from ...utils.backbone_utils import load_backbone
  26. from .configuration_zoedepth import ZoeDepthConfig
  27. logger = logging.get_logger(__name__)
  28. @dataclass
  29. @auto_docstring(
  30. custom_intro="""
  31. Extension of `DepthEstimatorOutput` to include domain logits (ZoeDepth specific).
  32. """
  33. )
  34. class ZoeDepthDepthEstimatorOutput(ModelOutput):
  35. r"""
  36. loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
  37. Classification (or regression if config.num_labels==1) loss.
  38. domain_logits (`torch.FloatTensor` of shape `(batch_size, num_domains)`):
  39. Logits for each domain (e.g. NYU and KITTI) in case multiple metric heads are used.
  40. """
  41. loss: Optional[torch.FloatTensor] = None
  42. predicted_depth: Optional[torch.FloatTensor] = None
  43. domain_logits: Optional[torch.FloatTensor] = None
  44. hidden_states: Optional[tuple[torch.FloatTensor, ...]] = None
  45. attentions: Optional[tuple[torch.FloatTensor, ...]] = None
  46. class ZoeDepthReassembleStage(nn.Module):
  47. """
  48. This class reassembles the hidden states of the backbone into image-like feature representations at various
  49. resolutions.
  50. This happens in 3 stages:
  51. 1. Map the N + 1 tokens to a set of N tokens, by taking into account the readout ([CLS]) token according to
  52. `config.readout_type`.
  53. 2. Project the channel dimension of the hidden states according to `config.neck_hidden_sizes`.
  54. 3. Resizing the spatial dimensions (height, width).
  55. Args:
  56. config (`[ZoeDepthConfig]`):
  57. Model configuration class defining the model architecture.
  58. """
  59. def __init__(self, config):
  60. super().__init__()
  61. self.readout_type = config.readout_type
  62. self.layers = nn.ModuleList()
  63. for neck_hidden_size, factor in zip(config.neck_hidden_sizes, config.reassemble_factors):
  64. self.layers.append(ZoeDepthReassembleLayer(config, channels=neck_hidden_size, factor=factor))
  65. if config.readout_type == "project":
  66. self.readout_projects = nn.ModuleList()
  67. hidden_size = config.backbone_hidden_size
  68. for _ in config.neck_hidden_sizes:
  69. self.readout_projects.append(
  70. nn.Sequential(nn.Linear(2 * hidden_size, hidden_size), ACT2FN[config.hidden_act])
  71. )
  72. def forward(self, hidden_states: list[torch.Tensor], patch_height, patch_width) -> list[torch.Tensor]:
  73. """
  74. Args:
  75. hidden_states (`list[torch.FloatTensor]`, each of shape `(batch_size, sequence_length + 1, hidden_size)`):
  76. List of hidden states from the backbone.
  77. """
  78. batch_size = hidden_states[0].shape[0]
  79. # stack along batch dimension
  80. # shape (batch_size*num_stages, sequence_length + 1, hidden_size)
  81. hidden_states = torch.cat(hidden_states, dim=0)
  82. cls_token, hidden_states = hidden_states[:, 0], hidden_states[:, 1:]
  83. # reshape hidden_states to (batch_size*num_stages, num_channels, height, width)
  84. total_batch_size, sequence_length, num_channels = hidden_states.shape
  85. hidden_states = hidden_states.reshape(total_batch_size, patch_height, patch_width, num_channels)
  86. hidden_states = hidden_states.permute(0, 3, 1, 2).contiguous()
  87. if self.readout_type == "project":
  88. # reshape to (batch_size*num_stages, height*width, num_channels)
  89. hidden_states = hidden_states.flatten(2).permute((0, 2, 1))
  90. readout = cls_token.unsqueeze(dim=1).expand_as(hidden_states)
  91. # concatenate the readout token to the hidden states
  92. # to get (batch_size*num_stages, height*width, 2*num_channels)
  93. hidden_states = torch.cat((hidden_states, readout), -1)
  94. elif self.readout_type == "add":
  95. hidden_states = hidden_states + cls_token.unsqueeze(-1)
  96. out = []
  97. for stage_idx, hidden_state in enumerate(hidden_states.split(batch_size, dim=0)):
  98. if self.readout_type == "project":
  99. hidden_state = self.readout_projects[stage_idx](hidden_state)
  100. # reshape back to (batch_size, num_channels, height, width)
  101. hidden_state = hidden_state.permute(0, 2, 1).reshape(batch_size, -1, patch_height, patch_width)
  102. hidden_state = self.layers[stage_idx](hidden_state)
  103. out.append(hidden_state)
  104. return out
  105. class ZoeDepthReassembleLayer(nn.Module):
  106. def __init__(self, config, channels, factor):
  107. super().__init__()
  108. # projection
  109. hidden_size = config.backbone_hidden_size
  110. self.projection = nn.Conv2d(in_channels=hidden_size, out_channels=channels, kernel_size=1)
  111. # up/down sampling depending on factor
  112. if factor > 1:
  113. self.resize = nn.ConvTranspose2d(channels, channels, kernel_size=factor, stride=factor, padding=0)
  114. elif factor == 1:
  115. self.resize = nn.Identity()
  116. elif factor < 1:
  117. # so should downsample
  118. self.resize = nn.Conv2d(channels, channels, kernel_size=3, stride=int(1 / factor), padding=1)
  119. # Copied from transformers.models.dpt.modeling_dpt.DPTReassembleLayer.forward with DPT->ZoeDepth
  120. def forward(self, hidden_state):
  121. hidden_state = self.projection(hidden_state)
  122. hidden_state = self.resize(hidden_state)
  123. return hidden_state
  124. # Copied from transformers.models.dpt.modeling_dpt.DPTFeatureFusionStage with DPT->ZoeDepth
  125. class ZoeDepthFeatureFusionStage(nn.Module):
  126. def __init__(self, config: ZoeDepthConfig):
  127. super().__init__()
  128. self.layers = nn.ModuleList()
  129. for _ in range(len(config.neck_hidden_sizes)):
  130. self.layers.append(ZoeDepthFeatureFusionLayer(config))
  131. def forward(self, hidden_states):
  132. # reversing the hidden_states, we start from the last
  133. hidden_states = hidden_states[::-1]
  134. fused_hidden_states = []
  135. fused_hidden_state = None
  136. for hidden_state, layer in zip(hidden_states, self.layers):
  137. if fused_hidden_state is None:
  138. # first layer only uses the last hidden_state
  139. fused_hidden_state = layer(hidden_state)
  140. else:
  141. fused_hidden_state = layer(fused_hidden_state, hidden_state)
  142. fused_hidden_states.append(fused_hidden_state)
  143. return fused_hidden_states
  144. # Copied from transformers.models.dpt.modeling_dpt.DPTPreActResidualLayer with DPT->ZoeDepth
  145. class ZoeDepthPreActResidualLayer(nn.Module):
  146. """
  147. ResidualConvUnit, pre-activate residual unit.
  148. Args:
  149. config (`[ZoeDepthConfig]`):
  150. Model configuration class defining the model architecture.
  151. """
  152. # Ignore copy
  153. def __init__(self, config):
  154. super().__init__()
  155. self.use_batch_norm = config.use_batch_norm_in_fusion_residual
  156. use_bias_in_fusion_residual = (
  157. config.use_bias_in_fusion_residual
  158. if config.use_bias_in_fusion_residual is not None
  159. else not self.use_batch_norm
  160. )
  161. self.activation1 = nn.ReLU()
  162. self.convolution1 = nn.Conv2d(
  163. config.fusion_hidden_size,
  164. config.fusion_hidden_size,
  165. kernel_size=3,
  166. stride=1,
  167. padding=1,
  168. bias=use_bias_in_fusion_residual,
  169. )
  170. self.activation2 = nn.ReLU()
  171. self.convolution2 = nn.Conv2d(
  172. config.fusion_hidden_size,
  173. config.fusion_hidden_size,
  174. kernel_size=3,
  175. stride=1,
  176. padding=1,
  177. bias=use_bias_in_fusion_residual,
  178. )
  179. if self.use_batch_norm:
  180. self.batch_norm1 = nn.BatchNorm2d(config.fusion_hidden_size, eps=config.batch_norm_eps)
  181. self.batch_norm2 = nn.BatchNorm2d(config.fusion_hidden_size, eps=config.batch_norm_eps)
  182. def forward(self, hidden_state: torch.Tensor) -> torch.Tensor:
  183. residual = hidden_state
  184. hidden_state = self.activation1(hidden_state)
  185. hidden_state = self.convolution1(hidden_state)
  186. if self.use_batch_norm:
  187. hidden_state = self.batch_norm1(hidden_state)
  188. hidden_state = self.activation2(hidden_state)
  189. hidden_state = self.convolution2(hidden_state)
  190. if self.use_batch_norm:
  191. hidden_state = self.batch_norm2(hidden_state)
  192. return hidden_state + residual
  193. # Copied from transformers.models.dpt.modeling_dpt.DPTFeatureFusionLayer with DPT->ZoeDepth
  194. class ZoeDepthFeatureFusionLayer(nn.Module):
  195. """Feature fusion layer, merges feature maps from different stages.
  196. Args:
  197. config (`[ZoeDepthConfig]`):
  198. Model configuration class defining the model architecture.
  199. align_corners (`bool`, *optional*, defaults to `True`):
  200. The align_corner setting for bilinear upsample.
  201. """
  202. def __init__(self, config: ZoeDepthConfig, align_corners: bool = True):
  203. super().__init__()
  204. self.align_corners = align_corners
  205. self.projection = nn.Conv2d(config.fusion_hidden_size, config.fusion_hidden_size, kernel_size=1, bias=True)
  206. self.residual_layer1 = ZoeDepthPreActResidualLayer(config)
  207. self.residual_layer2 = ZoeDepthPreActResidualLayer(config)
  208. def forward(self, hidden_state: torch.Tensor, residual: Optional[torch.Tensor] = None) -> torch.Tensor:
  209. if residual is not None:
  210. if hidden_state.shape != residual.shape:
  211. residual = nn.functional.interpolate(
  212. residual, size=(hidden_state.shape[2], hidden_state.shape[3]), mode="bilinear", align_corners=False
  213. )
  214. hidden_state = hidden_state + self.residual_layer1(residual)
  215. hidden_state = self.residual_layer2(hidden_state)
  216. hidden_state = nn.functional.interpolate(
  217. hidden_state, scale_factor=2, mode="bilinear", align_corners=self.align_corners
  218. )
  219. hidden_state = self.projection(hidden_state)
  220. return hidden_state
  221. class ZoeDepthNeck(nn.Module):
  222. """
  223. ZoeDepthNeck. A neck is a module that is normally used between the backbone and the head. It takes a list of tensors as
  224. input and produces another list of tensors as output. For ZoeDepth, it includes 2 stages:
  225. * ZoeDepthReassembleStage
  226. * ZoeDepthFeatureFusionStage.
  227. Args:
  228. config (dict): config dict.
  229. """
  230. # Copied from transformers.models.dpt.modeling_dpt.DPTNeck.__init__ with DPT->ZoeDepth
  231. def __init__(self, config: ZoeDepthConfig):
  232. super().__init__()
  233. self.config = config
  234. # postprocessing: only required in case of a non-hierarchical backbone (e.g. ViT, BEiT)
  235. if config.backbone_config is not None and config.backbone_config.model_type == "swinv2":
  236. self.reassemble_stage = None
  237. else:
  238. self.reassemble_stage = ZoeDepthReassembleStage(config)
  239. self.convs = nn.ModuleList()
  240. for channel in config.neck_hidden_sizes:
  241. self.convs.append(nn.Conv2d(channel, config.fusion_hidden_size, kernel_size=3, padding=1, bias=False))
  242. # fusion
  243. self.fusion_stage = ZoeDepthFeatureFusionStage(config)
  244. def forward(self, hidden_states: list[torch.Tensor], patch_height, patch_width) -> list[torch.Tensor]:
  245. """
  246. Args:
  247. hidden_states (`list[torch.FloatTensor]`, each of shape `(batch_size, sequence_length, hidden_size)` or `(batch_size, hidden_size, height, width)`):
  248. List of hidden states from the backbone.
  249. """
  250. if not isinstance(hidden_states, (tuple, list)):
  251. raise TypeError("hidden_states should be a tuple or list of tensors")
  252. if len(hidden_states) != len(self.config.neck_hidden_sizes):
  253. raise ValueError("The number of hidden states should be equal to the number of neck hidden sizes.")
  254. # postprocess hidden states
  255. if self.reassemble_stage is not None:
  256. hidden_states = self.reassemble_stage(hidden_states, patch_height, patch_width)
  257. features = [self.convs[i](feature) for i, feature in enumerate(hidden_states)]
  258. # fusion blocks
  259. output = self.fusion_stage(features)
  260. return output, features[-1]
  261. class ZoeDepthRelativeDepthEstimationHead(nn.Module):
  262. """
  263. Relative depth estimation head consisting of 3 convolutional layers. It progressively halves the feature dimension and upsamples
  264. the predictions to the input resolution after the first convolutional layer (details can be found in DPT's paper's
  265. supplementary material).
  266. """
  267. def __init__(self, config):
  268. super().__init__()
  269. self.head_in_index = config.head_in_index
  270. self.projection = None
  271. if config.add_projection:
  272. self.projection = nn.Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  273. features = config.fusion_hidden_size
  274. self.conv1 = nn.Conv2d(features, features // 2, kernel_size=3, stride=1, padding=1)
  275. self.upsample = nn.Upsample(scale_factor=2, mode="bilinear", align_corners=True)
  276. self.conv2 = nn.Conv2d(features // 2, config.num_relative_features, kernel_size=3, stride=1, padding=1)
  277. self.conv3 = nn.Conv2d(config.num_relative_features, 1, kernel_size=1, stride=1, padding=0)
  278. def forward(self, hidden_states: list[torch.Tensor]) -> torch.Tensor:
  279. # use last features
  280. hidden_states = hidden_states[self.head_in_index]
  281. if self.projection is not None:
  282. hidden_states = self.projection(hidden_states)
  283. hidden_states = nn.ReLU()(hidden_states)
  284. hidden_states = self.conv1(hidden_states)
  285. hidden_states = self.upsample(hidden_states)
  286. hidden_states = self.conv2(hidden_states)
  287. hidden_states = nn.ReLU()(hidden_states)
  288. # we need the features here (after second conv + ReLu)
  289. features = hidden_states
  290. hidden_states = self.conv3(hidden_states)
  291. hidden_states = nn.ReLU()(hidden_states)
  292. predicted_depth = hidden_states.squeeze(dim=1)
  293. return predicted_depth, features
  294. def log_binom(n, k, eps=1e-7):
  295. """log(nCk) using stirling approximation"""
  296. n = n + eps
  297. k = k + eps
  298. return n * torch.log(n) - k * torch.log(k) - (n - k) * torch.log(n - k + eps)
  299. class LogBinomialSoftmax(nn.Module):
  300. def __init__(self, n_classes=256, act=torch.softmax):
  301. """Compute log binomial distribution for n_classes
  302. Args:
  303. n_classes (`int`, *optional*, defaults to 256):
  304. Number of output classes.
  305. act (`torch.nn.Module`, *optional*, defaults to `torch.softmax`):
  306. Activation function to apply to the output.
  307. """
  308. super().__init__()
  309. self.k = n_classes
  310. self.act = act
  311. self.register_buffer("k_idx", torch.arange(0, n_classes).view(1, -1, 1, 1), persistent=False)
  312. self.register_buffer("k_minus_1", torch.tensor([self.k - 1]).view(1, -1, 1, 1), persistent=False)
  313. def forward(self, probabilities, temperature=1.0, eps=1e-4):
  314. """Compute the log binomial distribution for probabilities.
  315. Args:
  316. probabilities (`torch.Tensor` of shape `(batch_size, num_channels, height, width)`):
  317. Tensor containing probabilities of each class.
  318. temperature (`float` or `torch.Tensor` of shape `(batch_size, num_channels, height, width)`, *optional*, defaults to 1):
  319. Temperature of distribution.
  320. eps (`float`, *optional*, defaults to 1e-4):
  321. Small number for numerical stability.
  322. Returns:
  323. `torch.Tensor` of shape `(batch_size, num_channels, height, width)`:
  324. Log binomial distribution logbinomial(p;t).
  325. """
  326. if probabilities.ndim == 3:
  327. probabilities = probabilities.unsqueeze(1) # make it (batch_size, num_channels, height, width)
  328. one_minus_probabilities = torch.clamp(1 - probabilities, eps, 1)
  329. probabilities = torch.clamp(probabilities, eps, 1)
  330. y = (
  331. log_binom(self.k_minus_1, self.k_idx)
  332. + self.k_idx * torch.log(probabilities)
  333. + (self.k_minus_1 - self.k_idx) * torch.log(one_minus_probabilities)
  334. )
  335. return self.act(y / temperature, dim=1)
  336. class ZoeDepthConditionalLogBinomialSoftmax(nn.Module):
  337. def __init__(
  338. self,
  339. config,
  340. in_features,
  341. condition_dim,
  342. n_classes=256,
  343. bottleneck_factor=2,
  344. ):
  345. """Per-pixel MLP followed by a Conditional Log Binomial softmax.
  346. Args:
  347. in_features (`int`):
  348. Number of input channels in the main feature.
  349. condition_dim (`int`):
  350. Number of input channels in the condition feature.
  351. n_classes (`int`, *optional*, defaults to 256):
  352. Number of classes.
  353. bottleneck_factor (`int`, *optional*, defaults to 2):
  354. Hidden dim factor.
  355. """
  356. super().__init__()
  357. bottleneck = (in_features + condition_dim) // bottleneck_factor
  358. self.mlp = nn.Sequential(
  359. nn.Conv2d(in_features + condition_dim, bottleneck, kernel_size=1, stride=1, padding=0),
  360. nn.GELU(),
  361. # 2 for probabilities linear norm, 2 for temperature linear norm
  362. nn.Conv2d(bottleneck, 2 + 2, kernel_size=1, stride=1, padding=0),
  363. nn.Softplus(),
  364. )
  365. self.p_eps = 1e-4
  366. self.max_temp = config.max_temp
  367. self.min_temp = config.min_temp
  368. self.log_binomial_transform = LogBinomialSoftmax(n_classes, act=torch.softmax)
  369. def forward(self, main_feature, condition_feature):
  370. """
  371. Args:
  372. main_feature (`torch.Tensor` of shape `(batch_size, num_channels, height, width)`):
  373. Main feature.
  374. condition_feature (torch.Tensor of shape `(batch_size, num_channels, height, width)`):
  375. Condition feature.
  376. Returns:
  377. `torch.Tensor`:
  378. Output log binomial distribution
  379. """
  380. probabilities_and_temperature = self.mlp(torch.concat((main_feature, condition_feature), dim=1))
  381. probabilities, temperature = (
  382. probabilities_and_temperature[:, :2, ...],
  383. probabilities_and_temperature[:, 2:, ...],
  384. )
  385. probabilities = probabilities + self.p_eps
  386. probabilities = probabilities[:, 0, ...] / (probabilities[:, 0, ...] + probabilities[:, 1, ...])
  387. temperature = temperature + self.p_eps
  388. temperature = temperature[:, 0, ...] / (temperature[:, 0, ...] + temperature[:, 1, ...])
  389. temperature = temperature.unsqueeze(1)
  390. temperature = (self.max_temp - self.min_temp) * temperature + self.min_temp
  391. return self.log_binomial_transform(probabilities, temperature)
  392. class ZoeDepthSeedBinRegressor(nn.Module):
  393. def __init__(self, config, n_bins=16, mlp_dim=256, min_depth=1e-3, max_depth=10):
  394. """Bin center regressor network.
  395. Can be "normed" or "unnormed". If "normed", bin centers are bounded on the (min_depth, max_depth) interval.
  396. Args:
  397. config (`int`):
  398. Model configuration.
  399. n_bins (`int`, *optional*, defaults to 16):
  400. Number of bin centers.
  401. mlp_dim (`int`, *optional*, defaults to 256):
  402. Hidden dimension.
  403. min_depth (`float`, *optional*, defaults to 1e-3):
  404. Min depth value.
  405. max_depth (`float`, *optional*, defaults to 10):
  406. Max depth value.
  407. """
  408. super().__init__()
  409. self.in_features = config.bottleneck_features
  410. self.bin_centers_type = config.bin_centers_type
  411. self.min_depth = min_depth
  412. self.max_depth = max_depth
  413. self.conv1 = nn.Conv2d(self.in_features, mlp_dim, 1, 1, 0)
  414. self.act1 = nn.ReLU(inplace=True)
  415. self.conv2 = nn.Conv2d(mlp_dim, n_bins, 1, 1, 0)
  416. self.act2 = nn.ReLU(inplace=True) if self.bin_centers_type == "normed" else nn.Softplus()
  417. def forward(self, x):
  418. """
  419. Returns tensor of bin_width vectors (centers). One vector b for every pixel
  420. """
  421. x = self.conv1(x)
  422. x = self.act1(x)
  423. x = self.conv2(x)
  424. bin_centers = self.act2(x)
  425. if self.bin_centers_type == "normed":
  426. bin_centers = bin_centers + 1e-3
  427. bin_widths_normed = bin_centers / bin_centers.sum(dim=1, keepdim=True)
  428. # shape (batch_size, num_channels, height, width)
  429. bin_widths = (self.max_depth - self.min_depth) * bin_widths_normed
  430. # pad has the form (left, right, top, bottom, front, back)
  431. bin_widths = nn.functional.pad(bin_widths, (0, 0, 0, 0, 1, 0), mode="constant", value=self.min_depth)
  432. # shape (batch_size, num_channels, height, width)
  433. bin_edges = torch.cumsum(bin_widths, dim=1)
  434. bin_centers = 0.5 * (bin_edges[:, :-1, ...] + bin_edges[:, 1:, ...])
  435. return bin_widths_normed, bin_centers
  436. else:
  437. return bin_centers, bin_centers
  438. @torch.jit.script
  439. def inv_attractor(dx, alpha: float = 300, gamma: int = 2):
  440. """Inverse attractor: dc = dx / (1 + alpha*dx^gamma), where dx = a - c, a = attractor point, c = bin center, dc = shift in bin center
  441. This is the default one according to the accompanying paper.
  442. Args:
  443. dx (`torch.Tensor`):
  444. The difference tensor dx = Ai - Cj, where Ai is the attractor point and Cj is the bin center.
  445. alpha (`float`, *optional*, defaults to 300):
  446. Proportional Attractor strength. Determines the absolute strength. Lower alpha = greater attraction.
  447. gamma (`int`, *optional*, defaults to 2):
  448. Exponential Attractor strength. Determines the "region of influence" and indirectly number of bin centers affected.
  449. Lower gamma = farther reach.
  450. Returns:
  451. torch.Tensor: Delta shifts - dc; New bin centers = Old bin centers + dc
  452. """
  453. return dx.div(1 + alpha * dx.pow(gamma))
  454. class ZoeDepthAttractorLayer(nn.Module):
  455. def __init__(
  456. self,
  457. config,
  458. n_bins,
  459. n_attractors=16,
  460. min_depth=1e-3,
  461. max_depth=10,
  462. memory_efficient=False,
  463. ):
  464. """
  465. Attractor layer for bin centers. Bin centers are bounded on the interval (min_depth, max_depth)
  466. """
  467. super().__init__()
  468. self.alpha = config.attractor_alpha
  469. self.gemma = config.attractor_gamma
  470. self.kind = config.attractor_kind
  471. self.n_attractors = n_attractors
  472. self.n_bins = n_bins
  473. self.min_depth = min_depth
  474. self.max_depth = max_depth
  475. self.memory_efficient = memory_efficient
  476. # MLP to predict attractor points
  477. in_features = mlp_dim = config.bin_embedding_dim
  478. self.conv1 = nn.Conv2d(in_features, mlp_dim, 1, 1, 0)
  479. self.act1 = nn.ReLU(inplace=True)
  480. self.conv2 = nn.Conv2d(mlp_dim, n_attractors * 2, 1, 1, 0) # x2 for linear norm
  481. self.act2 = nn.ReLU(inplace=True)
  482. def forward(self, x, prev_bin, prev_bin_embedding=None, interpolate=True):
  483. """
  484. The forward pass of the attractor layer. This layer predicts the new bin centers based on the previous bin centers
  485. and the attractor points (the latter are predicted by the MLP).
  486. Args:
  487. x (`torch.Tensor` of shape `(batch_size, num_channels, height, width)`):
  488. Feature block.
  489. prev_bin (`torch.Tensor` of shape `(batch_size, prev_number_of_bins, height, width)`):
  490. Previous bin centers normed.
  491. prev_bin_embedding (`torch.Tensor`, *optional*):
  492. Optional previous bin embeddings.
  493. interpolate (`bool`, *optional*, defaults to `True`):
  494. Whether to interpolate the previous bin embeddings to the size of the input features.
  495. Returns:
  496. `tuple[`torch.Tensor`, `torch.Tensor`]:
  497. New bin centers normed and scaled.
  498. """
  499. if prev_bin_embedding is not None:
  500. if interpolate:
  501. prev_bin_embedding = nn.functional.interpolate(
  502. prev_bin_embedding, x.shape[-2:], mode="bilinear", align_corners=True
  503. )
  504. x = x + prev_bin_embedding
  505. x = self.conv1(x)
  506. x = self.act1(x)
  507. x = self.conv2(x)
  508. attractors = self.act2(x)
  509. attractors = attractors + 1e-3
  510. batch_size, _, height, width = attractors.shape
  511. attractors = attractors.view(batch_size, self.n_attractors, 2, height, width)
  512. # batch_size, num_attractors, 2, height, width
  513. # note: original repo had a bug here: https://github.com/isl-org/ZoeDepth/blame/edb6daf45458569e24f50250ef1ed08c015f17a7/zoedepth/models/layers/attractor.py#L105C9-L106C50
  514. # we include the bug to maintain compatibility with the weights
  515. attractors_normed = attractors[:, :, 0, ...] # batch_size, batch_size*num_attractors, height, width
  516. bin_centers = nn.functional.interpolate(prev_bin, (height, width), mode="bilinear", align_corners=True)
  517. # note: only attractor_type = "exp" is supported here, since no checkpoints were released with other attractor types
  518. if not self.memory_efficient:
  519. func = {"mean": torch.mean, "sum": torch.sum}[self.kind]
  520. # shape (batch_size, num_bins, height, width)
  521. delta_c = func(inv_attractor(attractors_normed.unsqueeze(2) - bin_centers.unsqueeze(1)), dim=1)
  522. else:
  523. delta_c = torch.zeros_like(bin_centers, device=bin_centers.device)
  524. for i in range(self.n_attractors):
  525. # shape (batch_size, num_bins, height, width)
  526. delta_c += inv_attractor(attractors_normed[:, i, ...].unsqueeze(1) - bin_centers)
  527. if self.kind == "mean":
  528. delta_c = delta_c / self.n_attractors
  529. bin_new_centers = bin_centers + delta_c
  530. bin_centers = (self.max_depth - self.min_depth) * bin_new_centers + self.min_depth
  531. bin_centers, _ = torch.sort(bin_centers, dim=1)
  532. bin_centers = torch.clip(bin_centers, self.min_depth, self.max_depth)
  533. return bin_new_centers, bin_centers
  534. class ZoeDepthAttractorLayerUnnormed(nn.Module):
  535. def __init__(
  536. self,
  537. config,
  538. n_bins,
  539. n_attractors=16,
  540. min_depth=1e-3,
  541. max_depth=10,
  542. memory_efficient=True,
  543. ):
  544. """
  545. Attractor layer for bin centers. Bin centers are unbounded
  546. """
  547. super().__init__()
  548. self.n_attractors = n_attractors
  549. self.n_bins = n_bins
  550. self.min_depth = min_depth
  551. self.max_depth = max_depth
  552. self.alpha = config.attractor_alpha
  553. self.gamma = config.attractor_alpha
  554. self.kind = config.attractor_kind
  555. self.memory_efficient = memory_efficient
  556. in_features = mlp_dim = config.bin_embedding_dim
  557. self.conv1 = nn.Conv2d(in_features, mlp_dim, 1, 1, 0)
  558. self.act1 = nn.ReLU(inplace=True)
  559. self.conv2 = nn.Conv2d(mlp_dim, n_attractors, 1, 1, 0)
  560. self.act2 = nn.Softplus()
  561. def forward(self, x, prev_bin, prev_bin_embedding=None, interpolate=True):
  562. """
  563. The forward pass of the attractor layer. This layer predicts the new bin centers based on the previous bin centers
  564. and the attractor points (the latter are predicted by the MLP).
  565. Args:
  566. x (`torch.Tensor` of shape (batch_size, num_channels, height, width)`):
  567. Feature block.
  568. prev_bin (`torch.Tensor` of shape (batch_size, prev_num_bins, height, width)`):
  569. Previous bin centers normed.
  570. prev_bin_embedding (`torch.Tensor`, *optional*):
  571. Optional previous bin embeddings.
  572. interpolate (`bool`, *optional*, defaults to `True`):
  573. Whether to interpolate the previous bin embeddings to the size of the input features.
  574. Returns:
  575. `tuple[`torch.Tensor`, `torch.Tensor`]:
  576. New bin centers unbounded. Two outputs just to keep the API consistent with the normed version.
  577. """
  578. if prev_bin_embedding is not None:
  579. if interpolate:
  580. prev_bin_embedding = nn.functional.interpolate(
  581. prev_bin_embedding, x.shape[-2:], mode="bilinear", align_corners=True
  582. )
  583. x = x + prev_bin_embedding
  584. x = self.conv1(x)
  585. x = self.act1(x)
  586. x = self.conv2(x)
  587. attractors = self.act2(x)
  588. height, width = attractors.shape[-2:]
  589. bin_centers = nn.functional.interpolate(prev_bin, (height, width), mode="bilinear", align_corners=True)
  590. if not self.memory_efficient:
  591. func = {"mean": torch.mean, "sum": torch.sum}[self.kind]
  592. # shape batch_size, num_bins, height, width
  593. delta_c = func(inv_attractor(attractors.unsqueeze(2) - bin_centers.unsqueeze(1)), dim=1)
  594. else:
  595. delta_c = torch.zeros_like(bin_centers, device=bin_centers.device)
  596. for i in range(self.n_attractors):
  597. # shape batch_size, num_bins, height, width
  598. delta_c += inv_attractor(attractors[:, i, ...].unsqueeze(1) - bin_centers)
  599. if self.kind == "mean":
  600. delta_c = delta_c / self.n_attractors
  601. bin_new_centers = bin_centers + delta_c
  602. bin_centers = bin_new_centers
  603. return bin_new_centers, bin_centers
  604. class ZoeDepthProjector(nn.Module):
  605. def __init__(self, in_features, out_features, mlp_dim=128):
  606. """Projector MLP.
  607. Args:
  608. in_features (`int`):
  609. Number of input channels.
  610. out_features (`int`):
  611. Number of output channels.
  612. mlp_dim (`int`, *optional*, defaults to 128):
  613. Hidden dimension.
  614. """
  615. super().__init__()
  616. self.conv1 = nn.Conv2d(in_features, mlp_dim, 1, 1, 0)
  617. self.act = nn.ReLU(inplace=True)
  618. self.conv2 = nn.Conv2d(mlp_dim, out_features, 1, 1, 0)
  619. def forward(self, hidden_state: torch.Tensor) -> torch.Tensor:
  620. hidden_state = self.conv1(hidden_state)
  621. hidden_state = self.act(hidden_state)
  622. hidden_state = self.conv2(hidden_state)
  623. return hidden_state
  624. # Copied from transformers.models.grounding_dino.modeling_grounding_dino.GroundingDinoMultiheadAttention with GroundingDino->ZoeDepth
  625. class ZoeDepthMultiheadAttention(nn.Module):
  626. """Equivalent implementation of nn.MultiheadAttention with `batch_first=True`."""
  627. # Ignore copy
  628. def __init__(self, hidden_size, num_attention_heads, dropout):
  629. super().__init__()
  630. if hidden_size % num_attention_heads != 0:
  631. raise ValueError(
  632. f"The hidden size ({hidden_size}) is not a multiple of the number of attention "
  633. f"heads ({num_attention_heads})"
  634. )
  635. self.num_attention_heads = num_attention_heads
  636. self.attention_head_size = int(hidden_size / num_attention_heads)
  637. self.all_head_size = self.num_attention_heads * self.attention_head_size
  638. self.query = nn.Linear(hidden_size, self.all_head_size)
  639. self.key = nn.Linear(hidden_size, self.all_head_size)
  640. self.value = nn.Linear(hidden_size, self.all_head_size)
  641. self.out_proj = nn.Linear(hidden_size, hidden_size)
  642. self.dropout = nn.Dropout(dropout)
  643. def forward(
  644. self,
  645. queries: torch.Tensor,
  646. keys: torch.Tensor,
  647. values: torch.Tensor,
  648. attention_mask: Optional[torch.FloatTensor] = None,
  649. output_attentions: Optional[bool] = False,
  650. ) -> tuple[torch.Tensor]:
  651. batch_size, seq_length, _ = queries.shape
  652. query_layer = (
  653. self.query(queries)
  654. .view(batch_size, -1, self.num_attention_heads, self.attention_head_size)
  655. .transpose(1, 2)
  656. )
  657. key_layer = (
  658. self.key(keys).view(batch_size, -1, self.num_attention_heads, self.attention_head_size).transpose(1, 2)
  659. )
  660. value_layer = (
  661. self.value(values).view(batch_size, -1, self.num_attention_heads, self.attention_head_size).transpose(1, 2)
  662. )
  663. # Take the dot product between "query" and "key" to get the raw attention scores.
  664. attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
  665. attention_scores = attention_scores / math.sqrt(self.attention_head_size)
  666. if attention_mask is not None:
  667. # Apply the attention mask is (precomputed for all layers in ZoeDepthModel forward() function)
  668. attention_scores = attention_scores + attention_mask
  669. # Normalize the attention scores to probabilities.
  670. attention_probs = nn.functional.softmax(attention_scores, dim=-1)
  671. # This is actually dropping out entire tokens to attend to, which might
  672. # seem a bit unusual, but is taken from the original Transformer paper.
  673. attention_probs = self.dropout(attention_probs)
  674. context_layer = torch.matmul(attention_probs, value_layer)
  675. context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
  676. new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
  677. context_layer = context_layer.view(new_context_layer_shape)
  678. context_layer = self.out_proj(context_layer)
  679. outputs = (context_layer, attention_probs) if output_attentions else (context_layer,)
  680. return outputs
  681. class ZoeDepthTransformerEncoderLayer(nn.Module):
  682. def __init__(self, config, dropout=0.1, activation="relu"):
  683. super().__init__()
  684. hidden_size = config.patch_transformer_hidden_size
  685. intermediate_size = config.patch_transformer_intermediate_size
  686. num_attention_heads = config.patch_transformer_num_attention_heads
  687. self.self_attn = ZoeDepthMultiheadAttention(hidden_size, num_attention_heads, dropout=dropout)
  688. self.linear1 = nn.Linear(hidden_size, intermediate_size)
  689. self.dropout = nn.Dropout(dropout)
  690. self.linear2 = nn.Linear(intermediate_size, hidden_size)
  691. self.norm1 = nn.LayerNorm(hidden_size)
  692. self.norm2 = nn.LayerNorm(hidden_size)
  693. self.dropout1 = nn.Dropout(dropout)
  694. self.dropout2 = nn.Dropout(dropout)
  695. self.activation = ACT2FN[activation]
  696. def forward(
  697. self,
  698. src,
  699. src_mask: Optional[torch.Tensor] = None,
  700. ):
  701. queries = keys = src
  702. src2 = self.self_attn(queries=queries, keys=keys, values=src, attention_mask=src_mask)[0]
  703. src = src + self.dropout1(src2)
  704. src = self.norm1(src)
  705. src2 = self.linear2(self.dropout(self.activation(self.linear1(src))))
  706. src = src + self.dropout2(src2)
  707. src = self.norm2(src)
  708. return src
  709. class ZoeDepthPatchTransformerEncoder(nn.Module):
  710. def __init__(self, config):
  711. """ViT-like transformer block
  712. Args:
  713. config (`ZoeDepthConfig`):
  714. Model configuration class defining the model architecture.
  715. """
  716. super().__init__()
  717. in_channels = config.bottleneck_features
  718. self.transformer_encoder = nn.ModuleList(
  719. [ZoeDepthTransformerEncoderLayer(config) for _ in range(config.num_patch_transformer_layers)]
  720. )
  721. self.embedding_convPxP = nn.Conv2d(
  722. in_channels, config.patch_transformer_hidden_size, kernel_size=1, stride=1, padding=0
  723. )
  724. def positional_encoding_1d(self, batch_size, sequence_length, embedding_dim, device="cpu", dtype=torch.float32):
  725. """Generate positional encodings
  726. Args:
  727. sequence_length (int): Sequence length
  728. embedding_dim (int): Embedding dimension
  729. Returns:
  730. torch.Tensor: Positional encodings.
  731. """
  732. position = torch.arange(0, sequence_length, dtype=dtype, device=device).unsqueeze(1)
  733. index = torch.arange(0, embedding_dim, 2, dtype=dtype, device=device).unsqueeze(0)
  734. div_term = torch.exp(index * (-torch.log(torch.tensor(10000.0, device=device)) / embedding_dim))
  735. pos_encoding = position * div_term
  736. pos_encoding = torch.cat([torch.sin(pos_encoding), torch.cos(pos_encoding)], dim=1)
  737. pos_encoding = pos_encoding.unsqueeze(dim=0).repeat(batch_size, 1, 1)
  738. return pos_encoding
  739. def forward(self, x):
  740. """Forward pass
  741. Args:
  742. x (torch.Tensor - NCHW): Input feature tensor
  743. Returns:
  744. torch.Tensor - Transformer output embeddings of shape (batch_size, sequence_length, embedding_dim)
  745. """
  746. embeddings = self.embedding_convPxP(x).flatten(2) # shape (batch_size, num_channels, sequence_length)
  747. # add an extra special CLS token at the start for global accumulation
  748. embeddings = nn.functional.pad(embeddings, (1, 0))
  749. embeddings = embeddings.permute(0, 2, 1)
  750. batch_size, sequence_length, embedding_dim = embeddings.shape
  751. embeddings = embeddings + self.positional_encoding_1d(
  752. batch_size, sequence_length, embedding_dim, device=embeddings.device, dtype=embeddings.dtype
  753. )
  754. for i in range(4):
  755. embeddings = self.transformer_encoder[i](embeddings)
  756. return embeddings
  757. class ZoeDepthMLPClassifier(nn.Module):
  758. def __init__(self, in_features, out_features) -> None:
  759. super().__init__()
  760. hidden_features = in_features
  761. self.linear1 = nn.Linear(in_features, hidden_features)
  762. self.activation = nn.ReLU()
  763. self.linear2 = nn.Linear(hidden_features, out_features)
  764. def forward(self, hidden_state):
  765. hidden_state = self.linear1(hidden_state)
  766. hidden_state = self.activation(hidden_state)
  767. domain_logits = self.linear2(hidden_state)
  768. return domain_logits
  769. class ZoeDepthMultipleMetricDepthEstimationHeads(nn.Module):
  770. """
  771. Multiple metric depth estimation heads. A MLP classifier is used to route between 2 different heads.
  772. """
  773. def __init__(self, config):
  774. super().__init__()
  775. bin_embedding_dim = config.bin_embedding_dim
  776. n_attractors = config.num_attractors
  777. self.bin_configurations = config.bin_configurations
  778. self.bin_centers_type = config.bin_centers_type
  779. # Bottleneck convolution
  780. bottleneck_features = config.bottleneck_features
  781. self.conv2 = nn.Conv2d(bottleneck_features, bottleneck_features, kernel_size=1, stride=1, padding=0)
  782. # Transformer classifier on the bottleneck
  783. self.patch_transformer = ZoeDepthPatchTransformerEncoder(config)
  784. # MLP classifier
  785. self.mlp_classifier = ZoeDepthMLPClassifier(in_features=128, out_features=2)
  786. # Regressor and attractor
  787. if self.bin_centers_type == "normed":
  788. Attractor = ZoeDepthAttractorLayer
  789. elif self.bin_centers_type == "softplus":
  790. Attractor = ZoeDepthAttractorLayerUnnormed
  791. # We have bins for each bin configuration
  792. # Create a map (ModuleDict) of 'name' -> seed_bin_regressor
  793. self.seed_bin_regressors = nn.ModuleDict(
  794. {
  795. conf["name"]: ZoeDepthSeedBinRegressor(
  796. config,
  797. n_bins=conf["n_bins"],
  798. mlp_dim=bin_embedding_dim // 2,
  799. min_depth=conf["min_depth"],
  800. max_depth=conf["max_depth"],
  801. )
  802. for conf in config.bin_configurations
  803. }
  804. )
  805. self.seed_projector = ZoeDepthProjector(
  806. in_features=bottleneck_features, out_features=bin_embedding_dim, mlp_dim=bin_embedding_dim // 2
  807. )
  808. self.projectors = nn.ModuleList(
  809. [
  810. ZoeDepthProjector(
  811. in_features=config.fusion_hidden_size,
  812. out_features=bin_embedding_dim,
  813. mlp_dim=bin_embedding_dim // 2,
  814. )
  815. for _ in range(4)
  816. ]
  817. )
  818. # Create a map (ModuleDict) of 'name' -> attractors (ModuleList)
  819. self.attractors = nn.ModuleDict(
  820. {
  821. configuration["name"]: nn.ModuleList(
  822. [
  823. Attractor(
  824. config,
  825. n_bins=n_attractors[i],
  826. min_depth=configuration["min_depth"],
  827. max_depth=configuration["max_depth"],
  828. )
  829. for i in range(len(n_attractors))
  830. ]
  831. )
  832. for configuration in config.bin_configurations
  833. }
  834. )
  835. last_in = config.num_relative_features
  836. # conditional log binomial for each bin configuration
  837. self.conditional_log_binomial = nn.ModuleDict(
  838. {
  839. configuration["name"]: ZoeDepthConditionalLogBinomialSoftmax(
  840. config,
  841. last_in,
  842. bin_embedding_dim,
  843. configuration["n_bins"],
  844. bottleneck_factor=4,
  845. )
  846. for configuration in config.bin_configurations
  847. }
  848. )
  849. def forward(self, outconv_activation, bottleneck, feature_blocks, relative_depth):
  850. x = self.conv2(bottleneck)
  851. # Predict which path to take
  852. # Embedding is of shape (batch_size, hidden_size)
  853. embedding = self.patch_transformer(x)[:, 0, :]
  854. # MLP classifier to get logits of shape (batch_size, 2)
  855. domain_logits = self.mlp_classifier(embedding)
  856. domain_vote = torch.softmax(domain_logits.sum(dim=0, keepdim=True), dim=-1)
  857. # Get the path
  858. names = [configuration["name"] for configuration in self.bin_configurations]
  859. bin_configurations_name = names[torch.argmax(domain_vote, dim=-1).squeeze().item()]
  860. try:
  861. conf = [config for config in self.bin_configurations if config["name"] == bin_configurations_name][0]
  862. except IndexError:
  863. raise ValueError(f"bin_configurations_name {bin_configurations_name} not found in bin_configurationss")
  864. min_depth = conf["min_depth"]
  865. max_depth = conf["max_depth"]
  866. seed_bin_regressor = self.seed_bin_regressors[bin_configurations_name]
  867. _, seed_bin_centers = seed_bin_regressor(x)
  868. if self.bin_centers_type in ["normed", "hybrid2"]:
  869. prev_bin = (seed_bin_centers - min_depth) / (max_depth - min_depth)
  870. else:
  871. prev_bin = seed_bin_centers
  872. prev_bin_embedding = self.seed_projector(x)
  873. attractors = self.attractors[bin_configurations_name]
  874. for projector, attractor, feature in zip(self.projectors, attractors, feature_blocks):
  875. bin_embedding = projector(feature)
  876. bin, bin_centers = attractor(bin_embedding, prev_bin, prev_bin_embedding, interpolate=True)
  877. prev_bin = bin
  878. prev_bin_embedding = bin_embedding
  879. last = outconv_activation
  880. bin_centers = nn.functional.interpolate(bin_centers, last.shape[-2:], mode="bilinear", align_corners=True)
  881. bin_embedding = nn.functional.interpolate(bin_embedding, last.shape[-2:], mode="bilinear", align_corners=True)
  882. conditional_log_binomial = self.conditional_log_binomial[bin_configurations_name]
  883. x = conditional_log_binomial(last, bin_embedding)
  884. # Now depth value is Sum px * cx , where cx are bin_centers from the last bin tensor
  885. out = torch.sum(x * bin_centers, dim=1, keepdim=True)
  886. return out, domain_logits
  887. class ZoeDepthMetricDepthEstimationHead(nn.Module):
  888. def __init__(self, config):
  889. super().__init__()
  890. bin_configuration = config.bin_configurations[0]
  891. n_bins = bin_configuration["n_bins"]
  892. min_depth = bin_configuration["min_depth"]
  893. max_depth = bin_configuration["max_depth"]
  894. bin_embedding_dim = config.bin_embedding_dim
  895. n_attractors = config.num_attractors
  896. bin_centers_type = config.bin_centers_type
  897. self.min_depth = min_depth
  898. self.max_depth = max_depth
  899. self.bin_centers_type = bin_centers_type
  900. # Bottleneck convolution
  901. bottleneck_features = config.bottleneck_features
  902. self.conv2 = nn.Conv2d(bottleneck_features, bottleneck_features, kernel_size=1, stride=1, padding=0)
  903. # Regressor and attractor
  904. if self.bin_centers_type == "normed":
  905. Attractor = ZoeDepthAttractorLayer
  906. elif self.bin_centers_type == "softplus":
  907. Attractor = ZoeDepthAttractorLayerUnnormed
  908. self.seed_bin_regressor = ZoeDepthSeedBinRegressor(
  909. config, n_bins=n_bins, min_depth=min_depth, max_depth=max_depth
  910. )
  911. self.seed_projector = ZoeDepthProjector(in_features=bottleneck_features, out_features=bin_embedding_dim)
  912. self.projectors = nn.ModuleList(
  913. [
  914. ZoeDepthProjector(in_features=config.fusion_hidden_size, out_features=bin_embedding_dim)
  915. for _ in range(4)
  916. ]
  917. )
  918. self.attractors = nn.ModuleList(
  919. [
  920. Attractor(
  921. config,
  922. n_bins=n_bins,
  923. n_attractors=n_attractors[i],
  924. min_depth=min_depth,
  925. max_depth=max_depth,
  926. )
  927. for i in range(4)
  928. ]
  929. )
  930. last_in = config.num_relative_features + 1 # +1 for relative depth
  931. # use log binomial instead of softmax
  932. self.conditional_log_binomial = ZoeDepthConditionalLogBinomialSoftmax(
  933. config,
  934. last_in,
  935. bin_embedding_dim,
  936. n_classes=n_bins,
  937. )
  938. def forward(self, outconv_activation, bottleneck, feature_blocks, relative_depth):
  939. x = self.conv2(bottleneck)
  940. _, seed_bin_centers = self.seed_bin_regressor(x)
  941. if self.bin_centers_type in ["normed", "hybrid2"]:
  942. prev_bin = (seed_bin_centers - self.min_depth) / (self.max_depth - self.min_depth)
  943. else:
  944. prev_bin = seed_bin_centers
  945. prev_bin_embedding = self.seed_projector(x)
  946. # unroll this loop for better performance
  947. for projector, attractor, feature in zip(self.projectors, self.attractors, feature_blocks):
  948. bin_embedding = projector(feature)
  949. bin, bin_centers = attractor(bin_embedding, prev_bin, prev_bin_embedding, interpolate=True)
  950. prev_bin = bin.clone()
  951. prev_bin_embedding = bin_embedding.clone()
  952. last = outconv_activation
  953. # concatenative relative depth with last. First interpolate relative depth to last size
  954. relative_conditioning = relative_depth.unsqueeze(1)
  955. relative_conditioning = nn.functional.interpolate(
  956. relative_conditioning, size=last.shape[2:], mode="bilinear", align_corners=True
  957. )
  958. last = torch.cat([last, relative_conditioning], dim=1)
  959. bin_embedding = nn.functional.interpolate(bin_embedding, last.shape[-2:], mode="bilinear", align_corners=True)
  960. x = self.conditional_log_binomial(last, bin_embedding)
  961. # Now depth value is Sum px * cx , where cx are bin_centers from the last bin tensor
  962. bin_centers = nn.functional.interpolate(bin_centers, x.shape[-2:], mode="bilinear", align_corners=True)
  963. out = torch.sum(x * bin_centers, dim=1, keepdim=True)
  964. return out, None
  965. # Modified from transformers.models.dpt.modeling_dpt.DPTPreTrainedModel with DPT->ZoeDepth,dpt->zoedepth
  966. # avoiding sdpa and flash_attn_2 support, it's done int the backend
  967. @auto_docstring
  968. class ZoeDepthPreTrainedModel(PreTrainedModel):
  969. config: ZoeDepthConfig
  970. base_model_prefix = "zoedepth"
  971. main_input_name = "pixel_values"
  972. supports_gradient_checkpointing = True
  973. def _init_weights(self, module):
  974. """Initialize the weights"""
  975. if isinstance(module, (nn.Linear, nn.Conv2d, nn.ConvTranspose2d)):
  976. # Slightly different from the TF version which uses truncated_normal for initialization
  977. # cf https://github.com/pytorch/pytorch/pull/5617
  978. module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
  979. if module.bias is not None:
  980. module.bias.data.zero_()
  981. elif isinstance(module, nn.LayerNorm):
  982. module.bias.data.zero_()
  983. module.weight.data.fill_(1.0)
  984. @auto_docstring(
  985. custom_intro="""
  986. ZoeDepth model with one or multiple metric depth estimation head(s) on top.
  987. """
  988. )
  989. class ZoeDepthForDepthEstimation(ZoeDepthPreTrainedModel):
  990. def __init__(self, config):
  991. super().__init__(config)
  992. self.backbone = load_backbone(config)
  993. if hasattr(self.backbone.config, "hidden_size") and hasattr(self.backbone.config, "patch_size"):
  994. config.backbone_hidden_size = self.backbone.config.hidden_size
  995. self.patch_size = self.backbone.config.patch_size
  996. else:
  997. raise ValueError(
  998. "ZoeDepth assumes the backbone's config to have `hidden_size` and `patch_size` attributes"
  999. )
  1000. self.neck = ZoeDepthNeck(config)
  1001. self.relative_head = ZoeDepthRelativeDepthEstimationHead(config)
  1002. self.metric_head = (
  1003. ZoeDepthMultipleMetricDepthEstimationHeads(config)
  1004. if len(config.bin_configurations) > 1
  1005. else ZoeDepthMetricDepthEstimationHead(config)
  1006. )
  1007. # Initialize weights and apply final processing
  1008. self.post_init()
  1009. @auto_docstring
  1010. def forward(
  1011. self,
  1012. pixel_values: torch.FloatTensor,
  1013. labels: Optional[torch.LongTensor] = None,
  1014. output_attentions: Optional[bool] = None,
  1015. output_hidden_states: Optional[bool] = None,
  1016. return_dict: Optional[bool] = None,
  1017. ) -> Union[tuple[torch.Tensor], DepthEstimatorOutput]:
  1018. r"""
  1019. labels (`torch.LongTensor` of shape `(batch_size, height, width)`, *optional*):
  1020. Ground truth depth estimation maps for computing the loss.
  1021. Examples:
  1022. ```python
  1023. >>> from transformers import AutoImageProcessor, ZoeDepthForDepthEstimation
  1024. >>> import torch
  1025. >>> import numpy as np
  1026. >>> from PIL import Image
  1027. >>> import requests
  1028. >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
  1029. >>> image = Image.open(requests.get(url, stream=True).raw)
  1030. >>> image_processor = AutoImageProcessor.from_pretrained("Intel/zoedepth-nyu-kitti")
  1031. >>> model = ZoeDepthForDepthEstimation.from_pretrained("Intel/zoedepth-nyu-kitti")
  1032. >>> # prepare image for the model
  1033. >>> inputs = image_processor(images=image, return_tensors="pt")
  1034. >>> with torch.no_grad():
  1035. ... outputs = model(**inputs)
  1036. >>> # interpolate to original size
  1037. >>> post_processed_output = image_processor.post_process_depth_estimation(
  1038. ... outputs,
  1039. ... source_sizes=[(image.height, image.width)],
  1040. ... )
  1041. >>> # visualize the prediction
  1042. >>> predicted_depth = post_processed_output[0]["predicted_depth"]
  1043. >>> depth = predicted_depth * 255 / predicted_depth.max()
  1044. >>> depth = depth.detach().cpu().numpy()
  1045. >>> depth = Image.fromarray(depth.astype("uint8"))
  1046. ```"""
  1047. loss = None
  1048. if labels is not None:
  1049. raise NotImplementedError("Training is not implemented yet")
  1050. return_dict = return_dict if return_dict is not None else self.config.use_return_dict
  1051. output_hidden_states = (
  1052. output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
  1053. )
  1054. output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
  1055. outputs = self.backbone.forward_with_filtered_kwargs(
  1056. pixel_values, output_hidden_states=output_hidden_states, output_attentions=output_attentions
  1057. )
  1058. hidden_states = outputs.feature_maps
  1059. _, _, height, width = pixel_values.shape
  1060. patch_size = self.patch_size
  1061. patch_height = height // patch_size
  1062. patch_width = width // patch_size
  1063. hidden_states, features = self.neck(hidden_states, patch_height, patch_width)
  1064. out = [features] + hidden_states
  1065. relative_depth, features = self.relative_head(hidden_states)
  1066. out = [features] + out
  1067. metric_depth, domain_logits = self.metric_head(
  1068. outconv_activation=out[0], bottleneck=out[1], feature_blocks=out[2:], relative_depth=relative_depth
  1069. )
  1070. metric_depth = metric_depth.squeeze(dim=1)
  1071. if not return_dict:
  1072. if domain_logits is not None:
  1073. output = (metric_depth, domain_logits) + outputs[1:]
  1074. else:
  1075. output = (metric_depth,) + outputs[1:]
  1076. return ((loss,) + output) if loss is not None else output
  1077. return ZoeDepthDepthEstimatorOutput(
  1078. loss=loss,
  1079. predicted_depth=metric_depth,
  1080. domain_logits=domain_logits,
  1081. hidden_states=outputs.hidden_states,
  1082. attentions=outputs.attentions,
  1083. )
  1084. __all__ = ["ZoeDepthForDepthEstimation", "ZoeDepthPreTrainedModel"]