# Copyright (c) Alibaba, Inc. and its affiliates. # Part of the implementation is borrowed from huggingface/transformers. from collections.abc import Mapping def torch_nested_numpify(tensors): """ Numpify nested torch tensors. NOTE: If the type of input tensors is dict-like(Mapping, dict, OrderedDict, etc.), the return type will be dict. Args: tensors: Nested torch tensors. Returns: The numpify tensors. """ import torch "Numpify `tensors` (even if it's a nested list/tuple of tensors)." if isinstance(tensors, (list, tuple)): return type(tensors)(torch_nested_numpify(t) for t in tensors) if isinstance(tensors, Mapping): # return dict return {k: torch_nested_numpify(t) for k, t in tensors.items()} if isinstance(tensors, torch.Tensor): t = tensors.cpu() return t.numpy() return tensors def torch_nested_detach(tensors): """ Detach nested torch tensors. NOTE: If the type of input tensors is dict-like(Mapping, dict, OrderedDict, etc.), the return type will be dict. Args: tensors: Nested torch tensors. Returns: The detached tensors. """ import torch "Detach `tensors` (even if it's a nested list/tuple of tensors)." if isinstance(tensors, (list, tuple)): return type(tensors)(torch_nested_detach(t) for t in tensors) if isinstance(tensors, Mapping): return {k: torch_nested_detach(t) for k, t in tensors.items()} if isinstance(tensors, torch.Tensor): return tensors.detach() return tensors