interactive_translation_pipeline.py 6.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170
  1. # Copyright (c) Alibaba, Inc. and its affiliates.
  2. import os.path as osp
  3. from typing import Any, Dict
  4. import jieba
  5. import numpy as np
  6. import tensorflow as tf
  7. from sacremoses import MosesDetokenizer, MosesPunctNormalizer, MosesTokenizer
  8. from subword_nmt import apply_bpe
  9. from modelscope.metainfo import Pipelines
  10. from modelscope.models.base import Model
  11. from modelscope.outputs import OutputKeys
  12. from modelscope.pipelines.base import Pipeline
  13. from modelscope.pipelines.builder import PIPELINES
  14. from modelscope.pipelines.nlp.translation_pipeline import TranslationPipeline
  15. from modelscope.utils.config import Config
  16. from modelscope.utils.constant import ModelFile, Tasks
  17. from modelscope.utils.logger import get_logger
  18. if tf.__version__ >= '2.0':
  19. tf = tf.compat.v1
  20. tf.disable_eager_execution()
  21. logger = get_logger()
  22. __all__ = ['InteractiveTranslationPipeline']
  23. @PIPELINES.register_module(
  24. Tasks.translation, module_name=Pipelines.interactive_translation)
  25. class InteractiveTranslationPipeline(TranslationPipeline):
  26. def __init__(self, model: Model, **kwargs):
  27. """Build a interactive translation pipeline with a model dir or a model id in the model hub.
  28. Args:
  29. model (`str` or `Model` or module instance): A model instance or a model local dir
  30. or a model id in the model hub.
  31. Example:
  32. >>> from modelscope.pipelines import pipeline
  33. >>> pipeline_ins = pipeline(task=Tasks.translation,
  34. model='damo/nlp_imt_translation_zh2en')
  35. >>> input_sequence = 'Elon Musk, co-founder and chief executive officer of Tesla Motors.'
  36. >>> input_prefix = "特斯拉汽车公司"
  37. >>> print(pipeline_ins(input_sequence + "<PREFIX_SPLIT>" + input_prefix))
  38. """
  39. super().__init__(model=model, **kwargs)
  40. model = self.model.model_dir
  41. tf.reset_default_graph()
  42. model_path = osp.join(
  43. osp.join(model, ModelFile.TF_CHECKPOINT_FOLDER), 'ckpt-0')
  44. self._trg_vocab = dict([
  45. (w.strip(), i) for i, w in enumerate(open(self._trg_vocab_path))
  46. ])
  47. self._len_tgt_vocab = len(self._trg_rvocab)
  48. self.input_wids = tf.placeholder(
  49. dtype=tf.int64, shape=[None, None], name='input_wids')
  50. self.prefix_wids = tf.placeholder(
  51. dtype=tf.int64, shape=[None, None], name='prefix_wids')
  52. self.prefix_hit = tf.placeholder(
  53. dtype=tf.bool, shape=[None, None], name='prefix_hit')
  54. self.output = {}
  55. # preprocess
  56. if self._tgt_lang == 'zh':
  57. self._tgt_tok = jieba
  58. else:
  59. self._tgt_punct_normalizer = MosesPunctNormalizer(
  60. lang=self._tgt_lang)
  61. self._tgt_tok = MosesTokenizer(lang=self._tgt_lang)
  62. # model
  63. output = self.model(self.input_wids, None, self.prefix_wids,
  64. self.prefix_hit)
  65. self.output.update(output)
  66. tf_config = tf.ConfigProto(allow_soft_placement=True)
  67. tf_config.gpu_options.allow_growth = True
  68. self._session = tf.Session(config=tf_config)
  69. with self._session.as_default() as sess:
  70. logger.info(f'loading model from {model_path}')
  71. # load model
  72. model_loader = tf.train.Saver(tf.global_variables())
  73. model_loader.restore(sess, model_path)
  74. def preprocess(self, input: str) -> Dict[str, Any]:
  75. input_src, prefix = input.split('<PREFIX_SPLIT>', 1)
  76. if self._src_lang == 'zh':
  77. input_tok = self._tok.cut(input_src)
  78. input_tok = ' '.join(list(input_tok))
  79. else:
  80. input_src = self._punct_normalizer.normalize(input_src)
  81. input_tok = self._tok.tokenize(
  82. input_src, return_str=True, aggressive_dash_splits=True)
  83. if self._tgt_lang == 'zh':
  84. prefix = self._tgt_tok.lcut(prefix)
  85. prefix_tok = ' '.join(list(prefix)[:-1])
  86. else:
  87. prefix = self._tgt_punct_normalizer.normalize(prefix)
  88. prefix = self._tgt_tok.tokenize(
  89. prefix, return_str=True, aggressive_dash_splits=True).split()
  90. prefix_tok = ' '.join(prefix[:-1])
  91. if len(list(prefix)) > 0:
  92. subword = list(prefix)[-1]
  93. else:
  94. subword = ''
  95. input_bpe = self._bpe.process_line(input_tok)
  96. prefix_bpe = self._bpe.process_line(prefix_tok)
  97. input_ids = np.array([[
  98. self._src_vocab[w]
  99. if w in self._src_vocab else self.cfg['model']['src_vocab_size']
  100. for w in input_bpe.strip().split()
  101. ]])
  102. prefix_ids = np.array([[
  103. self._trg_vocab[w]
  104. if w in self._trg_vocab else self.cfg['model']['trg_vocab_size']
  105. for w in prefix_bpe.strip().split()
  106. ]])
  107. prefix_hit = [[0] * (self._len_tgt_vocab + 1)]
  108. if subword != '':
  109. hit_state = False
  110. for i, w in self._trg_rvocab.items():
  111. if w.startswith(subword):
  112. prefix_hit[0][i] = 1
  113. hit_state = True
  114. if hit_state is False:
  115. prefix_hit = [[1] * (self._len_tgt_vocab + 1)]
  116. result = {
  117. 'input_ids': input_ids,
  118. 'prefix_ids': prefix_ids,
  119. 'prefix_hit': np.array(prefix_hit)
  120. }
  121. return result
  122. def forward(self, input: Dict[str, Any]) -> Dict[str, Any]:
  123. with self._session.as_default():
  124. feed_dict = {
  125. self.input_wids: input['input_ids'],
  126. self.prefix_wids: input['prefix_ids'],
  127. self.prefix_hit: input['prefix_hit']
  128. }
  129. sess_outputs = self._session.run(self.output, feed_dict=feed_dict)
  130. return sess_outputs
  131. def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, Any]:
  132. output_seqs = inputs['output_seqs'][0]
  133. wids = list(output_seqs[0]) + [0]
  134. wids = wids[:wids.index(0)]
  135. translation_out = ' '.join([
  136. self._trg_rvocab[wid] if wid in self._trg_rvocab else '<unk>'
  137. for wid in wids
  138. ]).replace('@@ ', '').replace('@@', '')
  139. translation_out = self._detok.detokenize(translation_out.split())
  140. result = {OutputKeys.TRANSLATION: translation_out}
  141. return result