model.py 1.2 KB

123456789101112131415161718192021222324252627282930313233343536
  1. # Copyright (c) Alibaba, Inc. and its affiliates.
  2. import os
  3. from typing import Any, Dict, Optional, Union
  4. import torch
  5. from modelscope.metainfo import Models
  6. from modelscope.models.base.base_torch_model import TorchModel
  7. from modelscope.models.builder import MODELS
  8. from modelscope.utils.constant import ModelFile, Tasks
  9. from modelscope.utils.logger import get_logger
  10. LOGGER = get_logger()
  11. @MODELS.register_module(
  12. Tasks.image_inpainting, module_name=Models.image_inpainting)
  13. class FFTInpainting(TorchModel):
  14. def __init__(self, model_dir: str, **kwargs):
  15. super().__init__(model_dir, **kwargs)
  16. from .default import DefaultInpaintingTrainingModule
  17. pretrained = kwargs.get('pretrained', True)
  18. predict_only = kwargs.get('predict_only', False)
  19. net = DefaultInpaintingTrainingModule(
  20. model_dir=model_dir, predict_only=predict_only)
  21. if pretrained:
  22. path = os.path.join(model_dir, ModelFile.TORCH_MODEL_FILE)
  23. LOGGER.info(f'loading pretrained model from {path}')
  24. state = torch.load(path, map_location='cpu')
  25. net.load_state_dict(state, strict=False)
  26. self.model = net
  27. def forward(self, inputs):
  28. return self.model(inputs)