rnn.py 2.0 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364
  1. # mypy: allow-untyped-defs
  2. import torch.cuda
  3. try:
  4. from torch._C import _cudnn
  5. except ImportError:
  6. # Uses of all the functions below should be guarded by torch.backends.cudnn.is_available(),
  7. # so it's safe to not emit any checks here.
  8. _cudnn = None # type: ignore[assignment]
  9. def get_cudnn_mode(mode):
  10. if mode == "RNN_RELU":
  11. return int(_cudnn.RNNMode.rnn_relu)
  12. elif mode == "RNN_TANH":
  13. return int(_cudnn.RNNMode.rnn_tanh)
  14. elif mode == "LSTM":
  15. return int(_cudnn.RNNMode.lstm)
  16. elif mode == "GRU":
  17. return int(_cudnn.RNNMode.gru)
  18. else:
  19. raise Exception(f"Unknown mode: {mode}") # noqa: TRY002
  20. # NB: We don't actually need this class anymore (in fact, we could serialize the
  21. # dropout state for even better reproducibility), but it is kept for backwards
  22. # compatibility for old models.
  23. class Unserializable:
  24. def __init__(self, inner):
  25. self.inner = inner
  26. def get(self):
  27. return self.inner
  28. def __getstate__(self):
  29. # Note: can't return {}, because python2 won't call __setstate__
  30. # if the value evaluates to False
  31. return "<unserializable>"
  32. def __setstate__(self, state):
  33. self.inner = None
  34. def init_dropout_state(dropout, train, dropout_seed, dropout_state):
  35. dropout_desc_name = "desc_" + str(torch.cuda.current_device())
  36. dropout_p = dropout if train else 0
  37. if (dropout_desc_name not in dropout_state) or (
  38. dropout_state[dropout_desc_name].get() is None
  39. ):
  40. if dropout_p == 0:
  41. dropout_state[dropout_desc_name] = Unserializable(None)
  42. else:
  43. dropout_state[dropout_desc_name] = Unserializable(
  44. torch._cudnn_init_dropout_state( # type: ignore[call-arg]
  45. dropout_p,
  46. train,
  47. dropout_seed,
  48. self_ty=torch.uint8,
  49. device=torch.device("cuda"),
  50. )
  51. )
  52. dropout_ts = dropout_state[dropout_desc_name].get()
  53. return dropout_ts