hf_datasets_util.py 62 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113111411151116111711181119112011211122112311241125112611271128112911301131113211331134113511361137113811391140114111421143114411451146114711481149115011511152115311541155115611571158115911601161116211631164116511661167116811691170117111721173117411751176117711781179118011811182118311841185118611871188118911901191119211931194119511961197119811991200120112021203120412051206120712081209121012111212121312141215121612171218121912201221122212231224122512261227122812291230123112321233123412351236123712381239124012411242124312441245124612471248124912501251125212531254125512561257125812591260126112621263126412651266126712681269127012711272127312741275127612771278127912801281128212831284128512861287128812891290129112921293129412951296129712981299130013011302130313041305130613071308130913101311131213131314131513161317131813191320132113221323132413251326132713281329133013311332133313341335133613371338133913401341134213431344134513461347134813491350135113521353135413551356135713581359136013611362136313641365136613671368136913701371137213731374137513761377137813791380138113821383138413851386138713881389139013911392139313941395139613971398139914001401140214031404140514061407140814091410141114121413141414151416141714181419142014211422142314241425142614271428142914301431143214331434143514361437143814391440144114421443144414451446144714481449145014511452145314541455145614571458145914601461146214631464146514661467146814691470147114721473147414751476147714781479148014811482148314841485148614871488148914901491149214931494149514961497149814991500150115021503150415051506150715081509151015111512151315141515
  1. # noqa: isort:skip_file, yapf: disable
  2. # Copyright (c) Alibaba, Inc. and its affiliates.
  3. # Copyright 2020 The HuggingFace Datasets Authors and the TensorFlow Datasets Authors.
  4. import importlib
  5. import contextlib
  6. import inspect
  7. import os
  8. import warnings
  9. from dataclasses import dataclass, field, fields
  10. from functools import partial
  11. from pathlib import Path
  12. from typing import Dict, Iterable, List, Mapping, Optional, Sequence, Union, Tuple, Literal, Any, ClassVar
  13. from urllib.parse import urlencode
  14. import requests
  15. from datasets import (BuilderConfig, Dataset, DatasetBuilder, DatasetDict,
  16. DownloadConfig, DownloadManager, DownloadMode, Features,
  17. IterableDataset, IterableDatasetDict, Split,
  18. VerificationMode, Version, config, data_files, LargeList, Sequence as SequenceHf)
  19. from datasets.features import features
  20. from datasets.features.features import _FEATURE_TYPES
  21. from datasets.data_files import (
  22. FILES_TO_IGNORE, DataFilesDict, EmptyDatasetError,
  23. _get_data_files_patterns, _is_inside_unrequested_special_dir,
  24. _is_unrequested_hidden_file_or_is_inside_unrequested_hidden_dir, sanitize_patterns)
  25. from datasets.download.streaming_download_manager import (
  26. _prepare_path_and_storage_options, xbasename, xjoin)
  27. from datasets.exceptions import DataFilesNotFoundError, DatasetNotFoundError
  28. from datasets.info import DatasetInfosDict
  29. from datasets.load import (
  30. ALL_ALLOWED_EXTENSIONS, BuilderConfigsParameters,
  31. CachedDatasetModuleFactory, DatasetModule,
  32. HubDatasetModuleFactoryWithoutScript,
  33. HubDatasetModuleFactoryWithParquetExport,
  34. HubDatasetModuleFactoryWithScript, LocalDatasetModuleFactoryWithoutScript,
  35. LocalDatasetModuleFactoryWithScript, PackagedDatasetModuleFactory,
  36. create_builder_configs_from_metadata_configs, get_dataset_builder_class,
  37. import_main_class, infer_module_for_data_files, files_to_hash,
  38. _get_importable_file_path, resolve_trust_remote_code, _create_importable_file, _load_importable_file,
  39. init_dynamic_modules)
  40. from datasets.naming import camelcase_to_snakecase
  41. from datasets.packaged_modules import (_EXTENSION_TO_MODULE,
  42. _MODULE_TO_EXTENSIONS,
  43. _PACKAGED_DATASETS_MODULES)
  44. from datasets.utils import file_utils
  45. from datasets.utils.file_utils import (_raise_if_offline_mode_is_enabled,
  46. cached_path, is_local_path,
  47. is_relative_path,
  48. relative_to_absolute_path)
  49. from datasets.utils.info_utils import is_small_dataset
  50. from datasets.utils.metadata import MetadataConfigs
  51. from datasets.utils.py_utils import get_imports
  52. from datasets.utils.track import tracked_str
  53. from fsspec import filesystem
  54. from fsspec.core import _un_chain
  55. from fsspec.utils import stringify_path
  56. from huggingface_hub import (DatasetCard, DatasetCardData)
  57. from huggingface_hub.errors import OfflineModeIsEnabled
  58. from huggingface_hub.hf_api import DatasetInfo as HfDatasetInfo
  59. from huggingface_hub.hf_api import HfApi, RepoFile, RepoFolder
  60. from packaging import version
  61. from modelscope import HubApi
  62. from modelscope.hub.utils.utils import get_endpoint
  63. from modelscope.msdatasets.utils.hf_file_utils import get_from_cache_ms
  64. from modelscope.utils.config_ds import MS_DATASETS_CACHE
  65. from modelscope.utils.constant import DEFAULT_DATASET_REVISION, REPO_TYPE_DATASET
  66. from modelscope.utils.import_utils import has_attr_in_class
  67. from modelscope.utils.logger import get_logger
  68. logger = get_logger()
  69. ExpandDatasetProperty_T = Literal[
  70. 'author',
  71. 'cardData',
  72. 'citation',
  73. 'createdAt',
  74. 'disabled',
  75. 'description',
  76. 'downloads',
  77. 'downloadsAllTime',
  78. 'gated',
  79. 'lastModified',
  80. 'likes',
  81. 'paperswithcode_id',
  82. 'private',
  83. 'siblings',
  84. 'sha',
  85. 'tags',
  86. ]
  87. # Patch datasets features
  88. @dataclass(repr=False)
  89. class ListMs(SequenceHf):
  90. """Feature type for large list data composed of child feature data type.
  91. It is backed by `pyarrow.ListType`, which uses 32-bit offsets or a fixed length.
  92. Args:
  93. feature ([`FeatureType`]):
  94. Child feature data type of each item within the large list.
  95. length (optional `int`, default to -1):
  96. Length of the list if it is fixed.
  97. Defaults to -1 which means an arbitrary length.
  98. """
  99. feature: Any
  100. length: int = -1
  101. id: Optional[str] = field(default=None, repr=False)
  102. # Automatically constructed
  103. pa_type: ClassVar[Any] = None
  104. _type: str = field(default='List', init=False, repr=False)
  105. def __repr__(self):
  106. if self.length != -1:
  107. return f'{type(self).__name__}({self.feature}, length={self.length})'
  108. else:
  109. return f'{type(self).__name__}({self.feature})'
  110. _FEATURE_TYPES['List'] = ListMs
  111. def generate_from_dict_ms(obj: Any):
  112. """Regenerate the nested feature object from a deserialized dict.
  113. We use the '_type' fields to get the dataclass name to load.
  114. generate_from_dict is the recursive helper for Features.from_dict, and allows for a convenient constructor syntax
  115. to define features from deserialized JSON dictionaries. This function is used in particular when deserializing
  116. a :class:`DatasetInfo` that was dumped to a JSON object. This acts as an analogue to
  117. :meth:`Features.from_arrow_schema` and handles the recursive field-by-field instantiation, but doesn't require any
  118. mapping to/from pyarrow, except for the fact that it takes advantage of the mapping of pyarrow primitive dtypes
  119. that :class:`Value` automatically performs.
  120. """
  121. # Nested structures: we allow dict, list/tuples, sequences
  122. if isinstance(obj, list):
  123. return [generate_from_dict_ms(value) for value in obj]
  124. # Otherwise we have a dict or a dataclass
  125. if '_type' not in obj or isinstance(obj['_type'], dict):
  126. return {key: generate_from_dict_ms(value) for key, value in obj.items()}
  127. obj = dict(obj)
  128. _type = obj.pop('_type')
  129. class_type = _FEATURE_TYPES.get(_type, None) or globals().get(_type, None)
  130. if class_type is None:
  131. raise ValueError(f"Feature type '{_type}' not found. Available feature types: {list(_FEATURE_TYPES.keys())}")
  132. if class_type == LargeList:
  133. feature = obj.pop('feature')
  134. return LargeList(generate_from_dict_ms(feature), **obj)
  135. if class_type == ListMs:
  136. feature = obj.pop('feature')
  137. return ListMs(generate_from_dict_ms(feature), **obj)
  138. if class_type == SequenceHf: # backward compatibility, this translates to a List or a dict
  139. feature = obj.pop('feature')
  140. return SequenceHf(feature=generate_from_dict_ms(feature), **obj)
  141. field_names = {f.name for f in fields(class_type)}
  142. return class_type(**{k: v for k, v in obj.items() if k in field_names})
  143. def _download_ms(self, url_or_filename: str, download_config: DownloadConfig) -> str:
  144. url_or_filename = str(url_or_filename)
  145. # for temp val
  146. revision = None
  147. if url_or_filename.startswith('hf://'):
  148. revision, url_or_filename = url_or_filename.split('@', 1)[-1].split('/', 1)
  149. if is_relative_path(url_or_filename):
  150. # append the relative path to the base_path
  151. # url_or_filename = url_or_path_join(self._base_path, url_or_filename)
  152. revision = revision or DEFAULT_DATASET_REVISION
  153. # Note: make sure the FilePath is the last param
  154. params: dict = {'Source': 'SDK', 'Revision': revision, 'FilePath': url_or_filename}
  155. params: str = urlencode(params)
  156. url_or_filename = self._base_path + params
  157. out = cached_path(url_or_filename, download_config=download_config)
  158. out = tracked_str(out)
  159. out.set_origin(url_or_filename)
  160. return out
  161. def _dataset_info(
  162. self,
  163. repo_id: str,
  164. *,
  165. revision: Optional[str] = None,
  166. timeout: Optional[float] = None,
  167. files_metadata: bool = False,
  168. token: Optional[Union[bool, str]] = None,
  169. expand: Optional[List[ExpandDatasetProperty_T]] = None,
  170. ) -> HfDatasetInfo:
  171. """
  172. Get info on one specific dataset on huggingface.co.
  173. Dataset can be private if you pass an acceptable token.
  174. Args:
  175. repo_id (`str`):
  176. A namespace (user or an organization) and a repo name separated
  177. by a `/`.
  178. revision (`str`, *optional*):
  179. The revision of the dataset repository from which to get the
  180. information.
  181. timeout (`float`, *optional*):
  182. Whether to set a timeout for the request to the Hub.
  183. files_metadata (`bool`, *optional*):
  184. Whether or not to retrieve metadata for files in the repository
  185. (size, LFS metadata, etc). Defaults to `False`.
  186. token (`bool` or `str`, *optional*):
  187. A valid authentication token (see https://huggingface.co/settings/token).
  188. If `None` or `True` and machine is logged in (through `huggingface-cli login`
  189. or [`~huggingface_hub.login`]), token will be retrieved from the cache.
  190. If `False`, token is not sent in the request header.
  191. Returns:
  192. [`hf_api.DatasetInfo`]: The dataset repository information.
  193. <Tip>
  194. Raises the following errors:
  195. - [`~utils.RepositoryNotFoundError`]
  196. If the repository to download from cannot be found. This may be because it doesn't exist,
  197. or because it is set to `private` and you do not have access.
  198. - [`~utils.RevisionNotFoundError`]
  199. If the revision to download from cannot be found.
  200. </Tip>
  201. """
  202. # Note: refer to `_list_repo_tree()`, for patching `HfApi.list_repo_tree`
  203. repo_info_iter = self.list_repo_tree(
  204. repo_id=repo_id,
  205. path_in_repo='/',
  206. revision=revision,
  207. recursive=False,
  208. expand=expand,
  209. token=token,
  210. repo_type=REPO_TYPE_DATASET,
  211. )
  212. # Update data_info
  213. data_info = dict({})
  214. data_info['id'] = repo_id
  215. data_info['private'] = False
  216. data_info['author'] = repo_id.split('/')[0] if repo_id else None
  217. data_info['sha'] = revision
  218. data_info['lastModified'] = None
  219. data_info['gated'] = False
  220. data_info['disabled'] = False
  221. data_info['downloads'] = 0
  222. data_info['likes'] = 0
  223. data_info['tags'] = []
  224. data_info['cardData'] = []
  225. data_info['createdAt'] = None
  226. # e.g. {'rfilename': 'xxx', 'blobId': 'xxx', 'size': 0, 'lfs': {'size': 0, 'sha256': 'xxx', 'pointerSize': 0}}
  227. data_siblings = []
  228. for info_item in repo_info_iter:
  229. if isinstance(info_item, RepoFile):
  230. data_siblings.append(
  231. dict(
  232. rfilename=info_item.rfilename,
  233. blobId=info_item.blob_id,
  234. size=info_item.size,
  235. )
  236. )
  237. data_info['siblings'] = data_siblings
  238. return HfDatasetInfo(**data_info)
  239. def _list_repo_tree(
  240. self,
  241. repo_id: str,
  242. path_in_repo: Optional[str] = None,
  243. *,
  244. recursive: bool = True,
  245. expand: bool = False,
  246. revision: Optional[str] = None,
  247. repo_type: Optional[str] = None,
  248. token: Optional[Union[bool, str]] = None,
  249. ) -> Iterable[Union[RepoFile, RepoFolder]]:
  250. _api = HubApi(timeout=3 * 60, max_retries=3)
  251. endpoint = _api.get_endpoint_for_read(
  252. repo_id=repo_id, repo_type=REPO_TYPE_DATASET)
  253. # List all files in the repo
  254. page_number = 1
  255. page_size = 100
  256. while True:
  257. try:
  258. dataset_files = _api.get_dataset_files(
  259. repo_id=repo_id,
  260. revision=revision or DEFAULT_DATASET_REVISION,
  261. root_path=path_in_repo or '/',
  262. recursive=recursive,
  263. page_number=page_number,
  264. page_size=page_size,
  265. endpoint=endpoint,
  266. )
  267. except Exception as e:
  268. logger.error(f'Get dataset: {repo_id} file list failed, message: {e}')
  269. break
  270. for file_info_d in dataset_files:
  271. path_info = {}
  272. path_info['type'] = 'directory' if file_info_d['Type'] == 'tree' else 'file'
  273. path_info['path'] = file_info_d['Path']
  274. path_info['size'] = file_info_d['Size']
  275. path_info['oid'] = file_info_d['Sha256']
  276. yield RepoFile(**path_info) if path_info['type'] == 'file' else RepoFolder(**path_info)
  277. if len(dataset_files) < page_size:
  278. break
  279. page_number += 1
  280. def _get_paths_info(
  281. self,
  282. repo_id: str,
  283. paths: Union[List[str], str],
  284. *,
  285. expand: bool = False,
  286. revision: Optional[str] = None,
  287. repo_type: Optional[str] = None,
  288. token: Optional[Union[bool, str]] = None,
  289. ) -> List[Union[RepoFile, RepoFolder]]:
  290. # Refer to func: `_list_repo_tree()`, for patching `HfApi.list_repo_tree`
  291. repo_info_iter = self.list_repo_tree(
  292. repo_id=repo_id,
  293. recursive=False,
  294. expand=expand,
  295. revision=revision,
  296. repo_type=repo_type,
  297. token=token,
  298. )
  299. return [item_info for item_info in repo_info_iter]
  300. def _download_repo_file(repo_id: str, path_in_repo: str, download_config: DownloadConfig, revision: str):
  301. _api = HubApi()
  302. _namespace, _dataset_name = repo_id.split('/')
  303. endpoint = _api.get_endpoint_for_read(
  304. repo_id=repo_id, repo_type=REPO_TYPE_DATASET)
  305. if download_config and download_config.download_desc is None:
  306. download_config.download_desc = f'Downloading [{path_in_repo}]'
  307. try:
  308. url_or_filename = _api.get_dataset_file_url(
  309. file_name=path_in_repo,
  310. dataset_name=_dataset_name,
  311. namespace=_namespace,
  312. revision=revision,
  313. extension_filter=False,
  314. endpoint=endpoint
  315. )
  316. repo_file_path = cached_path(
  317. url_or_filename=url_or_filename, download_config=download_config)
  318. except FileNotFoundError as e:
  319. repo_file_path = ''
  320. logger.error(e)
  321. return repo_file_path
  322. def get_fs_token_paths(
  323. urlpath,
  324. storage_options=None,
  325. protocol=None,
  326. ):
  327. if isinstance(urlpath, (list, tuple, set)):
  328. if not urlpath:
  329. raise ValueError('empty urlpath sequence')
  330. urlpath0 = stringify_path(list(urlpath)[0])
  331. else:
  332. urlpath0 = stringify_path(urlpath)
  333. storage_options = storage_options or {}
  334. if protocol:
  335. storage_options['protocol'] = protocol
  336. chain = _un_chain(urlpath0, storage_options or {})
  337. inkwargs = {}
  338. # Reverse iterate the chain, creating a nested target_* structure
  339. for i, ch in enumerate(reversed(chain)):
  340. urls, nested_protocol, kw = ch
  341. if i == len(chain) - 1:
  342. inkwargs = dict(**kw, **inkwargs)
  343. continue
  344. inkwargs['target_options'] = dict(**kw, **inkwargs)
  345. inkwargs['target_protocol'] = nested_protocol
  346. inkwargs['fo'] = urls
  347. paths, protocol, _ = chain[0]
  348. fs = filesystem(protocol, **inkwargs)
  349. return fs
  350. def _resolve_pattern(
  351. pattern: str,
  352. base_path: str,
  353. allowed_extensions: Optional[List[str]] = None,
  354. download_config: Optional[DownloadConfig] = None,
  355. ) -> List[str]:
  356. """
  357. Resolve the paths and URLs of the data files from the pattern passed by the user.
  358. You can use patterns to resolve multiple local files. Here are a few examples:
  359. - *.csv to match all the CSV files at the first level
  360. - **.csv to match all the CSV files at any level
  361. - data/* to match all the files inside "data"
  362. - data/** to match all the files inside "data" and its subdirectories
  363. The patterns are resolved using the fsspec glob.
  364. glob.glob, Path.glob, Path.match or fnmatch do not support ** with a prefix/suffix other than a forward slash /.
  365. For instance, this means **.json is the same as *.json. On the contrary, the fsspec glob has no limits regarding the ** prefix/suffix, # noqa: E501
  366. resulting in **.json being equivalent to **/*.json.
  367. More generally:
  368. - '*' matches any character except a forward-slash (to match just the file or directory name)
  369. - '**' matches any character including a forward-slash /
  370. Hidden files and directories (i.e. whose names start with a dot) are ignored, unless they are explicitly requested.
  371. The same applies to special directories that start with a double underscore like "__pycache__".
  372. You can still include one if the pattern explicitly mentions it:
  373. - to include a hidden file: "*/.hidden.txt" or "*/.*"
  374. - to include a hidden directory: ".hidden/*" or ".*/*"
  375. - to include a special directory: "__special__/*" or "__*/*"
  376. Example::
  377. >>> from datasets.data_files import resolve_pattern
  378. >>> base_path = "."
  379. >>> resolve_pattern("docs/**/*.py", base_path)
  380. [/Users/mariosasko/Desktop/projects/datasets/docs/source/_config.py']
  381. Args:
  382. pattern (str): Unix pattern or paths or URLs of the data files to resolve.
  383. The paths can be absolute or relative to base_path.
  384. Remote filesystems using fsspec are supported, e.g. with the hf:// protocol.
  385. base_path (str): Base path to use when resolving relative paths.
  386. allowed_extensions (Optional[list], optional): White-list of file extensions to use. Defaults to None (all extensions).
  387. For example: allowed_extensions=[".csv", ".json", ".txt", ".parquet"]
  388. Returns:
  389. List[str]: List of paths or URLs to the local or remote files that match the patterns.
  390. """
  391. if is_relative_path(pattern):
  392. pattern = xjoin(base_path, pattern)
  393. elif is_local_path(pattern):
  394. base_path = os.path.splitdrive(pattern)[0] + os.sep
  395. else:
  396. base_path = ''
  397. # storage_options: {'hf': {'token': None, 'endpoint': 'https://huggingface.co'}}
  398. pattern, storage_options = _prepare_path_and_storage_options(
  399. pattern, download_config=download_config)
  400. fs = get_fs_token_paths(pattern, storage_options=storage_options)
  401. fs_base_path = base_path.split('::')[0].split('://')[-1] or fs.root_marker
  402. fs_pattern = pattern.split('::')[0].split('://')[-1]
  403. files_to_ignore = set(FILES_TO_IGNORE) - {xbasename(pattern)}
  404. protocol = fs.protocol if isinstance(fs.protocol, str) else fs.protocol[0]
  405. protocol_prefix = protocol + '://' if protocol != 'file' else ''
  406. glob_kwargs = {}
  407. if protocol == 'hf' and config.HF_HUB_VERSION >= version.parse('0.20.0'):
  408. # 10 times faster glob with detail=True (ignores costly info like lastCommit)
  409. glob_kwargs['expand_info'] = False
  410. try:
  411. tmp_file_paths = fs.glob(pattern, detail=True, **glob_kwargs)
  412. except FileNotFoundError:
  413. raise DataFilesNotFoundError(f"Unable to find '{pattern}'")
  414. matched_paths = [
  415. filepath if filepath.startswith(protocol_prefix) else protocol_prefix
  416. + filepath for filepath, info in tmp_file_paths.items()
  417. if info['type'] == 'file' and (
  418. xbasename(filepath) not in files_to_ignore)
  419. and not _is_inside_unrequested_special_dir(
  420. os.path.relpath(filepath, fs_base_path),
  421. os.path.relpath(fs_pattern, fs_base_path)) and # noqa: W504
  422. not _is_unrequested_hidden_file_or_is_inside_unrequested_hidden_dir( # noqa: W504
  423. os.path.relpath(filepath, fs_base_path),
  424. os.path.relpath(fs_pattern, fs_base_path))
  425. ] # ignore .ipynb and __pycache__, but keep /../
  426. if allowed_extensions is not None:
  427. out = [
  428. filepath for filepath in matched_paths
  429. if any('.' + suffix in allowed_extensions
  430. for suffix in xbasename(filepath).split('.')[1:])
  431. ]
  432. if len(out) < len(matched_paths):
  433. invalid_matched_files = list(set(matched_paths) - set(out))
  434. logger.info(
  435. f"Some files matched the pattern '{pattern}' but don't have valid data file extensions: "
  436. f'{invalid_matched_files}')
  437. else:
  438. out = matched_paths
  439. if not out:
  440. error_msg = f"Unable to find '{pattern}'"
  441. if allowed_extensions is not None:
  442. error_msg += f' with any supported extension {list(allowed_extensions)}'
  443. raise FileNotFoundError(error_msg)
  444. return out
  445. def _get_data_patterns(
  446. base_path: str,
  447. download_config: Optional[DownloadConfig] = None) -> Dict[str,
  448. List[str]]:
  449. """
  450. Get the default pattern from a directory testing all the supported patterns.
  451. The first patterns to return a non-empty list of data files is returned.
  452. Some examples of supported patterns:
  453. Input:
  454. my_dataset_repository/
  455. ├── README.md
  456. └── dataset.csv
  457. Output:
  458. {"train": ["**"]}
  459. Input:
  460. my_dataset_repository/
  461. ├── README.md
  462. ├── train.csv
  463. └── test.csv
  464. my_dataset_repository/
  465. ├── README.md
  466. └── data/
  467. ├── train.csv
  468. └── test.csv
  469. my_dataset_repository/
  470. ├── README.md
  471. ├── train_0.csv
  472. ├── train_1.csv
  473. ├── train_2.csv
  474. ├── train_3.csv
  475. ├── test_0.csv
  476. └── test_1.csv
  477. Output:
  478. {'train': ['train[-._ 0-9/]**', '**/*[-._ 0-9/]train[-._ 0-9/]**',
  479. 'training[-._ 0-9/]**', '**/*[-._ 0-9/]training[-._ 0-9/]**'],
  480. 'test': ['test[-._ 0-9/]**', '**/*[-._ 0-9/]test[-._ 0-9/]**',
  481. 'testing[-._ 0-9/]**', '**/*[-._ 0-9/]testing[-._ 0-9/]**', ...]}
  482. Input:
  483. my_dataset_repository/
  484. ├── README.md
  485. └── data/
  486. ├── train/
  487. │ ├── shard_0.csv
  488. │ ├── shard_1.csv
  489. │ ├── shard_2.csv
  490. │ └── shard_3.csv
  491. └── test/
  492. ├── shard_0.csv
  493. └── shard_1.csv
  494. Output:
  495. {'train': ['train[-._ 0-9/]**', '**/*[-._ 0-9/]train[-._ 0-9/]**',
  496. 'training[-._ 0-9/]**', '**/*[-._ 0-9/]training[-._ 0-9/]**'],
  497. 'test': ['test[-._ 0-9/]**', '**/*[-._ 0-9/]test[-._ 0-9/]**',
  498. 'testing[-._ 0-9/]**', '**/*[-._ 0-9/]testing[-._ 0-9/]**', ...]}
  499. Input:
  500. my_dataset_repository/
  501. ├── README.md
  502. └── data/
  503. ├── train-00000-of-00003.csv
  504. ├── train-00001-of-00003.csv
  505. ├── train-00002-of-00003.csv
  506. ├── test-00000-of-00001.csv
  507. ├── random-00000-of-00003.csv
  508. ├── random-00001-of-00003.csv
  509. └── random-00002-of-00003.csv
  510. Output:
  511. {'train': ['data/train-[0-9][0-9][0-9][0-9][0-9]-of-[0-9][0-9][0-9][0-9][0-9]*.*'],
  512. 'test': ['data/test-[0-9][0-9][0-9][0-9][0-9]-of-[0-9][0-9][0-9][0-9][0-9]*.*'],
  513. 'random': ['data/random-[0-9][0-9][0-9][0-9][0-9]-of-[0-9][0-9][0-9][0-9][0-9]*.*']}
  514. In order, it first tests if SPLIT_PATTERN_SHARDED works, otherwise it tests the patterns in ALL_DEFAULT_PATTERNS.
  515. """
  516. resolver = partial(
  517. _resolve_pattern, base_path=base_path, download_config=download_config)
  518. try:
  519. return _get_data_files_patterns(resolver)
  520. except FileNotFoundError:
  521. raise EmptyDatasetError(
  522. f"The directory at {base_path} doesn't contain any data files"
  523. ) from None
  524. def get_module_without_script(self) -> DatasetModule:
  525. # hfh_dataset_info = HfApi(config.HF_ENDPOINT).dataset_info(
  526. # self.name,
  527. # revision=self.revision,
  528. # token=self.download_config.token,
  529. # timeout=100.0,
  530. # )
  531. # even if metadata_configs is not None (which means that we will resolve files for each config later)
  532. # we cannot skip resolving all files because we need to infer module name by files extensions
  533. # revision = hfh_dataset_info.sha # fix the revision in case there are new commits in the meantime
  534. revision = self.download_config.storage_options.get('revision', None) or DEFAULT_DATASET_REVISION
  535. base_path = f"hf://datasets/{self.name}@{revision}/{self.data_dir or ''}".rstrip(
  536. '/')
  537. repo_id: str = self.name
  538. download_config = self.download_config.copy()
  539. dataset_readme_path = _download_repo_file(
  540. repo_id=repo_id,
  541. path_in_repo='README.md',
  542. download_config=download_config,
  543. revision=revision)
  544. dataset_card_data = DatasetCard.load(Path(dataset_readme_path)).data if dataset_readme_path else DatasetCardData()
  545. subset_name: str = download_config.storage_options.get('name', None)
  546. metadata_configs = MetadataConfigs.from_dataset_card_data(
  547. dataset_card_data)
  548. dataset_infos = DatasetInfosDict.from_dataset_card_data(dataset_card_data)
  549. # we need a set of data files to find which dataset builder to use
  550. # because we need to infer module name by files extensions
  551. if self.data_files is not None:
  552. patterns = sanitize_patterns(self.data_files)
  553. elif metadata_configs and 'data_files' in next(
  554. iter(metadata_configs.values())):
  555. if subset_name is not None:
  556. subset_data_files = metadata_configs[subset_name]['data_files']
  557. else:
  558. subset_data_files = next(iter(metadata_configs.values()))['data_files']
  559. patterns = sanitize_patterns(subset_data_files)
  560. else:
  561. patterns = _get_data_patterns(
  562. base_path, download_config=self.download_config)
  563. data_files = DataFilesDict.from_patterns(
  564. patterns,
  565. base_path=base_path,
  566. allowed_extensions=ALL_ALLOWED_EXTENSIONS,
  567. download_config=self.download_config,
  568. )
  569. module_name, default_builder_kwargs = infer_module_for_data_files(
  570. data_files=data_files,
  571. path=self.name,
  572. download_config=self.download_config,
  573. )
  574. if hasattr(data_files, 'filter'):
  575. data_files = data_files.filter(extensions=_MODULE_TO_EXTENSIONS[module_name])
  576. else:
  577. data_files = data_files.filter_extensions(_MODULE_TO_EXTENSIONS[module_name])
  578. module_path, _ = _PACKAGED_DATASETS_MODULES[module_name]
  579. if metadata_configs:
  580. supports_metadata = module_name in {'imagefolder', 'audiofolder'}
  581. create_builder_signature = inspect.signature(create_builder_configs_from_metadata_configs)
  582. in_args = {
  583. 'module_path': module_path,
  584. 'metadata_configs': metadata_configs,
  585. 'base_path': base_path,
  586. 'default_builder_kwargs': default_builder_kwargs,
  587. 'download_config': self.download_config,
  588. }
  589. if 'supports_metadata' in create_builder_signature.parameters:
  590. in_args['supports_metadata'] = supports_metadata
  591. builder_configs, default_config_name = create_builder_configs_from_metadata_configs(**in_args)
  592. else:
  593. builder_configs: List[BuilderConfig] = [
  594. import_main_class(module_path).BUILDER_CONFIG_CLASS(
  595. data_files=data_files,
  596. **default_builder_kwargs,
  597. )
  598. ]
  599. default_config_name = None
  600. _api = HubApi()
  601. endpoint = _api.get_endpoint_for_read(
  602. repo_id=repo_id, repo_type=REPO_TYPE_DATASET)
  603. builder_kwargs = {
  604. # "base_path": hf_hub_url(self.name, "", revision=revision).rstrip("/"),
  605. 'base_path':
  606. HubApi().get_file_base_path(repo_id=repo_id, endpoint=endpoint),
  607. 'repo_id':
  608. self.name,
  609. 'dataset_name':
  610. camelcase_to_snakecase(Path(self.name).name),
  611. 'data_files': data_files,
  612. }
  613. download_config = self.download_config.copy()
  614. if download_config.download_desc is None:
  615. download_config.download_desc = 'Downloading metadata'
  616. # Note: `dataset_infos.json` is deprecated and can cause an error during loading if it exists
  617. if default_config_name is None and len(dataset_infos) == 1:
  618. default_config_name = next(iter(dataset_infos))
  619. hash = revision
  620. return DatasetModule(
  621. module_path,
  622. hash,
  623. builder_kwargs,
  624. dataset_infos=dataset_infos,
  625. builder_configs_parameters=BuilderConfigsParameters(
  626. metadata_configs=metadata_configs,
  627. builder_configs=builder_configs,
  628. default_config_name=default_config_name,
  629. ),
  630. )
  631. def _download_additional_modules(
  632. name: str,
  633. dataset_name: str,
  634. namespace: str,
  635. revision: str,
  636. imports: Tuple[str, str, str, str],
  637. download_config: Optional[DownloadConfig],
  638. trust_remote_code: Optional[bool] = False,
  639. ) -> List[Tuple[str, str]]:
  640. """
  641. Download additional module for a module <name>.py at URL (or local path) <base_path>/<name>.py
  642. The imports must have been parsed first using ``get_imports``.
  643. If some modules need to be installed with pip, an error is raised showing how to install them.
  644. This function return the list of downloaded modules as tuples (import_name, module_file_path).
  645. The downloaded modules can then be moved into an importable directory
  646. with ``_copy_script_and_other_resources_in_importable_dir``.
  647. """
  648. local_imports = []
  649. library_imports = []
  650. # Check if we need to execute remote code
  651. has_remote_code = any(
  652. import_type in ('internal', 'external')
  653. for import_type, _, _, _ in imports
  654. )
  655. if has_remote_code and not trust_remote_code:
  656. raise ValueError(
  657. f'Loading {name} requires executing code from the repository. '
  658. 'This is disabled by default for security reasons. '
  659. 'If you trust the authors of this dataset, you can enable it with '
  660. '`trust_remote_code=True`.'
  661. )
  662. download_config = download_config.copy()
  663. if download_config.download_desc is None:
  664. download_config.download_desc = 'Downloading extra modules'
  665. for import_type, import_name, import_path, sub_directory in imports:
  666. if import_type == 'library':
  667. library_imports.append((import_name, import_path)) # Import from a library
  668. continue
  669. if import_name == name:
  670. raise ValueError(
  671. f'Error in the {name} script, importing relative {import_name} module '
  672. f'but {import_name} is the name of the script. '
  673. f"Please change relative import {import_name} to another name and add a '# From: URL_OR_PATH' "
  674. f'comment pointing to the original relative import file path.'
  675. )
  676. if import_type == 'internal':
  677. _api = HubApi()
  678. # url_or_filename = url_or_path_join(base_path, import_path + ".py")
  679. file_name = import_path + '.py'
  680. url_or_filename = _api.get_dataset_file_url(file_name=file_name,
  681. dataset_name=dataset_name,
  682. namespace=namespace,
  683. revision=revision,)
  684. elif import_type == 'external':
  685. url_or_filename = import_path
  686. else:
  687. raise ValueError('Wrong import_type')
  688. local_import_path = cached_path(
  689. url_or_filename,
  690. download_config=download_config,
  691. )
  692. if sub_directory is not None:
  693. local_import_path = os.path.join(local_import_path, sub_directory)
  694. local_imports.append((import_name, local_import_path))
  695. # Check library imports
  696. needs_to_be_installed = {}
  697. for library_import_name, library_import_path in library_imports:
  698. try:
  699. lib = importlib.import_module(library_import_name) # noqa F841
  700. except ImportError:
  701. if library_import_name not in needs_to_be_installed or library_import_path != library_import_name:
  702. needs_to_be_installed[library_import_name] = library_import_path
  703. if needs_to_be_installed:
  704. _dependencies_str = 'dependencies' if len(needs_to_be_installed) > 1 else 'dependency'
  705. _them_str = 'them' if len(needs_to_be_installed) > 1 else 'it'
  706. if 'sklearn' in needs_to_be_installed.keys():
  707. needs_to_be_installed['sklearn'] = 'scikit-learn'
  708. if 'Bio' in needs_to_be_installed.keys():
  709. needs_to_be_installed['Bio'] = 'biopython'
  710. raise ImportError(
  711. f'To be able to use {name}, you need to install the following {_dependencies_str}: '
  712. f"{', '.join(needs_to_be_installed)}.\nPlease install {_them_str} using 'pip install "
  713. f"{' '.join(needs_to_be_installed.values())}' for instance."
  714. )
  715. return local_imports
  716. def get_module_with_script(self) -> DatasetModule:
  717. repo_id: str = self.name
  718. _namespace, _dataset_name = repo_id.split('/')
  719. revision = self.download_config.storage_options.get('revision', None) or DEFAULT_DATASET_REVISION
  720. script_file_name = f'{_dataset_name}.py'
  721. local_script_path = _download_repo_file(
  722. repo_id=repo_id,
  723. path_in_repo=script_file_name,
  724. download_config=self.download_config,
  725. revision=revision,
  726. )
  727. if not local_script_path:
  728. raise FileNotFoundError(
  729. f'Cannot find {script_file_name} in {repo_id} at revision {revision}. '
  730. f'Please create {script_file_name} in the repo.'
  731. )
  732. dataset_infos_path = None
  733. # try:
  734. # dataset_infos_url: str = _api.get_dataset_file_url(
  735. # file_name='dataset_infos.json',
  736. # dataset_name=_dataset_name,
  737. # namespace=_namespace,
  738. # revision=self.revision,
  739. # extension_filter=False,
  740. # )
  741. # dataset_infos_path = cached_path(
  742. # url_or_filename=dataset_infos_url, download_config=self.download_config)
  743. # except Exception as e:
  744. # logger.info(f'Cannot find dataset_infos.json: {e}')
  745. # dataset_infos_path = None
  746. dataset_readme_path = _download_repo_file(
  747. repo_id=repo_id,
  748. path_in_repo='README.md',
  749. download_config=self.download_config,
  750. revision=revision
  751. )
  752. imports = get_imports(local_script_path)
  753. local_imports = _download_additional_modules(
  754. name=repo_id,
  755. dataset_name=_dataset_name,
  756. namespace=_namespace,
  757. revision=revision,
  758. imports=imports,
  759. download_config=self.download_config,
  760. trust_remote_code=self.trust_remote_code,
  761. )
  762. additional_files = []
  763. if dataset_infos_path:
  764. additional_files.append((config.DATASETDICT_INFOS_FILENAME, dataset_infos_path))
  765. if dataset_readme_path:
  766. additional_files.append((config.REPOCARD_FILENAME, dataset_readme_path))
  767. # copy the script and the files in an importable directory
  768. dynamic_modules_path = self.dynamic_modules_path if self.dynamic_modules_path else init_dynamic_modules()
  769. hash = files_to_hash([local_script_path] + [loc[1] for loc in local_imports])
  770. importable_file_path = _get_importable_file_path(
  771. dynamic_modules_path=dynamic_modules_path,
  772. module_namespace='datasets',
  773. subdirectory_name=hash,
  774. name=repo_id,
  775. )
  776. if not os.path.exists(importable_file_path):
  777. trust_remote_code = resolve_trust_remote_code(trust_remote_code=self.trust_remote_code, repo_id=self.name)
  778. if trust_remote_code:
  779. logger.warning(f'Use trust_remote_code=True. Will invoke codes from {repo_id}. Please make sure that '
  780. 'you can trust the external codes.')
  781. _create_importable_file(
  782. local_path=local_script_path,
  783. local_imports=local_imports,
  784. additional_files=additional_files,
  785. dynamic_modules_path=dynamic_modules_path,
  786. module_namespace='datasets',
  787. subdirectory_name=hash,
  788. name=repo_id,
  789. download_mode=self.download_mode,
  790. )
  791. else:
  792. raise ValueError(
  793. f'Loading {repo_id} requires you to execute the dataset script in that'
  794. ' repo on your local machine. Make sure you have read the code there to avoid malicious use, then'
  795. ' set the option `trust_remote_code=True` to remove this error.'
  796. )
  797. module_path, hash = _load_importable_file(
  798. dynamic_modules_path=dynamic_modules_path,
  799. module_namespace='datasets',
  800. subdirectory_name=hash,
  801. name=repo_id,
  802. )
  803. # make the new module to be noticed by the import system
  804. importlib.invalidate_caches()
  805. builder_kwargs = {
  806. # "base_path": hf_hub_url(self.name, "", revision=self.revision).rstrip("/"),
  807. 'base_path': HubApi().get_file_base_path(repo_id=repo_id),
  808. 'repo_id': repo_id,
  809. }
  810. return DatasetModule(module_path, hash, builder_kwargs)
  811. class DatasetsWrapperHF:
  812. @staticmethod
  813. def load_dataset(
  814. path: str,
  815. name: Optional[str] = None,
  816. data_dir: Optional[str] = None,
  817. data_files: Optional[Union[str, Sequence[str],
  818. Mapping[str, Union[str,
  819. Sequence[str]]]]] = None,
  820. split: Optional[Union[str, Split]] = None,
  821. cache_dir: Optional[str] = None,
  822. features: Optional[Features] = None,
  823. download_config: Optional[DownloadConfig] = None,
  824. download_mode: Optional[Union[DownloadMode, str]] = None,
  825. verification_mode: Optional[Union[VerificationMode, str]] = None,
  826. keep_in_memory: Optional[bool] = None,
  827. save_infos: bool = False,
  828. revision: Optional[Union[str, Version]] = None,
  829. token: Optional[Union[bool, str]] = None,
  830. use_auth_token='deprecated',
  831. task='deprecated',
  832. streaming: bool = False,
  833. num_proc: Optional[int] = None,
  834. storage_options: Optional[Dict] = None,
  835. trust_remote_code: bool = False,
  836. dataset_info_only: Optional[bool] = False,
  837. **config_kwargs,
  838. ) -> Union[DatasetDict, Dataset, IterableDatasetDict, IterableDataset,
  839. dict]:
  840. if use_auth_token != 'deprecated':
  841. warnings.warn(
  842. "'use_auth_token' was deprecated in favor of 'token' in version 2.14.0 and will be removed in 3.0.0.\n"
  843. "You can remove this warning by passing 'token=<use_auth_token>' instead.",
  844. FutureWarning,
  845. )
  846. token = use_auth_token
  847. if task != 'deprecated':
  848. warnings.warn(
  849. "'task' was deprecated in version 2.13.0 and will be removed in 3.0.0.\n",
  850. FutureWarning,
  851. )
  852. else:
  853. task = None
  854. if data_files is not None and not data_files:
  855. raise ValueError(
  856. f"Empty 'data_files': '{data_files}'. It should be either non-empty or None (default)."
  857. )
  858. if Path(path, config.DATASET_STATE_JSON_FILENAME).exists(
  859. ):
  860. raise ValueError(
  861. 'You are trying to load a dataset that was saved using `save_to_disk`. '
  862. 'Please use `load_from_disk` instead.')
  863. if streaming and num_proc is not None:
  864. raise NotImplementedError(
  865. 'Loading a streaming dataset in parallel with `num_proc` is not implemented. '
  866. 'To parallelize streaming, you can wrap the dataset with a PyTorch DataLoader '
  867. 'using `num_workers` > 1 instead.')
  868. download_mode = DownloadMode(download_mode
  869. or DownloadMode.REUSE_DATASET_IF_EXISTS)
  870. verification_mode = VerificationMode((
  871. verification_mode or VerificationMode.BASIC_CHECKS
  872. ) if not save_infos else VerificationMode.ALL_CHECKS)
  873. if trust_remote_code:
  874. logger.warning(f'Use trust_remote_code=True. Will invoke codes from {path}. Please make sure '
  875. 'that you can trust the external codes.'
  876. )
  877. # Create a dataset builder
  878. builder_instance = DatasetsWrapperHF.load_dataset_builder(
  879. path=path,
  880. name=name,
  881. data_dir=data_dir,
  882. data_files=data_files,
  883. cache_dir=cache_dir,
  884. features=features,
  885. download_config=download_config,
  886. download_mode=download_mode,
  887. revision=revision,
  888. token=token,
  889. storage_options=storage_options,
  890. trust_remote_code=trust_remote_code,
  891. _require_default_config_name=name is None,
  892. **config_kwargs,
  893. )
  894. # Note: Only for preview mode
  895. if dataset_info_only:
  896. ret_dict = {}
  897. # Get dataset config info from python script
  898. if isinstance(path, str) and path.endswith('.py') and os.path.exists(path):
  899. from datasets import get_dataset_config_names
  900. subset_list = get_dataset_config_names(path)
  901. ret_dict = {_subset: [] for _subset in subset_list}
  902. return ret_dict
  903. if builder_instance is None or not hasattr(builder_instance,
  904. 'builder_configs'):
  905. logger.error(f'No builder_configs found for {path} dataset.')
  906. return ret_dict
  907. _tmp_builder_configs = builder_instance.builder_configs
  908. for tmp_config_name, tmp_builder_config in _tmp_builder_configs.items():
  909. tmp_config_name = str(tmp_config_name)
  910. if hasattr(tmp_builder_config, 'data_files') and tmp_builder_config.data_files is not None:
  911. ret_dict[tmp_config_name] = [str(item) for item in list(tmp_builder_config.data_files.keys())]
  912. else:
  913. ret_dict[tmp_config_name] = []
  914. return ret_dict
  915. # Return iterable dataset in case of streaming
  916. if streaming:
  917. return builder_instance.as_streaming_dataset(split=split)
  918. # Some datasets are already processed on the HF google storage
  919. # Don't try downloading from Google storage for the packaged datasets as text, json, csv or pandas
  920. # try_from_hf_gcs = path not in _PACKAGED_DATASETS_MODULES
  921. # Download and prepare data
  922. builder_instance.download_and_prepare(
  923. download_config=download_config,
  924. download_mode=download_mode,
  925. verification_mode=verification_mode,
  926. num_proc=num_proc,
  927. storage_options=storage_options,
  928. # base_path=builder_instance.base_path,
  929. # file_format=builder_instance.name or 'arrow',
  930. )
  931. # Build dataset for splits
  932. keep_in_memory = (
  933. keep_in_memory if keep_in_memory is not None else is_small_dataset(
  934. builder_instance.info.dataset_size))
  935. ds = builder_instance.as_dataset(
  936. split=split,
  937. verification_mode=verification_mode,
  938. in_memory=keep_in_memory)
  939. # Rename and cast features to match task schema
  940. if task is not None:
  941. # To avoid issuing the same warning twice
  942. with warnings.catch_warnings():
  943. warnings.simplefilter('ignore', FutureWarning)
  944. ds = ds.prepare_for_task(task)
  945. if save_infos:
  946. builder_instance._save_infos()
  947. try:
  948. _api = HubApi()
  949. if is_relative_path(path) and path.count('/') == 1:
  950. _namespace, _dataset_name = path.split('/')
  951. endpoint = _api.get_endpoint_for_read(
  952. repo_id=path, repo_type=REPO_TYPE_DATASET)
  953. _api.dataset_download_statistics(dataset_name=_dataset_name, namespace=_namespace, endpoint=endpoint)
  954. except Exception as e:
  955. logger.warning(f'Could not record download statistics: {e}')
  956. return ds
  957. @staticmethod
  958. def load_dataset_builder(
  959. path: str,
  960. name: Optional[str] = None,
  961. data_dir: Optional[str] = None,
  962. data_files: Optional[Union[str, Sequence[str],
  963. Mapping[str, Union[str,
  964. Sequence[str]]]]] = None,
  965. cache_dir: Optional[str] = None,
  966. features: Optional[Features] = None,
  967. download_config: Optional[DownloadConfig] = None,
  968. download_mode: Optional[Union[DownloadMode, str]] = None,
  969. revision: Optional[Union[str, Version]] = None,
  970. token: Optional[Union[bool, str]] = None,
  971. use_auth_token='deprecated',
  972. storage_options: Optional[Dict] = None,
  973. trust_remote_code: Optional[bool] = None,
  974. _require_default_config_name=True,
  975. **config_kwargs,
  976. ) -> DatasetBuilder:
  977. if use_auth_token != 'deprecated':
  978. warnings.warn(
  979. "'use_auth_token' was deprecated in favor of 'token' in version 2.14.0 and will be removed in 3.0.0.\n"
  980. "You can remove this warning by passing 'token=<use_auth_token>' instead.",
  981. FutureWarning,
  982. )
  983. token = use_auth_token
  984. download_mode = DownloadMode(download_mode
  985. or DownloadMode.REUSE_DATASET_IF_EXISTS)
  986. if token is not None:
  987. download_config = download_config.copy(
  988. ) if download_config else DownloadConfig()
  989. download_config.token = token
  990. if storage_options is not None:
  991. download_config = download_config.copy(
  992. ) if download_config else DownloadConfig()
  993. download_config.storage_options.update(storage_options)
  994. if trust_remote_code:
  995. logger.warning(f'Use trust_remote_code=True. Will invoke codes from {path}. Please make sure '
  996. 'that you can trust the external codes.'
  997. )
  998. dataset_module = DatasetsWrapperHF.dataset_module_factory(
  999. path,
  1000. revision=revision,
  1001. download_config=download_config,
  1002. download_mode=download_mode,
  1003. data_dir=data_dir,
  1004. data_files=data_files,
  1005. cache_dir=cache_dir,
  1006. trust_remote_code=trust_remote_code,
  1007. _require_default_config_name=_require_default_config_name,
  1008. _require_custom_configs=bool(config_kwargs),
  1009. name=name,
  1010. )
  1011. # Get dataset builder class from the processing script
  1012. builder_kwargs = dataset_module.builder_kwargs
  1013. data_dir = builder_kwargs.pop('data_dir', data_dir)
  1014. data_files = builder_kwargs.pop('data_files', data_files)
  1015. config_name = builder_kwargs.pop(
  1016. 'config_name', name
  1017. or dataset_module.builder_configs_parameters.default_config_name)
  1018. dataset_name = builder_kwargs.pop('dataset_name', None)
  1019. info = dataset_module.dataset_infos.get(
  1020. config_name) if dataset_module.dataset_infos else None
  1021. if (path in _PACKAGED_DATASETS_MODULES and data_files is None
  1022. and dataset_module.builder_configs_parameters.
  1023. builder_configs[0].data_files is None):
  1024. error_msg = f'Please specify the data files or data directory to load for the {path} dataset builder.'
  1025. example_extensions = [
  1026. extension for extension in _EXTENSION_TO_MODULE
  1027. if _EXTENSION_TO_MODULE[extension] == path
  1028. ]
  1029. if example_extensions:
  1030. error_msg += f'\nFor example `data_files={{"train": "path/to/data/train/*.{example_extensions[0]}"}}`'
  1031. raise ValueError(error_msg)
  1032. builder_cls = get_dataset_builder_class(
  1033. dataset_module, dataset_name=dataset_name)
  1034. builder_instance: DatasetBuilder = builder_cls(
  1035. cache_dir=cache_dir,
  1036. dataset_name=dataset_name,
  1037. config_name=config_name,
  1038. data_dir=data_dir,
  1039. data_files=data_files,
  1040. hash=dataset_module.hash,
  1041. info=info,
  1042. features=features,
  1043. token=token,
  1044. storage_options=storage_options,
  1045. **builder_kwargs, # contains base_path
  1046. **config_kwargs,
  1047. )
  1048. builder_instance._use_legacy_cache_dir_if_possible(dataset_module)
  1049. return builder_instance
  1050. @staticmethod
  1051. def dataset_module_factory(
  1052. path: str,
  1053. revision: Optional[Union[str, Version]] = None,
  1054. download_config: Optional[DownloadConfig] = None,
  1055. download_mode: Optional[Union[DownloadMode, str]] = None,
  1056. dynamic_modules_path: Optional[str] = None,
  1057. data_dir: Optional[str] = None,
  1058. data_files: Optional[Union[Dict, List, str, DataFilesDict]] = None,
  1059. cache_dir: Optional[str] = None,
  1060. trust_remote_code: Optional[bool] = None,
  1061. _require_default_config_name=True,
  1062. _require_custom_configs=False,
  1063. **download_kwargs,
  1064. ) -> DatasetModule:
  1065. subset_name: str = download_kwargs.pop('name', None)
  1066. revision = revision or DEFAULT_DATASET_REVISION
  1067. if download_config is None:
  1068. download_config = DownloadConfig(**download_kwargs)
  1069. download_config.storage_options.update({'name': subset_name})
  1070. download_config.storage_options.update({'revision': revision})
  1071. if download_config and download_config.cache_dir is None:
  1072. download_config.cache_dir = MS_DATASETS_CACHE
  1073. download_mode = DownloadMode(download_mode
  1074. or DownloadMode.REUSE_DATASET_IF_EXISTS)
  1075. download_config.extract_compressed_file = True
  1076. download_config.force_extract = True
  1077. download_config.force_download = download_mode == DownloadMode.FORCE_REDOWNLOAD
  1078. filename = list(
  1079. filter(lambda x: x,
  1080. path.replace(os.sep, '/').split('/')))[-1]
  1081. if not filename.endswith('.py'):
  1082. filename = filename + '.py'
  1083. combined_path = os.path.join(path, filename)
  1084. # We have several ways to get a dataset builder:
  1085. #
  1086. # - if path is the name of a packaged dataset module
  1087. # -> use the packaged module (json, csv, etc.)
  1088. #
  1089. # - if os.path.join(path, name) is a local python file
  1090. # -> use the module from the python file
  1091. # - if path is a local directory (but no python file)
  1092. # -> use a packaged module (csv, text etc.) based on content of the directory
  1093. #
  1094. # - if path has one "/" and is dataset repository on the HF hub with a python file
  1095. # -> the module from the python file in the dataset repository
  1096. # - if path has one "/" and is dataset repository on the HF hub without a python file
  1097. # -> use a packaged module (csv, text etc.) based on content of the repository
  1098. if trust_remote_code:
  1099. logger.warning(f'Use trust_remote_code=True. Will invoke codes from {path}. Please make sure '
  1100. 'that you can trust the external codes.'
  1101. )
  1102. # Try packaged
  1103. if path in _PACKAGED_DATASETS_MODULES:
  1104. return PackagedDatasetModuleFactory(
  1105. path,
  1106. data_dir=data_dir,
  1107. data_files=data_files,
  1108. download_config=download_config,
  1109. download_mode=download_mode,
  1110. ).get_module()
  1111. # Try locally
  1112. elif path.endswith(filename):
  1113. if os.path.isfile(path):
  1114. return LocalDatasetModuleFactoryWithScript(
  1115. path,
  1116. download_mode=download_mode,
  1117. dynamic_modules_path=dynamic_modules_path,
  1118. trust_remote_code=trust_remote_code,
  1119. ).get_module()
  1120. else:
  1121. raise FileNotFoundError(
  1122. f"Couldn't find a dataset script at {relative_to_absolute_path(path)}"
  1123. )
  1124. elif os.path.isfile(combined_path):
  1125. return LocalDatasetModuleFactoryWithScript(
  1126. combined_path,
  1127. download_mode=download_mode,
  1128. dynamic_modules_path=dynamic_modules_path,
  1129. trust_remote_code=trust_remote_code,
  1130. ).get_module()
  1131. elif os.path.isdir(path):
  1132. return LocalDatasetModuleFactoryWithoutScript(
  1133. path,
  1134. data_dir=data_dir,
  1135. data_files=data_files,
  1136. download_mode=download_mode).get_module()
  1137. # Try remotely
  1138. elif is_relative_path(path) and path.count('/') == 1:
  1139. try:
  1140. _raise_if_offline_mode_is_enabled()
  1141. try:
  1142. dataset_info = HfApi().dataset_info(
  1143. repo_id=path,
  1144. revision=revision,
  1145. token=download_config.token,
  1146. timeout=100.0,
  1147. )
  1148. except Exception as e: # noqa catch any exception of hf_hub and consider that the dataset doesn't exist
  1149. if isinstance(
  1150. e,
  1151. ( # noqa: E131
  1152. OfflineModeIsEnabled, # noqa: E131
  1153. requests.exceptions.
  1154. ConnectTimeout, # noqa: E131, E261
  1155. requests.exceptions.ConnectionError, # noqa: E131
  1156. ), # noqa: E131
  1157. ):
  1158. raise ConnectionError(
  1159. f"Couldn't reach '{path}' on the Hub ({type(e).__name__})"
  1160. )
  1161. elif '404' in str(e):
  1162. msg = f"Dataset '{path}' doesn't exist on the Hub"
  1163. raise DatasetNotFoundError(
  1164. msg
  1165. + f" at revision '{revision}'" if revision else msg
  1166. )
  1167. elif '401' in str(e):
  1168. msg = f"Dataset '{path}' doesn't exist on the Hub"
  1169. msg = msg + f" at revision '{revision}'" if revision else msg
  1170. raise DatasetNotFoundError(
  1171. msg + '. If the repo is private or gated, '
  1172. 'make sure to log in with `huggingface-cli login`.'
  1173. )
  1174. else:
  1175. raise e
  1176. dataset_readme_path = _download_repo_file(
  1177. repo_id=path,
  1178. path_in_repo='README.md',
  1179. download_config=download_config,
  1180. revision=revision,
  1181. )
  1182. commit_hash = os.path.basename(os.path.dirname(dataset_readme_path))
  1183. if filename in [
  1184. sibling.rfilename for sibling in dataset_info.siblings
  1185. ]: # contains a dataset script
  1186. # fs = HfFileSystem(
  1187. # endpoint=config.HF_ENDPOINT,
  1188. # token=download_config.token)
  1189. # TODO
  1190. can_load_config_from_parquet_export = False
  1191. # if _require_custom_configs:
  1192. # can_load_config_from_parquet_export = False
  1193. # elif _require_default_config_name:
  1194. # with fs.open(
  1195. # f'datasets/{path}/{filename}',
  1196. # 'r',
  1197. # revision=revision,
  1198. # encoding='utf-8') as f:
  1199. # can_load_config_from_parquet_export = 'DEFAULT_CONFIG_NAME' not in f.read(
  1200. # )
  1201. # else:
  1202. # can_load_config_from_parquet_export = True
  1203. if config.USE_PARQUET_EXPORT and can_load_config_from_parquet_export:
  1204. # If the parquet export is ready (parquet files + info available for the current sha),
  1205. # we can use it instead
  1206. # This fails when the dataset has multiple configs and a default config and
  1207. # the user didn't specify a configuration name (_require_default_config_name=True).
  1208. try:
  1209. if has_attr_in_class(HubDatasetModuleFactoryWithParquetExport, 'revision'):
  1210. return HubDatasetModuleFactoryWithParquetExport(
  1211. path,
  1212. revision=revision,
  1213. download_config=download_config).get_module()
  1214. return HubDatasetModuleFactoryWithParquetExport(
  1215. path,
  1216. commit_hash=commit_hash,
  1217. download_config=download_config).get_module()
  1218. except Exception as e:
  1219. logger.error(e)
  1220. # Otherwise we must use the dataset script if the user trusts it
  1221. # To be adapted to the old version of datasets
  1222. if has_attr_in_class(HubDatasetModuleFactoryWithScript, 'revision'):
  1223. return HubDatasetModuleFactoryWithScript(
  1224. path,
  1225. revision=revision,
  1226. download_config=download_config,
  1227. download_mode=download_mode,
  1228. dynamic_modules_path=dynamic_modules_path,
  1229. trust_remote_code=trust_remote_code,
  1230. ).get_module()
  1231. return HubDatasetModuleFactoryWithScript(
  1232. path,
  1233. commit_hash=commit_hash,
  1234. download_config=download_config,
  1235. download_mode=download_mode,
  1236. dynamic_modules_path=dynamic_modules_path,
  1237. trust_remote_code=trust_remote_code,
  1238. ).get_module()
  1239. else:
  1240. # To be adapted to the old version of datasets
  1241. if has_attr_in_class(HubDatasetModuleFactoryWithoutScript, 'revision'):
  1242. return HubDatasetModuleFactoryWithoutScript(
  1243. path,
  1244. revision=revision,
  1245. data_dir=data_dir,
  1246. data_files=data_files,
  1247. download_config=download_config,
  1248. download_mode=download_mode,
  1249. ).get_module()
  1250. return HubDatasetModuleFactoryWithoutScript(
  1251. path,
  1252. commit_hash=commit_hash,
  1253. data_dir=data_dir,
  1254. data_files=data_files,
  1255. download_config=download_config,
  1256. download_mode=download_mode,
  1257. ).get_module()
  1258. except Exception as e1:
  1259. # All the attempts failed, before raising the error we should check if the module is already cached
  1260. logger.error(f'>> Error loading {path}: {e1}')
  1261. try:
  1262. return CachedDatasetModuleFactory(
  1263. path,
  1264. dynamic_modules_path=dynamic_modules_path,
  1265. cache_dir=cache_dir).get_module()
  1266. except Exception:
  1267. # If it's not in the cache, then it doesn't exist.
  1268. if isinstance(e1, OfflineModeIsEnabled):
  1269. raise ConnectionError(
  1270. f"Couldn't reach the Hugging Face Hub for dataset '{path}': {e1}"
  1271. ) from None
  1272. if isinstance(e1,
  1273. (DataFilesNotFoundError,
  1274. DatasetNotFoundError, EmptyDatasetError)):
  1275. raise e1 from None
  1276. if isinstance(e1, FileNotFoundError):
  1277. raise FileNotFoundError(
  1278. f"Couldn't find a dataset script at {relative_to_absolute_path(combined_path)} or "
  1279. f'any data file in the same directory. '
  1280. f"Couldn't find '{path}' on the Hugging Face Hub either: {type(e1).__name__}: {e1}"
  1281. ) from None
  1282. raise e1 from None
  1283. else:
  1284. raise FileNotFoundError(
  1285. f"Couldn't find a dataset script at {relative_to_absolute_path(combined_path)} or "
  1286. f'any data file in the same directory.')
  1287. @contextlib.contextmanager
  1288. def load_dataset_with_ctx(*args, **kwargs):
  1289. # Keep the original functions
  1290. hf_endpoint_origin = config.HF_ENDPOINT
  1291. get_from_cache_origin = file_utils.get_from_cache
  1292. # Compatible with datasets 2.18.0
  1293. _download_origin = DownloadManager._download if hasattr(DownloadManager, '_download') \
  1294. else DownloadManager._download_single
  1295. dataset_info_origin = HfApi.dataset_info
  1296. list_repo_tree_origin = HfApi.list_repo_tree
  1297. get_paths_info_origin = HfApi.get_paths_info
  1298. resolve_pattern_origin = data_files.resolve_pattern
  1299. get_module_without_script_origin = HubDatasetModuleFactoryWithoutScript.get_module
  1300. get_module_with_script_origin = HubDatasetModuleFactoryWithScript.get_module
  1301. generate_from_dict_origin = features.generate_from_dict
  1302. # Monkey patching with modelscope functions
  1303. config.HF_ENDPOINT = get_endpoint()
  1304. file_utils.get_from_cache = get_from_cache_ms
  1305. # Compatible with datasets 2.18.0
  1306. if hasattr(DownloadManager, '_download'):
  1307. DownloadManager._download = _download_ms
  1308. else:
  1309. DownloadManager._download_single = _download_ms
  1310. HfApi.dataset_info = _dataset_info
  1311. HfApi.list_repo_tree = _list_repo_tree
  1312. HfApi.get_paths_info = _get_paths_info
  1313. data_files.resolve_pattern = _resolve_pattern
  1314. HubDatasetModuleFactoryWithoutScript.get_module = get_module_without_script
  1315. HubDatasetModuleFactoryWithScript.get_module = get_module_with_script
  1316. features.generate_from_dict = generate_from_dict_ms
  1317. streaming = kwargs.get('streaming', False)
  1318. try:
  1319. dataset_res = DatasetsWrapperHF.load_dataset(*args, **kwargs)
  1320. yield dataset_res
  1321. finally:
  1322. # Restore the original functions
  1323. config.HF_ENDPOINT = hf_endpoint_origin
  1324. file_utils.get_from_cache = get_from_cache_origin
  1325. features.generate_from_dict = generate_from_dict_origin
  1326. # Keep the context during the streaming iteration
  1327. if not streaming:
  1328. config.HF_ENDPOINT = hf_endpoint_origin
  1329. file_utils.get_from_cache = get_from_cache_origin
  1330. # Compatible with datasets 2.18.0
  1331. if hasattr(DownloadManager, '_download'):
  1332. DownloadManager._download = _download_origin
  1333. else:
  1334. DownloadManager._download_single = _download_origin
  1335. HfApi.dataset_info = dataset_info_origin
  1336. HfApi.list_repo_tree = list_repo_tree_origin
  1337. HfApi.get_paths_info = get_paths_info_origin
  1338. data_files.resolve_pattern = resolve_pattern_origin
  1339. HubDatasetModuleFactoryWithoutScript.get_module = get_module_without_script_origin
  1340. HubDatasetModuleFactoryWithScript.get_module = get_module_with_script_origin