| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167 |
- # Copyright (c) Alibaba, Inc. and its affiliates.
- from __future__ import print_function
- import multiprocessing
- import os
- from datasets.utils.file_utils import hash_url_to_filename
- from modelscope.hub.api import HubApi
- from modelscope.msdatasets.download.download_config import DataDownloadConfig
- from modelscope.utils.config_ds import MS_CACHE_HOME
- from modelscope.utils.constant import (DEFAULT_DATA_ACCELERATION_ENDPOINT,
- MetaDataFields, UploadMode)
- from modelscope.utils.logger import get_logger
- logger = get_logger()
- ACCESS_ID = 'AccessId'
- ACCESS_SECRET = 'AccessSecret'
- SECURITY_TOKEN = 'SecurityToken'
- BUCKET = 'Bucket'
- BACK_DIR = 'BackupDir'
- DIR = 'Dir'
- class OssUtilities:
- def __init__(self, oss_config, dataset_name, namespace, revision):
- self._do_init(oss_config=oss_config)
- self.dataset_name = dataset_name
- self.namespace = namespace
- self.revision = revision
- self.resumable_store_root_path = os.path.join(MS_CACHE_HOME,
- 'tmp/resumable_store')
- self.num_threads = multiprocessing.cpu_count()
- self.part_size = 1 * 1024 * 1024
- self.multipart_threshold = 50 * 1024 * 1024
- self.max_retries = 3
- import oss2
- self.resumable_store_download = oss2.ResumableDownloadStore(
- root=self.resumable_store_root_path)
- self.resumable_store_upload = oss2.ResumableStore(
- root=self.resumable_store_root_path)
- self.api = HubApi()
- def _do_init(self, oss_config):
- import oss2
- self.key = oss_config[ACCESS_ID]
- self.secret = oss_config[ACCESS_SECRET]
- self.token = oss_config[SECURITY_TOKEN]
- if os.getenv('ENABLE_DATASET_ACCELERATION') == 'True':
- self.endpoint = DEFAULT_DATA_ACCELERATION_ENDPOINT
- else:
- self.endpoint = f"https://{oss_config['Region']}.aliyuncs.com"
- self.bucket_name = oss_config[BUCKET]
- auth = oss2.StsAuth(self.key, self.secret, self.token)
- self.bucket = oss2.Bucket(
- auth, self.endpoint, self.bucket_name, connect_timeout=120)
- self.oss_dir = oss_config[DIR]
- self.oss_backup_dir = oss_config[BACK_DIR]
- def _reload_sts(self):
- logger.info('Reloading sts token automatically.')
- oss_config_refresh = self.api.get_dataset_access_config_session(
- dataset_name=self.dataset_name,
- namespace=self.namespace,
- check_cookie=True,
- revision=self.revision)
- self._do_init(oss_config_refresh)
- @staticmethod
- def _percentage(consumed_bytes, total_bytes):
- if total_bytes:
- rate = int(100 * (float(consumed_bytes) / float(total_bytes)))
- print('\r{0}% '.format(rate), end='', flush=True)
- def download(self, oss_file_name: str,
- download_config: DataDownloadConfig):
- import oss2
- cache_dir = download_config.cache_dir
- candidate_key = os.path.join(self.oss_dir, oss_file_name)
- candidate_key_backup = os.path.join(self.oss_backup_dir, oss_file_name)
- split = download_config.split
- big_data = False
- if split:
- args_dict = download_config.meta_args_map.get(split)
- if args_dict:
- big_data = args_dict.get(MetaDataFields.ARGS_BIG_DATA)
- retry_count = 0
- while True:
- try:
- retry_count += 1
- # big_data is True when the dataset contains large number of objects
- if big_data:
- file_oss_key = candidate_key
- else:
- file_oss_key = candidate_key if self.bucket.object_exists(
- candidate_key) else candidate_key_backup
- filename = hash_url_to_filename(file_oss_key, etag=None)
- local_path = os.path.join(cache_dir, filename)
- if download_config.force_download or not os.path.exists(
- local_path):
- oss2.resumable_download(
- self.bucket,
- file_oss_key,
- local_path,
- store=self.resumable_store_download,
- multiget_threshold=self.multipart_threshold,
- part_size=self.part_size,
- progress_callback=self._percentage,
- num_threads=self.num_threads)
- break
- except Exception as e:
- if e.__dict__.get('status') == 403:
- self._reload_sts()
- if retry_count >= self.max_retries:
- logger.warning(f'Failed to download {oss_file_name}')
- raise e
- return local_path
- def upload(self, oss_object_name: str, local_file_path: str,
- indicate_individual_progress: bool,
- upload_mode: UploadMode) -> str:
- import oss2
- retry_count = 0
- object_key = os.path.join(self.oss_dir, oss_object_name)
- if indicate_individual_progress:
- progress_callback = self._percentage
- else:
- progress_callback = None
- while True:
- try:
- retry_count += 1
- exist = self.bucket.object_exists(object_key)
- if upload_mode == UploadMode.APPEND and exist:
- logger.info(
- f'Skip {oss_object_name} in case of {upload_mode.value} mode.'
- )
- break
- oss2.resumable_upload(
- self.bucket,
- object_key,
- local_file_path,
- store=self.resumable_store_upload,
- multipart_threshold=self.multipart_threshold,
- part_size=self.part_size,
- progress_callback=progress_callback,
- num_threads=self.num_threads)
- break
- except Exception as e:
- if e.__dict__.get('status') == 403:
- self._reload_sts()
- if retry_count >= self.max_retries:
- raise
- return object_key
|