# Copyright (c) Alibaba, Inc. and its affiliates. # Copyright 2022-present, the HuggingFace Inc. team. import base64 import functools import hashlib import io import os import sys from contextlib import contextmanager from dataclasses import dataclass, field from datetime import datetime from fnmatch import fnmatch from pathlib import Path from typing import (Any, BinaryIO, Callable, Generator, Iterable, Iterator, List, Literal, Optional, TypeVar, Union) from modelscope.hub.constants import DEFAULT_MODELSCOPE_DATA_ENDPOINT from modelscope.hub.utils.utils import convert_timestamp from modelscope.utils.file_utils import get_file_hash T = TypeVar('T') # Always ignore `.git` and `.cache/modelscope` folders in commits DEFAULT_IGNORE_PATTERNS = [ '.git', '.git/*', '*/.git', '**/.git/**', '.cache', '.cache/*', '*/.cache', '**/.cache/**', ] UploadMode = Literal['lfs', 'normal'] DATASET_LFS_SUFFIX = [ '.7z', '.aac', '.arrow', '.audio', '.bmp', '.bin', '.bz2', '.flac', '.ftz', '.gif', '.gz', '.h5', '.jack', '.jpeg', '.jpg', '.png', '.jsonl', '.joblib', '.lz4', '.msgpack', '.npy', '.npz', '.ot', '.parquet', '.pb', '.pickle', '.pcm', '.pkl', '.raw', '.rar', '.sam', '.tar', '.tgz', '.wasm', '.wav', '.webm', '.webp', '.zip', '.zst', '.tiff', '.mp3', '.mp4', '.ogg', ] MODEL_LFS_SUFFIX = [ '.7z', '.arrow', '.bin', '.bz2', '.ckpt', '.ftz', '.gz', '.h5', '.joblib', '.mlmodel', '.model', '.msgpack', '.npy', '.npz', '.onnx', '.ot', '.parquet', '.pb', '.pickle', '.pkl', '.pt', '.pth', '.rar', '.safetensors', '.tar', '.tflite', '.tgz', '.wasm', '.xz', '.zip', '.zst', ] class RepoUtils: @staticmethod def filter_repo_objects( items: Iterable[T], *, allow_patterns: Optional[Union[List[str], str]] = None, ignore_patterns: Optional[Union[List[str], str]] = None, key: Optional[Callable[[T], str]] = None, ) -> Generator[T, None, None]: """Filter repo objects based on an allowlist and a denylist. Input must be a list of paths (`str` or `Path`) or a list of arbitrary objects. In the later case, `key` must be provided and specifies a function of one argument that is used to extract a path from each element in iterable. Patterns are Unix shell-style wildcards which are NOT regular expressions. See https://docs.python.org/3/library/fnmatch.html for more details. Args: items (`Iterable`): List of items to filter. allow_patterns (`str` or `List[str]`, *optional*): Patterns constituting the allowlist. If provided, item paths must match at least one pattern from the allowlist. ignore_patterns (`str` or `List[str]`, *optional*): Patterns constituting the denylist. If provided, item paths must not match any patterns from the denylist. key (`Callable[[T], str]`, *optional*): Single-argument function to extract a path from each item. If not provided, the `items` must already be `str` or `Path`. Returns: Filtered list of objects, as a generator. Raises: :class:`ValueError`: If `key` is not provided and items are not `str` or `Path`. Example usage with paths: ```python >>> # Filter only PDFs that are not hidden. >>> list(RepoUtils.filter_repo_objects( ... ["aaa.PDF", "bbb.jpg", ".ccc.pdf", ".ddd.png"], ... allow_patterns=["*.pdf"], ... ignore_patterns=[".*"], ... )) ["aaa.pdf"] ``` """ allow_patterns = allow_patterns if allow_patterns else None ignore_patterns = ignore_patterns if ignore_patterns else None if isinstance(allow_patterns, str): allow_patterns = [allow_patterns] if isinstance(ignore_patterns, str): ignore_patterns = [ignore_patterns] if allow_patterns is not None: allow_patterns = [ RepoUtils._add_wildcard_to_directories(p) for p in allow_patterns ] if ignore_patterns is not None: ignore_patterns = [ RepoUtils._add_wildcard_to_directories(p) for p in ignore_patterns ] if key is None: def _identity(item: T) -> str: if isinstance(item, str): return item if isinstance(item, Path): return str(item) raise ValueError( f'Please provide `key` argument in `filter_repo_objects`: `{item}` is not a string.' ) key = _identity # Items must be `str` or `Path`, otherwise raise ValueError for item in items: path = key(item) # Skip if there's an allowlist and path doesn't match any if allow_patterns is not None and not any( fnmatch(path, r) for r in allow_patterns): continue # Skip if there's a denylist and path matches any if ignore_patterns is not None and any( fnmatch(path, r) for r in ignore_patterns): continue yield item @staticmethod def _add_wildcard_to_directories(pattern: str) -> str: if pattern[-1] == '/': return pattern + '*' return pattern @dataclass class CommitInfo: """Data structure containing information about a newly created commit. Returned by any method that creates a commit on the Hub: [`create_commit`], [`upload_file`], [`upload_folder`], [`delete_file`], [`delete_folder`]. It inherits from `str` for backward compatibility but using methods specific to `str` is deprecated. Attributes: commit_url (`str`): Url where to find the commit. commit_message (`str`): The summary (first line) of the commit that has been created. commit_description (`str`): Description of the commit that has been created. Can be empty. oid (`str`): Commit hash id. Example: `"91c54ad1727ee830252e457677f467be0bfd8a57"`. """ commit_url: str commit_message: str commit_description: str oid: str def to_dict(cls): return { 'commit_url': cls.commit_url, 'commit_message': cls.commit_message, 'commit_description': cls.commit_description, 'oid': cls.oid, } @dataclass class DetailedCommitInfo: """Detailed commit information from repository history API.""" id: Optional[str] short_id: Optional[str] title: Optional[str] message: Optional[str] author_name: Optional[str] authored_date: Optional[datetime] author_email: Optional[str] committed_date: Optional[datetime] committer_name: Optional[str] committer_email: Optional[str] created_at: Optional[datetime] @classmethod def from_api_response(cls, data: dict) -> 'DetailedCommitInfo': """Create DetailedCommitInfo from API response data.""" return cls( id=data.get('Id', ''), short_id=data.get('ShortId', ''), title=data.get('Title', ''), message=data.get('Message', ''), author_name=data.get('AuthorName', ''), authored_date=convert_timestamp(data.get('AuthoredDate', None)), author_email=data.get('AuthorEmail', ''), committed_date=convert_timestamp(data.get('CommittedDate', None)), committer_name=data.get('CommitterName', ''), committer_email=data.get('CommitterEmail', ''), created_at=convert_timestamp(data.get('CreatedAt', None)), ) def to_dict(self) -> dict: """Convert to dictionary.""" return { 'id': self.id, 'short_id': self.short_id, 'title': self.title, 'message': self.message, 'author_name': self.author_name, 'authored_date': self.authored_date, 'author_email': self.author_email, 'committed_date': self.committed_date, 'committer_name': self.committer_name, 'committer_email': self.committer_email, 'created_at': self.created_at, } @dataclass class CommitHistoryResponse: """Response from commit history API.""" commits: Optional[List[DetailedCommitInfo]] total_count: Optional[int] @classmethod def from_api_response(cls, data: dict) -> 'CommitHistoryResponse': """Create CommitHistoryResponse from API response data.""" commits_data = data.get('Data', {}).get('Commit', []) if not commits_data: return cls( commits=[], total_count=0, ) commits = [ DetailedCommitInfo.from_api_response(commit) for commit in commits_data ] return cls( commits=commits, total_count=data.get('TotalCount', 0), ) @dataclass class RepoUrl: url: Optional[str] = None namespace: Optional[str] = None repo_name: Optional[str] = None repo_id: Optional[str] = None repo_type: Optional[str] = None endpoint: Optional[str] = DEFAULT_MODELSCOPE_DATA_ENDPOINT def __repr__(self) -> str: return f"RepoUrl('{self}', endpoint='{self.endpoint}', repo_type='{self.repo_type}', repo_id='{self.repo_id}')" def git_hash(data: bytes) -> str: """ Computes the git-sha1 hash of the given bytes, using the same algorithm as git. This is equivalent to running `git hash-object`. See https://git-scm.com/docs/git-hash-object for more details. Note: this method is valid for regular files. For LFS files, the proper git hash is supposed to be computed on the pointer file content, not the actual file content. However, for simplicity, we directly compare the sha256 of the LFS file content when we want to compare LFS files. Args: data (`bytes`): The data to compute the git-hash for. Returns: `str`: the git-hash of `data` as an hexadecimal string. """ _kwargs = {'usedforsecurity': False} if sys.version_info >= (3, 9) else {} sha1 = functools.partial(hashlib.sha1, **_kwargs) sha = sha1() sha.update(b'blob ') sha.update(str(len(data)).encode()) sha.update(b'\0') sha.update(data) return sha.hexdigest() @dataclass class UploadInfo: """ Dataclass holding required information to determine whether a blob should be uploaded to the hub using the LFS protocol or the regular protocol Args: sha256 (`str`): SHA256 hash of the blob size (`int`): Size in bytes of the blob sample (`bytes`): First 512 bytes of the blob """ sha256: str size: int sample: bytes @classmethod def from_path(cls, path: str, file_hash_info: dict = None): file_hash_info = file_hash_info or get_file_hash(path) size = file_hash_info['file_size'] sha = file_hash_info['file_hash'] with open(path, 'rb') as f: sample = f.read(512) return cls(sha256=sha, size=size, sample=sample) @classmethod def from_bytes(cls, data: bytes, file_hash_info: dict = None): file_hash_info = file_hash_info or get_file_hash(data) sha = file_hash_info['file_hash'] return cls(size=len(data), sample=data[:512], sha256=sha) @classmethod def from_fileobj(cls, fileobj: BinaryIO, file_hash_info: dict = None): file_hash_info: dict = file_hash_info or get_file_hash(fileobj) fileobj.seek(0, os.SEEK_SET) sample = fileobj.read(512) fileobj.seek(0, os.SEEK_SET) return cls( sha256=file_hash_info['file_hash'], size=file_hash_info['file_size'], sample=sample) @dataclass class CommitOperationAdd: """Data structure containing information about a file to be added to a commit.""" path_in_repo: str path_or_fileobj: Union[str, Path, bytes, BinaryIO] upload_info: UploadInfo = field(init=False, repr=False) file_hash_info: dict = field(default_factory=dict) # Internal attributes # set to "lfs" or "regular" once known _upload_mode: Optional[UploadMode] = field( init=False, repr=False, default=None) # set to True if .gitignore rules prevent the file from being uploaded as LFS # (server-side check) _should_ignore: Optional[bool] = field( init=False, repr=False, default=None) # set to the remote OID of the file if it has already been uploaded # useful to determine if a commit will be empty or not _remote_oid: Optional[str] = field(init=False, repr=False, default=None) # set to True once the file has been uploaded as LFS _is_uploaded: bool = field(init=False, repr=False, default=False) # set to True once the file has been committed _is_committed: bool = field(init=False, repr=False, default=False) def __post_init__(self) -> None: """Validates `path_or_fileobj` and compute `upload_info`.""" self.path_in_repo = _validate_path_in_repo(self.path_in_repo) # Validate `path_or_fileobj` value if isinstance(self.path_or_fileobj, Path): self.path_or_fileobj = str(self.path_or_fileobj) if isinstance(self.path_or_fileobj, str): path_or_fileobj = os.path.normpath( os.path.expanduser(self.path_or_fileobj)) if not os.path.isfile(path_or_fileobj): raise ValueError( f"Provided path: '{path_or_fileobj}' is not a file on the local file system" ) elif not isinstance(self.path_or_fileobj, (io.BufferedIOBase, bytes)): raise ValueError( 'path_or_fileobj must be either an instance of str, bytes or' ' io.BufferedIOBase. If you passed a file-like object, make sure it is' ' in binary mode.') if isinstance(self.path_or_fileobj, io.BufferedIOBase): try: self.path_or_fileobj.tell() self.path_or_fileobj.seek(0, os.SEEK_CUR) except (OSError, AttributeError) as exc: raise ValueError( 'path_or_fileobj is a file-like object but does not implement seek() and tell()' ) from exc # Compute "upload_info" attribute if isinstance(self.path_or_fileobj, str): self.upload_info = UploadInfo.from_path(self.path_or_fileobj, self.file_hash_info) elif isinstance(self.path_or_fileobj, bytes): self.upload_info = UploadInfo.from_bytes(self.path_or_fileobj, self.file_hash_info) else: self.upload_info = UploadInfo.from_fileobj(self.path_or_fileobj, self.file_hash_info) @contextmanager def as_file(self) -> Iterator[BinaryIO]: """ A context manager that yields a file-like object allowing to read the underlying data behind `path_or_fileobj`. """ if isinstance(self.path_or_fileobj, str) or isinstance( self.path_or_fileobj, Path): with open(self.path_or_fileobj, 'rb') as file: yield file elif isinstance(self.path_or_fileobj, bytes): yield io.BytesIO(self.path_or_fileobj) elif isinstance(self.path_or_fileobj, io.BufferedIOBase): prev_pos = self.path_or_fileobj.tell() yield self.path_or_fileobj self.path_or_fileobj.seek(prev_pos, 0) def b64content(self) -> bytes: """ The base64-encoded content of `path_or_fileobj` Returns: `bytes` """ with self.as_file() as file: return base64.b64encode(file.read()) @property def _local_oid(self) -> Optional[str]: """Return the OID of the local file. This OID is then compared to `self._remote_oid` to check if the file has changed compared to the remote one. If the file did not change, we won't upload it again to prevent empty commits. For LFS files, the OID corresponds to the SHA256 of the file content (used a LFS ref). For regular files, the OID corresponds to the SHA1 of the file content. Note: this is slightly different to git OID computation since the oid of an LFS file is usually the git-SHA1 of the pointer file content (not the actual file content). However, using the SHA256 is enough to detect changes and more convenient client-side. """ if self._upload_mode is None: return None elif self._upload_mode == 'lfs': return self.upload_info.sha256 else: # Regular file => compute sha1 # => no need to read by chunk since the file is guaranteed to be <=5MB. with self.as_file() as file: return git_hash(file.read()) def _validate_path_in_repo(path_in_repo: str) -> str: # Validate `path_in_repo` value to prevent a server-side issue if path_in_repo.startswith('/'): path_in_repo = path_in_repo[1:] if path_in_repo == '.' or path_in_repo == '..' or path_in_repo.startswith( '../'): raise ValueError( f"Invalid `path_in_repo` in CommitOperation: '{path_in_repo}'") if path_in_repo.startswith('./'): path_in_repo = path_in_repo[2:] return path_in_repo CommitOperation = Union[CommitOperationAdd, ]