speaker.py 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314
  1. # Copyright (c) Alibaba, Inc. and its affiliates.
  2. from typing import Any, Dict, List, Tuple, Union
  3. import torch
  4. from modelscope.metainfo import Preprocessors
  5. from modelscope.preprocessors import Preprocessor
  6. from modelscope.preprocessors.builder import PREPROCESSORS
  7. from modelscope.preprocessors.nlp.text_classification_preprocessor import \
  8. TextClassificationPreprocessorBase
  9. from modelscope.preprocessors.nlp.token_classification_preprocessor import (
  10. NLPTokenizerForLSTM, TokenClassificationPreprocessorBase)
  11. from modelscope.preprocessors.nlp.transformers_tokenizer import NLPTokenizer
  12. from modelscope.utils.constant import Fields, ModeKeys
  13. from modelscope.utils.hub import get_model_type, parse_label_mapping
  14. from modelscope.utils.logger import get_logger
  15. logger = get_logger()
  16. @PREPROCESSORS.register_module(
  17. Fields.audio, module_name=Preprocessors.sen_cls_tokenizer)
  18. class SpeakerDiarizationDialogueDetectionPreprocessor(
  19. TextClassificationPreprocessorBase):
  20. def _tokenize_text(self, sequence1, sequence2=None, **kwargs):
  21. if 'return_tensors' not in kwargs:
  22. kwargs[
  23. 'return_tensors'] = 'pt' if self.mode == ModeKeys.INFERENCE else None
  24. return self.nlp_tokenizer(sequence1, sequence2, **kwargs)
  25. def __init__(self,
  26. model_dir=None,
  27. first_sequence: str = None,
  28. second_sequence: str = None,
  29. label: Union[str, List] = 'label',
  30. label2id: Dict = None,
  31. mode: str = ModeKeys.INFERENCE,
  32. max_length: int = None,
  33. use_fast: bool = None,
  34. keep_original_columns=None,
  35. **kwargs):
  36. kwargs['truncation'] = kwargs.get('truncation', True)
  37. kwargs['padding'] = kwargs.get('padding', 'max_length')
  38. kwargs[
  39. 'max_length'] = max_length if max_length is not None else kwargs.get(
  40. 'sequence_length', 128)
  41. kwargs.pop('sequence_length', None)
  42. model_type = None
  43. if model_dir is not None:
  44. model_type = get_model_type(model_dir)
  45. self.nlp_tokenizer = NLPTokenizer(
  46. model_dir, model_type, use_fast=use_fast, tokenize_kwargs=kwargs)
  47. super().__init__(model_dir, first_sequence, second_sequence, label,
  48. label2id, mode, keep_original_columns)
  49. @PREPROCESSORS.register_module(
  50. Fields.audio, module_name=Preprocessors.token_cls_tokenizer)
  51. class SpeakerDiarizationSemanticSpeakerTurnDetectionPreprocessor(
  52. TokenClassificationPreprocessorBase):
  53. def __init__(self,
  54. model_dir: str = None,
  55. first_sequence: str = 'text',
  56. label: str = 'label',
  57. label2id: Dict = None,
  58. label_all_tokens: bool = False,
  59. mode: str = ModeKeys.INFERENCE,
  60. max_length=None,
  61. use_fast=None,
  62. keep_original_columns=None,
  63. return_text=True,
  64. **kwargs):
  65. super().__init__(model_dir, first_sequence, label, label2id,
  66. label_all_tokens, mode, keep_original_columns,
  67. return_text)
  68. model_type = None
  69. if model_dir is not None:
  70. model_type = get_model_type(model_dir)
  71. kwargs['truncation'] = kwargs.get('truncation', True)
  72. kwargs['padding'] = kwargs.get('padding', 'max_length')
  73. kwargs[
  74. 'max_length'] = max_length if max_length is not None else kwargs.get(
  75. 'sequence_length', 128)
  76. kwargs.pop('sequence_length', None)
  77. kwargs['add_special_tokens'] = model_type != 'lstm'
  78. self.nlp_tokenizer = NLPTokenizerForLSTM(
  79. model_dir=model_dir,
  80. model_type=model_type,
  81. use_fast=use_fast,
  82. tokenize_kwargs=kwargs)
  83. def _tokenize_text(self, text: Union[str, List[str]], **kwargs):
  84. tokens = text
  85. if self.mode != ModeKeys.INFERENCE:
  86. assert isinstance(tokens, list), 'Input needs to be lists in training and evaluating,' \
  87. 'because the length of the words and the labels need to be equal.'
  88. is_split_into_words = self.nlp_tokenizer.get_tokenizer_kwarg(
  89. 'is_split_into_words', False)
  90. if is_split_into_words:
  91. # for supporting prompt seperator, should split twice. [SEP] for default.
  92. sep_idx = tokens.find('[SEP]')
  93. if sep_idx == -1 or self.is_lstm_model:
  94. tokens = list(tokens)
  95. else:
  96. tmp_tokens = []
  97. tmp_tokens.extend(list(tokens[:sep_idx]))
  98. tmp_tokens.append('[SEP]')
  99. tmp_tokens.extend(list(tokens[sep_idx + 5:]))
  100. tokens = tmp_tokens
  101. if is_split_into_words and self.mode == ModeKeys.INFERENCE:
  102. encodings, word_ids = self._tokenize_text_by_words(
  103. tokens, **kwargs)
  104. elif self.nlp_tokenizer.tokenizer.is_fast:
  105. encodings, word_ids = self._tokenize_text_with_fast_tokenizer(
  106. tokens, **kwargs)
  107. else:
  108. encodings, word_ids = self._tokenize_text_with_slow_tokenizer(
  109. tokens, **kwargs)
  110. sep_idx = -1
  111. for idx, token_id in enumerate(encodings['input_ids']):
  112. if token_id == self.nlp_tokenizer.tokenizer.sep_token_id:
  113. sep_idx = idx
  114. break
  115. if sep_idx != -1:
  116. for i in range(sep_idx, len(encodings['label_mask'])):
  117. encodings['label_mask'][i] = False
  118. if self.mode == ModeKeys.INFERENCE:
  119. for key in encodings.keys():
  120. encodings[key] = torch.tensor(encodings[key]).unsqueeze(0)
  121. else:
  122. encodings.pop('offset_mapping', None)
  123. return encodings, word_ids
  124. def _tokenize_text_by_words(self, tokens, **kwargs):
  125. input_ids = []
  126. label_mask = []
  127. offset_mapping = []
  128. attention_mask = []
  129. for offset, token in enumerate(tokens):
  130. subtoken_ids = self.nlp_tokenizer.tokenizer.encode(
  131. token, add_special_tokens=False)
  132. if len(subtoken_ids) == 0:
  133. subtoken_ids = [self.nlp_tokenizer.tokenizer.unk_token_id]
  134. input_ids.extend(subtoken_ids)
  135. attention_mask.extend([1] * len(subtoken_ids))
  136. label_mask.extend([True] + [False] * (len(subtoken_ids) - 1))
  137. offset_mapping.extend([(offset, offset + 1)])
  138. padding = kwargs.get('padding',
  139. self.nlp_tokenizer.get_tokenizer_kwarg('padding'))
  140. max_length = kwargs.get(
  141. 'max_length',
  142. kwargs.get('sequence_length',
  143. self.nlp_tokenizer.get_tokenizer_kwarg('max_length')))
  144. special_token = 1 if self.nlp_tokenizer.get_tokenizer_kwarg(
  145. 'add_special_tokens') else 0
  146. if len(label_mask) > max_length - 2 * special_token:
  147. label_mask = label_mask[:(max_length - 2 * special_token)]
  148. input_ids = input_ids[:(max_length - 2 * special_token)]
  149. offset_mapping = offset_mapping[:sum(label_mask)]
  150. if padding == 'max_length':
  151. label_mask = [False] * special_token + label_mask + \
  152. [False] * (max_length - len(label_mask) - special_token)
  153. offset_mapping = offset_mapping + [(0, 0)] * (
  154. max_length - len(offset_mapping))
  155. input_ids = [self.nlp_tokenizer.tokenizer.cls_token_id] * special_token + input_ids + \
  156. [self.nlp_tokenizer.tokenizer.sep_token_id] * special_token + \
  157. [self.nlp_tokenizer.tokenizer.pad_token_id] * (max_length - len(input_ids) - 2 * special_token)
  158. attention_mask = attention_mask + [1] * (
  159. special_token * 2) + [0] * (
  160. max_length - len(attention_mask) - 2 * special_token)
  161. else:
  162. label_mask = [False] * special_token + label_mask + \
  163. [False] * special_token
  164. input_ids = [self.nlp_tokenizer.tokenizer.cls_token_id] * special_token + input_ids + \
  165. [self.nlp_tokenizer.tokenizer.sep_token_id] * special_token
  166. attention_mask = attention_mask + [1] * (special_token * 2)
  167. encodings = {
  168. 'input_ids': input_ids,
  169. 'attention_mask': attention_mask,
  170. 'label_mask': label_mask,
  171. 'offset_mapping': offset_mapping,
  172. }
  173. return encodings, None
  174. def _tokenize_text_with_fast_tokenizer(self, tokens, **kwargs):
  175. is_split_into_words = isinstance(tokens, list)
  176. encodings = self.nlp_tokenizer(
  177. tokens,
  178. return_offsets_mapping=True,
  179. is_split_into_words=is_split_into_words,
  180. **kwargs)
  181. label_mask = []
  182. word_ids = encodings.word_ids()
  183. offset_mapping = []
  184. for i in range(len(word_ids)):
  185. if word_ids[i] is None:
  186. label_mask.append(False)
  187. elif word_ids[i] == word_ids[i - 1]:
  188. label_mask.append(False)
  189. if not is_split_into_words:
  190. offset_mapping[-1] = (offset_mapping[-1][0],
  191. encodings['offset_mapping'][i][1])
  192. else:
  193. label_mask.append(True)
  194. if is_split_into_words:
  195. offset_mapping.append((word_ids[i], word_ids[i] + 1))
  196. else:
  197. offset_mapping.append(encodings['offset_mapping'][i])
  198. padding = self.nlp_tokenizer.get_tokenizer_kwarg('padding')
  199. if padding == 'max_length':
  200. offset_mapping = offset_mapping + [(0, 0)] * (
  201. len(label_mask) - len(offset_mapping))
  202. encodings['offset_mapping'] = offset_mapping
  203. encodings['label_mask'] = label_mask
  204. return encodings, word_ids
  205. def _tokenize_text_with_slow_tokenizer(self, tokens, **kwargs):
  206. assert self.mode == ModeKeys.INFERENCE and isinstance(tokens, str), \
  207. 'Slow tokenizer now only support str input in inference mode. If you are training models, ' \
  208. 'please consider using the fast tokenizer.'
  209. word_ids = None
  210. encodings = self.nlp_tokenizer(
  211. tokens, is_split_into_words=False, **kwargs)
  212. tokenizer_name = self.nlp_tokenizer.get_tokenizer_class()
  213. method = 'get_label_mask_and_offset_mapping_' + tokenizer_name
  214. if not hasattr(self, method):
  215. raise RuntimeError(
  216. f'No `{method}` method defined for '
  217. f'tokenizer {tokenizer_name}, please use a fast tokenizer instead, or '
  218. f'try to implement a `{method}` method')
  219. label_mask, offset_mapping = getattr(self, method)(tokens)
  220. padding = kwargs.get('padding',
  221. self.nlp_tokenizer.get_tokenizer_kwarg('padding'))
  222. max_length = kwargs.get(
  223. 'max_length', self.nlp_tokenizer.get_tokenizer_kwarg('max_length'))
  224. special_token = 1 if kwargs.get(
  225. 'add_special_tokens',
  226. self.nlp_tokenizer.get_tokenizer_kwarg(
  227. 'add_special_tokens')) else 0
  228. if len(label_mask) > max_length - 2 * special_token:
  229. label_mask = label_mask[:(max_length - 2 * special_token)]
  230. offset_mapping = offset_mapping[:sum(label_mask)]
  231. if padding == 'max_length':
  232. label_mask = [False] * special_token + label_mask + \
  233. [False] * (max_length - len(label_mask) - special_token)
  234. offset_mapping = offset_mapping + [(0, 0)] * (
  235. max_length - len(offset_mapping))
  236. else:
  237. label_mask = [False] * special_token + label_mask + \
  238. [False] * special_token
  239. encodings['offset_mapping'] = offset_mapping
  240. encodings['label_mask'] = label_mask
  241. return encodings, word_ids
  242. def get_label_mask_and_offset_mapping_BertTokenizer(self, text):
  243. label_mask = []
  244. offset_mapping = []
  245. tokens = self.nlp_tokenizer.tokenizer.tokenize(text)
  246. offset = 0
  247. for token in tokens:
  248. is_start = (token[:2] != '##')
  249. if is_start:
  250. label_mask.append(True)
  251. else:
  252. token = token[2:]
  253. label_mask.append(False)
  254. start = offset + text[offset:].index(token)
  255. end = start + len(token)
  256. if is_start:
  257. offset_mapping.append((start, end))
  258. else:
  259. offset_mapping[-1] = (offset_mapping[-1][0], end)
  260. offset = end
  261. return label_mask, offset_mapping
  262. def get_label_mask_and_offset_mapping_XLMRobertaTokenizer(self, text):
  263. label_mask = []
  264. offset_mapping = []
  265. tokens = self.nlp_tokenizer.tokenizer.tokenize(text)
  266. offset = 0
  267. last_is_blank = False
  268. for token in tokens:
  269. is_start = (token[0] == '_')
  270. if is_start:
  271. token = token[1:]
  272. label_mask.append(True)
  273. if len(token) == 0:
  274. last_is_blank = True
  275. continue
  276. else:
  277. label_mask.append(False)
  278. start = offset + text[offset:].index(token)
  279. end = start + len(token)
  280. if last_is_blank or is_start:
  281. offset_mapping.append((start, end))
  282. else:
  283. offset_mapping[-1] = (offset_mapping[-1][0], end)
  284. offset = end
  285. last_is_blank = False
  286. return label_mask, offset_mapping