random.py 5.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184
  1. # mypy: allow-untyped-defs
  2. from collections.abc import Iterable
  3. from typing import Union
  4. import torch
  5. from torch import Tensor
  6. from . import _lazy_call, _lazy_init, current_device, device_count, is_initialized
  7. __all__ = [
  8. "get_rng_state",
  9. "get_rng_state_all",
  10. "set_rng_state",
  11. "set_rng_state_all",
  12. "manual_seed",
  13. "manual_seed_all",
  14. "seed",
  15. "seed_all",
  16. "initial_seed",
  17. ]
  18. def get_rng_state(device: Union[int, str, torch.device] = "cuda") -> Tensor:
  19. r"""Return the random number generator state of the specified GPU as a ByteTensor.
  20. Args:
  21. device (torch.device or int, optional): The device to return the RNG state of.
  22. Default: ``'cuda'`` (i.e., ``torch.device('cuda')``, the current CUDA device).
  23. .. warning::
  24. This function eagerly initializes CUDA.
  25. """
  26. _lazy_init()
  27. if isinstance(device, str):
  28. device = torch.device(device)
  29. elif isinstance(device, int):
  30. device = torch.device("cuda", device)
  31. idx = device.index
  32. if idx is None:
  33. idx = current_device()
  34. default_generator = torch.cuda.default_generators[idx]
  35. return default_generator.get_state()
  36. def get_rng_state_all() -> list[Tensor]:
  37. r"""Return a list of ByteTensor representing the random number states of all devices."""
  38. results = [get_rng_state(i) for i in range(device_count())]
  39. return results
  40. def set_rng_state(
  41. new_state: Tensor, device: Union[int, str, torch.device] = "cuda"
  42. ) -> None:
  43. r"""Set the random number generator state of the specified GPU.
  44. Args:
  45. new_state (torch.ByteTensor): The desired state
  46. device (torch.device or int, optional): The device to set the RNG state.
  47. Default: ``'cuda'`` (i.e., ``torch.device('cuda')``, the current CUDA device).
  48. """
  49. if not is_initialized():
  50. with torch._C._DisableFuncTorch():
  51. # Clone the state because the callback will be triggered
  52. # later when CUDA is lazy initialized.
  53. new_state = new_state.clone(memory_format=torch.contiguous_format)
  54. if isinstance(device, str):
  55. device = torch.device(device)
  56. elif isinstance(device, int):
  57. device = torch.device("cuda", device)
  58. def cb():
  59. idx = device.index
  60. if idx is None:
  61. idx = current_device()
  62. default_generator = torch.cuda.default_generators[idx]
  63. default_generator.set_state(new_state)
  64. _lazy_call(cb)
  65. def set_rng_state_all(new_states: Iterable[Tensor]) -> None:
  66. r"""Set the random number generator state of all devices.
  67. Args:
  68. new_states (Iterable of torch.ByteTensor): The desired state for each device.
  69. """
  70. for i, state in enumerate(new_states):
  71. set_rng_state(state, i)
  72. def manual_seed(seed: int) -> None:
  73. r"""Set the seed for generating random numbers for the current GPU.
  74. It's safe to call this function if CUDA is not available; in that
  75. case, it is silently ignored.
  76. Args:
  77. seed (int): The desired seed.
  78. .. warning::
  79. If you are working with a multi-GPU model, this function is insufficient
  80. to get determinism. To seed all GPUs, use :func:`manual_seed_all`.
  81. """
  82. seed = int(seed)
  83. def cb():
  84. idx = current_device()
  85. default_generator = torch.cuda.default_generators[idx]
  86. default_generator.manual_seed(seed)
  87. _lazy_call(cb, seed=True)
  88. def manual_seed_all(seed: int) -> None:
  89. r"""Set the seed for generating random numbers on all GPUs.
  90. It's safe to call this function if CUDA is not available; in that
  91. case, it is silently ignored.
  92. Args:
  93. seed (int): The desired seed.
  94. """
  95. seed = int(seed)
  96. def cb():
  97. for i in range(device_count()):
  98. default_generator = torch.cuda.default_generators[i]
  99. default_generator.manual_seed(seed)
  100. _lazy_call(cb, seed_all=True)
  101. def seed() -> None:
  102. r"""Set the seed for generating random numbers to a random number for the current GPU.
  103. It's safe to call this function if CUDA is not available; in that
  104. case, it is silently ignored.
  105. .. warning::
  106. If you are working with a multi-GPU model, this function will only initialize
  107. the seed on one GPU. To initialize all GPUs, use :func:`seed_all`.
  108. """
  109. def cb():
  110. idx = current_device()
  111. default_generator = torch.cuda.default_generators[idx]
  112. default_generator.seed()
  113. _lazy_call(cb)
  114. def seed_all() -> None:
  115. r"""Set the seed for generating random numbers to a random number on all GPUs.
  116. It's safe to call this function if CUDA is not available; in that
  117. case, it is silently ignored.
  118. """
  119. def cb():
  120. random_seed = 0
  121. seeded = False
  122. for i in range(device_count()):
  123. default_generator = torch.cuda.default_generators[i]
  124. if not seeded:
  125. default_generator.seed()
  126. random_seed = default_generator.initial_seed()
  127. seeded = True
  128. else:
  129. default_generator.manual_seed(random_seed)
  130. _lazy_call(cb)
  131. def initial_seed() -> int:
  132. r"""Return the current random seed of the current GPU.
  133. .. warning::
  134. This function eagerly initializes CUDA.
  135. """
  136. _lazy_init()
  137. idx = current_device()
  138. default_generator = torch.cuda.default_generators[idx]
  139. return default_generator.initial_seed()