__init__.py 6.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148
  1. # Copyright (c) Alibaba, Inc. and its affiliates.
  2. import importlib
  3. from typing import TYPE_CHECKING
  4. from modelscope.utils.import_utils import (LazyImportModule,
  5. is_transformers_available)
  6. if TYPE_CHECKING:
  7. from .exporters import Exporter, TfModelExporter, TorchModelExporter
  8. from .hub.api import HubApi
  9. from .hub.check_model import check_local_model_is_latest, check_model_is_id
  10. from .hub.push_to_hub import push_to_hub, push_to_hub_async
  11. from .hub.snapshot_download import snapshot_download, dataset_snapshot_download
  12. from .hub.file_download import model_file_download, dataset_file_download
  13. from .metrics import (
  14. AccuracyMetric, AudioNoiseMetric, BleuMetric, ImageColorEnhanceMetric,
  15. ImageColorizationMetric, ImageDenoiseMetric, ImageInpaintingMetric,
  16. ImageInstanceSegmentationCOCOMetric, ImagePortraitEnhancementMetric,
  17. ImageQualityAssessmentDegradationMetric,
  18. ImageQualityAssessmentMosMetric, LossMetric, Metric,
  19. MovieSceneSegmentationMetric, OCRRecognitionMetric, PplMetric,
  20. ReferringVideoObjectSegmentationMetric, SequenceClassificationMetric,
  21. TextGenerationMetric, TextRankingMetric, TokenClassificationMetric,
  22. VideoFrameInterpolationMetric, VideoStabilizationMetric,
  23. VideoSummarizationMetric, VideoSuperResolutionMetric,
  24. task_default_metrics)
  25. from .models import Model, TorchModel
  26. from .msdatasets import MsDataset
  27. from .pipelines import Pipeline, pipeline
  28. from .preprocessors import Preprocessor
  29. from .trainers import (EpochBasedTrainer, Hook, Priority, TrainingArgs,
  30. build_dataset_from_file)
  31. from .utils.constant import Tasks
  32. from .utils.hf_util import patch_hub, patch_context, unpatch_hub
  33. if is_transformers_available():
  34. from .utils.hf_util import (
  35. AutoModel, AutoProcessor, AutoFeatureExtractor, GenerationConfig,
  36. AutoConfig, GPTQConfig, AwqConfig, BitsAndBytesConfig,
  37. AutoModelForCausalLM, AutoModelForSeq2SeqLM,
  38. AutoModelForVision2Seq, AutoModelForSequenceClassification,
  39. AutoModelForTokenClassification, AutoModelForImageClassification,
  40. AutoModelForImageTextToText,
  41. AutoModelForZeroShotImageClassification,
  42. AutoModelForKeypointDetection,
  43. AutoModelForDocumentQuestionAnswering,
  44. AutoModelForSemanticSegmentation,
  45. AutoModelForUniversalSegmentation,
  46. AutoModelForInstanceSegmentation, AutoModelForObjectDetection,
  47. AutoModelForZeroShotObjectDetection,
  48. AutoModelForAudioClassification, AutoModelForSpeechSeq2Seq,
  49. AutoModelForMaskedImageModeling,
  50. AutoModelForVisualQuestionAnswering,
  51. AutoModelForTableQuestionAnswering, AutoModelForImageToImage,
  52. AutoModelForImageSegmentation, AutoModelForQuestionAnswering,
  53. AutoModelForMaskedLM, AutoTokenizer, AutoModelForMaskGeneration,
  54. AutoModelForPreTraining, AutoModelForTextEncoding,
  55. AutoImageProcessor, BatchFeature, Qwen2VLForConditionalGeneration,
  56. T5EncoderModel, Qwen2_5_VLForConditionalGeneration, LlamaModel,
  57. LlamaPreTrainedModel, LlamaForCausalLM, hf_pipeline)
  58. else:
  59. print(
  60. 'transformer is not installed, please install it if you want to use related modules'
  61. )
  62. from .utils.hub import create_model_if_not_exist, read_config
  63. from .utils.logger import get_logger
  64. from .version import __release_datetime__, __version__
  65. else:
  66. _import_structure = {
  67. 'version': ['__release_datetime__', '__version__'],
  68. 'trainers': [
  69. 'EpochBasedTrainer', 'TrainingArgs', 'Hook', 'Priority',
  70. 'build_dataset_from_file'
  71. ],
  72. 'exporters': [
  73. 'Exporter',
  74. 'TfModelExporter',
  75. 'TorchModelExporter',
  76. ],
  77. 'hub.api': ['HubApi'],
  78. 'hub.snapshot_download':
  79. ['snapshot_download', 'dataset_snapshot_download'],
  80. 'hub.file_download': ['model_file_download', 'dataset_file_download'],
  81. 'hub.push_to_hub': ['push_to_hub', 'push_to_hub_async'],
  82. 'hub.check_model':
  83. ['check_model_is_id', 'check_local_model_is_latest'],
  84. 'metrics': [
  85. 'AudioNoiseMetric', 'Metric', 'task_default_metrics',
  86. 'ImageColorEnhanceMetric', 'ImageDenoiseMetric',
  87. 'ImageInstanceSegmentationCOCOMetric',
  88. 'ImagePortraitEnhancementMetric', 'SequenceClassificationMetric',
  89. 'TextGenerationMetric', 'TokenClassificationMetric',
  90. 'VideoSummarizationMetric', 'MovieSceneSegmentationMetric',
  91. 'AccuracyMetric', 'BleuMetric', 'ImageInpaintingMetric',
  92. 'ReferringVideoObjectSegmentationMetric',
  93. 'VideoFrameInterpolationMetric', 'VideoStabilizationMetric',
  94. 'VideoSuperResolutionMetric', 'PplMetric',
  95. 'ImageQualityAssessmentDegradationMetric',
  96. 'ImageQualityAssessmentMosMetric', 'TextRankingMetric',
  97. 'LossMetric', 'ImageColorizationMetric', 'OCRRecognitionMetric'
  98. ],
  99. 'models': ['Model', 'TorchModel'],
  100. 'preprocessors': ['Preprocessor'],
  101. 'pipelines': ['Pipeline', 'pipeline'],
  102. 'utils.hub': ['read_config', 'create_model_if_not_exist'],
  103. 'utils.logger': ['get_logger'],
  104. 'utils.constant': ['Tasks'],
  105. 'msdatasets': ['MsDataset']
  106. }
  107. from modelscope.utils import hf_util
  108. from modelscope.utils.hf_util.patcher import _patch_pretrained_class
  109. extra_objects = {}
  110. attributes = dir(hf_util)
  111. imports = [attr for attr in attributes if not attr.startswith('__')]
  112. for _import in imports:
  113. extra_objects[_import] = getattr(hf_util, _import)
  114. def try_import_from_hf(name):
  115. hf_pkgs = ['transformers', 'peft', 'diffusers']
  116. module = None
  117. for pkg in hf_pkgs:
  118. try:
  119. module = getattr(importlib.import_module(pkg), name)
  120. break
  121. except Exception: # noqa
  122. pass
  123. if module is not None:
  124. module = _patch_pretrained_class([module], wrap=True)
  125. else:
  126. raise AttributeError(
  127. f'Cannot import available module of {name} in modelscope,'
  128. f' or related packages({hf_pkgs})')
  129. return module[0]
  130. import sys
  131. sys.modules[__name__] = LazyImportModule(
  132. __name__,
  133. globals()['__file__'],
  134. _import_structure,
  135. module_spec=__spec__,
  136. extra_objects=extra_objects,
  137. extra_import_func=try_import_from_hf,
  138. )