fill_mask.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253
  1. # Copyright 2021-2022 The Alibaba DAMO Team Authors.
  2. # Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team.
  3. # All rights reserved.
  4. #
  5. # Licensed under the Apache License, Version 2.0 (the "License");
  6. # you may not use this file except in compliance with the License.
  7. # You may obtain a copy of the License at
  8. #
  9. # http://www.apache.org/licenses/LICENSE-2.0
  10. #
  11. # Unless required by applicable law or agreed to in writing, software
  12. # distributed under the License is distributed on an "AS IS" BASIS,
  13. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  14. # See the License for the specific language governing permissions and
  15. # limitations under the License.
  16. import torch.utils.checkpoint
  17. from torch import nn
  18. from torch.nn import CrossEntropyLoss
  19. from transformers.activations import ACT2FN
  20. from modelscope.metainfo import Models
  21. from modelscope.models.builder import MODELS
  22. from modelscope.outputs import AttentionFillMaskModelOutput
  23. from modelscope.utils.constant import Tasks
  24. from modelscope.utils.logger import get_logger
  25. from .backbone import PoNetModel, PoNetPreTrainedModel
  26. logger = get_logger()
  27. class PoNetPredictionHeadTransform(nn.Module):
  28. def __init__(self, config):
  29. super().__init__()
  30. self.dense = nn.Linear(config.hidden_size, config.hidden_size)
  31. if isinstance(config.hidden_act, str):
  32. self.transform_act_fn = ACT2FN[config.hidden_act]
  33. else:
  34. self.transform_act_fn = config.hidden_act
  35. self.LayerNorm = nn.LayerNorm(
  36. config.hidden_size, eps=config.layer_norm_eps)
  37. def forward(self, hidden_states):
  38. hidden_states = self.dense(hidden_states)
  39. hidden_states = self.transform_act_fn(hidden_states)
  40. hidden_states = self.LayerNorm(hidden_states)
  41. return hidden_states
  42. class PoNetLMPredictionHead(nn.Module):
  43. def __init__(self, config):
  44. super().__init__()
  45. self.transform = PoNetPredictionHeadTransform(config)
  46. # The output weights are the same as the input embeddings, but there is
  47. # an output-only bias for each token.
  48. self.decoder = nn.Linear(
  49. config.hidden_size, config.vocab_size, bias=False)
  50. self.bias = nn.Parameter(torch.zeros(config.vocab_size))
  51. # Need a link between the two variables so that the bias is correctly resized with `resize_token_embeddings`
  52. self.decoder.bias = self.bias
  53. def forward(self, hidden_states):
  54. hidden_states = self.transform(hidden_states)
  55. hidden_states = self.decoder(hidden_states)
  56. return hidden_states
  57. class PoNetOnlyMLMHead(nn.Module):
  58. def __init__(self, config):
  59. super().__init__()
  60. self.predictions = PoNetLMPredictionHead(config)
  61. def forward(self, sequence_output):
  62. prediction_scores = self.predictions(sequence_output)
  63. return prediction_scores
  64. @MODELS.register_module(Tasks.fill_mask, module_name=Models.ponet)
  65. class PoNetForMaskedLM(PoNetPreTrainedModel):
  66. r"""PoNet Model with a `language modeling` head on top.
  67. This model inherits from :class:`~transformers.PreTrainedModel`. Check the superclass documentation for the generic
  68. methods the library implements for all its model (such as downloading or saving, resizing the input embeddings,
  69. pruning heads etc.)
  70. This model is also a PyTorch `torch.nn.Module <https://pytorch.org/docs/stable/nn.html#torch.nn.Module>`__
  71. subclass. Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to
  72. general usage and behavior.
  73. Preprocessor:
  74. This is the fill_mask model of PoNet, the preprocessor of this model
  75. is `modelscope.preprocessors.FillMaskPoNetPreprocessor`.
  76. Parameters:
  77. config (:class:`~modelscope.models.nlp.ponet.PoNetConfig`):
  78. Model configuration class with all the parameters of the model.
  79. Initializing with a config file does not load the weights associated with the model, only the
  80. configuration. Check out the :meth:`~transformers.PreTrainedModel.from_pretrained` method to load the model
  81. weights.
  82. """
  83. _keys_to_ignore_on_load_unexpected = [r'pooler']
  84. _keys_to_ignore_on_load_missing = [
  85. r'position_ids', r'predictions.decoder.bias'
  86. ]
  87. def __init__(self, config, **kwargs):
  88. super().__init__(config)
  89. if config.is_decoder:
  90. logger.warning(
  91. 'If you want to use `PoNetForMaskedLM` make sure `config.is_decoder=False` for '
  92. 'bi-directional self-attention.')
  93. self.ponet = PoNetModel(config, add_pooling_layer=False)
  94. self.cls = PoNetOnlyMLMHead(config)
  95. self.init_weights()
  96. def get_output_embeddings(self):
  97. return self.cls.predictions.decoder
  98. def set_output_embeddings(self, new_embeddings):
  99. self.cls.predictions.decoder = new_embeddings
  100. def forward(
  101. self,
  102. input_ids=None,
  103. attention_mask=None,
  104. token_type_ids=None,
  105. position_ids=None,
  106. segment_ids=None,
  107. head_mask=None,
  108. inputs_embeds=None,
  109. encoder_hidden_states=None,
  110. encoder_attention_mask=None,
  111. labels=None,
  112. output_attentions=None,
  113. output_hidden_states=None,
  114. return_dict=None,
  115. ):
  116. r"""
  117. Args:
  118. input_ids (:obj:`torch.LongTensor` of shape :obj:`('batch_size, sequence_length')`):
  119. Indices of input sequence tokens in the vocabulary.
  120. Indices can be obtained using :class:`~modelscope.models.nlp.ponet.PoNetTokenizer`. See
  121. :meth:`transformers.PreTrainedTokenizer.encode` and :meth:`transformers.PreTrainedTokenizer.__call__`
  122. for details.
  123. attention_mask (:obj:`torch.FloatTensor` of shape :obj:`('batch_size, sequence_length')`, `optional`):
  124. Mask to avoid performing attention on padding token indices. Mask values selected in ``[0, 1]``:
  125. - 1 for tokens that are **not masked**,
  126. - 0 for tokens that are **masked**.
  127. token_type_ids (:obj:`torch.LongTensor` of shape :obj:`('batch_size, sequence_length')`, `optional`):
  128. Segment token indices to indicate first and second portions of the inputs. Indices are selected in ``[0,
  129. 1]``:
  130. - 0 corresponds to a `sentence A` token,
  131. - 1 corresponds to a `sentence B` token.
  132. position_ids (:obj:`torch.LongTensor` of shape :obj:`('batch_size, sequence_length')`, `optional`):
  133. Indices of positions of each input sequence tokens in the position embeddings. Selected in the range
  134. ``[0, config.max_position_embeddings - 1]``.
  135. head_mask (:obj:`torch.FloatTensor` of shape :obj:`(num_heads,)` or :obj:`(num_layers, num_heads)`,
  136. `optional`):
  137. Mask to nullify selected heads of the self-attention modules. Mask values selected in ``[0, 1]``:
  138. - 1 indicates the head is **not masked**,
  139. - 0 indicates the head is **masked**.
  140. inputs_embeds (:obj:`torch.FloatTensor` of shape :obj:`('batch_size, sequence_length', hidden_size)`,
  141. `optional`):
  142. Optionally, instead of passing :obj:`input_ids` you can choose to directly pass an embedded
  143. representation. This is useful if you want more control over how to convert :obj:`input_ids`
  144. indices into associated vectors than the model's internal embedding lookup matrix.
  145. output_attentions (:obj:`bool`, `optional`):
  146. Whether or not to return the attentions tensors of all attention layers. See ``attentions`` under
  147. returned tensors for more detail.
  148. output_hidden_states (:obj:`bool`, `optional`):
  149. Whether or not to return the hidden states of all layers. See ``hidden_states`` under returned tensors
  150. for more detail.
  151. return_dict (:obj:`bool`, `optional`):
  152. Whether or not to return a :class:`~transformers.file_utils.ModelOutput` instead of a plain tuple.
  153. labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
  154. Labels for computing the masked language modeling loss. Indices should be in ``[-100, 0, ...,
  155. config.vocab_size]`` (see ``input_ids`` docstring) Tokens with indices set to ``-100`` are ignored
  156. (masked), the loss is only computed for the tokens with labels in ``[0, ..., config.vocab_size]``
  157. Returns:
  158. Returns `modelscope.outputs.AttentionFillMaskModelOutput`
  159. Examples:
  160. >>> from modelscope.models import Model
  161. >>> from modelscope.preprocessors import Preprocessor
  162. >>> model = Model.from_pretrained('damo/nlp_ponet_fill-mask_chinese-base')
  163. >>> preprocessor = Preprocessor.from_pretrained('damo/nlp_ponet_fill-mask_chinese-base')
  164. >>> # Call the model, return some tensors
  165. >>> print(model(**preprocessor('你师父差得动你,你师父可[MASK]不动我。')))
  166. >>> # Call the pipeline
  167. >>> from modelscope.pipelines import pipeline
  168. >>> pipeline_ins = pipeline('fill-mask', model=model, preprocessor=preprocessor)
  169. >>> print(pipeline_ins('你师父差得动你,你师父可[MASK]不动我。'))
  170. """
  171. return_dict = return_dict if return_dict is not None else self.config.use_return_dict
  172. outputs = self.ponet(
  173. input_ids,
  174. attention_mask=attention_mask,
  175. token_type_ids=token_type_ids,
  176. segment_ids=segment_ids,
  177. position_ids=position_ids,
  178. head_mask=head_mask,
  179. inputs_embeds=inputs_embeds,
  180. encoder_hidden_states=encoder_hidden_states,
  181. encoder_attention_mask=encoder_attention_mask,
  182. output_attentions=output_attentions,
  183. output_hidden_states=output_hidden_states,
  184. return_dict=return_dict,
  185. )
  186. sequence_output = outputs[0]
  187. prediction_scores = self.cls(sequence_output)
  188. masked_lm_loss = None
  189. if labels is not None:
  190. loss_fct = CrossEntropyLoss() # -100 index = padding token
  191. masked_lm_loss = loss_fct(
  192. prediction_scores.view(-1, self.config.vocab_size),
  193. labels.view(-1))
  194. if not return_dict:
  195. output = (prediction_scores, ) + outputs[2:]
  196. return ((masked_lm_loss, )
  197. + output) if masked_lm_loss is not None else output
  198. return AttentionFillMaskModelOutput(
  199. loss=masked_lm_loss,
  200. logits=prediction_scores,
  201. hidden_states=outputs.hidden_states,
  202. attentions=outputs.attentions,
  203. input_ids=input_ids,
  204. )