repo_utils.py 18 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560
  1. # Copyright (c) Alibaba, Inc. and its affiliates.
  2. # Copyright 2022-present, the HuggingFace Inc. team.
  3. import base64
  4. import functools
  5. import hashlib
  6. import io
  7. import os
  8. import sys
  9. from contextlib import contextmanager
  10. from dataclasses import dataclass, field
  11. from datetime import datetime
  12. from fnmatch import fnmatch
  13. from pathlib import Path
  14. from typing import (Any, BinaryIO, Callable, Generator, Iterable, Iterator,
  15. List, Literal, Optional, TypeVar, Union)
  16. from modelscope.hub.constants import DEFAULT_MODELSCOPE_DATA_ENDPOINT
  17. from modelscope.hub.utils.utils import convert_timestamp
  18. from modelscope.utils.file_utils import get_file_hash
  19. T = TypeVar('T')
  20. # Always ignore `.git` and `.cache/modelscope` folders in commits
  21. DEFAULT_IGNORE_PATTERNS = [
  22. '.git',
  23. '.git/*',
  24. '*/.git',
  25. '**/.git/**',
  26. '.cache',
  27. '.cache/*',
  28. '*/.cache',
  29. '**/.cache/**',
  30. ]
  31. UploadMode = Literal['lfs', 'normal']
  32. DATASET_LFS_SUFFIX = [
  33. '.7z',
  34. '.aac',
  35. '.arrow',
  36. '.audio',
  37. '.bmp',
  38. '.bin',
  39. '.bz2',
  40. '.flac',
  41. '.ftz',
  42. '.gif',
  43. '.gz',
  44. '.h5',
  45. '.jack',
  46. '.jpeg',
  47. '.jpg',
  48. '.png',
  49. '.jsonl',
  50. '.joblib',
  51. '.lz4',
  52. '.msgpack',
  53. '.npy',
  54. '.npz',
  55. '.ot',
  56. '.parquet',
  57. '.pb',
  58. '.pickle',
  59. '.pcm',
  60. '.pkl',
  61. '.raw',
  62. '.rar',
  63. '.sam',
  64. '.tar',
  65. '.tgz',
  66. '.wasm',
  67. '.wav',
  68. '.webm',
  69. '.webp',
  70. '.zip',
  71. '.zst',
  72. '.tiff',
  73. '.mp3',
  74. '.mp4',
  75. '.ogg',
  76. ]
  77. MODEL_LFS_SUFFIX = [
  78. '.7z',
  79. '.arrow',
  80. '.bin',
  81. '.bz2',
  82. '.ckpt',
  83. '.ftz',
  84. '.gz',
  85. '.h5',
  86. '.joblib',
  87. '.mlmodel',
  88. '.model',
  89. '.msgpack',
  90. '.npy',
  91. '.npz',
  92. '.onnx',
  93. '.ot',
  94. '.parquet',
  95. '.pb',
  96. '.pickle',
  97. '.pkl',
  98. '.pt',
  99. '.pth',
  100. '.rar',
  101. '.safetensors',
  102. '.tar',
  103. '.tflite',
  104. '.tgz',
  105. '.wasm',
  106. '.xz',
  107. '.zip',
  108. '.zst',
  109. ]
  110. class RepoUtils:
  111. @staticmethod
  112. def filter_repo_objects(
  113. items: Iterable[T],
  114. *,
  115. allow_patterns: Optional[Union[List[str], str]] = None,
  116. ignore_patterns: Optional[Union[List[str], str]] = None,
  117. key: Optional[Callable[[T], str]] = None,
  118. ) -> Generator[T, None, None]:
  119. """Filter repo objects based on an allowlist and a denylist.
  120. Input must be a list of paths (`str` or `Path`) or a list of arbitrary objects.
  121. In the later case, `key` must be provided and specifies a function of one argument
  122. that is used to extract a path from each element in iterable.
  123. Patterns are Unix shell-style wildcards which are NOT regular expressions. See
  124. https://docs.python.org/3/library/fnmatch.html for more details.
  125. Args:
  126. items (`Iterable`):
  127. List of items to filter.
  128. allow_patterns (`str` or `List[str]`, *optional*):
  129. Patterns constituting the allowlist. If provided, item paths must match at
  130. least one pattern from the allowlist.
  131. ignore_patterns (`str` or `List[str]`, *optional*):
  132. Patterns constituting the denylist. If provided, item paths must not match
  133. any patterns from the denylist.
  134. key (`Callable[[T], str]`, *optional*):
  135. Single-argument function to extract a path from each item. If not provided,
  136. the `items` must already be `str` or `Path`.
  137. Returns:
  138. Filtered list of objects, as a generator.
  139. Raises:
  140. :class:`ValueError`:
  141. If `key` is not provided and items are not `str` or `Path`.
  142. Example usage with paths:
  143. ```python
  144. >>> # Filter only PDFs that are not hidden.
  145. >>> list(RepoUtils.filter_repo_objects(
  146. ... ["aaa.PDF", "bbb.jpg", ".ccc.pdf", ".ddd.png"],
  147. ... allow_patterns=["*.pdf"],
  148. ... ignore_patterns=[".*"],
  149. ... ))
  150. ["aaa.pdf"]
  151. ```
  152. """
  153. allow_patterns = allow_patterns if allow_patterns else None
  154. ignore_patterns = ignore_patterns if ignore_patterns else None
  155. if isinstance(allow_patterns, str):
  156. allow_patterns = [allow_patterns]
  157. if isinstance(ignore_patterns, str):
  158. ignore_patterns = [ignore_patterns]
  159. if allow_patterns is not None:
  160. allow_patterns = [
  161. RepoUtils._add_wildcard_to_directories(p)
  162. for p in allow_patterns
  163. ]
  164. if ignore_patterns is not None:
  165. ignore_patterns = [
  166. RepoUtils._add_wildcard_to_directories(p)
  167. for p in ignore_patterns
  168. ]
  169. if key is None:
  170. def _identity(item: T) -> str:
  171. if isinstance(item, str):
  172. return item
  173. if isinstance(item, Path):
  174. return str(item)
  175. raise ValueError(
  176. f'Please provide `key` argument in `filter_repo_objects`: `{item}` is not a string.'
  177. )
  178. key = _identity # Items must be `str` or `Path`, otherwise raise ValueError
  179. for item in items:
  180. path = key(item)
  181. # Skip if there's an allowlist and path doesn't match any
  182. if allow_patterns is not None and not any(
  183. fnmatch(path, r) for r in allow_patterns):
  184. continue
  185. # Skip if there's a denylist and path matches any
  186. if ignore_patterns is not None and any(
  187. fnmatch(path, r) for r in ignore_patterns):
  188. continue
  189. yield item
  190. @staticmethod
  191. def _add_wildcard_to_directories(pattern: str) -> str:
  192. if pattern[-1] == '/':
  193. return pattern + '*'
  194. return pattern
  195. @dataclass
  196. class CommitInfo:
  197. """Data structure containing information about a newly created commit.
  198. Returned by any method that creates a commit on the Hub: [`create_commit`], [`upload_file`], [`upload_folder`],
  199. [`delete_file`], [`delete_folder`]. It inherits from `str` for backward compatibility but using methods specific
  200. to `str` is deprecated.
  201. Attributes:
  202. commit_url (`str`):
  203. Url where to find the commit.
  204. commit_message (`str`):
  205. The summary (first line) of the commit that has been created.
  206. commit_description (`str`):
  207. Description of the commit that has been created. Can be empty.
  208. oid (`str`):
  209. Commit hash id. Example: `"91c54ad1727ee830252e457677f467be0bfd8a57"`.
  210. """
  211. commit_url: str
  212. commit_message: str
  213. commit_description: str
  214. oid: str
  215. def to_dict(cls):
  216. return {
  217. 'commit_url': cls.commit_url,
  218. 'commit_message': cls.commit_message,
  219. 'commit_description': cls.commit_description,
  220. 'oid': cls.oid,
  221. }
  222. @dataclass
  223. class DetailedCommitInfo:
  224. """Detailed commit information from repository history API."""
  225. id: Optional[str]
  226. short_id: Optional[str]
  227. title: Optional[str]
  228. message: Optional[str]
  229. author_name: Optional[str]
  230. authored_date: Optional[datetime]
  231. author_email: Optional[str]
  232. committed_date: Optional[datetime]
  233. committer_name: Optional[str]
  234. committer_email: Optional[str]
  235. created_at: Optional[datetime]
  236. @classmethod
  237. def from_api_response(cls, data: dict) -> 'DetailedCommitInfo':
  238. """Create DetailedCommitInfo from API response data."""
  239. return cls(
  240. id=data.get('Id', ''),
  241. short_id=data.get('ShortId', ''),
  242. title=data.get('Title', ''),
  243. message=data.get('Message', ''),
  244. author_name=data.get('AuthorName', ''),
  245. authored_date=convert_timestamp(data.get('AuthoredDate', None)),
  246. author_email=data.get('AuthorEmail', ''),
  247. committed_date=convert_timestamp(data.get('CommittedDate', None)),
  248. committer_name=data.get('CommitterName', ''),
  249. committer_email=data.get('CommitterEmail', ''),
  250. created_at=convert_timestamp(data.get('CreatedAt', None)),
  251. )
  252. def to_dict(self) -> dict:
  253. """Convert to dictionary."""
  254. return {
  255. 'id': self.id,
  256. 'short_id': self.short_id,
  257. 'title': self.title,
  258. 'message': self.message,
  259. 'author_name': self.author_name,
  260. 'authored_date': self.authored_date,
  261. 'author_email': self.author_email,
  262. 'committed_date': self.committed_date,
  263. 'committer_name': self.committer_name,
  264. 'committer_email': self.committer_email,
  265. 'created_at': self.created_at,
  266. }
  267. @dataclass
  268. class CommitHistoryResponse:
  269. """Response from commit history API."""
  270. commits: Optional[List[DetailedCommitInfo]]
  271. total_count: Optional[int]
  272. @classmethod
  273. def from_api_response(cls, data: dict) -> 'CommitHistoryResponse':
  274. """Create CommitHistoryResponse from API response data."""
  275. commits_data = data.get('Data', {}).get('Commit', [])
  276. if not commits_data:
  277. return cls(
  278. commits=[],
  279. total_count=0,
  280. )
  281. commits = [
  282. DetailedCommitInfo.from_api_response(commit)
  283. for commit in commits_data
  284. ]
  285. return cls(
  286. commits=commits,
  287. total_count=data.get('TotalCount', 0),
  288. )
  289. @dataclass
  290. class RepoUrl:
  291. url: Optional[str] = None
  292. namespace: Optional[str] = None
  293. repo_name: Optional[str] = None
  294. repo_id: Optional[str] = None
  295. repo_type: Optional[str] = None
  296. endpoint: Optional[str] = DEFAULT_MODELSCOPE_DATA_ENDPOINT
  297. def __repr__(self) -> str:
  298. return f"RepoUrl('{self}', endpoint='{self.endpoint}', repo_type='{self.repo_type}', repo_id='{self.repo_id}')"
  299. def git_hash(data: bytes) -> str:
  300. """
  301. Computes the git-sha1 hash of the given bytes, using the same algorithm as git.
  302. This is equivalent to running `git hash-object`. See https://git-scm.com/docs/git-hash-object
  303. for more details.
  304. Note: this method is valid for regular files. For LFS files, the proper git hash is supposed to be computed on the
  305. pointer file content, not the actual file content. However, for simplicity, we directly compare the sha256 of
  306. the LFS file content when we want to compare LFS files.
  307. Args:
  308. data (`bytes`):
  309. The data to compute the git-hash for.
  310. Returns:
  311. `str`: the git-hash of `data` as an hexadecimal string.
  312. """
  313. _kwargs = {'usedforsecurity': False} if sys.version_info >= (3, 9) else {}
  314. sha1 = functools.partial(hashlib.sha1, **_kwargs)
  315. sha = sha1()
  316. sha.update(b'blob ')
  317. sha.update(str(len(data)).encode())
  318. sha.update(b'\0')
  319. sha.update(data)
  320. return sha.hexdigest()
  321. @dataclass
  322. class UploadInfo:
  323. """
  324. Dataclass holding required information to determine whether a blob
  325. should be uploaded to the hub using the LFS protocol or the regular protocol
  326. Args:
  327. sha256 (`str`):
  328. SHA256 hash of the blob
  329. size (`int`):
  330. Size in bytes of the blob
  331. sample (`bytes`):
  332. First 512 bytes of the blob
  333. """
  334. sha256: str
  335. size: int
  336. sample: bytes
  337. @classmethod
  338. def from_path(cls, path: str, file_hash_info: dict = None):
  339. file_hash_info = file_hash_info or get_file_hash(path)
  340. size = file_hash_info['file_size']
  341. sha = file_hash_info['file_hash']
  342. with open(path, 'rb') as f:
  343. sample = f.read(512)
  344. return cls(sha256=sha, size=size, sample=sample)
  345. @classmethod
  346. def from_bytes(cls, data: bytes, file_hash_info: dict = None):
  347. file_hash_info = file_hash_info or get_file_hash(data)
  348. sha = file_hash_info['file_hash']
  349. return cls(size=len(data), sample=data[:512], sha256=sha)
  350. @classmethod
  351. def from_fileobj(cls, fileobj: BinaryIO, file_hash_info: dict = None):
  352. file_hash_info: dict = file_hash_info or get_file_hash(fileobj)
  353. fileobj.seek(0, os.SEEK_SET)
  354. sample = fileobj.read(512)
  355. fileobj.seek(0, os.SEEK_SET)
  356. return cls(
  357. sha256=file_hash_info['file_hash'],
  358. size=file_hash_info['file_size'],
  359. sample=sample)
  360. @dataclass
  361. class CommitOperationAdd:
  362. """Data structure containing information about a file to be added to a commit."""
  363. path_in_repo: str
  364. path_or_fileobj: Union[str, Path, bytes, BinaryIO]
  365. upload_info: UploadInfo = field(init=False, repr=False)
  366. file_hash_info: dict = field(default_factory=dict)
  367. # Internal attributes
  368. # set to "lfs" or "regular" once known
  369. _upload_mode: Optional[UploadMode] = field(
  370. init=False, repr=False, default=None)
  371. # set to True if .gitignore rules prevent the file from being uploaded as LFS
  372. # (server-side check)
  373. _should_ignore: Optional[bool] = field(
  374. init=False, repr=False, default=None)
  375. # set to the remote OID of the file if it has already been uploaded
  376. # useful to determine if a commit will be empty or not
  377. _remote_oid: Optional[str] = field(init=False, repr=False, default=None)
  378. # set to True once the file has been uploaded as LFS
  379. _is_uploaded: bool = field(init=False, repr=False, default=False)
  380. # set to True once the file has been committed
  381. _is_committed: bool = field(init=False, repr=False, default=False)
  382. def __post_init__(self) -> None:
  383. """Validates `path_or_fileobj` and compute `upload_info`."""
  384. self.path_in_repo = _validate_path_in_repo(self.path_in_repo)
  385. # Validate `path_or_fileobj` value
  386. if isinstance(self.path_or_fileobj, Path):
  387. self.path_or_fileobj = str(self.path_or_fileobj)
  388. if isinstance(self.path_or_fileobj, str):
  389. path_or_fileobj = os.path.normpath(
  390. os.path.expanduser(self.path_or_fileobj))
  391. if not os.path.isfile(path_or_fileobj):
  392. raise ValueError(
  393. f"Provided path: '{path_or_fileobj}' is not a file on the local file system"
  394. )
  395. elif not isinstance(self.path_or_fileobj, (io.BufferedIOBase, bytes)):
  396. raise ValueError(
  397. 'path_or_fileobj must be either an instance of str, bytes or'
  398. ' io.BufferedIOBase. If you passed a file-like object, make sure it is'
  399. ' in binary mode.')
  400. if isinstance(self.path_or_fileobj, io.BufferedIOBase):
  401. try:
  402. self.path_or_fileobj.tell()
  403. self.path_or_fileobj.seek(0, os.SEEK_CUR)
  404. except (OSError, AttributeError) as exc:
  405. raise ValueError(
  406. 'path_or_fileobj is a file-like object but does not implement seek() and tell()'
  407. ) from exc
  408. # Compute "upload_info" attribute
  409. if isinstance(self.path_or_fileobj, str):
  410. self.upload_info = UploadInfo.from_path(self.path_or_fileobj,
  411. self.file_hash_info)
  412. elif isinstance(self.path_or_fileobj, bytes):
  413. self.upload_info = UploadInfo.from_bytes(self.path_or_fileobj,
  414. self.file_hash_info)
  415. else:
  416. self.upload_info = UploadInfo.from_fileobj(self.path_or_fileobj,
  417. self.file_hash_info)
  418. @contextmanager
  419. def as_file(self) -> Iterator[BinaryIO]:
  420. """
  421. A context manager that yields a file-like object allowing to read the underlying
  422. data behind `path_or_fileobj`.
  423. """
  424. if isinstance(self.path_or_fileobj, str) or isinstance(
  425. self.path_or_fileobj, Path):
  426. with open(self.path_or_fileobj, 'rb') as file:
  427. yield file
  428. elif isinstance(self.path_or_fileobj, bytes):
  429. yield io.BytesIO(self.path_or_fileobj)
  430. elif isinstance(self.path_or_fileobj, io.BufferedIOBase):
  431. prev_pos = self.path_or_fileobj.tell()
  432. yield self.path_or_fileobj
  433. self.path_or_fileobj.seek(prev_pos, 0)
  434. def b64content(self) -> bytes:
  435. """
  436. The base64-encoded content of `path_or_fileobj`
  437. Returns: `bytes`
  438. """
  439. with self.as_file() as file:
  440. return base64.b64encode(file.read())
  441. @property
  442. def _local_oid(self) -> Optional[str]:
  443. """Return the OID of the local file.
  444. This OID is then compared to `self._remote_oid` to check if the file has changed compared to the remote one.
  445. If the file did not change, we won't upload it again to prevent empty commits.
  446. For LFS files, the OID corresponds to the SHA256 of the file content (used a LFS ref).
  447. For regular files, the OID corresponds to the SHA1 of the file content.
  448. Note: this is slightly different to git OID computation since the oid of an LFS file is usually the git-SHA1
  449. of the pointer file content (not the actual file content). However, using the SHA256 is enough to detect
  450. changes and more convenient client-side.
  451. """
  452. if self._upload_mode is None:
  453. return None
  454. elif self._upload_mode == 'lfs':
  455. return self.upload_info.sha256
  456. else:
  457. # Regular file => compute sha1
  458. # => no need to read by chunk since the file is guaranteed to be <=5MB.
  459. with self.as_file() as file:
  460. return git_hash(file.read())
  461. def _validate_path_in_repo(path_in_repo: str) -> str:
  462. # Validate `path_in_repo` value to prevent a server-side issue
  463. if path_in_repo.startswith('/'):
  464. path_in_repo = path_in_repo[1:]
  465. if path_in_repo == '.' or path_in_repo == '..' or path_in_repo.startswith(
  466. '../'):
  467. raise ValueError(
  468. f"Invalid `path_in_repo` in CommitOperation: '{path_in_repo}'")
  469. if path_in_repo.startswith('./'):
  470. path_in_repo = path_in_repo[2:]
  471. return path_in_repo
  472. CommitOperation = Union[CommitOperationAdd, ]