translation_pipeline.py 5.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146
  1. # Copyright (c) Alibaba, Inc. and its affiliates.
  2. import os.path as osp
  3. from typing import Any, Dict
  4. import jieba
  5. import numpy as np
  6. import tensorflow as tf
  7. from sacremoses import MosesDetokenizer, MosesPunctNormalizer, MosesTokenizer
  8. from subword_nmt import apply_bpe
  9. from modelscope.metainfo import Pipelines
  10. from modelscope.models.base import Model
  11. from modelscope.outputs import OutputKeys
  12. from modelscope.pipelines.base import Pipeline
  13. from modelscope.pipelines.builder import PIPELINES
  14. from modelscope.utils.config import Config
  15. from modelscope.utils.constant import ModelFile, Tasks
  16. from modelscope.utils.logger import get_logger
  17. if tf.__version__ >= '2.0':
  18. tf = tf.compat.v1
  19. tf.disable_eager_execution()
  20. logger = get_logger()
  21. __all__ = ['TranslationPipeline']
  22. @PIPELINES.register_module(
  23. Tasks.translation, module_name=Pipelines.csanmt_translation)
  24. class TranslationPipeline(Pipeline):
  25. def __init__(self, model: Model, **kwargs):
  26. """Build a translation pipeline with a model dir or a model id in the model hub.
  27. Args:
  28. model: A Model instance.
  29. """
  30. super().__init__(model=model, **kwargs)
  31. assert isinstance(self.model, Model), \
  32. f'please check whether model config exists in {ModelFile.CONFIGURATION}'
  33. model = self.model.model_dir
  34. tf.reset_default_graph()
  35. self.model_path = osp.join(
  36. osp.join(model, ModelFile.TF_CHECKPOINT_FOLDER), 'ckpt-0')
  37. self.cfg = Config.from_file(osp.join(model, ModelFile.CONFIGURATION))
  38. self._src_vocab_path = osp.join(
  39. model, self.cfg['dataset']['src_vocab']['file'])
  40. self._src_vocab = dict([(w.strip(), i) for i, w in enumerate(
  41. open(self._src_vocab_path, encoding='utf-8'))])
  42. self._trg_vocab_path = osp.join(
  43. model, self.cfg['dataset']['trg_vocab']['file'])
  44. self._trg_rvocab = dict([(i, w.strip()) for i, w in enumerate(
  45. open(self._trg_vocab_path, encoding='utf-8'))])
  46. tf_config = tf.ConfigProto(allow_soft_placement=True)
  47. tf_config.gpu_options.allow_growth = True
  48. self._session = tf.Session(config=tf_config)
  49. self.input_wids = tf.placeholder(
  50. dtype=tf.int64, shape=[None, None], name='input_wids')
  51. self.output = {}
  52. # preprocess
  53. self._src_lang = self.cfg['preprocessor']['src_lang']
  54. self._tgt_lang = self.cfg['preprocessor']['tgt_lang']
  55. self._src_bpe_path = osp.join(
  56. model, self.cfg['preprocessor']['src_bpe']['file'])
  57. if self._src_lang == 'zh':
  58. self._tok = jieba
  59. else:
  60. self._punct_normalizer = MosesPunctNormalizer(lang=self._src_lang)
  61. self._tok = MosesTokenizer(lang=self._src_lang)
  62. self._detok = MosesDetokenizer(lang=self._tgt_lang)
  63. self._bpe = apply_bpe.BPE(open(self._src_bpe_path, encoding='utf-8'))
  64. # model
  65. output = self.model(self.input_wids)
  66. self.output.update(output)
  67. with self._session.as_default() as sess:
  68. logger.info(f'loading model from {self.model_path}')
  69. # load model
  70. self.model_loader = tf.train.Saver(tf.global_variables())
  71. self.model_loader.restore(sess, self.model_path)
  72. def preprocess(self, input: str) -> Dict[str, Any]:
  73. input = input.split('<SENT_SPLIT>')
  74. if self._src_lang == 'zh':
  75. input_tok = [self._tok.cut(item) for item in input]
  76. input_tok = [' '.join(list(item)) for item in input_tok]
  77. else:
  78. input = [self._punct_normalizer.normalize(item) for item in input]
  79. aggressive_dash_splits = True
  80. if (self._src_lang in ['es', 'fr'] and self._tgt_lang == 'en') or (
  81. self._src_lang == 'en' and self._tgt_lang in ['es', 'fr']):
  82. aggressive_dash_splits = False
  83. input_tok = [
  84. self._tok.tokenize(
  85. item,
  86. return_str=True,
  87. aggressive_dash_splits=aggressive_dash_splits)
  88. for item in input
  89. ]
  90. input_bpe = [
  91. self._bpe.process_line(item).strip().split() for item in input_tok
  92. ]
  93. MAX_LENGTH = max([len(item) for item in input_bpe])
  94. input_ids = np.array([[
  95. self._src_vocab[w] if w in self._src_vocab else
  96. self.cfg['model']['src_vocab_size'] - 1 for w in item
  97. ] + [0] * (MAX_LENGTH - len(item)) for item in input_bpe])
  98. result = {'input_ids': input_ids}
  99. return result
  100. def forward(self, input: Dict[str, Any]) -> Dict[str, Any]:
  101. with self._session.as_default():
  102. feed_dict = {self.input_wids: input['input_ids']}
  103. sess_outputs = self._session.run(self.output, feed_dict=feed_dict)
  104. return sess_outputs
  105. def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, Any]:
  106. x, y, z = inputs['output_seqs'].shape
  107. translation_out = []
  108. for i in range(x):
  109. output_seqs = inputs['output_seqs'][i]
  110. wids = list(output_seqs[0]) + [0]
  111. wids = wids[:wids.index(0)]
  112. translation = ' '.join([
  113. self._trg_rvocab[wid] if wid in self._trg_rvocab else '<unk>'
  114. for wid in wids
  115. ]).replace('@@ ', '').replace('@@', '')
  116. translation_out.append(self._detok.detokenize(translation.split()))
  117. translation_out = '<SENT_SPLIT>'.join(translation_out)
  118. result = {OutputKeys.TRANSLATION: translation_out}
  119. return result