container.py 36 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028
  1. # mypy: allow-untyped-defs
  2. from __future__ import annotations
  3. import operator
  4. from collections import abc as container_abcs, OrderedDict
  5. from itertools import chain, islice
  6. from typing import Any, Optional, overload, TYPE_CHECKING, TypeVar, Union
  7. from typing_extensions import deprecated, Self
  8. import torch
  9. from torch._jit_internal import _copy_to_script_wrapper
  10. from torch.nn.parameter import Parameter
  11. from .module import Module
  12. if TYPE_CHECKING:
  13. from collections.abc import Iterable, Iterator, Mapping
  14. __all__ = [
  15. "Container",
  16. "Sequential",
  17. "ModuleList",
  18. "ModuleDict",
  19. "ParameterList",
  20. "ParameterDict",
  21. ]
  22. T = TypeVar("T", bound=Module)
  23. _V = TypeVar("_V")
  24. # Copied from torch.nn.modules.module, required for a custom __repr__ for ModuleList
  25. def _addindent(s_, numSpaces):
  26. s = s_.split("\n")
  27. # don't do anything for single-line stuff
  28. if len(s) == 1:
  29. return s_
  30. first = s.pop(0)
  31. s = [(numSpaces * " ") + line for line in s]
  32. s = "\n".join(s)
  33. s = first + "\n" + s
  34. return s
  35. @deprecated(
  36. "`nn.Container` is deprecated. "
  37. "All of it's functionality is now implemented in `nn.Module`. Subclass that instead.",
  38. category=FutureWarning,
  39. )
  40. class Container(Module):
  41. def __init__(self, **kwargs: Any) -> None:
  42. super().__init__()
  43. for key, value in kwargs.items():
  44. self.add_module(key, value)
  45. class Sequential(Module):
  46. r"""A sequential container.
  47. Modules will be added to it in the order they are passed in the
  48. constructor. Alternatively, an ``OrderedDict`` of modules can be
  49. passed in. The ``forward()`` method of ``Sequential`` accepts any
  50. input and forwards it to the first module it contains. It then
  51. "chains" outputs to inputs sequentially for each subsequent module,
  52. finally returning the output of the last module.
  53. The value a ``Sequential`` provides over manually calling a sequence
  54. of modules is that it allows treating the whole container as a
  55. single module, such that performing a transformation on the
  56. ``Sequential`` applies to each of the modules it stores (which are
  57. each a registered submodule of the ``Sequential``).
  58. What's the difference between a ``Sequential`` and a
  59. :class:`torch.nn.ModuleList`? A ``ModuleList`` is exactly what it
  60. sounds like--a list for storing ``Module`` s! On the other hand,
  61. the layers in a ``Sequential`` are connected in a cascading way.
  62. Example::
  63. # Using Sequential to create a small model. When `model` is run,
  64. # input will first be passed to `Conv2d(1,20,5)`. The output of
  65. # `Conv2d(1,20,5)` will be used as the input to the first
  66. # `ReLU`; the output of the first `ReLU` will become the input
  67. # for `Conv2d(20,64,5)`. Finally, the output of
  68. # `Conv2d(20,64,5)` will be used as input to the second `ReLU`
  69. model = nn.Sequential(
  70. nn.Conv2d(1, 20, 5), nn.ReLU(), nn.Conv2d(20, 64, 5), nn.ReLU()
  71. )
  72. # Using Sequential with OrderedDict. This is functionally the
  73. # same as the above code
  74. model = nn.Sequential(
  75. OrderedDict(
  76. [
  77. ("conv1", nn.Conv2d(1, 20, 5)),
  78. ("relu1", nn.ReLU()),
  79. ("conv2", nn.Conv2d(20, 64, 5)),
  80. ("relu2", nn.ReLU()),
  81. ]
  82. )
  83. )
  84. """
  85. _modules: dict[str, Module] # type: ignore[assignment]
  86. @overload
  87. def __init__(self, *args: Module) -> None: ...
  88. @overload
  89. def __init__(self, arg: OrderedDict[str, Module]) -> None: ...
  90. def __init__(self, *args):
  91. super().__init__()
  92. if len(args) == 1 and isinstance(args[0], OrderedDict):
  93. for key, module in args[0].items():
  94. self.add_module(key, module)
  95. else:
  96. for idx, module in enumerate(args):
  97. self.add_module(str(idx), module)
  98. def _get_item_by_idx(self, iterator: Iterable[_V], idx: int) -> _V:
  99. """Get the idx-th item of the iterator."""
  100. size = len(self)
  101. idx = operator.index(idx)
  102. if not -size <= idx < size:
  103. raise IndexError(f"index {idx} is out of range")
  104. idx %= size
  105. return next(islice(iterator, idx, None))
  106. @_copy_to_script_wrapper
  107. def __getitem__(self, idx: Union[slice, int]) -> Union[Sequential, Module]:
  108. if isinstance(idx, slice):
  109. return self.__class__(OrderedDict(list(self._modules.items())[idx]))
  110. else:
  111. return self._get_item_by_idx(self._modules.values(), idx)
  112. def __setitem__(self, idx: int, module: Module) -> None:
  113. key: str = self._get_item_by_idx(self._modules.keys(), idx)
  114. return setattr(self, key, module)
  115. def __delitem__(self, idx: Union[slice, int]) -> None:
  116. if isinstance(idx, slice):
  117. for key in list(self._modules.keys())[idx]:
  118. delattr(self, key)
  119. else:
  120. key = self._get_item_by_idx(self._modules.keys(), idx)
  121. delattr(self, key)
  122. # To preserve numbering
  123. str_indices = [str(i) for i in range(len(self._modules))]
  124. self._modules = OrderedDict(list(zip(str_indices, self._modules.values())))
  125. @_copy_to_script_wrapper
  126. def __len__(self) -> int:
  127. return len(self._modules)
  128. def __add__(self, other) -> Sequential:
  129. if isinstance(other, Sequential):
  130. ret = Sequential()
  131. for layer in self:
  132. ret.append(layer)
  133. for layer in other:
  134. ret.append(layer)
  135. return ret
  136. else:
  137. raise ValueError(
  138. "add operator supports only objects "
  139. f"of Sequential class, but {str(type(other))} is given."
  140. )
  141. def pop(self, key: Union[int, slice]) -> Module:
  142. """
  143. Pop ``key`` from self.
  144. """
  145. v = self[key]
  146. del self[key]
  147. return v
  148. def __iadd__(self, other) -> Self:
  149. if isinstance(other, Sequential):
  150. offset = len(self)
  151. for i, module in enumerate(other):
  152. self.add_module(str(i + offset), module)
  153. return self
  154. else:
  155. raise ValueError(
  156. "add operator supports only objects "
  157. f"of Sequential class, but {str(type(other))} is given."
  158. )
  159. def __mul__(self, other: int) -> Sequential:
  160. if not isinstance(other, int):
  161. raise TypeError(
  162. f"unsupported operand type(s) for *: {type(self)} and {type(other)}"
  163. )
  164. elif other <= 0:
  165. raise ValueError(
  166. f"Non-positive multiplication factor {other} for {type(self)}"
  167. )
  168. else:
  169. combined = Sequential()
  170. offset = 0
  171. for _ in range(other):
  172. for module in self:
  173. combined.add_module(str(offset), module)
  174. offset += 1
  175. return combined
  176. def __rmul__(self, other: int) -> Sequential:
  177. return self.__mul__(other)
  178. def __imul__(self, other: int) -> Self:
  179. if not isinstance(other, int):
  180. raise TypeError(
  181. f"unsupported operand type(s) for *: {type(self)} and {type(other)}"
  182. )
  183. elif other <= 0:
  184. raise ValueError(
  185. f"Non-positive multiplication factor {other} for {type(self)}"
  186. )
  187. else:
  188. len_original = len(self)
  189. offset = len(self)
  190. for _ in range(other - 1):
  191. for i in range(len_original):
  192. self.add_module(str(i + offset), self._modules[str(i)])
  193. offset += len_original
  194. return self
  195. @_copy_to_script_wrapper
  196. def __dir__(self) -> list[str]:
  197. keys = super().__dir__()
  198. keys = [key for key in keys if not key.isdigit()]
  199. return keys
  200. @_copy_to_script_wrapper
  201. def __iter__(self) -> Iterator[Module]:
  202. return iter(self._modules.values())
  203. # NB: We can't really type check this function as the type of input
  204. # may change dynamically (as is tested in
  205. # TestScript.test_sequential_intermediary_types). Cannot annotate
  206. # with Any as TorchScript expects a more precise type
  207. def forward(self, input):
  208. """
  209. Runs the forward pass.
  210. """
  211. for module in self:
  212. input = module(input)
  213. return input
  214. def append(self, module: Module) -> Self:
  215. r"""Append a given module to the end.
  216. Args:
  217. module (nn.Module): module to append
  218. Example::
  219. >>> import torch.nn as nn
  220. >>> n = nn.Sequential(nn.Linear(1, 2), nn.Linear(2, 3))
  221. >>> n.append(nn.Linear(3, 4))
  222. Sequential(
  223. (0): Linear(in_features=1, out_features=2, bias=True)
  224. (1): Linear(in_features=2, out_features=3, bias=True)
  225. (2): Linear(in_features=3, out_features=4, bias=True)
  226. )
  227. """
  228. self.add_module(str(len(self)), module)
  229. return self
  230. def insert(self, index: int, module: Module) -> Self:
  231. """
  232. Inserts a module into the Sequential container at the specified index.
  233. Args:
  234. index (int): The index to insert the module.
  235. module (Module): The module to be inserted.
  236. Example::
  237. >>> import torch.nn as nn
  238. >>> n = nn.Sequential(nn.Linear(1, 2), nn.Linear(2, 3))
  239. >>> n.insert(0, nn.Linear(3, 4))
  240. Sequential(
  241. (0): Linear(in_features=3, out_features=4, bias=True)
  242. (1): Linear(in_features=1, out_features=2, bias=True)
  243. (2): Linear(in_features=2, out_features=3, bias=True)
  244. )
  245. """
  246. if not isinstance(module, Module):
  247. raise AssertionError(f"module should be of type: {Module}")
  248. n = len(self._modules)
  249. if not (-n <= index <= n):
  250. raise IndexError(f"Index out of range: {index}")
  251. if index < 0:
  252. index += n
  253. for i in range(n, index, -1):
  254. self._modules[str(i)] = self._modules[str(i - 1)]
  255. self._modules[str(index)] = module
  256. return self
  257. def extend(self, sequential: Iterable[Module]) -> Self:
  258. """
  259. Extends the current Sequential container with layers from another Sequential container.
  260. Args:
  261. sequential (Sequential): A Sequential container whose layers will be added to the current container.
  262. Example::
  263. >>> import torch.nn as nn
  264. >>> n = nn.Sequential(nn.Linear(1, 2), nn.Linear(2, 3))
  265. >>> other = nn.Sequential(nn.Linear(3, 4), nn.Linear(4, 5))
  266. >>> n.extend(other) # or `n + other`
  267. Sequential(
  268. (0): Linear(in_features=1, out_features=2, bias=True)
  269. (1): Linear(in_features=2, out_features=3, bias=True)
  270. (2): Linear(in_features=3, out_features=4, bias=True)
  271. (3): Linear(in_features=4, out_features=5, bias=True)
  272. )
  273. """
  274. for layer in sequential:
  275. self.append(layer)
  276. return self
  277. class ModuleList(Module):
  278. r"""Holds submodules in a list.
  279. :class:`~torch.nn.ModuleList` can be indexed like a regular Python list, but
  280. modules it contains are properly registered, and will be visible by all
  281. :class:`~torch.nn.Module` methods.
  282. Args:
  283. modules (iterable, optional): an iterable of modules to add
  284. Example::
  285. class MyModule(nn.Module):
  286. def __init__(self) -> None:
  287. super().__init__()
  288. self.linears = nn.ModuleList([nn.Linear(10, 10) for i in range(10)])
  289. def forward(self, x):
  290. # ModuleList can act as an iterable, or be indexed using ints
  291. for i, l in enumerate(self.linears):
  292. x = self.linears[i // 2](x) + l(x)
  293. return x
  294. """
  295. _modules: dict[str, Module] # type: ignore[assignment]
  296. def __init__(self, modules: Optional[Iterable[Module]] = None) -> None:
  297. super().__init__()
  298. if modules is not None:
  299. self += modules
  300. def _get_abs_string_index(self, idx):
  301. """Get the absolute index for the list of modules."""
  302. idx = operator.index(idx)
  303. if not (-len(self) <= idx < len(self)):
  304. raise IndexError(f"index {idx} is out of range")
  305. if idx < 0:
  306. idx += len(self)
  307. return str(idx)
  308. @overload
  309. def __getitem__(self, idx: slice) -> ModuleList: ...
  310. @overload
  311. def __getitem__(self, idx: int) -> Module: ...
  312. @_copy_to_script_wrapper
  313. def __getitem__(self, idx: Union[int, slice]) -> Union[Module, ModuleList]:
  314. if isinstance(idx, slice):
  315. return self.__class__(list(self._modules.values())[idx])
  316. else:
  317. return self._modules[self._get_abs_string_index(idx)]
  318. def __setitem__(self, idx: int, module: Module) -> None:
  319. idx = self._get_abs_string_index(idx)
  320. return setattr(self, str(idx), module)
  321. def __delitem__(self, idx: Union[int, slice]) -> None:
  322. if isinstance(idx, slice):
  323. for k in range(len(self._modules))[idx]:
  324. delattr(self, str(k))
  325. else:
  326. delattr(self, self._get_abs_string_index(idx))
  327. # To preserve numbering, self._modules is being reconstructed with modules after deletion
  328. str_indices = [str(i) for i in range(len(self._modules))]
  329. self._modules = OrderedDict(list(zip(str_indices, self._modules.values())))
  330. @_copy_to_script_wrapper
  331. def __len__(self) -> int:
  332. return len(self._modules)
  333. @_copy_to_script_wrapper
  334. def __iter__(self) -> Iterator[Module]:
  335. return iter(self._modules.values())
  336. def __iadd__(self, modules: Iterable[Module]) -> Self:
  337. return self.extend(modules)
  338. def __add__(self, other: Iterable[Module]) -> ModuleList:
  339. combined = ModuleList()
  340. for i, module in enumerate(chain(self, other)):
  341. combined.add_module(str(i), module)
  342. return combined
  343. def __repr__(self) -> str:
  344. """Return a custom repr for ModuleList that compresses repeated module representations."""
  345. list_of_reprs = [repr(item) for item in self]
  346. if len(list_of_reprs) == 0:
  347. return self._get_name() + "()"
  348. start_end_indices = [[0, 0]]
  349. repeated_blocks = [list_of_reprs[0]]
  350. for i, r in enumerate(list_of_reprs[1:], 1):
  351. if r == repeated_blocks[-1]:
  352. start_end_indices[-1][1] += 1
  353. continue
  354. start_end_indices.append([i, i])
  355. repeated_blocks.append(r)
  356. lines = []
  357. main_str = self._get_name() + "("
  358. for (start_id, end_id), b in zip(start_end_indices, repeated_blocks):
  359. local_repr = f"({start_id}): {b}" # default repr
  360. if start_id != end_id:
  361. n = end_id - start_id + 1
  362. local_repr = f"({start_id}-{end_id}): {n} x {b}"
  363. local_repr = _addindent(local_repr, 2)
  364. lines.append(local_repr)
  365. main_str += "\n " + "\n ".join(lines) + "\n"
  366. main_str += ")"
  367. return main_str
  368. @_copy_to_script_wrapper
  369. def __dir__(self) -> list[str]:
  370. keys = super().__dir__()
  371. keys = [key for key in keys if not key.isdigit()]
  372. return keys
  373. def insert(self, index: int, module: Module) -> None:
  374. r"""Insert a given module before a given index in the list.
  375. Args:
  376. index (int): index to insert.
  377. module (nn.Module): module to insert
  378. """
  379. for i in range(len(self._modules), index, -1):
  380. self._modules[str(i)] = self._modules[str(i - 1)]
  381. self._modules[str(index)] = module
  382. def append(self, module: Module) -> Self:
  383. r"""Append a given module to the end of the list.
  384. Args:
  385. module (nn.Module): module to append
  386. """
  387. self.add_module(str(len(self)), module)
  388. return self
  389. def pop(self, key: Union[int, slice]) -> Module:
  390. v = self[key]
  391. del self[key]
  392. return v
  393. def extend(self, modules: Iterable[Module]) -> Self:
  394. r"""Append modules from a Python iterable to the end of the list.
  395. Args:
  396. modules (iterable): iterable of modules to append
  397. """
  398. if not isinstance(modules, container_abcs.Iterable):
  399. raise TypeError(
  400. "ModuleList.extend should be called with an "
  401. "iterable, but got " + type(modules).__name__
  402. )
  403. offset = len(self)
  404. for i, module in enumerate(modules):
  405. self.add_module(str(offset + i), module)
  406. return self
  407. # remove forward altogether to fallback on Module's _forward_unimplemented
  408. class ModuleDict(Module):
  409. r"""Holds submodules in a dictionary.
  410. :class:`~torch.nn.ModuleDict` can be indexed like a regular Python dictionary,
  411. but modules it contains are properly registered, and will be visible by all
  412. :class:`~torch.nn.Module` methods.
  413. :class:`~torch.nn.ModuleDict` is an **ordered** dictionary that respects
  414. * the order of insertion, and
  415. * in :meth:`~torch.nn.ModuleDict.update`, the order of the merged
  416. ``OrderedDict``, ``dict`` (started from Python 3.6) or another
  417. :class:`~torch.nn.ModuleDict` (the argument to
  418. :meth:`~torch.nn.ModuleDict.update`).
  419. Note that :meth:`~torch.nn.ModuleDict.update` with other unordered mapping
  420. types (e.g., Python's plain ``dict`` before Python version 3.6) does not
  421. preserve the order of the merged mapping.
  422. Args:
  423. modules (iterable, optional): a mapping (dictionary) of (string: module)
  424. or an iterable of key-value pairs of type (string, module)
  425. Example::
  426. class MyModule(nn.Module):
  427. def __init__(self) -> None:
  428. super().__init__()
  429. self.choices = nn.ModuleDict(
  430. {"conv": nn.Conv2d(10, 10, 3), "pool": nn.MaxPool2d(3)}
  431. )
  432. self.activations = nn.ModuleDict(
  433. [["lrelu", nn.LeakyReLU()], ["prelu", nn.PReLU()]]
  434. )
  435. def forward(self, x, choice, act):
  436. x = self.choices[choice](x)
  437. x = self.activations[act](x)
  438. return x
  439. """
  440. _modules: dict[str, Module] # type: ignore[assignment]
  441. def __init__(self, modules: Optional[Mapping[str, Module]] = None) -> None:
  442. super().__init__()
  443. if modules is not None:
  444. self.update(modules)
  445. @_copy_to_script_wrapper
  446. def __getitem__(self, key: str) -> Module:
  447. return self._modules[key]
  448. def __setitem__(self, key: str, module: Module) -> None:
  449. self.add_module(key, module)
  450. def __delitem__(self, key: str) -> None:
  451. del self._modules[key]
  452. @_copy_to_script_wrapper
  453. def __len__(self) -> int:
  454. return len(self._modules)
  455. @_copy_to_script_wrapper
  456. def __iter__(self) -> Iterator[str]:
  457. return iter(self._modules)
  458. @_copy_to_script_wrapper
  459. def __contains__(self, key: str) -> bool:
  460. return key in self._modules
  461. def clear(self) -> None:
  462. """Remove all items from the ModuleDict."""
  463. self._modules.clear()
  464. def pop(self, key: str) -> Module:
  465. r"""Remove key from the ModuleDict and return its module.
  466. Args:
  467. key (str): key to pop from the ModuleDict
  468. """
  469. v = self[key]
  470. del self[key]
  471. return v
  472. @_copy_to_script_wrapper
  473. def keys(self) -> container_abcs.KeysView[str]:
  474. r"""Return an iterable of the ModuleDict keys."""
  475. return self._modules.keys()
  476. @_copy_to_script_wrapper
  477. def items(self) -> container_abcs.ItemsView[str, Module]:
  478. r"""Return an iterable of the ModuleDict key/value pairs."""
  479. return self._modules.items()
  480. @_copy_to_script_wrapper
  481. def values(self) -> container_abcs.ValuesView[Module]:
  482. r"""Return an iterable of the ModuleDict values."""
  483. return self._modules.values()
  484. def update(self, modules: Mapping[str, Module]) -> None:
  485. r"""Update the :class:`~torch.nn.ModuleDict` with key-value pairs from a mapping, overwriting existing keys.
  486. .. note::
  487. If :attr:`modules` is an ``OrderedDict``, a :class:`~torch.nn.ModuleDict`, or
  488. an iterable of key-value pairs, the order of new elements in it is preserved.
  489. Args:
  490. modules (iterable): a mapping (dictionary) from string to :class:`~torch.nn.Module`,
  491. or an iterable of key-value pairs of type (string, :class:`~torch.nn.Module`)
  492. """
  493. if not isinstance(modules, container_abcs.Iterable):
  494. raise TypeError(
  495. "ModuleDict.update should be called with an "
  496. "iterable of key/value pairs, but got " + type(modules).__name__
  497. )
  498. if isinstance(modules, (OrderedDict, ModuleDict, container_abcs.Mapping)):
  499. for key, module in modules.items():
  500. self[key] = module
  501. else:
  502. # modules here can be a list with two items
  503. for j, m in enumerate(modules):
  504. if not isinstance(m, container_abcs.Iterable):
  505. raise TypeError(
  506. "ModuleDict update sequence element "
  507. "#" + str(j) + " should be Iterable; is" + type(m).__name__
  508. )
  509. if not len(m) == 2:
  510. raise ValueError(
  511. "ModuleDict update sequence element "
  512. "#" + str(j) + " has length " + str(len(m)) + "; 2 is required"
  513. )
  514. # modules can be Mapping (what it's typed at), or a list: [(name1, module1), (name2, module2)]
  515. # that's too cumbersome to type correctly with overloads, so we add an ignore here
  516. self[m[0]] = m[1] # type: ignore[assignment]
  517. # remove forward altogether to fallback on Module's _forward_unimplemented
  518. class ParameterList(Module):
  519. r"""Holds parameters in a list.
  520. :class:`~torch.nn.ParameterList` can be used like a regular Python
  521. list, but Tensors that are :class:`~torch.nn.Parameter` are properly registered,
  522. and will be visible by all :class:`~torch.nn.Module` methods.
  523. Note that the constructor, assigning an element of the list, the
  524. :meth:`~torch.nn.ParameterList.append` method and the :meth:`~torch.nn.ParameterList.extend`
  525. method will convert any :class:`~torch.Tensor` into :class:`~torch.nn.Parameter`.
  526. Args:
  527. parameters (iterable, optional): an iterable of elements to add to the list.
  528. Example::
  529. class MyModule(nn.Module):
  530. def __init__(self) -> None:
  531. super().__init__()
  532. self.params = nn.ParameterList(
  533. [nn.Parameter(torch.randn(10, 10)) for i in range(10)]
  534. )
  535. def forward(self, x):
  536. # ParameterList can act as an iterable, or be indexed using ints
  537. for i, p in enumerate(self.params):
  538. x = self.params[i // 2].mm(x) + p.mm(x)
  539. return x
  540. """
  541. def __init__(self, values: Optional[Iterable[Any]] = None) -> None:
  542. super().__init__()
  543. self._size = 0
  544. if values is not None:
  545. self += values
  546. def _get_abs_string_index(self, idx):
  547. """Get the absolute index for the list of modules."""
  548. idx = operator.index(idx)
  549. if not (-len(self) <= idx < len(self)):
  550. raise IndexError(f"index {idx} is out of range")
  551. if idx < 0:
  552. idx += len(self)
  553. return str(idx)
  554. @overload
  555. def __getitem__(self, idx: int) -> Any: ...
  556. @overload
  557. def __getitem__(self: T, idx: slice) -> T: ...
  558. def __getitem__(self, idx):
  559. if isinstance(idx, slice):
  560. start, stop, step = idx.indices(len(self))
  561. out = self.__class__()
  562. for i in range(start, stop, step):
  563. out.append(self[i])
  564. return out
  565. else:
  566. idx = self._get_abs_string_index(idx)
  567. return getattr(self, str(idx))
  568. def __setitem__(self, idx: int, param: Any) -> None:
  569. # Note that all other function that add an entry to the list part of
  570. # the ParameterList end up here. So this is the only place where we need
  571. # to wrap things into Parameter if needed.
  572. # Objects added via setattr() are not in the list part and thus won't
  573. # call into this function.
  574. idx = self._get_abs_string_index(idx)
  575. if isinstance(param, torch.Tensor) and not isinstance(param, Parameter):
  576. param = Parameter(param)
  577. return setattr(self, str(idx), param)
  578. def __len__(self) -> int:
  579. return self._size
  580. def __iter__(self) -> Iterator[Any]:
  581. return iter(self[i] for i in range(len(self)))
  582. def __iadd__(self, parameters: Iterable[Any]) -> Self:
  583. return self.extend(parameters)
  584. def __dir__(self) -> list[str]:
  585. keys = super().__dir__()
  586. keys = [key for key in keys if not key.isdigit()]
  587. return keys
  588. def append(self, value: Any) -> Self:
  589. """Append a given value at the end of the list.
  590. Args:
  591. value (Any): value to append
  592. """
  593. new_idx = len(self)
  594. self._size += 1
  595. self[new_idx] = value
  596. return self
  597. def extend(self, values: Iterable[Any]) -> Self:
  598. """Append values from a Python iterable to the end of the list.
  599. Args:
  600. values (iterable): iterable of values to append
  601. """
  602. # Tensor is an iterable but we never want to unpack it here
  603. if not isinstance(values, container_abcs.Iterable) or isinstance(
  604. values, torch.Tensor
  605. ):
  606. raise TypeError(
  607. "ParameterList.extend should be called with an "
  608. "iterable, but got " + type(values).__name__
  609. )
  610. for value in values:
  611. self.append(value)
  612. return self
  613. def extra_repr(self) -> str:
  614. """
  615. Return the extra representation of the module.
  616. """
  617. child_lines = []
  618. for k, p in enumerate(self):
  619. if isinstance(p, torch.Tensor):
  620. size_str = "x".join(str(size) for size in p.size())
  621. if p.device.type in ["cuda", torch._C._get_privateuse1_backend_name()]:
  622. device_str = f" ({p.device})"
  623. else:
  624. device_str = ""
  625. parastr = "{} containing: [{} of size {}{}]".format(
  626. "Parameter" if isinstance(p, Parameter) else "Tensor",
  627. p.dtype,
  628. size_str,
  629. device_str,
  630. )
  631. child_lines.append(" (" + str(k) + "): " + parastr)
  632. else:
  633. child_lines.append(
  634. " (" + str(k) + "): Object of type: " + type(p).__name__
  635. )
  636. tmpstr = "\n".join(child_lines)
  637. return tmpstr
  638. def __call__(self, *args, **kwargs):
  639. raise RuntimeError("ParameterList should not be called.")
  640. class ParameterDict(Module):
  641. r"""Holds parameters in a dictionary.
  642. ParameterDict can be indexed like a regular Python dictionary, but Parameters it
  643. contains are properly registered, and will be visible by all Module methods.
  644. Other objects are treated as would be done by a regular Python dictionary
  645. :class:`~torch.nn.ParameterDict` is an **ordered** dictionary.
  646. :meth:`~torch.nn.ParameterDict.update` with other unordered mapping
  647. types (e.g., Python's plain ``dict``) does not preserve the order of the
  648. merged mapping. On the other hand, ``OrderedDict`` or another :class:`~torch.nn.ParameterDict`
  649. will preserve their ordering.
  650. Note that the constructor, assigning an element of the dictionary and the
  651. :meth:`~torch.nn.ParameterDict.update` method will convert any :class:`~torch.Tensor` into
  652. :class:`~torch.nn.Parameter`.
  653. Args:
  654. values (iterable, optional): a mapping (dictionary) of
  655. (string : Any) or an iterable of key-value pairs
  656. of type (string, Any)
  657. Example::
  658. class MyModule(nn.Module):
  659. def __init__(self) -> None:
  660. super().__init__()
  661. self.params = nn.ParameterDict(
  662. {
  663. "left": nn.Parameter(torch.randn(5, 10)),
  664. "right": nn.Parameter(torch.randn(5, 10)),
  665. }
  666. )
  667. def forward(self, x, choice):
  668. x = self.params[choice].mm(x)
  669. return x
  670. """
  671. def __init__(self, parameters: Any = None) -> None:
  672. super().__init__()
  673. self._keys: dict[str, None] = {}
  674. if parameters is not None:
  675. self.update(parameters)
  676. def _key_to_attr(self, key: str) -> str:
  677. if not isinstance(key, str):
  678. raise TypeError(
  679. "Index given to ParameterDict cannot be used as a key as it is "
  680. f"not a string (type is '{type(key).__name__}'). Open an issue on "
  681. "github if you need non-string keys."
  682. )
  683. else:
  684. # Use the key as-is so that `.named_parameters()` returns the right thing
  685. return key
  686. def __getitem__(self, key: str) -> Any:
  687. attr = self._key_to_attr(key)
  688. return getattr(self, attr)
  689. def __setitem__(self, key: str, value: Any) -> None:
  690. # Note that all other function that add an entry to the dictionary part of
  691. # the ParameterDict end up here. So this is the only place where we need
  692. # to wrap things into Parameter if needed.
  693. # Objects added via setattr() are not in the dictionary part and thus won't
  694. # call into this function.
  695. self._keys[key] = None
  696. attr = self._key_to_attr(key)
  697. if isinstance(value, torch.Tensor) and not isinstance(value, Parameter):
  698. value = Parameter(value)
  699. setattr(self, attr, value)
  700. def __delitem__(self, key: str) -> None:
  701. del self._keys[key]
  702. attr = self._key_to_attr(key)
  703. delattr(self, attr)
  704. def __len__(self) -> int:
  705. return len(self._keys)
  706. def __iter__(self) -> Iterator[str]:
  707. return iter(self._keys)
  708. def __reversed__(self) -> Iterator[str]:
  709. return reversed(self._keys)
  710. def copy(self) -> ParameterDict:
  711. """Return a copy of this :class:`~torch.nn.ParameterDict` instance."""
  712. # We have to use an OrderedDict because the ParameterDict constructor
  713. # behaves differently on plain dict vs OrderedDict
  714. return ParameterDict(OrderedDict((k, self[k]) for k in self._keys))
  715. def __contains__(self, key: str) -> bool:
  716. return key in self._keys
  717. def setdefault(self, key: str, default: Optional[Any] = None) -> Any:
  718. """Set the default for a key in the Parameterdict.
  719. If key is in the ParameterDict, return its value.
  720. If not, insert `key` with a parameter `default` and return `default`.
  721. `default` defaults to `None`.
  722. Args:
  723. key (str): key to set default for
  724. default (Any): the parameter set to the key
  725. """
  726. if key not in self:
  727. self[key] = default
  728. return self[key]
  729. def clear(self) -> None:
  730. """Remove all items from the ParameterDict."""
  731. for k in self._keys.copy():
  732. del self[k]
  733. def pop(self, key: str) -> Any:
  734. r"""Remove key from the ParameterDict and return its parameter.
  735. Args:
  736. key (str): key to pop from the ParameterDict
  737. """
  738. v = self[key]
  739. del self[key]
  740. return v
  741. def popitem(self) -> tuple[str, Any]:
  742. """Remove and return the last inserted `(key, parameter)` pair from the ParameterDict."""
  743. k, _ = self._keys.popitem()
  744. # We need the key in the _keys to be able to access/del
  745. self._keys[k] = None
  746. val = self[k]
  747. del self[k]
  748. return k, val
  749. def get(self, key: str, default: Optional[Any] = None) -> Any:
  750. r"""Return the parameter associated with key if present. Otherwise return default if provided, None if not.
  751. Args:
  752. key (str): key to get from the ParameterDict
  753. default (Parameter, optional): value to return if key not present
  754. """
  755. return self[key] if key in self else default
  756. def fromkeys(
  757. self, keys: Iterable[str], default: Optional[Any] = None
  758. ) -> ParameterDict:
  759. r"""Return a new ParameterDict with the keys provided.
  760. Args:
  761. keys (iterable, string): keys to make the new ParameterDict from
  762. default (Parameter, optional): value to set for all keys
  763. """
  764. return ParameterDict((k, default) for k in keys)
  765. def keys(self) -> container_abcs.KeysView[str]:
  766. r"""Return an iterable of the ParameterDict keys."""
  767. return self._keys.keys()
  768. def items(self) -> Iterable[tuple[str, Any]]:
  769. r"""Return an iterable of the ParameterDict key/value pairs."""
  770. return ((k, self[k]) for k in self._keys)
  771. def values(self) -> Iterable[Any]:
  772. r"""Return an iterable of the ParameterDict values."""
  773. return (self[k] for k in self._keys)
  774. def update(self, parameters: Union[Mapping[str, Any], ParameterDict]) -> None:
  775. r"""Update the :class:`~torch.nn.ParameterDict` with key-value pairs from ``parameters``, overwriting existing keys.
  776. .. note::
  777. If :attr:`parameters` is an ``OrderedDict``, a :class:`~torch.nn.ParameterDict`, or
  778. an iterable of key-value pairs, the order of new elements in it is preserved.
  779. Args:
  780. parameters (iterable): a mapping (dictionary) from string to
  781. :class:`~torch.nn.Parameter`, or an iterable of
  782. key-value pairs of type (string, :class:`~torch.nn.Parameter`)
  783. """
  784. if not isinstance(parameters, container_abcs.Iterable):
  785. raise TypeError(
  786. "ParametersDict.update should be called with an "
  787. "iterable of key/value pairs, but got " + type(parameters).__name__
  788. )
  789. if isinstance(parameters, (OrderedDict, ParameterDict)):
  790. for key, parameter in parameters.items():
  791. self[key] = parameter
  792. elif isinstance(parameters, container_abcs.Mapping):
  793. for key, parameter in sorted(parameters.items()):
  794. self[key] = parameter
  795. else:
  796. for j, p in enumerate(parameters):
  797. if not isinstance(p, container_abcs.Iterable):
  798. raise TypeError(
  799. "ParameterDict update sequence element "
  800. "#" + str(j) + " should be Iterable; is" + type(p).__name__
  801. )
  802. if not len(p) == 2:
  803. raise ValueError(
  804. "ParameterDict update sequence element "
  805. "#" + str(j) + " has length " + str(len(p)) + "; 2 is required"
  806. )
  807. # parameters as length-2 list too cumbersome to type, see ModuleDict.update comment
  808. self[p[0]] = p[1] # type: ignore[assignment]
  809. def extra_repr(self) -> str:
  810. child_lines = []
  811. for k, p in self.items():
  812. if isinstance(p, torch.Tensor):
  813. size_str = "x".join(str(size) for size in p.size())
  814. if p.device.type in ["cuda", torch._C._get_privateuse1_backend_name()]:
  815. device_str = f" ({p.device})"
  816. else:
  817. device_str = ""
  818. parastr = "{} containing: [{} of size {}{}]".format(
  819. "Parameter" if isinstance(p, Parameter) else "Tensor",
  820. torch.typename(p),
  821. size_str,
  822. device_str,
  823. )
  824. child_lines.append(" (" + str(k) + "): " + parastr)
  825. else:
  826. child_lines.append(
  827. " (" + str(k) + "): Object of type: " + type(p).__name__
  828. )
  829. tmpstr = "\n".join(child_lines)
  830. return tmpstr
  831. def __call__(self, input):
  832. raise RuntimeError("ParameterDict should not be called.")
  833. def __or__(self, other: ParameterDict) -> ParameterDict:
  834. copy = self.copy()
  835. copy.update(other)
  836. return copy
  837. def __ror__(self, other: ParameterDict) -> ParameterDict:
  838. copy = other.copy()
  839. copy.update(self)
  840. return copy
  841. def __ior__(self, other: ParameterDict) -> Self:
  842. self.update(other)
  843. return self