base.py 2.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475
  1. # Copyright (c) Alibaba, Inc. and its affiliates.
  2. import os
  3. from abc import ABC, abstractmethod
  4. from typing import Dict, Union
  5. from modelscope.models import Model
  6. from modelscope.utils.config import Config, ConfigDict
  7. from modelscope.utils.constant import ModelFile
  8. from modelscope.utils.logger import get_logger
  9. from .builder import build_exporter
  10. logger = get_logger()
  11. class Exporter(ABC):
  12. """Exporter base class to output model to onnx, torch_script, graphdef, etc.
  13. """
  14. def __init__(self, model=None):
  15. self.model = model
  16. @classmethod
  17. def from_model(cls, model: Union[Model, str], **kwargs):
  18. """Build the Exporter instance.
  19. Args:
  20. model: A Model instance or a model id or a model dir, the configuration.json file besides to which
  21. will be used to create the exporter instance.
  22. kwargs: Extra kwargs used to create the Exporter instance.
  23. Returns:
  24. The Exporter instance
  25. """
  26. if isinstance(model, str):
  27. model = Model.from_pretrained(model)
  28. assert hasattr(model, 'model_dir')
  29. model_dir = model.model_dir
  30. cfg = Config.from_file(
  31. os.path.join(model_dir, ModelFile.CONFIGURATION))
  32. task_name = cfg.task
  33. if hasattr(model, 'group_key'):
  34. task_name = model.group_key
  35. model_cfg = cfg.model
  36. if hasattr(model_cfg, 'model_type') and not hasattr(model_cfg, 'type'):
  37. model_cfg.type = model_cfg.model_type
  38. export_cfg = ConfigDict({'type': model_cfg.type})
  39. if hasattr(cfg, 'export'):
  40. export_cfg.update(cfg.export)
  41. export_cfg['model'] = model
  42. try:
  43. exporter = build_exporter(export_cfg, task_name, kwargs)
  44. except KeyError as e:
  45. raise KeyError(
  46. f'The exporting of model \'{model_cfg.type}\' with task: \'{task_name}\' '
  47. f'is not supported currently.') from e
  48. return exporter
  49. @abstractmethod
  50. def export_onnx(self, output_dir: str, opset=13, **kwargs):
  51. """Export the model as onnx format files.
  52. In some cases, several files may be generated,
  53. So please return a dict which contains the generated name with the file path.
  54. Args:
  55. opset: The version of the ONNX operator set to use.
  56. output_dir: The output dir.
  57. kwargs: In this default implementation,
  58. kwargs will be carried to generate_dummy_inputs as extra arguments (like input shape).
  59. Returns:
  60. A dict contains the model name with the model file path.
  61. """
  62. pass