drop.py 8.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225
  1. """ DropBlock, DropPath
  2. PyTorch implementations of DropBlock and DropPath (Stochastic Depth) regularization layers.
  3. Papers:
  4. DropBlock: A regularization method for convolutional networks (https://arxiv.org/abs/1810.12890)
  5. Deep Networks with Stochastic Depth (https://arxiv.org/abs/1603.09382)
  6. Code:
  7. DropBlock impl inspired by two Tensorflow impl that I liked:
  8. - https://github.com/tensorflow/tpu/blob/master/models/official/resnet/resnet_model.py#L74
  9. - https://github.com/clovaai/assembled-cnn/blob/master/nets/blocks.py
  10. Hacked together by / Copyright 2020 Ross Wightman
  11. """
  12. from typing import List, Union
  13. import torch
  14. import torch.nn as nn
  15. import torch.nn.functional as F
  16. from .grid import ndgrid
  17. def drop_block_2d(
  18. x,
  19. drop_prob: float = 0.1,
  20. block_size: int = 7,
  21. gamma_scale: float = 1.0,
  22. with_noise: bool = False,
  23. inplace: bool = False,
  24. batchwise: bool = False
  25. ):
  26. """ DropBlock. See https://arxiv.org/pdf/1810.12890.pdf
  27. DropBlock with an experimental gaussian noise option. This layer has been tested on a few training
  28. runs with success, but needs further validation and possibly optimization for lower runtime impact.
  29. """
  30. B, C, H, W = x.shape
  31. total_size = W * H
  32. clipped_block_size = min(block_size, min(W, H))
  33. # seed_drop_rate, the gamma parameter
  34. gamma = gamma_scale * drop_prob * total_size / clipped_block_size ** 2 / (
  35. (W - block_size + 1) * (H - block_size + 1))
  36. # Forces the block to be inside the feature map.
  37. w_i, h_i = ndgrid(torch.arange(W, device=x.device), torch.arange(H, device=x.device))
  38. valid_block = ((w_i >= clipped_block_size // 2) & (w_i < W - (clipped_block_size - 1) // 2)) & \
  39. ((h_i >= clipped_block_size // 2) & (h_i < H - (clipped_block_size - 1) // 2))
  40. valid_block = torch.reshape(valid_block, (1, 1, H, W)).to(dtype=x.dtype)
  41. if batchwise:
  42. # one mask for whole batch, quite a bit faster
  43. uniform_noise = torch.rand((1, C, H, W), dtype=x.dtype, device=x.device)
  44. else:
  45. uniform_noise = torch.rand_like(x)
  46. block_mask = ((2 - gamma - valid_block + uniform_noise) >= 1).to(dtype=x.dtype)
  47. block_mask = -F.max_pool2d(
  48. -block_mask,
  49. kernel_size=clipped_block_size, # block_size,
  50. stride=1,
  51. padding=clipped_block_size // 2)
  52. if with_noise:
  53. normal_noise = torch.randn((1, C, H, W), dtype=x.dtype, device=x.device) if batchwise else torch.randn_like(x)
  54. if inplace:
  55. x.mul_(block_mask).add_(normal_noise * (1 - block_mask))
  56. else:
  57. x = x * block_mask + normal_noise * (1 - block_mask)
  58. else:
  59. normalize_scale = (block_mask.numel() / block_mask.to(dtype=torch.float32).sum().add(1e-7)).to(x.dtype)
  60. if inplace:
  61. x.mul_(block_mask * normalize_scale)
  62. else:
  63. x = x * block_mask * normalize_scale
  64. return x
  65. def drop_block_fast_2d(
  66. x: torch.Tensor,
  67. drop_prob: float = 0.1,
  68. block_size: int = 7,
  69. gamma_scale: float = 1.0,
  70. with_noise: bool = False,
  71. inplace: bool = False,
  72. ):
  73. """ DropBlock. See https://arxiv.org/pdf/1810.12890.pdf
  74. DropBlock with an experimental gaussian noise option. Simplied from above without concern for valid
  75. block mask at edges.
  76. """
  77. B, C, H, W = x.shape
  78. total_size = W * H
  79. clipped_block_size = min(block_size, min(W, H))
  80. gamma = gamma_scale * drop_prob * total_size / clipped_block_size ** 2 / (
  81. (W - block_size + 1) * (H - block_size + 1))
  82. block_mask = torch.empty_like(x).bernoulli_(gamma)
  83. block_mask = F.max_pool2d(
  84. block_mask.to(x.dtype), kernel_size=clipped_block_size, stride=1, padding=clipped_block_size // 2)
  85. if with_noise:
  86. normal_noise = torch.empty_like(x).normal_()
  87. if inplace:
  88. x.mul_(1. - block_mask).add_(normal_noise * block_mask)
  89. else:
  90. x = x * (1. - block_mask) + normal_noise * block_mask
  91. else:
  92. block_mask = 1 - block_mask
  93. normalize_scale = (block_mask.numel() / block_mask.to(dtype=torch.float32).sum().add(1e-6)).to(dtype=x.dtype)
  94. if inplace:
  95. x.mul_(block_mask * normalize_scale)
  96. else:
  97. x = x * block_mask * normalize_scale
  98. return x
  99. class DropBlock2d(nn.Module):
  100. """ DropBlock. See https://arxiv.org/pdf/1810.12890.pdf
  101. """
  102. def __init__(
  103. self,
  104. drop_prob: float = 0.1,
  105. block_size: int = 7,
  106. gamma_scale: float = 1.0,
  107. with_noise: bool = False,
  108. inplace: bool = False,
  109. batchwise: bool = False,
  110. fast: bool = True):
  111. super().__init__()
  112. self.drop_prob = drop_prob
  113. self.gamma_scale = gamma_scale
  114. self.block_size = block_size
  115. self.with_noise = with_noise
  116. self.inplace = inplace
  117. self.batchwise = batchwise
  118. self.fast = fast # FIXME finish comparisons of fast vs not
  119. def forward(self, x):
  120. if not self.training or not self.drop_prob:
  121. return x
  122. if self.fast:
  123. return drop_block_fast_2d(
  124. x, self.drop_prob, self.block_size, self.gamma_scale, self.with_noise, self.inplace)
  125. else:
  126. return drop_block_2d(
  127. x, self.drop_prob, self.block_size, self.gamma_scale, self.with_noise, self.inplace, self.batchwise)
  128. def drop_path(x, drop_prob: float = 0., training: bool = False, scale_by_keep: bool = True):
  129. """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
  130. This is the same as the DropConnect impl I created for EfficientNet, etc networks, however,
  131. the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper...
  132. See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for
  133. changing the layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use
  134. 'survival rate' as the argument.
  135. """
  136. if drop_prob == 0. or not training:
  137. return x
  138. keep_prob = 1 - drop_prob
  139. shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets
  140. random_tensor = x.new_empty(shape).bernoulli_(keep_prob)
  141. if keep_prob > 0.0 and scale_by_keep:
  142. random_tensor.div_(keep_prob)
  143. return x * random_tensor
  144. class DropPath(nn.Module):
  145. """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
  146. """
  147. def __init__(self, drop_prob: float = 0., scale_by_keep: bool = True):
  148. super().__init__()
  149. self.drop_prob = drop_prob
  150. self.scale_by_keep = scale_by_keep
  151. def forward(self, x):
  152. return drop_path(x, self.drop_prob, self.training, self.scale_by_keep)
  153. def extra_repr(self):
  154. return f'drop_prob={round(self.drop_prob,3):0.3f}'
  155. def calculate_drop_path_rates(
  156. drop_path_rate: float,
  157. depths: Union[int, List[int]],
  158. stagewise: bool = False,
  159. ) -> Union[List[float], List[List[float]]]:
  160. """Generate drop path rates for stochastic depth.
  161. This function handles two common patterns for drop path rate scheduling:
  162. 1. Per-block: Linear increase from 0 to drop_path_rate across all blocks
  163. 2. Stage-wise: Linear increase across stages, with same rate within each stage
  164. Args:
  165. drop_path_rate: Maximum drop path rate (at the end).
  166. depths: Either a single int for total depth (per-block mode) or
  167. list of ints for depths per stage (stage-wise mode).
  168. stagewise: If True, use stage-wise pattern. If False, use per-block pattern.
  169. When depths is a list, stagewise defaults to True.
  170. Returns:
  171. For per-block mode: List of drop rates, one per block.
  172. For stage-wise mode: List of lists, drop rates per stage.
  173. """
  174. if isinstance(depths, int):
  175. # Single depth value - per-block pattern
  176. if stagewise:
  177. raise ValueError("stagewise=True requires depths to be a list of stage depths")
  178. dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depths, device='cpu')]
  179. return dpr
  180. else:
  181. # List of depths - can be either pattern
  182. total_depth = sum(depths)
  183. if stagewise:
  184. # Stage-wise pattern: same drop rate within each stage
  185. dpr = [x.tolist() for x in torch.linspace(0, drop_path_rate, total_depth, device='cpu').split(depths)]
  186. return dpr
  187. else:
  188. # Per-block pattern across all stages
  189. dpr = [x.item() for x in torch.linspace(0, drop_path_rate, total_depth, device='cpu')]
  190. return dpr