collate.py 6.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166
  1. # Copyright (c) Alibaba, Inc. and its affiliates.
  2. from typing import List
  3. import numpy as np
  4. import torch
  5. def collate_fn(samples, pad_idx, eos_idx):
  6. r"""
  7. convert the sample to batch tensor.
  8. """
  9. if len(samples) == 0:
  10. return {}
  11. def merge(key):
  12. return collate_tokens([s[key] for s in samples],
  13. pad_idx,
  14. eos_idx=eos_idx)
  15. batch = {
  16. 'nsentences': len(samples),
  17. 'net_input': {},
  18. }
  19. if samples[0].get('source', None) is not None:
  20. batch['net_input']['input_ids'] = merge('source')
  21. if samples[0].get('id', None) is not None:
  22. batch['id'] = np.array([s.get('id') for s in samples])
  23. if samples[0].get('target', None) is not None:
  24. batch['target'] = merge('target')
  25. tgt_lengths = torch.LongTensor(
  26. [s['target'].ne(pad_idx).long().sum() for s in samples])
  27. ntokens = tgt_lengths.sum().item()
  28. batch['ntokens'] = ntokens
  29. if samples[0].get('prev_output_tokens', None) is not None:
  30. batch['net_input']['decoder_input_ids'] = merge('prev_output_tokens')
  31. if samples[0].get('patch_image', None) is not None:
  32. batch['net_input']['patch_images'] = torch.stack(
  33. [sample['patch_image'] for sample in samples], dim=0)
  34. if samples[0].get('patch_mask', None) is not None:
  35. batch['net_input']['patch_masks'] = torch.cat(
  36. [sample['patch_mask'] for sample in samples])
  37. # image generation
  38. if samples[0].get('code_mask', None) is not None:
  39. batch['net_input']['code_masks'] = torch.cat(
  40. [sample['code_mask'] for sample in samples])
  41. if samples[0].get('code_image', None) is not None:
  42. batch['code_images'] = torch.cat(
  43. [sample['code_image'] for sample in samples])
  44. # For classification tasks (i.e., VQA, SNLI-VE, GLUE)
  45. if samples[0].get('conf', None) is not None:
  46. batch['conf'] = torch.cat([s['conf'] for s in samples], dim=0)
  47. if samples[0].get('ref_dict', None) is not None:
  48. batch['ref_dict'] = np.array([s['ref_dict'] for s in samples])
  49. if samples[0].get('label', None) is not None:
  50. batch['labels'] = np.array([s['label'] for s in samples]).tolist()
  51. if samples[0].get('constraint_mask', None) is not None:
  52. batch['constraint_masks'] = merge('constraint_mask')
  53. if samples[0].get('decoder_prompt', None) is not None:
  54. batch['decoder_prompts'] = np.array(
  55. [s['decoder_prompt'].tolist() for s in samples])
  56. if samples[0].get('prefix_token', None) is not None:
  57. batch['prefix_tokens'] = merge('prefix_token')
  58. # For detection and visual grounding
  59. if samples[0].get('w_resize_ratio', None) is not None:
  60. batch['w_resize_ratios'] = torch.stack(
  61. [s['w_resize_ratio'] for s in samples], dim=0)
  62. if samples[0].get('h_resize_ratio', None) is not None:
  63. batch['h_resize_ratios'] = torch.stack(
  64. [s['h_resize_ratio'] for s in samples], dim=0)
  65. if samples[0].get('region_coord', None) is not None:
  66. batch['region_coords'] = torch.stack(
  67. [s['region_coord'] for s in samples], dim=0)
  68. if samples[0].get('sample', None) is not None:
  69. batch['samples'] = [s['sample'] for s in samples]
  70. # For asr
  71. if samples[0].get('fbank', None) is not None:
  72. batch['net_input']['fbank'] = _collate_frames(
  73. [s['fbank'] for s in samples])
  74. batch['net_input']['fbank_length'] = torch.tensor(
  75. [s['fbank'].size(0) for s in samples], dtype=torch.long)
  76. if samples[0].get('fbank_mask', None) is not None:
  77. batch['net_input']['fbank_masks'] = torch.cat(
  78. [s['fbank_mask'] for s in samples])
  79. if samples[0].get('phone_item', None) is not None:
  80. batch['net_input']['phone_items'] = merge('phone_item')
  81. batch['net_input']['phone_masks'] = torch.cat(
  82. [s['phone_mask'] for s in samples])
  83. if samples[0].get('phone_target', None) is not None:
  84. batch['phone_target'] = merge('phone_target')
  85. batch['phone_length'] = torch.tensor(
  86. [s['phone_target'].size(0) for s in samples], dtype=torch.long)
  87. # for sudoku
  88. if samples[0].get('db_struct', None) is not None:
  89. db_struct = [sample['db_struct'] for sample in samples]
  90. batch['db_struct'] = db_struct
  91. if samples[0].get('mask_ratio', None) is not None:
  92. mask_ratio = [sample['mask_ratio'] for sample in samples]
  93. batch['mask_ratio'] = mask_ratio
  94. if samples[0].get('seg_col_tokens', None) is not None:
  95. seg_col_tokens = merge('seg_col_tokens')
  96. batch['net_input']['seg_col_tokens'] = seg_col_tokens
  97. if samples[0].get('seg_row_tokens', None) is not None:
  98. seg_row_tokens = merge('seg_row_tokens')
  99. batch['net_input']['seg_row_tokens'] = seg_row_tokens
  100. return batch
  101. def collate_tokens(
  102. values,
  103. pad_idx,
  104. eos_idx=None,
  105. left_pad=False,
  106. move_eos_to_beginning=False,
  107. pad_to_length=None,
  108. pad_to_multiple=1,
  109. pad_to_bsz=None,
  110. ):
  111. """Convert a list of 1d tensors into a padded 2d tensor."""
  112. size = max(v.size(0) for v in values)
  113. size = size if pad_to_length is None else max(size, pad_to_length)
  114. if pad_to_multiple != 1 and size % pad_to_multiple != 0:
  115. size = int(((size - 0.1) // pad_to_multiple + 1) * pad_to_multiple)
  116. def copy_tensor(src, dst):
  117. assert dst.numel() == src.numel()
  118. if move_eos_to_beginning:
  119. if eos_idx is None:
  120. # if no eos_idx is specified, then use the last token in src
  121. dst[0] = src[-1]
  122. else:
  123. dst[0] = eos_idx
  124. dst[1:] = src[:-1]
  125. else:
  126. dst.copy_(src)
  127. if values[0].dim() == 1:
  128. res = values[0].new(len(values), size).fill_(pad_idx)
  129. elif values[0].dim() == 2:
  130. assert move_eos_to_beginning is False
  131. res = values[0].new(len(values), size,
  132. values[0].size(1)).fill_(pad_idx)
  133. else:
  134. raise NotImplementedError
  135. for i, v in enumerate(values):
  136. copy_tensor(v, res[i][size - len(v):] if left_pad else res[i][:len(v)])
  137. return res
  138. def _collate_frames(frames: List[torch.Tensor]):
  139. """
  140. Convert a list of 2D frames into a padded 3D tensor
  141. Args:
  142. frames (list): list of 2D frames of size L[i]*f_dim. Where L[i] is
  143. length of i-th frame and f_dim is static dimension of features
  144. Returns:
  145. 3D tensor of size len(frames)*len_max*f_dim where len_max is max of L[i]
  146. """
  147. max_len = max(frame.size(0) for frame in frames)
  148. out = frames[0].new_zeros((len(frames), max_len, frames[0].size(1)))
  149. for i, v in enumerate(frames):
  150. out[i, :v.size(0)] = v
  151. return out