| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152 |
- # Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- import numbers
- from collections.abc import Mapping, Sequence
- import numpy as np
- import paddle
- FIELD_PREFIX = "_paddle_field_"
- def _flatten_batch(batch):
- """
- For lod_blocking_queue only receive tensor array, flatten batch
- data, extract numpy.array data out as a list of numpy.array to
- send to lod_blocking_queue, and save the batch data structure
- such as fields in other types (str, int, etc) or key-value map
- of dictionaries
- """
- def _flatten(batch, flat_batch, structure, field_idx):
- if isinstance(batch, Sequence):
- for field in batch:
- if isinstance(
- field,
- (np.ndarray, paddle.Tensor, paddle.base.core.eager.Tensor),
- ):
- structure.append(f'{FIELD_PREFIX}{field_idx}')
- flat_batch.append(field)
- field_idx += 1
- elif isinstance(field, (str, bytes, numbers.Number)):
- structure.append(field)
- elif isinstance(field, Sequence):
- field_struct, field_idx = _flatten(
- field, flat_batch, [], field_idx
- )
- structure.append(field_struct)
- elif isinstance(field, Mapping):
- field_struct, field_idx = _flatten(
- field, flat_batch, {}, field_idx
- )
- structure.append(field_struct)
- else:
- structure.append(field)
- elif isinstance(batch, Mapping):
- for k, field in batch.items():
- if isinstance(
- field,
- (np.ndarray, paddle.Tensor, paddle.base.core.eager.Tensor),
- ):
- structure[k] = f'{FIELD_PREFIX}{field_idx}'
- flat_batch.append(field)
- field_idx += 1
- elif isinstance(field, (str, bytes, numbers.Number)):
- structure[k] = field
- elif isinstance(field, Sequence):
- field_struct, field_idx = _flatten(
- field, flat_batch, [], field_idx
- )
- structure[k] = field_struct
- elif isinstance(field, Mapping):
- field_struct, field_idx = _flatten(
- field, flat_batch, {}, field_idx
- )
- structure[k] = field_struct
- else:
- structure[k] = field
- else:
- raise TypeError(f"wrong flat data type: {type(batch)}")
- return structure, field_idx
- # sample only contains single fields
- if not isinstance(batch, Sequence):
- flat_batch = []
- structure, _ = _flatten([batch], flat_batch, [], 0)
- return flat_batch, structure[0]
- flat_batch = []
- structure, _ = _flatten(batch, flat_batch, [], 0)
- return flat_batch, structure
- def _restore_batch(flat_batch, structure):
- """
- After reading list of Tensor data from lod_blocking_queue outputs,
- use this function to restore the batch data structure, replace
- :attr:`_paddle_field_x` with data from flat_batch
- """
- def _restore(structure, field_idx):
- if isinstance(structure, Sequence):
- for i, field in enumerate(structure):
- if isinstance(field, str) and field.startswith(FIELD_PREFIX):
- cur_field_idx = int(field.replace(FIELD_PREFIX, ''))
- field_idx = max(field_idx, cur_field_idx)
- assert (
- flat_batch[cur_field_idx] is not None
- ), "flat_batch[{}] parsed repeatly"
- structure[i] = flat_batch[cur_field_idx]
- flat_batch[cur_field_idx] = None
- elif isinstance(field, (str, bytes, numbers.Number)):
- continue
- elif isinstance(field, (Sequence, Mapping)):
- field_idx = _restore(structure[i], field_idx)
- elif isinstance(structure, Mapping):
- for k, field in structure.items():
- if isinstance(field, str) and field.startswith(FIELD_PREFIX):
- cur_field_idx = int(field.replace(FIELD_PREFIX, ''))
- field_idx = max(field_idx, cur_field_idx)
- assert (
- flat_batch[cur_field_idx] is not None
- ), "flat_batch[{}] parsed repeatly"
- structure[k] = flat_batch[cur_field_idx]
- flat_batch[cur_field_idx] = None
- elif isinstance(field, (str, bytes, numbers.Number)):
- continue
- elif isinstance(field, (Sequence, Mapping)):
- field_idx = _restore(structure[k], field_idx)
- else:
- raise TypeError(f"wrong flat data type: {type(structure)}")
- return field_idx
- assert isinstance(flat_batch, Sequence), "flat_batch is not a list or tuple"
- # no np.array in dataset, no output tensor from blocking queue
- # simply return structure
- if len(flat_batch) == 0:
- return structure
- # sample only contains single fields
- if isinstance(structure, (str, bytes)):
- assert (
- structure == f'{FIELD_PREFIX}{0}'
- ), f"invalid structure: {structure}"
- return flat_batch[0]
- field_idx = _restore(structure, 0)
- assert field_idx + 1 == len(flat_batch), "Tensor parse incomplete"
- return structure
|