aigc.py 18 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429
  1. # Copyright (c) Alibaba, Inc. and its affiliates.
  2. import glob
  3. import os
  4. from typing import List, Optional
  5. from urllib.parse import urlparse
  6. import requests
  7. from tqdm.auto import tqdm
  8. from modelscope.hub.utils.utils import (MODELSCOPE_URL_SCHEME,
  9. encode_media_to_base64, get_endpoint)
  10. from modelscope.utils.logger import get_logger
  11. logger = get_logger()
  12. # Default AIGC model cover image
  13. DEFAULT_AIGC_COVER_IMAGE = (
  14. 'https://modelscope.cn/models/modelscope/modelscope_aigc_default_logo/resolve/master/'
  15. 'aigc_default_logo.png')
  16. class AigcModel:
  17. """
  18. Helper class to encapsulate AIGC-specific model creation parameters.
  19. This class can be initialized directly with parameters, or loaded from a
  20. JSON configuration file using the `from_json_file` classmethod.
  21. Example of direct initialization:
  22. >>> aigc_model = AigcModel(
  23. ... aigc_type='Checkpoint',
  24. ... base_model_type='SD_XL',
  25. ... model_path='/path/to/your/model.safetensors'
  26. ... base_model_id='AI-ModelScope/FLUX.1-dev'
  27. ... )
  28. Example of loading from a JSON file:
  29. `config.json`:
  30. {
  31. "model_path": "/path/to/your/model.safetensors",
  32. "aigc_type": "Checkpoint",
  33. "base_model_type": "SD_XL",
  34. "base_model_id": "AI-ModelScope/FLUX.1-dev"
  35. }
  36. >>> aigc_model = AigcModel.from_json_file('config.json')
  37. """
  38. AIGC_TYPES = {'Checkpoint', 'LoRA', 'VAE'}
  39. # Supported base model types for reference
  40. BASE_MODEL_TYPES = {
  41. 'WAN_VIDEO_2_1_14_B', 'SD_XL', 'SD_2', 'SD_3',
  42. 'WAN_VIDEO_2_1_T2V_1_3_B', 'UNKNOWN', 'WAN_VIDEO_2_2_TI2V_5_B',
  43. 'WAN_VIDEO_2_2_I2V_A_14_B', 'WAN_VIDEO_2_1_I2V_14_B',
  44. 'QWEN_IMAGE_20_B', 'SD_2_1', 'SD_1_5', 'FLUX_1',
  45. 'WAN_VIDEO_2_2_T2V_A_14_B', 'WAN_VIDEO_2_1_T2V_14_B',
  46. 'WAN_VIDEO_2_1_FLF2V_14_B'
  47. }
  48. OFFICIAL_TAGS = {
  49. 'photography', 'illustration-design', 'e-commerce-design', 'dimension',
  50. '3d', 'hand-drawn-style', 'logo', 'commodity', 'toy-figurines',
  51. 'flat-abstraction', 'character-enhancement', 'scenery', 'animal',
  52. 'art-style-strong', 'other-styles', 'architectural-design',
  53. 'classic-painting-style', 'cg-fantasy', 'artware', 'construction',
  54. 'man', 'woman', 'food', 'automobile-traffic', 'sci-fi-mecha',
  55. 'clothing', 'plant', 'other-functions', 'picture-control',
  56. 'main-strong', 'character-strong'
  57. }
  58. def __init__(
  59. self,
  60. aigc_type: str,
  61. base_model_type: str,
  62. model_path: str,
  63. base_model_id: str = '',
  64. tag: Optional[str] = 'v1.0',
  65. description: Optional[str] = 'this is an aigc model',
  66. cover_images: Optional[List[str]] = None,
  67. path_in_repo: Optional[str] = '',
  68. trigger_words: Optional[List[str]] = None,
  69. official_tags: Optional[List[str]] = None,
  70. model_source: Optional[str] = 'USER_UPLOAD',
  71. base_model_sub_type: Optional[str] = '',
  72. ):
  73. """
  74. Initializes the AigcModel helper.
  75. Args:
  76. model_path (str): The path of checkpoint/LoRA weight file or folder.
  77. aigc_type (str): AIGC model type. Recommended: 'Checkpoint', 'LoRA', 'VAE'.
  78. base_model_type (str): Vision foundation model type. Recommended values are in BASE_MODEL_TYPES.
  79. tag (str, optional): Tag for the AIGC model. Defaults to 'v1.0'.
  80. description (str, optional): Model description. Defaults to 'this is an aigc model'.
  81. cover_images (List[str], optional): List of cover image URLs.
  82. base_model_id (str, optional): Base model name. e.g., 'AI-ModelScope/FLUX.1-dev'.
  83. path_in_repo (str, optional): Path in the repository.
  84. trigger_words (List[str], optional): Trigger words for the AIGC Lora model.
  85. official_tags (List[str], optional): Official tags for the AIGC model. Defaults to None.
  86. model_source (str, optional): Source of the model.
  87. `USER_UPLOAD`, `TRAINED_FROM_MODELSCOPE` or `TRAINED_FROM_ALIYUN_FC`. Defaults to 'USER_UPLOAD'.
  88. base_model_sub_type (str, Optional): Sub vision foundation model type. Defaults to ''. e.g. `SD_1_5`
  89. """
  90. self.model_path = model_path
  91. self.aigc_type = aigc_type
  92. self.base_model_type = base_model_type
  93. self.tag = tag
  94. self.description = description
  95. self.model_source = model_source
  96. self.base_model_sub_type = base_model_sub_type
  97. # Process cover images - convert local paths to base64 data URLs
  98. if cover_images is not None:
  99. processed_cover_images = []
  100. for img in cover_images:
  101. if isinstance(img, str):
  102. # Check if it's a local file path (not a URL)
  103. if not (img.startswith('http://')
  104. or img.startswith('https://')
  105. or img.startswith('data:')):
  106. try:
  107. # Convert local path to base64 data URL
  108. processed_img = encode_media_to_base64(img)
  109. processed_cover_images.append(processed_img)
  110. logger.info('Converted local image to base64: %s',
  111. os.path.basename(img))
  112. except (FileNotFoundError, ValueError) as e:
  113. logger.warning(
  114. 'Failed to process local image %s: %s. Using as-is.',
  115. img, e)
  116. processed_cover_images.append(img)
  117. else:
  118. # Keep URLs and data URLs as-is
  119. processed_cover_images.append(img)
  120. else:
  121. processed_cover_images.append(img)
  122. self.cover_images = processed_cover_images
  123. else:
  124. self.cover_images = [DEFAULT_AIGC_COVER_IMAGE]
  125. self.base_model_id = base_model_id
  126. self.path_in_repo = path_in_repo
  127. self.trigger_words = trigger_words
  128. # Validate types and provide warnings
  129. self._validate_aigc_type()
  130. self._validate_base_model_type()
  131. if official_tags:
  132. self.official_tags = official_tags
  133. self._validate_official_tags()
  134. else:
  135. self.official_tags = None
  136. # Process model path and calculate weights information
  137. self._process_model_path()
  138. def _validate_aigc_type(self):
  139. """Validate aigc_type and provide a warning for unsupported types."""
  140. if self.aigc_type not in self.AIGC_TYPES:
  141. supported_types = ', '.join(sorted(self.AIGC_TYPES))
  142. logger.warning(f'Unsupported aigc_type: "{self.aigc_type}". '
  143. f'Recommended values: {supported_types}. '
  144. 'Custom values are allowed but may cause issues.')
  145. def _validate_base_model_type(self):
  146. """Validate base_model_type and provide warning for unsupported types."""
  147. if self.base_model_type not in self.BASE_MODEL_TYPES:
  148. supported_types = ', '.join(sorted(self.BASE_MODEL_TYPES))
  149. logger.warning(
  150. f'Your base_model_type: "{self.base_model_type}" may not be supported. '
  151. f'Recommended values: {supported_types}. '
  152. f'Custom values are allowed but may cause issues. ')
  153. def _validate_official_tags(self):
  154. """Validate official tags and provide warning for unsupported tags."""
  155. invalid_tags = {
  156. tag
  157. for tag in self.official_tags if tag not in self.OFFICIAL_TAGS
  158. }
  159. if invalid_tags:
  160. supported_tags = ', '.join(self.OFFICIAL_TAGS)
  161. invalid_tags_str = ', '.join(f'"{tag}"' for tag in invalid_tags)
  162. logger.warning(
  163. f'Your tag(s): {invalid_tags_str} may not be supported. '
  164. f'Recommended values: {supported_tags}. ')
  165. def _process_model_path(self):
  166. """Process model_path to extract weight information"""
  167. from modelscope.utils.file_utils import get_file_hash
  168. # Expand user path
  169. self.model_path = os.path.expanduser(self.model_path)
  170. if not os.path.exists(self.model_path):
  171. raise ValueError(f'Model path does not exist: {self.model_path}')
  172. if os.path.isfile(self.model_path):
  173. target_file = self.model_path
  174. logger.info('Using file: %s', os.path.basename(target_file))
  175. elif os.path.isdir(self.model_path):
  176. # Validate top-level directory: it must not be empty; and if it has files,
  177. # they must not be only the common placeholder files
  178. top_entries = os.listdir(self.model_path)
  179. if len(top_entries) == 0:
  180. raise ValueError(
  181. f'Directory is empty: {self.model_path}. '
  182. f'Please place at least one model file at the top level (e.g., .safetensors/.pth/.bin).'
  183. )
  184. top_files = [
  185. name for name in top_entries
  186. if os.path.isfile(os.path.join(self.model_path, name))
  187. ]
  188. placeholder_names = {
  189. '.gitattributes', 'configuration.json', 'readme.md'
  190. }
  191. if top_files:
  192. normalized = {name.lower() for name in top_files}
  193. if normalized.issubset(placeholder_names):
  194. raise ValueError(
  195. 'Top-level directory contains only [.gitattributes, configuration.json, README.md]. '
  196. 'Please place additional model files at the top level (e.g., .safetensors/.pth/.bin).'
  197. )
  198. # Priority order for metadata file: safetensors -> pth -> bin -> first file
  199. file_extensions = ['.safetensors', '.pth', '.bin']
  200. target_file = None
  201. for ext in file_extensions:
  202. files = glob.glob(os.path.join(self.model_path, f'*{ext}'))
  203. if files:
  204. target_file = files[0]
  205. logger.info(f'Found {ext} file: %s',
  206. os.path.basename(target_file))
  207. if len(files) > 1:
  208. logger.warning(
  209. f'Multiple {ext} files found, using: %s for metadata',
  210. os.path.basename(target_file))
  211. logger.info(f'Other {ext} files: %s',
  212. [os.path.basename(f) for f in files[1:]])
  213. break
  214. # If no preferred files found, use the first available file
  215. if not target_file:
  216. all_files = [
  217. f for f in os.listdir(self.model_path)
  218. if os.path.isfile(os.path.join(self.model_path, f))
  219. ]
  220. if all_files:
  221. target_file = os.path.join(self.model_path, all_files[0])
  222. logger.warning(
  223. 'No safetensors/pth/bin files found, using: %s for metadata',
  224. os.path.basename(target_file))
  225. logger.info('Available files: %s', all_files)
  226. else:
  227. raise ValueError(
  228. f'No files found in directory: {self.model_path}. '
  229. f'AIGC models require at least one model file (.safetensors recommended).'
  230. )
  231. else:
  232. raise ValueError(
  233. f'Model path must be a file or directory: {self.model_path}')
  234. if target_file:
  235. # Calculate file hash and size for the target file
  236. logger.info('Computing hash and size for %s...', target_file)
  237. hash_info = get_file_hash(target_file)
  238. # Store weight information
  239. self.weight_filename = os.path.basename(target_file)
  240. self.weight_sha256 = hash_info['file_hash']
  241. self.weight_size = hash_info['file_size']
  242. self.target_file = target_file
  243. def upload_to_repo(self, api, model_id: str, token: Optional[str] = None):
  244. """Upload model files to repository."""
  245. logger.info('Uploading model to %s...', model_id)
  246. try:
  247. if os.path.isdir(self.model_path):
  248. # Upload entire folder with path_in_repo support
  249. logger.info('Uploading directory: %s', self.model_path)
  250. api.upload_folder(
  251. repo_id=model_id,
  252. folder_path=self.model_path,
  253. path_in_repo=self.path_in_repo,
  254. token=token,
  255. commit_message='Upload model folder for AIGC model')
  256. elif os.path.isfile(self.model_path):
  257. # Upload single file, target_file is guaranteed to be set by _process_model_path
  258. logger.info('Uploading file: %s', self.target_file)
  259. api.upload_file(
  260. path_or_fileobj=self.target_file,
  261. path_in_repo=self.path_in_repo + '/' + self.weight_filename
  262. if self.path_in_repo else self.weight_filename,
  263. repo_id=model_id,
  264. token=token,
  265. commit_message=f'Upload {self.weight_filename} '
  266. 'for AIGC model')
  267. logger.info('Successfully uploaded model to %s', model_id)
  268. return True
  269. except Exception as e:
  270. logger.warning('Warning: Failed to upload model: %s', e)
  271. logger.warning(
  272. 'You may need to upload the model manually after creation.')
  273. return False
  274. def preupload_weights(self,
  275. *,
  276. cookies: Optional[object] = None,
  277. timeout: int = 300,
  278. headers: Optional[dict] = None,
  279. endpoint: Optional[str] = None) -> None:
  280. """Pre-upload aigc model weights to the LFS server.
  281. Server may require the sha256 of weights to be registered before creation.
  282. This method streams the weight file so the sha gets registered.
  283. Args:
  284. cookies: Optional requests-style cookies (CookieJar/dict). If provided, preferred.
  285. timeout: Request timeout seconds.
  286. headers: Optional headers.
  287. """
  288. endpoint = endpoint or get_endpoint()
  289. endpoint_host: str = urlparse(endpoint.strip()).hostname.lstrip('www.')
  290. # https://lfs.modelscope.cn or https://pre-lfs.modelscope.cn
  291. base_url: str = f'{MODELSCOPE_URL_SCHEME}lfs.{endpoint_host}' if not endpoint_host.startswith('pre') \
  292. else f'{MODELSCOPE_URL_SCHEME}pre-lfs.{endpoint_host.lstrip("pre.")}'
  293. url: str = f'{base_url}/api/v1/models/aigc/weights'
  294. file_path = getattr(self, 'target_file', None) or self.model_path
  295. file_path = os.path.abspath(os.path.expanduser(file_path))
  296. if not os.path.isfile(file_path):
  297. raise ValueError(f'Pre-upload expects a file, got: {file_path}')
  298. cookies = dict(cookies) if cookies else None
  299. if cookies is None:
  300. raise ValueError('Token does not exist, please login first.')
  301. headers.update({'Cookie': f"m_session_id={cookies['m_session_id']}"})
  302. file_size = os.path.getsize(file_path)
  303. def read_in_chunks(file_object,
  304. pbar,
  305. chunk_size: int = 1 * 1024 * 1024):
  306. while True:
  307. ck = file_object.read(chunk_size)
  308. if not ck:
  309. break
  310. pbar.update(len(ck))
  311. yield ck
  312. with tqdm(
  313. total=file_size,
  314. unit='B',
  315. unit_scale=True,
  316. dynamic_ncols=True,
  317. desc='[Pre-uploading] ') as pbar:
  318. with open(file_path, 'rb') as f:
  319. r = requests.put(
  320. url,
  321. headers=headers,
  322. data=read_in_chunks(f, pbar),
  323. timeout=timeout,
  324. )
  325. try:
  326. resp = r.json()
  327. except requests.exceptions.JSONDecodeError:
  328. r.raise_for_status()
  329. return
  330. # If JSON body returned, try best-effort check
  331. if isinstance(resp, dict) and resp.get('Success') is False:
  332. msg = resp.get('Message', 'unknown error')
  333. raise RuntimeError(f'Pre-upload failed: {msg}')
  334. def to_dict(self) -> dict:
  335. """Converts the AIGC parameters to a dictionary suitable for API calls."""
  336. return {
  337. 'aigc_type': self.aigc_type,
  338. 'base_model_type': self.base_model_type,
  339. 'tag': self.tag,
  340. 'description': self.description,
  341. 'cover_images': self.cover_images,
  342. 'base_model_id': self.base_model_id,
  343. 'model_path': self.model_path,
  344. 'weight_filename': self.weight_filename,
  345. 'weight_sha256': self.weight_sha256,
  346. 'weight_size': self.weight_size,
  347. 'trigger_words': self.trigger_words,
  348. 'official_tags': self.official_tags,
  349. 'model_source': self.model_source,
  350. 'base_model_sub_type': self.base_model_sub_type,
  351. }
  352. @classmethod
  353. def from_json_file(cls, json_path: str):
  354. """
  355. Creates an AigcModel instance from a JSON configuration file.
  356. Args:
  357. json_path (str): The path to the JSON configuration file.
  358. Returns:
  359. AigcModel: An instance of the AigcModel.
  360. """
  361. import json
  362. json_path = os.path.expanduser(json_path)
  363. if not os.path.exists(json_path):
  364. raise FileNotFoundError(
  365. f'JSON config file not found at: {json_path}')
  366. with open(json_path, 'r', encoding='utf-8') as f:
  367. config = json.load(f)
  368. # Ensure required fields are present
  369. required_fields = [
  370. 'model_path', 'aigc_type', 'base_model_type', 'base_model_id'
  371. ]
  372. for field in required_fields:
  373. if field not in config:
  374. raise ValueError(
  375. f"Missing required field in JSON config: '{field}'")
  376. return cls(**config)