# Copyright (c) Alibaba, Inc. and its affiliates. import glob import os from typing import List, Optional from urllib.parse import urlparse import requests from tqdm.auto import tqdm from modelscope.hub.utils.utils import (MODELSCOPE_URL_SCHEME, encode_media_to_base64, get_endpoint) from modelscope.utils.logger import get_logger logger = get_logger() # Default AIGC model cover image DEFAULT_AIGC_COVER_IMAGE = ( 'https://modelscope.cn/models/modelscope/modelscope_aigc_default_logo/resolve/master/' 'aigc_default_logo.png') class AigcModel: """ Helper class to encapsulate AIGC-specific model creation parameters. This class can be initialized directly with parameters, or loaded from a JSON configuration file using the `from_json_file` classmethod. Example of direct initialization: >>> aigc_model = AigcModel( ... aigc_type='Checkpoint', ... base_model_type='SD_XL', ... model_path='/path/to/your/model.safetensors' ... base_model_id='AI-ModelScope/FLUX.1-dev' ... ) Example of loading from a JSON file: `config.json`: { "model_path": "/path/to/your/model.safetensors", "aigc_type": "Checkpoint", "base_model_type": "SD_XL", "base_model_id": "AI-ModelScope/FLUX.1-dev" } >>> aigc_model = AigcModel.from_json_file('config.json') """ AIGC_TYPES = {'Checkpoint', 'LoRA', 'VAE'} # Supported base model types for reference BASE_MODEL_TYPES = { 'WAN_VIDEO_2_1_14_B', 'SD_XL', 'SD_2', 'SD_3', 'WAN_VIDEO_2_1_T2V_1_3_B', 'UNKNOWN', 'WAN_VIDEO_2_2_TI2V_5_B', 'WAN_VIDEO_2_2_I2V_A_14_B', 'WAN_VIDEO_2_1_I2V_14_B', 'QWEN_IMAGE_20_B', 'SD_2_1', 'SD_1_5', 'FLUX_1', 'WAN_VIDEO_2_2_T2V_A_14_B', 'WAN_VIDEO_2_1_T2V_14_B', 'WAN_VIDEO_2_1_FLF2V_14_B' } OFFICIAL_TAGS = { 'photography', 'illustration-design', 'e-commerce-design', 'dimension', '3d', 'hand-drawn-style', 'logo', 'commodity', 'toy-figurines', 'flat-abstraction', 'character-enhancement', 'scenery', 'animal', 'art-style-strong', 'other-styles', 'architectural-design', 'classic-painting-style', 'cg-fantasy', 'artware', 'construction', 'man', 'woman', 'food', 'automobile-traffic', 'sci-fi-mecha', 'clothing', 'plant', 'other-functions', 'picture-control', 'main-strong', 'character-strong' } def __init__( self, aigc_type: str, base_model_type: str, model_path: str, base_model_id: str = '', tag: Optional[str] = 'v1.0', description: Optional[str] = 'this is an aigc model', cover_images: Optional[List[str]] = None, path_in_repo: Optional[str] = '', trigger_words: Optional[List[str]] = None, official_tags: Optional[List[str]] = None, model_source: Optional[str] = 'USER_UPLOAD', base_model_sub_type: Optional[str] = '', ): """ Initializes the AigcModel helper. Args: model_path (str): The path of checkpoint/LoRA weight file or folder. aigc_type (str): AIGC model type. Recommended: 'Checkpoint', 'LoRA', 'VAE'. base_model_type (str): Vision foundation model type. Recommended values are in BASE_MODEL_TYPES. tag (str, optional): Tag for the AIGC model. Defaults to 'v1.0'. description (str, optional): Model description. Defaults to 'this is an aigc model'. cover_images (List[str], optional): List of cover image URLs. base_model_id (str, optional): Base model name. e.g., 'AI-ModelScope/FLUX.1-dev'. path_in_repo (str, optional): Path in the repository. trigger_words (List[str], optional): Trigger words for the AIGC Lora model. official_tags (List[str], optional): Official tags for the AIGC model. Defaults to None. model_source (str, optional): Source of the model. `USER_UPLOAD`, `TRAINED_FROM_MODELSCOPE` or `TRAINED_FROM_ALIYUN_FC`. Defaults to 'USER_UPLOAD'. base_model_sub_type (str, Optional): Sub vision foundation model type. Defaults to ''. e.g. `SD_1_5` """ self.model_path = model_path self.aigc_type = aigc_type self.base_model_type = base_model_type self.tag = tag self.description = description self.model_source = model_source self.base_model_sub_type = base_model_sub_type # Process cover images - convert local paths to base64 data URLs if cover_images is not None: processed_cover_images = [] for img in cover_images: if isinstance(img, str): # Check if it's a local file path (not a URL) if not (img.startswith('http://') or img.startswith('https://') or img.startswith('data:')): try: # Convert local path to base64 data URL processed_img = encode_media_to_base64(img) processed_cover_images.append(processed_img) logger.info('Converted local image to base64: %s', os.path.basename(img)) except (FileNotFoundError, ValueError) as e: logger.warning( 'Failed to process local image %s: %s. Using as-is.', img, e) processed_cover_images.append(img) else: # Keep URLs and data URLs as-is processed_cover_images.append(img) else: processed_cover_images.append(img) self.cover_images = processed_cover_images else: self.cover_images = [DEFAULT_AIGC_COVER_IMAGE] self.base_model_id = base_model_id self.path_in_repo = path_in_repo self.trigger_words = trigger_words # Validate types and provide warnings self._validate_aigc_type() self._validate_base_model_type() if official_tags: self.official_tags = official_tags self._validate_official_tags() else: self.official_tags = None # Process model path and calculate weights information self._process_model_path() def _validate_aigc_type(self): """Validate aigc_type and provide a warning for unsupported types.""" if self.aigc_type not in self.AIGC_TYPES: supported_types = ', '.join(sorted(self.AIGC_TYPES)) logger.warning(f'Unsupported aigc_type: "{self.aigc_type}". ' f'Recommended values: {supported_types}. ' 'Custom values are allowed but may cause issues.') def _validate_base_model_type(self): """Validate base_model_type and provide warning for unsupported types.""" if self.base_model_type not in self.BASE_MODEL_TYPES: supported_types = ', '.join(sorted(self.BASE_MODEL_TYPES)) logger.warning( f'Your base_model_type: "{self.base_model_type}" may not be supported. ' f'Recommended values: {supported_types}. ' f'Custom values are allowed but may cause issues. ') def _validate_official_tags(self): """Validate official tags and provide warning for unsupported tags.""" invalid_tags = { tag for tag in self.official_tags if tag not in self.OFFICIAL_TAGS } if invalid_tags: supported_tags = ', '.join(self.OFFICIAL_TAGS) invalid_tags_str = ', '.join(f'"{tag}"' for tag in invalid_tags) logger.warning( f'Your tag(s): {invalid_tags_str} may not be supported. ' f'Recommended values: {supported_tags}. ') def _process_model_path(self): """Process model_path to extract weight information""" from modelscope.utils.file_utils import get_file_hash # Expand user path self.model_path = os.path.expanduser(self.model_path) if not os.path.exists(self.model_path): raise ValueError(f'Model path does not exist: {self.model_path}') if os.path.isfile(self.model_path): target_file = self.model_path logger.info('Using file: %s', os.path.basename(target_file)) elif os.path.isdir(self.model_path): # Validate top-level directory: it must not be empty; and if it has files, # they must not be only the common placeholder files top_entries = os.listdir(self.model_path) if len(top_entries) == 0: raise ValueError( f'Directory is empty: {self.model_path}. ' f'Please place at least one model file at the top level (e.g., .safetensors/.pth/.bin).' ) top_files = [ name for name in top_entries if os.path.isfile(os.path.join(self.model_path, name)) ] placeholder_names = { '.gitattributes', 'configuration.json', 'readme.md' } if top_files: normalized = {name.lower() for name in top_files} if normalized.issubset(placeholder_names): raise ValueError( 'Top-level directory contains only [.gitattributes, configuration.json, README.md]. ' 'Please place additional model files at the top level (e.g., .safetensors/.pth/.bin).' ) # Priority order for metadata file: safetensors -> pth -> bin -> first file file_extensions = ['.safetensors', '.pth', '.bin'] target_file = None for ext in file_extensions: files = glob.glob(os.path.join(self.model_path, f'*{ext}')) if files: target_file = files[0] logger.info(f'Found {ext} file: %s', os.path.basename(target_file)) if len(files) > 1: logger.warning( f'Multiple {ext} files found, using: %s for metadata', os.path.basename(target_file)) logger.info(f'Other {ext} files: %s', [os.path.basename(f) for f in files[1:]]) break # If no preferred files found, use the first available file if not target_file: all_files = [ f for f in os.listdir(self.model_path) if os.path.isfile(os.path.join(self.model_path, f)) ] if all_files: target_file = os.path.join(self.model_path, all_files[0]) logger.warning( 'No safetensors/pth/bin files found, using: %s for metadata', os.path.basename(target_file)) logger.info('Available files: %s', all_files) else: raise ValueError( f'No files found in directory: {self.model_path}. ' f'AIGC models require at least one model file (.safetensors recommended).' ) else: raise ValueError( f'Model path must be a file or directory: {self.model_path}') if target_file: # Calculate file hash and size for the target file logger.info('Computing hash and size for %s...', target_file) hash_info = get_file_hash(target_file) # Store weight information self.weight_filename = os.path.basename(target_file) self.weight_sha256 = hash_info['file_hash'] self.weight_size = hash_info['file_size'] self.target_file = target_file def upload_to_repo(self, api, model_id: str, token: Optional[str] = None): """Upload model files to repository.""" logger.info('Uploading model to %s...', model_id) try: if os.path.isdir(self.model_path): # Upload entire folder with path_in_repo support logger.info('Uploading directory: %s', self.model_path) api.upload_folder( repo_id=model_id, folder_path=self.model_path, path_in_repo=self.path_in_repo, token=token, commit_message='Upload model folder for AIGC model') elif os.path.isfile(self.model_path): # Upload single file, target_file is guaranteed to be set by _process_model_path logger.info('Uploading file: %s', self.target_file) api.upload_file( path_or_fileobj=self.target_file, path_in_repo=self.path_in_repo + '/' + self.weight_filename if self.path_in_repo else self.weight_filename, repo_id=model_id, token=token, commit_message=f'Upload {self.weight_filename} ' 'for AIGC model') logger.info('Successfully uploaded model to %s', model_id) return True except Exception as e: logger.warning('Warning: Failed to upload model: %s', e) logger.warning( 'You may need to upload the model manually after creation.') return False def preupload_weights(self, *, cookies: Optional[object] = None, timeout: int = 300, headers: Optional[dict] = None, endpoint: Optional[str] = None) -> None: """Pre-upload aigc model weights to the LFS server. Server may require the sha256 of weights to be registered before creation. This method streams the weight file so the sha gets registered. Args: cookies: Optional requests-style cookies (CookieJar/dict). If provided, preferred. timeout: Request timeout seconds. headers: Optional headers. """ endpoint = endpoint or get_endpoint() endpoint_host: str = urlparse(endpoint.strip()).hostname.lstrip('www.') # https://lfs.modelscope.cn or https://pre-lfs.modelscope.cn base_url: str = f'{MODELSCOPE_URL_SCHEME}lfs.{endpoint_host}' if not endpoint_host.startswith('pre') \ else f'{MODELSCOPE_URL_SCHEME}pre-lfs.{endpoint_host.lstrip("pre.")}' url: str = f'{base_url}/api/v1/models/aigc/weights' file_path = getattr(self, 'target_file', None) or self.model_path file_path = os.path.abspath(os.path.expanduser(file_path)) if not os.path.isfile(file_path): raise ValueError(f'Pre-upload expects a file, got: {file_path}') cookies = dict(cookies) if cookies else None if cookies is None: raise ValueError('Token does not exist, please login first.') headers.update({'Cookie': f"m_session_id={cookies['m_session_id']}"}) file_size = os.path.getsize(file_path) def read_in_chunks(file_object, pbar, chunk_size: int = 1 * 1024 * 1024): while True: ck = file_object.read(chunk_size) if not ck: break pbar.update(len(ck)) yield ck with tqdm( total=file_size, unit='B', unit_scale=True, dynamic_ncols=True, desc='[Pre-uploading] ') as pbar: with open(file_path, 'rb') as f: r = requests.put( url, headers=headers, data=read_in_chunks(f, pbar), timeout=timeout, ) try: resp = r.json() except requests.exceptions.JSONDecodeError: r.raise_for_status() return # If JSON body returned, try best-effort check if isinstance(resp, dict) and resp.get('Success') is False: msg = resp.get('Message', 'unknown error') raise RuntimeError(f'Pre-upload failed: {msg}') def to_dict(self) -> dict: """Converts the AIGC parameters to a dictionary suitable for API calls.""" return { 'aigc_type': self.aigc_type, 'base_model_type': self.base_model_type, 'tag': self.tag, 'description': self.description, 'cover_images': self.cover_images, 'base_model_id': self.base_model_id, 'model_path': self.model_path, 'weight_filename': self.weight_filename, 'weight_sha256': self.weight_sha256, 'weight_size': self.weight_size, 'trigger_words': self.trigger_words, 'official_tags': self.official_tags, 'model_source': self.model_source, 'base_model_sub_type': self.base_model_sub_type, } @classmethod def from_json_file(cls, json_path: str): """ Creates an AigcModel instance from a JSON configuration file. Args: json_path (str): The path to the JSON configuration file. Returns: AigcModel: An instance of the AigcModel. """ import json json_path = os.path.expanduser(json_path) if not os.path.exists(json_path): raise FileNotFoundError( f'JSON config file not found at: {json_path}') with open(json_path, 'r', encoding='utf-8') as f: config = json.load(f) # Ensure required fields are present required_fields = [ 'model_path', 'aigc_type', 'base_model_type', 'base_model_id' ] for field in required_fields: if field not in config: raise ValueError( f"Missing required field in JSON config: '{field}'") return cls(**config)