model.py 1.5 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849
  1. # Copyright (c) Alibaba, Inc. and its affiliates.
  2. import os
  3. from typing import Any, Dict, Optional, Union
  4. import torch
  5. from omegaconf import OmegaConf
  6. from paint_ldm.util import instantiate_from_config
  7. from modelscope.metainfo import Models
  8. from modelscope.models.base.base_torch_model import TorchModel
  9. from modelscope.models.builder import MODELS
  10. from modelscope.utils.constant import ModelFile, Tasks
  11. from modelscope.utils.logger import get_logger
  12. LOGGER = get_logger()
  13. def load_model_from_config(config, ckpt, verbose=False):
  14. LOGGER.info(f'Loading model from {ckpt}')
  15. pl_sd = torch.load(ckpt, map_location='cpu')
  16. if 'global_step' in pl_sd:
  17. LOGGER.info(f"Global Step: {pl_sd['global_step']}")
  18. sd = pl_sd['state_dict']
  19. model = instantiate_from_config(config.model)
  20. m, u = model.load_state_dict(sd, strict=False)
  21. if len(m) > 0 and verbose:
  22. LOGGER.info('missing keys:')
  23. LOGGER.info(m)
  24. if len(u) > 0 and verbose:
  25. LOGGER.info('unexpected keys:')
  26. LOGGER.info(u)
  27. return model
  28. @MODELS.register_module(
  29. Tasks.image_paintbyexample, module_name=Models.image_paintbyexample)
  30. class StablediffusionPaintbyexample(TorchModel):
  31. def __init__(self, model_dir: str, **kwargs):
  32. super().__init__(model_dir, **kwargs)
  33. config = OmegaConf.load(os.path.join(model_dir, 'v1.yaml'))
  34. model = load_model_from_config(
  35. config, os.path.join(model_dir, 'pytorch_model.pt'))
  36. self.model = model
  37. def forward(self, inputs):
  38. return self.model(inputs)