data_loader.py 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325
  1. # Copyright (c) Alibaba, Inc. and its affiliates.
  2. import os
  3. from abc import ABC, abstractmethod
  4. from typing import Optional, Union
  5. from datasets import (Dataset, DatasetBuilder, DatasetDict, IterableDataset,
  6. IterableDatasetDict)
  7. from datasets import load_dataset as hf_load_dataset
  8. from modelscope.hub.api import ModelScopeConfig
  9. from modelscope.msdatasets.auth.auth_config import OssAuthConfig
  10. from modelscope.msdatasets.context.dataset_context_config import \
  11. DatasetContextConfig
  12. from modelscope.msdatasets.data_files.data_files_manager import \
  13. DataFilesManager
  14. from modelscope.msdatasets.dataset_cls import ExternalDataset
  15. from modelscope.msdatasets.meta.data_meta_manager import DataMetaManager
  16. from modelscope.utils.constant import (DatasetFormations, DatasetPathName,
  17. DownloadMode, VirgoDatasetConfig)
  18. from modelscope.utils.logger import get_logger
  19. from modelscope.utils.url_utils import valid_url
  20. logger = get_logger()
  21. class BaseDownloader(ABC):
  22. """Base dataset downloader to load data."""
  23. def __init__(self, dataset_context_config: DatasetContextConfig):
  24. self.dataset_context_config = dataset_context_config
  25. @abstractmethod
  26. def process(self):
  27. """The entity processing pipeline for fetching the data. """
  28. raise NotImplementedError(
  29. f'No default implementation provided for {BaseDownloader.__name__}.process.'
  30. )
  31. @abstractmethod
  32. def _authorize(self):
  33. raise NotImplementedError(
  34. f'No default implementation provided for {BaseDownloader.__name__}._authorize.'
  35. )
  36. @abstractmethod
  37. def _build(self):
  38. raise NotImplementedError(
  39. f'No default implementation provided for {BaseDownloader.__name__}._build.'
  40. )
  41. @abstractmethod
  42. def _prepare_and_download(self):
  43. raise NotImplementedError(
  44. f'No default implementation provided for {BaseDownloader.__name__}._prepare_and_download.'
  45. )
  46. @abstractmethod
  47. def _post_process(self):
  48. raise NotImplementedError(
  49. f'No default implementation provided for {BaseDownloader.__name__}._post_process.'
  50. )
  51. class OssDownloader(BaseDownloader):
  52. def __init__(self, dataset_context_config: DatasetContextConfig):
  53. super().__init__(dataset_context_config)
  54. self.data_files_builder: Optional[DataFilesManager] = None
  55. self.dataset: Optional[Union[Dataset, IterableDataset, DatasetDict,
  56. IterableDatasetDict,
  57. ExternalDataset]] = None
  58. self.builder: Optional[DatasetBuilder] = None
  59. self.data_files_manager: Optional[DataFilesManager] = None
  60. def process(self) -> None:
  61. """ Sequential data fetching process: authorize -> build -> prepare_and_download -> post_process,
  62. to keep dataset_context_config updated. """
  63. self._authorize()
  64. self._build()
  65. self._prepare_and_download()
  66. self._post_process()
  67. def _authorize(self) -> None:
  68. """ Authorization of target dataset.
  69. Get credentials from cache and send to the modelscope-hub in the future. """
  70. cookies = ModelScopeConfig.get_cookies()
  71. git_token = ModelScopeConfig.get_token()
  72. user_info = ModelScopeConfig.get_user_info()
  73. if not self.dataset_context_config.auth_config:
  74. auth_config = OssAuthConfig(
  75. cookies=cookies, git_token=git_token, user_info=user_info)
  76. else:
  77. auth_config = self.dataset_context_config.auth_config
  78. auth_config.cookies = cookies
  79. auth_config.git_token = git_token
  80. auth_config.user_info = user_info
  81. self.dataset_context_config.auth_config = auth_config
  82. def _build(self) -> None:
  83. """ Sequential data files building process: build_meta -> build_data_files , to keep context_config updated. """
  84. # Build meta data
  85. meta_manager = DataMetaManager(self.dataset_context_config)
  86. meta_manager.fetch_meta_files()
  87. meta_manager.parse_dataset_structure()
  88. self.dataset_context_config = meta_manager.dataset_context_config
  89. # Build data-files manager
  90. self.data_files_manager = DataFilesManager(
  91. dataset_context_config=self.dataset_context_config)
  92. self.builder = self.data_files_manager.get_data_files_builder()
  93. def _prepare_and_download(self) -> None:
  94. """ Fetch data-files from modelscope dataset-hub. """
  95. dataset_py_script = self.dataset_context_config.data_meta_config.dataset_py_script
  96. dataset_formation = self.dataset_context_config.data_meta_config.dataset_formation
  97. dataset_name = self.dataset_context_config.dataset_name
  98. subset_name = self.dataset_context_config.subset_name
  99. version = self.dataset_context_config.version
  100. split = self.dataset_context_config.split
  101. data_dir = self.dataset_context_config.data_dir
  102. data_files = self.dataset_context_config.data_files
  103. cache_dir = self.dataset_context_config.cache_root_dir
  104. download_mode = self.dataset_context_config.download_mode
  105. input_kwargs = self.dataset_context_config.config_kwargs
  106. trust_remote_code = self.dataset_context_config.trust_remote_code
  107. if self.builder is None and not dataset_py_script:
  108. raise f'meta-file: {dataset_name}.py not found on the modelscope hub.'
  109. if dataset_py_script and dataset_formation == DatasetFormations.hf_compatible:
  110. if trust_remote_code:
  111. logger.warning(
  112. f'Use trust_remote_code=True. Will invoke codes from {dataset_name}. Please make '
  113. 'sure that you can trust the external codes.')
  114. self.dataset = hf_load_dataset(
  115. dataset_py_script,
  116. name=subset_name,
  117. revision=version,
  118. split=split,
  119. data_dir=data_dir,
  120. data_files=data_files,
  121. cache_dir=cache_dir,
  122. download_mode=download_mode.value,
  123. trust_remote_code=trust_remote_code,
  124. **input_kwargs)
  125. else:
  126. self.dataset = self.data_files_manager.fetch_data_files(
  127. self.builder)
  128. def _post_process(self) -> None:
  129. if isinstance(self.dataset, ExternalDataset):
  130. self.dataset.custom_map = self.dataset_context_config.data_meta_config.meta_type_map
  131. class VirgoDownloader(BaseDownloader):
  132. """Data downloader for Virgo data source."""
  133. def __init__(self, dataset_context_config: DatasetContextConfig):
  134. super().__init__(dataset_context_config)
  135. self.dataset = None
  136. def process(self):
  137. """
  138. Sequential data fetching virgo dataset process: authorize -> build -> prepare_and_download -> post_process
  139. """
  140. self._authorize()
  141. self._build()
  142. self._prepare_and_download()
  143. self._post_process()
  144. def _authorize(self):
  145. """Authorization of virgo dataset."""
  146. from modelscope.msdatasets.auth.auth_config import VirgoAuthConfig
  147. cookies = ModelScopeConfig.get_cookies()
  148. user_info = ModelScopeConfig.get_user_info()
  149. if not self.dataset_context_config.auth_config:
  150. auth_config = VirgoAuthConfig(
  151. cookies=cookies, git_token='', user_info=user_info)
  152. else:
  153. auth_config = self.dataset_context_config.auth_config
  154. auth_config.cookies = cookies
  155. auth_config.git_token = ''
  156. auth_config.user_info = user_info
  157. self.dataset_context_config.auth_config = auth_config
  158. def _build(self):
  159. """
  160. Fetch virgo meta and build virgo dataset.
  161. """
  162. from modelscope.msdatasets.dataset_cls.dataset import VirgoDataset
  163. import pandas as pd
  164. meta_manager = DataMetaManager(self.dataset_context_config)
  165. meta_manager.fetch_virgo_meta()
  166. self.dataset_context_config = meta_manager.dataset_context_config
  167. self.dataset = VirgoDataset(
  168. **self.dataset_context_config.config_kwargs)
  169. virgo_cache_dir = os.path.join(
  170. self.dataset_context_config.cache_root_dir,
  171. self.dataset_context_config.namespace,
  172. self.dataset_context_config.dataset_name,
  173. self.dataset_context_config.version)
  174. os.makedirs(
  175. os.path.join(virgo_cache_dir, DatasetPathName.META_NAME),
  176. exist_ok=True)
  177. meta_content_cache_file = os.path.join(virgo_cache_dir,
  178. DatasetPathName.META_NAME,
  179. 'meta_content.csv')
  180. if isinstance(self.dataset.meta, pd.DataFrame):
  181. meta_content_df = self.dataset.meta
  182. meta_content_df.to_csv(meta_content_cache_file, index=False)
  183. self.dataset.meta_content_cache_file = meta_content_cache_file
  184. self.dataset.virgo_cache_dir = virgo_cache_dir
  185. logger.info(
  186. f'Virgo meta content saved to {meta_content_cache_file}')
  187. def _prepare_and_download(self):
  188. """
  189. Fetch data-files from oss-urls in the virgo meta content.
  190. """
  191. download_virgo_files = self.dataset_context_config.config_kwargs.pop(
  192. 'download_virgo_files', '')
  193. if self.dataset.data_type == 0 and download_virgo_files:
  194. import requests
  195. import json
  196. import shutil
  197. from urllib.parse import urlparse
  198. from functools import partial
  199. def download_file(meta_info_val, data_dir):
  200. file_url_list = []
  201. file_path_list = []
  202. try:
  203. meta_info_val = json.loads(meta_info_val)
  204. # get url first, if not exist, try to get inner_url
  205. file_url = meta_info_val.get('url', '')
  206. if file_url:
  207. file_url_list.append(file_url)
  208. else:
  209. tmp_inner_member_list = meta_info_val.get(
  210. 'inner_url', '')
  211. for item in tmp_inner_member_list:
  212. file_url = item.get('url', '')
  213. if file_url:
  214. file_url_list.append(file_url)
  215. for one_file_url in file_url_list:
  216. is_url = valid_url(one_file_url)
  217. if is_url:
  218. url_parse_res = urlparse(file_url)
  219. file_name = os.path.basename(url_parse_res.path)
  220. else:
  221. raise ValueError(f'Unsupported url: {file_url}')
  222. file_path = os.path.join(data_dir, file_name)
  223. file_path_list.append((one_file_url, file_path))
  224. except Exception as e:
  225. logger.error(f'parse virgo meta info error: {e}')
  226. file_path_list = []
  227. for file_url_item, file_path_item in file_path_list:
  228. if file_path_item and not os.path.exists(file_path_item):
  229. logger.info(f'Downloading file to {file_path_item}')
  230. os.makedirs(data_dir, exist_ok=True)
  231. with open(file_path_item, 'wb') as f:
  232. f.write(requests.get(file_url_item).content)
  233. return file_path_list
  234. self.dataset.download_virgo_files = True
  235. download_mode = self.dataset_context_config.download_mode
  236. data_files_dir = os.path.join(self.dataset.virgo_cache_dir,
  237. DatasetPathName.DATA_FILES_NAME)
  238. if download_mode == DownloadMode.FORCE_REDOWNLOAD:
  239. shutil.rmtree(data_files_dir, ignore_errors=True)
  240. from tqdm.auto import tqdm
  241. tqdm.pandas(desc='apply download_file')
  242. self.dataset.meta[
  243. VirgoDatasetConfig.
  244. col_cache_file] = self.dataset.meta.progress_apply(
  245. lambda row: partial(
  246. download_file, data_dir=data_files_dir)(row.meta_info),
  247. axis=1)
  248. def _post_process(self):
  249. ...
  250. class MaxComputeDownloader(BaseDownloader):
  251. """Data downloader for MaxCompute data source."""
  252. # TODO: MaxCompute data source to be supported .
  253. def __init__(self, dataset_context_config: DatasetContextConfig):
  254. super().__init__(dataset_context_config)
  255. self.dataset = None
  256. def process(self):
  257. ...
  258. def _authorize(self):
  259. ...
  260. def _build(self):
  261. ...
  262. def _prepare_and_download(self):
  263. ...
  264. def _post_process(self):
  265. ...