backbone.py 40 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919
  1. # Copyright 2021-2022 The Alibaba DAMO NLP Team Authors.
  2. # Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team.
  3. # Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
  4. # All rights reserved.
  5. #
  6. # Licensed under the Apache License, Version 2.0 (the "License");
  7. # you may not use this file except in compliance with the License.
  8. # You may obtain a copy of the License at
  9. #
  10. # http://www.apache.org/licenses/LICENSE-2.0
  11. #
  12. # Unless required by applicable law or agreed to in writing, software
  13. # distributed under the License is distributed on an "AS IS" BASIS,
  14. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  15. # See the License for the specific language governing permissions and
  16. # limitations under the License.
  17. """PyTorch StructBERT model. mainly copied from :module:`~transformers.modeling_bert`"""
  18. import math
  19. from dataclasses import dataclass
  20. from typing import Optional, Union
  21. import torch
  22. import torch.nn as nn
  23. import torch.utils.checkpoint
  24. from packaging import version
  25. from transformers.activations import ACT2FN
  26. from transformers.modeling_utils import PreTrainedModel
  27. from modelscope.metainfo import Models
  28. from modelscope.models import Model, TorchModel
  29. from modelscope.models.builder import MODELS
  30. from modelscope.outputs import AttentionBackboneModelOutput
  31. from modelscope.utils.constant import Tasks
  32. from modelscope.utils.logger import get_logger
  33. from modelscope.utils.nlp.utils import parse_labels_in_order
  34. from modelscope.utils.torch_utils import (apply_chunking_to_forward,
  35. find_pruneable_heads_and_indices,
  36. prune_linear_layer)
  37. from .configuration import SbertConfig
  38. logger = get_logger()
  39. class SbertEmbeddings(nn.Module):
  40. """Construct the embeddings from word, position and token_type embeddings."""
  41. def __init__(self, config):
  42. super().__init__()
  43. self.word_embeddings = nn.Embedding(
  44. config.vocab_size,
  45. config.hidden_size,
  46. padding_idx=config.pad_token_id)
  47. self.position_embeddings = nn.Embedding(config.max_position_embeddings,
  48. config.hidden_size)
  49. self.token_type_embeddings = nn.Embedding(config.type_vocab_size,
  50. config.hidden_size)
  51. # self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load
  52. # any TensorFlow checkpoint file
  53. self.LayerNorm = nn.LayerNorm(
  54. config.hidden_size, eps=config.layer_norm_eps)
  55. self.dropout = nn.Dropout(config.hidden_dropout_prob)
  56. # position_ids (1, len position emb) is contiguous in memory and exported when serialized
  57. self.position_embedding_type = getattr(config,
  58. 'position_embedding_type',
  59. 'absolute')
  60. self.register_buffer(
  61. 'position_ids',
  62. torch.arange(config.max_position_embeddings).expand((1, -1)))
  63. if version.parse(torch.__version__) > version.parse('1.6.0'):
  64. self.register_buffer(
  65. 'token_type_ids',
  66. torch.zeros(
  67. self.position_ids.size(),
  68. dtype=torch.long,
  69. device=self.position_ids.device),
  70. persistent=False,
  71. )
  72. def forward(self,
  73. input_ids=None,
  74. token_type_ids=None,
  75. position_ids=None,
  76. inputs_embeds=None,
  77. past_key_values_length=0,
  78. return_inputs_embeds=False):
  79. if input_ids is not None:
  80. input_shape = input_ids.size()
  81. else:
  82. input_shape = inputs_embeds.size()[:-1]
  83. seq_length = input_shape[1]
  84. if position_ids is None:
  85. position_ids = self.position_ids[:,
  86. past_key_values_length:seq_length
  87. + past_key_values_length]
  88. # Setting the token_type_ids to the registered buffer in constructor where it is all zeros, which usually occurs
  89. # when its auto-generated, registered buffer helps users
  90. # when tracing the model without passing token_type_ids, solves
  91. # issue #5664
  92. if token_type_ids is None:
  93. if hasattr(self, 'token_type_ids'):
  94. buffered_token_type_ids = self.token_type_ids[:, :seq_length]
  95. buffered_token_type_ids_expanded = buffered_token_type_ids.expand(
  96. input_shape[0], seq_length)
  97. token_type_ids = buffered_token_type_ids_expanded
  98. else:
  99. token_type_ids = torch.zeros(
  100. input_shape,
  101. dtype=torch.long,
  102. device=self.position_ids.device)
  103. if inputs_embeds is None:
  104. inputs_embeds = self.word_embeddings(input_ids)
  105. token_type_embeddings = self.token_type_embeddings(token_type_ids)
  106. embeddings = inputs_embeds + token_type_embeddings
  107. if self.position_embedding_type == 'absolute':
  108. position_embeddings = self.position_embeddings(position_ids)
  109. embeddings += position_embeddings
  110. embeddings = self.LayerNorm(embeddings)
  111. embeddings = self.dropout(embeddings)
  112. if not return_inputs_embeds:
  113. return embeddings
  114. else:
  115. return embeddings, inputs_embeds
  116. class SbertSelfAttention(nn.Module):
  117. def __init__(self, config):
  118. super().__init__()
  119. if config.hidden_size % config.num_attention_heads != 0 and not hasattr(
  120. config, 'embedding_size'):
  121. raise ValueError(
  122. f'The hidden size ({config.hidden_size}) is not a multiple of the number of attention '
  123. f'heads ({config.num_attention_heads})')
  124. self.num_attention_heads = config.num_attention_heads
  125. self.attention_head_size = int(config.hidden_size
  126. / config.num_attention_heads)
  127. self.all_head_size = self.num_attention_heads * self.attention_head_size
  128. self.query = nn.Linear(config.hidden_size, self.all_head_size)
  129. self.key = nn.Linear(config.hidden_size, self.all_head_size)
  130. self.value = nn.Linear(config.hidden_size, self.all_head_size)
  131. self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
  132. self.position_embedding_type = getattr(config,
  133. 'position_embedding_type',
  134. 'absolute')
  135. if self.position_embedding_type == 'relative_key' or self.position_embedding_type == 'relative_key_query':
  136. self.max_position_embeddings = config.max_position_embeddings
  137. self.distance_embedding = nn.Embedding(
  138. 2 * config.max_position_embeddings - 1,
  139. self.attention_head_size)
  140. self.is_decoder = config.is_decoder
  141. def transpose_for_scores(self, x):
  142. new_x_shape = x.size()[:-1] + (self.num_attention_heads,
  143. self.attention_head_size)
  144. x = x.view(*new_x_shape)
  145. return x.permute(0, 2, 1, 3)
  146. def forward(
  147. self,
  148. hidden_states,
  149. attention_mask=None,
  150. head_mask=None,
  151. encoder_hidden_states=None,
  152. encoder_attention_mask=None,
  153. past_key_value=None,
  154. output_attentions=False,
  155. ):
  156. mixed_query_layer = self.query(hidden_states)
  157. # If this is instantiated as a cross-attention module, the keys
  158. # and values come from an encoder; the attention mask needs to be
  159. # such that the encoder's padding tokens are not attended to.
  160. is_cross_attention = encoder_hidden_states is not None
  161. if is_cross_attention and past_key_value is not None:
  162. # reuse k,v, cross_attentions
  163. key_layer = past_key_value[0]
  164. value_layer = past_key_value[1]
  165. attention_mask = encoder_attention_mask
  166. elif is_cross_attention:
  167. key_layer = self.transpose_for_scores(
  168. self.key(encoder_hidden_states))
  169. value_layer = self.transpose_for_scores(
  170. self.value(encoder_hidden_states))
  171. attention_mask = encoder_attention_mask
  172. elif past_key_value is not None:
  173. key_layer = self.transpose_for_scores(self.key(hidden_states))
  174. value_layer = self.transpose_for_scores(self.value(hidden_states))
  175. key_layer = torch.cat([past_key_value[0], key_layer], dim=2)
  176. value_layer = torch.cat([past_key_value[1], value_layer], dim=2)
  177. else:
  178. key_layer = self.transpose_for_scores(self.key(hidden_states))
  179. value_layer = self.transpose_for_scores(self.value(hidden_states))
  180. query_layer = self.transpose_for_scores(mixed_query_layer)
  181. if self.is_decoder:
  182. # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states.
  183. # Further calls to cross_attention layer can then reuse all cross-attention
  184. # key/value_states (first "if" case)
  185. # if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of
  186. # all previous decoder key/value_states. Further calls to uni-directional self-attention
  187. # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case)
  188. # if encoder bi-directional self-attention `past_key_value` is always `None`
  189. past_key_value = (key_layer, value_layer)
  190. # Take the dot product between "query" and "key" to get the raw attention scores.
  191. attention_scores = torch.matmul(query_layer,
  192. key_layer.transpose(-1, -2))
  193. if self.position_embedding_type == 'relative_key' or self.position_embedding_type == 'relative_key_query':
  194. seq_length = hidden_states.size()[1]
  195. position_ids_l = torch.arange(
  196. seq_length, dtype=torch.long,
  197. device=hidden_states.device).view(-1, 1)
  198. position_ids_r = torch.arange(
  199. seq_length, dtype=torch.long,
  200. device=hidden_states.device).view(1, -1)
  201. distance = position_ids_l - position_ids_r
  202. positional_embedding = self.distance_embedding(
  203. distance + self.max_position_embeddings - 1)
  204. positional_embedding = positional_embedding.to(
  205. dtype=query_layer.dtype) # fp16 compatibility
  206. if self.position_embedding_type == 'relative_key':
  207. relative_position_scores = torch.einsum(
  208. 'bhld,lrd->bhlr', query_layer, positional_embedding)
  209. attention_scores = attention_scores + relative_position_scores
  210. elif self.position_embedding_type == 'relative_key_query':
  211. relative_position_scores_query = torch.einsum(
  212. 'bhld,lrd->bhlr', query_layer, positional_embedding)
  213. relative_position_scores_key = torch.einsum(
  214. 'bhrd,lrd->bhlr', key_layer, positional_embedding)
  215. attention_scores = attention_scores + relative_position_scores_query + relative_position_scores_key
  216. attention_scores = attention_scores / math.sqrt(
  217. self.attention_head_size)
  218. if attention_mask is not None:
  219. # Apply the attention mask is (precomputed for all layers in SbertModel forward() function)
  220. attention_scores = attention_scores + attention_mask
  221. # Normalize the attention scores to probabilities.
  222. attention_probs = nn.Softmax(dim=-1)(attention_scores)
  223. # This is actually dropping out entire tokens to attend to, which might
  224. # seem a bit unusual, but is taken from the original Transformer paper.
  225. attention_probs = self.dropout(attention_probs)
  226. # Mask heads if we want to
  227. if head_mask is not None:
  228. attention_probs = attention_probs * head_mask
  229. context_layer = torch.matmul(attention_probs, value_layer)
  230. context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
  231. new_context_layer_shape = context_layer.size()[:-2] + (
  232. self.all_head_size, )
  233. context_layer = context_layer.view(*new_context_layer_shape)
  234. outputs = (context_layer,
  235. attention_probs) if output_attentions else (context_layer, )
  236. if self.is_decoder:
  237. outputs = outputs + (past_key_value, )
  238. return outputs
  239. class SbertSelfOutput(nn.Module):
  240. def __init__(self, config):
  241. super().__init__()
  242. self.dense = nn.Linear(config.hidden_size, config.hidden_size)
  243. self.LayerNorm = nn.LayerNorm(
  244. config.hidden_size, eps=config.layer_norm_eps)
  245. self.dropout = nn.Dropout(config.hidden_dropout_prob)
  246. def forward(self, hidden_states, input_tensor):
  247. hidden_states = self.dense(hidden_states)
  248. hidden_states = self.dropout(hidden_states)
  249. hidden_states = self.LayerNorm(hidden_states + input_tensor)
  250. return hidden_states
  251. class SbertAttention(nn.Module):
  252. def __init__(self, config):
  253. super().__init__()
  254. self.self = SbertSelfAttention(config)
  255. self.output = SbertSelfOutput(config)
  256. self.pruned_heads = set()
  257. def prune_heads(self, heads):
  258. if len(heads) == 0:
  259. return
  260. heads, index = find_pruneable_heads_and_indices(
  261. heads, self.self.num_attention_heads,
  262. self.self.attention_head_size, self.pruned_heads)
  263. # Prune linear layers
  264. self.self.query = prune_linear_layer(self.self.query, index)
  265. self.self.key = prune_linear_layer(self.self.key, index)
  266. self.self.value = prune_linear_layer(self.self.value, index)
  267. self.output.dense = prune_linear_layer(self.output.dense, index, dim=1)
  268. # Update hyper params and store pruned heads
  269. self.self.num_attention_heads = self.self.num_attention_heads - len(
  270. heads)
  271. self.self.all_head_size = self.self.attention_head_size * self.self.num_attention_heads
  272. self.pruned_heads = self.pruned_heads.union(heads)
  273. def forward(
  274. self,
  275. hidden_states,
  276. attention_mask=None,
  277. head_mask=None,
  278. encoder_hidden_states=None,
  279. encoder_attention_mask=None,
  280. past_key_value=None,
  281. output_attentions=False,
  282. ):
  283. self_outputs = self.self(
  284. hidden_states,
  285. attention_mask,
  286. head_mask,
  287. encoder_hidden_states,
  288. encoder_attention_mask,
  289. past_key_value,
  290. output_attentions,
  291. )
  292. attention_output = self.output(self_outputs[0], hidden_states)
  293. outputs = (attention_output,
  294. ) + self_outputs[1:] # add attentions if we output them
  295. return outputs
  296. class SbertIntermediate(nn.Module):
  297. def __init__(self, config):
  298. super().__init__()
  299. self.dense = nn.Linear(config.hidden_size, config.intermediate_size)
  300. if isinstance(config.hidden_act, str):
  301. self.intermediate_act_fn = ACT2FN[config.hidden_act]
  302. else:
  303. self.intermediate_act_fn = config.hidden_act
  304. def forward(self, hidden_states):
  305. hidden_states = self.dense(hidden_states)
  306. hidden_states = self.intermediate_act_fn(hidden_states)
  307. return hidden_states
  308. class SbertOutput(nn.Module):
  309. def __init__(self, config):
  310. super().__init__()
  311. self.dense = nn.Linear(config.intermediate_size, config.hidden_size)
  312. self.LayerNorm = nn.LayerNorm(
  313. config.hidden_size, eps=config.layer_norm_eps)
  314. self.dropout = nn.Dropout(config.hidden_dropout_prob)
  315. def forward(self, hidden_states, input_tensor):
  316. hidden_states = self.dense(hidden_states)
  317. hidden_states = self.dropout(hidden_states)
  318. hidden_states = self.LayerNorm(hidden_states + input_tensor)
  319. return hidden_states
  320. class SbertLayer(nn.Module):
  321. def __init__(self, config):
  322. super().__init__()
  323. self.chunk_size_feed_forward = config.chunk_size_feed_forward
  324. self.seq_len_dim = 1
  325. self.attention = SbertAttention(config)
  326. self.is_decoder = config.is_decoder
  327. self.add_cross_attention = config.add_cross_attention
  328. if self.add_cross_attention:
  329. if not self.is_decoder:
  330. raise ValueError(
  331. f'{self} should be used as a decoder model if cross attention is added'
  332. )
  333. self.crossattention = SbertAttention(config)
  334. self.intermediate = SbertIntermediate(config)
  335. self.output = SbertOutput(config)
  336. def forward(
  337. self,
  338. hidden_states,
  339. attention_mask=None,
  340. head_mask=None,
  341. encoder_hidden_states=None,
  342. encoder_attention_mask=None,
  343. past_key_value=None,
  344. output_attentions=False,
  345. ):
  346. # decoder uni-directional self-attention cached key/values tuple is at positions 1,2
  347. self_attn_past_key_value = past_key_value[:
  348. 2] if past_key_value is not None else None
  349. self_attention_outputs = self.attention(
  350. hidden_states,
  351. attention_mask,
  352. head_mask,
  353. output_attentions=output_attentions,
  354. past_key_value=self_attn_past_key_value,
  355. )
  356. attention_output = self_attention_outputs[0]
  357. # if decoder, the last output is tuple of self-attn cache
  358. if self.is_decoder:
  359. outputs = self_attention_outputs[1:-1]
  360. present_key_value = self_attention_outputs[-1]
  361. else:
  362. outputs = self_attention_outputs[
  363. 1:] # add self attentions if we output attention weights
  364. cross_attn_present_key_value = None
  365. if self.is_decoder and encoder_hidden_states is not None:
  366. if not hasattr(self, 'crossattention'):
  367. raise ValueError(
  368. f'If `encoder_hidden_states` are passed, {self} has to be instantiated with cross-attention '
  369. f'layers by setting `config.add_cross_attention=True`')
  370. # cross_attn cached key/values tuple is at positions 3,4 of past_key_value tuple
  371. cross_attn_past_key_value = past_key_value[
  372. -2:] if past_key_value is not None else None
  373. cross_attention_outputs = self.crossattention(
  374. attention_output,
  375. attention_mask,
  376. head_mask,
  377. encoder_hidden_states,
  378. encoder_attention_mask,
  379. cross_attn_past_key_value,
  380. output_attentions,
  381. )
  382. attention_output = cross_attention_outputs[0]
  383. outputs = outputs + cross_attention_outputs[
  384. 1:-1] # add cross attentions if we output attention weights
  385. # add cross-attn cache to positions 3,4 of present_key_value tuple
  386. cross_attn_present_key_value = cross_attention_outputs[-1]
  387. present_key_value = present_key_value + cross_attn_present_key_value
  388. layer_output = apply_chunking_to_forward(self.feed_forward_chunk,
  389. self.chunk_size_feed_forward,
  390. self.seq_len_dim,
  391. attention_output)
  392. outputs = (layer_output, ) + outputs
  393. # if decoder, return the attn key/values as the last output
  394. if self.is_decoder:
  395. outputs = outputs + (present_key_value, )
  396. return outputs
  397. def feed_forward_chunk(self, attention_output):
  398. intermediate_output = self.intermediate(attention_output)
  399. layer_output = self.output(intermediate_output, attention_output)
  400. return layer_output
  401. class SbertEncoder(nn.Module):
  402. def __init__(self, config):
  403. super().__init__()
  404. self.config = config
  405. self.layer = nn.ModuleList(
  406. [SbertLayer(config) for _ in range(config.num_hidden_layers)])
  407. self.gradient_checkpointing = False
  408. def forward(
  409. self,
  410. hidden_states,
  411. attention_mask=None,
  412. head_mask=None,
  413. encoder_hidden_states=None,
  414. encoder_attention_mask=None,
  415. past_key_values=None,
  416. use_cache=None,
  417. output_attentions=False,
  418. output_hidden_states=False,
  419. return_dict=True,
  420. ):
  421. all_hidden_states = () if output_hidden_states else None
  422. all_self_attentions = () if output_attentions else None
  423. all_cross_attentions = (
  424. ) if output_attentions and self.config.add_cross_attention else None
  425. next_decoder_cache = () if use_cache else None
  426. for i, layer_module in enumerate(self.layer):
  427. if output_hidden_states:
  428. all_hidden_states = all_hidden_states + (hidden_states, )
  429. layer_head_mask = head_mask[i] if head_mask is not None else None
  430. past_key_value = past_key_values[
  431. i] if past_key_values is not None else None
  432. if self.gradient_checkpointing and self.training:
  433. if use_cache:
  434. logger.warning(
  435. '`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...'
  436. )
  437. use_cache = False
  438. def create_custom_forward(module):
  439. def custom_forward(*inputs):
  440. return module(*inputs, past_key_value,
  441. output_attentions)
  442. return custom_forward
  443. layer_outputs = torch.utils.checkpoint.checkpoint(
  444. create_custom_forward(layer_module),
  445. hidden_states,
  446. attention_mask,
  447. layer_head_mask,
  448. encoder_hidden_states,
  449. encoder_attention_mask,
  450. )
  451. else:
  452. layer_outputs = layer_module(
  453. hidden_states,
  454. attention_mask,
  455. layer_head_mask,
  456. encoder_hidden_states,
  457. encoder_attention_mask,
  458. past_key_value,
  459. output_attentions,
  460. )
  461. hidden_states = layer_outputs[0]
  462. if use_cache:
  463. next_decoder_cache += (layer_outputs[-1], )
  464. if output_attentions:
  465. all_self_attentions = all_self_attentions + (
  466. layer_outputs[1], )
  467. if self.config.add_cross_attention:
  468. all_cross_attentions = all_cross_attentions + (
  469. layer_outputs[2], )
  470. if output_hidden_states:
  471. all_hidden_states = all_hidden_states + (hidden_states, )
  472. if not return_dict:
  473. return tuple(v for v in [
  474. hidden_states,
  475. next_decoder_cache,
  476. all_hidden_states,
  477. all_self_attentions,
  478. all_cross_attentions,
  479. ] if v is not None)
  480. return AttentionBackboneModelOutput(
  481. last_hidden_state=hidden_states,
  482. past_key_values=next_decoder_cache,
  483. hidden_states=all_hidden_states,
  484. attentions=all_self_attentions,
  485. cross_attentions=all_cross_attentions,
  486. )
  487. class SbertPooler(nn.Module):
  488. def __init__(self, config):
  489. super().__init__()
  490. self.dense = nn.Linear(config.hidden_size, config.hidden_size)
  491. self.activation = nn.Tanh()
  492. def forward(self, hidden_states):
  493. # We "pool" the model by simply taking the hidden state corresponding
  494. # to the first token.
  495. first_token_tensor = hidden_states[:, 0]
  496. pooled_output = self.dense(first_token_tensor)
  497. pooled_output = self.activation(pooled_output)
  498. return pooled_output
  499. class SbertPreTrainedModel(TorchModel, PreTrainedModel):
  500. """
  501. An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
  502. models.
  503. """
  504. config_class = SbertConfig
  505. base_model_prefix = 'bert'
  506. supports_gradient_checkpointing = True
  507. _keys_to_ignore_on_load_missing = [r'position_ids']
  508. def __init__(self, config, **kwargs):
  509. super().__init__(config.name_or_path, **kwargs)
  510. super(Model, self).__init__(config)
  511. def _init_weights(self, module):
  512. """Initialize the weights"""
  513. if isinstance(module, nn.Linear):
  514. # Slightly different from the TF version which uses truncated_normal for initialization
  515. # cf https://github.com/pytorch/pytorch/pull/5617
  516. module.weight.data.normal_(
  517. mean=0.0, std=self.config.initializer_range)
  518. if module.bias is not None:
  519. module.bias.data.zero_()
  520. elif isinstance(module, nn.Embedding):
  521. module.weight.data.normal_(
  522. mean=0.0, std=self.config.initializer_range)
  523. if module.padding_idx is not None:
  524. module.weight.data[module.padding_idx].zero_()
  525. elif isinstance(module, nn.LayerNorm):
  526. module.bias.data.zero_()
  527. module.weight.data.fill_(1.0)
  528. def _set_gradient_checkpointing(self, module, value=False):
  529. if isinstance(module, SbertEncoder):
  530. module.gradient_checkpointing = value
  531. @classmethod
  532. def _instantiate(cls, **kwargs):
  533. """Instantiate the model.
  534. Args:
  535. kwargs: Input args.
  536. model_dir: The model dir used to load the checkpoint and the label information.
  537. num_labels: An optional arg to tell the model how many classes to initialize.
  538. Method will call utils.parse_label_mapping if num_labels is not input.
  539. label2id: An optional label2id mapping, which will cover the label2id in configuration (if exists).
  540. Returns:
  541. The loaded model, which is initialized by transformers.PreTrainedModel.from_pretrained
  542. """
  543. model_dir = kwargs.pop('model_dir', None)
  544. cfg = kwargs.pop('cfg', None)
  545. model_args = parse_labels_in_order(model_dir, cfg, **kwargs)
  546. if model_dir is None:
  547. config = SbertConfig(**model_args)
  548. model = cls(config)
  549. else:
  550. model = super(Model, cls).from_pretrained(
  551. pretrained_model_name_or_path=model_dir, **model_args)
  552. return model
  553. @dataclass
  554. class AttentionBackboneModelOutputWithEmbedding(AttentionBackboneModelOutput):
  555. embedding_output: torch.FloatTensor = None
  556. logits: Optional[Union[tuple, torch.FloatTensor]] = None
  557. kwargs: dict = None
  558. @MODELS.register_module(Tasks.backbone, module_name=Models.structbert)
  559. class SbertModel(SbertPreTrainedModel):
  560. """The StructBERT Model transformer outputting raw hidden-states without any specific head on top.
  561. This model inherits from :class:`~transformers.PreTrainedModel`. Check the superclass documentation for the generic
  562. methods the library implements for all its model (such as downloading or saving, resizing the input embeddings,
  563. pruning heads etc.)
  564. This model is also a PyTorch `torch.nn.Module <https://pytorch.org/docs/stable/nn.html#torch.nn.Module>`__
  565. subclass. Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to
  566. general usage and behavior.
  567. Parameters:
  568. config (:class:`~modelscope.models.nlp.structbert.SbertConfig`): Model configuration class with
  569. all the parameters of the model.
  570. Initializing with a config file does not load the weights associated with the model, only the
  571. configuration. Check out the :meth:`~transformers.PreTrainedModel.from_pretrained` method to load the model
  572. weights.
  573. The model can behave as an encoder (with only self-attention) as well as a decoder, in which case a layer of
  574. cross-attention is added between the self-attention layers, following the architecture described in `Attention is
  575. all you need <https://arxiv.org/abs/1706.03762>`__ by Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit,
  576. Llion Jones, Aidan N. Gomez, Lukasz Kaiser and Illia Polosukhin.
  577. To behave as an decoder the model needs to be initialized with the :obj:`is_decoder` argument of the configuration
  578. set to :obj:`True`. To be used in a Seq2Seq model, the model needs to initialized with both :obj:`is_decoder`
  579. argument and :obj:`add_cross_attention` set to :obj:`True`; an :obj:`encoder_hidden_states` is then expected as an
  580. input to the forward pass.
  581. """
  582. def __init__(self, config: SbertConfig, add_pooling_layer=True, **kwargs):
  583. super().__init__(config)
  584. self.config = config
  585. self.embeddings = SbertEmbeddings(config)
  586. self.encoder = SbertEncoder(config)
  587. self.pooler = SbertPooler(config) if add_pooling_layer else None
  588. self.init_weights()
  589. def get_input_embeddings(self):
  590. return self.embeddings.word_embeddings
  591. def set_input_embeddings(self, value):
  592. self.embeddings.word_embeddings = value
  593. def _prune_heads(self, heads_to_prune):
  594. """
  595. Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base
  596. class PreTrainedModel
  597. """
  598. for layer, heads in heads_to_prune.items():
  599. self.encoder.layer[layer].attention.prune_heads(heads)
  600. def forward(self,
  601. input_ids=None,
  602. attention_mask=None,
  603. token_type_ids=None,
  604. position_ids=None,
  605. head_mask=None,
  606. inputs_embeds=None,
  607. encoder_hidden_states=None,
  608. encoder_attention_mask=None,
  609. past_key_values=None,
  610. use_cache=None,
  611. output_attentions=None,
  612. output_hidden_states=None,
  613. return_dict=None,
  614. **kwargs):
  615. r"""
  616. Args:
  617. input_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`):
  618. Indices of input sequence tokens in the vocabulary.
  619. Indices can be obtained using :class:`~modelscope.models.nlp.structbert.SbertTokenizer`. See
  620. :meth:`transformers.PreTrainedTokenizer.encode` and :meth:`transformers.PreTrainedTokenizer.__call__`
  621. for details.
  622. attention_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
  623. Mask to avoid performing attention on padding token indices. Mask values selected in ``[0, 1]``:
  624. - 1 for tokens that are **not masked**,
  625. - 0 for tokens that are **masked**.
  626. token_type_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
  627. Segment token indices to indicate first and second portions of the inputs. Indices are selected in
  628. ``[0, 1]``:
  629. - 0 corresponds to a `sentence A` token,
  630. - 1 corresponds to a `sentence B` token.
  631. position_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
  632. Indices of positions of each input sequence tokens in the position embeddings. Selected in the range
  633. ``[0, config.max_position_embeddings - 1]``.
  634. head_mask (:obj:`torch.FloatTensor` of shape :obj:`(num_heads,)` or :obj:`(num_layers, num_heads)`,
  635. `optional`):
  636. Mask to nullify selected heads of the self-attention modules. Mask values selected in ``[0, 1]``:
  637. - 1 indicates the head is **not masked**,
  638. - 0 indicates the head is **masked**.
  639. inputs_embeds (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`,
  640. `optional`):
  641. Optionally, instead of passing :obj:`input_ids` you can choose to directly pass an embedded
  642. representation. This is useful if you want more control over how to convert :obj:`input_ids` indices
  643. into associated vectors than the model's internal embedding lookup matrix.
  644. output_attentions (:obj:`bool`, `optional`):
  645. Whether or not to return the attentions tensors of all attention layers. See ``attentions`` under
  646. returned tensors for more detail.
  647. output_hidden_states (:obj:`bool`, `optional`):
  648. Whether or not to return the hidden states of all layers. See ``hidden_states`` under returned tensors
  649. for more detail.
  650. return_dict (:obj:`bool`, `optional`):
  651. Whether or not to return a :class:`~transformers.ModelOutput` instead of a plain tuple.
  652. encoder_hidden_states (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`,
  653. `optional`):
  654. Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if
  655. the model is configured as a decoder.
  656. encoder_attention_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
  657. Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used
  658. in the cross-attention if the model is configured as a decoder. Mask values selected in ``[0, 1]``:
  659. - 1 for tokens that are **not masked**,
  660. - 0 for tokens that are **masked**.
  661. past_key_values (:obj:`tuple(tuple(torch.FloatTensor))` of length :obj:`config.n_layers` with each tuple
  662. having 4 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`):
  663. Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up
  664. decoding.
  665. If :obj:`past_key_values` are used, the user can optionally input only the last :obj:`decoder_input_ids`
  666. (those that don't have their past key value states given to this model) of shape :obj:`(batch_size, 1)`
  667. instead of all :obj:`decoder_input_ids` of shape :obj:`(batch_size, sequence_length)`.
  668. use_cache (:obj:`bool`, `optional`):
  669. If set to :obj:`True`, :obj:`past_key_values` key value states are returned and can be used to speed up
  670. decoding (see :obj:`past_key_values`).
  671. Returns:
  672. Returns `modelscope.outputs.AttentionBackboneModelOutputWithEmbedding`
  673. Examples:
  674. >>> from modelscope.models import Model
  675. >>> from modelscope.preprocessors import Preprocessor
  676. >>> model = Model.from_pretrained('damo/nlp_structbert_backbone_base_std', task='backbone')
  677. >>> preprocessor = Preprocessor.from_pretrained('damo/nlp_structbert_backbone_base_std')
  678. >>> print(model(**preprocessor('这是个测试')))
  679. """
  680. output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
  681. output_hidden_states = (
  682. output_hidden_states if output_hidden_states is not None else
  683. self.config.output_hidden_states)
  684. return_dict = return_dict if return_dict is not None else self.config.use_return_dict
  685. if self.config.is_decoder:
  686. use_cache = use_cache if use_cache is not None else self.config.use_cache
  687. else:
  688. use_cache = False
  689. if input_ids is not None and inputs_embeds is not None:
  690. raise ValueError(
  691. 'You cannot specify both input_ids and inputs_embeds at the same time'
  692. )
  693. elif input_ids is not None:
  694. input_shape = input_ids.size()
  695. elif inputs_embeds is not None:
  696. input_shape = inputs_embeds.size()[:-1]
  697. else:
  698. raise ValueError(
  699. 'You have to specify either input_ids or inputs_embeds')
  700. batch_size, seq_length = input_shape
  701. device = input_ids.device if input_ids is not None else inputs_embeds.device
  702. # past_key_values_length
  703. past_key_values_length = past_key_values[0][0].shape[
  704. 2] if past_key_values is not None else 0
  705. if attention_mask is None:
  706. attention_mask = torch.ones(
  707. ((batch_size, seq_length + past_key_values_length)),
  708. device=device)
  709. if token_type_ids is None:
  710. if hasattr(self.embeddings, 'token_type_ids'):
  711. buffered_token_type_ids = self.embeddings.token_type_ids[:, :
  712. seq_length]
  713. buffered_token_type_ids_expanded = buffered_token_type_ids.expand(
  714. batch_size, seq_length)
  715. token_type_ids = buffered_token_type_ids_expanded
  716. else:
  717. token_type_ids = torch.zeros(
  718. input_shape, dtype=torch.long, device=device)
  719. # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
  720. # ourselves in which case we just need to make it broadcastable to all heads.
  721. extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(
  722. attention_mask, input_shape, device)
  723. # If a 2D or 3D attention mask is provided for the cross-attention
  724. # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length]
  725. if self.config.is_decoder and encoder_hidden_states is not None:
  726. encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size(
  727. )
  728. encoder_hidden_shape = (encoder_batch_size,
  729. encoder_sequence_length)
  730. if encoder_attention_mask is None:
  731. encoder_attention_mask = torch.ones(
  732. encoder_hidden_shape, device=device)
  733. encoder_extended_attention_mask = self.invert_attention_mask(
  734. encoder_attention_mask)
  735. else:
  736. encoder_extended_attention_mask = None
  737. # Prepare head mask if needed
  738. # 1.0 in head_mask indicate we keep the head
  739. # attention_probs has shape bsz x n_heads x N x N
  740. # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
  741. # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
  742. head_mask = self.get_head_mask(head_mask,
  743. self.config.num_hidden_layers)
  744. embedding_output, original_embeds = self.embeddings(
  745. input_ids=input_ids,
  746. position_ids=position_ids,
  747. token_type_ids=token_type_ids,
  748. inputs_embeds=inputs_embeds,
  749. past_key_values_length=past_key_values_length,
  750. return_inputs_embeds=True,
  751. )
  752. encoder_outputs = self.encoder(
  753. embedding_output,
  754. attention_mask=extended_attention_mask,
  755. head_mask=head_mask,
  756. encoder_hidden_states=encoder_hidden_states,
  757. encoder_attention_mask=encoder_extended_attention_mask,
  758. past_key_values=past_key_values,
  759. use_cache=use_cache,
  760. output_attentions=output_attentions,
  761. output_hidden_states=output_hidden_states,
  762. return_dict=return_dict,
  763. )
  764. sequence_output = encoder_outputs[0]
  765. pooled_output = self.pooler(
  766. sequence_output) if self.pooler is not None else None
  767. if not return_dict:
  768. return (sequence_output,
  769. pooled_output) + encoder_outputs[1:] + (original_embeds, )
  770. return AttentionBackboneModelOutputWithEmbedding(
  771. last_hidden_state=sequence_output,
  772. pooler_output=pooled_output,
  773. past_key_values=encoder_outputs.past_key_values,
  774. hidden_states=encoder_outputs.hidden_states,
  775. attentions=encoder_outputs.attentions,
  776. cross_attentions=encoder_outputs.cross_attentions,
  777. embedding_output=original_embeds)