# Copyright (c) Alibaba, Inc. and its affiliates. from __future__ import print_function import multiprocessing import os from datasets.utils.file_utils import hash_url_to_filename from modelscope.hub.api import HubApi from modelscope.msdatasets.download.download_config import DataDownloadConfig from modelscope.utils.config_ds import MS_CACHE_HOME from modelscope.utils.constant import (DEFAULT_DATA_ACCELERATION_ENDPOINT, MetaDataFields, UploadMode) from modelscope.utils.logger import get_logger logger = get_logger() ACCESS_ID = 'AccessId' ACCESS_SECRET = 'AccessSecret' SECURITY_TOKEN = 'SecurityToken' BUCKET = 'Bucket' BACK_DIR = 'BackupDir' DIR = 'Dir' class OssUtilities: def __init__(self, oss_config, dataset_name, namespace, revision): self._do_init(oss_config=oss_config) self.dataset_name = dataset_name self.namespace = namespace self.revision = revision self.resumable_store_root_path = os.path.join(MS_CACHE_HOME, 'tmp/resumable_store') self.num_threads = multiprocessing.cpu_count() self.part_size = 1 * 1024 * 1024 self.multipart_threshold = 50 * 1024 * 1024 self.max_retries = 3 import oss2 self.resumable_store_download = oss2.ResumableDownloadStore( root=self.resumable_store_root_path) self.resumable_store_upload = oss2.ResumableStore( root=self.resumable_store_root_path) self.api = HubApi() def _do_init(self, oss_config): import oss2 self.key = oss_config[ACCESS_ID] self.secret = oss_config[ACCESS_SECRET] self.token = oss_config[SECURITY_TOKEN] if os.getenv('ENABLE_DATASET_ACCELERATION') == 'True': self.endpoint = DEFAULT_DATA_ACCELERATION_ENDPOINT else: self.endpoint = f"https://{oss_config['Region']}.aliyuncs.com" self.bucket_name = oss_config[BUCKET] auth = oss2.StsAuth(self.key, self.secret, self.token) self.bucket = oss2.Bucket( auth, self.endpoint, self.bucket_name, connect_timeout=120) self.oss_dir = oss_config[DIR] self.oss_backup_dir = oss_config[BACK_DIR] def _reload_sts(self): logger.info('Reloading sts token automatically.') oss_config_refresh = self.api.get_dataset_access_config_session( dataset_name=self.dataset_name, namespace=self.namespace, check_cookie=True, revision=self.revision) self._do_init(oss_config_refresh) @staticmethod def _percentage(consumed_bytes, total_bytes): if total_bytes: rate = int(100 * (float(consumed_bytes) / float(total_bytes))) print('\r{0}% '.format(rate), end='', flush=True) def download(self, oss_file_name: str, download_config: DataDownloadConfig): import oss2 cache_dir = download_config.cache_dir candidate_key = os.path.join(self.oss_dir, oss_file_name) candidate_key_backup = os.path.join(self.oss_backup_dir, oss_file_name) split = download_config.split big_data = False if split: args_dict = download_config.meta_args_map.get(split) if args_dict: big_data = args_dict.get(MetaDataFields.ARGS_BIG_DATA) retry_count = 0 while True: try: retry_count += 1 # big_data is True when the dataset contains large number of objects if big_data: file_oss_key = candidate_key else: file_oss_key = candidate_key if self.bucket.object_exists( candidate_key) else candidate_key_backup filename = hash_url_to_filename(file_oss_key, etag=None) local_path = os.path.join(cache_dir, filename) if download_config.force_download or not os.path.exists( local_path): oss2.resumable_download( self.bucket, file_oss_key, local_path, store=self.resumable_store_download, multiget_threshold=self.multipart_threshold, part_size=self.part_size, progress_callback=self._percentage, num_threads=self.num_threads) break except Exception as e: if e.__dict__.get('status') == 403: self._reload_sts() if retry_count >= self.max_retries: logger.warning(f'Failed to download {oss_file_name}') raise e return local_path def upload(self, oss_object_name: str, local_file_path: str, indicate_individual_progress: bool, upload_mode: UploadMode) -> str: import oss2 retry_count = 0 object_key = os.path.join(self.oss_dir, oss_object_name) if indicate_individual_progress: progress_callback = self._percentage else: progress_callback = None while True: try: retry_count += 1 exist = self.bucket.object_exists(object_key) if upload_mode == UploadMode.APPEND and exist: logger.info( f'Skip {oss_object_name} in case of {upload_mode.value} mode.' ) break oss2.resumable_upload( self.bucket, object_key, local_file_path, store=self.resumable_store_upload, multipart_threshold=self.multipart_threshold, part_size=self.part_size, progress_callback=progress_callback, num_threads=self.num_threads) break except Exception as e: if e.__dict__.get('status') == 403: self._reload_sts() if retry_count >= self.max_retries: raise return object_key