random.py 6.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156
  1. # Copyright 2022 The HuggingFace Team. All rights reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import random
  15. from typing import Optional, Union
  16. import numpy as np
  17. import torch
  18. from ..state import AcceleratorState
  19. from .constants import CUDA_DISTRIBUTED_TYPES
  20. from .dataclasses import DistributedType, RNGType
  21. from .imports import (
  22. is_hpu_available,
  23. is_mlu_available,
  24. is_musa_available,
  25. is_npu_available,
  26. is_sdaa_available,
  27. is_torch_xla_available,
  28. is_xpu_available,
  29. )
  30. if is_torch_xla_available():
  31. import torch_xla.core.xla_model as xm
  32. def set_seed(seed: int, device_specific: bool = False, deterministic: bool = False):
  33. """
  34. Helper function for reproducible behavior to set the seed in `random`, `numpy`, `torch`.
  35. Args:
  36. seed (`int`):
  37. The seed to set.
  38. device_specific (`bool`, *optional*, defaults to `False`):
  39. Whether to differ the seed on each device slightly with `self.process_index`.
  40. deterministic (`bool`, *optional*, defaults to `False`):
  41. Whether to use deterministic algorithms where available. Can slow down training.
  42. """
  43. if device_specific:
  44. seed += AcceleratorState().process_index
  45. random.seed(seed)
  46. np.random.seed(seed)
  47. torch.manual_seed(seed)
  48. if is_xpu_available():
  49. torch.xpu.manual_seed_all(seed)
  50. elif is_npu_available():
  51. torch.npu.manual_seed_all(seed)
  52. elif is_mlu_available():
  53. torch.mlu.manual_seed_all(seed)
  54. elif is_sdaa_available():
  55. torch.sdaa.manual_seed_all(seed)
  56. elif is_musa_available():
  57. torch.musa.manual_seed_all(seed)
  58. elif is_hpu_available():
  59. torch.hpu.manual_seed_all(seed)
  60. else:
  61. torch.cuda.manual_seed_all(seed)
  62. # ^^ safe to call this function even if cuda is not available
  63. if is_torch_xla_available():
  64. xm.set_rng_state(seed)
  65. if deterministic:
  66. torch.use_deterministic_algorithms(True)
  67. def synchronize_rng_state(rng_type: Optional[RNGType] = None, generator: Optional[torch.Generator] = None):
  68. # Get the proper rng state
  69. if rng_type == RNGType.TORCH:
  70. rng_state = torch.get_rng_state()
  71. elif rng_type == RNGType.CUDA:
  72. rng_state = torch.cuda.get_rng_state()
  73. elif rng_type == RNGType.XLA:
  74. assert is_torch_xla_available(), "Can't synchronize XLA seeds as torch_xla is unavailable."
  75. rng_state = torch.tensor(xm.get_rng_state())
  76. elif rng_type == RNGType.NPU:
  77. assert is_npu_available(), "Can't synchronize NPU seeds on an environment without NPUs."
  78. rng_state = torch.npu.get_rng_state()
  79. elif rng_type == RNGType.MLU:
  80. assert is_mlu_available(), "Can't synchronize MLU seeds on an environment without MLUs."
  81. rng_state = torch.mlu.get_rng_state()
  82. elif rng_type == RNGType.SDAA:
  83. assert is_sdaa_available(), "Can't synchronize SDAA seeds on an environment without SDAAs."
  84. rng_state = torch.sdaa.get_rng_state()
  85. elif rng_type == RNGType.MUSA:
  86. assert is_musa_available(), "Can't synchronize MUSA seeds on an environment without MUSAs."
  87. rng_state = torch.musa.get_rng_state()
  88. elif rng_type == RNGType.XPU:
  89. assert is_xpu_available(), "Can't synchronize XPU seeds on an environment without XPUs."
  90. rng_state = torch.xpu.get_rng_state()
  91. elif rng_type == RNGType.HPU:
  92. assert is_hpu_available(), "Can't synchronize HPU seeds on an environment without HPUs."
  93. rng_state = torch.hpu.get_rng_state()
  94. elif rng_type == RNGType.GENERATOR:
  95. assert generator is not None, "Need a generator to synchronize its seed."
  96. rng_state = generator.get_state()
  97. # Broadcast the rng state from device 0 to other devices
  98. state = AcceleratorState()
  99. if state.distributed_type == DistributedType.XLA:
  100. rng_state = rng_state.to(xm.xla_device())
  101. xm.collective_broadcast([rng_state])
  102. xm.mark_step()
  103. rng_state = rng_state.cpu()
  104. elif (
  105. state.distributed_type in CUDA_DISTRIBUTED_TYPES
  106. or state.distributed_type == DistributedType.MULTI_MLU
  107. or state.distributed_type == DistributedType.MULTI_SDAA
  108. or state.distributed_type == DistributedType.MULTI_MUSA
  109. or state.distributed_type == DistributedType.MULTI_NPU
  110. or state.distributed_type == DistributedType.MULTI_XPU
  111. or state.distributed_type == DistributedType.MULTI_HPU
  112. ):
  113. rng_state = rng_state.to(state.device)
  114. torch.distributed.broadcast(rng_state, 0)
  115. rng_state = rng_state.cpu()
  116. elif state.distributed_type == DistributedType.MULTI_CPU:
  117. torch.distributed.broadcast(rng_state, 0)
  118. # Set the broadcast rng state
  119. if rng_type == RNGType.TORCH:
  120. torch.set_rng_state(rng_state)
  121. elif rng_type == RNGType.CUDA:
  122. torch.cuda.set_rng_state(rng_state)
  123. elif rng_type == RNGType.NPU:
  124. torch.npu.set_rng_state(rng_state)
  125. elif rng_type == RNGType.MLU:
  126. torch.mlu.set_rng_state(rng_state)
  127. elif rng_type == RNGType.SDAA:
  128. torch.sdaa.set_rng_state(rng_state)
  129. elif rng_type == RNGType.MUSA:
  130. torch.musa.set_rng_state(rng_state)
  131. elif rng_type == RNGType.XPU:
  132. torch.xpu.set_rng_state(rng_state)
  133. elif rng_state == RNGType.HPU:
  134. torch.hpu.set_rng_state(rng_state)
  135. elif rng_type == RNGType.XLA:
  136. xm.set_rng_state(rng_state.item())
  137. elif rng_type == RNGType.GENERATOR:
  138. generator.set_state(rng_state)
  139. def synchronize_rng_states(rng_types: list[Union[str, RNGType]], generator: Optional[torch.Generator] = None):
  140. for rng_type in rng_types:
  141. synchronize_rng_state(RNGType(rng_type), generator=generator)