translation_evaluation.py 18 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430
  1. # Copyright (c) Alibaba, Inc. and its affiliates.
  2. """PyTorch UniTE model."""
  3. import warnings
  4. from dataclasses import dataclass
  5. from math import ceil
  6. from typing import Dict, List, Optional, Tuple, Union
  7. import numpy as np
  8. import torch
  9. import torch.utils.checkpoint
  10. from packaging import version
  11. from torch.nn import (Dropout, Linear, Module, Parameter, ParameterList,
  12. Sequential)
  13. from torch.nn.functional import softmax
  14. from torch.nn.utils.rnn import pad_sequence
  15. from transformers import XLMRobertaConfig, XLMRobertaModel
  16. from transformers.activations import ACT2FN
  17. from modelscope.metainfo import Models
  18. from modelscope.models.base import TorchModel
  19. from modelscope.models.builder import MODELS
  20. from modelscope.models.nlp.unite.configuration import InputFormat
  21. from modelscope.outputs.nlp_outputs import TranslationEvaluationOutput
  22. from modelscope.utils.compatible_with_transformers import \
  23. compatible_position_ids
  24. from modelscope.utils.constant import Tasks
  25. from modelscope.utils.logger import get_logger
  26. logger = get_logger()
  27. __all__ = ['UniTEForTranslationEvaluation']
  28. def _layer_norm_all(tensor, mask_float):
  29. broadcast_mask = mask_float.unsqueeze(dim=-1)
  30. num_elements_not_masked = broadcast_mask.sum() * tensor.size(-1)
  31. tensor_masked = tensor * broadcast_mask
  32. mean = tensor_masked.sum([-1, -2, -3],
  33. keepdim=True) / num_elements_not_masked
  34. variance = (((tensor_masked - mean) * broadcast_mask)**2).sum(
  35. [-1, -2, -3], keepdim=True) / num_elements_not_masked
  36. return (tensor - mean) / torch.sqrt(variance + 1e-12)
  37. class LayerwiseAttention(Module):
  38. def __init__(
  39. self,
  40. num_layers: int,
  41. model_dim: int,
  42. dropout: float = None,
  43. ) -> None:
  44. super(LayerwiseAttention, self).__init__()
  45. self.num_layers = num_layers
  46. self.model_dim = model_dim
  47. self.dropout = dropout
  48. self.scalar_parameters = Parameter(
  49. torch.zeros((num_layers, ), requires_grad=True))
  50. self.gamma = Parameter(torch.FloatTensor([1.0]), requires_grad=True)
  51. if self.dropout:
  52. dropout_mask = torch.zeros(len(self.scalar_parameters))
  53. dropout_fill = torch.empty(len(
  54. self.scalar_parameters)).fill_(-1e20)
  55. self.register_buffer('dropout_mask', dropout_mask)
  56. self.register_buffer('dropout_fill', dropout_fill)
  57. def forward(
  58. self,
  59. tensors: List[torch.Tensor], # pylint: disable=arguments-differ
  60. mask: torch.Tensor = None,
  61. ) -> torch.Tensor:
  62. tensors = torch.cat(list(x.unsqueeze(dim=0) for x in tensors), dim=0)
  63. if self.training and self.dropout:
  64. normed_weights = softmax(
  65. torch.where(self.dropout_mask.uniform_() > self.dropout,
  66. self.scalar_parameters, self.dropout_fill),
  67. dim=-1)
  68. else:
  69. normed_weights = softmax(self.scalar_parameters, dim=-1)
  70. normed_weights = normed_weights.view(-1, 1, 1, 1)
  71. mask_float = mask.float()
  72. weighted_sum = (normed_weights
  73. * _layer_norm_all(tensors, mask_float)).sum(dim=0)
  74. weighted_sum = weighted_sum[:, 0, :]
  75. return self.gamma * weighted_sum
  76. class FeedForward(Module):
  77. def __init__(
  78. self,
  79. in_dim: int,
  80. out_dim: int = 1,
  81. hidden_sizes: List[int] = [3072, 768],
  82. activations: str = 'Sigmoid',
  83. final_activation: Optional[str] = None,
  84. dropout: float = 0.1,
  85. ) -> None:
  86. """
  87. Feed Forward Neural Network.
  88. Args:
  89. in_dim (:obj:`int`):
  90. Number of input features.
  91. out_dim (:obj:`int`, defaults to 1):
  92. Number of output features. Default is 1 -- a single scalar.
  93. hidden_sizes (:obj:`List[int]`, defaults to `[3072, 768]`):
  94. List with hidden layer sizes.
  95. activations (:obj:`str`, defaults to `Sigmoid`):
  96. Name of the activation function to be used in the hidden layers.
  97. final_activation (:obj:`str`, Optional, defaults to `None`):
  98. Name of the final activation function if any.
  99. dropout (:obj:`float`, defaults to 0.1):
  100. Dropout ratio to be used in the hidden layers.
  101. """
  102. super().__init__()
  103. modules = []
  104. modules.append(Linear(in_dim, hidden_sizes[0]))
  105. modules.append(self.build_activation(activations))
  106. modules.append(Dropout(dropout))
  107. for i in range(1, len(hidden_sizes)):
  108. modules.append(Linear(hidden_sizes[i - 1], hidden_sizes[i]))
  109. modules.append(self.build_activation(activations))
  110. modules.append(Dropout(dropout))
  111. modules.append(Linear(hidden_sizes[-1], int(out_dim)))
  112. if final_activation is not None:
  113. modules.append(self.build_activation(final_activation))
  114. self.ff = Sequential(*modules)
  115. def build_activation(self, activation: str) -> Module:
  116. return ACT2FN[activation]
  117. def forward(self, in_features: torch.Tensor) -> torch.Tensor:
  118. return self.ff(in_features)
  119. @MODELS.register_module(Tasks.translation_evaluation, module_name=Models.unite)
  120. class UniTEForTranslationEvaluation(TorchModel):
  121. def __init__(self,
  122. attention_probs_dropout_prob: float = 0.1,
  123. bos_token_id: int = 0,
  124. eos_token_id: int = 2,
  125. pad_token_id: int = 1,
  126. hidden_act: str = 'gelu',
  127. hidden_dropout_prob: float = 0.1,
  128. hidden_size: int = 1024,
  129. initializer_range: float = 0.02,
  130. intermediate_size: int = 4096,
  131. layer_norm_eps: float = 1e-05,
  132. max_position_embeddings: int = 512,
  133. num_attention_heads: int = 16,
  134. num_hidden_layers: int = 24,
  135. type_vocab_size: int = 1,
  136. use_cache: bool = True,
  137. vocab_size: int = 250002,
  138. mlp_hidden_sizes: List[int] = [3072, 1024],
  139. mlp_act: str = 'tanh',
  140. mlp_final_act: Optional[str] = None,
  141. mlp_dropout: float = 0.1,
  142. **kwargs):
  143. r"""The UniTE Model which outputs the scalar to describe the corresponding
  144. translation quality of hypothesis. The model architecture includes two
  145. modules: a pre-trained language model (PLM) to derive representations,
  146. and a multi-layer perceptron (MLP) to give predicted score.
  147. Args:
  148. attention_probs_dropout_prob (:obj:`float`, defaults to 0.1):
  149. The dropout ratio for attention weights inside PLM.
  150. bos_token_id (:obj:`int`, defaults to 0):
  151. The numeric id representing beginning-of-sentence symbol.
  152. eos_token_id (:obj:`int`, defaults to 2):
  153. The numeric id representing ending-of-sentence symbol.
  154. pad_token_id (:obj:`int`, defaults to 1):
  155. The numeric id representing padding symbol.
  156. hidden_act (:obj:`str`, defaults to :obj:`"gelu"`):
  157. Activation inside PLM.
  158. hidden_dropout_prob (:obj:`float`, defaults to 0.1):
  159. The dropout ratio for activation states inside PLM.
  160. hidden_size (:obj:`int`, defaults to 1024):
  161. The dimensionality of PLM.
  162. initializer_range (:obj:`float`, defaults to 0.02):
  163. The hyper-parameter for initializing PLM.
  164. intermediate_size (:obj:`int`, defaults to 4096):
  165. The dimensionality of PLM inside feed-forward block.
  166. layer_norm_eps (:obj:`float`, defaults to 1e-5):
  167. The value for setting epsilon to avoid zero-division inside
  168. layer normalization.
  169. max_position_embeddings: (:obj:`int`, defaults to 512):
  170. The maximum value for identifying the length of input sequence.
  171. num_attention_heads (:obj:`int`, defaults to 16):
  172. The number of attention heads inside multi-head attention layer.
  173. num_hidden_layers (:obj:`int`, defaults to 24):
  174. The number of layers inside PLM.
  175. type_vocab_size (:obj:`int`, defaults to 1):
  176. The number of type embeddings.
  177. use_cache (:obj:`bool`, defaults to :obj:`True`):
  178. Whether to use cached buffer to initialize PLM.
  179. vocab_size (:obj:`int`, defaults to 250002):
  180. The size of vocabulary.
  181. mlp_hidden_sizes (:obj:`List[int]`, defaults to `[3072, 1024]`):
  182. The size of hidden states inside MLP.
  183. mlp_act (:obj:`str`, defaults to :obj:`"tanh"`):
  184. Activation inside MLP.
  185. mlp_final_act (:obj:`str`, `optional`, defaults to :obj:`None`):
  186. Activation at the end of MLP.
  187. mlp_dropout (:obj:`float`, defaults to 0.1):
  188. The dropout ratio for MLP.
  189. """
  190. super().__init__(**kwargs)
  191. self.attention_probs_dropout_prob = attention_probs_dropout_prob
  192. self.bos_token_id = bos_token_id
  193. self.eos_token_id = eos_token_id
  194. self.pad_token_id = pad_token_id
  195. self.hidden_act = hidden_act
  196. self.hidden_dropout_prob = hidden_dropout_prob
  197. self.hidden_size = hidden_size
  198. self.initializer_range = initializer_range
  199. self.intermediate_size = intermediate_size
  200. self.layer_norm_eps = layer_norm_eps
  201. self.max_position_embeddings = max_position_embeddings
  202. self.num_attention_heads = num_attention_heads
  203. self.num_hidden_layers = num_hidden_layers
  204. self.type_vocab_size = type_vocab_size
  205. self.use_cache = use_cache
  206. self.vocab_size = vocab_size
  207. self.mlp_hidden_sizes = mlp_hidden_sizes
  208. self.mlp_act = mlp_act
  209. self.mlp_final_act = mlp_final_act
  210. self.mlp_dropout = mlp_dropout
  211. self.encoder_config = XLMRobertaConfig(
  212. bos_token_id=self.bos_token_id,
  213. eos_token_id=self.eos_token_id,
  214. pad_token_id=self.pad_token_id,
  215. vocab_size=self.vocab_size,
  216. hidden_size=self.hidden_size,
  217. num_hidden_layers=self.num_hidden_layers,
  218. num_attention_heads=self.num_attention_heads,
  219. intermediate_size=self.intermediate_size,
  220. hidden_act=self.hidden_act,
  221. hidden_dropout_prob=self.hidden_dropout_prob,
  222. attention_probs_dropout_prob=self.attention_probs_dropout_prob,
  223. max_position_embeddings=self.max_position_embeddings,
  224. type_vocab_size=self.type_vocab_size,
  225. initializer_range=self.initializer_range,
  226. layer_norm_eps=self.layer_norm_eps,
  227. use_cache=self.use_cache)
  228. self.encoder = XLMRobertaModel(
  229. self.encoder_config, add_pooling_layer=False)
  230. self.layerwise_attention = LayerwiseAttention(
  231. num_layers=self.num_hidden_layers + 1,
  232. model_dim=self.hidden_size,
  233. dropout=self.mlp_dropout)
  234. self.estimator = FeedForward(
  235. in_dim=self.hidden_size,
  236. out_dim=1,
  237. hidden_sizes=self.mlp_hidden_sizes,
  238. activations=self.mlp_act,
  239. final_activation=self.mlp_final_act,
  240. dropout=self.mlp_dropout)
  241. return
  242. def forward(self,
  243. input_ids: torch.Tensor,
  244. input_format: Optional[List[InputFormat]] = None,
  245. score: Optional[torch.Tensor] = None,
  246. **kwargs) -> TranslationEvaluationOutput:
  247. attention_mask = input_ids.ne(self.pad_token_id).long()
  248. outputs = self.encoder(
  249. input_ids=input_ids,
  250. attention_mask=attention_mask,
  251. output_hidden_states=True,
  252. return_dict=True)
  253. mix_states = self.layerwise_attention(outputs['hidden_states'],
  254. attention_mask)
  255. pred = self.estimator(mix_states).squeeze(dim=-1)
  256. output = TranslationEvaluationOutput(
  257. score=pred.cpu().tolist(), input_format=input_format)
  258. if score is not None:
  259. loss = (pred - score).pow(2).mean()
  260. output['loss'] = loss
  261. return output
  262. def load_checkpoint(self, path: str, device: torch.device, plm_only: bool):
  263. if plm_only:
  264. self.encoder = self.encoder.from_pretrained(path).to(device)
  265. self.encoder.pooler = None
  266. else:
  267. state_dict = torch.load(path, map_location=device)
  268. compatible_position_ids(state_dict,
  269. 'encoder.embeddings.position_ids')
  270. self.load_state_dict(state_dict)
  271. logger.info('Loading checkpoint parameters from %s' % path)
  272. return
  273. def combine_input_sentences(all_input_concat: List[List[torch.Tensor]],
  274. maximum_length: int = 512,
  275. pad_idx: int = 1,
  276. eos_idx: int = 2):
  277. for group in all_input_concat[1:]:
  278. group[:, 0] = eos_idx
  279. if len(all_input_concat) == 3:
  280. return cut_long_sequences3(all_input_concat, maximum_length, pad_idx)
  281. else:
  282. return cut_long_sequences2(all_input_concat, maximum_length, pad_idx)
  283. def cut_long_sequences2(all_input_concat: List[List[torch.Tensor]],
  284. maximum_length: int = 512,
  285. pad_idx: int = 1):
  286. all_input_concat = list(zip(*all_input_concat))
  287. collected_tuples = list()
  288. for tensor_tuple in all_input_concat:
  289. tensor_tuple = tuple(
  290. x.masked_select(x.ne(pad_idx)) for x in tensor_tuple)
  291. all_lens = tuple(len(x) for x in tensor_tuple)
  292. if sum(all_lens) > maximum_length:
  293. lengths = dict(enumerate(all_lens))
  294. lengths_sorted_idxes = list(x[0] for x in sorted(
  295. lengths.items(), key=lambda d: d[1], reverse=True))
  296. offset = ceil((sum(lengths.values()) - maximum_length) / 2)
  297. if min(all_lens) > (maximum_length
  298. // 2) and min(all_lens) > offset:
  299. lengths = dict((k, v - offset) for k, v in lengths.items())
  300. else:
  301. lengths[lengths_sorted_idxes[0]] = maximum_length - lengths[
  302. lengths_sorted_idxes[1]]
  303. new_lens = list(lengths[k] for k in range(0, len(tensor_tuple)))
  304. new_tensor_tuple = tuple(x[:y]
  305. for x, y in zip(tensor_tuple, new_lens))
  306. for x, y in zip(new_tensor_tuple, tensor_tuple):
  307. x[-1] = y[-1]
  308. collected_tuples.append(new_tensor_tuple)
  309. else:
  310. collected_tuples.append(tensor_tuple)
  311. concat_tensor = list(torch.cat(x, dim=0) for x in collected_tuples)
  312. all_input_concat_padded = pad_sequence(
  313. concat_tensor, batch_first=True, padding_value=pad_idx)
  314. return all_input_concat_padded
  315. def cut_long_sequences3(all_input_concat: List[List[torch.Tensor]],
  316. maximum_length: int = 512,
  317. pad_idx: int = 1):
  318. all_input_concat = list(zip(*all_input_concat))
  319. collected_tuples = list()
  320. for tensor_tuple in all_input_concat:
  321. tensor_tuple = tuple(
  322. x.masked_select(x.ne(pad_idx)) for x in tensor_tuple)
  323. all_lens = tuple(len(x) for x in tensor_tuple)
  324. if sum(all_lens) > maximum_length:
  325. lengths = dict(enumerate(all_lens))
  326. lengths_sorted_idxes = list(x[0] for x in sorted(
  327. lengths.items(), key=lambda d: d[1], reverse=True))
  328. offset = ceil((sum(lengths.values()) - maximum_length) / 3)
  329. if min(all_lens) > (maximum_length
  330. // 3) and min(all_lens) > offset:
  331. lengths = dict((k, v - offset) for k, v in lengths.items())
  332. else:
  333. while sum(lengths.values()) > maximum_length:
  334. if lengths[lengths_sorted_idxes[0]] > lengths[
  335. lengths_sorted_idxes[1]]:
  336. offset = maximum_length - lengths[lengths_sorted_idxes[
  337. 1]] - lengths[lengths_sorted_idxes[2]]
  338. if offset > lengths[lengths_sorted_idxes[1]]:
  339. lengths[lengths_sorted_idxes[0]] = offset
  340. else:
  341. lengths[lengths_sorted_idxes[0]] = lengths[
  342. lengths_sorted_idxes[1]]
  343. elif lengths[lengths_sorted_idxes[0]] == lengths[
  344. lengths_sorted_idxes[1]] > lengths[
  345. lengths_sorted_idxes[2]]:
  346. offset = (maximum_length
  347. - lengths[lengths_sorted_idxes[2]]) // 2
  348. if offset > lengths[lengths_sorted_idxes[2]]:
  349. lengths[lengths_sorted_idxes[0]] = lengths[
  350. lengths_sorted_idxes[1]] = offset
  351. else:
  352. lengths[lengths_sorted_idxes[0]] = lengths[
  353. lengths_sorted_idxes[1]] = lengths[
  354. lengths_sorted_idxes[2]]
  355. else:
  356. lengths[lengths_sorted_idxes[0]] = lengths[
  357. lengths_sorted_idxes[1]] = lengths[
  358. lengths_sorted_idxes[2]] = maximum_length // 3
  359. new_lens = list(lengths[k] for k in range(0, len(lengths)))
  360. new_tensor_tuple = tuple(x[:y]
  361. for x, y in zip(tensor_tuple, new_lens))
  362. for x, y in zip(new_tensor_tuple, tensor_tuple):
  363. x[-1] = y[-1]
  364. collected_tuples.append(new_tensor_tuple)
  365. else:
  366. collected_tuples.append(tensor_tuple)
  367. concat_tensor = list(torch.cat(x, dim=0) for x in collected_tuples)
  368. all_input_concat_padded = pad_sequence(
  369. concat_tensor, batch_first=True, padding_value=pad_idx)
  370. return all_input_concat_padded