| 12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849 |
- # Copyright (c) Alibaba, Inc. and its affiliates.
- import os
- from typing import Any, Dict, Optional, Union
- import torch
- from omegaconf import OmegaConf
- from paint_ldm.util import instantiate_from_config
- from modelscope.metainfo import Models
- from modelscope.models.base.base_torch_model import TorchModel
- from modelscope.models.builder import MODELS
- from modelscope.utils.constant import ModelFile, Tasks
- from modelscope.utils.logger import get_logger
- LOGGER = get_logger()
- def load_model_from_config(config, ckpt, verbose=False):
- LOGGER.info(f'Loading model from {ckpt}')
- pl_sd = torch.load(ckpt, map_location='cpu')
- if 'global_step' in pl_sd:
- LOGGER.info(f"Global Step: {pl_sd['global_step']}")
- sd = pl_sd['state_dict']
- model = instantiate_from_config(config.model)
- m, u = model.load_state_dict(sd, strict=False)
- if len(m) > 0 and verbose:
- LOGGER.info('missing keys:')
- LOGGER.info(m)
- if len(u) > 0 and verbose:
- LOGGER.info('unexpected keys:')
- LOGGER.info(u)
- return model
- @MODELS.register_module(
- Tasks.image_paintbyexample, module_name=Models.image_paintbyexample)
- class StablediffusionPaintbyexample(TorchModel):
- def __init__(self, model_dir: str, **kwargs):
- super().__init__(model_dir, **kwargs)
- config = OmegaConf.load(os.path.join(model_dir, 'v1.yaml'))
- model = load_model_from_config(
- config, os.path.join(model_dir, 'pytorch_model.pt'))
- self.model = model
- def forward(self, inputs):
- return self.model(inputs)
|