backbone.py 37 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900
  1. # Copyright 2021-2022 The Alibaba DAMO Team Authors.
  2. # Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team.
  3. # All rights reserved.
  4. #
  5. # Licensed under the Apache License, Version 2.0 (the "License");
  6. # you may not use this file except in compliance with the License.
  7. # You may obtain a copy of the License at
  8. #
  9. # http://www.apache.org/licenses/LICENSE-2.0
  10. #
  11. # Unless required by applicable law or agreed to in writing, software
  12. # distributed under the License is distributed on an "AS IS" BASIS,
  13. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  14. # See the License for the specific language governing permissions and
  15. # limitations under the License.
  16. """PyTorch PoNet model. """
  17. import math
  18. from distutils.version import LooseVersion
  19. import torch
  20. import torch.utils.checkpoint
  21. from packaging import version
  22. from torch import nn
  23. from transformers.activations import ACT2FN
  24. from transformers.modeling_utils import PreTrainedModel
  25. from modelscope.metainfo import Models
  26. from modelscope.models import Model, TorchModel
  27. from modelscope.models.builder import MODELS
  28. from modelscope.outputs import AttentionBackboneModelOutput
  29. from modelscope.utils.constant import Tasks
  30. from modelscope.utils.logger import get_logger
  31. from modelscope.utils.torch_utils import (apply_chunking_to_forward,
  32. find_pruneable_heads_and_indices,
  33. prune_linear_layer)
  34. from .configuration import PoNetConfig
  35. logger = get_logger()
  36. is_pytorch_12plus = LooseVersion(torch.__version__) >= LooseVersion('1.12.0')
  37. CLS_ID = 101
  38. EOS_ID = 102
  39. def segment_max(src, index, dim=1):
  40. if is_pytorch_12plus:
  41. out = torch.zeros_like(src).scatter_reduce(
  42. dim,
  43. index[:, :, None].expand_as(src),
  44. src,
  45. reduce='amax',
  46. include_self=False)
  47. else:
  48. dummy_scatter_index = index[:, :, None].expand_as(src)
  49. min_value = src.min() - 1
  50. dummpy_scatter_shape = (*src.shape[:-1], index.max() + 1,
  51. src.shape[-1])
  52. dummy_scatter_index_expand = dummy_scatter_index.unsqueeze(-2).expand(
  53. *dummpy_scatter_shape)
  54. index_reconstruct_expand = torch.arange(
  55. index.max() + 1,
  56. device=src.device)[None, None, :,
  57. None].expand(*dummpy_scatter_shape)
  58. src_expand = src.unsqueeze(-2).expand(*dummpy_scatter_shape)
  59. out, _ = src_expand.masked_scatter(
  60. dummy_scatter_index_expand != index_reconstruct_expand,
  61. torch.full_like(src_expand, min_value.item())).max(dim=1)
  62. dummy = index.unsqueeze(-1).expand(*index.shape[:2], out.size(-1))
  63. return torch.gather(out, dim, dummy).to(dtype=src.dtype)
  64. def get_segment_index(input_ids, cls_id=CLS_ID, eos_id=EOS_ID):
  65. mask = (input_ids == cls_id).to(
  66. dtype=torch.long) + (input_ids == eos_id).to(dtype=torch.long)
  67. mask = mask + torch.cat([torch.zeros_like(mask[:, 0:1]), mask[:, :-1]],
  68. dim=1)
  69. return mask.cumsum(dim=1) - 1
  70. def get_token_type_mask(input_ids, cls_id=CLS_ID, eos_id=EOS_ID):
  71. mask = (input_ids == cls_id) | (input_ids == eos_id)
  72. return mask
  73. def get_win_max(hidden_states, kernel_size=3):
  74. m = nn.MaxPool1d(kernel_size, stride=1, padding=kernel_size // 2)
  75. out = m(hidden_states.permute(0, 2, 1)).permute(0, 2, 1)
  76. return out
  77. class PoNetEmbeddings(nn.Module):
  78. """Construct the embeddings from word, position and token_type embeddings."""
  79. def __init__(self, config):
  80. super().__init__()
  81. self.word_embeddings = nn.Embedding(
  82. config.vocab_size,
  83. config.hidden_size,
  84. padding_idx=config.pad_token_id)
  85. self.position_embeddings = nn.Embedding(config.max_position_embeddings,
  86. config.hidden_size)
  87. self.token_type_embeddings = nn.Embedding(config.type_vocab_size,
  88. config.hidden_size)
  89. # self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load
  90. # any TensorFlow checkpoint file
  91. self.LayerNorm = nn.LayerNorm(
  92. config.hidden_size, eps=config.layer_norm_eps)
  93. self.dropout = nn.Dropout(config.hidden_dropout_prob)
  94. # position_ids (1, len position emb) is contiguous in memory and exported when serialized
  95. self.position_embedding_type = getattr(config,
  96. 'position_embedding_type',
  97. 'absolute')
  98. self.register_buffer(
  99. 'position_ids',
  100. torch.arange(config.max_position_embeddings).expand((1, -1)))
  101. if version.parse(torch.__version__) > version.parse('1.6.0'):
  102. self.register_buffer(
  103. 'token_type_ids',
  104. torch.zeros(
  105. self.position_ids.size(),
  106. dtype=torch.long,
  107. device=self.position_ids.device),
  108. persistent=False,
  109. )
  110. def forward(self,
  111. input_ids=None,
  112. token_type_ids=None,
  113. position_ids=None,
  114. inputs_embeds=None,
  115. past_key_values_length=0):
  116. if input_ids is not None:
  117. input_shape = input_ids.size()
  118. else:
  119. input_shape = inputs_embeds.size()[:-1]
  120. seq_length = input_shape[1]
  121. if position_ids is None:
  122. position_ids = self.position_ids[:,
  123. past_key_values_length:seq_length
  124. + past_key_values_length]
  125. if token_type_ids is None:
  126. if hasattr(self, 'token_type_ids'):
  127. buffered_token_type_ids = self.token_type_ids[:, :seq_length]
  128. buffered_token_type_ids_expanded = buffered_token_type_ids.expand(
  129. input_shape[0], seq_length)
  130. token_type_ids = buffered_token_type_ids_expanded
  131. else:
  132. token_type_ids = torch.zeros(
  133. input_shape,
  134. dtype=torch.long,
  135. device=self.position_ids.device)
  136. if inputs_embeds is None:
  137. inputs_embeds = self.word_embeddings(input_ids)
  138. token_type_embeddings = self.token_type_embeddings(token_type_ids)
  139. embeddings = inputs_embeds + token_type_embeddings
  140. if self.position_embedding_type == 'absolute':
  141. position_embeddings = self.position_embeddings(position_ids)
  142. embeddings += position_embeddings
  143. embeddings = self.LayerNorm(embeddings)
  144. embeddings = self.dropout(embeddings)
  145. return embeddings
  146. class PoNetSelfAttention(nn.Module):
  147. def __init__(self, config):
  148. super().__init__()
  149. self.dense_local = nn.Linear(config.hidden_size, config.hidden_size)
  150. self.dense_segment = nn.Linear(config.hidden_size, config.hidden_size)
  151. self.num_attention_heads = config.num_attention_heads
  152. self.clsgsepg = getattr(config, 'clsgsepg', True)
  153. self.attention_head_size = int(config.hidden_size
  154. / config.num_attention_heads)
  155. self.all_head_size = self.num_attention_heads * self.attention_head_size
  156. self.dense_q = nn.Linear(config.hidden_size, self.all_head_size)
  157. self.dense_k = nn.Linear(config.hidden_size, self.all_head_size)
  158. self.dense_o = nn.Linear(config.hidden_size, self.all_head_size)
  159. def transpose_for_scores(self, x):
  160. new_x_shape = x.size()[:-1] + (self.num_attention_heads,
  161. self.attention_head_size)
  162. x = x.view(*new_x_shape)
  163. return x.permute(0, 2, 1, 3) # bz, head, len, head_size
  164. def forward(
  165. self,
  166. hidden_states,
  167. segment_index,
  168. token_type_mask,
  169. attention_mask=None,
  170. head_mask=None,
  171. encoder_hidden_states=None,
  172. encoder_attention_mask=None,
  173. past_key_value=None,
  174. output_attentions=False,
  175. ):
  176. context_layer_q = self.transpose_for_scores(
  177. self.dense_q(hidden_states))
  178. context_layer_k = self.transpose_for_scores(
  179. self.dense_k(hidden_states))
  180. context_layer_v = context_layer_k
  181. context_layer_o = self.transpose_for_scores(
  182. self.dense_o(hidden_states))
  183. if attention_mask is not None:
  184. _attention_mask = (attention_mask.squeeze(1).unsqueeze(-1) < -1)
  185. if attention_mask is not None:
  186. context_layer_q.masked_fill_(_attention_mask, 0.0)
  187. q = context_layer_q.sum(dim=-2) / torch.ones_like(
  188. _attention_mask).to(dtype=context_layer_q.dtype).masked_fill(
  189. _attention_mask, 0.0).sum(dim=-2)
  190. else:
  191. q = context_layer_q.mean(dim=-2)
  192. att = torch.einsum('bdh,bdlh -> bdl', q, context_layer_k) / math.sqrt(
  193. context_layer_q.shape[-1])
  194. if attention_mask is not None:
  195. att = att + attention_mask.squeeze(1)
  196. att_prob = att.softmax(dim=-1)
  197. v = torch.einsum('bdlh,bdl->bdh', context_layer_v, att_prob)
  198. context_layer_segment = self.dense_segment(hidden_states)
  199. context_layer_local = self.dense_local(hidden_states)
  200. if attention_mask is not None:
  201. context_layer_local.masked_fill_(
  202. _attention_mask.squeeze(1), -10000)
  203. context_layer_segment.masked_fill_(
  204. _attention_mask.squeeze(1), -10000)
  205. if self.clsgsepg:
  206. # XXX: a trick to make sure the segment and local information will not leak
  207. context_layer_local = get_win_max(
  208. context_layer_local.masked_fill(
  209. token_type_mask.unsqueeze(dim=-1), -10000))
  210. context_layer_segment = segment_max(
  211. context_layer_segment, index=segment_index)
  212. context_layer_segment.masked_fill_(
  213. token_type_mask.unsqueeze(dim=-1), 0.0)
  214. context_layer_local.masked_fill_(
  215. token_type_mask.unsqueeze(dim=-1), 0.0)
  216. else:
  217. context_layer_local = get_win_max(context_layer_local)
  218. context_layer_segment = segment_max(
  219. context_layer_segment, index=segment_index)
  220. context_layer_local = self.transpose_for_scores(context_layer_local)
  221. context_layer_segment = self.transpose_for_scores(
  222. context_layer_segment)
  223. context_layer = (v.unsqueeze(dim=-2) + context_layer_segment
  224. ) * context_layer_o + context_layer_local
  225. context_layer = context_layer.permute(0, 2, 1, 3).reshape(
  226. *hidden_states.shape[:2], -1)
  227. if attention_mask is not None:
  228. context_layer.masked_fill_(_attention_mask.squeeze(1), 0.0)
  229. outputs = (context_layer,
  230. att_prob) if output_attentions else (context_layer, )
  231. return outputs
  232. class PoNetSelfOutput(nn.Module):
  233. def __init__(self, config):
  234. super().__init__()
  235. self.dense = nn.Linear(config.hidden_size, config.hidden_size)
  236. self.LayerNorm = nn.LayerNorm(
  237. config.hidden_size, eps=config.layer_norm_eps)
  238. self.dropout = nn.Dropout(config.hidden_dropout_prob)
  239. def forward(self, hidden_states, input_tensor):
  240. hidden_states = self.dense(hidden_states)
  241. hidden_states = self.dropout(hidden_states)
  242. hidden_states = self.LayerNorm(hidden_states + input_tensor)
  243. return hidden_states
  244. class PoNetIntermediate(nn.Module):
  245. def __init__(self, config):
  246. super().__init__()
  247. self.dense = nn.Linear(config.hidden_size, config.intermediate_size)
  248. if isinstance(config.hidden_act, str):
  249. self.intermediate_act_fn = ACT2FN[config.hidden_act]
  250. else:
  251. self.intermediate_act_fn = config.hidden_act
  252. def forward(self, hidden_states):
  253. hidden_states = self.dense(hidden_states)
  254. hidden_states = self.intermediate_act_fn(hidden_states)
  255. return hidden_states
  256. class PoNetOutput(nn.Module):
  257. def __init__(self, config):
  258. super().__init__()
  259. self.dense = nn.Linear(config.intermediate_size, config.hidden_size)
  260. self.LayerNorm = nn.LayerNorm(
  261. config.hidden_size, eps=config.layer_norm_eps)
  262. self.dropout = nn.Dropout(config.hidden_dropout_prob)
  263. def forward(self, hidden_states, input_tensor):
  264. hidden_states = self.dense(hidden_states)
  265. hidden_states = self.dropout(hidden_states)
  266. hidden_states = self.LayerNorm(hidden_states + input_tensor)
  267. return hidden_states
  268. class PoNetAttention(nn.Module):
  269. def __init__(self, config):
  270. super().__init__()
  271. self.self = PoNetSelfAttention(config)
  272. self.output = PoNetSelfOutput(config)
  273. self.pruned_heads = set()
  274. def prune_heads(self, heads):
  275. if len(heads) == 0:
  276. return
  277. heads, index = find_pruneable_heads_and_indices(
  278. heads, self.self.num_attention_heads,
  279. self.self.attention_head_size, self.pruned_heads)
  280. # Prune linear layers
  281. self.self.query = prune_linear_layer(self.self.query, index)
  282. self.self.key = prune_linear_layer(self.self.key, index)
  283. self.self.value = prune_linear_layer(self.self.value, index)
  284. self.output.dense = prune_linear_layer(self.output.dense, index, dim=1)
  285. # Update hyper params and store pruned heads
  286. self.self.num_attention_heads = self.self.num_attention_heads - len(
  287. heads)
  288. self.self.all_head_size = self.self.attention_head_size * self.self.num_attention_heads
  289. self.pruned_heads = self.pruned_heads.union(heads)
  290. def forward(
  291. self,
  292. hidden_states,
  293. segment_index,
  294. token_type_mask,
  295. attention_mask=None,
  296. head_mask=None,
  297. encoder_hidden_states=None,
  298. encoder_attention_mask=None,
  299. past_key_value=None,
  300. output_attentions=False,
  301. ):
  302. self_outputs = self.self(
  303. hidden_states,
  304. segment_index,
  305. token_type_mask,
  306. attention_mask,
  307. head_mask,
  308. encoder_hidden_states,
  309. encoder_attention_mask,
  310. past_key_value,
  311. output_attentions,
  312. )
  313. attention_output = self.output(self_outputs[0], hidden_states)
  314. outputs = (attention_output,
  315. ) + self_outputs[1:] # add attentions if we output them
  316. return outputs
  317. class PoNetLayer(nn.Module):
  318. def __init__(self, config):
  319. super().__init__()
  320. self.chunk_size_feed_forward = config.chunk_size_feed_forward
  321. self.seq_len_dim = 1
  322. self.attention = PoNetAttention(config)
  323. config.is_decoder = False # XXX: Decoder is not yet impletemented.
  324. self.is_decoder = config.is_decoder
  325. self.add_cross_attention = config.add_cross_attention
  326. if self.add_cross_attention:
  327. assert self.is_decoder, f'{self} should be used as a decoder model if cross attention is added'
  328. self.crossattention = PoNetAttention(config)
  329. self.intermediate = PoNetIntermediate(config)
  330. self.output = PoNetOutput(config)
  331. def forward(
  332. self,
  333. hidden_states,
  334. segment_index,
  335. token_type_mask,
  336. attention_mask=None,
  337. head_mask=None,
  338. encoder_hidden_states=None,
  339. encoder_attention_mask=None,
  340. past_key_value=None,
  341. output_attentions=False,
  342. ):
  343. # decoder uni-directional self-attention cached key/values tuple is at positions 1,2
  344. self_attn_past_key_value = past_key_value[:
  345. 2] if past_key_value is not None else None
  346. self_attention_outputs = self.attention(
  347. hidden_states,
  348. segment_index,
  349. token_type_mask,
  350. attention_mask,
  351. head_mask,
  352. output_attentions=output_attentions,
  353. past_key_value=self_attn_past_key_value,
  354. )
  355. attention_output = self_attention_outputs[0]
  356. # if decoder, the last output is tuple of self-attn cache
  357. if self.is_decoder:
  358. outputs = self_attention_outputs[1:-1]
  359. present_key_value = self_attention_outputs[-1]
  360. else:
  361. outputs = self_attention_outputs[
  362. 1:] # add self attentions if we output attention weights
  363. cross_attn_present_key_value = None
  364. if self.is_decoder and encoder_hidden_states is not None:
  365. assert hasattr(
  366. self, 'crossattention'
  367. ), f'If `encoder_hidden_states` are passed, {self} has to be instantiated with cross-attention layers by setting `config.add_cross_attention=True`' # noqa *
  368. cross_attn_past_key_value = past_key_value[
  369. -2:] if past_key_value is not None else None
  370. cross_attention_outputs = self.crossattention(
  371. attention_output,
  372. attention_mask,
  373. head_mask,
  374. encoder_hidden_states,
  375. encoder_attention_mask,
  376. cross_attn_past_key_value,
  377. output_attentions,
  378. )
  379. attention_output = cross_attention_outputs[0]
  380. outputs = outputs + cross_attention_outputs[
  381. 1:-1] # add cross attentions if we output attention weights
  382. # add cross-attn cache to positions 3,4 of present_key_value tuple
  383. cross_attn_present_key_value = cross_attention_outputs[-1]
  384. present_key_value = present_key_value + cross_attn_present_key_value
  385. layer_output = apply_chunking_to_forward(self.feed_forward_chunk,
  386. self.chunk_size_feed_forward,
  387. self.seq_len_dim,
  388. attention_output)
  389. outputs = (layer_output, ) + outputs
  390. # if decoder, return the attn key/values as the last output
  391. if self.is_decoder:
  392. outputs = outputs + (present_key_value, )
  393. return outputs
  394. def feed_forward_chunk(self, attention_output):
  395. intermediate_output = self.intermediate(attention_output)
  396. layer_output = self.output(intermediate_output, attention_output)
  397. return layer_output
  398. class PoNetEncoder(nn.Module):
  399. def __init__(self, config):
  400. super().__init__()
  401. self.config = config
  402. self.layer = nn.ModuleList(
  403. [PoNetLayer(config) for _ in range(config.num_hidden_layers)])
  404. def forward(
  405. self,
  406. hidden_states,
  407. segment_index,
  408. token_type_mask,
  409. attention_mask=None,
  410. head_mask=None,
  411. encoder_hidden_states=None,
  412. encoder_attention_mask=None,
  413. past_key_values=None,
  414. use_cache=None,
  415. output_attentions=False,
  416. output_hidden_states=False,
  417. return_dict=True,
  418. ):
  419. all_hidden_states = () if output_hidden_states else None
  420. all_self_attentions = () if output_attentions else None
  421. all_cross_attentions = (
  422. ) if output_attentions and self.config.add_cross_attention else None
  423. next_decoder_cache = () if use_cache else None
  424. for i, layer_module in enumerate(self.layer):
  425. if output_hidden_states:
  426. all_hidden_states = all_hidden_states + (hidden_states, )
  427. layer_head_mask = head_mask[i] if head_mask is not None else None
  428. past_key_value = past_key_values[
  429. i] if past_key_values is not None else None
  430. if getattr(self.config, 'gradient_checkpointing',
  431. False) and self.training:
  432. if use_cache:
  433. logger.warning(
  434. '`use_cache=True` is incompatible with `config.gradient_checkpointing=True`. Setting '
  435. '`use_cache=False`...')
  436. use_cache = False
  437. def create_custom_forward(module):
  438. def custom_forward(*inputs):
  439. return module(*inputs, past_key_value,
  440. output_attentions)
  441. return custom_forward
  442. layer_outputs = torch.utils.checkpoint.checkpoint(
  443. create_custom_forward(layer_module),
  444. hidden_states,
  445. segment_index,
  446. token_type_mask,
  447. attention_mask,
  448. layer_head_mask,
  449. encoder_hidden_states,
  450. encoder_attention_mask,
  451. )
  452. else:
  453. layer_outputs = layer_module(
  454. hidden_states,
  455. segment_index,
  456. token_type_mask,
  457. attention_mask,
  458. layer_head_mask,
  459. encoder_hidden_states,
  460. encoder_attention_mask,
  461. past_key_value,
  462. output_attentions,
  463. )
  464. hidden_states = layer_outputs[0]
  465. if use_cache:
  466. next_decoder_cache += (layer_outputs[-1], )
  467. if output_attentions:
  468. all_self_attentions = all_self_attentions + (
  469. layer_outputs[1], )
  470. if self.config.add_cross_attention:
  471. all_cross_attentions = all_cross_attentions + (
  472. layer_outputs[2], )
  473. if output_hidden_states:
  474. all_hidden_states = all_hidden_states + (hidden_states, )
  475. if not return_dict:
  476. return tuple(v for v in [
  477. hidden_states,
  478. next_decoder_cache,
  479. all_hidden_states,
  480. all_self_attentions,
  481. all_cross_attentions,
  482. ] if v is not None)
  483. return AttentionBackboneModelOutput(
  484. last_hidden_state=hidden_states,
  485. past_key_values=next_decoder_cache,
  486. hidden_states=all_hidden_states,
  487. attentions=all_self_attentions,
  488. cross_attentions=all_cross_attentions,
  489. )
  490. class PoNetPooler(nn.Module):
  491. def __init__(self, config):
  492. super().__init__()
  493. self.dense = nn.Linear(config.hidden_size, config.hidden_size)
  494. self.activation = nn.Tanh()
  495. def forward(self, hidden_states):
  496. # We "pool" the model by simply taking the hidden state corresponding
  497. # to the first token.
  498. first_token_tensor = hidden_states[:, 0]
  499. pooled_output = self.dense(first_token_tensor)
  500. pooled_output = self.activation(pooled_output)
  501. return pooled_output
  502. class PoNetPreTrainedModel(TorchModel, PreTrainedModel):
  503. """
  504. A base class to handle weights initialization and a simple interface for loading pretrained models.
  505. """
  506. config_class = PoNetConfig
  507. base_model_prefix = 'ponet'
  508. _keys_to_ignore_on_load_missing = [r'position_ids']
  509. def __init__(self, config, **kwargs):
  510. super().__init__(config.name_or_path, **kwargs)
  511. super(Model, self).__init__(config)
  512. def _init_weights(self, module):
  513. """Initialize the weights"""
  514. if isinstance(module, nn.Linear):
  515. # Slightly different from the TF version which uses truncated_normal for initialization
  516. # cf https://github.com/pytorch/pytorch/pull/5617
  517. module.weight.data.normal_(
  518. mean=0.0, std=self.config.initializer_range)
  519. if module.bias is not None:
  520. module.bias.data.zero_()
  521. elif isinstance(module, nn.Embedding):
  522. module.weight.data.normal_(
  523. mean=0.0, std=self.config.initializer_range)
  524. if module.padding_idx is not None:
  525. module.weight.data[module.padding_idx].zero_()
  526. elif isinstance(module, nn.LayerNorm):
  527. module.bias.data.zero_()
  528. module.weight.data.fill_(1.0)
  529. @classmethod
  530. def _instantiate(cls, **kwargs):
  531. model_dir = kwargs.pop('model_dir', None)
  532. if model_dir is None:
  533. ponet_config = PoNetConfig(**kwargs)
  534. model = cls(ponet_config)
  535. else:
  536. model = super(
  537. Model,
  538. cls).from_pretrained(pretrained_model_name_or_path=model_dir)
  539. return model
  540. @MODELS.register_module(Tasks.backbone, module_name=Models.ponet)
  541. class PoNetModel(PoNetPreTrainedModel):
  542. """The bare PoNet Model transformer outputting raw hidden-states without any specific head on top.
  543. This model inherits from :class:`~transformers.PreTrainedModel`. Check the superclass documentation for the generic
  544. methods the library implements for all its model (such as downloading or saving, resizing the input embeddings,
  545. pruning heads etc.)
  546. This model is also a PyTorch `torch.nn.Module <https://pytorch.org/docs/stable/nn.html#torch.nn.Module>`__
  547. subclass. Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to
  548. general usage and behavior.
  549. Parameters:
  550. config (:class:`~modelscope.models.nlp.ponet.PoNetConfig`):
  551. Model configuration class with all the parameters of the model.
  552. Initializing with a config file does not load the weights associated with the model, only the
  553. configuration. Check out the :meth:`~transformers.PreTrainedModel.from_pretrained` method to load the model
  554. weights.
  555. The model can behave as an encoder (with only self-attention) as well as a decoder, in which case a layer of
  556. cross-attention is added between the self-attention layers, following the architecture described in `Attention is
  557. all you need <https://arxiv.org/abs/1706.03762>`__ by Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit,
  558. Llion Jones, Aidan N. Gomez, Lukasz Kaiser and Illia Polosukhin.
  559. To behave as an decoder the model needs to be initialized with the :obj:`is_decoder` argument of the configuration
  560. set to :obj:`True`. To be used in a Seq2Seq model, the model needs to initialized with both :obj:`is_decoder`
  561. argument and :obj:`add_cross_attention` set to :obj:`True`; an :obj:`encoder_hidden_states` is then expected as an
  562. input to the forward pass.
  563. """
  564. def __init__(self, config, add_pooling_layer=True, **kwargs):
  565. super().__init__(config, **kwargs)
  566. self.config = config
  567. self.embeddings = PoNetEmbeddings(config)
  568. self.encoder = PoNetEncoder(config)
  569. self.pooler = PoNetPooler(config) if add_pooling_layer else None
  570. self.init_weights()
  571. def get_input_embeddings(self):
  572. return self.embeddings.word_embeddings
  573. def set_input_embeddings(self, value):
  574. self.embeddings.word_embeddings = value
  575. def _prune_heads(self, heads_to_prune):
  576. """
  577. Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base
  578. class PreTrainedModel
  579. """
  580. for layer, heads in heads_to_prune.items():
  581. self.encoder.layer[layer].attention.prune_heads(heads)
  582. def forward(
  583. self,
  584. input_ids=None,
  585. attention_mask=None,
  586. token_type_ids=None,
  587. segment_ids=None,
  588. position_ids=None,
  589. head_mask=None,
  590. inputs_embeds=None,
  591. encoder_hidden_states=None,
  592. encoder_attention_mask=None,
  593. past_key_values=None,
  594. use_cache=None,
  595. output_attentions=None,
  596. output_hidden_states=None,
  597. return_dict=None,
  598. ):
  599. r"""
  600. Args:
  601. input_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`):
  602. Indices of input sequence tokens in the vocabulary.
  603. Indices can be obtained using :class:`~modelscope.models.nlp.ponet.PoNetTokenizer`. See
  604. :meth:`transformers.PreTrainedTokenizer.encode` and :meth:`transformers.PreTrainedTokenizer.__call__`
  605. for details.
  606. attention_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
  607. Mask to avoid performing attention on padding token indices. Mask values selected in ``[0, 1]``:
  608. - 1 for tokens that are **not masked**,
  609. - 0 for tokens that are **masked**.
  610. token_type_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
  611. Segment token indices to indicate first and second portions of the inputs. Indices are selected in
  612. ``[0,1]``:
  613. - 0 corresponds to a `sentence A` token,
  614. - 1 corresponds to a `sentence B` token.
  615. position_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
  616. Indices of positions of each input sequence tokens in the position embeddings. Selected in the range
  617. ``[0,config.max_position_embeddings - 1]``.
  618. head_mask (:obj:`torch.FloatTensor` of shape :obj:`(num_heads,)` or :obj:`(num_layers, num_heads)`,
  619. `optional`):
  620. Mask to nullify selected heads of the self-attention modules. Mask values selected in ``[0, 1]``:
  621. - 1 indicates the head is **not masked**,
  622. - 0 indicates the head is **masked**.
  623. inputs_embeds (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`,
  624. `optional`):
  625. Optionally, instead of passing :obj:`input_ids` you can choose to directly pass an embedded
  626. representation. This is useful if you want more control over how to convert :obj:`input_ids`
  627. indices into associated vectors than the model's internal embedding lookup matrix.
  628. output_attentions (:obj:`bool`, `optional`):
  629. Whether or not to return the attentions tensors of all attention layers. See ``attentions`` under
  630. returned tensors for more detail.
  631. output_hidden_states (:obj:`bool`, `optional`):
  632. Whether or not to return the hidden states of all layers. See ``hidden_states`` under returned tensors
  633. for more detail.
  634. return_dict (:obj:`bool`, `optional`):
  635. Whether or not to return a :class:`~transformers.file_utils.ModelOutput` instead of a plain tuple.
  636. encoder_hidden_states
  637. (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`):
  638. Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if
  639. the model is configured as a decoder.
  640. encoder_attention_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
  641. Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used
  642. in the cross-attention if the model is configured as a decoder. Mask values selected in ``[0, 1]``:
  643. - 1 for tokens that are **not masked**,
  644. - 0 for tokens that are **masked**.
  645. past_key_values (:obj:`tuple(tuple(torch.FloatTensor))` of length :obj:`config.n_layers`
  646. with each tuple having 4 tensors of shape :obj:
  647. `(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`):
  648. Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up
  649. decoding.
  650. If :obj:`past_key_values` are used, the user can optionally input only the last :obj:`decoder_input_ids`
  651. (those that don't have their past key value states given to this model) of shape :obj:`(batch_size, 1)`
  652. instead of all :obj:`decoder_input_ids` of shape :obj:`(batch_size, sequence_length)`.
  653. use_cache (:obj:`bool`, `optional`):
  654. If set to :obj:`True`, :obj:`past_key_values` key value states are returned and can be used to speed up
  655. decoding (see :obj:`past_key_values`).
  656. Returns:
  657. Returns `modelscope.outputs.AttentionBackboneModelOutput`
  658. Examples:
  659. >>> from modelscope.models import Model
  660. >>> from modelscope.preprocessors import Preprocessor
  661. >>> model = Model.from_pretrained('damo/nlp_ponet_fill-mask_chinese-base', task='backbone')
  662. >>> preprocessor = Preprocessor.from_pretrained('damo/nlp_ponet_fill-mask_chinese-base')
  663. >>> print(model(**preprocessor('这是个测试')))
  664. """
  665. output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
  666. output_hidden_states = (
  667. output_hidden_states if output_hidden_states is not None else
  668. self.config.output_hidden_states)
  669. return_dict = return_dict if return_dict is not None else self.config.use_return_dict
  670. if self.config.is_decoder:
  671. use_cache = use_cache if use_cache is not None else self.config.use_cache
  672. else:
  673. use_cache = False
  674. if input_ids is not None and inputs_embeds is not None:
  675. raise ValueError(
  676. 'You cannot specify both input_ids and inputs_embeds at the same time'
  677. )
  678. elif input_ids is not None:
  679. input_shape = input_ids.size()
  680. batch_size, seq_length = input_shape
  681. elif inputs_embeds is not None:
  682. input_shape = inputs_embeds.size()[:-1]
  683. batch_size, seq_length = input_shape
  684. else:
  685. raise ValueError(
  686. 'You have to specify either input_ids or inputs_embeds')
  687. device = input_ids.device if input_ids is not None else inputs_embeds.device
  688. # past_key_values_length
  689. past_key_values_length = past_key_values[0][0].shape[
  690. 2] if past_key_values is not None else 0
  691. if attention_mask is None:
  692. attention_mask = torch.ones(
  693. ((batch_size, seq_length + past_key_values_length)),
  694. device=device)
  695. if token_type_ids is None:
  696. token_type_ids = torch.zeros(
  697. input_shape, dtype=torch.long, device=device)
  698. # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
  699. # ourselves in which case we just need to make it broadcastable to all heads.
  700. extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(
  701. attention_mask, input_shape, device)
  702. # If a 2D or 3D attention mask is provided for the cross-attention
  703. # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length]
  704. if self.config.is_decoder and encoder_hidden_states is not None:
  705. encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size(
  706. )
  707. encoder_hidden_shape = (encoder_batch_size,
  708. encoder_sequence_length)
  709. if encoder_attention_mask is None:
  710. encoder_attention_mask = torch.ones(
  711. encoder_hidden_shape, device=device)
  712. encoder_extended_attention_mask = self.invert_attention_mask(
  713. encoder_attention_mask)
  714. else:
  715. encoder_extended_attention_mask = None
  716. # Prepare head mask if needed
  717. # 1.0 in head_mask indicate we keep the head
  718. # attention_probs has shape bsz x n_heads x N x N
  719. # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
  720. # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
  721. head_mask = self.get_head_mask(head_mask,
  722. self.config.num_hidden_layers)
  723. embedding_output = self.embeddings(
  724. input_ids=input_ids,
  725. position_ids=position_ids,
  726. token_type_ids=token_type_ids,
  727. inputs_embeds=inputs_embeds,
  728. past_key_values_length=past_key_values_length,
  729. )
  730. segment_index = get_segment_index(
  731. input_ids) if segment_ids is None else segment_ids
  732. token_type_mask = get_token_type_mask(input_ids)
  733. encoder_outputs = self.encoder(
  734. embedding_output,
  735. segment_index,
  736. token_type_mask,
  737. attention_mask=extended_attention_mask,
  738. head_mask=head_mask,
  739. encoder_hidden_states=encoder_hidden_states,
  740. encoder_attention_mask=encoder_extended_attention_mask,
  741. past_key_values=past_key_values,
  742. use_cache=use_cache,
  743. output_attentions=output_attentions,
  744. output_hidden_states=output_hidden_states,
  745. return_dict=return_dict,
  746. )
  747. sequence_output = encoder_outputs[0]
  748. pooled_output = self.pooler(
  749. sequence_output) if self.pooler is not None else None
  750. if not return_dict:
  751. return (sequence_output, pooled_output) + encoder_outputs[1:]
  752. return AttentionBackboneModelOutput(
  753. last_hidden_state=sequence_output,
  754. pooler_output=pooled_output,
  755. past_key_values=encoder_outputs.past_key_values,
  756. hidden_states=encoder_outputs.hidden_states,
  757. attentions=encoder_outputs.attentions,
  758. cross_attentions=encoder_outputs.cross_attentions,
  759. )