dialog_modeling_pipeline.py 2.3 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162
  1. # Copyright (c) Alibaba, Inc. and its affiliates.
  2. from typing import Dict, Union
  3. from modelscope.metainfo import Pipelines
  4. from modelscope.models import Model
  5. from modelscope.models.nlp import SpaceForDialogModeling
  6. from modelscope.outputs import OutputKeys
  7. from modelscope.pipelines.base import Pipeline, Tensor
  8. from modelscope.pipelines.builder import PIPELINES
  9. from modelscope.preprocessors import DialogModelingPreprocessor
  10. from modelscope.utils.constant import Tasks
  11. __all__ = ['DialogModelingPipeline']
  12. @PIPELINES.register_module(
  13. Tasks.task_oriented_conversation, module_name=Pipelines.dialog_modeling)
  14. class DialogModelingPipeline(Pipeline):
  15. def __init__(self,
  16. model: Union[SpaceForDialogModeling, str],
  17. preprocessor: DialogModelingPreprocessor = None,
  18. config_file: str = None,
  19. device: str = 'gpu',
  20. auto_collate=True,
  21. **kwargs):
  22. """Use `model` and `preprocessor` to create a dialog modeling pipeline for dialog response generation
  23. Args:
  24. model (str or SpaceForDialogModeling): Supply either a local model dir or a model id from the model hub,
  25. or a SpaceForDialogModeling instance.
  26. preprocessor (DialogModelingPreprocessor): An optional preprocessor instance.
  27. kwargs (dict, `optional`):
  28. Extra kwargs passed into the preprocessor's constructor.
  29. """
  30. super().__init__(
  31. model=model,
  32. preprocessor=preprocessor,
  33. config_file=config_file,
  34. device=device,
  35. auto_collate=auto_collate)
  36. if preprocessor is None:
  37. self.preprocessor = DialogModelingPreprocessor(
  38. self.model.model_dir, **kwargs)
  39. def postprocess(self, inputs: Dict[str, Tensor]) -> Dict[str, str]:
  40. """process the prediction results
  41. Args:
  42. inputs (Dict[str, Any]): _description_
  43. Returns:
  44. Dict[str, str]: the prediction results
  45. """
  46. sys_rsp = self.preprocessor.text_field.tokenizer.convert_ids_to_tokens(
  47. inputs['resp'])
  48. assert len(sys_rsp) > 2
  49. sys_rsp = sys_rsp[1:len(sys_rsp) - 1]
  50. inputs[OutputKeys.OUTPUT] = sys_rsp
  51. return inputs