language_identification_pipline.py 9.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259
  1. # Copyright (c) Alibaba, Inc. and its affiliates.
  2. import os
  3. import os.path as osp
  4. import re
  5. from typing import Any, Dict
  6. import numpy as np
  7. import tensorflow as tf
  8. from modelscope.metainfo import Pipelines
  9. from modelscope.models.base import Model
  10. from modelscope.outputs import OutputKeys
  11. from modelscope.pipelines.base import Pipeline
  12. from modelscope.pipelines.builder import PIPELINES
  13. from modelscope.utils.config import Config, ConfigFields
  14. from modelscope.utils.constant import ModelFile, Tasks
  15. from modelscope.utils.logger import get_logger
  16. if tf.__version__ >= '2.0':
  17. tf = tf.compat.v1
  18. tf.disable_eager_execution()
  19. logger = get_logger()
  20. __all__ = ['LanguageIdentificationPipeline']
  21. @PIPELINES.register_module(
  22. Tasks.text_classification, module_name=Pipelines.language_identification)
  23. class LanguageIdentificationPipeline(Pipeline):
  24. r""" Language Identification Pipeline.
  25. Examples:
  26. >>> from modelscope.pipelines import pipeline
  27. >>> from modelscope.utils.constant import Tasks
  28. >>> pipeline_ins = pipeline(Tasks.text_classification, 'damo/nlp_language_identification-classification-base')
  29. >>> pipeline_ins('Elon Musk, co-founder and chief executive officer of Tesla Motors.\n' \
  30. >>> 'Gleichzeitig nahm die Legion an der Befriedung Algeriens teil, die von.\n' \
  31. >>> '使用pipeline推理及在线体验功能的时候,尽量输入单句文本,如果是多句长文本建议人工分句。'
  32. >>> {
  33. >>> "labels":[
  34. >>> "en",
  35. >>> "de",
  36. >>> "zh"
  37. >>> ],
  38. >>> "scores":[
  39. >>> [('en', 0.99)],
  40. >>> [('de', 1.0)],
  41. >>> [('zh', 1.0)]
  42. >>> ]
  43. >>> }
  44. """
  45. def __init__(self, model: str, **kwargs):
  46. """Build a language identification pipeline with a model dir or a model id in the model hub.
  47. Args:
  48. model: A Model instance.
  49. """
  50. super().__init__(model=model, **kwargs)
  51. export_dir = model
  52. self.debug = False
  53. self.cfg = Config.from_file(
  54. os.path.join(export_dir, ModelFile.CONFIGURATION))
  55. joint_vocab_file = os.path.join(
  56. export_dir, self.cfg[ConfigFields.preprocessor]['vocab'])
  57. vocabfiles = []
  58. vocabfiles_reverse = []
  59. for i, w in enumerate(open(joint_vocab_file, 'rb')):
  60. w = w.strip()
  61. try:
  62. w = w.decode('utf-8')
  63. vocabfiles.append((w, i))
  64. vocabfiles_reverse.append((i, w))
  65. except UnicodeDecodeError:
  66. # [debug] print error info
  67. if self.debug:
  68. print('error vocab:', w, i)
  69. pass
  70. self.vocab = dict(vocabfiles)
  71. self.vocab_reverse = dict(vocabfiles_reverse)
  72. self.unk_id = self.vocab.get('<UNK>', 1)
  73. self.pad_id = self.vocab.get('</S>', 0)
  74. joint_label_file = os.path.join(
  75. export_dir, self.cfg[ConfigFields.preprocessor]['label'])
  76. self.label = dict([(i, w.strip()) for i, w in enumerate(
  77. open(joint_label_file, 'r', encoding='utf8'))])
  78. self.unk_label = 'unk'
  79. tf.reset_default_graph()
  80. tf_config = tf.ConfigProto(allow_soft_placement=True)
  81. tf_config.gpu_options.allow_growth = True
  82. self._session = tf.Session(config=tf_config)
  83. tf.saved_model.loader.load(self._session,
  84. [tf.saved_model.tag_constants.SERVING],
  85. export_dir)
  86. default_graph = tf.get_default_graph()
  87. # [debug] print graph ops
  88. if self.debug:
  89. for op in default_graph.get_operations():
  90. print(op.name, op.values())
  91. self.input_ids = default_graph.get_tensor_by_name('src_cid:0')
  92. output_label = default_graph.get_tensor_by_name('output_label:0')
  93. output_score = default_graph.get_tensor_by_name('predict_score:0')
  94. self.output = {
  95. 'output_ids': output_label,
  96. 'output_score': output_score
  97. }
  98. init = tf.global_variables_initializer()
  99. local_init = tf.local_variables_initializer()
  100. self._session.run([init, local_init])
  101. tf.saved_model.loader.load(self._session,
  102. [tf.saved_model.tag_constants.SERVING],
  103. export_dir)
  104. def _lid_preprocess(self, input: str) -> list:
  105. sentence = input.lower()
  106. # HtmlToText
  107. CLEANR = r'<.*?>|&([a-z0-9]+|#[0-9]{1,6}|#x[0-9a-f]{1,6});'
  108. sentence = re.sub(CLEANR, '', sentence)
  109. # RemoveLinks
  110. URLRE = r'\S+[./]\S+\s?'
  111. sentence = re.sub(URLRE, '', sentence)
  112. EMAILRE = r'\S*@\S*\s?'
  113. sentence = re.sub(EMAILRE, '', sentence)
  114. # SBC2DBC
  115. def stringpartQ2B(uchar):
  116. inside_code = ord(uchar)
  117. if 0xFF00 < inside_code or inside_code > 0xFF5F:
  118. inside_code -= 0xFEE0
  119. elif inside_code == 0x3000:
  120. inside_code = 0x0020
  121. elif inside_code in [
  122. 0x301D, 0x301E, 0x201C, 0x201D, 0x201E, 0x201F
  123. ]:
  124. inside_code = 0x0022
  125. elif inside_code in [0x2018, 0x2019, 0x201A, 0x201B]:
  126. inside_code = 0x0027
  127. return chr(inside_code)
  128. # RemoveNoisyChars
  129. m_noisyChars = ",-+\"\'\\&.!=:;°·$«»|±[]{}_?<>~^*/%#@(),。!《》?、`\xc2\xa0…‼️"
  130. sentence = ''.join([
  131. stringpartQ2B(c) if c not in m_noisyChars else ' '
  132. for c in sentence
  133. ])
  134. EMOJIRE = re.compile(
  135. '['
  136. u'\U0001F600-\U0001F64F' # emoticons
  137. u'\U0001F300-\U0001F5FF' # symbols & pictographs
  138. u'\U0001F680-\U0001F6FF' # transport & map symbols
  139. u'\U0001F1E0-\U0001F1FF' # flags (iOS)
  140. u'\U0001f926-\U0001f937' # emoji
  141. u'\U00010000-\U0010ffff' # char emoji
  142. u'\U00002702-\U000027B0' # char emoji
  143. u'\u2640-\u2642\u2600-\u2B55'
  144. u'\u200d\u23cf\u23e9\u231a\ufe0f\u3030' # dingbats
  145. ']+',
  146. re.UNICODE)
  147. sentence = re.sub(EMOJIRE, '', sentence)
  148. # RemoveDigitalWords
  149. sentence = ' '.join([
  150. item for item in sentence.split()
  151. if (not bool(re.search(r'\d', item))
  152. or not bool(re.match(r'^[a-z0-9+-_]+$', item)))
  153. ])
  154. # replaceBrandWords
  155. # wordCorrection
  156. # removeSpaces
  157. outids = []
  158. for w in sentence.strip():
  159. tmp = self.vocab.get(w, self.unk_id)
  160. if len(outids
  161. ) > 0 and tmp == self.unk_id and outids[-1] == self.unk_id:
  162. continue
  163. outids.append(tmp)
  164. if len(outids) > 0 and outids[0] == self.unk_id:
  165. outids = outids[1:]
  166. if len(outids) > 0 and outids[-1] == self.unk_id:
  167. outids = outids[:-1]
  168. return outids
  169. def preprocess(self, input: str) -> Dict[str, Any]:
  170. sentencelt = input.split('\n')
  171. input_ids_lt = [
  172. self._lid_preprocess(sentence) for sentence in sentencelt
  173. if sentence.strip() != ''
  174. ]
  175. # [debug] print info example:
  176. if self.debug:
  177. for sentence, input_ids in zip(sentencelt, input_ids_lt):
  178. print('raw:', sentence)
  179. print(
  180. 'res:', ''.join([
  181. self.vocab_reverse.get(wid, self.unk_id).replace(
  182. '<UNK>', ' ') for wid in input_ids
  183. ]))
  184. maxlen = max([len(ids) for ids in input_ids_lt])
  185. for ids in input_ids_lt:
  186. ids.extend([self.pad_id] * (maxlen - len(ids)))
  187. input_ids = np.array(input_ids_lt)
  188. result = {'input_ids': input_ids}
  189. return result
  190. def forward(self, input: Dict[str, Any]) -> Dict[str, Any]:
  191. with self._session.as_default():
  192. feed_dict = {self.input_ids: input['input_ids']}
  193. sess_outputs = self._session.run(self.output, feed_dict=feed_dict)
  194. return sess_outputs
  195. def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, Any]:
  196. output_scores_raw = inputs['output_score']
  197. supported_104_lang = set([
  198. 'af', 'am', 'ar', 'az', 'be', 'bg', 'bn', 'bs', 'ca', 'ce', 'co',
  199. 'cs', 'cy', 'da', 'de', 'el', 'en', 'eo', 'es', 'et', 'eu', 'fa',
  200. 'fi', 'fr', 'fy', 'ga', 'gd', 'gl', 'gu', 'ha', 'haw', 'he', 'hi',
  201. 'hmn', 'hr', 'ht', 'hu', 'hy', 'id', 'ig', 'is', 'it', 'ja', 'jv',
  202. 'ka', 'kk', 'km', 'kn', 'ko', 'ku', 'ky', 'la', 'lo', 'lt', 'lv',
  203. 'mg', 'mi', 'mk', 'ml', 'mn', 'mr', 'ms', 'mt', 'my', 'ne', 'nl',
  204. 'no', 'ny', 'pa', 'pl', 'ps', 'pt', 'ro', 'ru', 'sd', 'si', 'sk',
  205. 'sl', 'sm', 'sn', 'so', 'sq', 'sr', 'st', 'su', 'sv', 'sw', 'ta',
  206. 'te', 'tg', 'th', 'tl', 'tr', 'ug', 'uk', 'ur', 'uz', 'vi', 'xh',
  207. 'yi', 'yo', 'zh', 'zh-tw', 'zu'
  208. ])
  209. labels_scores_lt = []
  210. output_labels = []
  211. for output_score in output_scores_raw:
  212. tmplt = []
  213. for s, l in zip(output_score, self.label.values()):
  214. if l not in supported_104_lang:
  215. continue
  216. tmplt.append((l, s))
  217. tmplt = sorted(tmplt, key=lambda i: i[1], reverse=True)[:3]
  218. if len(tmplt) == 0:
  219. tmplt = [(0, 1.00)]
  220. labels_scores_lt.append(tmplt)
  221. output_labels.append(tmplt[0][0])
  222. output_scores = [[(label, round(score, 2))
  223. for label, score in labels_scores if score > 0.01]
  224. for labels_scores in labels_scores_lt]
  225. result = {
  226. OutputKeys.LABELS: output_labels,
  227. OutputKeys.SCORES: output_scores
  228. }
  229. return result