| 12345678910111213141516171819202122232425262728293031323334353637 |
- # Copyright (c) Alibaba, Inc. and its affiliates.
- from collections.abc import Mapping
- import torch
- from modelscope.outputs import ModelOutputBase
- def to_device(batch, device, non_blocking=False):
- """Put the data to the target cuda device just before the forward function.
- Args:
- batch: The batch data out of the dataloader.
- device: (str | torch.device): The target device for the data.
- Returns: The data to the target device.
- """
- if isinstance(batch, ModelOutputBase):
- for idx in range(len(batch)):
- batch[idx] = to_device(batch[idx], device)
- return batch
- elif isinstance(batch, dict) or isinstance(batch, Mapping):
- if hasattr(batch, '__setitem__'):
- # Reuse mini-batch to keep attributes for prediction.
- for k, v in batch.items():
- batch[k] = to_device(v, device)
- return batch
- else:
- return type(batch)(
- {k: to_device(v, device)
- for k, v in batch.items()})
- elif isinstance(batch, (tuple, list)):
- return type(batch)(to_device(v, device) for v in batch)
- elif isinstance(batch, torch.Tensor):
- return batch.to(device, non_blocking=non_blocking)
- else:
- return batch
|