| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620 |
- # Copyright (c) 2022 Zhipu.AI
- import copy
- import math
- import random
- import numpy as np
- import torch
- import torch.utils.data
- from megatron_util import mpu, print_rank_0
- from scipy.stats import poisson
- def rindex(lst, val, start=None):
- if start is None:
- start = len(lst) - 1
- for i in range(start, -1, -1):
- if lst[i] == val:
- return i
- return -1
- def index_in_list(lst, val, start=None):
- if start is None:
- start = 0
- for i in range(start, len(lst)):
- if lst[i] == val:
- return i
- return -1
- class ConstructBlockStrategy:
- def __init__(self,
- args,
- tokenizer,
- max_seq_length,
- bert_prob=1.0,
- gap_sentence_prob=0.0,
- gpt_infill_prob=0.5,
- gpt_min_ratio=0.5,
- bert_ratio=0.15,
- gap_sentence_ratio=0.15,
- average_block_length=3,
- max_block_length=40,
- block_mask_prob=0.0,
- context_mask_ratio=0.0,
- context_mask_range=3,
- short_seq_prob=0.0,
- single_span_prob=0.0,
- block_position_encoding=True,
- encoder_decoder=False,
- shuffle_blocks=True,
- sentinel_token=False,
- task_mask=False,
- random_position=False,
- masked_lm=False):
- self.eod_token = args.eod_token
- self.tokenizer = tokenizer
- self.count = 0
- self.max_seq_length = max_seq_length
- self.rank = mpu.get_data_parallel_rank()
- self.world_size = mpu.get_data_parallel_world_size()
- # self.rank = 0
- # self.world_size = 1
- assert 0.0 <= bert_prob <= 1.0
- self.bert_prob = bert_prob
- self.gap_sentence_prob = gap_sentence_prob
- self.gpt_prob = 1 - bert_prob - gap_sentence_prob
- assert self.gpt_prob >= -1e-10
- self.infill_prob = gpt_infill_prob
- self.gpt_min_ratio = gpt_min_ratio
- self.bert_ratio = bert_ratio
- self.gap_sentence_ratio = gap_sentence_ratio
- self.block_length_distribution = [
- poisson.pmf(i, average_block_length)
- for i in range(1, max_block_length)
- ]
- self.block_mask_prob = block_mask_prob
- self.context_mask_ratio = context_mask_ratio
- self.context_mask_range = context_mask_range
- self.short_seq_prob = short_seq_prob
- self.single_span_prob = single_span_prob
- self.block_position_encoding = block_position_encoding
- self.encoder_decoder = encoder_decoder
- self.shuffle_blocks = shuffle_blocks
- self.sentinel_token = sentinel_token
- self.generation_mask = 'gMASK' if task_mask else 'MASK'
- self.generation_mask = self.tokenizer.get_command(
- self.generation_mask).Id
- self.gap_sentence_mask = 'sMASK' if task_mask else 'MASK'
- self.gap_sentence_mask = self.tokenizer.get_command(
- self.gap_sentence_mask).Id
- self.random_position = random_position
- self.masked_lm = masked_lm
- print_rank_0(
- f'BERT prob {self.bert_prob}, gap sent prob {self.gap_sentence_prob}, GPT prob {self.gpt_prob}, infill prob {self.infill_prob}' # noqa
- )
- print_rank_0(
- f'generation min ratio {self.gpt_min_ratio}, block ratio {self.bert_ratio}, gap sent ratio {self.gap_sentence_ratio}' # noqa
- )
- print_rank_0(
- f'block length distribution {self.block_length_distribution}')
- print_rank_0(
- f'block mask prob {self.block_mask_prob}, context mask ratio {self.context_mask_ratio}'
- )
- def contains_sentence_end(self, tok):
- tok = self.tokenizer.IdToToken(tok)
- if '.' in tok:
- return True
- if '?' in tok:
- return True
- if '!' in tok:
- return True
- if ';' in tok:
- return True
- if ':' in tok:
- return True
- if '。' in tok:
- return True
- if '?' in tok:
- return True
- if '!' in tok:
- return True
- if ';' in tok:
- return True
- if '…' in tok:
- return True
- if '\n' in tok:
- return True
- return False
- @staticmethod
- def sample_spans(span_lengths, total_length, rng, offset=0):
- blank_length = total_length - sum(span_lengths)
- m = blank_length - len(span_lengths) + 1
- places = [rng.randrange(m + 1) for _ in range(len(span_lengths))]
- places.sort()
- spans = []
- for place, span_length in zip(places, span_lengths):
- start = offset + place
- end = offset + place + span_length
- spans.append((start, end))
- offset += span_length + 1
- return spans
- def sample_span_in_document(self, tokens, masked_lengths, rng):
- rng.shuffle(masked_lengths)
- mask_spans = []
- mask_index = 0
- indices = [-1] + np.where(tokens == self.eod_token)[0].tolist()
- last_index = len(tokens)
- documents = []
- for index in reversed(indices):
- start_index = index
- if start_index + 1 < len(tokens) and tokens[
- start_index + 1] == self.tokenizer.get_command('ENC').Id:
- start_index += 1
- length = last_index - start_index - 1
- if last_index == len(tokens) and length > 0:
- length -= 1
- documents.append((start_index + 1, length))
- last_index = index
- documents.sort(key=lambda x: x[1])
- for i, (offset, length) in enumerate(documents):
- if i == len(documents) - 1:
- current_masked_length, current_count = 0, 0
- while mask_index + current_count < len(
- masked_lengths
- ) and masked_lengths[
- mask_index + # noqa
- current_count] + current_masked_length + current_count <= length:
- current_masked_length += masked_lengths[mask_index
- + current_count]
- current_count += 1
- if current_count > 0:
- spans = self.sample_spans(
- masked_lengths[mask_index:mask_index + current_count],
- length,
- rng,
- offset=offset)
- mask_spans += spans
- if mask_index + current_count < len(masked_lengths) - 1:
- print(length, masked_lengths[mask_index:],
- masked_lengths[:mask_index], indices)
- else:
- current_masked_total = int(length * self.bert_ratio)
- current_masked_length, current_count = 0, 0
- while mask_index + current_count < len(
- masked_lengths
- ) and masked_lengths[
- mask_index + # noqa
- current_count] + current_masked_length <= current_masked_total:
- current_masked_length += masked_lengths[mask_index
- + current_count]
- current_count += 1
- if current_count > 0:
- spans = self.sample_spans(
- masked_lengths[mask_index:mask_index + current_count],
- length,
- rng,
- offset=offset)
- mask_spans += spans
- mask_index += current_count
- return mask_spans
- def make_masked_data(self,
- tokens,
- loss_masks,
- attention_mask,
- block_spans,
- rng,
- task='bert'):
- position_ids = np.arange(len(tokens), dtype=int)
- targets = copy.deepcopy(tokens)
- mask_id = self.tokenizer.get_command('MASK').Id
- mlm_masks = np.zeros(len(tokens), dtype=int)
- for start, end in block_spans:
- for idx in range(start, end):
- tokens[idx] = mask_id
- mlm_masks[start:end] = 1
- loss_masks = loss_masks * mlm_masks
- return tokens, targets, loss_masks, position_ids
- def make_block_data(self,
- tokens,
- loss_masks,
- attention_mask,
- block_spans,
- rng,
- task='bert'):
- text_length = len(tokens)
- position_ids = np.ones(len(tokens), dtype=int)
- for start, end in block_spans:
- position_ids[start + 1:end] = 0
- position_ids = np.cumsum(position_ids) - 1
- if self.random_position and position_ids[-1] < self.max_seq_length - 1:
- position_bias = self.max_seq_length - position_ids[-1]
- position_bias = rng.randrange(0, position_bias)
- position_ids = position_ids + position_bias
- if self.encoder_decoder or not self.shuffle_blocks:
- block_spans.sort(key=lambda x: x[0])
- else:
- rng.shuffle(block_spans)
- if self.sentinel_token:
- block_spans = [(start, end, idx)
- for idx, (start, end) in enumerate(block_spans)]
- else:
- block_spans = [(start, end, 0) for start, end in block_spans]
- target_tokens, target_position_ids, target_block_position_ids, targets = [], [], [], []
- for start, end, idx in block_spans:
- sop_token = 'sop' if idx == 0 else f'sop{idx}'
- target_tokens.append([self.tokenizer.get_command(sop_token).Id])
- span_tokens = copy.deepcopy(tokens[start:end])
- if self.block_mask_prob > 0.0 and task == 'bert':
- for sub_idx in range(len(span_tokens)):
- if random.random() < self.block_mask_prob:
- span_tokens[sub_idx] = self.tokenizer.get_command(
- 'dBLOCK').Id
- target_tokens.append(span_tokens)
- targets.append(tokens[start:end])
- targets.append([self.tokenizer.get_command('eop').Id])
- if not self.sentinel_token:
- target_position_id = position_ids[start:end]
- target_position_ids.append(target_position_id)
- target_position_ids.append([target_position_id[0]])
- else:
- target_position_ids.append([self.max_seq_length] * # noqa
- (end - start + 1))
- if self.block_position_encoding:
- target_block_position_ids.append(
- np.arange(1, end - start + 2, dtype=int))
- else:
- target_block_position_ids.append([1] * (end - start + 1))
- block_spans.sort(key=lambda x: x[0])
- source_tokens, source_position_ids, local_spans = [], [], []
- last, current_length = 0, 0
- for start, end, idx in block_spans:
- if task == 'generation':
- mask_id = self.generation_mask
- elif task == 'gap_sentence':
- mask_id = self.gap_sentence_mask
- else:
- mask_token = 'MASK' if idx == 0 else f'MASK{idx}'
- mask_id = self.tokenizer.get_command(mask_token).Id
- local_spans.append((current_length, current_length + start - last))
- source_tokens.append(tokens[last:start])
- source_tokens.append([mask_id])
- source_position_ids.append(position_ids[last:start])
- source_position_ids.append([position_ids[start]])
- current_length += start - last + 1
- last = end
- if last < len(tokens):
- local_spans.append(
- (current_length, current_length + len(tokens) - last))
- source_tokens.append(tokens[last:])
- source_position_ids.append(position_ids[last:])
- source_length = sum(map(len, source_tokens))
- if attention_mask is not None:
- assert source_length == attention_mask
- if target_tokens and self.eod_token in np.concatenate(
- target_tokens).tolist():
- print('Found EOS in target', self.tokenizer.DecodeIds(tokens))
- raise RuntimeError
- if self.encoder_decoder:
- target_tokens = target_tokens + [
- self.tokenizer.get_command('eop').Id
- ]
- loss_masks = np.ones(len(target_tokens), dtype=int)
- return source_tokens, target_tokens, loss_masks
- else:
- tokens = np.concatenate(source_tokens + target_tokens)
- if task == 'bert' and self.context_mask_ratio > 0:
- mask_candidates = set()
- for start, end in local_spans:
- if start != 0:
- local_end = min(end, start + self.context_mask_range)
- mask_candidates.update(range(start, local_end))
- if end != 0:
- local_start = max(start, end - self.context_mask_range)
- mask_candidates.update(range(local_start, end))
- mask_pos = rng.sample(
- mask_candidates,
- int(self.context_mask_ratio * text_length))
- for pos in mask_pos:
- tokens[pos] = self.tokenizer.get_command('dBLOCK').Id
- targets = np.concatenate(source_tokens + targets)
- loss_masks = np.ones(len(tokens), dtype=int)
- loss_masks[:source_length] = 0
- position_ids = np.concatenate(source_position_ids
- + target_position_ids)
- block_position_ids = np.concatenate(
- [np.zeros(source_length, dtype=int)]
- + target_block_position_ids)
- position_ids = np.stack([position_ids, block_position_ids], axis=0)
- if attention_mask is not None:
- return tokens, targets, loss_masks, position_ids
- else:
- return tokens, targets, loss_masks, position_ids, source_length
- def generate_blank_data(self,
- sample,
- masked_lengths,
- attention_mask,
- rng,
- task='bert'):
- rng.shuffle(masked_lengths)
- tokens, loss_masks = sample['text'], sample['loss_mask']
- assert tokens[0] == self.tokenizer.get_command('ENC').Id
- block_spans = self.sample_span_in_document(tokens, masked_lengths, rng)
- if len(block_spans) < len(masked_lengths):
- return None
- if self.masked_lm:
- data = self.make_masked_data(tokens, loss_masks, attention_mask,
- block_spans, rng)
- else:
- data = self.make_block_data(
- tokens,
- loss_masks,
- attention_mask,
- block_spans,
- rng,
- task=task)
- return data
- def split_samples(self, samples, rng):
- target_length = rng.randrange(32, self.max_seq_length - 1)
- num_splits = (self.max_seq_length - 1) // target_length
- new_samples = []
- cls_id = self.tokenizer.get_command('ENC').Id
- eos_id = self.tokenizer.get_command('eos').Id
- for sample in samples:
- tokens, loss_masks = sample['text'][1:], sample['loss_mask'][1:]
- for _ in range(num_splits):
- if target_length >= len(tokens):
- new_tokens, new_loss_masks = tokens, loss_masks
- else:
- random_start = rng.randrange(0,
- len(tokens) - target_length)
- while random_start > 0 and (
- tokens[random_start] == eos_id or # noqa
- not (self.contains_sentence_end( # noqa
- tokens[random_start - 1]) or # noqa
- tokens[random_start - 1] == eos_id)): # noqa
- random_start -= 1
- random_end = random_start + target_length
- while random_end > random_start and not (
- self.contains_sentence_end(tokens[random_end - 1])
- or tokens[random_end - 1] == eos_id):
- random_end -= 1
- if random_end - random_start < target_length // 2:
- random_end = random_start + target_length
- new_tokens, new_loss_masks = tokens[
- random_start:random_end], loss_masks[
- random_start:random_end]
- new_tokens = np.concatenate(([cls_id], new_tokens))
- new_loss_masks = np.concatenate(([0], new_loss_masks))
- new_samples.append({
- 'text': new_tokens,
- 'loss_mask': new_loss_masks
- })
- return new_samples
- def construct_blocks(self, samples):
- worker_info = torch.utils.data.get_worker_info()
- if worker_info is not None:
- worker_id, num_workers = worker_info.id, worker_info.num_workers
- else:
- worker_id, num_workers = 0, 1
- rng = random.Random((self.count * num_workers + worker_id)
- * self.world_size + self.rank)
- self.count += 1
- token_batch, target_batch, loss_mask_batch, position_id_batch = [], [], [], []
- source_batch, target_batch = [], []
- if rng.random() < self.short_seq_prob:
- samples = self.split_samples(samples, rng)
- rand = rng.random()
- single_span = rand < self.single_span_prob
- rand = 0.0 if single_span else rng.random()
- attention_mask = []
- if rand < self.bert_prob:
- mode = 'bert'
- for sample in samples:
- if single_span:
- masked_lengths = [
- rng.choices(
- range(1,
- len(self.block_length_distribution) + 1),
- weights=self.block_length_distribution)[0]
- ]
- masked_count = masked_lengths[0]
- else:
- masked_lengths, masked_count = [], 0
- while masked_count < int(
- self.bert_ratio * len(sample['text'])):
- block_length = rng.choices(
- range(1,
- len(self.block_length_distribution) + 1),
- weights=self.block_length_distribution)[0]
- masked_lengths.append(block_length)
- masked_count += block_length
- if self.masked_lm:
- sep = len(sample['text'])
- else:
- sep = len(
- sample['text']) - masked_count + len(masked_lengths)
- data = self.generate_blank_data(
- sample, masked_lengths, sep, rng, task='bert')
- if data is not None:
- if self.encoder_decoder:
- source_tokens, target_tokens, loss_masks = data
- source_batch.append(source_tokens)
- target_batch.append(target_tokens)
- loss_mask_batch.append(loss_masks)
- else:
- tokens, targets, loss_masks, position_ids = data
- token_batch.append(tokens)
- target_batch.append(targets)
- loss_mask_batch.append(loss_masks)
- position_id_batch.append(position_ids)
- attention_mask.append(sep)
- elif rand < self.bert_prob + self.gap_sentence_prob:
- mode = 'sentence'
- for sample in samples:
- tokens, loss_masks = sample['text'], sample['loss_mask']
- sentence_spans = []
- last_index = 1 if tokens[0] == self.tokenizer.get_command(
- 'ENC').Id else 0
- for i in range(len(tokens)):
- if self.contains_sentence_end(tokens[i]):
- if last_index < i + 1:
- sentence_spans.append((last_index, i + 1))
- last_index = i + 1
- elif tokens[i] == self.tokenizer.get_command('eos').Id:
- last_index = i + 1
- if last_index < len(tokens):
- sentence_spans.append((last_index, len(tokens)))
- if not sentence_spans and torch.distributed.get_rank() == 0:
- try:
- print(self.tokenizer.DecodeIds(tokens[1:]))
- except IndexError:
- print(tokens[1:])
- rng.shuffle(sentence_spans)
- block_spans, block_length = [], 0
- for start, end in sentence_spans:
- block_spans.append((start, end))
- block_length += end - start
- if block_length >= int(
- self.gap_sentence_ratio * len(tokens)):
- break
- data = self.make_block_data(
- tokens,
- loss_masks,
- None,
- block_spans,
- rng,
- task='gap_sentence')
- tokens, targets, loss_masks, position_ids, sep = data
- token_batch.append(tokens)
- target_batch.append(targets)
- loss_mask_batch.append(loss_masks)
- position_id_batch.append(position_ids)
- attention_mask.append(sep)
- else:
- # start_indices = [index_in_list(sample['loss_mask'], 1) for sample in samples]
- # end_indices = [rindex(sample['loss_mask'], 1) for sample in samples]
- # start_index, end_index = max(start_indices), min(end_indices) - self.min_generation_length
- # if end_index < start_index + 1:
- # end_index = start_index + 1
- # division = rng.randrange(start_index, end_index)
- mode = 'gpt'
- max_generation_length = rng.randint(
- int(self.gpt_min_ratio
- * min(map(lambda x: len(x['text']), samples))),
- max(map(lambda x: len(x['text']), samples)) - 2)
- for sample in samples:
- generation_length = min(max_generation_length,
- len(sample['text']) - 2)
- attention_mask.append(
- len(sample['text']) - generation_length + 1)
- multiple_doc = index_in_list(
- sample['text'],
- self.tokenizer.get_command('eos').Id) not in [
- -1, len(sample['text']) - 1
- ] # noqa
- if multiple_doc or rng.random() < self.infill_prob:
- division = len(sample['text']) - generation_length
- tokens, loss_masks = sample['text'], sample['loss_mask']
- source_tokens, target_tokens = tokens[:division], tokens[
- division:]
- target_masks = loss_masks[division:]
- tokens = np.concatenate((source_tokens, [
- self.generation_mask,
- self.tokenizer.get_command('sop').Id
- ], target_tokens[:-1]))
- targets = np.concatenate(
- (source_tokens, [self.generation_mask], target_tokens))
- loss_masks = np.concatenate(
- (np.zeros(len(source_tokens) + 1,
- dtype=int), target_masks))
- token_batch.append(tokens)
- target_batch.append(targets)
- loss_mask_batch.append(loss_masks)
- position_ids = np.arange(
- len(source_tokens) + len(target_tokens) + 1, dtype=int)
- position_ids[len(source_tokens) + 1:] = len(source_tokens)
- if self.block_position_encoding:
- block_position_ids = np.concatenate(
- (np.zeros(len(source_tokens), dtype=int),
- np.arange(len(target_tokens) + 1, dtype=int)))
- else:
- block_position_ids = np.concatenate(
- (np.zeros(len(source_tokens) + 1, dtype=int),
- np.ones(len(target_tokens) + 1, dtype=int)))
- position_id_batch.append(
- np.stack([position_ids, block_position_ids], axis=0))
- else:
- tokens, targets, loss_masks, position_ids = self.generate_blank_data(
- sample, [generation_length],
- attention_mask[-1],
- rng,
- task='generation')
- token_batch.append(tokens)
- target_batch.append(targets)
- loss_mask_batch.append(loss_masks)
- position_id_batch.append(position_ids)
- if tokens is None:
- print(sample, generation_length, multiple_doc)
- if self.encoder_decoder:
- return {
- 'text': torch.tensor(source_batch, dtype=torch.long),
- 'target': torch.tensor(target_batch, dtype=torch.long),
- 'loss_mask': torch.tensor(loss_mask_batch, dtype=torch.long)
- }
- else:
- token_batch, target_batch, loss_mask_batch, position_id_batch = self.pad_batch(
- token_batch, target_batch, loss_mask_batch, position_id_batch)
- return {
- 'text': torch.tensor(token_batch, dtype=torch.long),
- 'target': torch.tensor(target_batch, dtype=torch.long),
- 'loss_mask': torch.tensor(loss_mask_batch, dtype=torch.long),
- 'position_id':
- torch.tensor(position_id_batch, dtype=torch.long),
- 'attention_mask':
- torch.tensor(attention_mask, dtype=torch.long),
- 'mode': mode
- }
- @staticmethod
- def pad_batch(token_batch, target_batch, loss_mask_batch,
- position_id_batch):
- seq_lengths = list(map(len, token_batch))
- if seq_lengths.count(seq_lengths[0]) != len(seq_lengths):
- max_length = max(seq_lengths)
- token_batch = [
- np.concatenate(
- (tokens, np.zeros(max_length - len(tokens), dtype=int)))
- for tokens in token_batch
- ]
- target_batch = [
- np.concatenate(
- (targets, np.zeros(max_length - len(targets), dtype=int)))
- for targets in target_batch
- ]
- loss_mask_batch = [
- np.concatenate(
- (loss_masks,
- np.zeros(max_length - len(loss_masks), dtype=int)))
- for loss_masks in loss_mask_batch
- ]
- position_id_batch = [
- np.concatenate(
- (position_ids,
- np.zeros(
- (2, max_length - position_ids.shape[1]), dtype=int)),
- axis=1) for position_ids in position_id_batch
- ]
- return token_batch, target_batch, loss_mask_batch, position_id_batch
|