_utils.py 968 B

12345678910111213141516171819202122232425262728
  1. from typing import Optional
  2. import torch
  3. from torch.types import Device as _device_t
  4. def _get_device_index(device: _device_t, optional: bool = False) -> int:
  5. if isinstance(device, int):
  6. return device
  7. if isinstance(device, str):
  8. device = torch.device(device)
  9. device_index: Optional[int] = None
  10. if isinstance(device, torch.device):
  11. acc = torch.accelerator.current_accelerator()
  12. if acc is None:
  13. raise RuntimeError("Accelerator expected")
  14. if acc.type != device.type:
  15. raise ValueError(
  16. f"{device.type} doesn't match the current accelerator {acc}."
  17. )
  18. device_index = device.index
  19. if device_index is None:
  20. if not optional:
  21. raise ValueError(
  22. f"Expected a torch.device with a specified index or an integer, but got:{device}"
  23. )
  24. return torch.accelerator.current_device_index()
  25. return device_index