dialog_modeling.py 4.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119
  1. # Copyright (c) Alibaba, Inc. and its affiliates.
  2. import os
  3. from typing import Dict
  4. from modelscope.metainfo import Models
  5. from modelscope.models import TorchModel
  6. from modelscope.models.base import Tensor
  7. from modelscope.models.builder import MODELS
  8. from modelscope.models.nlp.space import SpaceGenerator, SpaceModelBase
  9. from modelscope.preprocessors.nlp import MultiWOZBPETextField
  10. from modelscope.utils.config import Config
  11. from modelscope.utils.constant import ModelFile, Tasks
  12. __all__ = ['SpaceForDialogModeling']
  13. @MODELS.register_module(
  14. Tasks.task_oriented_conversation, module_name=Models.space_modeling)
  15. class SpaceForDialogModeling(TorchModel):
  16. def __init__(self, model_dir: str, *args, **kwargs):
  17. """initialize the test generation model from the `model_dir` path.
  18. Args:
  19. model_dir (`str`):
  20. The model path.
  21. text_field (`BPETextField`, *optional*, defaults to `MultiWOZBPETextField`):
  22. The text field.
  23. config (`Config`, *optional*, defaults to config in model hub):
  24. The config.
  25. """
  26. super().__init__(model_dir, *args, **kwargs)
  27. from modelscope.trainers.nlp.space.trainer.gen_trainer import MultiWOZTrainer
  28. self.model_dir = model_dir
  29. self.config = kwargs.pop(
  30. 'config',
  31. Config.from_file(
  32. os.path.join(self.model_dir, ModelFile.CONFIGURATION)))
  33. import torch
  34. self.config.use_gpu = True if (
  35. 'device' not in kwargs or kwargs['device']
  36. == 'gpu') and torch.cuda.is_available() else False
  37. self.text_field = kwargs.pop(
  38. 'text_field',
  39. MultiWOZBPETextField(config=self.config, model_dir=self.model_dir))
  40. self.generator = SpaceGenerator.create(
  41. self.config, reader=self.text_field)
  42. self.model = SpaceModelBase.create(
  43. model_dir=model_dir,
  44. config=self.config,
  45. reader=self.text_field,
  46. generator=self.generator)
  47. def to_tensor(array):
  48. """
  49. numpy array -> tensor
  50. """
  51. import torch
  52. array = torch.tensor(array)
  53. return array.cuda() if self.config.use_gpu else array
  54. self.trainer = MultiWOZTrainer(
  55. model=self.model,
  56. to_tensor=to_tensor,
  57. config=self.config,
  58. reader=self.text_field,
  59. evaluator=None)
  60. self.trainer.load()
  61. def forward(self, input: Dict[str, Tensor]) -> Dict[str, Tensor]:
  62. """return the result by the model
  63. Args:
  64. input (Dict[str, Tensor]): the preprocessed data
  65. Returns:
  66. Dict[str, Tensor]: results
  67. Example:
  68. {
  69. 'labels': array([1,192,321,12]), # label
  70. 'resp': array([293,1023,123,1123]), #vocab label for response
  71. 'bspn': array([123,321,2,24,1 ]),
  72. 'aspn': array([47,8345,32,29,1983]),
  73. 'db': array([19, 24, 20]),
  74. }
  75. Examples:
  76. >>> from modelscope.hub.snapshot_download import snapshot_download
  77. >>> from modelscope.models.nlp import SpaceForDialogModeling
  78. >>> from modelscope.preprocessors import DialogModelingPreprocessor
  79. >>> cache_path = snapshot_download('damo/nlp_space_dialog-modeling')
  80. >>> preprocessor = DialogModelingPreprocessor(model_dir=cache_path)
  81. >>> model = SpaceForDialogModeling(model_dir=cache_path,
  82. text_field=preprocessor.text_field,
  83. config=preprocessor.config)
  84. >>> print(model(preprocessor({
  85. 'user_input': 'i would like a taxi from saint john \'s college to pizza hut fen ditton .',
  86. 'history': {}
  87. })))
  88. """
  89. first_turn = input['first_turn']
  90. batch = input['batch']
  91. prompt_id = input['prompt_id']
  92. labels = input['labels']
  93. old_pv_turn = input['history']
  94. pv_turn = self.trainer.forward(
  95. first_turn=first_turn,
  96. batch=batch,
  97. prompt_id=prompt_id,
  98. labels=labels,
  99. old_pv_turn=old_pv_turn)
  100. return pv_turn