dataset_builder.py 21 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514
  1. # Copyright (c) Alibaba, Inc. and its affiliates.
  2. import os
  3. from typing import Dict, Union
  4. import datasets
  5. import pandas as pd
  6. import pyarrow as pa
  7. from datasets import (ArrowBasedBuilder, Dataset, DatasetDict,
  8. GeneratorBasedBuilder, IterableDataset,
  9. IterableDatasetDict)
  10. from datasets.filesystems import is_remote_filesystem
  11. from datasets.info import DatasetInfo
  12. from datasets.naming import camelcase_to_snakecase
  13. from datasets.packaged_modules import csv
  14. from datasets.utils.filelock import FileLock
  15. from datasets.utils.py_utils import map_nested
  16. from modelscope.hub.api import HubApi
  17. from modelscope.msdatasets.context.dataset_context_config import \
  18. DatasetContextConfig
  19. from modelscope.msdatasets.dataset_cls import (ExternalDataset,
  20. NativeIterableDataset)
  21. from modelscope.msdatasets.download.download_manager import \
  22. DataStreamingDownloadManager
  23. from modelscope.msdatasets.utils.dataset_utils import \
  24. get_subdir_hash_from_split
  25. from modelscope.utils.constant import (DEFAULT_DATASET_NAMESPACE,
  26. DatasetPathName, DownloadMode)
  27. from modelscope.utils.logger import get_logger
  28. logger = get_logger()
  29. DELIMITER_NAME = 'delimiter'
  30. DEFAULT_CSV_DELIMITER = ','
  31. class CsvDatasetBuilder(csv.Csv):
  32. def __init__(self, dataset_context_config: DatasetContextConfig):
  33. # Init config args
  34. self.dataset_name = dataset_context_config.dataset_name
  35. self.cache_root_dir = dataset_context_config.cache_root_dir
  36. self.namespace = dataset_context_config.namespace
  37. self.version = dataset_context_config.version
  38. self.subset_name = dataset_context_config.subset_name
  39. self.split = dataset_context_config.split
  40. self.meta_data_files = dataset_context_config.data_meta_config.meta_data_files
  41. self.zip_data_files = dataset_context_config.data_meta_config.zip_data_files
  42. self.input_config_kwargs = dataset_context_config.config_kwargs
  43. self.split_path_dict = dict({})
  44. self.cache_build_dir = os.path.join(self.cache_root_dir,
  45. self.namespace, self.dataset_name,
  46. self.version,
  47. DatasetPathName.META_NAME)
  48. self.csv_delimiter = DEFAULT_CSV_DELIMITER
  49. if DELIMITER_NAME in self.input_config_kwargs:
  50. self.csv_delimiter = self.input_config_kwargs[DELIMITER_NAME]
  51. split = self.split or list(dataset_context_config.data_meta_config.
  52. target_dataset_structure.keys())
  53. sub_dir_hash = get_subdir_hash_from_split(
  54. split=split, version=self.version)
  55. from datasets.data_files import DataFilesDict, DataFilesList
  56. data_files = {
  57. k: DataFilesList([v], origin_metadata=None)
  58. for k, v in self.meta_data_files.items()
  59. }
  60. data_files = DataFilesDict.from_local_or_remote(data_files)
  61. super().__init__(
  62. cache_dir=self.cache_build_dir,
  63. config_name=self.namespace,
  64. hash=sub_dir_hash,
  65. data_files=data_files,
  66. **self.input_config_kwargs)
  67. self.info.builder_name = self.dataset_name
  68. self.name = camelcase_to_snakecase(self.dataset_name)
  69. self.local_meta_csv_paths: dict = dict({})
  70. def _build_cache_dir(self, namespace=DEFAULT_DATASET_NAMESPACE):
  71. builder_data_dir = os.path.join(
  72. self._cache_dir_root,
  73. self._relative_data_dir(
  74. with_version=False, with_hash=True, namespace=namespace))
  75. return builder_data_dir
  76. def _relative_data_dir(self,
  77. with_version=True,
  78. with_hash=True,
  79. namespace=DEFAULT_DATASET_NAMESPACE) -> str:
  80. """Relative path of this dataset in cache_dir:
  81. Will be:
  82. self.name/self.config.version/self.hash/
  83. or if a namespace has been specified:
  84. self.namespace___self.name/self.config.version/self.hash/
  85. """
  86. builder_data_dir = self.info.builder_name if namespace is None else f'{namespace}___{self.info.builder_name}'
  87. builder_config = self.config
  88. hash = self.hash
  89. if builder_config:
  90. builder_data_dir = os.path.join(builder_data_dir, self.config_id)
  91. if with_version:
  92. builder_data_dir = os.path.join(builder_data_dir,
  93. str(self.config.version))
  94. if with_hash and hash and isinstance(hash, str):
  95. builder_data_dir = os.path.join(builder_data_dir, hash)
  96. return builder_data_dir
  97. def _split_generators(self, dl_manager):
  98. if not self.config.data_files:
  99. raise ValueError(
  100. 'At least one data file must be specified, but got none.')
  101. data_files = dl_manager.download_and_extract(self.config.data_files)
  102. zip_data_files = dl_manager.download_and_extract(self.zip_data_files)
  103. splits = []
  104. for split_name, files in data_files.items():
  105. if isinstance(files, str):
  106. files = [files]
  107. splits.append(
  108. datasets.SplitGenerator(
  109. name=split_name,
  110. gen_kwargs={
  111. 'files': dl_manager.iter_files(files),
  112. 'base_dir': zip_data_files.get(split_name)
  113. }))
  114. return splits
  115. def _generate_tables(self, files, base_dir):
  116. schema = pa.schema(self.config.features.type
  117. ) if self.config.features is not None else None
  118. dtype = {
  119. name: dtype.to_pandas_dtype()
  120. for name, dtype in zip(schema.names, schema.types)
  121. } if schema else None
  122. for file_idx, file in enumerate(files):
  123. csv_file_reader = pd.read_csv(
  124. file, iterator=True, dtype=dtype, delimiter=self.csv_delimiter)
  125. transform_fields = []
  126. for field_name in csv_file_reader._engine.names:
  127. if field_name.endswith(':FILE'):
  128. transform_fields.append(field_name)
  129. try:
  130. for batch_idx, df in enumerate(csv_file_reader):
  131. for field_name in transform_fields:
  132. if base_dir:
  133. df[field_name] = df[field_name].apply(
  134. lambda x: os.path.join(base_dir, x))
  135. pa_table = pa.Table.from_pandas(df, schema=schema)
  136. yield (file_idx, batch_idx), pa_table
  137. except ValueError as e:
  138. logger.error(
  139. f"Failed to read file '{file}' with error {type(e)}: {e}")
  140. raise
  141. def download_and_prepare(self, download_mode, dl_manager,
  142. **download_kwargs):
  143. target_cache_dir = dl_manager.download_config.cache_dir
  144. split_name = dl_manager.download_config.split
  145. if not split_name:
  146. split_name = DatasetPathName.LOCK_FILE_NAME_ANY
  147. version_name = dl_manager.download_config.version
  148. if not version_name:
  149. version_name = DatasetPathName.LOCK_FILE_NAME_ANY
  150. subset_name = self.subset_name
  151. if not subset_name:
  152. subset_name = DatasetPathName.LOCK_FILE_NAME_ANY
  153. # Prevent parallel disk operations
  154. lock_file_names = []
  155. lock_file_names.append(DatasetPathName.DATA_FILES_NAME)
  156. lock_file_names.append(dl_manager.download_config.dataset_name)
  157. lock_file_names.append(version_name)
  158. lock_file_names.append(subset_name)
  159. lock_file_names.append(split_name)
  160. lock_file_name = DatasetPathName.LOCK_FILE_NAME_DELIMITER.join(
  161. lock_file_names)
  162. lock_path = os.path.join(
  163. target_cache_dir.strip(DatasetPathName.DATA_FILES_NAME),
  164. lock_file_name + '.lock')
  165. with FileLock(lock_path):
  166. data_exists = os.path.exists(target_cache_dir)
  167. if data_exists and download_mode == DownloadMode.REUSE_DATASET_IF_EXISTS.value:
  168. logger.warning(
  169. f'Reusing dataset {self.name} ({target_cache_dir})')
  170. logger.info(f'Generating dataset {self.name} ({target_cache_dir})')
  171. self._download_and_prepare(
  172. dl_manager=dl_manager, download_mode=download_mode)
  173. def _download_and_prepare(self, dl_manager, download_mode):
  174. import shutil
  175. target_cache_dir = dl_manager.download_config.cache_dir
  176. if download_mode == DownloadMode.FORCE_REDOWNLOAD.value:
  177. shutil.rmtree(target_cache_dir, ignore_errors=True)
  178. os.makedirs(target_cache_dir, exist_ok=True)
  179. self.local_meta_csv_paths = {
  180. k: HubApi.fetch_meta_files_from_url(v, target_cache_dir)
  181. for k, v in self.meta_data_files.items()
  182. }
  183. self.split_path_dict = dl_manager.download_and_extract(
  184. self.zip_data_files)
  185. def _convert_csv_to_dataset(self, split_name, csv_file_path):
  186. df = pd.read_csv(
  187. csv_file_path, iterator=False, delimiter=self.csv_delimiter)
  188. transform_fields = []
  189. for field_name in df.columns.tolist():
  190. if field_name.endswith(':FILE'):
  191. transform_fields.append(field_name)
  192. base_extracted_dir: Union[str, list] = self.split_path_dict.get(
  193. split_name, '')
  194. for field_name in transform_fields:
  195. if isinstance(base_extracted_dir,
  196. list) and len(base_extracted_dir) > 0:
  197. if df.shape[0] != len(base_extracted_dir):
  198. logger.error(
  199. f"Number of lines in meta-csv file for split '{split_name}' ({df.shape[0]}) "
  200. f'does not match number of data-files({len(base_extracted_dir)})!'
  201. )
  202. else:
  203. df[field_name] = base_extracted_dir
  204. elif isinstance(base_extracted_dir, str) and base_extracted_dir:
  205. df[field_name] = df[field_name].apply(
  206. lambda x: os.path.join(base_extracted_dir, x))
  207. else:
  208. logger.warning(f'Nothing to do for field {field_name}')
  209. pa_data = pa.Table.from_pandas(df)
  210. return Dataset(arrow_table=pa_data)
  211. def as_dataset(self) -> DatasetDict:
  212. return DatasetDict({
  213. k: self._convert_csv_to_dataset(k, v)
  214. for k, v in self.local_meta_csv_paths.items()
  215. })
  216. class TaskSpecificDatasetBuilder(CsvDatasetBuilder):
  217. def __init__(self, dataset_context_config: DatasetContextConfig):
  218. # Init args
  219. self.name = dataset_context_config.dataset_name
  220. self.subset_name = dataset_context_config.subset_name
  221. self.namespace = dataset_context_config.namespace
  222. self.split = dataset_context_config.split
  223. self.version = dataset_context_config.version
  224. split = self.split or list(dataset_context_config.data_meta_config.
  225. target_dataset_structure.keys())
  226. self.hash = get_subdir_hash_from_split(
  227. split=split, version=self.version)
  228. self.data_files = dataset_context_config.data_meta_config.meta_data_files
  229. self.zip_data_files = dataset_context_config.data_meta_config.zip_data_files
  230. self.split_path_dict = None
  231. self.config = None
  232. self.info = DatasetInfo.from_dict(
  233. {'builder_name': dataset_context_config.dataset_name})
  234. self._cache_dir_root = os.path.expanduser(
  235. dataset_context_config.cache_root_dir)
  236. self._cache_dir = self._build_cache_dir()
  237. self._config_kwargs = dataset_context_config.data_meta_config.meta_args_map
  238. def download_and_prepare(self, download_mode, dl_manager,
  239. **download_kwargs):
  240. # Prevent parallel disk operations
  241. lock_path = os.path.join(
  242. self._cache_dir_root,
  243. self._cache_dir.replace(os.sep, '_') + '.lock')
  244. with FileLock(lock_path):
  245. data_exists = os.path.exists(self._cache_dir)
  246. if data_exists and download_mode == DownloadMode.REUSE_DATASET_IF_EXISTS: # TODO: .value??
  247. logger.warning(
  248. f'Reusing dataset {self.name} ({self._cache_dir})')
  249. return
  250. logger.info(f'Generating dataset {self.name} ({self._cache_dir})')
  251. self._download_and_prepare(dl_manager=dl_manager)
  252. def _download_and_prepare(self, dl_manager):
  253. self.split_path_dict = dl_manager.download_and_extract(
  254. self.zip_data_files)
  255. def as_dataset(self):
  256. return ExternalDataset(self.split_path_dict, self._config_kwargs)
  257. class IterableDatasetBuilder(csv.Csv):
  258. def __init__(self, dataset_context_config: DatasetContextConfig):
  259. # Init config args
  260. self.dataset_name = dataset_context_config.dataset_name
  261. self.cache_root_dir = dataset_context_config.cache_root_dir
  262. self.namespace = dataset_context_config.namespace
  263. self.version = dataset_context_config.version
  264. self.subset_name = dataset_context_config.subset_name
  265. self.split = dataset_context_config.split
  266. self.meta_data_files = dataset_context_config.data_meta_config.meta_data_files
  267. self.zip_data_files = dataset_context_config.data_meta_config.zip_data_files
  268. self.input_config_kwargs = dataset_context_config.config_kwargs
  269. self.stream_batch_size = dataset_context_config.stream_batch_size
  270. self.cache_build_dir = os.path.join(self.cache_root_dir,
  271. self.namespace, self.dataset_name,
  272. self.version,
  273. DatasetPathName.META_NAME)
  274. self.csv_delimiter = DEFAULT_CSV_DELIMITER
  275. if DELIMITER_NAME in self.input_config_kwargs:
  276. self.csv_delimiter = self.input_config_kwargs[DELIMITER_NAME]
  277. split = self.split or list(dataset_context_config.data_meta_config.
  278. target_dataset_structure.keys())
  279. sub_dir_hash = get_subdir_hash_from_split(
  280. split=split, version=self.version)
  281. super().__init__(
  282. cache_dir=self.cache_build_dir,
  283. dataset_name=self.dataset_name,
  284. config_name=self.namespace,
  285. hash=sub_dir_hash,
  286. data_files=None, # TODO: self.meta_data_files,
  287. **self.input_config_kwargs)
  288. self.info.builder_name = self.dataset_name
  289. self.name = camelcase_to_snakecase(self.dataset_name)
  290. self.meta_csv_df = None
  291. self.meta_cache_dir = dataset_context_config.data_meta_config.meta_cache_dir
  292. @staticmethod
  293. def get_builder_instance(
  294. dataset_context_config: DatasetContextConfig) -> csv.Csv:
  295. builder_instance = IterableDatasetBuilder(
  296. dataset_context_config=dataset_context_config)
  297. return builder_instance
  298. def as_streaming_dataset(
  299. self, dl_manager: DataStreamingDownloadManager
  300. ) -> Union[Dict[str, IterableDataset], IterableDataset]:
  301. if not isinstance(self, (GeneratorBasedBuilder, ArrowBasedBuilder)):
  302. raise ValueError(f'Builder {self.name} is not streamable.')
  303. is_local = not is_remote_filesystem(self._fs)
  304. if not is_local:
  305. raise NotImplementedError(
  306. f'Loading a streaming dataset cached in a {type(self._fs).__name__} is not supported yet.'
  307. )
  308. self._check_manual_download(dl_manager)
  309. splits_generators = {
  310. sg.name: sg
  311. for sg in self._split_generators(dl_manager)
  312. }
  313. # By default, return all splits
  314. split = dl_manager.download_config.split
  315. if split is None:
  316. splits_generator = splits_generators
  317. elif split in splits_generators:
  318. splits_generator = splits_generators[split]
  319. else:
  320. raise ValueError(
  321. f'Bad split: {split}. Available splits: {list(splits_generators)}'
  322. )
  323. # Create a dataset for each of the given splits
  324. streaming_datasets = map_nested(
  325. self._as_streaming_dataset_single,
  326. splits_generator,
  327. map_tuple=True,
  328. )
  329. if isinstance(streaming_datasets, dict):
  330. streaming_datasets = IterableDatasetDict(streaming_datasets)
  331. return streaming_datasets
  332. def _split_generators(self, dl_manager: DataStreamingDownloadManager):
  333. splits = []
  334. meta_data_file = ''
  335. zip_data_file = ''
  336. if self.meta_data_files:
  337. meta_data_file = next(iter(self.meta_data_files.values()))
  338. if self.zip_data_files:
  339. zip_data_file = next(iter(self.zip_data_files.values()))
  340. if meta_data_file and not zip_data_file:
  341. for split_name, meta_file_url in self.meta_data_files.items():
  342. splits.append(
  343. datasets.SplitGenerator(
  344. name=split_name,
  345. gen_kwargs={
  346. 'meta': meta_file_url,
  347. 'files': [],
  348. 'dl_manager': dl_manager,
  349. }))
  350. elif meta_data_file and zip_data_file:
  351. for split_name, files in self.zip_data_files.items():
  352. if isinstance(files, str):
  353. files = [files]
  354. meta_file_url = self.meta_data_files.get(split_name)
  355. splits.append(
  356. datasets.SplitGenerator(
  357. name=split_name,
  358. gen_kwargs={
  359. 'meta': meta_file_url,
  360. 'files': files,
  361. 'dl_manager': dl_manager,
  362. }))
  363. elif not meta_data_file and zip_data_file:
  364. for split_name, files in self.zip_data_files.items():
  365. if isinstance(files, str):
  366. files = [files]
  367. splits.append(
  368. datasets.SplitGenerator(
  369. name=split_name,
  370. gen_kwargs={
  371. 'meta': '',
  372. 'files': files,
  373. 'dl_manager': dl_manager,
  374. }))
  375. else:
  376. raise f'Neither column meta nor data file found in {self.dataset_name}.json, specify at least one column.'
  377. return splits
  378. def _as_streaming_dataset_single(
  379. self,
  380. splits_generator,
  381. ) -> NativeIterableDataset:
  382. ex_iterable = self._get_examples_iterable_for_split(splits_generator)
  383. return NativeIterableDataset(
  384. ex_iterable,
  385. info=self.info,
  386. split=splits_generator.name,
  387. stream_batch_size=self.stream_batch_size)
  388. def _generate_tables(self, **gen_kwargs):
  389. meta_file_url = gen_kwargs.get('meta')
  390. files = gen_kwargs.get('files')
  391. dl_manager = gen_kwargs.get('dl_manager')
  392. hub_api = HubApi()
  393. is_zip = False
  394. zip_file_name = ''
  395. if files:
  396. zip_file = str(next(iter(files)))
  397. if zip_file.endswith('.zip'):
  398. is_zip = True
  399. zip_file_name = os.path.splitext(zip_file)[0]
  400. if meta_file_url and not files:
  401. self._get_meta_csv_df(meta_file_url)
  402. pa_table = pa.Table.from_pandas(self.meta_csv_df)
  403. yield 0, pa_table
  404. elif meta_file_url and files:
  405. # Get meta file
  406. self._get_meta_csv_df(meta_file_url)
  407. if is_zip:
  408. oss_config_for_unzipped = hub_api.get_dataset_access_config_for_unzipped(
  409. self.dataset_name, self.namespace, self.version,
  410. zip_file_name)
  411. dl_manager.download_config.oss_config = oss_config_for_unzipped
  412. pa_table = pa.Table.from_pandas(self.meta_csv_df)
  413. yield 0, pa_table
  414. elif not meta_file_url and files:
  415. pa_table = pa.Table.from_pydict({'Input:FILE': files})
  416. yield 0, pa_table
  417. else:
  418. raise f'Neither column meta nor data file found in {self.dataset_name}.json .'
  419. def _get_meta_csv_df(self, meta_file_url: str) -> None:
  420. if self.meta_csv_df is None or self.meta_csv_df.empty:
  421. meta_csv_file_path = HubApi.fetch_meta_files_from_url(
  422. meta_file_url, self.meta_cache_dir)
  423. self.meta_csv_df = pd.read_csv(
  424. meta_csv_file_path,
  425. iterator=False,
  426. delimiter=self.csv_delimiter)
  427. @staticmethod
  428. def trans_data_to_mapping(headers: str, texts: list, delimiter: str):
  429. res = {}
  430. headers = headers.split(delimiter)
  431. for idx in range(0, len(headers)):
  432. col_list = []
  433. for line in texts:
  434. col_list.append(line.split(delimiter)[idx])
  435. res[headers[idx]] = col_list
  436. return res