data_utils.py 1.2 KB

12345678910111213141516171819202122232425262728293031323334353637
  1. # Copyright (c) Alibaba, Inc. and its affiliates.
  2. from collections.abc import Mapping
  3. import torch
  4. from modelscope.outputs import ModelOutputBase
  5. def to_device(batch, device, non_blocking=False):
  6. """Put the data to the target cuda device just before the forward function.
  7. Args:
  8. batch: The batch data out of the dataloader.
  9. device: (str | torch.device): The target device for the data.
  10. Returns: The data to the target device.
  11. """
  12. if isinstance(batch, ModelOutputBase):
  13. for idx in range(len(batch)):
  14. batch[idx] = to_device(batch[idx], device)
  15. return batch
  16. elif isinstance(batch, dict) or isinstance(batch, Mapping):
  17. if hasattr(batch, '__setitem__'):
  18. # Reuse mini-batch to keep attributes for prediction.
  19. for k, v in batch.items():
  20. batch[k] = to_device(v, device)
  21. return batch
  22. else:
  23. return type(batch)(
  24. {k: to_device(v, device)
  25. for k, v in batch.items()})
  26. elif isinstance(batch, (tuple, list)):
  27. return type(batch)(to_device(v, device) for v in batch)
  28. elif isinstance(batch, torch.Tensor):
  29. return batch.to(device, non_blocking=non_blocking)
  30. else:
  31. return batch