pooling.py 59 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126112711281129113011311132113311341135113611371138113911401141114211431144114511461147114811491150115111521153115411551156115711581159116011611162116311641165116611671168116911701171117211731174117511761177117811791180118111821183118411851186118711881189119011911192119311941195119611971198119912001201120212031204120512061207120812091210121112121213121412151216121712181219122012211222122312241225122612271228122912301231123212331234123512361237123812391240124112421243124412451246124712481249125012511252125312541255125612571258125912601261126212631264126512661267126812691270127112721273127412751276127712781279128012811282128312841285128612871288128912901291129212931294129512961297129812991300130113021303130413051306130713081309131013111312131313141315131613171318131913201321132213231324132513261327132813291330133113321333133413351336133713381339134013411342134313441345134613471348134913501351135213531354135513561357135813591360136113621363136413651366136713681369137013711372137313741375137613771378137913801381138213831384138513861387138813891390139113921393139413951396139713981399140014011402140314041405140614071408140914101411141214131414141514161417141814191420142114221423142414251426142714281429143014311432143314341435143614371438143914401441144214431444144514461447144814491450145114521453145414551456145714581459146014611462146314641465146614671468146914701471147214731474147514761477147814791480148114821483148414851486148714881489149014911492149314941495149614971498149915001501150215031504150515061507150815091510151115121513151415151516151715181519152015211522152315241525152615271528152915301531153215331534153515361537153815391540
  1. from typing import Optional
  2. import torch.nn.functional as F
  3. from torch import Tensor
  4. from torch.nn.common_types import (
  5. _ratio_2_t,
  6. _ratio_3_t,
  7. _size_1_t,
  8. _size_2_opt_t,
  9. _size_2_t,
  10. _size_3_opt_t,
  11. _size_3_t,
  12. _size_any_opt_t,
  13. _size_any_t,
  14. )
  15. from .module import Module
  16. from .utils import _pair, _single, _triple
  17. __all__ = [
  18. "MaxPool1d",
  19. "MaxPool2d",
  20. "MaxPool3d",
  21. "MaxUnpool1d",
  22. "MaxUnpool2d",
  23. "MaxUnpool3d",
  24. "AvgPool1d",
  25. "AvgPool2d",
  26. "AvgPool3d",
  27. "FractionalMaxPool2d",
  28. "FractionalMaxPool3d",
  29. "LPPool1d",
  30. "LPPool2d",
  31. "LPPool3d",
  32. "AdaptiveMaxPool1d",
  33. "AdaptiveMaxPool2d",
  34. "AdaptiveMaxPool3d",
  35. "AdaptiveAvgPool1d",
  36. "AdaptiveAvgPool2d",
  37. "AdaptiveAvgPool3d",
  38. ]
  39. class _MaxPoolNd(Module):
  40. __constants__ = [
  41. "kernel_size",
  42. "stride",
  43. "padding",
  44. "dilation",
  45. "return_indices",
  46. "ceil_mode",
  47. ]
  48. return_indices: bool
  49. ceil_mode: bool
  50. def __init__(
  51. self,
  52. kernel_size: _size_any_t,
  53. stride: Optional[_size_any_t] = None,
  54. padding: _size_any_t = 0,
  55. dilation: _size_any_t = 1,
  56. return_indices: bool = False,
  57. ceil_mode: bool = False,
  58. ) -> None:
  59. super().__init__()
  60. self.kernel_size = kernel_size
  61. self.stride = stride if (stride is not None) else kernel_size
  62. self.padding = padding
  63. self.dilation = dilation
  64. self.return_indices = return_indices
  65. self.ceil_mode = ceil_mode
  66. def extra_repr(self) -> str:
  67. return (
  68. "kernel_size={kernel_size}, stride={stride}, padding={padding}"
  69. ", dilation={dilation}, ceil_mode={ceil_mode}".format(**self.__dict__)
  70. )
  71. class MaxPool1d(_MaxPoolNd):
  72. r"""Applies a 1D max pooling over an input signal composed of several input planes.
  73. In the simplest case, the output value of the layer with input size :math:`(N, C, L)`
  74. and output :math:`(N, C, L_{out})` can be precisely described as:
  75. .. math::
  76. out(N_i, C_j, k) = \max_{m=0, \ldots, \text{kernel\_size} - 1}
  77. input(N_i, C_j, stride \times k + m)
  78. If :attr:`padding` is non-zero, then the input is implicitly padded with negative infinity on both sides
  79. for :attr:`padding` number of points. :attr:`dilation` is the stride between the elements within the
  80. sliding window. This `link`_ has a nice visualization of the pooling parameters.
  81. Note:
  82. When ceil_mode=True, sliding windows are allowed to go off-bounds if they start within the left padding
  83. or the input. Sliding windows that would start in the right padded region are ignored.
  84. Args:
  85. kernel_size: The size of the sliding window, must be > 0.
  86. stride: The stride of the sliding window, must be > 0. Default value is :attr:`kernel_size`.
  87. padding: Implicit negative infinity padding to be added on both sides, must be >= 0 and <= kernel_size / 2.
  88. dilation: The stride between elements within a sliding window, must be > 0.
  89. return_indices: If ``True``, will return the argmax along with the max values.
  90. Useful for :class:`torch.nn.MaxUnpool1d` later
  91. ceil_mode: If ``True``, will use `ceil` instead of `floor` to compute the output shape. This
  92. ensures that every element in the input tensor is covered by a sliding window.
  93. Shape:
  94. - Input: :math:`(N, C, L_{in})` or :math:`(C, L_{in})`.
  95. - Output: :math:`(N, C, L_{out})` or :math:`(C, L_{out})`,
  96. where ``ceil_mode = False``
  97. .. math::
  98. L_{out} = \left\lfloor \frac{L_{in} + 2 \times \text{padding} - \text{dilation}
  99. \times (\text{kernel\_size} - 1) - 1}{\text{stride}}\right\rfloor + 1
  100. where ``ceil_mode = True``
  101. .. math::
  102. L_{out} = \left\lceil \frac{L_{in} + 2 \times \text{padding} - \text{dilation}
  103. \times (\text{kernel\_size} - 1) - 1 + (stride - 1)}{\text{stride}}\right\rceil + 1
  104. - Ensure that the last pooling starts inside the image, make :math:`L_{out} = L_{out} - 1`
  105. when :math:`(L_{out} - 1) * \text{stride} >= L_{in} + \text{padding}`.
  106. Examples::
  107. >>> # pool of size=3, stride=2
  108. >>> m = nn.MaxPool1d(3, stride=2)
  109. >>> input = torch.randn(20, 16, 50)
  110. >>> output = m(input)
  111. .. _link:
  112. https://github.com/vdumoulin/conv_arithmetic/blob/master/README.md
  113. """
  114. kernel_size: _size_1_t
  115. stride: _size_1_t
  116. padding: _size_1_t
  117. dilation: _size_1_t
  118. def forward(self, input: Tensor):
  119. """Runs the forward pass."""
  120. return F.max_pool1d(
  121. input,
  122. self.kernel_size,
  123. self.stride,
  124. self.padding,
  125. self.dilation,
  126. ceil_mode=self.ceil_mode,
  127. return_indices=self.return_indices,
  128. )
  129. class MaxPool2d(_MaxPoolNd):
  130. r"""Applies a 2D max pooling over an input signal composed of several input planes.
  131. In the simplest case, the output value of the layer with input size :math:`(N, C, H, W)`,
  132. output :math:`(N, C, H_{out}, W_{out})` and :attr:`kernel_size` :math:`(kH, kW)`
  133. can be precisely described as:
  134. .. math::
  135. \begin{aligned}
  136. out(N_i, C_j, h, w) ={} & \max_{m=0, \ldots, kH-1} \max_{n=0, \ldots, kW-1} \\
  137. & \text{input}(N_i, C_j, \text{stride[0]} \times h + m,
  138. \text{stride[1]} \times w + n)
  139. \end{aligned}
  140. If :attr:`padding` is non-zero, then the input is implicitly padded with negative infinity on both sides
  141. for :attr:`padding` number of points. :attr:`dilation` controls the spacing between the kernel points.
  142. It is harder to describe, but this `link`_ has a nice visualization of what :attr:`dilation` does.
  143. Note:
  144. When ceil_mode=True, sliding windows are allowed to go off-bounds if they start within the left padding
  145. or the input. Sliding windows that would start in the right padded region are ignored.
  146. The parameters :attr:`kernel_size`, :attr:`stride`, :attr:`padding`, :attr:`dilation` can either be:
  147. - a single ``int`` -- in which case the same value is used for the height and width dimension
  148. - a ``tuple`` of two ints -- in which case, the first `int` is used for the height dimension,
  149. and the second `int` for the width dimension
  150. Args:
  151. kernel_size: the size of the window to take a max over
  152. stride: the stride of the window. Default value is :attr:`kernel_size`
  153. padding: Implicit negative infinity padding to be added on both sides
  154. dilation: a parameter that controls the stride of elements in the window
  155. return_indices: if ``True``, will return the max indices along with the outputs.
  156. Useful for :class:`torch.nn.MaxUnpool2d` later
  157. ceil_mode: when True, will use `ceil` instead of `floor` to compute the output shape
  158. Shape:
  159. - Input: :math:`(N, C, H_{in}, W_{in})` or :math:`(C, H_{in}, W_{in})`
  160. - Output: :math:`(N, C, H_{out}, W_{out})` or :math:`(C, H_{out}, W_{out})`, where
  161. .. math::
  162. H_{out} = \left\lfloor\frac{H_{in} + 2 * \text{padding[0]} - \text{dilation[0]}
  163. \times (\text{kernel\_size[0]} - 1) - 1}{\text{stride[0]}} + 1\right\rfloor
  164. .. math::
  165. W_{out} = \left\lfloor\frac{W_{in} + 2 * \text{padding[1]} - \text{dilation[1]}
  166. \times (\text{kernel\_size[1]} - 1) - 1}{\text{stride[1]}} + 1\right\rfloor
  167. Examples::
  168. >>> # pool of square window of size=3, stride=2
  169. >>> m = nn.MaxPool2d(3, stride=2)
  170. >>> # pool of non-square window
  171. >>> m = nn.MaxPool2d((3, 2), stride=(2, 1))
  172. >>> input = torch.randn(20, 16, 50, 32)
  173. >>> output = m(input)
  174. .. _link:
  175. https://github.com/vdumoulin/conv_arithmetic/blob/master/README.md
  176. """
  177. kernel_size: _size_2_t
  178. stride: _size_2_t
  179. padding: _size_2_t
  180. dilation: _size_2_t
  181. def forward(self, input: Tensor):
  182. """Runs the forward pass."""
  183. return F.max_pool2d(
  184. input,
  185. self.kernel_size,
  186. self.stride,
  187. self.padding,
  188. self.dilation,
  189. ceil_mode=self.ceil_mode,
  190. return_indices=self.return_indices,
  191. )
  192. class MaxPool3d(_MaxPoolNd):
  193. r"""Applies a 3D max pooling over an input signal composed of several input planes.
  194. In the simplest case, the output value of the layer with input size :math:`(N, C, D, H, W)`,
  195. output :math:`(N, C, D_{out}, H_{out}, W_{out})` and :attr:`kernel_size` :math:`(kD, kH, kW)`
  196. can be precisely described as:
  197. .. math::
  198. \begin{aligned}
  199. \text{out}(N_i, C_j, d, h, w) ={} & \max_{k=0, \ldots, kD-1} \max_{m=0, \ldots, kH-1} \max_{n=0, \ldots, kW-1} \\
  200. & \text{input}(N_i, C_j, \text{stride[0]} \times d + k,
  201. \text{stride[1]} \times h + m, \text{stride[2]} \times w + n)
  202. \end{aligned}
  203. If :attr:`padding` is non-zero, then the input is implicitly padded with negative infinity on both sides
  204. for :attr:`padding` number of points. :attr:`dilation` controls the spacing between the kernel points.
  205. It is harder to describe, but this `link`_ has a nice visualization of what :attr:`dilation` does.
  206. Note:
  207. When ceil_mode=True, sliding windows are allowed to go off-bounds if they start within the left padding
  208. or the input. Sliding windows that would start in the right padded region are ignored.
  209. The parameters :attr:`kernel_size`, :attr:`stride`, :attr:`padding`, :attr:`dilation` can either be:
  210. - a single ``int`` -- in which case the same value is used for the depth, height and width dimension
  211. - a ``tuple`` of three ints -- in which case, the first `int` is used for the depth dimension,
  212. the second `int` for the height dimension and the third `int` for the width dimension
  213. Args:
  214. kernel_size: the size of the window to take a max over
  215. stride: the stride of the window. Default value is :attr:`kernel_size`
  216. padding: Implicit negative infinity padding to be added on all three sides
  217. dilation: a parameter that controls the stride of elements in the window
  218. return_indices: if ``True``, will return the max indices along with the outputs.
  219. Useful for :class:`torch.nn.MaxUnpool3d` later
  220. ceil_mode: when True, will use `ceil` instead of `floor` to compute the output shape
  221. Shape:
  222. - Input: :math:`(N, C, D_{in}, H_{in}, W_{in})` or :math:`(C, D_{in}, H_{in}, W_{in})`.
  223. - Output: :math:`(N, C, D_{out}, H_{out}, W_{out})` or :math:`(C, D_{out}, H_{out}, W_{out})`, where
  224. .. math::
  225. D_{out} = \left\lfloor\frac{D_{in} + 2 \times \text{padding}[0] - \text{dilation}[0] \times
  226. (\text{kernel\_size}[0] - 1) - 1}{\text{stride}[0]} + 1\right\rfloor
  227. .. math::
  228. H_{out} = \left\lfloor\frac{H_{in} + 2 \times \text{padding}[1] - \text{dilation}[1] \times
  229. (\text{kernel\_size}[1] - 1) - 1}{\text{stride}[1]} + 1\right\rfloor
  230. .. math::
  231. W_{out} = \left\lfloor\frac{W_{in} + 2 \times \text{padding}[2] - \text{dilation}[2] \times
  232. (\text{kernel\_size}[2] - 1) - 1}{\text{stride}[2]} + 1\right\rfloor
  233. Examples::
  234. >>> # pool of square window of size=3, stride=2
  235. >>> m = nn.MaxPool3d(3, stride=2)
  236. >>> # pool of non-square window
  237. >>> m = nn.MaxPool3d((3, 2, 2), stride=(2, 1, 2))
  238. >>> input = torch.randn(20, 16, 50, 44, 31)
  239. >>> output = m(input)
  240. .. _link:
  241. https://github.com/vdumoulin/conv_arithmetic/blob/master/README.md
  242. """ # noqa: E501
  243. kernel_size: _size_3_t
  244. stride: _size_3_t
  245. padding: _size_3_t
  246. dilation: _size_3_t
  247. def forward(self, input: Tensor):
  248. """Runs the forward pass."""
  249. return F.max_pool3d(
  250. input,
  251. self.kernel_size,
  252. self.stride,
  253. self.padding,
  254. self.dilation,
  255. ceil_mode=self.ceil_mode,
  256. return_indices=self.return_indices,
  257. )
  258. class _MaxUnpoolNd(Module):
  259. def extra_repr(self) -> str:
  260. return f"kernel_size={self.kernel_size}, stride={self.stride}, padding={self.padding}"
  261. class MaxUnpool1d(_MaxUnpoolNd):
  262. r"""Computes a partial inverse of :class:`MaxPool1d`.
  263. :class:`MaxPool1d` is not fully invertible, since the non-maximal values are lost.
  264. :class:`MaxUnpool1d` takes in as input the output of :class:`MaxPool1d`
  265. including the indices of the maximal values and computes a partial inverse
  266. in which all non-maximal values are set to zero.
  267. Note:
  268. This operation may behave nondeterministically when the input indices has repeat values.
  269. See https://github.com/pytorch/pytorch/issues/80827 and :doc:`/notes/randomness` for more information.
  270. .. note:: :class:`MaxPool1d` can map several input sizes to the same output
  271. sizes. Hence, the inversion process can get ambiguous.
  272. To accommodate this, you can provide the needed output size
  273. as an additional argument :attr:`output_size` in the forward call.
  274. See the Inputs and Example below.
  275. Args:
  276. kernel_size (int or tuple): Size of the max pooling window.
  277. stride (int or tuple): Stride of the max pooling window.
  278. It is set to :attr:`kernel_size` by default.
  279. padding (int or tuple): Padding that was added to the input
  280. Inputs:
  281. - `input`: the input Tensor to invert
  282. - `indices`: the indices given out by :class:`~torch.nn.MaxPool1d`
  283. - `output_size` (optional): the targeted output size
  284. Shape:
  285. - Input: :math:`(N, C, H_{in})` or :math:`(C, H_{in})`.
  286. - Output: :math:`(N, C, H_{out})` or :math:`(C, H_{out})`, where
  287. .. math::
  288. H_{out} = (H_{in} - 1) \times \text{stride}[0] - 2 \times \text{padding}[0] + \text{kernel\_size}[0]
  289. or as given by :attr:`output_size` in the call operator
  290. Example::
  291. >>> # xdoctest: +IGNORE_WANT("do other tests modify the global state?")
  292. >>> pool = nn.MaxPool1d(2, stride=2, return_indices=True)
  293. >>> unpool = nn.MaxUnpool1d(2, stride=2)
  294. >>> input = torch.tensor([[[1., 2, 3, 4, 5, 6, 7, 8]]])
  295. >>> output, indices = pool(input)
  296. >>> unpool(output, indices)
  297. tensor([[[ 0., 2., 0., 4., 0., 6., 0., 8.]]])
  298. >>> # Example showcasing the use of output_size
  299. >>> input = torch.tensor([[[1., 2, 3, 4, 5, 6, 7, 8, 9]]])
  300. >>> output, indices = pool(input)
  301. >>> unpool(output, indices, output_size=input.size())
  302. tensor([[[ 0., 2., 0., 4., 0., 6., 0., 8., 0.]]])
  303. >>> unpool(output, indices)
  304. tensor([[[ 0., 2., 0., 4., 0., 6., 0., 8.]]])
  305. """
  306. kernel_size: _size_1_t
  307. stride: _size_1_t
  308. padding: _size_1_t
  309. def __init__(
  310. self,
  311. kernel_size: _size_1_t,
  312. stride: Optional[_size_1_t] = None,
  313. padding: _size_1_t = 0,
  314. ) -> None:
  315. super().__init__()
  316. self.kernel_size = _single(kernel_size)
  317. self.stride = _single(stride if (stride is not None) else kernel_size)
  318. self.padding = _single(padding)
  319. def forward(
  320. self, input: Tensor, indices: Tensor, output_size: Optional[list[int]] = None
  321. ) -> Tensor:
  322. """Runs the forward pass."""
  323. return F.max_unpool1d(
  324. input, indices, self.kernel_size, self.stride, self.padding, output_size
  325. )
  326. class MaxUnpool2d(_MaxUnpoolNd):
  327. r"""Computes a partial inverse of :class:`MaxPool2d`.
  328. :class:`MaxPool2d` is not fully invertible, since the non-maximal values are lost.
  329. :class:`MaxUnpool2d` takes in as input the output of :class:`MaxPool2d`
  330. including the indices of the maximal values and computes a partial inverse
  331. in which all non-maximal values are set to zero.
  332. Note:
  333. This operation may behave nondeterministically when the input indices has repeat values.
  334. See https://github.com/pytorch/pytorch/issues/80827 and :doc:`/notes/randomness` for more information.
  335. .. note:: :class:`MaxPool2d` can map several input sizes to the same output
  336. sizes. Hence, the inversion process can get ambiguous.
  337. To accommodate this, you can provide the needed output size
  338. as an additional argument :attr:`output_size` in the forward call.
  339. See the Inputs and Example below.
  340. Args:
  341. kernel_size (int or tuple): Size of the max pooling window.
  342. stride (int or tuple): Stride of the max pooling window.
  343. It is set to :attr:`kernel_size` by default.
  344. padding (int or tuple): Padding that was added to the input
  345. Inputs:
  346. - `input`: the input Tensor to invert
  347. - `indices`: the indices given out by :class:`~torch.nn.MaxPool2d`
  348. - `output_size` (optional): the targeted output size
  349. Shape:
  350. - Input: :math:`(N, C, H_{in}, W_{in})` or :math:`(C, H_{in}, W_{in})`.
  351. - Output: :math:`(N, C, H_{out}, W_{out})` or :math:`(C, H_{out}, W_{out})`, where
  352. .. math::
  353. H_{out} = (H_{in} - 1) \times \text{stride[0]} - 2 \times \text{padding[0]} + \text{kernel\_size[0]}
  354. .. math::
  355. W_{out} = (W_{in} - 1) \times \text{stride[1]} - 2 \times \text{padding[1]} + \text{kernel\_size[1]}
  356. or as given by :attr:`output_size` in the call operator
  357. Example::
  358. >>> pool = nn.MaxPool2d(2, stride=2, return_indices=True)
  359. >>> unpool = nn.MaxUnpool2d(2, stride=2)
  360. >>> input = torch.tensor([[[[ 1., 2., 3., 4.],
  361. [ 5., 6., 7., 8.],
  362. [ 9., 10., 11., 12.],
  363. [13., 14., 15., 16.]]]])
  364. >>> output, indices = pool(input)
  365. >>> unpool(output, indices)
  366. tensor([[[[ 0., 0., 0., 0.],
  367. [ 0., 6., 0., 8.],
  368. [ 0., 0., 0., 0.],
  369. [ 0., 14., 0., 16.]]]])
  370. >>> # Now using output_size to resolve an ambiguous size for the inverse
  371. >>> input = torch.tensor([[[[ 1., 2., 3., 4., 5.],
  372. [ 6., 7., 8., 9., 10.],
  373. [11., 12., 13., 14., 15.],
  374. [16., 17., 18., 19., 20.]]]])
  375. >>> output, indices = pool(input)
  376. >>> # This call will not work without specifying output_size
  377. >>> unpool(output, indices, output_size=input.size())
  378. tensor([[[[ 0., 0., 0., 0., 0.],
  379. [ 0., 7., 0., 9., 0.],
  380. [ 0., 0., 0., 0., 0.],
  381. [ 0., 17., 0., 19., 0.]]]])
  382. """
  383. kernel_size: _size_2_t
  384. stride: _size_2_t
  385. padding: _size_2_t
  386. def __init__(
  387. self,
  388. kernel_size: _size_2_t,
  389. stride: Optional[_size_2_t] = None,
  390. padding: _size_2_t = 0,
  391. ) -> None:
  392. super().__init__()
  393. self.kernel_size = _pair(kernel_size)
  394. self.stride = _pair(stride if (stride is not None) else kernel_size)
  395. self.padding = _pair(padding)
  396. def forward(
  397. self, input: Tensor, indices: Tensor, output_size: Optional[list[int]] = None
  398. ) -> Tensor:
  399. """Runs the forward pass."""
  400. return F.max_unpool2d(
  401. input, indices, self.kernel_size, self.stride, self.padding, output_size
  402. )
  403. class MaxUnpool3d(_MaxUnpoolNd):
  404. r"""Computes a partial inverse of :class:`MaxPool3d`.
  405. :class:`MaxPool3d` is not fully invertible, since the non-maximal values are lost.
  406. :class:`MaxUnpool3d` takes in as input the output of :class:`MaxPool3d`
  407. including the indices of the maximal values and computes a partial inverse
  408. in which all non-maximal values are set to zero.
  409. Note:
  410. This operation may behave nondeterministically when the input indices has repeat values.
  411. See https://github.com/pytorch/pytorch/issues/80827 and :doc:`/notes/randomness` for more information.
  412. .. note:: :class:`MaxPool3d` can map several input sizes to the same output
  413. sizes. Hence, the inversion process can get ambiguous.
  414. To accommodate this, you can provide the needed output size
  415. as an additional argument :attr:`output_size` in the forward call.
  416. See the Inputs section below.
  417. Args:
  418. kernel_size (int or tuple): Size of the max pooling window.
  419. stride (int or tuple): Stride of the max pooling window.
  420. It is set to :attr:`kernel_size` by default.
  421. padding (int or tuple): Padding that was added to the input
  422. Inputs:
  423. - `input`: the input Tensor to invert
  424. - `indices`: the indices given out by :class:`~torch.nn.MaxPool3d`
  425. - `output_size` (optional): the targeted output size
  426. Shape:
  427. - Input: :math:`(N, C, D_{in}, H_{in}, W_{in})` or :math:`(C, D_{in}, H_{in}, W_{in})`.
  428. - Output: :math:`(N, C, D_{out}, H_{out}, W_{out})` or :math:`(C, D_{out}, H_{out}, W_{out})`, where
  429. .. math::
  430. D_{out} = (D_{in} - 1) \times \text{stride[0]} - 2 \times \text{padding[0]} + \text{kernel\_size[0]}
  431. .. math::
  432. H_{out} = (H_{in} - 1) \times \text{stride[1]} - 2 \times \text{padding[1]} + \text{kernel\_size[1]}
  433. .. math::
  434. W_{out} = (W_{in} - 1) \times \text{stride[2]} - 2 \times \text{padding[2]} + \text{kernel\_size[2]}
  435. or as given by :attr:`output_size` in the call operator
  436. Example::
  437. >>> # pool of square window of size=3, stride=2
  438. >>> pool = nn.MaxPool3d(3, stride=2, return_indices=True)
  439. >>> unpool = nn.MaxUnpool3d(3, stride=2)
  440. >>> output, indices = pool(torch.randn(20, 16, 51, 33, 15))
  441. >>> unpooled_output = unpool(output, indices)
  442. >>> unpooled_output.size()
  443. torch.Size([20, 16, 51, 33, 15])
  444. """
  445. kernel_size: _size_3_t
  446. stride: _size_3_t
  447. padding: _size_3_t
  448. def __init__(
  449. self,
  450. kernel_size: _size_3_t,
  451. stride: Optional[_size_3_t] = None,
  452. padding: _size_3_t = 0,
  453. ) -> None:
  454. super().__init__()
  455. self.kernel_size = _triple(kernel_size)
  456. self.stride = _triple(stride if (stride is not None) else kernel_size)
  457. self.padding = _triple(padding)
  458. def forward(
  459. self, input: Tensor, indices: Tensor, output_size: Optional[list[int]] = None
  460. ) -> Tensor:
  461. """Runs the forward pass."""
  462. return F.max_unpool3d(
  463. input, indices, self.kernel_size, self.stride, self.padding, output_size
  464. )
  465. class _AvgPoolNd(Module):
  466. __constants__ = [
  467. "kernel_size",
  468. "stride",
  469. "padding",
  470. "ceil_mode",
  471. "count_include_pad",
  472. ]
  473. def extra_repr(self) -> str:
  474. return f"kernel_size={self.kernel_size}, stride={self.stride}, padding={self.padding}"
  475. class AvgPool1d(_AvgPoolNd):
  476. r"""Applies a 1D average pooling over an input signal composed of several input planes.
  477. In the simplest case, the output value of the layer with input size :math:`(N, C, L)`,
  478. output :math:`(N, C, L_{out})` and :attr:`kernel_size` :math:`k`
  479. can be precisely described as:
  480. .. math::
  481. \text{out}(N_i, C_j, l) = \frac{1}{k} \sum_{m=0}^{k-1}
  482. \text{input}(N_i, C_j, \text{stride} \times l + m)
  483. If :attr:`padding` is non-zero, then the input is implicitly zero-padded on both sides
  484. for :attr:`padding` number of points.
  485. Note:
  486. When ceil_mode=True, sliding windows are allowed to go off-bounds if they start within the left padding
  487. or the input. Sliding windows that would start in the right padded region are ignored.
  488. .. note::
  489. pad should be at most half of effective kernel size.
  490. The parameters :attr:`kernel_size`, :attr:`stride`, :attr:`padding` can each be
  491. an ``int`` or a one-element tuple.
  492. Args:
  493. kernel_size: the size of the window
  494. stride: the stride of the window. Default value is :attr:`kernel_size`
  495. padding: implicit zero padding to be added on both sides
  496. ceil_mode: when True, will use `ceil` instead of `floor` to compute the output shape
  497. count_include_pad: when True, will include the zero-padding in the averaging calculation
  498. Shape:
  499. - Input: :math:`(N, C, L_{in})` or :math:`(C, L_{in})`.
  500. - Output: :math:`(N, C, L_{out})` or :math:`(C, L_{out})`, where
  501. .. math::
  502. L_{out} = \left\lfloor \frac{L_{in} +
  503. 2 \times \text{padding} - \text{kernel\_size}}{\text{stride}} + 1\right\rfloor
  504. Per the note above, if ``ceil_mode`` is True and :math:`(L_{out} - 1) \times \text{stride} \geq L_{in}
  505. + \text{padding}`, we skip the last window as it would start in the right padded region, resulting in
  506. :math:`L_{out}` being reduced by one.
  507. Examples::
  508. >>> # pool with window of size=3, stride=2
  509. >>> m = nn.AvgPool1d(3, stride=2)
  510. >>> m(torch.tensor([[[1., 2, 3, 4, 5, 6, 7]]]))
  511. tensor([[[2., 4., 6.]]])
  512. """
  513. kernel_size: _size_1_t
  514. stride: _size_1_t
  515. padding: _size_1_t
  516. ceil_mode: bool
  517. count_include_pad: bool
  518. def __init__(
  519. self,
  520. kernel_size: _size_1_t,
  521. stride: _size_1_t = None,
  522. padding: _size_1_t = 0,
  523. ceil_mode: bool = False,
  524. count_include_pad: bool = True,
  525. ) -> None:
  526. super().__init__()
  527. self.kernel_size = _single(kernel_size)
  528. self.stride = _single(stride if stride is not None else kernel_size)
  529. self.padding = _single(padding)
  530. self.ceil_mode = ceil_mode
  531. self.count_include_pad = count_include_pad
  532. def forward(self, input: Tensor) -> Tensor:
  533. """Runs the forward pass."""
  534. return F.avg_pool1d(
  535. input,
  536. self.kernel_size,
  537. self.stride,
  538. self.padding,
  539. self.ceil_mode,
  540. self.count_include_pad,
  541. )
  542. class AvgPool2d(_AvgPoolNd):
  543. r"""Applies a 2D average pooling over an input signal composed of several input planes.
  544. In the simplest case, the output value of the layer with input size :math:`(N, C, H, W)`,
  545. output :math:`(N, C, H_{out}, W_{out})` and :attr:`kernel_size` :math:`(kH, kW)`
  546. can be precisely described as:
  547. .. math::
  548. out(N_i, C_j, h, w) = \frac{1}{kH * kW} \sum_{m=0}^{kH-1} \sum_{n=0}^{kW-1}
  549. input(N_i, C_j, stride[0] \times h + m, stride[1] \times w + n)
  550. If :attr:`padding` is non-zero, then the input is implicitly zero-padded on both sides
  551. for :attr:`padding` number of points.
  552. Note:
  553. When ceil_mode=True, sliding windows are allowed to go off-bounds if they start within the left padding
  554. or the input. Sliding windows that would start in the right padded region are ignored.
  555. .. note::
  556. pad should be at most half of effective kernel size.
  557. The parameters :attr:`kernel_size`, :attr:`stride`, :attr:`padding` can either be:
  558. - a single ``int`` or a single-element tuple -- in which case the same value is used for the height and width dimension
  559. - a ``tuple`` of two ints -- in which case, the first `int` is used for the height dimension,
  560. and the second `int` for the width dimension
  561. Args:
  562. kernel_size: the size of the window
  563. stride: the stride of the window. Default value is :attr:`kernel_size`
  564. padding: implicit zero padding to be added on both sides
  565. ceil_mode: when True, will use `ceil` instead of `floor` to compute the output shape
  566. count_include_pad: when True, will include the zero-padding in the averaging calculation
  567. divisor_override: if specified, it will be used as divisor, otherwise size of the pooling region will be used.
  568. Shape:
  569. - Input: :math:`(N, C, H_{in}, W_{in})` or :math:`(C, H_{in}, W_{in})`.
  570. - Output: :math:`(N, C, H_{out}, W_{out})` or :math:`(C, H_{out}, W_{out})`, where
  571. .. math::
  572. H_{out} = \left\lfloor\frac{H_{in} + 2 \times \text{padding}[0] -
  573. \text{kernel\_size}[0]}{\text{stride}[0]} + 1\right\rfloor
  574. .. math::
  575. W_{out} = \left\lfloor\frac{W_{in} + 2 \times \text{padding}[1] -
  576. \text{kernel\_size}[1]}{\text{stride}[1]} + 1\right\rfloor
  577. Per the note above, if ``ceil_mode`` is True and :math:`(H_{out} - 1)\times \text{stride}[0]\geq H_{in}
  578. + \text{padding}[0]`, we skip the last window as it would start in the bottom padded region,
  579. resulting in :math:`H_{out}` being reduced by one.
  580. The same applies for :math:`W_{out}`.
  581. Examples::
  582. >>> # pool of square window of size=3, stride=2
  583. >>> m = nn.AvgPool2d(3, stride=2)
  584. >>> # pool of non-square window
  585. >>> m = nn.AvgPool2d((3, 2), stride=(2, 1))
  586. >>> input = torch.randn(20, 16, 50, 32)
  587. >>> output = m(input)
  588. """
  589. __constants__ = [
  590. "kernel_size",
  591. "stride",
  592. "padding",
  593. "ceil_mode",
  594. "count_include_pad",
  595. "divisor_override",
  596. ]
  597. kernel_size: _size_2_t
  598. stride: _size_2_t
  599. padding: _size_2_t
  600. ceil_mode: bool
  601. count_include_pad: bool
  602. def __init__(
  603. self,
  604. kernel_size: _size_2_t,
  605. stride: Optional[_size_2_t] = None,
  606. padding: _size_2_t = 0,
  607. ceil_mode: bool = False,
  608. count_include_pad: bool = True,
  609. divisor_override: Optional[int] = None,
  610. ) -> None:
  611. super().__init__()
  612. self.kernel_size = kernel_size
  613. self.stride = stride if (stride is not None) else kernel_size
  614. self.padding = padding
  615. self.ceil_mode = ceil_mode
  616. self.count_include_pad = count_include_pad
  617. self.divisor_override = divisor_override
  618. def forward(self, input: Tensor) -> Tensor:
  619. """Runs the forward pass."""
  620. return F.avg_pool2d(
  621. input,
  622. self.kernel_size,
  623. self.stride,
  624. self.padding,
  625. self.ceil_mode,
  626. self.count_include_pad,
  627. self.divisor_override,
  628. )
  629. class AvgPool3d(_AvgPoolNd):
  630. r"""Applies a 3D average pooling over an input signal composed of several input planes.
  631. In the simplest case, the output value of the layer with input size :math:`(N, C, D, H, W)`,
  632. output :math:`(N, C, D_{out}, H_{out}, W_{out})` and :attr:`kernel_size` :math:`(kD, kH, kW)`
  633. can be precisely described as:
  634. .. math::
  635. \begin{aligned}
  636. \text{out}(N_i, C_j, d, h, w) ={} & \sum_{k=0}^{kD-1} \sum_{m=0}^{kH-1} \sum_{n=0}^{kW-1} \\
  637. & \frac{\text{input}(N_i, C_j, \text{stride}[0] \times d + k,
  638. \text{stride}[1] \times h + m, \text{stride}[2] \times w + n)}
  639. {kD \times kH \times kW}
  640. \end{aligned}
  641. If :attr:`padding` is non-zero, then the input is implicitly zero-padded on all three sides
  642. for :attr:`padding` number of points.
  643. Note:
  644. When ceil_mode=True, sliding windows are allowed to go off-bounds if they start within the left padding
  645. or the input. Sliding windows that would start in the right padded region are ignored.
  646. .. note::
  647. pad should be at most half of effective kernel size.
  648. The parameters :attr:`kernel_size`, :attr:`stride` can either be:
  649. - a single ``int`` -- in which case the same value is used for the depth, height and width dimension
  650. - a ``tuple`` of three ints -- in which case, the first `int` is used for the depth dimension,
  651. the second `int` for the height dimension and the third `int` for the width dimension
  652. Args:
  653. kernel_size: the size of the window
  654. stride: the stride of the window. Default value is :attr:`kernel_size`
  655. padding: implicit zero padding to be added on all three sides
  656. ceil_mode: when True, will use `ceil` instead of `floor` to compute the output shape
  657. count_include_pad: when True, will include the zero-padding in the averaging calculation
  658. divisor_override: if specified, it will be used as divisor, otherwise :attr:`kernel_size` will be used
  659. Shape:
  660. - Input: :math:`(N, C, D_{in}, H_{in}, W_{in})` or :math:`(C, D_{in}, H_{in}, W_{in})`.
  661. - Output: :math:`(N, C, D_{out}, H_{out}, W_{out})` or
  662. :math:`(C, D_{out}, H_{out}, W_{out})`, where
  663. .. math::
  664. D_{out} = \left\lfloor\frac{D_{in} + 2 \times \text{padding}[0] -
  665. \text{kernel\_size}[0]}{\text{stride}[0]} + 1\right\rfloor
  666. .. math::
  667. H_{out} = \left\lfloor\frac{H_{in} + 2 \times \text{padding}[1] -
  668. \text{kernel\_size}[1]}{\text{stride}[1]} + 1\right\rfloor
  669. .. math::
  670. W_{out} = \left\lfloor\frac{W_{in} + 2 \times \text{padding}[2] -
  671. \text{kernel\_size}[2]}{\text{stride}[2]} + 1\right\rfloor
  672. Per the note above, if ``ceil_mode`` is True and :math:`(D_{out} - 1)\times \text{stride}[0]\geq D_{in}
  673. + \text{padding}[0]`, we skip the last window as it would start in the padded region,
  674. resulting in :math:`D_{out}` being reduced by one.
  675. The same applies for :math:`W_{out}` and :math:`H_{out}`.
  676. Examples::
  677. >>> # pool of square window of size=3, stride=2
  678. >>> m = nn.AvgPool3d(3, stride=2)
  679. >>> # pool of non-square window
  680. >>> m = nn.AvgPool3d((3, 2, 2), stride=(2, 1, 2))
  681. >>> input = torch.randn(20, 16, 50, 44, 31)
  682. >>> output = m(input)
  683. """
  684. __constants__ = [
  685. "kernel_size",
  686. "stride",
  687. "padding",
  688. "ceil_mode",
  689. "count_include_pad",
  690. "divisor_override",
  691. ]
  692. kernel_size: _size_3_t
  693. stride: _size_3_t
  694. padding: _size_3_t
  695. ceil_mode: bool
  696. count_include_pad: bool
  697. def __init__(
  698. self,
  699. kernel_size: _size_3_t,
  700. stride: Optional[_size_3_t] = None,
  701. padding: _size_3_t = 0,
  702. ceil_mode: bool = False,
  703. count_include_pad: bool = True,
  704. divisor_override: Optional[int] = None,
  705. ) -> None:
  706. super().__init__()
  707. self.kernel_size = kernel_size
  708. self.stride = stride if (stride is not None) else kernel_size
  709. self.padding = padding
  710. self.ceil_mode = ceil_mode
  711. self.count_include_pad = count_include_pad
  712. self.divisor_override = divisor_override
  713. def forward(self, input: Tensor) -> Tensor:
  714. """Runs the forward pass."""
  715. return F.avg_pool3d(
  716. input,
  717. self.kernel_size,
  718. self.stride,
  719. self.padding,
  720. self.ceil_mode,
  721. self.count_include_pad,
  722. self.divisor_override,
  723. )
  724. def __setstate__(self, d):
  725. super().__setstate__(d)
  726. self.__dict__.setdefault("padding", 0)
  727. self.__dict__.setdefault("ceil_mode", False)
  728. self.__dict__.setdefault("count_include_pad", True)
  729. class FractionalMaxPool2d(Module):
  730. r"""Applies a 2D fractional max pooling over an input signal composed of several input planes.
  731. Fractional MaxPooling is described in detail in the paper `Fractional MaxPooling`_ by Ben Graham
  732. The max-pooling operation is applied in :math:`kH \times kW` regions by a stochastic
  733. step size determined by the target output size.
  734. The number of output features is equal to the number of input planes.
  735. .. note:: Exactly one of ``output_size`` or ``output_ratio`` must be defined.
  736. Args:
  737. kernel_size: the size of the window to take a max over.
  738. Can be a single number k (for a square kernel of k x k) or a tuple `(kh, kw)`
  739. output_size: the target output size of the image of the form `oH x oW`.
  740. Can be a tuple `(oH, oW)` or a single number oH for a square image `oH x oH`.
  741. Note that we must have :math:`kH + oH - 1 <= H_{in}` and :math:`kW + oW - 1 <= W_{in}`
  742. output_ratio: If one wants to have an output size as a ratio of the input size, this option can be given.
  743. This has to be a number or tuple in the range (0, 1).
  744. Note that we must have :math:`kH + (output\_ratio\_H * H_{in}) - 1 <= H_{in}`
  745. and :math:`kW + (output\_ratio\_W * W_{in}) - 1 <= W_{in}`
  746. return_indices: if ``True``, will return the indices along with the outputs.
  747. Useful to pass to :meth:`nn.MaxUnpool2d`. Default: ``False``
  748. Shape:
  749. - Input: :math:`(N, C, H_{in}, W_{in})` or :math:`(C, H_{in}, W_{in})`.
  750. - Output: :math:`(N, C, H_{out}, W_{out})` or :math:`(C, H_{out}, W_{out})`, where
  751. :math:`(H_{out}, W_{out})=\text{output\_size}` or
  752. :math:`(H_{out}, W_{out})=\text{output\_ratio} \times (H_{in}, W_{in})`.
  753. Examples:
  754. >>> # pool of square window of size=3, and target output size 13x12
  755. >>> m = nn.FractionalMaxPool2d(3, output_size=(13, 12))
  756. >>> # pool of square window and target output size being half of input image size
  757. >>> m = nn.FractionalMaxPool2d(3, output_ratio=(0.5, 0.5))
  758. >>> input = torch.randn(20, 16, 50, 32)
  759. >>> output = m(input)
  760. .. _Fractional MaxPooling:
  761. https://arxiv.org/abs/1412.6071
  762. """
  763. __constants__ = ["kernel_size", "return_indices", "output_size", "output_ratio"]
  764. kernel_size: _size_2_t
  765. return_indices: bool
  766. output_size: _size_2_t
  767. output_ratio: _ratio_2_t
  768. def __init__(
  769. self,
  770. kernel_size: _size_2_t,
  771. output_size: Optional[_size_2_t] = None,
  772. output_ratio: Optional[_ratio_2_t] = None,
  773. return_indices: bool = False,
  774. _random_samples=None,
  775. ) -> None:
  776. super().__init__()
  777. self.kernel_size = _pair(kernel_size)
  778. self.return_indices = return_indices
  779. self.register_buffer("_random_samples", _random_samples)
  780. self.output_size = _pair(output_size) if output_size is not None else None
  781. self.output_ratio = _pair(output_ratio) if output_ratio is not None else None
  782. if output_size is None and output_ratio is None:
  783. raise ValueError(
  784. "FractionalMaxPool2d requires specifying either "
  785. "an output size, or a pooling ratio"
  786. )
  787. if output_size is not None and output_ratio is not None:
  788. raise ValueError(
  789. "only one of output_size and output_ratio may be specified"
  790. )
  791. if self.output_ratio is not None:
  792. if not (0 < self.output_ratio[0] < 1 and 0 < self.output_ratio[1] < 1):
  793. raise ValueError(
  794. f"output_ratio must be between 0 and 1 (got {output_ratio})"
  795. )
  796. def forward(self, input: Tensor):
  797. return F.fractional_max_pool2d(
  798. input,
  799. self.kernel_size,
  800. self.output_size,
  801. self.output_ratio,
  802. self.return_indices,
  803. _random_samples=self._random_samples,
  804. )
  805. class FractionalMaxPool3d(Module):
  806. r"""Applies a 3D fractional max pooling over an input signal composed of several input planes.
  807. Fractional MaxPooling is described in detail in the paper `Fractional MaxPooling`_ by Ben Graham
  808. The max-pooling operation is applied in :math:`kT \times kH \times kW` regions by a stochastic
  809. step size determined by the target output size.
  810. The number of output features is equal to the number of input planes.
  811. .. note:: Exactly one of ``output_size`` or ``output_ratio`` must be defined.
  812. Args:
  813. kernel_size: the size of the window to take a max over.
  814. Can be a single number `k` (for a square kernel of `k x k x k`) or a tuple `(kt x kh x kw)`,
  815. `k` must greater than 0.
  816. output_size: the target output size of the image of the form `oT x oH x oW`.
  817. Can be a tuple `(oT, oH, oW)` or a single number oH for a square image `oH x oH x oH`
  818. output_ratio: If one wants to have an output size as a ratio of the input size, this option can be given.
  819. This has to be a number or tuple in the range (0, 1)
  820. return_indices: if ``True``, will return the indices along with the outputs.
  821. Useful to pass to :meth:`nn.MaxUnpool3d`. Default: ``False``
  822. Shape:
  823. - Input: :math:`(N, C, T_{in}, H_{in}, W_{in})` or :math:`(C, T_{in}, H_{in}, W_{in})`.
  824. - Output: :math:`(N, C, T_{out}, H_{out}, W_{out})` or :math:`(C, T_{out}, H_{out}, W_{out})`, where
  825. :math:`(T_{out}, H_{out}, W_{out})=\text{output\_size}` or
  826. :math:`(T_{out}, H_{out}, W_{out})=\text{output\_ratio} \times (T_{in}, H_{in}, W_{in})`
  827. Examples:
  828. >>> # pool of cubic window of size=3, and target output size 13x12x11
  829. >>> m = nn.FractionalMaxPool3d(3, output_size=(13, 12, 11))
  830. >>> # pool of cubic window and target output size being half of input size
  831. >>> m = nn.FractionalMaxPool3d(3, output_ratio=(0.5, 0.5, 0.5))
  832. >>> input = torch.randn(20, 16, 50, 32, 16)
  833. >>> output = m(input)
  834. .. _Fractional MaxPooling:
  835. https://arxiv.org/abs/1412.6071
  836. """
  837. __constants__ = ["kernel_size", "return_indices", "output_size", "output_ratio"]
  838. kernel_size: _size_3_t
  839. return_indices: bool
  840. output_size: _size_3_t
  841. output_ratio: _ratio_3_t
  842. def __init__(
  843. self,
  844. kernel_size: _size_3_t,
  845. output_size: Optional[_size_3_t] = None,
  846. output_ratio: Optional[_ratio_3_t] = None,
  847. return_indices: bool = False,
  848. _random_samples=None,
  849. ) -> None:
  850. super().__init__()
  851. if (isinstance(kernel_size, int) and kernel_size <= 0) or (
  852. isinstance(kernel_size, (tuple, list))
  853. and not all(k > 0 for k in kernel_size)
  854. ):
  855. raise ValueError(f"kernel_size must greater than 0, but got {kernel_size}")
  856. self.kernel_size = _triple(kernel_size)
  857. self.return_indices = return_indices
  858. self.register_buffer("_random_samples", _random_samples)
  859. self.output_size = _triple(output_size) if output_size is not None else None
  860. self.output_ratio = _triple(output_ratio) if output_ratio is not None else None
  861. if output_size is None and output_ratio is None:
  862. raise ValueError(
  863. "FractionalMaxPool3d requires specifying either "
  864. "an output size, or a pooling ratio"
  865. )
  866. if output_size is not None and output_ratio is not None:
  867. raise ValueError(
  868. "only one of output_size and output_ratio may be specified"
  869. )
  870. if self.output_ratio is not None:
  871. if not (
  872. 0 < self.output_ratio[0] < 1
  873. and 0 < self.output_ratio[1] < 1
  874. and 0 < self.output_ratio[2] < 1
  875. ):
  876. raise ValueError(
  877. f"output_ratio must be between 0 and 1 (got {output_ratio})"
  878. )
  879. def forward(self, input: Tensor):
  880. return F.fractional_max_pool3d(
  881. input,
  882. self.kernel_size,
  883. self.output_size,
  884. self.output_ratio,
  885. self.return_indices,
  886. _random_samples=self._random_samples,
  887. )
  888. class _LPPoolNd(Module):
  889. __constants__ = ["norm_type", "kernel_size", "stride", "ceil_mode"]
  890. norm_type: float
  891. ceil_mode: bool
  892. def __init__(
  893. self,
  894. norm_type: float,
  895. kernel_size: _size_any_t,
  896. stride: Optional[_size_any_t] = None,
  897. ceil_mode: bool = False,
  898. ) -> None:
  899. super().__init__()
  900. self.norm_type = norm_type
  901. self.kernel_size = kernel_size
  902. self.stride = stride
  903. self.ceil_mode = ceil_mode
  904. def extra_repr(self) -> str:
  905. return (
  906. "norm_type={norm_type}, kernel_size={kernel_size}, stride={stride}, "
  907. "ceil_mode={ceil_mode}".format(**self.__dict__)
  908. )
  909. class LPPool1d(_LPPoolNd):
  910. r"""Applies a 1D power-average pooling over an input signal composed of several input planes.
  911. On each window, the function computed is:
  912. .. math::
  913. f(X) = \sqrt[p]{\sum_{x \in X} x^{p}}
  914. - At p = :math:`\infty`, one gets Max Pooling
  915. - At p = 1, one gets Sum Pooling (which is proportional to Average Pooling)
  916. .. note:: If the sum to the power of `p` is zero, the gradient of this function is
  917. not defined. This implementation will set the gradient to zero in this case.
  918. Args:
  919. kernel_size: a single int, the size of the window
  920. stride: a single int, the stride of the window. Default value is :attr:`kernel_size`
  921. ceil_mode: when True, will use `ceil` instead of `floor` to compute the output shape
  922. Shape:
  923. - Input: :math:`(N, C, L_{in})` or :math:`(C, L_{in})`.
  924. - Output: :math:`(N, C, L_{out})` or :math:`(C, L_{out})`, where
  925. .. math::
  926. L_{out} = \left\lfloor\frac{L_{in} - \text{kernel\_size}}{\text{stride}} + 1\right\rfloor
  927. Examples::
  928. >>> # power-2 pool of window of length 3, with stride 2.
  929. >>> m = nn.LPPool1d(2, 3, stride=2)
  930. >>> input = torch.randn(20, 16, 50)
  931. >>> output = m(input)
  932. """
  933. kernel_size: _size_1_t
  934. stride: _size_1_t
  935. def forward(self, input: Tensor) -> Tensor:
  936. """Runs the forward pass."""
  937. return F.lp_pool1d(
  938. input, float(self.norm_type), self.kernel_size, self.stride, self.ceil_mode
  939. )
  940. class LPPool2d(_LPPoolNd):
  941. r"""Applies a 2D power-average pooling over an input signal composed of several input planes.
  942. On each window, the function computed is:
  943. .. math::
  944. f(X) = \sqrt[p]{\sum_{x \in X} x^{p}}
  945. - At p = :math:`\infty`, one gets Max Pooling
  946. - At p = 1, one gets Sum Pooling (which is proportional to average pooling)
  947. The parameters :attr:`kernel_size`, :attr:`stride` can either be:
  948. - a single ``int`` -- in which case the same value is used for the height and width dimension
  949. - a ``tuple`` of two ints -- in which case, the first `int` is used for the height dimension,
  950. and the second `int` for the width dimension
  951. .. note:: If the sum to the power of `p` is zero, the gradient of this function is
  952. not defined. This implementation will set the gradient to zero in this case.
  953. Args:
  954. kernel_size: the size of the window
  955. stride: the stride of the window. Default value is :attr:`kernel_size`
  956. ceil_mode: when True, will use `ceil` instead of `floor` to compute the output shape
  957. Shape:
  958. - Input: :math:`(N, C, H_{in}, W_{in})` or :math:`(C, H_{in}, W_{in})`.
  959. - Output: :math:`(N, C, H_{out}, W_{out})` or :math:`(C, H_{out}, W_{out})`, where
  960. .. math::
  961. H_{out} = \left\lfloor\frac{H_{in} - \text{kernel\_size}[0]}{\text{stride}[0]} + 1\right\rfloor
  962. .. math::
  963. W_{out} = \left\lfloor\frac{W_{in} - \text{kernel\_size}[1]}{\text{stride}[1]} + 1\right\rfloor
  964. Examples::
  965. >>> # power-2 pool of square window of size=3, stride=2
  966. >>> m = nn.LPPool2d(2, 3, stride=2)
  967. >>> # pool of non-square window of power 1.2
  968. >>> m = nn.LPPool2d(1.2, (3, 2), stride=(2, 1))
  969. >>> input = torch.randn(20, 16, 50, 32)
  970. >>> output = m(input)
  971. """
  972. kernel_size: _size_2_t
  973. stride: _size_2_t
  974. def forward(self, input: Tensor) -> Tensor:
  975. """Runs the forward pass."""
  976. return F.lp_pool2d(
  977. input, float(self.norm_type), self.kernel_size, self.stride, self.ceil_mode
  978. )
  979. class LPPool3d(_LPPoolNd):
  980. r"""Applies a 3D power-average pooling over an input signal composed of several input planes.
  981. On each window, the function computed is:
  982. .. math::
  983. f(X) = \sqrt[p]{\sum_{x \in X} x^{p}}
  984. - At p = :math:`\infty`, one gets Max Pooling
  985. - At p = 1, one gets Sum Pooling (which is proportional to average pooling)
  986. The parameters :attr:`kernel_size`, :attr:`stride` can either be:
  987. - a single ``int`` -- in which case the same value is used for the height, width and depth dimension
  988. - a ``tuple`` of three ints -- in which case, the first `int` is used for the depth dimension,
  989. the second `int` for the height dimension and the third `int` for the width dimension
  990. .. note:: If the sum to the power of `p` is zero, the gradient of this function is
  991. not defined. This implementation will set the gradient to zero in this case.
  992. Args:
  993. kernel_size: the size of the window
  994. stride: the stride of the window. Default value is :attr:`kernel_size`
  995. ceil_mode: when True, will use `ceil` instead of `floor` to compute the output shape
  996. Shape:
  997. - Input: :math:`(N, C, D_{in}, H_{in}, W_{in})` or :math:`(C, D_{in}, H_{in}, W_{in})`.
  998. - Output: :math:`(N, C, D_{out}, H_{out}, W_{out})` or
  999. :math:`(C, D_{out}, H_{out}, W_{out})`, where
  1000. .. math::
  1001. D_{out} = \left\lfloor\frac{D_{in} - \text{kernel\_size}[0]}{\text{stride}[0]} + 1\right\rfloor
  1002. .. math::
  1003. H_{out} = \left\lfloor\frac{H_{in} - \text{kernel\_size}[1]}{\text{stride}[1]} + 1\right\rfloor
  1004. .. math::
  1005. W_{out} = \left\lfloor\frac{W_{in} - \text{kernel\_size}[2]}{\text{stride}[2]} + 1\right\rfloor
  1006. Examples::
  1007. >>> # power-2 pool of square window of size=3, stride=2
  1008. >>> m = nn.LPPool3d(2, 3, stride=2)
  1009. >>> # pool of non-square window of power 1.2
  1010. >>> m = nn.LPPool3d(1.2, (3, 2, 2), stride=(2, 1, 2))
  1011. >>> input = torch.randn(20, 16, 50, 44, 31)
  1012. >>> output = m(input)
  1013. """
  1014. kernel_size: _size_3_t
  1015. stride: _size_3_t
  1016. def forward(self, input: Tensor) -> Tensor:
  1017. """Runs the forward pass."""
  1018. return F.lp_pool3d(
  1019. input, float(self.norm_type), self.kernel_size, self.stride, self.ceil_mode
  1020. )
  1021. class _AdaptiveMaxPoolNd(Module):
  1022. __constants__ = ["output_size", "return_indices"]
  1023. return_indices: bool
  1024. def __init__(
  1025. self, output_size: _size_any_opt_t, return_indices: bool = False
  1026. ) -> None:
  1027. super().__init__()
  1028. self.output_size = output_size
  1029. self.return_indices = return_indices
  1030. def extra_repr(self) -> str:
  1031. return f"output_size={self.output_size}"
  1032. # FIXME (by @ssnl): Improve adaptive pooling docs: specify what the input and
  1033. # output shapes are, and how the operation computes output.
  1034. class AdaptiveMaxPool1d(_AdaptiveMaxPoolNd):
  1035. r"""Applies a 1D adaptive max pooling over an input signal composed of several input planes.
  1036. The output size is :math:`L_{out}`, for any input size.
  1037. The number of output features is equal to the number of input planes.
  1038. Args:
  1039. output_size: the target output size :math:`L_{out}`.
  1040. return_indices: if ``True``, will return the indices along with the outputs.
  1041. Useful to pass to nn.MaxUnpool1d. Default: ``False``
  1042. Shape:
  1043. - Input: :math:`(N, C, L_{in})` or :math:`(C, L_{in})`.
  1044. - Output: :math:`(N, C, L_{out})` or :math:`(C, L_{out})`, where
  1045. :math:`L_{out}=\text{output\_size}`.
  1046. Examples:
  1047. >>> # target output size of 5
  1048. >>> m = nn.AdaptiveMaxPool1d(5)
  1049. >>> input = torch.randn(1, 64, 8)
  1050. >>> output = m(input)
  1051. """
  1052. output_size: _size_1_t
  1053. def forward(self, input: Tensor):
  1054. """Runs the forward pass."""
  1055. return F.adaptive_max_pool1d(input, self.output_size, self.return_indices)
  1056. class AdaptiveMaxPool2d(_AdaptiveMaxPoolNd):
  1057. r"""Applies a 2D adaptive max pooling over an input signal composed of several input planes.
  1058. The output is of size :math:`H_{out} \times W_{out}`, for any input size.
  1059. The number of output features is equal to the number of input planes.
  1060. Args:
  1061. output_size: the target output size of the image of the form :math:`H_{out} \times W_{out}`.
  1062. Can be a tuple :math:`(H_{out}, W_{out})` or a single :math:`H_{out}` for a
  1063. square image :math:`H_{out} \times H_{out}`. :math:`H_{out}` and :math:`W_{out}`
  1064. can be either a ``int``, or ``None`` which means the size will be the same as that
  1065. of the input.
  1066. return_indices: if ``True``, will return the indices along with the outputs.
  1067. Useful to pass to nn.MaxUnpool2d. Default: ``False``
  1068. Shape:
  1069. - Input: :math:`(N, C, H_{in}, W_{in})` or :math:`(C, H_{in}, W_{in})`.
  1070. - Output: :math:`(N, C, H_{out}, W_{out})` or :math:`(C, H_{out}, W_{out})`, where
  1071. :math:`(H_{out}, W_{out})=\text{output\_size}`.
  1072. Examples:
  1073. >>> # target output size of 5x7
  1074. >>> m = nn.AdaptiveMaxPool2d((5, 7))
  1075. >>> input = torch.randn(1, 64, 8, 9)
  1076. >>> output = m(input)
  1077. >>> # target output size of 7x7 (square)
  1078. >>> m = nn.AdaptiveMaxPool2d(7)
  1079. >>> input = torch.randn(1, 64, 10, 9)
  1080. >>> output = m(input)
  1081. >>> # target output size of 10x7
  1082. >>> m = nn.AdaptiveMaxPool2d((None, 7))
  1083. >>> input = torch.randn(1, 64, 10, 9)
  1084. >>> output = m(input)
  1085. """
  1086. output_size: _size_2_opt_t
  1087. def forward(self, input: Tensor):
  1088. """Runs the forward pass."""
  1089. return F.adaptive_max_pool2d(input, self.output_size, self.return_indices)
  1090. class AdaptiveMaxPool3d(_AdaptiveMaxPoolNd):
  1091. r"""Applies a 3D adaptive max pooling over an input signal composed of several input planes.
  1092. The output is of size :math:`D_{out} \times H_{out} \times W_{out}`, for any input size.
  1093. The number of output features is equal to the number of input planes.
  1094. Args:
  1095. output_size: the target output size of the image of the form :math:`D_{out} \times H_{out} \times W_{out}`.
  1096. Can be a tuple :math:`(D_{out}, H_{out}, W_{out})` or a single
  1097. :math:`D_{out}` for a cube :math:`D_{out} \times D_{out} \times D_{out}`.
  1098. :math:`D_{out}`, :math:`H_{out}` and :math:`W_{out}` can be either a
  1099. ``int``, or ``None`` which means the size will be the same as that of the input.
  1100. return_indices: if ``True``, will return the indices along with the outputs.
  1101. Useful to pass to nn.MaxUnpool3d. Default: ``False``
  1102. Shape:
  1103. - Input: :math:`(N, C, D_{in}, H_{in}, W_{in})` or :math:`(C, D_{in}, H_{in}, W_{in})`.
  1104. - Output: :math:`(N, C, D_{out}, H_{out}, W_{out})` or :math:`(C, D_{out}, H_{out}, W_{out})`,
  1105. where :math:`(D_{out}, H_{out}, W_{out})=\text{output\_size}`.
  1106. Examples:
  1107. >>> # target output size of 5x7x9
  1108. >>> m = nn.AdaptiveMaxPool3d((5, 7, 9))
  1109. >>> input = torch.randn(1, 64, 8, 9, 10)
  1110. >>> output = m(input)
  1111. >>> # target output size of 7x7x7 (cube)
  1112. >>> m = nn.AdaptiveMaxPool3d(7)
  1113. >>> input = torch.randn(1, 64, 10, 9, 8)
  1114. >>> output = m(input)
  1115. >>> # target output size of 7x9x8
  1116. >>> m = nn.AdaptiveMaxPool3d((7, None, None))
  1117. >>> input = torch.randn(1, 64, 10, 9, 8)
  1118. >>> output = m(input)
  1119. """
  1120. output_size: _size_3_opt_t
  1121. def forward(self, input: Tensor):
  1122. """Runs the forward pass."""
  1123. return F.adaptive_max_pool3d(input, self.output_size, self.return_indices)
  1124. class _AdaptiveAvgPoolNd(Module):
  1125. __constants__ = ["output_size"]
  1126. def __init__(self, output_size: _size_any_opt_t) -> None:
  1127. super().__init__()
  1128. self.output_size = output_size
  1129. def extra_repr(self) -> str:
  1130. return f"output_size={self.output_size}"
  1131. class AdaptiveAvgPool1d(_AdaptiveAvgPoolNd):
  1132. r"""Applies a 1D adaptive average pooling over an input signal composed of several input planes.
  1133. The output size is :math:`L_{out}`, for any input size.
  1134. The number of output features is equal to the number of input planes.
  1135. Args:
  1136. output_size: the target output size :math:`L_{out}`.
  1137. Shape:
  1138. - Input: :math:`(N, C, L_{in})` or :math:`(C, L_{in})`.
  1139. - Output: :math:`(N, C, L_{out})` or :math:`(C, L_{out})`, where
  1140. :math:`L_{out}=\text{output\_size}`.
  1141. Examples:
  1142. >>> # target output size of 5
  1143. >>> m = nn.AdaptiveAvgPool1d(5)
  1144. >>> input = torch.randn(1, 64, 8)
  1145. >>> output = m(input)
  1146. """
  1147. output_size: _size_1_t
  1148. def forward(self, input: Tensor) -> Tensor:
  1149. """
  1150. Runs the forward pass.
  1151. """
  1152. return F.adaptive_avg_pool1d(input, self.output_size)
  1153. class AdaptiveAvgPool2d(_AdaptiveAvgPoolNd):
  1154. r"""Applies a 2D adaptive average pooling over an input signal composed of several input planes.
  1155. The output is of size H x W, for any input size.
  1156. The number of output features is equal to the number of input planes.
  1157. Args:
  1158. output_size: the target output size of the image of the form H x W.
  1159. Can be a tuple (H, W) or a single H for a square image H x H.
  1160. H and W can be either a ``int``, or ``None`` which means the size will
  1161. be the same as that of the input.
  1162. Shape:
  1163. - Input: :math:`(N, C, H_{in}, W_{in})` or :math:`(C, H_{in}, W_{in})`.
  1164. - Output: :math:`(N, C, S_{0}, S_{1})` or :math:`(C, S_{0}, S_{1})`, where
  1165. :math:`S=\text{output\_size}`.
  1166. Examples:
  1167. >>> # target output size of 5x7
  1168. >>> m = nn.AdaptiveAvgPool2d((5, 7))
  1169. >>> input = torch.randn(1, 64, 8, 9)
  1170. >>> output = m(input)
  1171. >>> # target output size of 7x7 (square)
  1172. >>> m = nn.AdaptiveAvgPool2d(7)
  1173. >>> input = torch.randn(1, 64, 10, 9)
  1174. >>> output = m(input)
  1175. >>> # target output size of 10x7
  1176. >>> m = nn.AdaptiveAvgPool2d((None, 7))
  1177. >>> input = torch.randn(1, 64, 10, 9)
  1178. >>> output = m(input)
  1179. """
  1180. output_size: _size_2_opt_t
  1181. def forward(self, input: Tensor) -> Tensor:
  1182. """Runs the forward pass."""
  1183. return F.adaptive_avg_pool2d(input, self.output_size)
  1184. class AdaptiveAvgPool3d(_AdaptiveAvgPoolNd):
  1185. r"""Applies a 3D adaptive average pooling over an input signal composed of several input planes.
  1186. The output is of size D x H x W, for any input size.
  1187. The number of output features is equal to the number of input planes.
  1188. Args:
  1189. output_size: the target output size of the form D x H x W.
  1190. Can be a tuple (D, H, W) or a single number D for a cube D x D x D.
  1191. D, H and W can be either a ``int``, or ``None`` which means the size will
  1192. be the same as that of the input.
  1193. Shape:
  1194. - Input: :math:`(N, C, D_{in}, H_{in}, W_{in})` or :math:`(C, D_{in}, H_{in}, W_{in})`.
  1195. - Output: :math:`(N, C, S_{0}, S_{1}, S_{2})` or :math:`(C, S_{0}, S_{1}, S_{2})`,
  1196. where :math:`S=\text{output\_size}`.
  1197. Examples:
  1198. >>> # target output size of 5x7x9
  1199. >>> m = nn.AdaptiveAvgPool3d((5, 7, 9))
  1200. >>> input = torch.randn(1, 64, 8, 9, 10)
  1201. >>> output = m(input)
  1202. >>> # target output size of 7x7x7 (cube)
  1203. >>> m = nn.AdaptiveAvgPool3d(7)
  1204. >>> input = torch.randn(1, 64, 10, 9, 8)
  1205. >>> output = m(input)
  1206. >>> # target output size of 7x9x8
  1207. >>> m = nn.AdaptiveAvgPool3d((7, None, None))
  1208. >>> input = torch.randn(1, 64, 10, 9, 8)
  1209. >>> output = m(input)
  1210. """
  1211. output_size: _size_3_opt_t
  1212. def forward(self, input: Tensor) -> Tensor:
  1213. """Runs the forward pass."""
  1214. return F.adaptive_avg_pool3d(input, self.output_size)