random.py 5.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177
  1. # mypy: allow-untyped-defs
  2. from collections.abc import Iterable
  3. from typing import Union
  4. import torch
  5. from torch import Tensor
  6. from . import _lazy_call, _lazy_init, current_device, device_count
  7. def get_rng_state(device: Union[int, str, torch.device] = "xpu") -> Tensor:
  8. r"""Return the random number generator state of the specified GPU as a ByteTensor.
  9. Args:
  10. device (torch.device or int, optional): The device to return the RNG state of.
  11. Default: ``'xpu'`` (i.e., ``torch.device('xpu')``, the current XPU device).
  12. .. warning::
  13. This function eagerly initializes XPU.
  14. """
  15. _lazy_init()
  16. if isinstance(device, str):
  17. device = torch.device(device)
  18. elif isinstance(device, int):
  19. device = torch.device("xpu", device)
  20. idx = device.index
  21. if idx is None:
  22. idx = current_device()
  23. default_generator = torch.xpu.default_generators[idx]
  24. return default_generator.get_state()
  25. def get_rng_state_all() -> list[Tensor]:
  26. r"""Return a list of ByteTensor representing the random number states of all devices."""
  27. results = [get_rng_state(i) for i in range(device_count())]
  28. return results
  29. def set_rng_state(
  30. new_state: Tensor, device: Union[int, str, torch.device] = "xpu"
  31. ) -> None:
  32. r"""Set the random number generator state of the specified GPU.
  33. Args:
  34. new_state (torch.ByteTensor): The desired state
  35. device (torch.device or int, optional): The device to set the RNG state.
  36. Default: ``'xpu'`` (i.e., ``torch.device('xpu')``, the current XPU device).
  37. """
  38. with torch._C._DisableFuncTorch():
  39. new_state_copy = new_state.clone(memory_format=torch.contiguous_format)
  40. if isinstance(device, str):
  41. device = torch.device(device)
  42. elif isinstance(device, int):
  43. device = torch.device("xpu", device)
  44. def cb():
  45. idx = device.index
  46. if idx is None:
  47. idx = current_device()
  48. default_generator = torch.xpu.default_generators[idx]
  49. default_generator.set_state(new_state_copy)
  50. _lazy_call(cb)
  51. def set_rng_state_all(new_states: Iterable[Tensor]) -> None:
  52. r"""Set the random number generator state of all devices.
  53. Args:
  54. new_states (Iterable of torch.ByteTensor): The desired state for each device.
  55. """
  56. for i, state in enumerate(new_states):
  57. set_rng_state(state, i)
  58. def manual_seed(seed: int) -> None:
  59. r"""Set the seed for generating random numbers for the current GPU.
  60. It's safe to call this function if XPU is not available; in that case, it is silently ignored.
  61. Args:
  62. seed (int): The desired seed.
  63. .. warning::
  64. If you are working with a multi-GPU model, this function is insufficient
  65. to get determinism. To seed all GPUs, use :func:`manual_seed_all`.
  66. """
  67. seed = int(seed)
  68. def cb():
  69. idx = current_device()
  70. default_generator = torch.xpu.default_generators[idx]
  71. default_generator.manual_seed(seed)
  72. _lazy_call(cb, seed=True)
  73. def manual_seed_all(seed: int) -> None:
  74. r"""Set the seed for generating random numbers on all GPUs.
  75. It's safe to call this function if XPU is not available; in that case, it is silently ignored.
  76. Args:
  77. seed (int): The desired seed.
  78. """
  79. seed = int(seed)
  80. def cb():
  81. for i in range(device_count()):
  82. default_generator = torch.xpu.default_generators[i]
  83. default_generator.manual_seed(seed)
  84. _lazy_call(cb, seed_all=True)
  85. def seed() -> None:
  86. r"""Set the seed for generating random numbers to a random number for the current GPU.
  87. It's safe to call this function if XPU is not available; in that case, it is silently ignored.
  88. .. warning::
  89. If you are working with a multi-GPU model, this function will only initialize
  90. the seed on one GPU. To initialize all GPUs, use :func:`seed_all`.
  91. """
  92. def cb():
  93. idx = current_device()
  94. default_generator = torch.xpu.default_generators[idx]
  95. default_generator.seed()
  96. _lazy_call(cb)
  97. def seed_all() -> None:
  98. r"""Set the seed for generating random numbers to a random number on all GPUs.
  99. It's safe to call this function if XPU is not available; in that case, it is silently ignored.
  100. """
  101. def cb():
  102. random_seed = 0
  103. seeded = False
  104. for i in range(device_count()):
  105. default_generator = torch.xpu.default_generators[i]
  106. if not seeded:
  107. default_generator.seed()
  108. random_seed = default_generator.initial_seed()
  109. seeded = True
  110. else:
  111. default_generator.manual_seed(random_seed)
  112. _lazy_call(cb)
  113. def initial_seed() -> int:
  114. r"""Return the current random seed of the current GPU.
  115. .. warning::
  116. This function eagerly initializes XPU.
  117. """
  118. _lazy_init()
  119. idx = current_device()
  120. default_generator = torch.xpu.default_generators[idx]
  121. return default_generator.initial_seed()
  122. __all__ = [
  123. "get_rng_state",
  124. "get_rng_state_all",
  125. "set_rng_state",
  126. "set_rng_state_all",
  127. "manual_seed",
  128. "manual_seed_all",
  129. "seed",
  130. "seed_all",
  131. "initial_seed",
  132. ]