download_manager.py 2.9 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273
  1. # Copyright (c) Alibaba, Inc. and its affiliates.
  2. from datasets.download.download_manager import DownloadManager
  3. from datasets.download.streaming_download_manager import \
  4. StreamingDownloadManager
  5. from datasets.utils.file_utils import cached_path, is_relative_path
  6. from modelscope.msdatasets.download.download_config import DataDownloadConfig
  7. from modelscope.msdatasets.utils.oss_utils import OssUtilities
  8. class DataDownloadManager(DownloadManager):
  9. def __init__(self, download_config: DataDownloadConfig):
  10. super().__init__(
  11. dataset_name=download_config.dataset_name,
  12. data_dir=download_config.data_dir,
  13. download_config=download_config,
  14. record_checksums=True)
  15. def _download(self, url_or_filename: str,
  16. download_config: DataDownloadConfig) -> str:
  17. url_or_filename = str(url_or_filename)
  18. oss_utilities = OssUtilities(
  19. oss_config=download_config.oss_config,
  20. dataset_name=download_config.dataset_name,
  21. namespace=download_config.namespace,
  22. revision=download_config.version)
  23. if is_relative_path(url_or_filename):
  24. # fetch oss files
  25. return oss_utilities.download(
  26. url_or_filename, download_config=download_config)
  27. else:
  28. return cached_path(
  29. url_or_filename, download_config=download_config)
  30. def _download_single(self, url_or_filename: str,
  31. download_config: DataDownloadConfig) -> str:
  32. # Note: _download_single function is available for datasets>=2.19.0
  33. return self._download(url_or_filename, download_config)
  34. class DataStreamingDownloadManager(StreamingDownloadManager):
  35. """The data streaming download manager."""
  36. def __init__(self, download_config: DataDownloadConfig):
  37. super().__init__(
  38. dataset_name=download_config.dataset_name,
  39. data_dir=download_config.data_dir,
  40. download_config=download_config,
  41. base_path=download_config.cache_dir)
  42. def _download(self, url_or_filename: str) -> str:
  43. url_or_filename = str(url_or_filename)
  44. oss_utilities = OssUtilities(
  45. oss_config=self.download_config.oss_config,
  46. dataset_name=self.download_config.dataset_name,
  47. namespace=self.download_config.namespace,
  48. revision=self.download_config.version)
  49. if is_relative_path(url_or_filename):
  50. # fetch oss files
  51. return oss_utilities.download(
  52. url_or_filename, download_config=self.download_config)
  53. else:
  54. return cached_path(
  55. url_or_filename, download_config=self.download_config)
  56. def _download_single(self, url_or_filename: str) -> str:
  57. # Note: _download_single function is available for datasets>=2.19.0
  58. return self._download(url_or_filename)