translation_evaluation_pipeline.py 4.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116
  1. # Copyright (c) Alibaba, Inc. and its affiliates.
  2. import os.path as osp
  3. from enum import Enum
  4. from typing import Any, Dict, List, Optional, Union
  5. import numpy as np
  6. import torch
  7. from modelscope.metainfo import Pipelines
  8. from modelscope.models.base import Model
  9. from modelscope.models.nlp.unite.configuration import InputFormat
  10. from modelscope.outputs import OutputKeys
  11. from modelscope.pipelines.base import InputModel, Pipeline
  12. from modelscope.pipelines.builder import PIPELINES
  13. from modelscope.preprocessors import Preprocessor
  14. from modelscope.utils.config import Config
  15. from modelscope.utils.constant import ModelFile, Tasks
  16. from modelscope.utils.logger import get_logger
  17. logger = get_logger()
  18. __all__ = ['TranslationEvaluationPipeline']
  19. @PIPELINES.register_module(
  20. Tasks.translation_evaluation, module_name=Pipelines.translation_evaluation)
  21. class TranslationEvaluationPipeline(Pipeline):
  22. def __init__(self,
  23. model: InputModel,
  24. preprocessor: Optional[Preprocessor] = None,
  25. input_format: InputFormat = InputFormat.SRC_REF,
  26. device: str = 'gpu',
  27. **kwargs):
  28. r"""Build a translation evaluation pipeline with a model dir or a model id in the model hub.
  29. Args:
  30. model: A Model instance.
  31. preprocessor: The preprocessor for this pipeline.
  32. input_format: Input format, choosing one from `"InputFormat.SRC_REF"`,
  33. `"InputFormat.SRC"`, `"InputFormat.REF"`. Aside from hypothesis, the
  34. source/reference/source+reference can be presented during evaluation.
  35. device: Used device for this pipeline.
  36. """
  37. super().__init__(
  38. model=model,
  39. preprocessor=preprocessor,
  40. compile=kwargs.pop('compile', False),
  41. compile_options=kwargs.pop('compile_options', {}))
  42. self.input_format = input_format
  43. self.checking_input_format()
  44. assert isinstance(self.model, Model), \
  45. f'please check whether model config exists in {ModelFile.CONFIGURATION}'
  46. self.model.load_checkpoint(
  47. osp.join(self.model.model_dir, ModelFile.TORCH_MODEL_BIN_FILE),
  48. device=self.device,
  49. plm_only=False)
  50. self.model.eval()
  51. return
  52. def checking_input_format(self):
  53. if self.input_format == InputFormat.SRC:
  54. logger.info('Evaluation mode: source-only')
  55. elif self.input_format == InputFormat.REF:
  56. logger.info('Evaluation mode: reference-only')
  57. elif self.input_format == InputFormat.SRC_REF:
  58. logger.info('Evaluation mode: source-reference-combined')
  59. else:
  60. raise ValueError('Evaluation mode should be one choice among'
  61. '\'InputFormat.SRC\', \'InputFormat.REF\', and'
  62. '\'InputFormat.SRC_REF\'.')
  63. def change_input_format(self,
  64. input_format: InputFormat = InputFormat.SRC_REF):
  65. logger.info('Changing the evaluation mode.')
  66. self.input_format = input_format
  67. self.checking_input_format()
  68. self.preprocessor.change_input_format(input_format)
  69. return
  70. def __call__(self, input_dict: Dict[str, Union[str, List[str]]], **kwargs):
  71. r"""Implementation of __call__ function.
  72. Args:
  73. input: The formatted dict containing the inputted sentences.
  74. An example of the formatted dict:
  75. ```
  76. input = {
  77. 'hyp': [
  78. 'This is a sentence.',
  79. 'This is another sentence.',
  80. ],
  81. 'src': [
  82. '这是个句子。',
  83. '这是另一个句子。',
  84. ],
  85. 'ref': [
  86. 'It is a sentence.',
  87. 'It is another sentence.',
  88. ]
  89. }
  90. ```
  91. """
  92. return super().__call__(input=input_dict, **kwargs)
  93. def forward(
  94. self, input_dict: Dict[str,
  95. torch.Tensor]) -> Dict[str, torch.Tensor]:
  96. return self.model(**input_dict)
  97. def postprocess(self, output: torch.Tensor) -> Dict[str, Any]:
  98. return output