sas_utils.py 6.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173
  1. import random
  2. import nltk
  3. import numpy as np
  4. import torch
  5. def get_random_states(device=None):
  6. random_states = {}
  7. random_states['rng_state_torch'] = torch.get_rng_state()
  8. random_states['rng_state_np'] = np.random.get_state()
  9. random_states['rng_state_rnd'] = random.getstate()
  10. if device is not None and device.type == 'cuda':
  11. random_states['rng_state_torch_cuda'] = torch.cuda.get_rng_state(
  12. device)
  13. return random_states
  14. def set_random_states(random_states, device=None):
  15. torch.set_rng_state(random_states['rng_state_torch'])
  16. np.random.set_state(random_states['rng_state_np'])
  17. random.setstate(random_states['rng_state_rnd'])
  18. if device is not None and device.type == 'cuda':
  19. torch.cuda.set_rng_state(random_states['rng_state_torch_cuda'])
  20. # Check any nan or inf in the data. Return an array of two elements for nan and inf, respectively.
  21. # Inputs
  22. # data: a tensor or a tuple of multiple tensors
  23. # Outputs:
  24. # results: Each element shows the # of tensors that includes nan or inf.
  25. # If data is a "tuple" (instead of a single tensor),
  26. # we add 10 to the count if any nan or inf is detected.
  27. def check_nan_inf(data):
  28. if data is None:
  29. return None
  30. result = [0, 0]
  31. if torch.is_tensor(data):
  32. if torch.isnan(data).any():
  33. result[0] = 1
  34. if torch.isinf(data).any():
  35. result[1] = 1
  36. elif type(data) is tuple:
  37. for i in range(len(data)):
  38. if torch.is_tensor(data[i]):
  39. if torch.isnan(data[i]).any():
  40. result[0] += 1
  41. if torch.isinf(data[i]).any():
  42. result[1] += 1
  43. if result[0] > 0:
  44. result[0] += 10
  45. if result[1] > 0:
  46. result[1] += 10
  47. return result if sum(result) > 0 else None
  48. class SequenceSideInfo():
  49. def __init__(self, tokenizer=None):
  50. if tokenizer is not None:
  51. self.tokenizer = tokenizer
  52. else:
  53. from transformers import ElectraTokenizer
  54. self.tokenizer = ElectraTokenizer.from_pretrained(
  55. 'google/electra-small-generator')
  56. self.sen_tokenizer = nltk.tokenize.punkt.PunktSentenceTokenizer()
  57. tokens = [
  58. self.tokenizer.decode([i])
  59. for i in range(self.tokenizer.vocab_size)
  60. ]
  61. self.ind_subtokens = set(
  62. [i for i in range(len(tokens)) if tokens[i][0:2] == '##'])
  63. tmp = [
  64. 0 if t[0] == '[' and t[-1] == ']' else
  65. (10 + min(5,
  66. len(t) - 2) if t[0:2] == '##' else min(10, len(t)))
  67. for t in tokens
  68. ]
  69. self.len_tokens = torch.tensor(tmp, dtype=torch.int8)
  70. def getSenTokIdx(self, sentence_position_embedding, inputs_str,
  71. seq_len_total):
  72. sentences = self.sen_tokenizer.tokenize(inputs_str)
  73. sen_lengths = np.array([
  74. len(x) - 2
  75. for x in self.tokenizer.batch_encode_plus(sentences)['input_ids']
  76. ]) # -2: to drop the extra [CLS] and [SEP] added by sen_tokenizer
  77. sen_lengths[0] = seq_len_total - sen_lengths[1:].sum()
  78. idx_sen = np.concatenate([
  79. i * np.ones(sen_lengths[i], dtype=np.int8)
  80. for i in range(len(sen_lengths))
  81. ])
  82. idx_tok = np.concatenate([
  83. np.arange(sen_lengths[i], dtype=np.int8)
  84. for i in range(len(sen_lengths))
  85. ])
  86. return np.concatenate((idx_sen, idx_tok))
  87. def generate_seq_side_info(self, sentence_position_embedding, inputs_id):
  88. is_np_array = False
  89. if isinstance(inputs_id[0], (list, np.ndarray)):
  90. is_np_array = True
  91. inputs_id = torch.tensor(inputs_id)
  92. if hasattr(self.tokenizer, 'batch_decode'):
  93. inputs_str = self.tokenizer.batch_decode(inputs_id)
  94. sen_tok_idx = torch.tensor(
  95. np.array([
  96. self.getSenTokIdx(sentence_position_embedding, input_str,
  97. inputs_id.shape[1])
  98. for input_str in inputs_str
  99. ]),
  100. device=inputs_id.device)
  101. else:
  102. sen_tok_idx = torch.tensor(
  103. np.array([
  104. self.getSenTokIdx(sentence_position_embedding,
  105. self.tokenizer.decode(input_ori),
  106. inputs_id.shape[1])
  107. for input_ori in inputs_id.numpy()
  108. ]),
  109. device=inputs_id.device)
  110. side_info_dict = dict()
  111. seq_length = inputs_id.shape[1]
  112. side_info_dict[
  113. 'ss_sentence_position_in_sequence'] = sen_tok_idx[:, 0:seq_length]
  114. side_info_dict[
  115. 'ss_token_position_in_sentence'] = sen_tok_idx[:, 1 * seq_length:2
  116. * seq_length]
  117. if sentence_position_embedding >= 2:
  118. # consider sub-word tokens
  119. unique, _ = np.unique(inputs_id, return_inverse=True)
  120. ind_subtokens = self.ind_subtokens.intersection(set(unique))
  121. if len(ind_subtokens) > 0:
  122. idx_tok_ww = torch.stack([
  123. inputs_id == st for st in ind_subtokens
  124. ]).any(axis=0).char()
  125. else:
  126. idx_tok_ww = torch.zeros(inputs_id.shape, dtype=torch.int8)
  127. idx_tok_ww[:, 0] = 0
  128. idx_tok_ww_1 = idx_tok_ww[:, 1:]
  129. for i in range(1, 11):
  130. pos = torch.logical_and(idx_tok_ww_1 == i,
  131. idx_tok_ww[:, 0:-1] == i)
  132. if len(pos) == 0:
  133. break
  134. idx_tok_ww_1[pos] = i + 1
  135. side_info_dict['ss_token_position_in_whole_word'] = idx_tok_ww
  136. inputs_str_len = self.len_tokens[inputs_id.long()]
  137. side_info_dict['ss_token_string_length'] = inputs_str_len
  138. if is_np_array:
  139. for key in side_info_dict.keys():
  140. side_info_dict[key] = side_info_dict[key].numpy()
  141. return side_info_dict