canmt_translation.py 2.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778
  1. # Copyright (c) Alibaba, Inc. and its affiliates.
  2. import math
  3. import os.path as osp
  4. from typing import Any, Dict, List, Optional, Tuple
  5. import numpy
  6. import torch
  7. import torch.nn as nn
  8. from torch import Tensor
  9. from modelscope.metainfo import Models
  10. from modelscope.models.base import TorchModel
  11. from modelscope.models.builder import MODELS
  12. from modelscope.utils.config import Config
  13. from modelscope.utils.constant import ModelFile, Tasks
  14. __all__ = ['CanmtForTranslation']
  15. @MODELS.register_module(
  16. Tasks.competency_aware_translation, module_name=Models.canmt)
  17. class CanmtForTranslation(TorchModel):
  18. def __init__(self, model_dir, **args):
  19. """
  20. CanmtForTranslation implements a Competency-Aware Neural Machine Translaton,
  21. which has both translation and self-estimation abilities.
  22. For more details, please refer to https://aclanthology.org/2022.emnlp-main.330.pdf
  23. """
  24. super().__init__(model_dir=model_dir, **args)
  25. self.args = args
  26. cfg_file = osp.join(model_dir, ModelFile.CONFIGURATION)
  27. self.cfg = Config.from_file(cfg_file)
  28. from fairseq.data import Dictionary
  29. self.vocab_src = Dictionary.load(osp.join(model_dir, 'dict.src.txt'))
  30. self.vocab_tgt = Dictionary.load(osp.join(model_dir, 'dict.tgt.txt'))
  31. self.model = self.build_model(model_dir)
  32. self.generator = self.build_generator(self.model, self.vocab_tgt,
  33. self.cfg['decode'])
  34. def build_model(self, model_dir):
  35. from .canmt_model import CanmtModel
  36. state = self.load_checkpoint(
  37. osp.join(model_dir, ModelFile.TORCH_MODEL_FILE), 'cpu')
  38. cfg = state['cfg']
  39. model = CanmtModel.build_model(cfg['model'], self)
  40. model.load_state_dict(state['model'], model_cfg=cfg['model'])
  41. return model
  42. def build_generator(cls, model, vocab_tgt, args):
  43. from .sequence_generator import SequenceGenerator
  44. return SequenceGenerator(
  45. model,
  46. vocab_tgt,
  47. beam_size=args['beam'],
  48. len_penalty=args['lenpen'])
  49. def load_checkpoint(self, path: str, device: torch.device):
  50. state_dict = torch.load(path, map_location=device)
  51. self.load_state_dict(state_dict, strict=False)
  52. return state_dict
  53. def forward(self, input: Dict[str, Dict]):
  54. """return the result by the model
  55. Args:
  56. input (Dict[str, Tensor]): the preprocessed data which contains following:
  57. - src_tokens: tensor with shape (2478,242,24,4),
  58. - src_lengths: tensor with shape (4)
  59. Returns:
  60. Dict[str, Tensor]: results which contains following:
  61. - predictions: tokens need to be decode by tokenizer with shape [1377, 4959, 2785, 6392...]
  62. """
  63. input = {'net_input': input}
  64. return self.generator.generate(input)