oss_utils.py 6.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167
  1. # Copyright (c) Alibaba, Inc. and its affiliates.
  2. from __future__ import print_function
  3. import multiprocessing
  4. import os
  5. from datasets.utils.file_utils import hash_url_to_filename
  6. from modelscope.hub.api import HubApi
  7. from modelscope.msdatasets.download.download_config import DataDownloadConfig
  8. from modelscope.utils.config_ds import MS_CACHE_HOME
  9. from modelscope.utils.constant import (DEFAULT_DATA_ACCELERATION_ENDPOINT,
  10. MetaDataFields, UploadMode)
  11. from modelscope.utils.logger import get_logger
  12. logger = get_logger()
  13. ACCESS_ID = 'AccessId'
  14. ACCESS_SECRET = 'AccessSecret'
  15. SECURITY_TOKEN = 'SecurityToken'
  16. BUCKET = 'Bucket'
  17. BACK_DIR = 'BackupDir'
  18. DIR = 'Dir'
  19. class OssUtilities:
  20. def __init__(self, oss_config, dataset_name, namespace, revision):
  21. self._do_init(oss_config=oss_config)
  22. self.dataset_name = dataset_name
  23. self.namespace = namespace
  24. self.revision = revision
  25. self.resumable_store_root_path = os.path.join(MS_CACHE_HOME,
  26. 'tmp/resumable_store')
  27. self.num_threads = multiprocessing.cpu_count()
  28. self.part_size = 1 * 1024 * 1024
  29. self.multipart_threshold = 50 * 1024 * 1024
  30. self.max_retries = 3
  31. import oss2
  32. self.resumable_store_download = oss2.ResumableDownloadStore(
  33. root=self.resumable_store_root_path)
  34. self.resumable_store_upload = oss2.ResumableStore(
  35. root=self.resumable_store_root_path)
  36. self.api = HubApi()
  37. def _do_init(self, oss_config):
  38. import oss2
  39. self.key = oss_config[ACCESS_ID]
  40. self.secret = oss_config[ACCESS_SECRET]
  41. self.token = oss_config[SECURITY_TOKEN]
  42. if os.getenv('ENABLE_DATASET_ACCELERATION') == 'True':
  43. self.endpoint = DEFAULT_DATA_ACCELERATION_ENDPOINT
  44. else:
  45. self.endpoint = f"https://{oss_config['Region']}.aliyuncs.com"
  46. self.bucket_name = oss_config[BUCKET]
  47. auth = oss2.StsAuth(self.key, self.secret, self.token)
  48. self.bucket = oss2.Bucket(
  49. auth, self.endpoint, self.bucket_name, connect_timeout=120)
  50. self.oss_dir = oss_config[DIR]
  51. self.oss_backup_dir = oss_config[BACK_DIR]
  52. def _reload_sts(self):
  53. logger.info('Reloading sts token automatically.')
  54. oss_config_refresh = self.api.get_dataset_access_config_session(
  55. dataset_name=self.dataset_name,
  56. namespace=self.namespace,
  57. check_cookie=True,
  58. revision=self.revision)
  59. self._do_init(oss_config_refresh)
  60. @staticmethod
  61. def _percentage(consumed_bytes, total_bytes):
  62. if total_bytes:
  63. rate = int(100 * (float(consumed_bytes) / float(total_bytes)))
  64. print('\r{0}% '.format(rate), end='', flush=True)
  65. def download(self, oss_file_name: str,
  66. download_config: DataDownloadConfig):
  67. import oss2
  68. cache_dir = download_config.cache_dir
  69. candidate_key = os.path.join(self.oss_dir, oss_file_name)
  70. candidate_key_backup = os.path.join(self.oss_backup_dir, oss_file_name)
  71. split = download_config.split
  72. big_data = False
  73. if split:
  74. args_dict = download_config.meta_args_map.get(split)
  75. if args_dict:
  76. big_data = args_dict.get(MetaDataFields.ARGS_BIG_DATA)
  77. retry_count = 0
  78. while True:
  79. try:
  80. retry_count += 1
  81. # big_data is True when the dataset contains large number of objects
  82. if big_data:
  83. file_oss_key = candidate_key
  84. else:
  85. file_oss_key = candidate_key if self.bucket.object_exists(
  86. candidate_key) else candidate_key_backup
  87. filename = hash_url_to_filename(file_oss_key, etag=None)
  88. local_path = os.path.join(cache_dir, filename)
  89. if download_config.force_download or not os.path.exists(
  90. local_path):
  91. oss2.resumable_download(
  92. self.bucket,
  93. file_oss_key,
  94. local_path,
  95. store=self.resumable_store_download,
  96. multiget_threshold=self.multipart_threshold,
  97. part_size=self.part_size,
  98. progress_callback=self._percentage,
  99. num_threads=self.num_threads)
  100. break
  101. except Exception as e:
  102. if e.__dict__.get('status') == 403:
  103. self._reload_sts()
  104. if retry_count >= self.max_retries:
  105. logger.warning(f'Failed to download {oss_file_name}')
  106. raise e
  107. return local_path
  108. def upload(self, oss_object_name: str, local_file_path: str,
  109. indicate_individual_progress: bool,
  110. upload_mode: UploadMode) -> str:
  111. import oss2
  112. retry_count = 0
  113. object_key = os.path.join(self.oss_dir, oss_object_name)
  114. if indicate_individual_progress:
  115. progress_callback = self._percentage
  116. else:
  117. progress_callback = None
  118. while True:
  119. try:
  120. retry_count += 1
  121. exist = self.bucket.object_exists(object_key)
  122. if upload_mode == UploadMode.APPEND and exist:
  123. logger.info(
  124. f'Skip {oss_object_name} in case of {upload_mode.value} mode.'
  125. )
  126. break
  127. oss2.resumable_upload(
  128. self.bucket,
  129. object_key,
  130. local_file_path,
  131. store=self.resumable_store_upload,
  132. multipart_threshold=self.multipart_threshold,
  133. part_size=self.part_size,
  134. progress_callback=progress_callback,
  135. num_threads=self.num_threads)
  136. break
  137. except Exception as e:
  138. if e.__dict__.get('status') == 403:
  139. self._reload_sts()
  140. if retry_count >= self.max_retries:
  141. raise
  142. return object_key