utils.py 6.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194
  1. # Copyright (c) Alibaba, Inc. and its affiliates.
  2. import logging
  3. from collections import OrderedDict
  4. import json
  5. import numpy as np
  6. from modelscope.utils.logger import get_logger
  7. from . import ontology
  8. logger = get_logger()
  9. def max_lens(X):
  10. lens = [len(X)]
  11. while isinstance(X[0], list):
  12. lens.append(max(map(len, X)))
  13. X = [x for xs in X for x in xs]
  14. return lens
  15. def list2np(X: object, padding: object = 0, dtype: object = 'int64') -> object:
  16. shape = max_lens(X)
  17. ret = np.full(shape, padding, dtype=np.int32)
  18. if len(shape) == 1:
  19. ret = np.array(X)
  20. elif len(shape) == 2:
  21. for i, x in enumerate(X):
  22. ret[i, :len(x)] = np.array(x)
  23. elif len(shape) == 3:
  24. for i, xs in enumerate(X):
  25. for j, x in enumerate(xs):
  26. ret[i, j, :len(x)] = np.array(x)
  27. return ret.astype(dtype)
  28. def clean_replace(s, r, t, forward=True, backward=False):
  29. def clean_replace_single(s, r, t, forward, backward, sidx=0):
  30. # idx = s[sidx:].find(r)
  31. idx = s.find(r)
  32. if idx == -1:
  33. return s, -1
  34. idx_r = idx + len(r)
  35. if backward:
  36. while idx > 0 and s[idx - 1]:
  37. idx -= 1
  38. elif idx > 0 and s[idx - 1] != ' ':
  39. return s, -1
  40. if forward:
  41. while \
  42. idx_r < len(s) and (s[idx_r].isalpha() or s[idx_r].isdigit()):
  43. idx_r += 1
  44. elif idx_r != len(s) and (s[idx_r].isalpha() or s[idx_r].isdigit()):
  45. return s, -1
  46. return s[:idx] + t + s[idx_r:], idx_r
  47. sidx = 0
  48. while sidx != -1:
  49. s, sidx = clean_replace_single(s, r, t, forward, backward, sidx)
  50. return s
  51. def py2np(list):
  52. return np.array(list)
  53. def write_dict(fn, dic):
  54. with open(fn, 'w') as f:
  55. json.dump(dic, f, indent=2)
  56. def f1_score(label_list, pred_list):
  57. tp = len([t for t in pred_list if t in label_list])
  58. fp = max(0, len(pred_list) - tp)
  59. fn = max(0, len(label_list) - tp)
  60. precision = tp / (tp + fp + 1e-10)
  61. recall = tp / (tp + fn + 1e-10)
  62. f1 = 2 * precision * recall / (precision + recall + 1e-10)
  63. return f1
  64. class MultiWOZVocab(object):
  65. def __init__(self, vocab_size=0):
  66. """
  67. vocab for multiwoz dataset
  68. """
  69. self.vocab_size = vocab_size
  70. self.vocab_size_oov = 0 # get after construction
  71. self._idx2word = {} # word + oov
  72. self._word2idx = {} # word
  73. self._freq_dict = {} # word + oov
  74. for w in [
  75. '[PAD]', '<go_r>', '[UNK]', '<go_b>', '<go_a>', '<eos_u>',
  76. '<eos_r>', '<eos_b>', '<eos_a>', '<go_d>', '<eos_d>'
  77. ]:
  78. self._absolute_add_word(w)
  79. def _absolute_add_word(self, w):
  80. idx = len(self._idx2word)
  81. self._idx2word[idx] = w
  82. self._word2idx[w] = idx
  83. def add_word(self, word):
  84. if word not in self._freq_dict:
  85. self._freq_dict[word] = 0
  86. self._freq_dict[word] += 1
  87. def has_word(self, word):
  88. return self._freq_dict.get(word)
  89. def _add_to_vocab(self, word):
  90. if word not in self._word2idx:
  91. idx = len(self._idx2word)
  92. self._idx2word[idx] = word
  93. self._word2idx[word] = idx
  94. def construct(self):
  95. freq_dict_sorted = sorted(
  96. self._freq_dict.keys(), key=lambda x: -self._freq_dict[x])
  97. logger.info('Vocabulary size including oov: %d' %
  98. (len(freq_dict_sorted) + len(self._idx2word)))
  99. if len(freq_dict_sorted) + len(self._idx2word) < self.vocab_size:
  100. logging.warning(
  101. 'actual label set smaller than that configured: {}/{}'.format(
  102. len(freq_dict_sorted) + len(self._idx2word),
  103. self.vocab_size))
  104. for word in ontology.all_domains + ['general']:
  105. word = '[' + word + ']'
  106. self._add_to_vocab(word)
  107. for word in ontology.all_acts:
  108. word = '[' + word + ']'
  109. self._add_to_vocab(word)
  110. for word in ontology.all_slots:
  111. self._add_to_vocab(word)
  112. for word in freq_dict_sorted:
  113. if word.startswith('[value_') and word.endswith(']'):
  114. self._add_to_vocab(word)
  115. for word in freq_dict_sorted:
  116. self._add_to_vocab(word)
  117. self.vocab_size_oov = len(self._idx2word)
  118. def load_vocab(self, vocab_path):
  119. self._freq_dict = json.loads(
  120. open(vocab_path + '.freq.json', 'r', encoding='utf-8').read())
  121. self._word2idx = json.loads(
  122. open(vocab_path + '.word2idx.json', 'r', encoding='utf-8').read())
  123. self._idx2word = {}
  124. for w, idx in self._word2idx.items():
  125. self._idx2word[idx] = w
  126. self.vocab_size_oov = len(self._idx2word)
  127. logger.info('vocab file loaded from "' + vocab_path + '"')
  128. logger.info('Vocabulary size including oov: %d' %
  129. (self.vocab_size_oov))
  130. def save_vocab(self, vocab_path):
  131. _freq_dict = OrderedDict(
  132. sorted(
  133. self._freq_dict.items(), key=lambda kv: kv[1], reverse=True))
  134. write_dict(vocab_path + '.word2idx.json', self._word2idx)
  135. write_dict(vocab_path + '.freq.json', _freq_dict)
  136. def encode(self, word, include_oov=True):
  137. if include_oov:
  138. if self._word2idx.get(word, None) is None:
  139. raise ValueError(
  140. 'Unknown word: %s. Vocabulary should include oovs here.'
  141. % word)
  142. return self._word2idx[word]
  143. else:
  144. word = '<unk>' if word not in self._word2idx else word
  145. return self._word2idx[word]
  146. def sentence_encode(self, word_list):
  147. return [self.encode(_) for _ in word_list]
  148. def oov_idx_map(self, idx):
  149. return 2 if idx > self.vocab_size else idx
  150. def sentence_oov_map(self, index_list):
  151. return [self.oov_idx_map(_) for _ in index_list]
  152. def decode(self, idx, indicate_oov=False):
  153. if not self._idx2word.get(idx):
  154. raise ValueError(
  155. 'Error idx: %d. Vocabulary should include oovs here.' % idx)
  156. if not indicate_oov or idx < self.vocab_size:
  157. return self._idx2word[idx]
  158. else:
  159. return self._idx2word[idx] + '(o)'