siamese_uie_trainer.py 17 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378
  1. # Copyright (c) Alibaba, Inc. and its affiliates.
  2. import os
  3. import random
  4. import time
  5. from collections import defaultdict
  6. from math import ceil
  7. from typing import Callable, Dict, List, Optional, Tuple, Union
  8. import json
  9. import numpy as np
  10. import torch
  11. from torch import distributed as dist
  12. from torch import nn
  13. from torch.utils.data import Dataset
  14. from modelscope.metainfo import Trainers
  15. from modelscope.models.base import TorchModel
  16. from modelscope.msdatasets import MsDataset
  17. from modelscope.pipelines import pipeline
  18. from modelscope.preprocessors.base import Preprocessor
  19. from modelscope.trainers import EpochBasedTrainer, NlpEpochBasedTrainer
  20. from modelscope.trainers.builder import TRAINERS
  21. from modelscope.trainers.optimizer.builder import build_optimizer
  22. from modelscope.utils.config import Config
  23. from modelscope.utils.constant import DEFAULT_MODEL_REVISION, ModeKeys, Tasks
  24. from modelscope.utils.file_utils import func_receive_dict_inputs
  25. from modelscope.utils.logger import get_logger
  26. from ..parallel.utils import is_parallel
  27. PATH = None
  28. logger = get_logger(PATH)
  29. os.environ['TOKENIZERS_PARALLELISM'] = 'true'
  30. @TRAINERS.register_module(module_name=Trainers.siamese_uie_trainer)
  31. class SiameseUIETrainer(EpochBasedTrainer):
  32. def __init__(
  33. self,
  34. model: Optional[Union[TorchModel, nn.Module, str]] = None,
  35. cfg_file: Optional[str] = None,
  36. cfg_modify_fn: Optional[Callable] = None,
  37. train_dataset: Optional[Union[MsDataset, Dataset]] = None,
  38. eval_dataset: Optional[Union[MsDataset, Dataset]] = None,
  39. preprocessor: Optional[Union[Preprocessor,
  40. Dict[str, Preprocessor]]] = None,
  41. optimizers: Tuple[torch.optim.Optimizer,
  42. torch.optim.lr_scheduler._LRScheduler] = (None,
  43. None),
  44. model_revision: Optional[str] = DEFAULT_MODEL_REVISION,
  45. seed: int = 42,
  46. negative_sampling_rate=1,
  47. slide_len=352,
  48. max_len=384,
  49. hint_max_len=128,
  50. **kwargs):
  51. """Epoch based Trainer, a training helper for PyTorch.
  52. Args:
  53. model (:obj:`torch.nn.Module` or :obj:`TorchModel` or `str`): The model to be run, or a valid model dir
  54. or a model id. If model is None, build_model method will be called.
  55. cfg_file(str): The local config file.
  56. cfg_modify_fn (function): Optional[Callable] = None, config function
  57. train_dataset (`MsDataset` or `torch.utils.data.Dataset`, *optional*):
  58. The dataset to use for training.
  59. Note that if it's a `torch.utils.data.IterableDataset` with some randomization and you are training in a
  60. distributed fashion, your iterable dataset should either use a internal attribute `generator` that is a
  61. `torch.Generator` for the randomization that must be identical on all processes (and the Trainer will
  62. manually set the seed of this `generator` at each epoch) or have a `set_epoch()` method that internally
  63. sets the seed of the RNGs used.
  64. eval_dataset (`MsDataset` or `torch.utils.data.Dataset`, *optional*): The dataset to use for evaluation.
  65. preprocessor (:obj:`Preprocessor`, *optional*): The optional preprocessor.
  66. NOTE: If the preprocessor has been called before the dataset fed into this
  67. trainer by user's custom code,
  68. this parameter should be None, meanwhile remove the 'preprocessor' key from the cfg_file.
  69. Else the preprocessor will be instantiated from the cfg_file or assigned from this parameter and
  70. this preprocessing action will be executed every time the dataset's __getitem__ is called.
  71. optimizers (`Tuple[torch.optim.Optimizer, torch.optim.lr_scheduler._LRScheduler]`, *optional*): A tuple
  72. containing the optimizer and the scheduler to use.
  73. model_revision (str): The model version to use in modelhub.
  74. negative_sampling_rate (float): The rate to do negative sampling.
  75. slide_len (int): The length to slide.
  76. max_len (int): The max length of prompt + text.
  77. hint_max_len (int): The max length of prompt.
  78. seed (int): The optional random seed for torch, cuda, numpy and random.
  79. """
  80. print('*******************')
  81. self.slide_len = slide_len
  82. self.max_len = max_len
  83. self.hint_max_len = hint_max_len
  84. self.negative_sampling_rate = negative_sampling_rate
  85. super().__init__(
  86. model=model,
  87. cfg_file=cfg_file,
  88. cfg_modify_fn=cfg_modify_fn,
  89. data_collator=self._nn_collate_fn,
  90. train_dataset=train_dataset,
  91. eval_dataset=eval_dataset,
  92. preprocessor=preprocessor,
  93. optimizers=optimizers,
  94. model_revision=model_revision,
  95. seed=seed,
  96. **kwargs)
  97. def build_dataset(self,
  98. datasets: Union[torch.utils.data.Dataset, MsDataset,
  99. List[torch.utils.data.Dataset]],
  100. model_cfg: Config,
  101. mode: str,
  102. preprocessor: Optional[Preprocessor] = None,
  103. **kwargs):
  104. if mode == ModeKeys.TRAIN:
  105. datasets = self.load_dataset(datasets)
  106. return super(SiameseUIETrainer, self).build_dataset(
  107. datasets=datasets,
  108. model_cfg=self.cfg,
  109. mode=mode,
  110. preprocessor=preprocessor,
  111. **kwargs)
  112. def get_train_dataloader(self):
  113. """ Builder torch dataloader for training.
  114. We provide a reasonable default that works well. If you want to use something else, you can change
  115. the config for data.train in configuration file, or subclass and override this method
  116. (or `get_train_dataloader` in a subclass.
  117. """
  118. self.train_dataset.preprocessor = None
  119. data_loader = self._build_dataloader_with_dataset(
  120. self.train_dataset,
  121. dist=self._dist,
  122. seed=self._seed,
  123. collate_fn=self.train_data_collator,
  124. **self.cfg.train.get('dataloader', {}))
  125. return data_loader
  126. def get_brother_type_map(self, schema, brother_type_map, prefix_types):
  127. if not schema:
  128. return
  129. for k in schema:
  130. brother_type_map[tuple(prefix_types
  131. + [k])] += [v for v in schema if v != k]
  132. self.get_brother_type_map(schema[k], brother_type_map,
  133. prefix_types + [k])
  134. def load_dataset(self, raw_dataset):
  135. data = []
  136. for num_line, raw_sample in enumerate(raw_dataset):
  137. raw_sample['info_list'] = json.loads(raw_sample['info_list'])
  138. raw_sample['schema'] = json.loads(raw_sample['schema'])
  139. hint_spans_map = defaultdict(list)
  140. # positive sampling
  141. for info in raw_sample['info_list']:
  142. hint = ''
  143. for item in info:
  144. hint += f'{item["type"]}: '
  145. span = {'span': item['span'], 'offset': item['offset']}
  146. if span not in hint_spans_map[hint]:
  147. hint_spans_map[hint].append(span)
  148. hint += f'{item["span"]}, '
  149. # negative sampling
  150. brother_type_map = defaultdict(list)
  151. self.get_brother_type_map(raw_sample['schema'], brother_type_map,
  152. [])
  153. for info in raw_sample['info_list']:
  154. hint = ''
  155. for i, item in enumerate(info):
  156. key = tuple([info[j]['type'] for j in range(i + 1)])
  157. for st in brother_type_map.get(key, []):
  158. neg_hint = hint + f'{st}: '
  159. if neg_hint not in hint_spans_map and random.random(
  160. ) < self.negative_sampling_rate:
  161. hint_spans_map[neg_hint] = []
  162. hint += f'{item["type"]}: '
  163. hint += f'{item["span"]}, '
  164. # info list为空
  165. for k in raw_sample['schema']:
  166. neg_hint = f'{k}: '
  167. if neg_hint not in hint_spans_map and random.random(
  168. ) < self.negative_sampling_rate:
  169. hint_spans_map[neg_hint] = []
  170. for i, hint in enumerate(hint_spans_map):
  171. sample = {
  172. 'id': f'{raw_sample["id"]}-{i}',
  173. 'hint': hint,
  174. 'text': raw_sample['text'],
  175. 'spans': hint_spans_map[hint]
  176. }
  177. uuid = sample['id']
  178. text = sample['text']
  179. tokenized_input = self.train_preprocessor([text])[0]
  180. tokenized_hint = self.train_preprocessor(
  181. [hint], max_length=self.hint_max_len, truncation=True)[0]
  182. sample['offsets'] = tokenized_input.offsets
  183. entities = sample.get('spans', [])
  184. head_labels, tail_labels = self._get_labels(
  185. text, tokenized_input, sample['offsets'], entities)
  186. split_num = ceil(
  187. (len(tokenized_input) - self.max_len) / self.slide_len
  188. ) + 1 if len(tokenized_input) > self.max_len else 1
  189. for j in range(split_num):
  190. a, b = j * self.slide_len, j * self.slide_len + self.max_len
  191. item = {
  192. 'id': uuid,
  193. 'shift': a,
  194. 'tokens': tokenized_input.tokens[a:b],
  195. 'token_ids': tokenized_input.ids[a:b],
  196. 'hint_tokens': tokenized_hint.tokens,
  197. 'hint_token_ids': tokenized_hint.ids,
  198. 'attention_masks': tokenized_input.attention_mask[a:b],
  199. 'cross_attention_masks': tokenized_hint.attention_mask,
  200. 'head_labels': head_labels[a:b],
  201. 'tail_labels': tail_labels[a:b]
  202. }
  203. data.append(item)
  204. from datasets import Dataset
  205. train_dataset = Dataset.from_list(data)
  206. for index in random.sample(range(len(train_dataset)), 3):
  207. logger.info(
  208. f'Sample {index} of the training set: {train_dataset[index]}.')
  209. return train_dataset
  210. def _get_labels(self, text, tokenized_input, offsets, entities):
  211. num_tokens = len(tokenized_input)
  212. head_labels = [0] * num_tokens
  213. tail_labels = [0] * num_tokens
  214. char_index_to_token_index_map = {}
  215. for i in range(len(offsets)):
  216. offset = offsets[i]
  217. for j in range(offset[0], offset[1]):
  218. char_index_to_token_index_map[j] = i
  219. for e in entities:
  220. h, t = e['offset']
  221. t -= 1
  222. while h not in char_index_to_token_index_map:
  223. h += 1
  224. if h > len(text):
  225. print('h', e['offset'], e['span'],
  226. text[e['offset'][0]:e['offset'][1]])
  227. break
  228. while t not in char_index_to_token_index_map:
  229. t -= 1
  230. if t < 0:
  231. print('t', e['offset'], e['span'],
  232. text[e['offset'][0]:e['offset'][1]])
  233. break
  234. if h > len(text) or t < 0:
  235. continue
  236. token_head = char_index_to_token_index_map[h]
  237. token_tail = char_index_to_token_index_map[t]
  238. head_labels[token_head] = 1
  239. tail_labels[token_tail] = 1
  240. return head_labels, tail_labels
  241. def _padding(self, data, val=0):
  242. res = []
  243. for seq in data:
  244. res.append(seq + [val] * (self.max_len - len(seq)))
  245. return res
  246. def _nn_collate_fn(self, batch):
  247. token_ids = torch.tensor(
  248. self._padding([item['token_ids'] for item in batch]),
  249. dtype=torch.long)
  250. hint_token_ids = torch.tensor(
  251. self._padding([item['hint_token_ids'] for item in batch]),
  252. dtype=torch.long)
  253. attention_masks = torch.tensor(
  254. self._padding([item['attention_masks'] for item in batch]),
  255. dtype=torch.long)
  256. cross_attention_masks = torch.tensor(
  257. self._padding([item['cross_attention_masks'] for item in batch]),
  258. dtype=torch.long)
  259. head_labels = torch.tensor(
  260. self._padding([item['head_labels'] for item in batch]),
  261. dtype=torch.float)
  262. tail_labels = torch.tensor(
  263. self._padding([item['tail_labels'] for item in batch]),
  264. dtype=torch.float)
  265. # the content of `batch` is like batch_size * [token_ids, head_labels, tail_labels]
  266. # for fp16 acceleration, truncate seq_len to multiples of 8
  267. batch_max_len = token_ids.gt(0).sum(dim=-1).max().item()
  268. batch_max_len += (8 - batch_max_len % 8) % 8
  269. truncate_len = min(self.max_len, batch_max_len)
  270. token_ids = token_ids[:, :truncate_len]
  271. attention_masks = attention_masks[:, :truncate_len]
  272. head_labels = head_labels[:, :truncate_len]
  273. tail_labels = tail_labels[:, :truncate_len]
  274. # for fp16 acceleration, truncate seq_len to multiples of 8
  275. batch_max_len = hint_token_ids.gt(0).sum(dim=-1).max().item()
  276. batch_max_len += (8 - batch_max_len % 8) % 8
  277. hint_truncate_len = min(self.hint_max_len, batch_max_len)
  278. hint_token_ids = hint_token_ids[:, :hint_truncate_len]
  279. cross_attention_masks = cross_attention_masks[:, :hint_truncate_len]
  280. return {
  281. 'input_ids': token_ids,
  282. 'attention_masks': attention_masks,
  283. 'hint_ids': hint_token_ids,
  284. 'cross_attention_masks': cross_attention_masks,
  285. 'head_labels': head_labels,
  286. 'tail_labels': tail_labels
  287. }
  288. def evaluate(self,
  289. checkpoint_path: Optional[str] = None,
  290. *args,
  291. **kwargs) -> Dict[str, float]:
  292. """evaluate a dataset
  293. evaluate a dataset via a specific model from the `checkpoint_path` path, if the `checkpoint_path`
  294. does not exist, read from the config file.
  295. Args:
  296. checkpoint_path (Optional[str], optional): the model path. Defaults to None.
  297. Returns:
  298. Dict[str, float]: the results about the evaluation
  299. Example:
  300. {"accuracy": 0.5091743119266054, "f1": 0.673780487804878}
  301. """
  302. pipeline_uie = pipeline(
  303. Tasks.siamese_uie, self.model, device=str(self.device))
  304. if checkpoint_path is not None and os.path.isfile(checkpoint_path):
  305. from modelscope.trainers.hooks import LoadCheckpointHook
  306. LoadCheckpointHook.load_checkpoint(checkpoint_path, self)
  307. self.model.eval()
  308. self._mode = ModeKeys.EVAL
  309. self.eval_dataloader = self.train_dataloader
  310. num_pred = num_recall = num_correct = 1e-10
  311. self.eval_dataset.preprocessor = None
  312. for sample in self.eval_dataset:
  313. text = sample['text']
  314. schema = json.loads(sample['schema'])
  315. gold_info_list = json.loads(sample['info_list'])
  316. pred_info_list = pipeline_uie(input=text, schema=schema)['output']
  317. pred_info_list_set = set([str(item) for item in pred_info_list])
  318. gold_info_list_set = set([str(item) for item in gold_info_list])
  319. a, b, c = len(pred_info_list_set), len(gold_info_list_set), len(
  320. pred_info_list_set.intersection(gold_info_list_set))
  321. num_pred += a
  322. num_recall += b
  323. num_correct += c
  324. precision, recall, f1 = self.compute_metrics(num_pred, num_recall,
  325. num_correct)
  326. return {'precision': precision, 'recall': recall, 'f1': f1}
  327. def get_metrics(self) -> List[Union[str, Dict]]:
  328. """Get the metric class types.
  329. The first choice will be the metrics configured in the config file, if not found, the default metrics will be
  330. used.
  331. If no metrics is found and the eval dataset exists, the method will raise an error.
  332. Returns: The metric types.
  333. """
  334. return self.compute_metrics
  335. def compute_metrics(self, num_pred, num_recall, num_correct):
  336. if num_pred == num_recall == 1e-10:
  337. return 1, 1, 1
  338. precision = num_correct / float(num_pred)
  339. recall = num_correct / float(num_recall)
  340. f1 = 2 * precision * recall / (precision + recall)
  341. # print(num_pred, num_recall, num_correct)
  342. if num_correct == 1e-10:
  343. return 0, 0, 0
  344. return precision, recall, f1