init.py 24 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752
  1. """This file contains utilities for initializing neural network parameters."""
  2. import math
  3. import warnings
  4. from typing import Callable, Literal, Optional as _Optional, TypeVar, Union
  5. from typing_extensions import ParamSpec
  6. import torch
  7. from torch import Tensor
  8. __all__ = [
  9. "calculate_gain",
  10. "uniform_",
  11. "normal_",
  12. "trunc_normal_",
  13. "constant_",
  14. "ones_",
  15. "zeros_",
  16. "eye_",
  17. "dirac_",
  18. "xavier_uniform_",
  19. "xavier_normal_",
  20. "kaiming_uniform_",
  21. "kaiming_normal_",
  22. "orthogonal_",
  23. "sparse_",
  24. # Deprecated aliases (for backward compatibility)
  25. "uniform",
  26. "normal",
  27. "constant",
  28. "eye",
  29. "dirac",
  30. "xavier_uniform",
  31. "xavier_normal",
  32. "kaiming_uniform",
  33. "kaiming_normal",
  34. "orthogonal",
  35. "sparse",
  36. ]
  37. _R = TypeVar("_R")
  38. _P = ParamSpec("_P")
  39. _NonlinearityType = Literal[
  40. "linear",
  41. "conv1d",
  42. "conv2d",
  43. "conv3d",
  44. "conv_transpose1d",
  45. "conv_transpose2d",
  46. "conv_transpose3d",
  47. "sigmoid",
  48. "tanh",
  49. "relu",
  50. "leaky_relu",
  51. "selu",
  52. ]
  53. _FanMode = Literal["fan_in", "fan_out"]
  54. # These no_grad_* functions are necessary as wrappers around the parts of these
  55. # functions that use `with torch.no_grad()`. The JIT doesn't support context
  56. # managers, so these need to be implemented as builtins. Using these wrappers
  57. # lets us keep those builtins small and reusable.
  58. def _no_grad_uniform_(
  59. tensor: Tensor, a: float, b: float, generator: _Optional[torch.Generator] = None
  60. ) -> Tensor:
  61. with torch.no_grad():
  62. return tensor.uniform_(a, b, generator=generator)
  63. def _no_grad_normal_(
  64. tensor: Tensor,
  65. mean: float,
  66. std: float,
  67. generator: _Optional[torch.Generator] = None,
  68. ) -> Tensor:
  69. with torch.no_grad():
  70. return tensor.normal_(mean, std, generator=generator)
  71. def _no_grad_trunc_normal_(
  72. tensor: Tensor,
  73. mean: float,
  74. std: float,
  75. a: float,
  76. b: float,
  77. generator: _Optional[torch.Generator] = None,
  78. ) -> Tensor:
  79. # Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf
  80. def norm_cdf(x: float) -> float:
  81. # Computes standard normal cumulative distribution function
  82. return (1.0 + math.erf(x / math.sqrt(2.0))) / 2.0
  83. if (mean < a - 2 * std) or (mean > b + 2 * std):
  84. warnings.warn(
  85. "mean is more than 2 std from [a, b] in nn.init.trunc_normal_. "
  86. "The distribution of values may be incorrect.",
  87. stacklevel=2,
  88. )
  89. with torch.no_grad():
  90. # Values are generated by using a truncated uniform distribution and
  91. # then using the inverse CDF for the normal distribution.
  92. # Get upper and lower cdf values
  93. l = norm_cdf((a - mean) / std)
  94. u = norm_cdf((b - mean) / std)
  95. # Uniformly fill tensor with values from [l, u], then translate to
  96. # [2l-1, 2u-1].
  97. tensor.uniform_(2 * l - 1, 2 * u - 1, generator=generator)
  98. # Use inverse cdf transform for normal distribution to get truncated
  99. # standard normal
  100. tensor.erfinv_()
  101. # Transform to proper mean, std
  102. tensor.mul_(std * math.sqrt(2.0))
  103. tensor.add_(mean)
  104. # Clamp to ensure it's in the proper range
  105. tensor.clamp_(min=a, max=b)
  106. return tensor
  107. def _no_grad_fill_(tensor: Tensor, val: float) -> Tensor:
  108. with torch.no_grad():
  109. return tensor.fill_(val)
  110. def _no_grad_zero_(tensor: Tensor) -> Tensor:
  111. with torch.no_grad():
  112. return tensor.zero_()
  113. def calculate_gain(
  114. nonlinearity: _NonlinearityType, param: _Optional[Union[int, float]] = None
  115. ) -> float:
  116. r"""Return the recommended gain value for the given nonlinearity function.
  117. The values are as follows:
  118. ================= ====================================================
  119. nonlinearity gain
  120. ================= ====================================================
  121. Linear / Identity :math:`1`
  122. Conv{1,2,3}D :math:`1`
  123. Sigmoid :math:`1`
  124. Tanh :math:`\frac{5}{3}`
  125. ReLU :math:`\sqrt{2}`
  126. Leaky Relu :math:`\sqrt{\frac{2}{1 + \text{negative\_slope}^2}}`
  127. SELU :math:`\frac{3}{4}`
  128. ================= ====================================================
  129. .. warning::
  130. In order to implement `Self-Normalizing Neural Networks`_ ,
  131. you should use ``nonlinearity='linear'`` instead of ``nonlinearity='selu'``.
  132. This gives the initial weights a variance of ``1 / N``,
  133. which is necessary to induce a stable fixed point in the forward pass.
  134. In contrast, the default gain for ``SELU`` sacrifices the normalization
  135. effect for more stable gradient flow in rectangular layers.
  136. Args:
  137. nonlinearity: the non-linear function (`nn.functional` name)
  138. param: optional parameter for the non-linear function
  139. Examples:
  140. >>> gain = nn.init.calculate_gain(
  141. ... "leaky_relu", 0.2
  142. ... ) # leaky_relu with negative_slope=0.2
  143. .. _Self-Normalizing Neural Networks: https://papers.nips.cc/paper/2017/hash/5d44ee6f2c3f71b73125876103c8f6c4-Abstract.html
  144. """
  145. linear_fns = [
  146. "linear",
  147. "conv1d",
  148. "conv2d",
  149. "conv3d",
  150. "conv_transpose1d",
  151. "conv_transpose2d",
  152. "conv_transpose3d",
  153. ]
  154. if nonlinearity in linear_fns or nonlinearity == "sigmoid":
  155. return 1
  156. elif nonlinearity == "tanh":
  157. return 5.0 / 3
  158. elif nonlinearity == "relu":
  159. return math.sqrt(2.0)
  160. elif nonlinearity == "leaky_relu":
  161. if param is None:
  162. negative_slope = 0.01
  163. elif (
  164. not isinstance(param, bool)
  165. and isinstance(param, int)
  166. or isinstance(param, float)
  167. ):
  168. # True/False are instances of int, hence check above
  169. negative_slope = param
  170. else:
  171. raise ValueError(f"negative_slope {param} not a valid number")
  172. return math.sqrt(2.0 / (1 + negative_slope**2))
  173. elif nonlinearity == "selu":
  174. return (
  175. 3.0 / 4
  176. ) # Value found empirically (https://github.com/pytorch/pytorch/pull/50664)
  177. else:
  178. raise ValueError(f"Unsupported nonlinearity {nonlinearity}")
  179. def uniform_(
  180. tensor: Tensor,
  181. a: float = 0.0,
  182. b: float = 1.0,
  183. generator: _Optional[torch.Generator] = None,
  184. ) -> Tensor:
  185. r"""Fill the input Tensor with values drawn from the uniform distribution.
  186. :math:`\mathcal{U}(a, b)`.
  187. Args:
  188. tensor: an n-dimensional `torch.Tensor`
  189. a: the lower bound of the uniform distribution
  190. b: the upper bound of the uniform distribution
  191. generator: the torch Generator to sample from (default: None)
  192. Examples:
  193. >>> w = torch.empty(3, 5)
  194. >>> nn.init.uniform_(w)
  195. """
  196. if torch.overrides.has_torch_function_variadic(tensor):
  197. return torch.overrides.handle_torch_function(
  198. uniform_, (tensor,), tensor=tensor, a=a, b=b, generator=generator
  199. )
  200. return _no_grad_uniform_(tensor, a, b, generator)
  201. def normal_(
  202. tensor: Tensor,
  203. mean: float = 0.0,
  204. std: float = 1.0,
  205. generator: _Optional[torch.Generator] = None,
  206. ) -> Tensor:
  207. r"""Fill the input Tensor with values drawn from the normal distribution.
  208. :math:`\mathcal{N}(\text{mean}, \text{std}^2)`.
  209. Args:
  210. tensor: an n-dimensional `torch.Tensor`
  211. mean: the mean of the normal distribution
  212. std: the standard deviation of the normal distribution
  213. generator: the torch Generator to sample from (default: None)
  214. Examples:
  215. >>> w = torch.empty(3, 5)
  216. >>> nn.init.normal_(w)
  217. """
  218. if torch.overrides.has_torch_function_variadic(tensor):
  219. return torch.overrides.handle_torch_function(
  220. normal_, (tensor,), tensor=tensor, mean=mean, std=std, generator=generator
  221. )
  222. return _no_grad_normal_(tensor, mean, std, generator)
  223. def trunc_normal_(
  224. tensor: Tensor,
  225. mean: float = 0.0,
  226. std: float = 1.0,
  227. a: float = -2.0,
  228. b: float = 2.0,
  229. generator: _Optional[torch.Generator] = None,
  230. ) -> Tensor:
  231. r"""Fill the input Tensor with values drawn from a truncated normal distribution.
  232. The values are effectively drawn from the
  233. normal distribution :math:`\mathcal{N}(\text{mean}, \text{std}^2)`
  234. with values outside :math:`[a, b]` redrawn until they are within
  235. the bounds. The method used for generating the random values works
  236. best when :math:`a \leq \text{mean} \leq b`.
  237. Args:
  238. tensor: an n-dimensional `torch.Tensor`
  239. mean: the mean of the normal distribution
  240. std: the standard deviation of the normal distribution
  241. a: the minimum cutoff value
  242. b: the maximum cutoff value
  243. generator: the torch Generator to sample from (default: None)
  244. Examples:
  245. >>> w = torch.empty(3, 5)
  246. >>> nn.init.trunc_normal_(w)
  247. """
  248. return _no_grad_trunc_normal_(tensor, mean, std, a, b, generator=generator)
  249. def constant_(tensor: Tensor, val: float) -> Tensor:
  250. r"""Fill the input Tensor with the value :math:`\text{val}`.
  251. Args:
  252. tensor: an n-dimensional `torch.Tensor`
  253. val: the value to fill the tensor with
  254. Examples:
  255. >>> w = torch.empty(3, 5)
  256. >>> nn.init.constant_(w, 0.3)
  257. """
  258. if torch.overrides.has_torch_function_variadic(tensor):
  259. return torch.overrides.handle_torch_function(
  260. constant_, (tensor,), tensor=tensor, val=val
  261. )
  262. return _no_grad_fill_(tensor, val)
  263. def ones_(tensor: Tensor) -> Tensor:
  264. r"""Fill the input Tensor with the scalar value `1`.
  265. Args:
  266. tensor: an n-dimensional `torch.Tensor`
  267. Examples:
  268. >>> w = torch.empty(3, 5)
  269. >>> nn.init.ones_(w)
  270. """
  271. return _no_grad_fill_(tensor, 1.0)
  272. def zeros_(tensor: Tensor) -> Tensor:
  273. r"""Fill the input Tensor with the scalar value `0`.
  274. Args:
  275. tensor: an n-dimensional `torch.Tensor`
  276. Examples:
  277. >>> w = torch.empty(3, 5)
  278. >>> nn.init.zeros_(w)
  279. """
  280. return _no_grad_zero_(tensor)
  281. def eye_(tensor: Tensor) -> Tensor:
  282. r"""Fill the 2-dimensional input `Tensor` with the identity matrix.
  283. Preserves the identity of the inputs in `Linear` layers, where as
  284. many inputs are preserved as possible.
  285. Args:
  286. tensor: a 2-dimensional `torch.Tensor`
  287. Examples:
  288. >>> w = torch.empty(3, 5)
  289. >>> nn.init.eye_(w)
  290. """
  291. if tensor.ndimension() != 2:
  292. raise ValueError("Only tensors with 2 dimensions are supported")
  293. with torch.no_grad():
  294. torch.eye(*tensor.shape, out=tensor, requires_grad=tensor.requires_grad)
  295. return tensor
  296. def dirac_(tensor: Tensor, groups: int = 1) -> Tensor:
  297. r"""Fill the {3, 4, 5}-dimensional input `Tensor` with the Dirac delta function.
  298. Preserves the identity of the inputs in `Convolutional`
  299. layers, where as many input channels are preserved as possible. In case
  300. of groups>1, each group of channels preserves identity
  301. Args:
  302. tensor: a {3, 4, 5}-dimensional `torch.Tensor`
  303. groups (int, optional): number of groups in the conv layer (default: 1)
  304. Examples:
  305. >>> w = torch.empty(3, 16, 5, 5)
  306. >>> nn.init.dirac_(w)
  307. >>> w = torch.empty(3, 24, 5, 5)
  308. >>> nn.init.dirac_(w, 3)
  309. """
  310. dimensions = tensor.ndimension()
  311. if dimensions not in [3, 4, 5]:
  312. raise ValueError("Only tensors with 3, 4, or 5 dimensions are supported")
  313. sizes = tensor.size()
  314. if sizes[0] % groups != 0:
  315. raise ValueError("dim 0 must be divisible by groups")
  316. out_chans_per_grp = sizes[0] // groups
  317. min_dim = min(out_chans_per_grp, sizes[1])
  318. with torch.no_grad():
  319. tensor.zero_()
  320. for g in range(groups):
  321. for d in range(min_dim):
  322. if dimensions == 3: # Temporal convolution
  323. tensor[g * out_chans_per_grp + d, d, tensor.size(2) // 2] = 1
  324. elif dimensions == 4: # Spatial convolution
  325. tensor[
  326. g * out_chans_per_grp + d,
  327. d,
  328. tensor.size(2) // 2,
  329. tensor.size(3) // 2,
  330. ] = 1
  331. else: # Volumetric convolution
  332. tensor[
  333. g * out_chans_per_grp + d,
  334. d,
  335. tensor.size(2) // 2,
  336. tensor.size(3) // 2,
  337. tensor.size(4) // 2,
  338. ] = 1
  339. return tensor
  340. def _calculate_fan_in_and_fan_out(tensor: Tensor) -> tuple[int, int]:
  341. dimensions = tensor.dim()
  342. if dimensions < 2:
  343. raise ValueError(
  344. "Fan in and fan out can not be computed for tensor with fewer than 2 dimensions"
  345. )
  346. num_input_fmaps = tensor.size(1)
  347. num_output_fmaps = tensor.size(0)
  348. receptive_field_size = 1
  349. if tensor.dim() > 2:
  350. # math.prod is not always available, accumulate the product manually
  351. # we could use functools.reduce but that is not supported by TorchScript
  352. for s in tensor.shape[2:]:
  353. receptive_field_size *= s
  354. fan_in = num_input_fmaps * receptive_field_size
  355. fan_out = num_output_fmaps * receptive_field_size
  356. return fan_in, fan_out
  357. def xavier_uniform_(
  358. tensor: Tensor,
  359. gain: float = 1.0,
  360. generator: _Optional[torch.Generator] = None,
  361. ) -> Tensor:
  362. r"""Fill the input `Tensor` with values using a Xavier uniform distribution.
  363. The method is described in `Understanding the difficulty of training
  364. deep feedforward neural networks` - Glorot, X. & Bengio, Y. (2010).
  365. The resulting tensor will have values sampled from
  366. :math:`\mathcal{U}(-a, a)` where
  367. .. math::
  368. a = \text{gain} \times \sqrt{\frac{6}{\text{fan\_in} + \text{fan\_out}}}
  369. Also known as Glorot initialization.
  370. Args:
  371. tensor: an n-dimensional `torch.Tensor`
  372. gain: an optional scaling factor
  373. generator: the torch Generator to sample from (default: None)
  374. Examples:
  375. >>> w = torch.empty(3, 5)
  376. >>> nn.init.xavier_uniform_(w, gain=nn.init.calculate_gain("relu"))
  377. """
  378. fan_in, fan_out = _calculate_fan_in_and_fan_out(tensor)
  379. std = gain * math.sqrt(2.0 / float(fan_in + fan_out))
  380. a = math.sqrt(3.0) * std # Calculate uniform bounds from standard deviation
  381. return _no_grad_uniform_(tensor, -a, a, generator)
  382. def xavier_normal_(
  383. tensor: Tensor,
  384. gain: float = 1.0,
  385. generator: _Optional[torch.Generator] = None,
  386. ) -> Tensor:
  387. r"""Fill the input `Tensor` with values using a Xavier normal distribution.
  388. The method is described in `Understanding the difficulty of training deep feedforward
  389. neural networks` - Glorot, X. & Bengio, Y. (2010). The resulting tensor
  390. will have values sampled from :math:`\mathcal{N}(0, \text{std}^2)` where
  391. .. math::
  392. \text{std} = \text{gain} \times \sqrt{\frac{2}{\text{fan\_in} + \text{fan\_out}}}
  393. Also known as Glorot initialization.
  394. Args:
  395. tensor: an n-dimensional `torch.Tensor`
  396. gain: an optional scaling factor
  397. generator: the torch Generator to sample from (default: None)
  398. Examples:
  399. >>> w = torch.empty(3, 5)
  400. >>> nn.init.xavier_normal_(w)
  401. """
  402. fan_in, fan_out = _calculate_fan_in_and_fan_out(tensor)
  403. std = gain * math.sqrt(2.0 / float(fan_in + fan_out))
  404. return _no_grad_normal_(tensor, 0.0, std, generator)
  405. def _calculate_correct_fan(tensor: Tensor, mode: _FanMode) -> int:
  406. mode = mode.lower()
  407. valid_modes = ["fan_in", "fan_out"]
  408. if mode not in valid_modes:
  409. raise ValueError(f"Mode {mode} not supported, please use one of {valid_modes}")
  410. fan_in, fan_out = _calculate_fan_in_and_fan_out(tensor)
  411. return fan_in if mode == "fan_in" else fan_out
  412. def kaiming_uniform_(
  413. tensor: Tensor,
  414. a: float = 0,
  415. mode: _FanMode = "fan_in",
  416. nonlinearity: _NonlinearityType = "leaky_relu",
  417. generator: _Optional[torch.Generator] = None,
  418. ) -> Tensor:
  419. r"""Fill the input `Tensor` with values using a Kaiming uniform distribution.
  420. The method is described in `Delving deep into rectifiers: Surpassing
  421. human-level performance on ImageNet classification` - He, K. et al. (2015).
  422. The resulting tensor will have values sampled from
  423. :math:`\mathcal{U}(-\text{bound}, \text{bound})` where
  424. .. math::
  425. \text{bound} = \text{gain} \times \sqrt{\frac{3}{\text{fan\_mode}}}
  426. Also known as He initialization.
  427. Args:
  428. tensor: an n-dimensional `torch.Tensor`
  429. a: the negative slope of the rectifier used after this layer (only
  430. used with ``'leaky_relu'``)
  431. mode: either ``'fan_in'`` (default) or ``'fan_out'``. Choosing ``'fan_in'``
  432. preserves the magnitude of the variance of the weights in the
  433. forward pass. Choosing ``'fan_out'`` preserves the magnitudes in the
  434. backwards pass.
  435. nonlinearity: the non-linear function (`nn.functional` name),
  436. recommended to use only with ``'relu'`` or ``'leaky_relu'`` (default).
  437. generator: the torch Generator to sample from (default: None)
  438. Examples:
  439. >>> w = torch.empty(3, 5)
  440. >>> nn.init.kaiming_uniform_(w, mode="fan_in", nonlinearity="relu")
  441. Note:
  442. Be aware that ``fan_in`` and ``fan_out`` are calculated assuming
  443. that the weight matrix is used in a transposed manner,
  444. (i.e., ``x @ w.T`` in ``Linear`` layers, where ``w.shape = [fan_out, fan_in]``).
  445. This is important for correct initialization.
  446. If you plan to use ``x @ w``, where ``w.shape = [fan_in, fan_out]``,
  447. pass in a transposed weight matrix, i.e. ``nn.init.kaiming_uniform_(w.T, ...)``.
  448. """
  449. if torch.overrides.has_torch_function_variadic(tensor):
  450. return torch.overrides.handle_torch_function(
  451. kaiming_uniform_,
  452. (tensor,),
  453. tensor=tensor,
  454. a=a,
  455. mode=mode,
  456. nonlinearity=nonlinearity,
  457. generator=generator,
  458. )
  459. if 0 in tensor.shape:
  460. warnings.warn("Initializing zero-element tensors is a no-op")
  461. return tensor
  462. fan = _calculate_correct_fan(tensor, mode)
  463. gain = calculate_gain(nonlinearity, a)
  464. std = gain / math.sqrt(fan)
  465. bound = math.sqrt(3.0) * std # Calculate uniform bounds from standard deviation
  466. with torch.no_grad():
  467. return tensor.uniform_(-bound, bound, generator=generator)
  468. def kaiming_normal_(
  469. tensor: Tensor,
  470. a: float = 0,
  471. mode: _FanMode = "fan_in",
  472. nonlinearity: _NonlinearityType = "leaky_relu",
  473. generator: _Optional[torch.Generator] = None,
  474. ) -> Tensor:
  475. r"""Fill the input `Tensor` with values using a Kaiming normal distribution.
  476. The method is described in `Delving deep into rectifiers: Surpassing
  477. human-level performance on ImageNet classification` - He, K. et al. (2015).
  478. The resulting tensor will have values sampled from
  479. :math:`\mathcal{N}(0, \text{std}^2)` where
  480. .. math::
  481. \text{std} = \frac{\text{gain}}{\sqrt{\text{fan\_mode}}}
  482. Also known as He initialization.
  483. Args:
  484. tensor: an n-dimensional `torch.Tensor`
  485. a: the negative slope of the rectifier used after this layer (only
  486. used with ``'leaky_relu'``)
  487. mode: either ``'fan_in'`` (default) or ``'fan_out'``. Choosing ``'fan_in'``
  488. preserves the magnitude of the variance of the weights in the
  489. forward pass. Choosing ``'fan_out'`` preserves the magnitudes in the
  490. backwards pass.
  491. nonlinearity: the non-linear function (`nn.functional` name),
  492. recommended to use only with ``'relu'`` or ``'leaky_relu'`` (default).
  493. generator: the torch Generator to sample from (default: None)
  494. Examples:
  495. >>> w = torch.empty(3, 5)
  496. >>> nn.init.kaiming_normal_(w, mode="fan_out", nonlinearity="relu")
  497. Note:
  498. Be aware that ``fan_in`` and ``fan_out`` are calculated assuming
  499. that the weight matrix is used in a transposed manner,
  500. (i.e., ``x @ w.T`` in ``Linear`` layers, where ``w.shape = [fan_out, fan_in]``).
  501. This is important for correct initialization.
  502. If you plan to use ``x @ w``, where ``w.shape = [fan_in, fan_out]``,
  503. pass in a transposed weight matrix, i.e. ``nn.init.kaiming_normal_(w.T, ...)``.
  504. """
  505. if 0 in tensor.shape:
  506. warnings.warn("Initializing zero-element tensors is a no-op")
  507. return tensor
  508. fan = _calculate_correct_fan(tensor, mode)
  509. gain = calculate_gain(nonlinearity, a)
  510. std = gain / math.sqrt(fan)
  511. with torch.no_grad():
  512. return tensor.normal_(0, std, generator=generator)
  513. def orthogonal_(
  514. tensor: Tensor,
  515. gain: float = 1,
  516. generator: _Optional[torch.Generator] = None,
  517. ) -> Tensor:
  518. r"""Fill the input `Tensor` with a (semi) orthogonal matrix.
  519. Described in `Exact solutions to the nonlinear dynamics of learning in deep
  520. linear neural networks` - Saxe, A. et al. (2013). The input tensor must have
  521. at least 2 dimensions, and for tensors with more than 2 dimensions the
  522. trailing dimensions are flattened.
  523. Args:
  524. tensor: an n-dimensional `torch.Tensor`, where :math:`n \geq 2`
  525. gain: optional scaling factor
  526. generator: the torch Generator to sample from (default: None)
  527. Examples:
  528. >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_LAPACK)
  529. >>> w = torch.empty(3, 5)
  530. >>> nn.init.orthogonal_(w)
  531. """
  532. if tensor.ndimension() < 2:
  533. raise ValueError("Only tensors with 2 or more dimensions are supported")
  534. if tensor.numel() == 0:
  535. # no-op
  536. return tensor
  537. rows = tensor.size(0)
  538. cols = tensor.numel() // rows
  539. flattened = tensor.new_empty((rows, cols)).normal_(0, 1, generator=generator)
  540. if rows < cols:
  541. flattened.t_()
  542. # Compute the qr factorization
  543. q, r = torch.linalg.qr(flattened)
  544. # Make Q uniform according to https://arxiv.org/pdf/math-ph/0609050.pdf
  545. d = torch.diag(r, 0)
  546. ph = d.sign()
  547. q *= ph
  548. if rows < cols:
  549. q.t_()
  550. with torch.no_grad():
  551. tensor.view_as(q).copy_(q)
  552. tensor.mul_(gain)
  553. return tensor
  554. def sparse_(
  555. tensor: Tensor,
  556. sparsity: float,
  557. std: float = 0.01,
  558. generator: _Optional[torch.Generator] = None,
  559. ) -> Tensor:
  560. r"""Fill the 2D input `Tensor` as a sparse matrix.
  561. The non-zero elements will be drawn from the normal distribution
  562. :math:`\mathcal{N}(0, 0.01)`, as described in `Deep learning via
  563. Hessian-free optimization` - Martens, J. (2010).
  564. Args:
  565. tensor: an n-dimensional `torch.Tensor`
  566. sparsity: The fraction of elements in each column to be set to zero
  567. std: the standard deviation of the normal distribution used to generate
  568. the non-zero values
  569. generator: the torch Generator to sample from (default: None)
  570. Examples:
  571. >>> w = torch.empty(3, 5)
  572. >>> nn.init.sparse_(w, sparsity=0.1)
  573. """
  574. if tensor.ndimension() != 2:
  575. raise ValueError("Only tensors with 2 dimensions are supported")
  576. rows, cols = tensor.shape
  577. num_zeros = int(math.ceil(sparsity * rows))
  578. with torch.no_grad():
  579. tensor.normal_(0, std, generator=generator)
  580. for col_idx in range(cols):
  581. row_indices = torch.randperm(rows)
  582. zero_indices = row_indices[:num_zeros]
  583. tensor[zero_indices, col_idx] = 0
  584. return tensor
  585. # for backward compatibility
  586. def _make_deprecate(meth: Callable[_P, _R]) -> Callable[_P, _R]:
  587. new_name = meth.__name__
  588. old_name = new_name[:-1]
  589. def deprecated_init(*args: _P.args, **kwargs: _P.kwargs) -> _R:
  590. warnings.warn(
  591. f"`nn.init.{old_name}` is now deprecated in favor of `nn.init.{new_name}`.",
  592. FutureWarning,
  593. stacklevel=2,
  594. )
  595. return meth(*args, **kwargs)
  596. deprecated_init.__doc__ = rf"""
  597. {old_name}(...)
  598. .. warning::
  599. This method is now deprecated in favor of :func:`torch.nn.init.{new_name}`.
  600. See :func:`~torch.nn.init.{new_name}` for details."""
  601. deprecated_init.__name__ = old_name
  602. return deprecated_init
  603. uniform = _make_deprecate(uniform_)
  604. normal = _make_deprecate(normal_)
  605. constant = _make_deprecate(constant_)
  606. eye = _make_deprecate(eye_)
  607. dirac = _make_deprecate(dirac_)
  608. xavier_uniform = _make_deprecate(xavier_uniform_)
  609. xavier_normal = _make_deprecate(xavier_normal_)
  610. kaiming_uniform = _make_deprecate(kaiming_uniform_)
  611. kaiming_normal = _make_deprecate(kaiming_normal_)
  612. orthogonal = _make_deprecate(orthogonal_)
  613. sparse = _make_deprecate(sparse_)