tensor_utils.py 1.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051
  1. # Copyright (c) Alibaba, Inc. and its affiliates.
  2. # Part of the implementation is borrowed from huggingface/transformers.
  3. from collections.abc import Mapping
  4. def torch_nested_numpify(tensors):
  5. """ Numpify nested torch tensors.
  6. NOTE: If the type of input tensors is dict-like(Mapping, dict, OrderedDict, etc.), the return type will be dict.
  7. Args:
  8. tensors: Nested torch tensors.
  9. Returns:
  10. The numpify tensors.
  11. """
  12. import torch
  13. "Numpify `tensors` (even if it's a nested list/tuple of tensors)."
  14. if isinstance(tensors, (list, tuple)):
  15. return type(tensors)(torch_nested_numpify(t) for t in tensors)
  16. if isinstance(tensors, Mapping):
  17. # return dict
  18. return {k: torch_nested_numpify(t) for k, t in tensors.items()}
  19. if isinstance(tensors, torch.Tensor):
  20. t = tensors.cpu()
  21. return t.numpy()
  22. return tensors
  23. def torch_nested_detach(tensors):
  24. """ Detach nested torch tensors.
  25. NOTE: If the type of input tensors is dict-like(Mapping, dict, OrderedDict, etc.), the return type will be dict.
  26. Args:
  27. tensors: Nested torch tensors.
  28. Returns:
  29. The detached tensors.
  30. """
  31. import torch
  32. "Detach `tensors` (even if it's a nested list/tuple of tensors)."
  33. if isinstance(tensors, (list, tuple)):
  34. return type(tensors)(torch_nested_detach(t) for t in tensors)
  35. if isinstance(tensors, Mapping):
  36. return {k: torch_nested_detach(t) for k, t in tensors.items()}
  37. if isinstance(tensors, torch.Tensor):
  38. return tensors.detach()
  39. return tensors