blocklm_utils.py 28 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620
  1. # Copyright (c) 2022 Zhipu.AI
  2. import copy
  3. import math
  4. import random
  5. import numpy as np
  6. import torch
  7. import torch.utils.data
  8. from megatron_util import mpu, print_rank_0
  9. from scipy.stats import poisson
  10. def rindex(lst, val, start=None):
  11. if start is None:
  12. start = len(lst) - 1
  13. for i in range(start, -1, -1):
  14. if lst[i] == val:
  15. return i
  16. return -1
  17. def index_in_list(lst, val, start=None):
  18. if start is None:
  19. start = 0
  20. for i in range(start, len(lst)):
  21. if lst[i] == val:
  22. return i
  23. return -1
  24. class ConstructBlockStrategy:
  25. def __init__(self,
  26. args,
  27. tokenizer,
  28. max_seq_length,
  29. bert_prob=1.0,
  30. gap_sentence_prob=0.0,
  31. gpt_infill_prob=0.5,
  32. gpt_min_ratio=0.5,
  33. bert_ratio=0.15,
  34. gap_sentence_ratio=0.15,
  35. average_block_length=3,
  36. max_block_length=40,
  37. block_mask_prob=0.0,
  38. context_mask_ratio=0.0,
  39. context_mask_range=3,
  40. short_seq_prob=0.0,
  41. single_span_prob=0.0,
  42. block_position_encoding=True,
  43. encoder_decoder=False,
  44. shuffle_blocks=True,
  45. sentinel_token=False,
  46. task_mask=False,
  47. random_position=False,
  48. masked_lm=False):
  49. self.eod_token = args.eod_token
  50. self.tokenizer = tokenizer
  51. self.count = 0
  52. self.max_seq_length = max_seq_length
  53. self.rank = mpu.get_data_parallel_rank()
  54. self.world_size = mpu.get_data_parallel_world_size()
  55. # self.rank = 0
  56. # self.world_size = 1
  57. assert 0.0 <= bert_prob <= 1.0
  58. self.bert_prob = bert_prob
  59. self.gap_sentence_prob = gap_sentence_prob
  60. self.gpt_prob = 1 - bert_prob - gap_sentence_prob
  61. assert self.gpt_prob >= -1e-10
  62. self.infill_prob = gpt_infill_prob
  63. self.gpt_min_ratio = gpt_min_ratio
  64. self.bert_ratio = bert_ratio
  65. self.gap_sentence_ratio = gap_sentence_ratio
  66. self.block_length_distribution = [
  67. poisson.pmf(i, average_block_length)
  68. for i in range(1, max_block_length)
  69. ]
  70. self.block_mask_prob = block_mask_prob
  71. self.context_mask_ratio = context_mask_ratio
  72. self.context_mask_range = context_mask_range
  73. self.short_seq_prob = short_seq_prob
  74. self.single_span_prob = single_span_prob
  75. self.block_position_encoding = block_position_encoding
  76. self.encoder_decoder = encoder_decoder
  77. self.shuffle_blocks = shuffle_blocks
  78. self.sentinel_token = sentinel_token
  79. self.generation_mask = 'gMASK' if task_mask else 'MASK'
  80. self.generation_mask = self.tokenizer.get_command(
  81. self.generation_mask).Id
  82. self.gap_sentence_mask = 'sMASK' if task_mask else 'MASK'
  83. self.gap_sentence_mask = self.tokenizer.get_command(
  84. self.gap_sentence_mask).Id
  85. self.random_position = random_position
  86. self.masked_lm = masked_lm
  87. print_rank_0(
  88. f'BERT prob {self.bert_prob}, gap sent prob {self.gap_sentence_prob}, GPT prob {self.gpt_prob}, infill prob {self.infill_prob}' # noqa
  89. )
  90. print_rank_0(
  91. f'generation min ratio {self.gpt_min_ratio}, block ratio {self.bert_ratio}, gap sent ratio {self.gap_sentence_ratio}' # noqa
  92. )
  93. print_rank_0(
  94. f'block length distribution {self.block_length_distribution}')
  95. print_rank_0(
  96. f'block mask prob {self.block_mask_prob}, context mask ratio {self.context_mask_ratio}'
  97. )
  98. def contains_sentence_end(self, tok):
  99. tok = self.tokenizer.IdToToken(tok)
  100. if '.' in tok:
  101. return True
  102. if '?' in tok:
  103. return True
  104. if '!' in tok:
  105. return True
  106. if ';' in tok:
  107. return True
  108. if ':' in tok:
  109. return True
  110. if '。' in tok:
  111. return True
  112. if '?' in tok:
  113. return True
  114. if '!' in tok:
  115. return True
  116. if ';' in tok:
  117. return True
  118. if '…' in tok:
  119. return True
  120. if '\n' in tok:
  121. return True
  122. return False
  123. @staticmethod
  124. def sample_spans(span_lengths, total_length, rng, offset=0):
  125. blank_length = total_length - sum(span_lengths)
  126. m = blank_length - len(span_lengths) + 1
  127. places = [rng.randrange(m + 1) for _ in range(len(span_lengths))]
  128. places.sort()
  129. spans = []
  130. for place, span_length in zip(places, span_lengths):
  131. start = offset + place
  132. end = offset + place + span_length
  133. spans.append((start, end))
  134. offset += span_length + 1
  135. return spans
  136. def sample_span_in_document(self, tokens, masked_lengths, rng):
  137. rng.shuffle(masked_lengths)
  138. mask_spans = []
  139. mask_index = 0
  140. indices = [-1] + np.where(tokens == self.eod_token)[0].tolist()
  141. last_index = len(tokens)
  142. documents = []
  143. for index in reversed(indices):
  144. start_index = index
  145. if start_index + 1 < len(tokens) and tokens[
  146. start_index + 1] == self.tokenizer.get_command('ENC').Id:
  147. start_index += 1
  148. length = last_index - start_index - 1
  149. if last_index == len(tokens) and length > 0:
  150. length -= 1
  151. documents.append((start_index + 1, length))
  152. last_index = index
  153. documents.sort(key=lambda x: x[1])
  154. for i, (offset, length) in enumerate(documents):
  155. if i == len(documents) - 1:
  156. current_masked_length, current_count = 0, 0
  157. while mask_index + current_count < len(
  158. masked_lengths
  159. ) and masked_lengths[
  160. mask_index + # noqa
  161. current_count] + current_masked_length + current_count <= length:
  162. current_masked_length += masked_lengths[mask_index
  163. + current_count]
  164. current_count += 1
  165. if current_count > 0:
  166. spans = self.sample_spans(
  167. masked_lengths[mask_index:mask_index + current_count],
  168. length,
  169. rng,
  170. offset=offset)
  171. mask_spans += spans
  172. if mask_index + current_count < len(masked_lengths) - 1:
  173. print(length, masked_lengths[mask_index:],
  174. masked_lengths[:mask_index], indices)
  175. else:
  176. current_masked_total = int(length * self.bert_ratio)
  177. current_masked_length, current_count = 0, 0
  178. while mask_index + current_count < len(
  179. masked_lengths
  180. ) and masked_lengths[
  181. mask_index + # noqa
  182. current_count] + current_masked_length <= current_masked_total:
  183. current_masked_length += masked_lengths[mask_index
  184. + current_count]
  185. current_count += 1
  186. if current_count > 0:
  187. spans = self.sample_spans(
  188. masked_lengths[mask_index:mask_index + current_count],
  189. length,
  190. rng,
  191. offset=offset)
  192. mask_spans += spans
  193. mask_index += current_count
  194. return mask_spans
  195. def make_masked_data(self,
  196. tokens,
  197. loss_masks,
  198. attention_mask,
  199. block_spans,
  200. rng,
  201. task='bert'):
  202. position_ids = np.arange(len(tokens), dtype=int)
  203. targets = copy.deepcopy(tokens)
  204. mask_id = self.tokenizer.get_command('MASK').Id
  205. mlm_masks = np.zeros(len(tokens), dtype=int)
  206. for start, end in block_spans:
  207. for idx in range(start, end):
  208. tokens[idx] = mask_id
  209. mlm_masks[start:end] = 1
  210. loss_masks = loss_masks * mlm_masks
  211. return tokens, targets, loss_masks, position_ids
  212. def make_block_data(self,
  213. tokens,
  214. loss_masks,
  215. attention_mask,
  216. block_spans,
  217. rng,
  218. task='bert'):
  219. text_length = len(tokens)
  220. position_ids = np.ones(len(tokens), dtype=int)
  221. for start, end in block_spans:
  222. position_ids[start + 1:end] = 0
  223. position_ids = np.cumsum(position_ids) - 1
  224. if self.random_position and position_ids[-1] < self.max_seq_length - 1:
  225. position_bias = self.max_seq_length - position_ids[-1]
  226. position_bias = rng.randrange(0, position_bias)
  227. position_ids = position_ids + position_bias
  228. if self.encoder_decoder or not self.shuffle_blocks:
  229. block_spans.sort(key=lambda x: x[0])
  230. else:
  231. rng.shuffle(block_spans)
  232. if self.sentinel_token:
  233. block_spans = [(start, end, idx)
  234. for idx, (start, end) in enumerate(block_spans)]
  235. else:
  236. block_spans = [(start, end, 0) for start, end in block_spans]
  237. target_tokens, target_position_ids, target_block_position_ids, targets = [], [], [], []
  238. for start, end, idx in block_spans:
  239. sop_token = 'sop' if idx == 0 else f'sop{idx}'
  240. target_tokens.append([self.tokenizer.get_command(sop_token).Id])
  241. span_tokens = copy.deepcopy(tokens[start:end])
  242. if self.block_mask_prob > 0.0 and task == 'bert':
  243. for sub_idx in range(len(span_tokens)):
  244. if random.random() < self.block_mask_prob:
  245. span_tokens[sub_idx] = self.tokenizer.get_command(
  246. 'dBLOCK').Id
  247. target_tokens.append(span_tokens)
  248. targets.append(tokens[start:end])
  249. targets.append([self.tokenizer.get_command('eop').Id])
  250. if not self.sentinel_token:
  251. target_position_id = position_ids[start:end]
  252. target_position_ids.append(target_position_id)
  253. target_position_ids.append([target_position_id[0]])
  254. else:
  255. target_position_ids.append([self.max_seq_length] * # noqa
  256. (end - start + 1))
  257. if self.block_position_encoding:
  258. target_block_position_ids.append(
  259. np.arange(1, end - start + 2, dtype=int))
  260. else:
  261. target_block_position_ids.append([1] * (end - start + 1))
  262. block_spans.sort(key=lambda x: x[0])
  263. source_tokens, source_position_ids, local_spans = [], [], []
  264. last, current_length = 0, 0
  265. for start, end, idx in block_spans:
  266. if task == 'generation':
  267. mask_id = self.generation_mask
  268. elif task == 'gap_sentence':
  269. mask_id = self.gap_sentence_mask
  270. else:
  271. mask_token = 'MASK' if idx == 0 else f'MASK{idx}'
  272. mask_id = self.tokenizer.get_command(mask_token).Id
  273. local_spans.append((current_length, current_length + start - last))
  274. source_tokens.append(tokens[last:start])
  275. source_tokens.append([mask_id])
  276. source_position_ids.append(position_ids[last:start])
  277. source_position_ids.append([position_ids[start]])
  278. current_length += start - last + 1
  279. last = end
  280. if last < len(tokens):
  281. local_spans.append(
  282. (current_length, current_length + len(tokens) - last))
  283. source_tokens.append(tokens[last:])
  284. source_position_ids.append(position_ids[last:])
  285. source_length = sum(map(len, source_tokens))
  286. if attention_mask is not None:
  287. assert source_length == attention_mask
  288. if target_tokens and self.eod_token in np.concatenate(
  289. target_tokens).tolist():
  290. print('Found EOS in target', self.tokenizer.DecodeIds(tokens))
  291. raise RuntimeError
  292. if self.encoder_decoder:
  293. target_tokens = target_tokens + [
  294. self.tokenizer.get_command('eop').Id
  295. ]
  296. loss_masks = np.ones(len(target_tokens), dtype=int)
  297. return source_tokens, target_tokens, loss_masks
  298. else:
  299. tokens = np.concatenate(source_tokens + target_tokens)
  300. if task == 'bert' and self.context_mask_ratio > 0:
  301. mask_candidates = set()
  302. for start, end in local_spans:
  303. if start != 0:
  304. local_end = min(end, start + self.context_mask_range)
  305. mask_candidates.update(range(start, local_end))
  306. if end != 0:
  307. local_start = max(start, end - self.context_mask_range)
  308. mask_candidates.update(range(local_start, end))
  309. mask_pos = rng.sample(
  310. mask_candidates,
  311. int(self.context_mask_ratio * text_length))
  312. for pos in mask_pos:
  313. tokens[pos] = self.tokenizer.get_command('dBLOCK').Id
  314. targets = np.concatenate(source_tokens + targets)
  315. loss_masks = np.ones(len(tokens), dtype=int)
  316. loss_masks[:source_length] = 0
  317. position_ids = np.concatenate(source_position_ids
  318. + target_position_ids)
  319. block_position_ids = np.concatenate(
  320. [np.zeros(source_length, dtype=int)]
  321. + target_block_position_ids)
  322. position_ids = np.stack([position_ids, block_position_ids], axis=0)
  323. if attention_mask is not None:
  324. return tokens, targets, loss_masks, position_ids
  325. else:
  326. return tokens, targets, loss_masks, position_ids, source_length
  327. def generate_blank_data(self,
  328. sample,
  329. masked_lengths,
  330. attention_mask,
  331. rng,
  332. task='bert'):
  333. rng.shuffle(masked_lengths)
  334. tokens, loss_masks = sample['text'], sample['loss_mask']
  335. assert tokens[0] == self.tokenizer.get_command('ENC').Id
  336. block_spans = self.sample_span_in_document(tokens, masked_lengths, rng)
  337. if len(block_spans) < len(masked_lengths):
  338. return None
  339. if self.masked_lm:
  340. data = self.make_masked_data(tokens, loss_masks, attention_mask,
  341. block_spans, rng)
  342. else:
  343. data = self.make_block_data(
  344. tokens,
  345. loss_masks,
  346. attention_mask,
  347. block_spans,
  348. rng,
  349. task=task)
  350. return data
  351. def split_samples(self, samples, rng):
  352. target_length = rng.randrange(32, self.max_seq_length - 1)
  353. num_splits = (self.max_seq_length - 1) // target_length
  354. new_samples = []
  355. cls_id = self.tokenizer.get_command('ENC').Id
  356. eos_id = self.tokenizer.get_command('eos').Id
  357. for sample in samples:
  358. tokens, loss_masks = sample['text'][1:], sample['loss_mask'][1:]
  359. for _ in range(num_splits):
  360. if target_length >= len(tokens):
  361. new_tokens, new_loss_masks = tokens, loss_masks
  362. else:
  363. random_start = rng.randrange(0,
  364. len(tokens) - target_length)
  365. while random_start > 0 and (
  366. tokens[random_start] == eos_id or # noqa
  367. not (self.contains_sentence_end( # noqa
  368. tokens[random_start - 1]) or # noqa
  369. tokens[random_start - 1] == eos_id)): # noqa
  370. random_start -= 1
  371. random_end = random_start + target_length
  372. while random_end > random_start and not (
  373. self.contains_sentence_end(tokens[random_end - 1])
  374. or tokens[random_end - 1] == eos_id):
  375. random_end -= 1
  376. if random_end - random_start < target_length // 2:
  377. random_end = random_start + target_length
  378. new_tokens, new_loss_masks = tokens[
  379. random_start:random_end], loss_masks[
  380. random_start:random_end]
  381. new_tokens = np.concatenate(([cls_id], new_tokens))
  382. new_loss_masks = np.concatenate(([0], new_loss_masks))
  383. new_samples.append({
  384. 'text': new_tokens,
  385. 'loss_mask': new_loss_masks
  386. })
  387. return new_samples
  388. def construct_blocks(self, samples):
  389. worker_info = torch.utils.data.get_worker_info()
  390. if worker_info is not None:
  391. worker_id, num_workers = worker_info.id, worker_info.num_workers
  392. else:
  393. worker_id, num_workers = 0, 1
  394. rng = random.Random((self.count * num_workers + worker_id)
  395. * self.world_size + self.rank)
  396. self.count += 1
  397. token_batch, target_batch, loss_mask_batch, position_id_batch = [], [], [], []
  398. source_batch, target_batch = [], []
  399. if rng.random() < self.short_seq_prob:
  400. samples = self.split_samples(samples, rng)
  401. rand = rng.random()
  402. single_span = rand < self.single_span_prob
  403. rand = 0.0 if single_span else rng.random()
  404. attention_mask = []
  405. if rand < self.bert_prob:
  406. mode = 'bert'
  407. for sample in samples:
  408. if single_span:
  409. masked_lengths = [
  410. rng.choices(
  411. range(1,
  412. len(self.block_length_distribution) + 1),
  413. weights=self.block_length_distribution)[0]
  414. ]
  415. masked_count = masked_lengths[0]
  416. else:
  417. masked_lengths, masked_count = [], 0
  418. while masked_count < int(
  419. self.bert_ratio * len(sample['text'])):
  420. block_length = rng.choices(
  421. range(1,
  422. len(self.block_length_distribution) + 1),
  423. weights=self.block_length_distribution)[0]
  424. masked_lengths.append(block_length)
  425. masked_count += block_length
  426. if self.masked_lm:
  427. sep = len(sample['text'])
  428. else:
  429. sep = len(
  430. sample['text']) - masked_count + len(masked_lengths)
  431. data = self.generate_blank_data(
  432. sample, masked_lengths, sep, rng, task='bert')
  433. if data is not None:
  434. if self.encoder_decoder:
  435. source_tokens, target_tokens, loss_masks = data
  436. source_batch.append(source_tokens)
  437. target_batch.append(target_tokens)
  438. loss_mask_batch.append(loss_masks)
  439. else:
  440. tokens, targets, loss_masks, position_ids = data
  441. token_batch.append(tokens)
  442. target_batch.append(targets)
  443. loss_mask_batch.append(loss_masks)
  444. position_id_batch.append(position_ids)
  445. attention_mask.append(sep)
  446. elif rand < self.bert_prob + self.gap_sentence_prob:
  447. mode = 'sentence'
  448. for sample in samples:
  449. tokens, loss_masks = sample['text'], sample['loss_mask']
  450. sentence_spans = []
  451. last_index = 1 if tokens[0] == self.tokenizer.get_command(
  452. 'ENC').Id else 0
  453. for i in range(len(tokens)):
  454. if self.contains_sentence_end(tokens[i]):
  455. if last_index < i + 1:
  456. sentence_spans.append((last_index, i + 1))
  457. last_index = i + 1
  458. elif tokens[i] == self.tokenizer.get_command('eos').Id:
  459. last_index = i + 1
  460. if last_index < len(tokens):
  461. sentence_spans.append((last_index, len(tokens)))
  462. if not sentence_spans and torch.distributed.get_rank() == 0:
  463. try:
  464. print(self.tokenizer.DecodeIds(tokens[1:]))
  465. except IndexError:
  466. print(tokens[1:])
  467. rng.shuffle(sentence_spans)
  468. block_spans, block_length = [], 0
  469. for start, end in sentence_spans:
  470. block_spans.append((start, end))
  471. block_length += end - start
  472. if block_length >= int(
  473. self.gap_sentence_ratio * len(tokens)):
  474. break
  475. data = self.make_block_data(
  476. tokens,
  477. loss_masks,
  478. None,
  479. block_spans,
  480. rng,
  481. task='gap_sentence')
  482. tokens, targets, loss_masks, position_ids, sep = data
  483. token_batch.append(tokens)
  484. target_batch.append(targets)
  485. loss_mask_batch.append(loss_masks)
  486. position_id_batch.append(position_ids)
  487. attention_mask.append(sep)
  488. else:
  489. # start_indices = [index_in_list(sample['loss_mask'], 1) for sample in samples]
  490. # end_indices = [rindex(sample['loss_mask'], 1) for sample in samples]
  491. # start_index, end_index = max(start_indices), min(end_indices) - self.min_generation_length
  492. # if end_index < start_index + 1:
  493. # end_index = start_index + 1
  494. # division = rng.randrange(start_index, end_index)
  495. mode = 'gpt'
  496. max_generation_length = rng.randint(
  497. int(self.gpt_min_ratio
  498. * min(map(lambda x: len(x['text']), samples))),
  499. max(map(lambda x: len(x['text']), samples)) - 2)
  500. for sample in samples:
  501. generation_length = min(max_generation_length,
  502. len(sample['text']) - 2)
  503. attention_mask.append(
  504. len(sample['text']) - generation_length + 1)
  505. multiple_doc = index_in_list(
  506. sample['text'],
  507. self.tokenizer.get_command('eos').Id) not in [
  508. -1, len(sample['text']) - 1
  509. ] # noqa
  510. if multiple_doc or rng.random() < self.infill_prob:
  511. division = len(sample['text']) - generation_length
  512. tokens, loss_masks = sample['text'], sample['loss_mask']
  513. source_tokens, target_tokens = tokens[:division], tokens[
  514. division:]
  515. target_masks = loss_masks[division:]
  516. tokens = np.concatenate((source_tokens, [
  517. self.generation_mask,
  518. self.tokenizer.get_command('sop').Id
  519. ], target_tokens[:-1]))
  520. targets = np.concatenate(
  521. (source_tokens, [self.generation_mask], target_tokens))
  522. loss_masks = np.concatenate(
  523. (np.zeros(len(source_tokens) + 1,
  524. dtype=int), target_masks))
  525. token_batch.append(tokens)
  526. target_batch.append(targets)
  527. loss_mask_batch.append(loss_masks)
  528. position_ids = np.arange(
  529. len(source_tokens) + len(target_tokens) + 1, dtype=int)
  530. position_ids[len(source_tokens) + 1:] = len(source_tokens)
  531. if self.block_position_encoding:
  532. block_position_ids = np.concatenate(
  533. (np.zeros(len(source_tokens), dtype=int),
  534. np.arange(len(target_tokens) + 1, dtype=int)))
  535. else:
  536. block_position_ids = np.concatenate(
  537. (np.zeros(len(source_tokens) + 1, dtype=int),
  538. np.ones(len(target_tokens) + 1, dtype=int)))
  539. position_id_batch.append(
  540. np.stack([position_ids, block_position_ids], axis=0))
  541. else:
  542. tokens, targets, loss_masks, position_ids = self.generate_blank_data(
  543. sample, [generation_length],
  544. attention_mask[-1],
  545. rng,
  546. task='generation')
  547. token_batch.append(tokens)
  548. target_batch.append(targets)
  549. loss_mask_batch.append(loss_masks)
  550. position_id_batch.append(position_ids)
  551. if tokens is None:
  552. print(sample, generation_length, multiple_doc)
  553. if self.encoder_decoder:
  554. return {
  555. 'text': torch.tensor(source_batch, dtype=torch.long),
  556. 'target': torch.tensor(target_batch, dtype=torch.long),
  557. 'loss_mask': torch.tensor(loss_mask_batch, dtype=torch.long)
  558. }
  559. else:
  560. token_batch, target_batch, loss_mask_batch, position_id_batch = self.pad_batch(
  561. token_batch, target_batch, loss_mask_batch, position_id_batch)
  562. return {
  563. 'text': torch.tensor(token_batch, dtype=torch.long),
  564. 'target': torch.tensor(target_batch, dtype=torch.long),
  565. 'loss_mask': torch.tensor(loss_mask_batch, dtype=torch.long),
  566. 'position_id':
  567. torch.tensor(position_id_batch, dtype=torch.long),
  568. 'attention_mask':
  569. torch.tensor(attention_mask, dtype=torch.long),
  570. 'mode': mode
  571. }
  572. @staticmethod
  573. def pad_batch(token_batch, target_batch, loss_mask_batch,
  574. position_id_batch):
  575. seq_lengths = list(map(len, token_batch))
  576. if seq_lengths.count(seq_lengths[0]) != len(seq_lengths):
  577. max_length = max(seq_lengths)
  578. token_batch = [
  579. np.concatenate(
  580. (tokens, np.zeros(max_length - len(tokens), dtype=int)))
  581. for tokens in token_batch
  582. ]
  583. target_batch = [
  584. np.concatenate(
  585. (targets, np.zeros(max_length - len(targets), dtype=int)))
  586. for targets in target_batch
  587. ]
  588. loss_mask_batch = [
  589. np.concatenate(
  590. (loss_masks,
  591. np.zeros(max_length - len(loss_masks), dtype=int)))
  592. for loss_masks in loss_mask_batch
  593. ]
  594. position_id_batch = [
  595. np.concatenate(
  596. (position_ids,
  597. np.zeros(
  598. (2, max_length - position_ids.shape[1]), dtype=int)),
  599. axis=1) for position_ids in position_id_batch
  600. ]
  601. return token_batch, target_batch, loss_mask_batch, position_id_batch