upload_utils.py 2.4 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768
  1. # Copyright (c) Alibaba, Inc. and its affiliates.
  2. import os
  3. from multiprocessing.dummy import Pool as ThreadPool
  4. from tqdm.auto import tqdm
  5. from modelscope.msdatasets.utils.oss_utils import OssUtilities
  6. from modelscope.utils.constant import UploadMode
  7. class DatasetUploadManager(object):
  8. def __init__(self, dataset_name: str, namespace: str, version: str):
  9. from modelscope.hub.api import HubApi
  10. _hub_api = HubApi()
  11. _oss_config = _hub_api.get_dataset_access_config_session(
  12. dataset_name=dataset_name,
  13. namespace=namespace,
  14. check_cookie=False,
  15. revision=version)
  16. self.oss_utilities = OssUtilities(
  17. oss_config=_oss_config,
  18. dataset_name=dataset_name,
  19. namespace=namespace,
  20. revision=version)
  21. def upload(self, object_name: str, local_file_path: str,
  22. upload_mode: UploadMode) -> str:
  23. object_key = self.oss_utilities.upload(
  24. oss_object_name=object_name,
  25. local_file_path=local_file_path,
  26. indicate_individual_progress=True,
  27. upload_mode=upload_mode)
  28. return object_key
  29. def upload_dir(self, object_dir_name: str, local_dir_path: str,
  30. num_processes: int, chunksize: int,
  31. filter_hidden_files: bool, upload_mode: UploadMode) -> int:
  32. def run_upload(args):
  33. self.oss_utilities.upload(
  34. oss_object_name=args[0],
  35. local_file_path=args[1],
  36. indicate_individual_progress=False,
  37. upload_mode=upload_mode)
  38. files_list = []
  39. for root, dirs, files in os.walk(local_dir_path):
  40. for file_name in files:
  41. if filter_hidden_files and file_name.startswith('.'):
  42. continue
  43. # Concatenate directory name and relative path into oss object key. e.g., train/001/1_1230.png
  44. object_name = os.path.join(
  45. object_dir_name,
  46. root.replace(local_dir_path, '', 1).strip('/'), file_name)
  47. local_file_path = os.path.join(root, file_name)
  48. files_list.append((object_name, local_file_path))
  49. with ThreadPool(processes=num_processes) as pool:
  50. result = list(
  51. tqdm(
  52. pool.imap(run_upload, files_list, chunksize=chunksize),
  53. total=len(files_list)))
  54. return len(result)