| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173 |
- import random
- import nltk
- import numpy as np
- import torch
- def get_random_states(device=None):
- random_states = {}
- random_states['rng_state_torch'] = torch.get_rng_state()
- random_states['rng_state_np'] = np.random.get_state()
- random_states['rng_state_rnd'] = random.getstate()
- if device is not None and device.type == 'cuda':
- random_states['rng_state_torch_cuda'] = torch.cuda.get_rng_state(
- device)
- return random_states
- def set_random_states(random_states, device=None):
- torch.set_rng_state(random_states['rng_state_torch'])
- np.random.set_state(random_states['rng_state_np'])
- random.setstate(random_states['rng_state_rnd'])
- if device is not None and device.type == 'cuda':
- torch.cuda.set_rng_state(random_states['rng_state_torch_cuda'])
- # Check any nan or inf in the data. Return an array of two elements for nan and inf, respectively.
- # Inputs
- # data: a tensor or a tuple of multiple tensors
- # Outputs:
- # results: Each element shows the # of tensors that includes nan or inf.
- # If data is a "tuple" (instead of a single tensor),
- # we add 10 to the count if any nan or inf is detected.
- def check_nan_inf(data):
- if data is None:
- return None
- result = [0, 0]
- if torch.is_tensor(data):
- if torch.isnan(data).any():
- result[0] = 1
- if torch.isinf(data).any():
- result[1] = 1
- elif type(data) is tuple:
- for i in range(len(data)):
- if torch.is_tensor(data[i]):
- if torch.isnan(data[i]).any():
- result[0] += 1
- if torch.isinf(data[i]).any():
- result[1] += 1
- if result[0] > 0:
- result[0] += 10
- if result[1] > 0:
- result[1] += 10
- return result if sum(result) > 0 else None
- class SequenceSideInfo():
- def __init__(self, tokenizer=None):
- if tokenizer is not None:
- self.tokenizer = tokenizer
- else:
- from transformers import ElectraTokenizer
- self.tokenizer = ElectraTokenizer.from_pretrained(
- 'google/electra-small-generator')
- self.sen_tokenizer = nltk.tokenize.punkt.PunktSentenceTokenizer()
- tokens = [
- self.tokenizer.decode([i])
- for i in range(self.tokenizer.vocab_size)
- ]
- self.ind_subtokens = set(
- [i for i in range(len(tokens)) if tokens[i][0:2] == '##'])
- tmp = [
- 0 if t[0] == '[' and t[-1] == ']' else
- (10 + min(5,
- len(t) - 2) if t[0:2] == '##' else min(10, len(t)))
- for t in tokens
- ]
- self.len_tokens = torch.tensor(tmp, dtype=torch.int8)
- def getSenTokIdx(self, sentence_position_embedding, inputs_str,
- seq_len_total):
- sentences = self.sen_tokenizer.tokenize(inputs_str)
- sen_lengths = np.array([
- len(x) - 2
- for x in self.tokenizer.batch_encode_plus(sentences)['input_ids']
- ]) # -2: to drop the extra [CLS] and [SEP] added by sen_tokenizer
- sen_lengths[0] = seq_len_total - sen_lengths[1:].sum()
- idx_sen = np.concatenate([
- i * np.ones(sen_lengths[i], dtype=np.int8)
- for i in range(len(sen_lengths))
- ])
- idx_tok = np.concatenate([
- np.arange(sen_lengths[i], dtype=np.int8)
- for i in range(len(sen_lengths))
- ])
- return np.concatenate((idx_sen, idx_tok))
- def generate_seq_side_info(self, sentence_position_embedding, inputs_id):
- is_np_array = False
- if isinstance(inputs_id[0], (list, np.ndarray)):
- is_np_array = True
- inputs_id = torch.tensor(inputs_id)
- if hasattr(self.tokenizer, 'batch_decode'):
- inputs_str = self.tokenizer.batch_decode(inputs_id)
- sen_tok_idx = torch.tensor(
- np.array([
- self.getSenTokIdx(sentence_position_embedding, input_str,
- inputs_id.shape[1])
- for input_str in inputs_str
- ]),
- device=inputs_id.device)
- else:
- sen_tok_idx = torch.tensor(
- np.array([
- self.getSenTokIdx(sentence_position_embedding,
- self.tokenizer.decode(input_ori),
- inputs_id.shape[1])
- for input_ori in inputs_id.numpy()
- ]),
- device=inputs_id.device)
- side_info_dict = dict()
- seq_length = inputs_id.shape[1]
- side_info_dict[
- 'ss_sentence_position_in_sequence'] = sen_tok_idx[:, 0:seq_length]
- side_info_dict[
- 'ss_token_position_in_sentence'] = sen_tok_idx[:, 1 * seq_length:2
- * seq_length]
- if sentence_position_embedding >= 2:
- # consider sub-word tokens
- unique, _ = np.unique(inputs_id, return_inverse=True)
- ind_subtokens = self.ind_subtokens.intersection(set(unique))
- if len(ind_subtokens) > 0:
- idx_tok_ww = torch.stack([
- inputs_id == st for st in ind_subtokens
- ]).any(axis=0).char()
- else:
- idx_tok_ww = torch.zeros(inputs_id.shape, dtype=torch.int8)
- idx_tok_ww[:, 0] = 0
- idx_tok_ww_1 = idx_tok_ww[:, 1:]
- for i in range(1, 11):
- pos = torch.logical_and(idx_tok_ww_1 == i,
- idx_tok_ww[:, 0:-1] == i)
- if len(pos) == 0:
- break
- idx_tok_ww_1[pos] = i + 1
- side_info_dict['ss_token_position_in_whole_word'] = idx_tok_ww
- inputs_str_len = self.len_tokens[inputs_id.long()]
- side_info_dict['ss_token_string_length'] = inputs_str_len
- if is_np_array:
- for key in side_info_dict.keys():
- side_info_dict[key] = side_info_dict[key].numpy()
- return side_info_dict
|