backbone.py 3.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596
  1. # Copyright 2019 Facebook AI Research and the HuggingFace Inc. team.
  2. # Copyright (c) 2018, NVIDIA CORPORATION.
  3. # Copyright 2021-2022 The Alibaba DAMO NLP Team Authors.
  4. # All rights reserved.
  5. #
  6. # Licensed under the Apache License, Version 2.0 (the "License");
  7. # you may not use this file except in compliance with the License.
  8. # You may obtain a copy of the License at
  9. #
  10. # http://www.apache.org/licenses/LICENSE-2.0
  11. #
  12. # Unless required by applicable law or agreed to in writing, software
  13. # distributed under the License is distributed on an "AS IS" BASIS,
  14. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  15. # See the License for the specific language governing permissions and
  16. # limitations under the License.
  17. """PyTorch Veco model. mainly copied from :module:`~transformers.modeling_xlm_roberta`"""
  18. from transformers import RobertaModel
  19. from modelscope.metainfo import Models
  20. from modelscope.models import Model, TorchModel
  21. from modelscope.models.builder import MODELS
  22. from modelscope.outputs import AttentionBackboneModelOutput
  23. from modelscope.utils import logger as logging
  24. from modelscope.utils.constant import Tasks
  25. from .configuration import VecoConfig
  26. logger = logging.get_logger()
  27. VECO_PRETRAINED_MODEL_ARCHIVE_LIST = []
  28. @MODELS.register_module(Tasks.backbone, module_name=Models.veco)
  29. class VecoModel(TorchModel, RobertaModel):
  30. """The bare Veco Model transformer outputting raw hidden-states without any specific head on top.
  31. This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic
  32. methods the library implements for all its model (such as downloading or saving, resizing the input embeddings,
  33. pruning heads etc.)
  34. This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module)
  35. subclass. Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to
  36. general usage and behavior.
  37. Parameters:
  38. config ([`VecoConfig`]): Model configuration class with all the parameters of the
  39. model. Initializing with a config file does not load the weights associated with the model, only the
  40. configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model
  41. weights.
  42. This class overrides [`RobertaModel`]. Please check the superclass for the appropriate
  43. documentation alongside usage examples.
  44. """
  45. config_class = VecoConfig
  46. def __init__(self, config, **kwargs):
  47. super().__init__(config.name_or_path, **kwargs)
  48. super(Model, self).__init__(config)
  49. def forward(self, *args, **kwargs):
  50. """
  51. Returns:
  52. Returns `modelscope.outputs.AttentionBackboneModelOutputWithEmbedding`
  53. Examples:
  54. >>> from modelscope.models import Model
  55. >>> from modelscope.preprocessors import Preprocessor
  56. >>> model = Model.from_pretrained('damo/nlp_veco_fill-mask-large', task='backbone')
  57. >>> preprocessor = Preprocessor.from_pretrained('damo/nlp_veco_fill-mask-large')
  58. >>> print(model(**preprocessor('这是个测试')))
  59. """
  60. kwargs['return_dict'] = True
  61. outputs = super(Model, self).forward(*args, **kwargs)
  62. return AttentionBackboneModelOutput(
  63. last_hidden_state=outputs.last_hidden_state,
  64. pooler_output=outputs.pooler_output,
  65. past_key_values=outputs.past_key_values,
  66. hidden_states=outputs.hidden_states,
  67. attentions=outputs.attentions,
  68. cross_attentions=outputs.cross_attentions,
  69. )
  70. @classmethod
  71. def _instantiate(cls, **kwargs):
  72. model_dir = kwargs.pop('model_dir', None)
  73. if model_dir is None:
  74. ponet_config = VecoConfig(**kwargs)
  75. model = cls(ponet_config)
  76. else:
  77. model = super(
  78. Model,
  79. cls).from_pretrained(pretrained_model_name_or_path=model_dir)
  80. return model