utils.py 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421
  1. # Copyright (c) Alibaba, Inc. and its affiliates.
  2. import contextlib
  3. import hashlib
  4. import os
  5. import sys
  6. import time
  7. import zoneinfo
  8. from datetime import datetime
  9. from pathlib import Path
  10. from typing import Generator, List, Optional, Union
  11. from filelock import BaseFileLock, FileLock, SoftFileLock, Timeout
  12. from modelscope.hub.constants import (DEFAULT_MODELSCOPE_DOMAIN,
  13. DEFAULT_MODELSCOPE_GROUP,
  14. DEFAULT_MODELSCOPE_INTL_DOMAIN,
  15. MODEL_ID_SEPARATOR, MODELSCOPE_DOMAIN,
  16. MODELSCOPE_SDK_DEBUG,
  17. MODELSCOPE_URL_SCHEME)
  18. from modelscope.hub.errors import FileIntegrityError
  19. from modelscope.utils.logger import get_logger
  20. logger = get_logger()
  21. def model_id_to_group_owner_name(model_id):
  22. if MODEL_ID_SEPARATOR in model_id:
  23. group_or_owner = model_id.split(MODEL_ID_SEPARATOR)[0]
  24. name = model_id.split(MODEL_ID_SEPARATOR)[1]
  25. else:
  26. group_or_owner = DEFAULT_MODELSCOPE_GROUP
  27. name = model_id
  28. return group_or_owner, name
  29. def is_env_true(var_name):
  30. value = os.environ.get(var_name, '').strip().lower()
  31. return value == 'true'
  32. def get_domain(cn_site=True):
  33. if MODELSCOPE_DOMAIN in os.environ and os.getenv(MODELSCOPE_DOMAIN):
  34. return os.getenv(MODELSCOPE_DOMAIN)
  35. if cn_site:
  36. return DEFAULT_MODELSCOPE_DOMAIN
  37. else:
  38. return DEFAULT_MODELSCOPE_INTL_DOMAIN
  39. def convert_patterns(raw_input: Union[str, List[str]]):
  40. output = None
  41. if isinstance(raw_input, str):
  42. output = list()
  43. if ',' in raw_input:
  44. output = [s.strip() for s in raw_input.split(',')]
  45. else:
  46. output.append(raw_input.strip())
  47. elif isinstance(raw_input, list):
  48. output = list()
  49. for s in raw_input:
  50. if isinstance(s, str):
  51. if ',' in s:
  52. output.extend([ss.strip() for ss in s.split(',')])
  53. else:
  54. output.append(s.strip())
  55. return output
  56. # during model download, the '.' would be converted to '___' to produce
  57. # actual physical (masked) directory for storage
  58. def get_model_masked_directory(directory, model_id):
  59. if sys.platform.startswith('win'):
  60. parts = directory.rsplit('\\', 2)
  61. else:
  62. parts = directory.rsplit('/', 2)
  63. # this is the actual directory the model files are located.
  64. masked_directory = os.path.join(parts[0], model_id.replace('.', '___'))
  65. return masked_directory
  66. def convert_readable_size(size_bytes: int) -> str:
  67. import math
  68. if size_bytes == 0:
  69. return '0B'
  70. size_name = ('B', 'KB', 'MB', 'GB', 'TB', 'PB', 'EB', 'ZB', 'YB')
  71. i = int(math.floor(math.log(size_bytes, 1024)))
  72. p = math.pow(1024, i)
  73. s = round(size_bytes / p, 2)
  74. return f'{s} {size_name[i]}'
  75. def get_folder_size(folder_path: str) -> int:
  76. total_size = 0
  77. for path in Path(folder_path).rglob('*'):
  78. if path.is_file():
  79. total_size += path.stat().st_size
  80. return total_size
  81. # return a readable string that describe size of for a given folder (MB, GB etc.)
  82. def get_readable_folder_size(folder_path: str) -> str:
  83. return convert_readable_size(get_folder_size(folder_path=folder_path))
  84. def get_cache_dir(model_id: Optional[str] = None):
  85. """cache dir precedence:
  86. function parameter > environment > ~/.cache/modelscope/hub
  87. Args:
  88. model_id (str, optional): The model id.
  89. Returns:
  90. str: the model_id dir if model_id not None, otherwise cache root dir.
  91. """
  92. default_cache_dir = Path.home().joinpath('.cache', 'modelscope')
  93. base_path = os.getenv('MODELSCOPE_CACHE',
  94. os.path.join(default_cache_dir, 'hub'))
  95. return base_path if model_id is None else os.path.join(
  96. base_path, model_id + '/')
  97. def get_release_datetime():
  98. if MODELSCOPE_SDK_DEBUG in os.environ:
  99. rt = int(round(datetime.now().timestamp()))
  100. else:
  101. from modelscope import version
  102. rt = int(
  103. round(
  104. datetime.strptime(version.__release_datetime__,
  105. '%Y-%m-%d %H:%M:%S').timestamp()))
  106. return rt
  107. def get_endpoint(cn_site=True):
  108. return MODELSCOPE_URL_SCHEME + get_domain(cn_site)
  109. def compute_hash(file_path):
  110. BUFFER_SIZE = 1024 * 64 # 64k buffer size
  111. sha256_hash = hashlib.sha256()
  112. with open(file_path, 'rb') as f:
  113. while True:
  114. data = f.read(BUFFER_SIZE)
  115. if not data:
  116. break
  117. sha256_hash.update(data)
  118. return sha256_hash.hexdigest()
  119. def file_integrity_validation(file_path, expected_sha256) -> bool:
  120. """Validate the file hash is expected, if not, delete the file
  121. Args:
  122. file_path (str): The file to validate
  123. expected_sha256 (str): The expected sha256 hash
  124. Returns:
  125. bool: True if the file is valid, False otherwise
  126. """
  127. file_sha256 = compute_hash(file_path)
  128. if not file_sha256 == expected_sha256:
  129. os.remove(file_path)
  130. msg = 'File %s integrity check failed, expected sha256 signature is %s, actual is %s, the download may be incomplete, please try again.' % ( # noqa E501
  131. file_path, expected_sha256, file_sha256)
  132. logger.error(msg)
  133. return False
  134. return True
  135. def add_content_to_file(repo,
  136. file_name: str,
  137. patterns: List[str],
  138. commit_message: Optional[str] = None,
  139. ignore_push_error=False) -> None:
  140. if isinstance(patterns, str):
  141. patterns = [patterns]
  142. if commit_message is None:
  143. commit_message = f'Add `{patterns[0]}` patterns to {file_name}'
  144. # Get current file content
  145. repo_dir = repo.model_dir
  146. file_path = os.path.join(repo_dir, file_name)
  147. if os.path.exists(file_path):
  148. with open(file_path, 'r', encoding='utf-8') as f:
  149. current_content = f.read()
  150. else:
  151. current_content = ''
  152. # Add the patterns to file
  153. content = current_content
  154. for pattern in patterns:
  155. if pattern not in content:
  156. if len(content) > 0 and not content.endswith('\n'):
  157. content += '\n'
  158. content += f'{pattern}\n'
  159. # Write the file if it has changed
  160. if content != current_content:
  161. with open(file_path, 'w', encoding='utf-8') as f:
  162. logger.debug(f'Writing {file_name} file. Content: {content}')
  163. f.write(content)
  164. try:
  165. repo.push(commit_message)
  166. except Exception as e:
  167. if ignore_push_error:
  168. pass
  169. else:
  170. raise e
  171. _TIMESINCE_CHUNKS = (
  172. # Label, divider, max value
  173. ('second', 1, 60),
  174. ('minute', 60, 60),
  175. ('hour', 60 * 60, 24),
  176. ('day', 60 * 60 * 24, 6),
  177. ('week', 60 * 60 * 24 * 7, 6),
  178. ('month', 60 * 60 * 24 * 30, 11),
  179. ('year', 60 * 60 * 24 * 365, None),
  180. )
  181. def format_timesince(ts: float) -> str:
  182. """Format timestamp in seconds into a human-readable string, relative to now.
  183. """
  184. delta = time.time() - ts
  185. if delta < 20:
  186. return 'a few seconds ago'
  187. for label, divider, max_value in _TIMESINCE_CHUNKS: # noqa: B007
  188. value = round(delta / divider)
  189. if max_value is not None and value <= max_value:
  190. break
  191. return f"{value} {label}{'s' if value > 1 else ''} ago"
  192. def tabulate(rows: List[List[Union[str, int]]], headers: List[str]) -> str:
  193. """
  194. Inspired by:
  195. - stackoverflow.com/a/8356620/593036
  196. - stackoverflow.com/questions/9535954/printing-lists-as-tabular-data
  197. """
  198. col_widths = [max(len(str(x)) for x in col) for col in zip(*rows, headers)]
  199. row_format = ('{{:{}}} ' * len(headers)).format(*col_widths)
  200. lines = []
  201. lines.append(row_format.format(*headers))
  202. lines.append(row_format.format(*['-' * w for w in col_widths]))
  203. for row in rows:
  204. lines.append(row_format.format(*row))
  205. return '\n'.join(lines)
  206. # Part of the code borrowed from the awesome work of huggingface_hub/transformers
  207. def strtobool(val):
  208. val = val.lower()
  209. if val in {'y', 'yes', 't', 'true', 'on', '1'}:
  210. return 1
  211. if val in {'n', 'no', 'f', 'false', 'off', '0'}:
  212. return 0
  213. raise ValueError(f'invalid truth value {val!r}')
  214. @contextlib.contextmanager
  215. def weak_file_lock(lock_file: Union[str, Path],
  216. *,
  217. timeout: Optional[float] = None
  218. ) -> Generator[BaseFileLock, None, None]:
  219. default_interval = 60
  220. lock = FileLock(lock_file, timeout=default_interval)
  221. start_time = time.time()
  222. while True:
  223. elapsed_time = time.time() - start_time
  224. if timeout is not None and elapsed_time >= timeout:
  225. raise Timeout(str(lock_file))
  226. try:
  227. lock.acquire(
  228. timeout=min(default_interval, timeout - elapsed_time)
  229. if timeout else default_interval) # noqa
  230. except Timeout:
  231. logger.info(
  232. f'Still waiting to acquire lock on {lock_file} (elapsed: {time.time() - start_time:.1f} seconds)'
  233. )
  234. except NotImplementedError as e:
  235. if 'use SoftFileLock instead' in str(e):
  236. logger.warning(
  237. 'FileSystem does not appear to support flock. Falling back to SoftFileLock for %s',
  238. lock_file)
  239. lock = SoftFileLock(lock_file, timeout=default_interval)
  240. continue
  241. else:
  242. break
  243. try:
  244. yield lock
  245. finally:
  246. try:
  247. lock.release()
  248. except OSError:
  249. try:
  250. Path(lock_file).unlink()
  251. except OSError:
  252. pass
  253. def convert_timestamp(time_stamp: Union[int, str, datetime],
  254. time_zone: str = 'Asia/Shanghai') -> Optional[datetime]:
  255. """Convert a UNIX/string timestamp to a timezone-aware datetime object.
  256. Args:
  257. time_stamp: UNIX timestamp (int), ISO string, or datetime object
  258. time_zone: Target timezone for non-UTC timestamps (default: 'Asia/Shanghai')
  259. Returns:
  260. Timezone-aware datetime object or None if input is None
  261. """
  262. if not time_stamp:
  263. return None
  264. # Handle datetime objects first
  265. if isinstance(time_stamp, datetime):
  266. return time_stamp
  267. if isinstance(time_stamp, str):
  268. try:
  269. if time_stamp.endswith('Z'):
  270. # Normalize fractional seconds to 6 digits
  271. if '.' not in time_stamp:
  272. # No fractional seconds (e.g., "2024-11-16T00:27:02Z")
  273. time_stamp = time_stamp[:-1] + '.000000Z'
  274. else:
  275. # Has fractional seconds (e.g., "2022-08-19T07:19:38.123456789Z")
  276. base, fraction = time_stamp[:-1].split('.')
  277. # Truncate or pad to 6 digits
  278. fraction = fraction[:6].ljust(6, '0')
  279. time_stamp = f'{base}.{fraction}Z'
  280. dt = datetime.strptime(time_stamp,
  281. '%Y-%m-%dT%H:%M:%S.%fZ').replace(
  282. tzinfo=zoneinfo.ZoneInfo('UTC'))
  283. if time_zone != 'UTC':
  284. dt = dt.astimezone(zoneinfo.ZoneInfo(time_zone))
  285. return dt
  286. else:
  287. # Try parsing common ISO formats
  288. formats = [
  289. '%Y-%m-%dT%H:%M:%S.%f', # With microseconds
  290. '%Y-%m-%dT%H:%M:%S', # Without microseconds
  291. '%Y-%m-%d %H:%M:%S.%f', # Space separator with microseconds
  292. '%Y-%m-%d %H:%M:%S', # Space separator without microseconds
  293. ]
  294. for fmt in formats:
  295. try:
  296. return datetime.strptime(
  297. time_stamp,
  298. fmt).replace(tzinfo=zoneinfo.ZoneInfo(time_zone))
  299. except ValueError:
  300. continue
  301. raise ValueError(
  302. f"Unsupported timestamp format: '{time_stamp}'")
  303. except ValueError as e:
  304. raise ValueError(
  305. f"Cannot parse '{time_stamp}' as a datetime. Expected formats: "
  306. f"'YYYY-MM-DDTHH:MM:SS[.ffffff]Z' (UTC) or 'YYYY-MM-DDTHH:MM:SS[.ffffff]' (local)"
  307. ) from e
  308. elif isinstance(time_stamp, int):
  309. try:
  310. # UNIX timestamps are always in UTC, then convert to target timezone
  311. return datetime.fromtimestamp(
  312. time_stamp, tz=zoneinfo.ZoneInfo('UTC')).astimezone(
  313. zoneinfo.ZoneInfo(time_zone))
  314. except (ValueError, OSError) as e:
  315. raise ValueError(
  316. f"Cannot convert '{time_stamp}' to datetime. Ensure it's a valid UNIX timestamp."
  317. ) from e
  318. else:
  319. raise TypeError(
  320. f"Unsupported type '{type(time_stamp)}'. Expected int, str, or datetime."
  321. )
  322. def encode_media_to_base64(media_file_path: str) -> str:
  323. """
  324. Encode image or video file to base64 string.
  325. Args:
  326. media_file_path (str): Path to the image or video file
  327. Returns:
  328. str: Base64 encoded string with data URL prefix
  329. Raises:
  330. FileNotFoundError: If image/video file doesn't exist
  331. ValueError: If file is not a valid format
  332. """
  333. import base64
  334. import mimetypes
  335. # Expand user path
  336. media_file_path = os.path.expanduser(media_file_path)
  337. if not os.path.exists(media_file_path):
  338. raise FileNotFoundError(f'Image file not found: {media_file_path}')
  339. if not os.path.isfile(media_file_path):
  340. raise ValueError(f'Path is not a file: {media_file_path}')
  341. # Get MIME type
  342. mime_type, _ = mimetypes.guess_type(media_file_path)
  343. if not mime_type:
  344. raise ValueError(f'File is not a valid format: {media_file_path}')
  345. # Read and encode file
  346. with open(media_file_path, 'rb') as media_file:
  347. image_data = media_file.read()
  348. base64_data = base64.b64encode(image_data).decode('utf-8')
  349. # Return data URL format
  350. return f'data:{mime_type};base64,{base64_data}'