distributed_plug_pipeline.py 4.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109
  1. # Copyright (c) Alibaba, Inc. and its affiliates.
  2. from typing import Any, Dict
  3. import torch
  4. from modelscope.metainfo import Pipelines
  5. from modelscope.models.nlp.plug import DistributedPlug
  6. from modelscope.pipelines.base import DistributedPipeline
  7. from modelscope.pipelines.builder import PIPELINES
  8. from modelscope.preprocessors import TextGenerationTransformersPreprocessor
  9. from modelscope.utils.constant import Tasks
  10. @PIPELINES.register_module(
  11. Tasks.text_generation, module_name=Pipelines.plug_generation)
  12. class DistributedPlugPipeline(DistributedPipeline):
  13. """This class is used to instantiate the plug model.
  14. """
  15. model = None
  16. def __init__(self,
  17. model,
  18. preprocessor=None,
  19. first_sequence='sentence',
  20. sequence_length=512,
  21. **kwargs):
  22. """Create a plug pipeline instance.
  23. Args:
  24. model: The model_id of plug(damo/nlp_plug_text-generation_27B).
  25. The default path to damo/nlp_plug_text-generation_27B can be obtained by function
  26. get_cache_dir("damo/nlp_plug_text-generation_27B"), the model should be downloaded to
  27. this path before calling this class by model_id.
  28. The model can be downloaded from the link on
  29. https://modelscope.cn/models/damo/nlp_plug_text-generation_27B/summary.
  30. After downloading, you should have a plug model structure like this:
  31. /your/path/to/damo/nlp_plug_text-generation_27B
  32. |_ config.json
  33. |_ configuration.json
  34. |_ ds_zero-offload_10B_config.json
  35. |_ vocab.txt
  36. |_ model <-- an empty directory
  37. Model binaries shall be downloaded separately to populate the model directory, so that
  38. the model directory would contain the following binaries:
  39. |_ model
  40. |_ mp_rank_00_model_states.pt
  41. |_ mp_rank_01_model_states.pt
  42. |_ mp_rank_02_model_states.pt
  43. |_ mp_rank_03_model_states.pt
  44. |_ mp_rank_04_model_states.pt
  45. |_ mp_rank_05_model_states.pt
  46. |_ mp_rank_06_model_states.pt
  47. |_ mp_rank_07_model_states.pt
  48. preprocessor: The optional preprocessor, if not passed in, a TextGenerationPreprocessor will
  49. be used as default.
  50. kwargs (dict, `optional`): Extra kwargs passed into the preprocessor's constructor.
  51. """
  52. if preprocessor is None:
  53. preprocessor = TextGenerationTransformersPreprocessor(
  54. model,
  55. first_sequence=first_sequence,
  56. sequence_length=sequence_length,
  57. **kwargs)
  58. super().__init__(model, preprocessor=preprocessor, **kwargs)
  59. self.cls_token_id = preprocessor.nlp_tokenizer.tokenizer.cls_token_id
  60. @classmethod
  61. def _forward_one(cls, inputs: Dict[str, Any]) -> Dict[str, Any]:
  62. with torch.no_grad():
  63. return cls.model.generate(inputs['inputs'],
  64. **inputs['forward_params'])
  65. def _sanitize_parameters(self, **pipeline_parameters):
  66. return {}, pipeline_parameters, {}
  67. def forward(self, inputs: Dict[str, Any],
  68. **forward_params) -> Dict[str, Any]:
  69. batch_size = inputs['input_ids'].shape[0]
  70. dec_input_ids = torch.full([batch_size, 1],
  71. self.cls_token_id,
  72. dtype=torch.long)
  73. inputs['dec_input_ids'] = dec_input_ids
  74. res = super().forward(inputs, **forward_params)
  75. return res
  76. @classmethod
  77. def _instantiate_one(cls, rank, model_dir, **kwargs):
  78. cls.model = DistributedPlug(model_dir, rank, **kwargs)
  79. cls.model.eval()
  80. def postprocess(self, inputs: Dict[str, Any],
  81. **postprocess_params) -> Dict[str, str]:
  82. """process the prediction results
  83. Args:
  84. inputs (Dict[str, Any]): _description_
  85. Returns:
  86. Dict[str, str]: the prediction results
  87. """
  88. from modelscope.outputs import OutputKeys
  89. generate_context = inputs['generate_context']
  90. generate_context = ''.join(
  91. self.preprocessor.nlp_tokenizer.tokenizer.convert_ids_to_tokens(
  92. generate_context)).replace('[UNK]', '“').replace('##', '')
  93. return {OutputKeys.TEXT: generate_context}