fill_mask.py 512 B

12345678910111213141516
  1. # Copyright (c) Alibaba, Inc. and its affiliates.
  2. from modelscope.metainfo import Heads, Models
  3. from modelscope.models.builder import MODELS
  4. from modelscope.models.nlp.task_models.fill_mask import ModelForFillMask
  5. from modelscope.utils import logger as logging
  6. from modelscope.utils.constant import Tasks
  7. logger = logging.get_logger()
  8. @MODELS.register_module(Tasks.fill_mask, module_name=Models.bert)
  9. class BertForMaskedLM(ModelForFillMask):
  10. base_model_type = Models.bert
  11. head_type = Heads.bert_mlm