| 123456789101112131415161718192021222324252627282930313233343536 |
- # Copyright (c) Alibaba, Inc. and its affiliates.
- import os
- from typing import Any, Dict, Optional, Union
- import torch
- from modelscope.metainfo import Models
- from modelscope.models.base.base_torch_model import TorchModel
- from modelscope.models.builder import MODELS
- from modelscope.utils.constant import ModelFile, Tasks
- from modelscope.utils.logger import get_logger
- LOGGER = get_logger()
- @MODELS.register_module(
- Tasks.image_inpainting, module_name=Models.image_inpainting)
- class FFTInpainting(TorchModel):
- def __init__(self, model_dir: str, **kwargs):
- super().__init__(model_dir, **kwargs)
- from .default import DefaultInpaintingTrainingModule
- pretrained = kwargs.get('pretrained', True)
- predict_only = kwargs.get('predict_only', False)
- net = DefaultInpaintingTrainingModule(
- model_dir=model_dir, predict_only=predict_only)
- if pretrained:
- path = os.path.join(model_dir, ModelFile.TORCH_MODEL_FILE)
- LOGGER.info(f'loading pretrained model from {path}')
- state = torch.load(path, map_location='cpu')
- net.load_state_dict(state, strict=False)
- self.model = net
- def forward(self, inputs):
- return self.model(inputs)
|