| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133 |
- # Copyright (c) Alibaba, Inc. and its affiliates.
- from typing import Any, Dict
- import torch
- from modelscope.utils.constant import ModeKeys
- from .base import OfaBasePreprocessor
- class OfaTextClassificationPreprocessor(OfaBasePreprocessor):
- r"""
- OFA preprocessor for text classification tasks.
- """
- def __init__(self,
- cfg,
- model_dir,
- mode=ModeKeys.INFERENCE,
- *args,
- **kwargs):
- """preprocess the data
- Args:
- cfg(modelscope.utils.config.ConfigDict) : model config
- model_dir (str): model path,
- mode: preprocessor mode (model mode)
- """
- super(OfaTextClassificationPreprocessor,
- self).__init__(cfg, model_dir, mode, *args, **kwargs)
- def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
- if self.mode == ModeKeys.TRAIN:
- return self._build_train_sample(data)
- else:
- return self._build_infer_sample(data)
- def _build_instruction(self, data):
- r"""
- Building text classification task's instruction.
- The `data` should contains key `text` and `text2`, and the final instruction
- is like ` can text1 " {} " imply text2 " {} "?`, the first `{}` refer to
- the value of `text` and the latter refer to `text2`
- step 1. Preprocess for input text `text` and `text2` in `data`.
- - Do lower, stripe and restrict the maximum length as `max_src_length`.
- step 2. Using instruction template to generate the final instruction.
- step 3. Tokenize the instruction as result.
- """
- text1 = ' '.join(
- data['text'].lower().strip().split()[:self.max_src_length])
- text2 = ' '.join(
- data['text2'].lower().strip().split()[:self.max_src_length])
- prompt = ' can text1 " {} " imply text2 " {} "?'
- text = prompt.format(text1, text2)
- instruction_itm = self.tokenize_text(text)
- return instruction_itm
- def _build_train_sample(self, data: Dict[str, Any]) -> Dict[str, Any]:
- r"""
- Building training samples.
- step 1. Building instruction for text classification using `_build_instruction`.
- step 2. If the `label` is not text, transfer it to text using `label2ans`.
- step 3. Tokenize the label data.
- step 4. Concatenate the instruction and label tokens as the target item.
- - padding the instruction tokens from target item as `target`.
- - remove the eos token from target item as `prev_output_tokens`.
- step 5. Add constraint mask.
- Args:
- data (`Dict[str, Any]`): Input data, should contains the key of `text`, `text2`
- and `label`, both of them refer to a text input, and the target of this job
- is to find whether or not `text` imply `text2`, the `label` is the supervised
- data for training.
- Return:
- A dict object, contains source text input, target tokens and previous output
- tokens and constraint mask.
- """
- instruction_itm = self._build_instruction(data)
- assert 'label' in data, 'there must has `label` column in train phase '
- label = data['label']
- if self.label2ans:
- label = self.label2ans[label] # ans
- label_itm = self.tokenize_text(f' {label}', add_bos=False)
- if self.prompt_type == 'none':
- target_itm = label_itm
- elif self.prompt_type == 'prev_output':
- target_itm = torch.cat([instruction_itm[1:-1], label_itm])
- else:
- raise NotImplementedError
- prev_output_itm = torch.cat([self.bos_item, target_itm[:-1]])
- target_itm[:-len(label_itm)] = self.pad_item
- sample = {
- 'source': instruction_itm,
- 'target': target_itm,
- 'prev_output_tokens': prev_output_itm,
- }
- self.add_constraint_mask(sample)
- return sample
- def _build_infer_sample(self, data: Dict[str, Any]) -> Dict[str, Any]:
- r"""
- Building inference samples.
- step 1. Building instruction for text classification using `_build_instruction`.
- step 2. Whether or not to add `prefix_token`.
- step 3. Whether or not to add `label` data.
- Args:
- data (`Dict[str, Any]`): Input data, should contains the key of `text` and `text2`,
- both of them refer to a text input, and the target of this job is to find
- whether or not `text` imply `text2`.
- Return:
- A dict object, contains source text input, prefix tokens and label data.
- """
- instruction_itm = self._build_instruction(data)
- if self.prompt_type == 'none':
- prefix_token = []
- decoder_prompt = self.bos_item
- elif self.prompt_type == 'prev_output':
- prefix_token = instruction_itm[:-1] # remove eos
- decoder_prompt = instruction_itm[:-1]
- else:
- raise NotImplementedError
- sample = {
- 'source': instruction_itm,
- 'prefix_token': prefix_token,
- 'decoder_prompt': decoder_prompt,
- }
- if 'label' in data:
- sample['label'] = self.label2ans[data['label']]
- return sample
|