text_classification.py 5.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133
  1. # Copyright (c) Alibaba, Inc. and its affiliates.
  2. from typing import Any, Dict
  3. import torch
  4. from modelscope.utils.constant import ModeKeys
  5. from .base import OfaBasePreprocessor
  6. class OfaTextClassificationPreprocessor(OfaBasePreprocessor):
  7. r"""
  8. OFA preprocessor for text classification tasks.
  9. """
  10. def __init__(self,
  11. cfg,
  12. model_dir,
  13. mode=ModeKeys.INFERENCE,
  14. *args,
  15. **kwargs):
  16. """preprocess the data
  17. Args:
  18. cfg(modelscope.utils.config.ConfigDict) : model config
  19. model_dir (str): model path,
  20. mode: preprocessor mode (model mode)
  21. """
  22. super(OfaTextClassificationPreprocessor,
  23. self).__init__(cfg, model_dir, mode, *args, **kwargs)
  24. def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
  25. if self.mode == ModeKeys.TRAIN:
  26. return self._build_train_sample(data)
  27. else:
  28. return self._build_infer_sample(data)
  29. def _build_instruction(self, data):
  30. r"""
  31. Building text classification task's instruction.
  32. The `data` should contains key `text` and `text2`, and the final instruction
  33. is like ` can text1 " {} " imply text2 " {} "?`, the first `{}` refer to
  34. the value of `text` and the latter refer to `text2`
  35. step 1. Preprocess for input text `text` and `text2` in `data`.
  36. - Do lower, stripe and restrict the maximum length as `max_src_length`.
  37. step 2. Using instruction template to generate the final instruction.
  38. step 3. Tokenize the instruction as result.
  39. """
  40. text1 = ' '.join(
  41. data['text'].lower().strip().split()[:self.max_src_length])
  42. text2 = ' '.join(
  43. data['text2'].lower().strip().split()[:self.max_src_length])
  44. prompt = ' can text1 " {} " imply text2 " {} "?'
  45. text = prompt.format(text1, text2)
  46. instruction_itm = self.tokenize_text(text)
  47. return instruction_itm
  48. def _build_train_sample(self, data: Dict[str, Any]) -> Dict[str, Any]:
  49. r"""
  50. Building training samples.
  51. step 1. Building instruction for text classification using `_build_instruction`.
  52. step 2. If the `label` is not text, transfer it to text using `label2ans`.
  53. step 3. Tokenize the label data.
  54. step 4. Concatenate the instruction and label tokens as the target item.
  55. - padding the instruction tokens from target item as `target`.
  56. - remove the eos token from target item as `prev_output_tokens`.
  57. step 5. Add constraint mask.
  58. Args:
  59. data (`Dict[str, Any]`): Input data, should contains the key of `text`, `text2`
  60. and `label`, both of them refer to a text input, and the target of this job
  61. is to find whether or not `text` imply `text2`, the `label` is the supervised
  62. data for training.
  63. Return:
  64. A dict object, contains source text input, target tokens and previous output
  65. tokens and constraint mask.
  66. """
  67. instruction_itm = self._build_instruction(data)
  68. assert 'label' in data, 'there must has `label` column in train phase '
  69. label = data['label']
  70. if self.label2ans:
  71. label = self.label2ans[label] # ans
  72. label_itm = self.tokenize_text(f' {label}', add_bos=False)
  73. if self.prompt_type == 'none':
  74. target_itm = label_itm
  75. elif self.prompt_type == 'prev_output':
  76. target_itm = torch.cat([instruction_itm[1:-1], label_itm])
  77. else:
  78. raise NotImplementedError
  79. prev_output_itm = torch.cat([self.bos_item, target_itm[:-1]])
  80. target_itm[:-len(label_itm)] = self.pad_item
  81. sample = {
  82. 'source': instruction_itm,
  83. 'target': target_itm,
  84. 'prev_output_tokens': prev_output_itm,
  85. }
  86. self.add_constraint_mask(sample)
  87. return sample
  88. def _build_infer_sample(self, data: Dict[str, Any]) -> Dict[str, Any]:
  89. r"""
  90. Building inference samples.
  91. step 1. Building instruction for text classification using `_build_instruction`.
  92. step 2. Whether or not to add `prefix_token`.
  93. step 3. Whether or not to add `label` data.
  94. Args:
  95. data (`Dict[str, Any]`): Input data, should contains the key of `text` and `text2`,
  96. both of them refer to a text input, and the target of this job is to find
  97. whether or not `text` imply `text2`.
  98. Return:
  99. A dict object, contains source text input, prefix tokens and label data.
  100. """
  101. instruction_itm = self._build_instruction(data)
  102. if self.prompt_type == 'none':
  103. prefix_token = []
  104. decoder_prompt = self.bos_item
  105. elif self.prompt_type == 'prev_output':
  106. prefix_token = instruction_itm[:-1] # remove eos
  107. decoder_prompt = instruction_itm[:-1]
  108. else:
  109. raise NotImplementedError
  110. sample = {
  111. 'source': instruction_itm,
  112. 'prefix_token': prefix_token,
  113. 'decoder_prompt': decoder_prompt,
  114. }
  115. if 'label' in data:
  116. sample['label'] = self.label2ans[data['label']]
  117. return sample