| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109 |
- # Copyright (c) Alibaba, Inc. and its affiliates.
- import os.path as osp
- from typing import Any, Dict
- import jieba
- import torch
- from sacremoses import MosesDetokenizer, MosesPunctNormalizer, MosesTokenizer
- from subword_nmt import apply_bpe
- from modelscope.metainfo import Preprocessors
- from modelscope.preprocessors.base import Preprocessor
- from modelscope.preprocessors.builder import PREPROCESSORS
- from modelscope.utils.config import Config
- from modelscope.utils.constant import Fields, ModelFile
- from .text_clean import TextClean
- @PREPROCESSORS.register_module(
- Fields.nlp, module_name=Preprocessors.canmt_translation)
- class CanmtTranslationPreprocessor(Preprocessor):
- """The preprocessor used in text correction task.
- """
- def __init__(self,
- model_dir: str,
- max_length: int = None,
- *args,
- **kwargs):
- from fairseq.data import Dictionary
- """preprocess the data via the vocab file from the `model_dir` path
- Args:
- model_dir (str): model path
- """
- super().__init__(*args, **kwargs)
- self.cfg = Config.from_file(
- osp.join(model_dir, ModelFile.CONFIGURATION))
- self.vocab_src = Dictionary.load(osp.join(model_dir, 'dict.src.txt'))
- self.vocab_tgt = Dictionary.load(osp.join(model_dir, 'dict.tgt.txt'))
- self.padding_value = self.vocab_src.pad()
- self.max_length = max_length + 1 if max_length is not None else 129 # 1 is eos token
- self.src_lang = self.cfg['preprocessor']['src_lang']
- self.tgt_lang = self.cfg['preprocessor']['tgt_lang']
- self.tc = TextClean()
- if self.src_lang == 'zh':
- self.tok = jieba
- else:
- self.punct_normalizer = MosesPunctNormalizer(lang=self.src_lang)
- self.tok = MosesTokenizer(lang=self.src_lang)
- self.src_bpe_path = osp.join(
- model_dir, self.cfg['preprocessor']['src_bpe']['file'])
- self.bpe = apply_bpe.BPE(open(self.src_bpe_path))
- def __call__(self, input: str) -> Dict[str, Any]:
- """process the raw input data
- Args:
- data (str): a sentence
- Example:
- '随着中国经济突飞猛近,建造工业与日俱增'
- Returns:
- Dict[str, Any]: the preprocessed data
- Example:
- {'net_input':
- {'src_tokens':tensor([1,2,3,4]),
- 'src_lengths': tensor([4])}
- }
- """
- if self.src_lang == 'zh':
- input = self.tc.clean(input)
- input_tok = self.tok.cut(input)
- input_tok = ' '.join(list(input_tok))
- else:
- input = [self._punct_normalizer.normalize(item) for item in input]
- input_tok = [
- self.tok.tokenize(
- item, return_str=True, aggressive_dash_splits=True)
- for item in input
- ]
- input_bpe = self.bpe.process_line(input_tok).strip().split()
- text = ' '.join([x for x in input_bpe])
- inputs = self.vocab_src.encode_line(
- text, append_eos=True, add_if_not_exist=False)
- prev_inputs = torch.roll(inputs, shifts=1)
- lengths = inputs.size()[0]
- max_len = min(self.max_length, lengths)
- padding = torch.tensor(
- [self.padding_value] * # noqa: W504
- (max_len - lengths),
- dtype=inputs.dtype)
- sources = torch.unsqueeze(torch.cat([inputs, padding]), dim=0)
- inputs = torch.unsqueeze(torch.cat([padding, inputs]), dim=0)
- prev_inputs = torch.unsqueeze(torch.cat([prev_inputs, padding]), dim=0)
- lengths = torch.tensor([lengths])
- out = {
- 'src_tokens': inputs,
- 'src_lengths': lengths,
- 'prev_src_tokens': prev_inputs,
- 'sources': sources
- }
- return out
|