rnn.py 73 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160116111621163116411651166116711681169117011711172117311741175117611771178117911801181118211831184118511861187118811891190119111921193119411951196119711981199120012011202120312041205120612071208120912101211121212131214121512161217121812191220122112221223122412251226122712281229123012311232123312341235123612371238123912401241124212431244124512461247124812491250125112521253125412551256125712581259126012611262126312641265126612671268126912701271127212731274127512761277127812791280128112821283128412851286128712881289129012911292129312941295129612971298129913001301130213031304130513061307130813091310131113121313131413151316131713181319132013211322132313241325132613271328132913301331133213331334133513361337133813391340134113421343134413451346134713481349135013511352135313541355135613571358135913601361136213631364136513661367136813691370137113721373137413751376137713781379138013811382138313841385138613871388138913901391139213931394139513961397139813991400140114021403140414051406140714081409141014111412141314141415141614171418141914201421142214231424142514261427142814291430143114321433143414351436143714381439144014411442144314441445144614471448144914501451145214531454145514561457145814591460146114621463146414651466146714681469147014711472147314741475147614771478147914801481148214831484148514861487148814891490149114921493149414951496149714981499150015011502150315041505150615071508150915101511151215131514151515161517151815191520152115221523152415251526152715281529153015311532153315341535153615371538153915401541154215431544154515461547154815491550155115521553155415551556155715581559156015611562156315641565156615671568156915701571157215731574157515761577157815791580158115821583158415851586158715881589159015911592159315941595159615971598159916001601160216031604160516061607160816091610161116121613161416151616161716181619162016211622162316241625162616271628162916301631163216331634163516361637163816391640164116421643164416451646164716481649165016511652165316541655165616571658165916601661166216631664166516661667166816691670167116721673167416751676167716781679168016811682168316841685168616871688168916901691169216931694169516961697169816991700170117021703170417051706170717081709171017111712171317141715171617171718171917201721172217231724172517261727172817291730173117321733173417351736173717381739174017411742174317441745174617471748174917501751175217531754175517561757175817591760176117621763176417651766176717681769177017711772177317741775177617771778177917801781178217831784178517861787178817891790179117921793179417951796179717981799180018011802180318041805180618071808180918101811181218131814181518161817181818191820182118221823182418251826
  1. # mypy: allow-untyped-decorators
  2. # mypy: allow-untyped-defs
  3. import math
  4. import numbers
  5. import warnings
  6. import weakref
  7. from typing import Optional, overload
  8. from typing_extensions import deprecated
  9. import torch
  10. from torch import _VF, Tensor
  11. from torch.nn import init
  12. from torch.nn.parameter import Parameter
  13. from torch.nn.utils.rnn import PackedSequence
  14. from .module import Module
  15. __all__ = [
  16. "RNNBase",
  17. "RNN",
  18. "LSTM",
  19. "GRU",
  20. "RNNCellBase",
  21. "RNNCell",
  22. "LSTMCell",
  23. "GRUCell",
  24. ]
  25. _rnn_impls = {
  26. "RNN_TANH": _VF.rnn_tanh,
  27. "RNN_RELU": _VF.rnn_relu,
  28. }
  29. def _apply_permutation(tensor: Tensor, permutation: Tensor, dim: int = 1) -> Tensor:
  30. return tensor.index_select(dim, permutation)
  31. @deprecated(
  32. "`apply_permutation` is deprecated, please use `tensor.index_select(dim, permutation)` instead",
  33. category=FutureWarning,
  34. )
  35. def apply_permutation(tensor: Tensor, permutation: Tensor, dim: int = 1) -> Tensor:
  36. return _apply_permutation(tensor, permutation, dim)
  37. class RNNBase(Module):
  38. r"""Base class for RNN modules (RNN, LSTM, GRU).
  39. Implements aspects of RNNs shared by the RNN, LSTM, and GRU classes, such as module initialization
  40. and utility methods for parameter storage management.
  41. .. note::
  42. The forward method is not implemented by the RNNBase class.
  43. .. note::
  44. LSTM and GRU classes override some methods implemented by RNNBase.
  45. """
  46. __constants__ = [
  47. "mode",
  48. "input_size",
  49. "hidden_size",
  50. "num_layers",
  51. "bias",
  52. "batch_first",
  53. "dropout",
  54. "bidirectional",
  55. "proj_size",
  56. ]
  57. __jit_unused_properties__ = ["all_weights"]
  58. mode: str
  59. input_size: int
  60. hidden_size: int
  61. num_layers: int
  62. bias: bool
  63. batch_first: bool
  64. dropout: float
  65. bidirectional: bool
  66. proj_size: int
  67. def __init__(
  68. self,
  69. mode: str,
  70. input_size: int,
  71. hidden_size: int,
  72. num_layers: int = 1,
  73. bias: bool = True,
  74. batch_first: bool = False,
  75. dropout: float = 0.0,
  76. bidirectional: bool = False,
  77. proj_size: int = 0,
  78. device=None,
  79. dtype=None,
  80. ) -> None:
  81. factory_kwargs = {"device": device, "dtype": dtype}
  82. super().__init__()
  83. self.mode = mode
  84. self.input_size = input_size
  85. self.hidden_size = hidden_size
  86. self.num_layers = num_layers
  87. self.bias = bias
  88. self.batch_first = batch_first
  89. self.dropout = float(dropout)
  90. self.bidirectional = bidirectional
  91. self.proj_size = proj_size
  92. self._flat_weight_refs: list[Optional[weakref.ReferenceType[Parameter]]] = []
  93. num_directions = 2 if bidirectional else 1
  94. if (
  95. not isinstance(dropout, numbers.Number)
  96. or not 0 <= dropout <= 1
  97. or isinstance(dropout, bool)
  98. ):
  99. raise ValueError(
  100. "dropout should be a number in range [0, 1] "
  101. "representing the probability of an element being "
  102. "zeroed"
  103. )
  104. if dropout > 0 and num_layers == 1:
  105. warnings.warn(
  106. "dropout option adds dropout after all but last "
  107. "recurrent layer, so non-zero dropout expects "
  108. f"num_layers greater than 1, but got dropout={dropout} and "
  109. f"num_layers={num_layers}"
  110. )
  111. if not isinstance(hidden_size, int):
  112. raise TypeError(
  113. f"hidden_size should be of type int, got: {type(hidden_size).__name__}"
  114. )
  115. if hidden_size <= 0:
  116. raise ValueError("hidden_size must be greater than zero")
  117. if num_layers <= 0:
  118. raise ValueError("num_layers must be greater than zero")
  119. if proj_size < 0:
  120. raise ValueError(
  121. "proj_size should be a positive integer or zero to disable projections"
  122. )
  123. if proj_size >= hidden_size:
  124. raise ValueError("proj_size has to be smaller than hidden_size")
  125. if mode == "LSTM":
  126. gate_size = 4 * hidden_size
  127. elif mode == "GRU":
  128. gate_size = 3 * hidden_size
  129. elif mode == "RNN_TANH":
  130. gate_size = hidden_size
  131. elif mode == "RNN_RELU":
  132. gate_size = hidden_size
  133. else:
  134. raise ValueError("Unrecognized RNN mode: " + mode)
  135. self._flat_weights_names = []
  136. self._all_weights = []
  137. for layer in range(num_layers):
  138. for direction in range(num_directions):
  139. real_hidden_size = proj_size if proj_size > 0 else hidden_size
  140. layer_input_size = (
  141. input_size if layer == 0 else real_hidden_size * num_directions
  142. )
  143. w_ih = Parameter(
  144. torch.empty((gate_size, layer_input_size), **factory_kwargs)
  145. )
  146. w_hh = Parameter(
  147. torch.empty((gate_size, real_hidden_size), **factory_kwargs)
  148. )
  149. b_ih = Parameter(torch.empty(gate_size, **factory_kwargs))
  150. # Second bias vector included for CuDNN compatibility. Only one
  151. # bias vector is needed in standard definition.
  152. b_hh = Parameter(torch.empty(gate_size, **factory_kwargs))
  153. layer_params: tuple[Tensor, ...] = ()
  154. if self.proj_size == 0:
  155. if bias:
  156. layer_params = (w_ih, w_hh, b_ih, b_hh)
  157. else:
  158. layer_params = (w_ih, w_hh)
  159. else:
  160. w_hr = Parameter(
  161. torch.empty((proj_size, hidden_size), **factory_kwargs)
  162. )
  163. if bias:
  164. layer_params = (w_ih, w_hh, b_ih, b_hh, w_hr)
  165. else:
  166. layer_params = (w_ih, w_hh, w_hr)
  167. suffix = "_reverse" if direction == 1 else ""
  168. param_names = ["weight_ih_l{}{}", "weight_hh_l{}{}"]
  169. if bias:
  170. param_names += ["bias_ih_l{}{}", "bias_hh_l{}{}"]
  171. if self.proj_size > 0:
  172. param_names += ["weight_hr_l{}{}"]
  173. param_names = [x.format(layer, suffix) for x in param_names]
  174. for name, param in zip(param_names, layer_params):
  175. setattr(self, name, param)
  176. self._flat_weights_names.extend(param_names)
  177. self._all_weights.append(param_names)
  178. self._init_flat_weights()
  179. self.reset_parameters()
  180. def _init_flat_weights(self) -> None:
  181. self._flat_weights = [
  182. getattr(self, wn) if hasattr(self, wn) else None
  183. for wn in self._flat_weights_names
  184. ]
  185. self._flat_weight_refs = [
  186. weakref.ref(w) if w is not None else None for w in self._flat_weights
  187. ]
  188. self.flatten_parameters()
  189. def __setattr__(self, attr, value) -> None:
  190. if hasattr(self, "_flat_weights_names") and attr in self._flat_weights_names:
  191. # keep self._flat_weights up to date if you do self.weight = ...
  192. idx = self._flat_weights_names.index(attr)
  193. self._flat_weights[idx] = value
  194. super().__setattr__(attr, value)
  195. def flatten_parameters(self) -> None:
  196. """Reset parameter data pointer so that they can use faster code paths.
  197. Right now, this works only if the module is on the GPU and cuDNN is enabled.
  198. Otherwise, it's a no-op.
  199. """
  200. # Short-circuits if _flat_weights is only partially instantiated
  201. if len(self._flat_weights) != len(self._flat_weights_names):
  202. return
  203. for w in self._flat_weights:
  204. if not isinstance(w, Tensor):
  205. return
  206. # Short-circuits if any tensor in self._flat_weights is not acceptable to cuDNN
  207. # or the tensors in _flat_weights are of different dtypes
  208. first_fw = self._flat_weights[0] # type: ignore[union-attr]
  209. dtype = first_fw.dtype # type: ignore[union-attr]
  210. for fw in self._flat_weights:
  211. if (
  212. not isinstance(fw, Tensor)
  213. or not (fw.dtype == dtype)
  214. or not fw.is_cuda
  215. or not torch.backends.cudnn.is_acceptable(fw)
  216. ):
  217. return
  218. # If any parameters alias, we fall back to the slower, copying code path. This is
  219. # a sufficient check, because overlapping parameter buffers that don't completely
  220. # alias would break the assumptions of the uniqueness check in
  221. # Module.named_parameters().
  222. unique_data_ptrs = {
  223. p.data_ptr() # type: ignore[union-attr]
  224. for p in self._flat_weights
  225. }
  226. if len(unique_data_ptrs) != len(self._flat_weights):
  227. return
  228. with torch.cuda.device_of(first_fw):
  229. import torch.backends.cudnn.rnn as rnn
  230. # Note: no_grad() is necessary since _cudnn_rnn_flatten_weight is
  231. # an inplace operation on self._flat_weights
  232. with torch.no_grad():
  233. if torch._use_cudnn_rnn_flatten_weight():
  234. num_weights = 4 if self.bias else 2
  235. if self.proj_size > 0:
  236. num_weights += 1
  237. torch._cudnn_rnn_flatten_weight(
  238. self._flat_weights, # type: ignore[arg-type]
  239. num_weights,
  240. self.input_size,
  241. rnn.get_cudnn_mode(self.mode),
  242. self.hidden_size,
  243. self.proj_size,
  244. self.num_layers,
  245. self.batch_first,
  246. bool(self.bidirectional),
  247. )
  248. def _apply(self, fn, recurse=True):
  249. self._flat_weight_refs = []
  250. ret = super()._apply(fn, recurse)
  251. # Resets _flat_weights
  252. # Note: be v. careful before removing this, as 3rd party device types
  253. # likely rely on this behavior to properly .to() modules like LSTM.
  254. self._init_flat_weights()
  255. return ret
  256. def reset_parameters(self) -> None:
  257. stdv = 1.0 / math.sqrt(self.hidden_size) if self.hidden_size > 0 else 0
  258. for weight in self.parameters():
  259. init.uniform_(weight, -stdv, stdv)
  260. def check_input(self, input: Tensor, batch_sizes: Optional[Tensor]) -> None:
  261. if not torch.jit.is_scripting():
  262. if (
  263. input.dtype != self._flat_weights[0].dtype # type: ignore[union-attr]
  264. and not torch._C._is_any_autocast_enabled()
  265. ):
  266. raise ValueError(
  267. f"input must have the type {self._flat_weights[0].dtype}, got type {input.dtype}" # type: ignore[union-attr]
  268. )
  269. expected_input_dim = 2 if batch_sizes is not None else 3
  270. if input.dim() != expected_input_dim:
  271. raise RuntimeError(
  272. f"input must have {expected_input_dim} dimensions, got {input.dim()}"
  273. )
  274. if self.input_size != input.size(-1):
  275. raise RuntimeError(
  276. f"input.size(-1) must be equal to input_size. Expected {self.input_size}, got {input.size(-1)}"
  277. )
  278. def get_expected_hidden_size(
  279. self, input: Tensor, batch_sizes: Optional[Tensor]
  280. ) -> tuple[int, int, int]:
  281. if batch_sizes is not None:
  282. mini_batch = int(batch_sizes[0])
  283. else:
  284. mini_batch = input.size(0) if self.batch_first else input.size(1)
  285. num_directions = 2 if self.bidirectional else 1
  286. if self.proj_size > 0:
  287. expected_hidden_size = (
  288. self.num_layers * num_directions,
  289. mini_batch,
  290. self.proj_size,
  291. )
  292. else:
  293. expected_hidden_size = (
  294. self.num_layers * num_directions,
  295. mini_batch,
  296. self.hidden_size,
  297. )
  298. return expected_hidden_size
  299. def check_hidden_size(
  300. self,
  301. hx: Tensor,
  302. expected_hidden_size: tuple[int, int, int],
  303. msg: str = "Expected hidden size {}, got {}",
  304. ) -> None:
  305. if hx.size() != expected_hidden_size:
  306. raise RuntimeError(msg.format(expected_hidden_size, list(hx.size())))
  307. def _weights_have_changed(self):
  308. # Returns True if the weight tensors have changed since the last forward pass.
  309. # This is the case when used with torch.func.functional_call(), for example.
  310. weights_changed = False
  311. for ref, name in zip(self._flat_weight_refs, self._flat_weights_names):
  312. weight = getattr(self, name) if hasattr(self, name) else None
  313. if weight is not None and ref is not None and ref() is not weight:
  314. weights_changed = True
  315. break
  316. return weights_changed
  317. def check_forward_args(
  318. self, input: Tensor, hidden: Tensor, batch_sizes: Optional[Tensor]
  319. ) -> None:
  320. self.check_input(input, batch_sizes)
  321. expected_hidden_size = self.get_expected_hidden_size(input, batch_sizes)
  322. self.check_hidden_size(hidden, expected_hidden_size)
  323. def permute_hidden(self, hx: Tensor, permutation: Optional[Tensor]):
  324. if permutation is None:
  325. return hx
  326. return _apply_permutation(hx, permutation)
  327. def extra_repr(self) -> str:
  328. s = "{input_size}, {hidden_size}"
  329. if self.proj_size != 0:
  330. s += ", proj_size={proj_size}"
  331. if self.num_layers != 1:
  332. s += ", num_layers={num_layers}"
  333. if self.bias is not True:
  334. s += ", bias={bias}"
  335. if self.batch_first is not False:
  336. s += ", batch_first={batch_first}"
  337. if self.dropout != 0:
  338. s += ", dropout={dropout}"
  339. if self.bidirectional is not False:
  340. s += ", bidirectional={bidirectional}"
  341. return s.format(**self.__dict__)
  342. def _update_flat_weights(self) -> None:
  343. if not torch.jit.is_scripting():
  344. if self._weights_have_changed():
  345. self._init_flat_weights()
  346. def __getstate__(self):
  347. # If weights have been changed, update the _flat_weights in __getstate__ here.
  348. self._update_flat_weights()
  349. # Don't serialize the weight references.
  350. state = self.__dict__.copy()
  351. del state["_flat_weight_refs"]
  352. return state
  353. def __setstate__(self, d):
  354. super().__setstate__(d)
  355. if "all_weights" in d:
  356. self._all_weights = d["all_weights"]
  357. # In PyTorch 1.8 we added a proj_size member variable to LSTM.
  358. # LSTMs that were serialized via torch.save(module) before PyTorch 1.8
  359. # don't have it, so to preserve compatibility we set proj_size here.
  360. if "proj_size" not in d:
  361. self.proj_size = 0
  362. if not isinstance(self._all_weights[0][0], str):
  363. num_layers = self.num_layers
  364. num_directions = 2 if self.bidirectional else 1
  365. self._flat_weights_names = []
  366. self._all_weights = []
  367. for layer in range(num_layers):
  368. for direction in range(num_directions):
  369. suffix = "_reverse" if direction == 1 else ""
  370. weights = [
  371. "weight_ih_l{}{}",
  372. "weight_hh_l{}{}",
  373. "bias_ih_l{}{}",
  374. "bias_hh_l{}{}",
  375. "weight_hr_l{}{}",
  376. ]
  377. weights = [x.format(layer, suffix) for x in weights]
  378. if self.bias:
  379. if self.proj_size > 0:
  380. self._all_weights += [weights]
  381. self._flat_weights_names.extend(weights)
  382. else:
  383. self._all_weights += [weights[:4]]
  384. self._flat_weights_names.extend(weights[:4])
  385. else:
  386. if self.proj_size > 0:
  387. self._all_weights += [weights[:2]] + [weights[-1:]]
  388. self._flat_weights_names.extend(
  389. weights[:2] + [weights[-1:]]
  390. )
  391. else:
  392. self._all_weights += [weights[:2]]
  393. self._flat_weights_names.extend(weights[:2])
  394. self._flat_weights = [
  395. getattr(self, wn) if hasattr(self, wn) else None
  396. for wn in self._flat_weights_names
  397. ]
  398. self._flat_weight_refs = [
  399. weakref.ref(w) if w is not None else None for w in self._flat_weights
  400. ]
  401. @property
  402. def all_weights(self) -> list[list[Parameter]]:
  403. return [
  404. [getattr(self, weight) for weight in weights]
  405. for weights in self._all_weights
  406. ]
  407. def _replicate_for_data_parallel(self):
  408. replica = super()._replicate_for_data_parallel()
  409. # Need to copy these caches, otherwise the replica will share the same
  410. # flat weights list.
  411. replica._flat_weights = replica._flat_weights[:]
  412. replica._flat_weights_names = replica._flat_weights_names[:]
  413. return replica
  414. class RNN(RNNBase):
  415. r"""__init__(input_size,hidden_size,num_layers=1,nonlinearity='tanh',bias=True,batch_first=False,dropout=0.0,bidirectional=False,device=None,dtype=None)
  416. Apply a multi-layer Elman RNN with :math:`\tanh` or :math:`\text{ReLU}`
  417. non-linearity to an input sequence. For each element in the input sequence,
  418. each layer computes the following function:
  419. .. math::
  420. h_t = \tanh(x_t W_{ih}^T + b_{ih} + h_{t-1}W_{hh}^T + b_{hh})
  421. where :math:`h_t` is the hidden state at time `t`, :math:`x_t` is
  422. the input at time `t`, and :math:`h_{(t-1)}` is the hidden state of the
  423. previous layer at time `t-1` or the initial hidden state at time `0`.
  424. If :attr:`nonlinearity` is ``'relu'``, then :math:`\text{ReLU}` is used instead of :math:`\tanh`.
  425. .. code-block:: python
  426. # Efficient implementation equivalent to the following with bidirectional=False
  427. rnn = nn.RNN(input_size, hidden_size, num_layers)
  428. params = dict(rnn.named_parameters())
  429. def forward(x, hx=None, batch_first=False):
  430. if batch_first:
  431. x = x.transpose(0, 1)
  432. seq_len, batch_size, _ = x.size()
  433. if hx is None:
  434. hx = torch.zeros(rnn.num_layers, batch_size, rnn.hidden_size)
  435. h_t_minus_1 = hx.clone()
  436. h_t = hx.clone()
  437. output = []
  438. for t in range(seq_len):
  439. for layer in range(rnn.num_layers):
  440. input_t = x[t] if layer == 0 else h_t[layer - 1]
  441. h_t[layer] = torch.tanh(
  442. input_t @ params[f"weight_ih_l{layer}"].T
  443. + h_t_minus_1[layer] @ params[f"weight_hh_l{layer}"].T
  444. + params[f"bias_hh_l{layer}"]
  445. + params[f"bias_ih_l{layer}"]
  446. )
  447. output.append(h_t[-1].clone())
  448. h_t_minus_1 = h_t.clone()
  449. output = torch.stack(output)
  450. if batch_first:
  451. output = output.transpose(0, 1)
  452. return output, h_t
  453. Args:
  454. input_size: The number of expected features in the input `x`
  455. hidden_size: The number of features in the hidden state `h`
  456. num_layers: Number of recurrent layers. E.g., setting ``num_layers=2``
  457. would mean stacking two RNNs together to form a `stacked RNN`,
  458. with the second RNN taking in outputs of the first RNN and
  459. computing the final results. Default: 1
  460. nonlinearity: The non-linearity to use. Can be either ``'tanh'`` or ``'relu'``. Default: ``'tanh'``
  461. bias: If ``False``, then the layer does not use bias weights `b_ih` and `b_hh`.
  462. Default: ``True``
  463. batch_first: If ``True``, then the input and output tensors are provided
  464. as `(batch, seq, feature)` instead of `(seq, batch, feature)`.
  465. Note that this does not apply to hidden or cell states. See the
  466. Inputs/Outputs sections below for details. Default: ``False``
  467. dropout: If non-zero, introduces a `Dropout` layer on the outputs of each
  468. RNN layer except the last layer, with dropout probability equal to
  469. :attr:`dropout`. Default: 0
  470. bidirectional: If ``True``, becomes a bidirectional RNN. Default: ``False``
  471. Inputs: input, hx
  472. * **input**: tensor of shape :math:`(L, H_{in})` for unbatched input,
  473. :math:`(L, N, H_{in})` when ``batch_first=False`` or
  474. :math:`(N, L, H_{in})` when ``batch_first=True`` containing the features of
  475. the input sequence. The input can also be a packed variable length sequence.
  476. See :func:`torch.nn.utils.rnn.pack_padded_sequence` or
  477. :func:`torch.nn.utils.rnn.pack_sequence` for details.
  478. * **hx**: tensor of shape :math:`(D * \text{num\_layers}, H_{out})` for unbatched input or
  479. :math:`(D * \text{num\_layers}, N, H_{out})` containing the initial hidden
  480. state for the input sequence batch. Defaults to zeros if not provided.
  481. where:
  482. .. math::
  483. \begin{aligned}
  484. N ={} & \text{batch size} \\
  485. L ={} & \text{sequence length} \\
  486. D ={} & 2 \text{ if bidirectional=True otherwise } 1 \\
  487. H_{in} ={} & \text{input\_size} \\
  488. H_{out} ={} & \text{hidden\_size}
  489. \end{aligned}
  490. Outputs: output, h_n
  491. * **output**: tensor of shape :math:`(L, D * H_{out})` for unbatched input,
  492. :math:`(L, N, D * H_{out})` when ``batch_first=False`` or
  493. :math:`(N, L, D * H_{out})` when ``batch_first=True`` containing the output features
  494. `(h_t)` from the last layer of the RNN, for each `t`. If a
  495. :class:`torch.nn.utils.rnn.PackedSequence` has been given as the input, the output
  496. will also be a packed sequence.
  497. * **h_n**: tensor of shape :math:`(D * \text{num\_layers}, H_{out})` for unbatched input or
  498. :math:`(D * \text{num\_layers}, N, H_{out})` containing the final hidden state
  499. for each element in the batch.
  500. Attributes:
  501. weight_ih_l[k]: the learnable input-hidden weights of the k-th layer,
  502. of shape `(hidden_size, input_size)` for `k = 0`. Otherwise, the shape is
  503. `(hidden_size, num_directions * hidden_size)`
  504. weight_hh_l[k]: the learnable hidden-hidden weights of the k-th layer,
  505. of shape `(hidden_size, hidden_size)`
  506. bias_ih_l[k]: the learnable input-hidden bias of the k-th layer,
  507. of shape `(hidden_size)`
  508. bias_hh_l[k]: the learnable hidden-hidden bias of the k-th layer,
  509. of shape `(hidden_size)`
  510. .. note::
  511. All the weights and biases are initialized from :math:`\mathcal{U}(-\sqrt{k}, \sqrt{k})`
  512. where :math:`k = \frac{1}{\text{hidden\_size}}`
  513. .. note::
  514. For bidirectional RNNs, forward and backward are directions 0 and 1 respectively.
  515. Example of splitting the output layers when ``batch_first=False``:
  516. ``output.view(seq_len, batch, num_directions, hidden_size)``.
  517. .. note::
  518. ``batch_first`` argument is ignored for unbatched inputs.
  519. .. include:: ../cudnn_rnn_determinism.rst
  520. .. include:: ../cudnn_persistent_rnn.rst
  521. Examples::
  522. >>> rnn = nn.RNN(10, 20, 2)
  523. >>> input = torch.randn(5, 3, 10)
  524. >>> h0 = torch.randn(2, 3, 20)
  525. >>> output, hn = rnn(input, h0)
  526. """
  527. @overload
  528. def __init__(
  529. self,
  530. input_size: int,
  531. hidden_size: int,
  532. num_layers: int = 1,
  533. nonlinearity: str = "tanh",
  534. bias: bool = True,
  535. batch_first: bool = False,
  536. dropout: float = 0.0,
  537. bidirectional: bool = False,
  538. device=None,
  539. dtype=None,
  540. ) -> None: ...
  541. @overload
  542. def __init__(self, *args, **kwargs) -> None: ...
  543. def __init__(self, *args, **kwargs):
  544. if "proj_size" in kwargs:
  545. raise ValueError(
  546. "proj_size argument is only supported for LSTM, not RNN or GRU"
  547. )
  548. if len(args) > 3:
  549. self.nonlinearity = args[3]
  550. args = args[:3] + args[4:]
  551. else:
  552. self.nonlinearity = kwargs.pop("nonlinearity", "tanh")
  553. if self.nonlinearity == "tanh":
  554. mode = "RNN_TANH"
  555. elif self.nonlinearity == "relu":
  556. mode = "RNN_RELU"
  557. else:
  558. raise ValueError(
  559. f"Unknown nonlinearity '{self.nonlinearity}'. Select from 'tanh' or 'relu'."
  560. )
  561. super().__init__(mode, *args, **kwargs)
  562. @overload
  563. @torch._jit_internal._overload_method # noqa: F811
  564. def forward(
  565. self, input: Tensor, hx: Optional[Tensor] = None
  566. ) -> tuple[Tensor, Tensor]:
  567. pass
  568. @overload
  569. @torch._jit_internal._overload_method # noqa: F811
  570. def forward(
  571. self, input: PackedSequence, hx: Optional[Tensor] = None
  572. ) -> tuple[PackedSequence, Tensor]:
  573. pass
  574. def forward(self, input, hx=None): # noqa: F811
  575. """
  576. Runs the forward pass.
  577. """
  578. self._update_flat_weights()
  579. num_directions = 2 if self.bidirectional else 1
  580. orig_input = input
  581. if isinstance(orig_input, PackedSequence):
  582. input, batch_sizes, sorted_indices, unsorted_indices = input
  583. max_batch_size = batch_sizes[0]
  584. # script() is unhappy when max_batch_size is different type in cond branches, so we duplicate
  585. if hx is None:
  586. hx = torch.zeros(
  587. self.num_layers * num_directions,
  588. max_batch_size,
  589. self.hidden_size,
  590. dtype=input.dtype,
  591. device=input.device,
  592. )
  593. else:
  594. # Each batch of the hidden state should match the input sequence that
  595. # the user believes he/she is passing in.
  596. hx = self.permute_hidden(hx, sorted_indices)
  597. else:
  598. batch_sizes = None
  599. if input.dim() not in (2, 3):
  600. raise ValueError(
  601. f"RNN: Expected input to be 2D or 3D, got {input.dim()}D tensor instead"
  602. )
  603. is_batched = input.dim() == 3
  604. batch_dim = 0 if self.batch_first else 1
  605. if not is_batched:
  606. input = input.unsqueeze(batch_dim)
  607. if hx is not None:
  608. if hx.dim() != 2:
  609. raise RuntimeError(
  610. f"For unbatched 2-D input, hx should also be 2-D but got {hx.dim()}-D tensor"
  611. )
  612. hx = hx.unsqueeze(1)
  613. else:
  614. if hx is not None and hx.dim() != 3:
  615. raise RuntimeError(
  616. f"For batched 3-D input, hx should also be 3-D but got {hx.dim()}-D tensor"
  617. )
  618. max_batch_size = input.size(0) if self.batch_first else input.size(1)
  619. sorted_indices = None
  620. unsorted_indices = None
  621. if hx is None:
  622. hx = torch.zeros(
  623. self.num_layers * num_directions,
  624. max_batch_size,
  625. self.hidden_size,
  626. dtype=input.dtype,
  627. device=input.device,
  628. )
  629. else:
  630. # Each batch of the hidden state should match the input sequence that
  631. # the user believes he/she is passing in.
  632. hx = self.permute_hidden(hx, sorted_indices)
  633. assert hx is not None
  634. self.check_forward_args(input, hx, batch_sizes)
  635. assert self.mode == "RNN_TANH" or self.mode == "RNN_RELU"
  636. if batch_sizes is None:
  637. if self.mode == "RNN_TANH":
  638. result = _VF.rnn_tanh(
  639. input,
  640. hx,
  641. self._flat_weights, # type: ignore[arg-type]
  642. self.bias,
  643. self.num_layers,
  644. self.dropout,
  645. self.training,
  646. self.bidirectional,
  647. self.batch_first,
  648. )
  649. else:
  650. result = _VF.rnn_relu(
  651. input,
  652. hx,
  653. self._flat_weights, # type: ignore[arg-type]
  654. self.bias,
  655. self.num_layers,
  656. self.dropout,
  657. self.training,
  658. self.bidirectional,
  659. self.batch_first,
  660. )
  661. else:
  662. if self.mode == "RNN_TANH":
  663. result = _VF.rnn_tanh(
  664. input,
  665. batch_sizes,
  666. hx,
  667. self._flat_weights, # type: ignore[arg-type]
  668. self.bias,
  669. self.num_layers,
  670. self.dropout,
  671. self.training,
  672. self.bidirectional,
  673. )
  674. else:
  675. result = _VF.rnn_relu(
  676. input,
  677. batch_sizes,
  678. hx,
  679. self._flat_weights, # type: ignore[arg-type]
  680. self.bias,
  681. self.num_layers,
  682. self.dropout,
  683. self.training,
  684. self.bidirectional,
  685. )
  686. output = result[0]
  687. hidden = result[1]
  688. if isinstance(orig_input, PackedSequence):
  689. output_packed = PackedSequence(
  690. output, batch_sizes, sorted_indices, unsorted_indices
  691. )
  692. return output_packed, self.permute_hidden(hidden, unsorted_indices)
  693. if not is_batched: # type: ignore[possibly-undefined]
  694. output = output.squeeze(batch_dim) # type: ignore[possibly-undefined]
  695. hidden = hidden.squeeze(1)
  696. return output, self.permute_hidden(hidden, unsorted_indices)
  697. # XXX: LSTM and GRU implementation is different from RNNBase, this is because:
  698. # 1. we want to support nn.LSTM and nn.GRU in TorchScript and TorchScript in
  699. # its current state could not support the python Union Type or Any Type
  700. # 2. TorchScript static typing does not allow a Function or Callable type in
  701. # Dict values, so we have to separately call _VF instead of using _rnn_impls
  702. # 3. This is temporary only and in the transition state that we want to make it
  703. # on time for the release
  704. #
  705. # More discussion details in https://github.com/pytorch/pytorch/pull/23266
  706. #
  707. # TODO: remove the overriding implementations for LSTM and GRU when TorchScript
  708. # support expressing these two modules generally.
  709. class LSTM(RNNBase):
  710. r"""__init__(input_size,hidden_size,num_layers=1,bias=True,batch_first=False,dropout=0.0,bidirectional=False,proj_size=0,device=None,dtype=None)
  711. Apply a multi-layer long short-term memory (LSTM) RNN to an input sequence.
  712. For each element in the input sequence, each layer computes the following
  713. function:
  714. .. math::
  715. \begin{array}{ll} \\
  716. i_t = \sigma(W_{ii} x_t + b_{ii} + W_{hi} h_{t-1} + b_{hi}) \\
  717. f_t = \sigma(W_{if} x_t + b_{if} + W_{hf} h_{t-1} + b_{hf}) \\
  718. g_t = \tanh(W_{ig} x_t + b_{ig} + W_{hg} h_{t-1} + b_{hg}) \\
  719. o_t = \sigma(W_{io} x_t + b_{io} + W_{ho} h_{t-1} + b_{ho}) \\
  720. c_t = f_t \odot c_{t-1} + i_t \odot g_t \\
  721. h_t = o_t \odot \tanh(c_t) \\
  722. \end{array}
  723. where :math:`h_t` is the hidden state at time `t`, :math:`c_t` is the cell
  724. state at time `t`, :math:`x_t` is the input at time `t`, :math:`h_{t-1}`
  725. is the hidden state of the layer at time `t-1` or the initial hidden
  726. state at time `0`, and :math:`i_t`, :math:`f_t`, :math:`g_t`,
  727. :math:`o_t` are the input, forget, cell, and output gates, respectively.
  728. :math:`\sigma` is the sigmoid function, and :math:`\odot` is the Hadamard product.
  729. In a multilayer LSTM, the input :math:`x^{(l)}_t` of the :math:`l` -th layer
  730. (:math:`l \ge 2`) is the hidden state :math:`h^{(l-1)}_t` of the previous layer multiplied by
  731. dropout :math:`\delta^{(l-1)}_t` where each :math:`\delta^{(l-1)}_t` is a Bernoulli random
  732. variable which is :math:`0` with probability :attr:`dropout`.
  733. If ``proj_size > 0`` is specified, LSTM with projections will be used. This changes
  734. the LSTM cell in the following way. First, the dimension of :math:`h_t` will be changed from
  735. ``hidden_size`` to ``proj_size`` (dimensions of :math:`W_{hi}` will be changed accordingly).
  736. Second, the output hidden state of each layer will be multiplied by a learnable projection
  737. matrix: :math:`h_t = W_{hr}h_t`. Note that as a consequence of this, the output
  738. of LSTM network will be of different shape as well. See Inputs/Outputs sections below for exact
  739. dimensions of all variables. You can find more details in https://arxiv.org/abs/1402.1128.
  740. Args:
  741. input_size: The number of expected features in the input `x`
  742. hidden_size: The number of features in the hidden state `h`
  743. num_layers: Number of recurrent layers. E.g., setting ``num_layers=2``
  744. would mean stacking two LSTMs together to form a `stacked LSTM`,
  745. with the second LSTM taking in outputs of the first LSTM and
  746. computing the final results. Default: 1
  747. bias: If ``False``, then the layer does not use bias weights `b_ih` and `b_hh`.
  748. Default: ``True``
  749. batch_first: If ``True``, then the input and output tensors are provided
  750. as `(batch, seq, feature)` instead of `(seq, batch, feature)`.
  751. Note that this does not apply to hidden or cell states. See the
  752. Inputs/Outputs sections below for details. Default: ``False``
  753. dropout: If non-zero, introduces a `Dropout` layer on the outputs of each
  754. LSTM layer except the last layer, with dropout probability equal to
  755. :attr:`dropout`. Default: 0
  756. bidirectional: If ``True``, becomes a bidirectional LSTM. Default: ``False``
  757. proj_size: If ``> 0``, will use LSTM with projections of corresponding size. Default: 0
  758. Inputs: input, (h_0, c_0)
  759. * **input**: tensor of shape :math:`(L, H_{in})` for unbatched input,
  760. :math:`(L, N, H_{in})` when ``batch_first=False`` or
  761. :math:`(N, L, H_{in})` when ``batch_first=True`` containing the features of
  762. the input sequence. The input can also be a packed variable length sequence.
  763. See :func:`torch.nn.utils.rnn.pack_padded_sequence` or
  764. :func:`torch.nn.utils.rnn.pack_sequence` for details.
  765. * **h_0**: tensor of shape :math:`(D * \text{num\_layers}, H_{out})` for unbatched input or
  766. :math:`(D * \text{num\_layers}, N, H_{out})` containing the
  767. initial hidden state for each element in the input sequence.
  768. Defaults to zeros if (h_0, c_0) is not provided.
  769. * **c_0**: tensor of shape :math:`(D * \text{num\_layers}, H_{cell})` for unbatched input or
  770. :math:`(D * \text{num\_layers}, N, H_{cell})` containing the
  771. initial cell state for each element in the input sequence.
  772. Defaults to zeros if (h_0, c_0) is not provided.
  773. where:
  774. .. math::
  775. \begin{aligned}
  776. N ={} & \text{batch size} \\
  777. L ={} & \text{sequence length} \\
  778. D ={} & 2 \text{ if bidirectional=True otherwise } 1 \\
  779. H_{in} ={} & \text{input\_size} \\
  780. H_{cell} ={} & \text{hidden\_size} \\
  781. H_{out} ={} & \text{proj\_size if } \text{proj\_size}>0 \text{ otherwise hidden\_size} \\
  782. \end{aligned}
  783. Outputs: output, (h_n, c_n)
  784. * **output**: tensor of shape :math:`(L, D * H_{out})` for unbatched input,
  785. :math:`(L, N, D * H_{out})` when ``batch_first=False`` or
  786. :math:`(N, L, D * H_{out})` when ``batch_first=True`` containing the output features
  787. `(h_t)` from the last layer of the LSTM, for each `t`. If a
  788. :class:`torch.nn.utils.rnn.PackedSequence` has been given as the input, the output
  789. will also be a packed sequence. When ``bidirectional=True``, `output` will contain
  790. a concatenation of the forward and reverse hidden states at each time step in the sequence.
  791. * **h_n**: tensor of shape :math:`(D * \text{num\_layers}, H_{out})` for unbatched input or
  792. :math:`(D * \text{num\_layers}, N, H_{out})` containing the
  793. final hidden state for each element in the sequence. When ``bidirectional=True``,
  794. `h_n` will contain a concatenation of the final forward and reverse hidden states, respectively.
  795. * **c_n**: tensor of shape :math:`(D * \text{num\_layers}, H_{cell})` for unbatched input or
  796. :math:`(D * \text{num\_layers}, N, H_{cell})` containing the
  797. final cell state for each element in the sequence. When ``bidirectional=True``,
  798. `c_n` will contain a concatenation of the final forward and reverse cell states, respectively.
  799. Attributes:
  800. weight_ih_l[k] : the learnable input-hidden weights of the :math:`\text{k}^{th}` layer
  801. `(W_ii|W_if|W_ig|W_io)`, of shape `(4*hidden_size, input_size)` for `k = 0`.
  802. Otherwise, the shape is `(4*hidden_size, num_directions * hidden_size)`. If
  803. ``proj_size > 0`` was specified, the shape will be
  804. `(4*hidden_size, num_directions * proj_size)` for `k > 0`
  805. weight_hh_l[k] : the learnable hidden-hidden weights of the :math:`\text{k}^{th}` layer
  806. `(W_hi|W_hf|W_hg|W_ho)`, of shape `(4*hidden_size, hidden_size)`. If ``proj_size > 0``
  807. was specified, the shape will be `(4*hidden_size, proj_size)`.
  808. bias_ih_l[k] : the learnable input-hidden bias of the :math:`\text{k}^{th}` layer
  809. `(b_ii|b_if|b_ig|b_io)`, of shape `(4*hidden_size)`
  810. bias_hh_l[k] : the learnable hidden-hidden bias of the :math:`\text{k}^{th}` layer
  811. `(b_hi|b_hf|b_hg|b_ho)`, of shape `(4*hidden_size)`
  812. weight_hr_l[k] : the learnable projection weights of the :math:`\text{k}^{th}` layer
  813. of shape `(proj_size, hidden_size)`. Only present when ``proj_size > 0`` was
  814. specified.
  815. weight_ih_l[k]_reverse: Analogous to `weight_ih_l[k]` for the reverse direction.
  816. Only present when ``bidirectional=True``.
  817. weight_hh_l[k]_reverse: Analogous to `weight_hh_l[k]` for the reverse direction.
  818. Only present when ``bidirectional=True``.
  819. bias_ih_l[k]_reverse: Analogous to `bias_ih_l[k]` for the reverse direction.
  820. Only present when ``bidirectional=True``.
  821. bias_hh_l[k]_reverse: Analogous to `bias_hh_l[k]` for the reverse direction.
  822. Only present when ``bidirectional=True``.
  823. weight_hr_l[k]_reverse: Analogous to `weight_hr_l[k]` for the reverse direction.
  824. Only present when ``bidirectional=True`` and ``proj_size > 0`` was specified.
  825. .. note::
  826. All the weights and biases are initialized from :math:`\mathcal{U}(-\sqrt{k}, \sqrt{k})`
  827. where :math:`k = \frac{1}{\text{hidden\_size}}`
  828. .. note::
  829. For bidirectional LSTMs, forward and backward are directions 0 and 1 respectively.
  830. Example of splitting the output layers when ``batch_first=False``:
  831. ``output.view(seq_len, batch, num_directions, hidden_size)``.
  832. .. note::
  833. For bidirectional LSTMs, `h_n` is not equivalent to the last element of `output`; the
  834. former contains the final forward and reverse hidden states, while the latter contains the
  835. final forward hidden state and the initial reverse hidden state.
  836. .. note::
  837. ``batch_first`` argument is ignored for unbatched inputs.
  838. .. note::
  839. ``proj_size`` should be smaller than ``hidden_size``.
  840. .. include:: ../cudnn_rnn_determinism.rst
  841. .. include:: ../cudnn_persistent_rnn.rst
  842. Examples::
  843. >>> rnn = nn.LSTM(10, 20, 2)
  844. >>> input = torch.randn(5, 3, 10)
  845. >>> h0 = torch.randn(2, 3, 20)
  846. >>> c0 = torch.randn(2, 3, 20)
  847. >>> output, (hn, cn) = rnn(input, (h0, c0))
  848. """
  849. @overload
  850. def __init__(
  851. self,
  852. input_size: int,
  853. hidden_size: int,
  854. num_layers: int = 1,
  855. bias: bool = True,
  856. batch_first: bool = False,
  857. dropout: float = 0.0,
  858. bidirectional: bool = False,
  859. proj_size: int = 0,
  860. device=None,
  861. dtype=None,
  862. ) -> None: ...
  863. @overload
  864. def __init__(self, *args, **kwargs) -> None: ...
  865. def __init__(self, *args, **kwargs):
  866. super().__init__("LSTM", *args, **kwargs)
  867. def get_expected_cell_size(
  868. self, input: Tensor, batch_sizes: Optional[Tensor]
  869. ) -> tuple[int, int, int]:
  870. if batch_sizes is not None:
  871. mini_batch = int(batch_sizes[0])
  872. else:
  873. mini_batch = input.size(0) if self.batch_first else input.size(1)
  874. num_directions = 2 if self.bidirectional else 1
  875. expected_hidden_size = (
  876. self.num_layers * num_directions,
  877. mini_batch,
  878. self.hidden_size,
  879. )
  880. return expected_hidden_size
  881. # In the future, we should prevent mypy from applying contravariance rules here.
  882. # See torch/nn/modules/module.py::_forward_unimplemented
  883. def check_forward_args(
  884. self,
  885. input: Tensor,
  886. hidden: tuple[Tensor, Tensor], # type: ignore[override]
  887. batch_sizes: Optional[Tensor],
  888. ) -> None:
  889. self.check_input(input, batch_sizes)
  890. self.check_hidden_size(
  891. hidden[0],
  892. self.get_expected_hidden_size(input, batch_sizes),
  893. "Expected hidden[0] size {}, got {}",
  894. )
  895. self.check_hidden_size(
  896. hidden[1],
  897. self.get_expected_cell_size(input, batch_sizes),
  898. "Expected hidden[1] size {}, got {}",
  899. )
  900. # Same as above, see torch/nn/modules/module.py::_forward_unimplemented
  901. def permute_hidden( # type: ignore[override]
  902. self,
  903. hx: tuple[Tensor, Tensor],
  904. permutation: Optional[Tensor],
  905. ) -> tuple[Tensor, Tensor]:
  906. if permutation is None:
  907. return hx
  908. return _apply_permutation(hx[0], permutation), _apply_permutation(
  909. hx[1], permutation
  910. )
  911. # Same as above, see torch/nn/modules/module.py::_forward_unimplemented
  912. @overload # type: ignore[override]
  913. @torch._jit_internal._overload_method # noqa: F811
  914. def forward(
  915. self, input: Tensor, hx: Optional[tuple[Tensor, Tensor]] = None
  916. ) -> tuple[Tensor, tuple[Tensor, Tensor]]: # noqa: F811
  917. pass
  918. # Same as above, see torch/nn/modules/module.py::_forward_unimplemented
  919. @overload
  920. @torch._jit_internal._overload_method # noqa: F811
  921. def forward(
  922. self, input: PackedSequence, hx: Optional[tuple[Tensor, Tensor]] = None
  923. ) -> tuple[PackedSequence, tuple[Tensor, Tensor]]: # noqa: F811
  924. pass
  925. def forward(self, input, hx=None): # noqa: F811
  926. self._update_flat_weights()
  927. orig_input = input
  928. # xxx: isinstance check needs to be in conditional for TorchScript to compile
  929. batch_sizes = None
  930. num_directions = 2 if self.bidirectional else 1
  931. real_hidden_size = self.proj_size if self.proj_size > 0 else self.hidden_size
  932. if isinstance(orig_input, PackedSequence):
  933. input, batch_sizes, sorted_indices, unsorted_indices = input
  934. max_batch_size = batch_sizes[0]
  935. if hx is None:
  936. h_zeros = torch.zeros(
  937. self.num_layers * num_directions,
  938. max_batch_size,
  939. real_hidden_size,
  940. dtype=input.dtype,
  941. device=input.device,
  942. )
  943. c_zeros = torch.zeros(
  944. self.num_layers * num_directions,
  945. max_batch_size,
  946. self.hidden_size,
  947. dtype=input.dtype,
  948. device=input.device,
  949. )
  950. hx = (h_zeros, c_zeros)
  951. else:
  952. # Each batch of the hidden state should match the input sequence that
  953. # the user believes he/she is passing in.
  954. hx = self.permute_hidden(hx, sorted_indices)
  955. else:
  956. if input.dim() not in (2, 3):
  957. raise ValueError(
  958. f"LSTM: Expected input to be 2D or 3D, got {input.dim()}D instead"
  959. )
  960. is_batched = input.dim() == 3
  961. batch_dim = 0 if self.batch_first else 1
  962. if not is_batched:
  963. input = input.unsqueeze(batch_dim)
  964. max_batch_size = input.size(0) if self.batch_first else input.size(1)
  965. sorted_indices = None
  966. unsorted_indices = None
  967. if hx is None:
  968. h_zeros = torch.zeros(
  969. self.num_layers * num_directions,
  970. max_batch_size,
  971. real_hidden_size,
  972. dtype=input.dtype,
  973. device=input.device,
  974. )
  975. c_zeros = torch.zeros(
  976. self.num_layers * num_directions,
  977. max_batch_size,
  978. self.hidden_size,
  979. dtype=input.dtype,
  980. device=input.device,
  981. )
  982. hx = (h_zeros, c_zeros)
  983. self.check_forward_args(input, hx, batch_sizes)
  984. else:
  985. if is_batched:
  986. if hx[0].dim() != 3 or hx[1].dim() != 3:
  987. msg = (
  988. "For batched 3-D input, hx and cx should "
  989. f"also be 3-D but got ({hx[0].dim()}-D, {hx[1].dim()}-D) tensors"
  990. )
  991. raise RuntimeError(msg)
  992. else:
  993. if hx[0].dim() != 2 or hx[1].dim() != 2:
  994. msg = (
  995. "For unbatched 2-D input, hx and cx should "
  996. f"also be 2-D but got ({hx[0].dim()}-D, {hx[1].dim()}-D) tensors"
  997. )
  998. raise RuntimeError(msg)
  999. hx = (hx[0].unsqueeze(1), hx[1].unsqueeze(1))
  1000. # Each batch of the hidden state should match the input sequence that
  1001. # the user believes he/she is passing in.
  1002. self.check_forward_args(input, hx, batch_sizes)
  1003. hx = self.permute_hidden(hx, sorted_indices)
  1004. if batch_sizes is None:
  1005. result = _VF.lstm(
  1006. input,
  1007. hx,
  1008. self._flat_weights, # type: ignore[arg-type]
  1009. self.bias,
  1010. self.num_layers,
  1011. self.dropout,
  1012. self.training,
  1013. self.bidirectional,
  1014. self.batch_first,
  1015. )
  1016. else:
  1017. result = _VF.lstm(
  1018. input,
  1019. batch_sizes,
  1020. hx,
  1021. self._flat_weights, # type: ignore[arg-type]
  1022. self.bias,
  1023. self.num_layers,
  1024. self.dropout,
  1025. self.training,
  1026. self.bidirectional,
  1027. )
  1028. output = result[0]
  1029. hidden = result[1:]
  1030. # xxx: isinstance check needs to be in conditional for TorchScript to compile
  1031. if isinstance(orig_input, PackedSequence):
  1032. output_packed = PackedSequence(
  1033. output, batch_sizes, sorted_indices, unsorted_indices
  1034. )
  1035. return output_packed, self.permute_hidden(hidden, unsorted_indices)
  1036. else:
  1037. if not is_batched: # type: ignore[possibly-undefined]
  1038. output = output.squeeze(batch_dim) # type: ignore[possibly-undefined]
  1039. hidden = (hidden[0].squeeze(1), hidden[1].squeeze(1))
  1040. return output, self.permute_hidden(hidden, unsorted_indices)
  1041. class GRU(RNNBase):
  1042. r"""__init__(input_size,hidden_size,num_layers=1,bias=True,batch_first=False,dropout=0.0,bidirectional=False,device=None,dtype=None)
  1043. Apply a multi-layer gated recurrent unit (GRU) RNN to an input sequence.
  1044. For each element in the input sequence, each layer computes the following
  1045. function:
  1046. .. math::
  1047. \begin{array}{ll}
  1048. r_t = \sigma(W_{ir} x_t + b_{ir} + W_{hr} h_{(t-1)} + b_{hr}) \\
  1049. z_t = \sigma(W_{iz} x_t + b_{iz} + W_{hz} h_{(t-1)} + b_{hz}) \\
  1050. n_t = \tanh(W_{in} x_t + b_{in} + r_t \odot (W_{hn} h_{(t-1)}+ b_{hn})) \\
  1051. h_t = (1 - z_t) \odot n_t + z_t \odot h_{(t-1)}
  1052. \end{array}
  1053. where :math:`h_t` is the hidden state at time `t`, :math:`x_t` is the input
  1054. at time `t`, :math:`h_{(t-1)}` is the hidden state of the layer
  1055. at time `t-1` or the initial hidden state at time `0`, and :math:`r_t`,
  1056. :math:`z_t`, :math:`n_t` are the reset, update, and new gates, respectively.
  1057. :math:`\sigma` is the sigmoid function, and :math:`\odot` is the Hadamard product.
  1058. In a multilayer GRU, the input :math:`x^{(l)}_t` of the :math:`l` -th layer
  1059. (:math:`l \ge 2`) is the hidden state :math:`h^{(l-1)}_t` of the previous layer multiplied by
  1060. dropout :math:`\delta^{(l-1)}_t` where each :math:`\delta^{(l-1)}_t` is a Bernoulli random
  1061. variable which is :math:`0` with probability :attr:`dropout`.
  1062. Args:
  1063. input_size: The number of expected features in the input `x`
  1064. hidden_size: The number of features in the hidden state `h`
  1065. num_layers: Number of recurrent layers. E.g., setting ``num_layers=2``
  1066. would mean stacking two GRUs together to form a `stacked GRU`,
  1067. with the second GRU taking in outputs of the first GRU and
  1068. computing the final results. Default: 1
  1069. bias: If ``False``, then the layer does not use bias weights `b_ih` and `b_hh`.
  1070. Default: ``True``
  1071. batch_first: If ``True``, then the input and output tensors are provided
  1072. as `(batch, seq, feature)` instead of `(seq, batch, feature)`.
  1073. Note that this does not apply to hidden or cell states. See the
  1074. Inputs/Outputs sections below for details. Default: ``False``
  1075. dropout: If non-zero, introduces a `Dropout` layer on the outputs of each
  1076. GRU layer except the last layer, with dropout probability equal to
  1077. :attr:`dropout`. Default: 0
  1078. bidirectional: If ``True``, becomes a bidirectional GRU. Default: ``False``
  1079. Inputs: input, h_0
  1080. * **input**: tensor of shape :math:`(L, H_{in})` for unbatched input,
  1081. :math:`(L, N, H_{in})` when ``batch_first=False`` or
  1082. :math:`(N, L, H_{in})` when ``batch_first=True`` containing the features of
  1083. the input sequence. The input can also be a packed variable length sequence.
  1084. See :func:`torch.nn.utils.rnn.pack_padded_sequence` or
  1085. :func:`torch.nn.utils.rnn.pack_sequence` for details.
  1086. * **h_0**: tensor of shape :math:`(D * \text{num\_layers}, H_{out})` or
  1087. :math:`(D * \text{num\_layers}, N, H_{out})`
  1088. containing the initial hidden state for the input sequence. Defaults to zeros if not provided.
  1089. where:
  1090. .. math::
  1091. \begin{aligned}
  1092. N ={} & \text{batch size} \\
  1093. L ={} & \text{sequence length} \\
  1094. D ={} & 2 \text{ if bidirectional=True otherwise } 1 \\
  1095. H_{in} ={} & \text{input\_size} \\
  1096. H_{out} ={} & \text{hidden\_size}
  1097. \end{aligned}
  1098. Outputs: output, h_n
  1099. * **output**: tensor of shape :math:`(L, D * H_{out})` for unbatched input,
  1100. :math:`(L, N, D * H_{out})` when ``batch_first=False`` or
  1101. :math:`(N, L, D * H_{out})` when ``batch_first=True`` containing the output features
  1102. `(h_t)` from the last layer of the GRU, for each `t`. If a
  1103. :class:`torch.nn.utils.rnn.PackedSequence` has been given as the input, the output
  1104. will also be a packed sequence.
  1105. * **h_n**: tensor of shape :math:`(D * \text{num\_layers}, H_{out})` or
  1106. :math:`(D * \text{num\_layers}, N, H_{out})` containing the final hidden state
  1107. for the input sequence.
  1108. Attributes:
  1109. weight_ih_l[k] : the learnable input-hidden weights of the :math:`\text{k}^{th}` layer
  1110. (W_ir|W_iz|W_in), of shape `(3*hidden_size, input_size)` for `k = 0`.
  1111. Otherwise, the shape is `(3*hidden_size, num_directions * hidden_size)`
  1112. weight_hh_l[k] : the learnable hidden-hidden weights of the :math:`\text{k}^{th}` layer
  1113. (W_hr|W_hz|W_hn), of shape `(3*hidden_size, hidden_size)`
  1114. bias_ih_l[k] : the learnable input-hidden bias of the :math:`\text{k}^{th}` layer
  1115. (b_ir|b_iz|b_in), of shape `(3*hidden_size)`
  1116. bias_hh_l[k] : the learnable hidden-hidden bias of the :math:`\text{k}^{th}` layer
  1117. (b_hr|b_hz|b_hn), of shape `(3*hidden_size)`
  1118. .. note::
  1119. All the weights and biases are initialized from :math:`\mathcal{U}(-\sqrt{k}, \sqrt{k})`
  1120. where :math:`k = \frac{1}{\text{hidden\_size}}`
  1121. .. note::
  1122. For bidirectional GRUs, forward and backward are directions 0 and 1 respectively.
  1123. Example of splitting the output layers when ``batch_first=False``:
  1124. ``output.view(seq_len, batch, num_directions, hidden_size)``.
  1125. .. note::
  1126. ``batch_first`` argument is ignored for unbatched inputs.
  1127. .. note::
  1128. The calculation of new gate :math:`n_t` subtly differs from the original paper and other frameworks.
  1129. In the original implementation, the Hadamard product :math:`(\odot)` between :math:`r_t` and the
  1130. previous hidden state :math:`h_{(t-1)}` is done before the multiplication with the weight matrix
  1131. `W` and addition of bias:
  1132. .. math::
  1133. \begin{aligned}
  1134. n_t = \tanh(W_{in} x_t + b_{in} + W_{hn} ( r_t \odot h_{(t-1)} ) + b_{hn})
  1135. \end{aligned}
  1136. This is in contrast to PyTorch implementation, which is done after :math:`W_{hn} h_{(t-1)}`
  1137. .. math::
  1138. \begin{aligned}
  1139. n_t = \tanh(W_{in} x_t + b_{in} + r_t \odot (W_{hn} h_{(t-1)}+ b_{hn}))
  1140. \end{aligned}
  1141. This implementation differs on purpose for efficiency.
  1142. .. include:: ../cudnn_persistent_rnn.rst
  1143. Examples::
  1144. >>> rnn = nn.GRU(10, 20, 2)
  1145. >>> input = torch.randn(5, 3, 10)
  1146. >>> h0 = torch.randn(2, 3, 20)
  1147. >>> output, hn = rnn(input, h0)
  1148. """
  1149. @overload
  1150. def __init__(
  1151. self,
  1152. input_size: int,
  1153. hidden_size: int,
  1154. num_layers: int = 1,
  1155. bias: bool = True,
  1156. batch_first: bool = False,
  1157. dropout: float = 0.0,
  1158. bidirectional: bool = False,
  1159. device=None,
  1160. dtype=None,
  1161. ) -> None: ...
  1162. @overload
  1163. def __init__(self, *args, **kwargs) -> None: ...
  1164. def __init__(self, *args, **kwargs):
  1165. if "proj_size" in kwargs:
  1166. raise ValueError(
  1167. "proj_size argument is only supported for LSTM, not RNN or GRU"
  1168. )
  1169. super().__init__("GRU", *args, **kwargs)
  1170. @overload # type: ignore[override]
  1171. @torch._jit_internal._overload_method # noqa: F811
  1172. def forward(
  1173. self, input: Tensor, hx: Optional[Tensor] = None
  1174. ) -> tuple[Tensor, Tensor]: # noqa: F811
  1175. pass
  1176. @overload
  1177. @torch._jit_internal._overload_method # noqa: F811
  1178. def forward(
  1179. self, input: PackedSequence, hx: Optional[Tensor] = None
  1180. ) -> tuple[PackedSequence, Tensor]: # noqa: F811
  1181. pass
  1182. def forward(self, input, hx=None): # noqa: F811
  1183. self._update_flat_weights()
  1184. orig_input = input
  1185. # xxx: isinstance check needs to be in conditional for TorchScript to compile
  1186. if isinstance(orig_input, PackedSequence):
  1187. input, batch_sizes, sorted_indices, unsorted_indices = input
  1188. max_batch_size = batch_sizes[0]
  1189. if hx is None:
  1190. num_directions = 2 if self.bidirectional else 1
  1191. hx = torch.zeros(
  1192. self.num_layers * num_directions,
  1193. max_batch_size,
  1194. self.hidden_size,
  1195. dtype=input.dtype,
  1196. device=input.device,
  1197. )
  1198. else:
  1199. # Each batch of the hidden state should match the input sequence that
  1200. # the user believes he/she is passing in.
  1201. hx = self.permute_hidden(hx, sorted_indices)
  1202. else:
  1203. batch_sizes = None
  1204. if input.dim() not in (2, 3):
  1205. raise ValueError(
  1206. f"GRU: Expected input to be 2D or 3D, got {input.dim()}D instead"
  1207. )
  1208. is_batched = input.dim() == 3
  1209. batch_dim = 0 if self.batch_first else 1
  1210. if not is_batched:
  1211. input = input.unsqueeze(batch_dim)
  1212. if hx is not None:
  1213. if hx.dim() != 2:
  1214. raise RuntimeError(
  1215. f"For unbatched 2-D input, hx should also be 2-D but got {hx.dim()}-D tensor"
  1216. )
  1217. hx = hx.unsqueeze(1)
  1218. else:
  1219. if hx is not None and hx.dim() != 3:
  1220. raise RuntimeError(
  1221. f"For batched 3-D input, hx should also be 3-D but got {hx.dim()}-D tensor"
  1222. )
  1223. max_batch_size = input.size(0) if self.batch_first else input.size(1)
  1224. sorted_indices = None
  1225. unsorted_indices = None
  1226. if hx is None:
  1227. num_directions = 2 if self.bidirectional else 1
  1228. hx = torch.zeros(
  1229. self.num_layers * num_directions,
  1230. max_batch_size,
  1231. self.hidden_size,
  1232. dtype=input.dtype,
  1233. device=input.device,
  1234. )
  1235. else:
  1236. # Each batch of the hidden state should match the input sequence that
  1237. # the user believes he/she is passing in.
  1238. hx = self.permute_hidden(hx, sorted_indices)
  1239. self.check_forward_args(input, hx, batch_sizes)
  1240. if batch_sizes is None:
  1241. result = _VF.gru(
  1242. input,
  1243. hx,
  1244. self._flat_weights, # type: ignore[arg-type]
  1245. self.bias,
  1246. self.num_layers,
  1247. self.dropout,
  1248. self.training,
  1249. self.bidirectional,
  1250. self.batch_first,
  1251. )
  1252. else:
  1253. result = _VF.gru(
  1254. input,
  1255. batch_sizes,
  1256. hx,
  1257. self._flat_weights, # type: ignore[arg-type]
  1258. self.bias,
  1259. self.num_layers,
  1260. self.dropout,
  1261. self.training,
  1262. self.bidirectional,
  1263. )
  1264. output = result[0]
  1265. hidden = result[1]
  1266. # xxx: isinstance check needs to be in conditional for TorchScript to compile
  1267. if isinstance(orig_input, PackedSequence):
  1268. output_packed = PackedSequence(
  1269. output, batch_sizes, sorted_indices, unsorted_indices
  1270. )
  1271. return output_packed, self.permute_hidden(hidden, unsorted_indices)
  1272. else:
  1273. if not is_batched: # type: ignore[possibly-undefined]
  1274. output = output.squeeze(batch_dim) # type: ignore[possibly-undefined]
  1275. hidden = hidden.squeeze(1)
  1276. return output, self.permute_hidden(hidden, unsorted_indices)
  1277. class RNNCellBase(Module):
  1278. __constants__ = ["input_size", "hidden_size", "bias"]
  1279. input_size: int
  1280. hidden_size: int
  1281. bias: bool
  1282. weight_ih: Tensor
  1283. weight_hh: Tensor
  1284. # WARNING: bias_ih and bias_hh purposely not defined here.
  1285. # See https://github.com/pytorch/pytorch/issues/39670
  1286. def __init__(
  1287. self,
  1288. input_size: int,
  1289. hidden_size: int,
  1290. bias: bool,
  1291. num_chunks: int,
  1292. device=None,
  1293. dtype=None,
  1294. ) -> None:
  1295. factory_kwargs = {"device": device, "dtype": dtype}
  1296. super().__init__()
  1297. self.input_size = input_size
  1298. self.hidden_size = hidden_size
  1299. self.bias = bias
  1300. self.weight_ih = Parameter(
  1301. torch.empty((num_chunks * hidden_size, input_size), **factory_kwargs)
  1302. )
  1303. self.weight_hh = Parameter(
  1304. torch.empty((num_chunks * hidden_size, hidden_size), **factory_kwargs)
  1305. )
  1306. if bias:
  1307. self.bias_ih = Parameter(
  1308. torch.empty(num_chunks * hidden_size, **factory_kwargs)
  1309. )
  1310. self.bias_hh = Parameter(
  1311. torch.empty(num_chunks * hidden_size, **factory_kwargs)
  1312. )
  1313. else:
  1314. self.register_parameter("bias_ih", None)
  1315. self.register_parameter("bias_hh", None)
  1316. self.reset_parameters()
  1317. def extra_repr(self) -> str:
  1318. s = "{input_size}, {hidden_size}"
  1319. if "bias" in self.__dict__ and self.bias is not True:
  1320. s += ", bias={bias}"
  1321. if "nonlinearity" in self.__dict__ and self.nonlinearity != "tanh":
  1322. s += ", nonlinearity={nonlinearity}"
  1323. return s.format(**self.__dict__)
  1324. def reset_parameters(self) -> None:
  1325. stdv = 1.0 / math.sqrt(self.hidden_size) if self.hidden_size > 0 else 0
  1326. for weight in self.parameters():
  1327. init.uniform_(weight, -stdv, stdv)
  1328. class RNNCell(RNNCellBase):
  1329. r"""An Elman RNN cell with tanh or ReLU non-linearity.
  1330. .. math::
  1331. h' = \tanh(W_{ih} x + b_{ih} + W_{hh} h + b_{hh})
  1332. If :attr:`nonlinearity` is `'relu'`, then ReLU is used in place of tanh.
  1333. Args:
  1334. input_size: The number of expected features in the input `x`
  1335. hidden_size: The number of features in the hidden state `h`
  1336. bias: If ``False``, then the layer does not use bias weights `b_ih` and `b_hh`.
  1337. Default: ``True``
  1338. nonlinearity: The non-linearity to use. Can be either ``'tanh'`` or ``'relu'``. Default: ``'tanh'``
  1339. Inputs: input, hidden
  1340. - **input**: tensor containing input features
  1341. - **hidden**: tensor containing the initial hidden state
  1342. Defaults to zero if not provided.
  1343. Outputs: h'
  1344. - **h'** of shape `(batch, hidden_size)`: tensor containing the next hidden state
  1345. for each element in the batch
  1346. Shape:
  1347. - input: :math:`(N, H_{in})` or :math:`(H_{in})` tensor containing input features where
  1348. :math:`H_{in}` = `input_size`.
  1349. - hidden: :math:`(N, H_{out})` or :math:`(H_{out})` tensor containing the initial hidden
  1350. state where :math:`H_{out}` = `hidden_size`. Defaults to zero if not provided.
  1351. - output: :math:`(N, H_{out})` or :math:`(H_{out})` tensor containing the next hidden state.
  1352. Attributes:
  1353. weight_ih: the learnable input-hidden weights, of shape
  1354. `(hidden_size, input_size)`
  1355. weight_hh: the learnable hidden-hidden weights, of shape
  1356. `(hidden_size, hidden_size)`
  1357. bias_ih: the learnable input-hidden bias, of shape `(hidden_size)`
  1358. bias_hh: the learnable hidden-hidden bias, of shape `(hidden_size)`
  1359. .. note::
  1360. All the weights and biases are initialized from :math:`\mathcal{U}(-\sqrt{k}, \sqrt{k})`
  1361. where :math:`k = \frac{1}{\text{hidden\_size}}`
  1362. Examples::
  1363. >>> rnn = nn.RNNCell(10, 20)
  1364. >>> input = torch.randn(6, 3, 10)
  1365. >>> hx = torch.randn(3, 20)
  1366. >>> output = []
  1367. >>> for i in range(6):
  1368. ... hx = rnn(input[i], hx)
  1369. ... output.append(hx)
  1370. """
  1371. __constants__ = ["input_size", "hidden_size", "bias", "nonlinearity"]
  1372. nonlinearity: str
  1373. def __init__(
  1374. self,
  1375. input_size: int,
  1376. hidden_size: int,
  1377. bias: bool = True,
  1378. nonlinearity: str = "tanh",
  1379. device=None,
  1380. dtype=None,
  1381. ) -> None:
  1382. factory_kwargs = {"device": device, "dtype": dtype}
  1383. super().__init__(input_size, hidden_size, bias, num_chunks=1, **factory_kwargs)
  1384. self.nonlinearity = nonlinearity
  1385. def forward(self, input: Tensor, hx: Optional[Tensor] = None) -> Tensor:
  1386. if input.dim() not in (1, 2):
  1387. raise ValueError(
  1388. f"RNNCell: Expected input to be 1D or 2D, got {input.dim()}D instead"
  1389. )
  1390. if hx is not None and hx.dim() not in (1, 2):
  1391. raise ValueError(
  1392. f"RNNCell: Expected hidden to be 1D or 2D, got {hx.dim()}D instead"
  1393. )
  1394. is_batched = input.dim() == 2
  1395. if not is_batched:
  1396. input = input.unsqueeze(0)
  1397. if hx is None:
  1398. hx = torch.zeros(
  1399. input.size(0), self.hidden_size, dtype=input.dtype, device=input.device
  1400. )
  1401. else:
  1402. hx = hx.unsqueeze(0) if not is_batched else hx
  1403. if self.nonlinearity == "tanh":
  1404. ret = _VF.rnn_tanh_cell(
  1405. input,
  1406. hx,
  1407. self.weight_ih,
  1408. self.weight_hh,
  1409. self.bias_ih,
  1410. self.bias_hh,
  1411. )
  1412. elif self.nonlinearity == "relu":
  1413. ret = _VF.rnn_relu_cell(
  1414. input,
  1415. hx,
  1416. self.weight_ih,
  1417. self.weight_hh,
  1418. self.bias_ih,
  1419. self.bias_hh,
  1420. )
  1421. else:
  1422. ret = input # TODO: remove when jit supports exception flow
  1423. raise RuntimeError(f"Unknown nonlinearity: {self.nonlinearity}")
  1424. if not is_batched:
  1425. ret = ret.squeeze(0)
  1426. return ret
  1427. class LSTMCell(RNNCellBase):
  1428. r"""A long short-term memory (LSTM) cell.
  1429. .. math::
  1430. \begin{array}{ll}
  1431. i = \sigma(W_{ii} x + b_{ii} + W_{hi} h + b_{hi}) \\
  1432. f = \sigma(W_{if} x + b_{if} + W_{hf} h + b_{hf}) \\
  1433. g = \tanh(W_{ig} x + b_{ig} + W_{hg} h + b_{hg}) \\
  1434. o = \sigma(W_{io} x + b_{io} + W_{ho} h + b_{ho}) \\
  1435. c' = f \odot c + i \odot g \\
  1436. h' = o \odot \tanh(c') \\
  1437. \end{array}
  1438. where :math:`\sigma` is the sigmoid function, and :math:`\odot` is the Hadamard product.
  1439. Args:
  1440. input_size: The number of expected features in the input `x`
  1441. hidden_size: The number of features in the hidden state `h`
  1442. bias: If ``False``, then the layer does not use bias weights `b_ih` and
  1443. `b_hh`. Default: ``True``
  1444. Inputs: input, (h_0, c_0)
  1445. - **input** of shape `(batch, input_size)` or `(input_size)`: tensor containing input features
  1446. - **h_0** of shape `(batch, hidden_size)` or `(hidden_size)`: tensor containing the initial hidden state
  1447. - **c_0** of shape `(batch, hidden_size)` or `(hidden_size)`: tensor containing the initial cell state
  1448. If `(h_0, c_0)` is not provided, both **h_0** and **c_0** default to zero.
  1449. Outputs: (h_1, c_1)
  1450. - **h_1** of shape `(batch, hidden_size)` or `(hidden_size)`: tensor containing the next hidden state
  1451. - **c_1** of shape `(batch, hidden_size)` or `(hidden_size)`: tensor containing the next cell state
  1452. Attributes:
  1453. weight_ih: the learnable input-hidden weights, of shape
  1454. `(4*hidden_size, input_size)`
  1455. weight_hh: the learnable hidden-hidden weights, of shape
  1456. `(4*hidden_size, hidden_size)`
  1457. bias_ih: the learnable input-hidden bias, of shape `(4*hidden_size)`
  1458. bias_hh: the learnable hidden-hidden bias, of shape `(4*hidden_size)`
  1459. .. note::
  1460. All the weights and biases are initialized from :math:`\mathcal{U}(-\sqrt{k}, \sqrt{k})`
  1461. where :math:`k = \frac{1}{\text{hidden\_size}}`
  1462. On certain ROCm devices, when using float16 inputs this module will use :ref:`different precision<fp16_on_mi200>` for backward.
  1463. Examples::
  1464. >>> rnn = nn.LSTMCell(10, 20) # (input_size, hidden_size)
  1465. >>> input = torch.randn(2, 3, 10) # (time_steps, batch, input_size)
  1466. >>> hx = torch.randn(3, 20) # (batch, hidden_size)
  1467. >>> cx = torch.randn(3, 20)
  1468. >>> output = []
  1469. >>> for i in range(input.size()[0]):
  1470. ... hx, cx = rnn(input[i], (hx, cx))
  1471. ... output.append(hx)
  1472. >>> output = torch.stack(output, dim=0)
  1473. """
  1474. def __init__(
  1475. self,
  1476. input_size: int,
  1477. hidden_size: int,
  1478. bias: bool = True,
  1479. device=None,
  1480. dtype=None,
  1481. ) -> None:
  1482. factory_kwargs = {"device": device, "dtype": dtype}
  1483. super().__init__(input_size, hidden_size, bias, num_chunks=4, **factory_kwargs)
  1484. def forward(
  1485. self, input: Tensor, hx: Optional[tuple[Tensor, Tensor]] = None
  1486. ) -> tuple[Tensor, Tensor]:
  1487. if input.dim() not in (1, 2):
  1488. raise ValueError(
  1489. f"LSTMCell: Expected input to be 1D or 2D, got {input.dim()}D instead"
  1490. )
  1491. if hx is not None:
  1492. for idx, value in enumerate(hx):
  1493. if value.dim() not in (1, 2):
  1494. raise ValueError(
  1495. f"LSTMCell: Expected hx[{idx}] to be 1D or 2D, got {value.dim()}D instead"
  1496. )
  1497. is_batched = input.dim() == 2
  1498. if not is_batched:
  1499. input = input.unsqueeze(0)
  1500. if hx is None:
  1501. zeros = torch.zeros(
  1502. input.size(0), self.hidden_size, dtype=input.dtype, device=input.device
  1503. )
  1504. hx = (zeros, zeros)
  1505. else:
  1506. hx = (hx[0].unsqueeze(0), hx[1].unsqueeze(0)) if not is_batched else hx
  1507. ret = _VF.lstm_cell(
  1508. input,
  1509. hx,
  1510. self.weight_ih,
  1511. self.weight_hh,
  1512. self.bias_ih,
  1513. self.bias_hh,
  1514. )
  1515. if not is_batched:
  1516. ret = (ret[0].squeeze(0), ret[1].squeeze(0))
  1517. return ret
  1518. class GRUCell(RNNCellBase):
  1519. r"""A gated recurrent unit (GRU) cell.
  1520. .. math::
  1521. \begin{array}{ll}
  1522. r = \sigma(W_{ir} x + b_{ir} + W_{hr} h + b_{hr}) \\
  1523. z = \sigma(W_{iz} x + b_{iz} + W_{hz} h + b_{hz}) \\
  1524. n = \tanh(W_{in} x + b_{in} + r \odot (W_{hn} h + b_{hn})) \\
  1525. h' = (1 - z) \odot n + z \odot h
  1526. \end{array}
  1527. where :math:`\sigma` is the sigmoid function, and :math:`\odot` is the Hadamard product.
  1528. Args:
  1529. input_size: The number of expected features in the input `x`
  1530. hidden_size: The number of features in the hidden state `h`
  1531. bias: If ``False``, then the layer does not use bias weights `b_ih` and
  1532. `b_hh`. Default: ``True``
  1533. Inputs: input, hidden
  1534. - **input** : tensor containing input features
  1535. - **hidden** : tensor containing the initial hidden
  1536. state for each element in the batch.
  1537. Defaults to zero if not provided.
  1538. Outputs: h'
  1539. - **h'** : tensor containing the next hidden state
  1540. for each element in the batch
  1541. Shape:
  1542. - input: :math:`(N, H_{in})` or :math:`(H_{in})` tensor containing input features where
  1543. :math:`H_{in}` = `input_size`.
  1544. - hidden: :math:`(N, H_{out})` or :math:`(H_{out})` tensor containing the initial hidden
  1545. state where :math:`H_{out}` = `hidden_size`. Defaults to zero if not provided.
  1546. - output: :math:`(N, H_{out})` or :math:`(H_{out})` tensor containing the next hidden state.
  1547. Attributes:
  1548. weight_ih: the learnable input-hidden weights, of shape
  1549. `(3*hidden_size, input_size)`
  1550. weight_hh: the learnable hidden-hidden weights, of shape
  1551. `(3*hidden_size, hidden_size)`
  1552. bias_ih: the learnable input-hidden bias, of shape `(3*hidden_size)`
  1553. bias_hh: the learnable hidden-hidden bias, of shape `(3*hidden_size)`
  1554. .. note::
  1555. All the weights and biases are initialized from :math:`\mathcal{U}(-\sqrt{k}, \sqrt{k})`
  1556. where :math:`k = \frac{1}{\text{hidden\_size}}`
  1557. On certain ROCm devices, when using float16 inputs this module will use :ref:`different precision<fp16_on_mi200>` for backward.
  1558. Examples::
  1559. >>> rnn = nn.GRUCell(10, 20)
  1560. >>> input = torch.randn(6, 3, 10)
  1561. >>> hx = torch.randn(3, 20)
  1562. >>> output = []
  1563. >>> for i in range(6):
  1564. ... hx = rnn(input[i], hx)
  1565. ... output.append(hx)
  1566. """
  1567. def __init__(
  1568. self,
  1569. input_size: int,
  1570. hidden_size: int,
  1571. bias: bool = True,
  1572. device=None,
  1573. dtype=None,
  1574. ) -> None:
  1575. factory_kwargs = {"device": device, "dtype": dtype}
  1576. super().__init__(input_size, hidden_size, bias, num_chunks=3, **factory_kwargs)
  1577. def forward(self, input: Tensor, hx: Optional[Tensor] = None) -> Tensor:
  1578. if input.dim() not in (1, 2):
  1579. raise ValueError(
  1580. f"GRUCell: Expected input to be 1D or 2D, got {input.dim()}D instead"
  1581. )
  1582. if hx is not None and hx.dim() not in (1, 2):
  1583. raise ValueError(
  1584. f"GRUCell: Expected hidden to be 1D or 2D, got {hx.dim()}D instead"
  1585. )
  1586. is_batched = input.dim() == 2
  1587. if not is_batched:
  1588. input = input.unsqueeze(0)
  1589. if hx is None:
  1590. hx = torch.zeros(
  1591. input.size(0), self.hidden_size, dtype=input.dtype, device=input.device
  1592. )
  1593. else:
  1594. hx = hx.unsqueeze(0) if not is_batched else hx
  1595. ret = _VF.gru_cell(
  1596. input,
  1597. hx,
  1598. self.weight_ih,
  1599. self.weight_hh,
  1600. self.bias_ih,
  1601. self.bias_hh,
  1602. )
  1603. if not is_batched:
  1604. ret = ret.squeeze(0)
  1605. return ret