canmt_translation.py 3.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109
  1. # Copyright (c) Alibaba, Inc. and its affiliates.
  2. import os.path as osp
  3. from typing import Any, Dict
  4. import jieba
  5. import torch
  6. from sacremoses import MosesDetokenizer, MosesPunctNormalizer, MosesTokenizer
  7. from subword_nmt import apply_bpe
  8. from modelscope.metainfo import Preprocessors
  9. from modelscope.preprocessors.base import Preprocessor
  10. from modelscope.preprocessors.builder import PREPROCESSORS
  11. from modelscope.utils.config import Config
  12. from modelscope.utils.constant import Fields, ModelFile
  13. from .text_clean import TextClean
  14. @PREPROCESSORS.register_module(
  15. Fields.nlp, module_name=Preprocessors.canmt_translation)
  16. class CanmtTranslationPreprocessor(Preprocessor):
  17. """The preprocessor used in text correction task.
  18. """
  19. def __init__(self,
  20. model_dir: str,
  21. max_length: int = None,
  22. *args,
  23. **kwargs):
  24. from fairseq.data import Dictionary
  25. """preprocess the data via the vocab file from the `model_dir` path
  26. Args:
  27. model_dir (str): model path
  28. """
  29. super().__init__(*args, **kwargs)
  30. self.cfg = Config.from_file(
  31. osp.join(model_dir, ModelFile.CONFIGURATION))
  32. self.vocab_src = Dictionary.load(osp.join(model_dir, 'dict.src.txt'))
  33. self.vocab_tgt = Dictionary.load(osp.join(model_dir, 'dict.tgt.txt'))
  34. self.padding_value = self.vocab_src.pad()
  35. self.max_length = max_length + 1 if max_length is not None else 129 # 1 is eos token
  36. self.src_lang = self.cfg['preprocessor']['src_lang']
  37. self.tgt_lang = self.cfg['preprocessor']['tgt_lang']
  38. self.tc = TextClean()
  39. if self.src_lang == 'zh':
  40. self.tok = jieba
  41. else:
  42. self.punct_normalizer = MosesPunctNormalizer(lang=self.src_lang)
  43. self.tok = MosesTokenizer(lang=self.src_lang)
  44. self.src_bpe_path = osp.join(
  45. model_dir, self.cfg['preprocessor']['src_bpe']['file'])
  46. self.bpe = apply_bpe.BPE(open(self.src_bpe_path))
  47. def __call__(self, input: str) -> Dict[str, Any]:
  48. """process the raw input data
  49. Args:
  50. data (str): a sentence
  51. Example:
  52. '随着中国经济突飞猛近,建造工业与日俱增'
  53. Returns:
  54. Dict[str, Any]: the preprocessed data
  55. Example:
  56. {'net_input':
  57. {'src_tokens':tensor([1,2,3,4]),
  58. 'src_lengths': tensor([4])}
  59. }
  60. """
  61. if self.src_lang == 'zh':
  62. input = self.tc.clean(input)
  63. input_tok = self.tok.cut(input)
  64. input_tok = ' '.join(list(input_tok))
  65. else:
  66. input = [self._punct_normalizer.normalize(item) for item in input]
  67. input_tok = [
  68. self.tok.tokenize(
  69. item, return_str=True, aggressive_dash_splits=True)
  70. for item in input
  71. ]
  72. input_bpe = self.bpe.process_line(input_tok).strip().split()
  73. text = ' '.join([x for x in input_bpe])
  74. inputs = self.vocab_src.encode_line(
  75. text, append_eos=True, add_if_not_exist=False)
  76. prev_inputs = torch.roll(inputs, shifts=1)
  77. lengths = inputs.size()[0]
  78. max_len = min(self.max_length, lengths)
  79. padding = torch.tensor(
  80. [self.padding_value] * # noqa: W504
  81. (max_len - lengths),
  82. dtype=inputs.dtype)
  83. sources = torch.unsqueeze(torch.cat([inputs, padding]), dim=0)
  84. inputs = torch.unsqueeze(torch.cat([padding, inputs]), dim=0)
  85. prev_inputs = torch.unsqueeze(torch.cat([prev_inputs, padding]), dim=0)
  86. lengths = torch.tensor([lengths])
  87. out = {
  88. 'src_tokens': inputs,
  89. 'src_lengths': lengths,
  90. 'prev_src_tokens': prev_inputs,
  91. 'sources': sources
  92. }
  93. return out