fill_mask.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286
  1. # Copyright 2021-2022 The Alibaba DAMO NLP Team Authors.
  2. # Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team.
  3. # Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
  4. # All rights reserved.
  5. #
  6. # Licensed under the Apache License, Version 2.0 (the "License");
  7. # you may not use this file except in compliance with the License.
  8. # You may obtain a copy of the License at
  9. #
  10. # http://www.apache.org/licenses/LICENSE-2.0
  11. #
  12. # Unless required by applicable law or agreed to in writing, software
  13. # distributed under the License is distributed on an "AS IS" BASIS,
  14. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  15. # See the License for the specific language governing permissions and
  16. # limitations under the License.
  17. import torch
  18. import torch.nn as nn
  19. import torch.utils.checkpoint
  20. from torch.nn import CrossEntropyLoss
  21. from transformers.activations import ACT2FN
  22. from modelscope.metainfo import Models
  23. from modelscope.models.builder import MODELS
  24. from modelscope.outputs import AttentionFillMaskModelOutput
  25. from modelscope.utils import logger as logging
  26. from modelscope.utils.constant import Tasks
  27. from .backbone import SbertModel, SbertPreTrainedModel
  28. from .configuration import SbertConfig
  29. logger = logging.get_logger()
  30. class SbertPredictionHeadTransform(nn.Module):
  31. def __init__(self, config):
  32. super().__init__()
  33. self.dense = nn.Linear(config.hidden_size, config.hidden_size)
  34. if isinstance(config.hidden_act, str):
  35. self.transform_act_fn = ACT2FN[config.hidden_act]
  36. else:
  37. self.transform_act_fn = config.hidden_act
  38. self.LayerNorm = nn.LayerNorm(
  39. config.hidden_size, eps=config.layer_norm_eps)
  40. def forward(self, hidden_states):
  41. hidden_states = self.dense(hidden_states)
  42. hidden_states = self.transform_act_fn(hidden_states)
  43. hidden_states = self.LayerNorm(hidden_states)
  44. return hidden_states
  45. class SbertLMPredictionHead(nn.Module):
  46. def __init__(self, config):
  47. super().__init__()
  48. self.transform = SbertPredictionHeadTransform(config)
  49. # The output weights are the same as the input embeddings, but there is
  50. # an output-only bias for each token.
  51. self.decoder = nn.Linear(config.hidden_size, config.vocab_size)
  52. def forward(self, hidden_states):
  53. hidden_states = self.transform(hidden_states)
  54. hidden_states = self.decoder(hidden_states)
  55. return hidden_states
  56. class SbertOnlyMLMHead(nn.Module):
  57. def __init__(self, config):
  58. super().__init__()
  59. self.predictions = SbertLMPredictionHead(config)
  60. def forward(self, sequence_output):
  61. prediction_scores = self.predictions(sequence_output)
  62. return prediction_scores
  63. class SbertPreTrainingHeads(nn.Module):
  64. def __init__(self, config):
  65. super().__init__()
  66. self.predictions = SbertLMPredictionHead(config)
  67. self.seq_relationship = nn.Linear(config.hidden_size, 2)
  68. def forward(self, sequence_output, pooled_output):
  69. prediction_scores = self.predictions(sequence_output)
  70. seq_relationship_score = self.seq_relationship(pooled_output)
  71. return prediction_scores, seq_relationship_score
  72. @MODELS.register_module(Tasks.fill_mask, module_name=Models.structbert)
  73. class SbertForMaskedLM(SbertPreTrainedModel):
  74. r"""StructBERT Model with a `language modeling` head on top.
  75. This model inherits from :class:`~transformers.PreTrainedModel`. Check the superclass documentation for the generic
  76. methods the library implements for all its model (such as downloading or saving, resizing the input embeddings,
  77. pruning heads etc.)
  78. This model is also a PyTorch `torch.nn.Module <https://pytorch.org/docs/stable/nn.html#torch.nn.Module>`__
  79. subclass. Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to
  80. general usage and behavior.
  81. Preprocessor:
  82. This is the fill_mask model of StructBERT, the preprocessor of this model
  83. is `modelscope.preprocessors.FillMaskTransformersPreprocessor`.
  84. Parameters:
  85. config (:class:`~modelscope.models.nlp.structbert.SbertConfig`): Model configuration class with
  86. all the parameters of the model.
  87. Initializing with a config file does not load the weights associated with the model, only the
  88. configuration. Check out the :meth:`~transformers.PreTrainedModel.from_pretrained` method to load the model
  89. weights.
  90. """
  91. _keys_to_ignore_on_load_unexpected = [r'pooler']
  92. _keys_to_ignore_on_load_missing = [
  93. r'position_ids', r'predictions.decoder.bias'
  94. ]
  95. def __init__(self, config: SbertConfig, **kwargs):
  96. super().__init__(config)
  97. if config.is_decoder:
  98. logger.warning(
  99. 'If you want to use `SbertForMaskedLM` make sure `config.is_decoder=False` for '
  100. 'bi-directional self-attention.')
  101. self.bert = SbertModel(config)
  102. self.cls = SbertOnlyMLMHead(config)
  103. self.init_weights()
  104. def get_output_embeddings(self):
  105. return self.cls.predictions.decoder
  106. def set_output_embeddings(self, new_embeddings):
  107. self.cls.predictions.decoder = new_embeddings
  108. def forward(
  109. self,
  110. input_ids=None,
  111. attention_mask=None,
  112. token_type_ids=None,
  113. position_ids=None,
  114. head_mask=None,
  115. inputs_embeds=None,
  116. encoder_hidden_states=None,
  117. encoder_attention_mask=None,
  118. labels=None,
  119. output_attentions=None,
  120. output_hidden_states=None,
  121. return_dict=None,
  122. ):
  123. r"""
  124. Args:
  125. input_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`):
  126. Indices of input sequence tokens in the vocabulary.
  127. Indices can be obtained using :class:`~modelscope.models.nlp.structbert.SbertTokenizer`. See
  128. :meth:`transformers.PreTrainedTokenizer.encode` and :meth:`transformers.PreTrainedTokenizer.__call__`
  129. for details.
  130. `What are input IDs? <../glossary.html#input-ids>`__
  131. attention_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
  132. Mask to avoid performing attention on padding token indices. Mask values selected in ``[0, 1]``:
  133. - 1 for tokens that are **not masked**,
  134. - 0 for tokens that are **masked**.
  135. `What are attention masks? <../glossary.html#attention-mask>`__
  136. token_type_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
  137. Segment token indices to indicate first and second portions of the inputs. Indices are selected in ``[0,
  138. 1]``:
  139. - 0 corresponds to a `sentence A` token,
  140. - 1 corresponds to a `sentence B` token.
  141. `What are token type IDs? <../glossary.html#token-type-ids>`_
  142. position_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
  143. Indices of positions of each input sequence tokens in the position embeddings. Selected in the range
  144. ``[0, config.max_position_embeddings - 1]``.
  145. head_mask (:obj:`torch.FloatTensor` of shape :obj:`(num_heads,)` or :obj:`(num_layers, num_heads)`,
  146. `optional`):
  147. Mask to nullify selected heads of the self-attention modules. Mask values selected in ``[0, 1]``:
  148. - 1 indicates the head is **not masked**,
  149. - 0 indicates the head is **masked**.
  150. inputs_embeds (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`,
  151. `optional`):
  152. Optionally, instead of passing :obj:`input_ids` you can choose to directly pass an embedded
  153. representation. This is useful if you want more control over how to convert :obj:`input_ids` indices
  154. into associated vectors than the model's internal embedding lookup matrix.
  155. output_attentions (:obj:`bool`, `optional`):
  156. Whether or not to return the attentions tensors of all attention layers. See ``attentions`` under
  157. returned tensors for more detail.
  158. output_hidden_states (:obj:`bool`, `optional`):
  159. Whether or not to return the hidden states of all layers. See ``hidden_states`` under returned tensors
  160. for more detail.
  161. return_dict (:obj:`bool`, `optional`):
  162. Whether or not to return a :class:`~transformers.ModelOutput` instead of a plain tuple.
  163. labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
  164. Labels for computing the masked language modeling loss. Indices should be in ``[-100, 0, ...,
  165. config.vocab_size]`` (see ``input_ids`` docstring) Tokens with indices set to ``-100`` are ignored
  166. (masked), the loss is only computed for the tokens with labels in ``[0, ..., config.vocab_size]``
  167. Returns:
  168. Returns `modelscope.outputs.AttentionFillMaskModelOutput`
  169. Examples:
  170. >>> from modelscope.models import Model
  171. >>> from modelscope.preprocessors import Preprocessor, FillMaskTransformersPreprocessor
  172. >>> model = Model.from_pretrained('damo/nlp_structbert_fill-mask_chinese-large')
  173. >>> preprocessor = FillMaskTransformersPreprocessor('damo/nlp_structbert_fill-mask_chinese-large')
  174. >>> # Call the model, return some tensors
  175. >>> print(model(**preprocessor('你师父差得动你,你师父可[MASK]不动我。')))
  176. >>> # Call the pipeline
  177. >>> from modelscope.pipelines import pipeline
  178. >>> pipeline_ins = pipeline('fill-mask', model=model, preprocessor=preprocessor)
  179. >>> print(pipeline_ins('你师父差得动你,你师父可[MASK]不动我。'))
  180. """
  181. return_dict = return_dict if return_dict is not None else self.config.use_return_dict
  182. outputs = self.bert(
  183. input_ids,
  184. attention_mask=attention_mask,
  185. token_type_ids=token_type_ids,
  186. position_ids=position_ids,
  187. head_mask=head_mask,
  188. inputs_embeds=inputs_embeds,
  189. encoder_hidden_states=encoder_hidden_states,
  190. encoder_attention_mask=encoder_attention_mask,
  191. output_attentions=output_attentions,
  192. output_hidden_states=output_hidden_states,
  193. return_dict=return_dict,
  194. )
  195. sequence_output = outputs[0]
  196. prediction_scores = self.cls(sequence_output)
  197. masked_lm_loss = None
  198. if labels is not None:
  199. loss_fct = CrossEntropyLoss() # -100 index = padding token
  200. masked_lm_loss = loss_fct(
  201. prediction_scores.view(-1, self.config.vocab_size),
  202. labels.view(-1))
  203. if not return_dict:
  204. output = (prediction_scores, ) + outputs[2:-1]
  205. return ((masked_lm_loss, )
  206. + output) if masked_lm_loss is not None else output
  207. return AttentionFillMaskModelOutput(
  208. loss=masked_lm_loss,
  209. logits=prediction_scores,
  210. hidden_states=outputs.hidden_states,
  211. attentions=outputs.attentions,
  212. input_ids=input_ids,
  213. )
  214. def prepare_inputs_for_generation(self,
  215. input_ids,
  216. attention_mask=None,
  217. **model_kwargs):
  218. input_shape = input_ids.shape
  219. effective_batch_size = input_shape[0]
  220. # add a dummy token
  221. assert self.config.pad_token_id is not None, 'The PAD token should be defined for generation'
  222. attention_mask_zero = attention_mask.new_zeros(
  223. (attention_mask.shape[0], 1))
  224. attention_mask = torch.cat([attention_mask, attention_mask_zero],
  225. dim=-1)
  226. dummy_token = torch.full((effective_batch_size, 1),
  227. self.config.pad_token_id,
  228. dtype=torch.long,
  229. device=input_ids.device)
  230. input_ids = torch.cat([input_ids, dummy_token], dim=1)
  231. return {'input_ids': input_ids, 'attention_mask': attention_mask}