ms_dataset.py 39 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900
  1. # Copyright (c) Alibaba, Inc. and its affiliates.
  2. import os
  3. import warnings
  4. from typing import (Any, Callable, Dict, Iterable, List, Mapping, Optional,
  5. Sequence, Union)
  6. import numpy as np
  7. from datasets import (Dataset, DatasetDict, Features, IterableDataset,
  8. IterableDatasetDict)
  9. from datasets.packaged_modules import _PACKAGED_DATASETS_MODULES
  10. from datasets.utils.file_utils import is_relative_path
  11. from modelscope.hub.repository import DatasetRepository
  12. from modelscope.msdatasets.context.dataset_context_config import \
  13. DatasetContextConfig
  14. from modelscope.msdatasets.data_loader.data_loader_manager import (
  15. LocalDataLoaderManager, LocalDataLoaderType, RemoteDataLoaderManager,
  16. RemoteDataLoaderType)
  17. from modelscope.msdatasets.dataset_cls import (ExternalDataset,
  18. NativeIterableDataset)
  19. from modelscope.msdatasets.dataset_cls.custom_datasets.builder import \
  20. build_custom_dataset
  21. from modelscope.msdatasets.utils.delete_utils import DatasetDeleteManager
  22. from modelscope.msdatasets.utils.hf_datasets_util import load_dataset_with_ctx
  23. from modelscope.msdatasets.utils.upload_utils import DatasetUploadManager
  24. from modelscope.preprocessors import build_preprocessor
  25. from modelscope.utils.config import Config, ConfigDict
  26. from modelscope.utils.config_ds import MS_DATASETS_CACHE
  27. from modelscope.utils.constant import (DEFAULT_DATASET_NAMESPACE,
  28. DEFAULT_DATASET_REVISION,
  29. REPO_TYPE_DATASET, ConfigFields,
  30. DatasetFormations, DownloadMode, Hubs,
  31. ModeKeys, Tasks, UploadMode)
  32. from modelscope.utils.import_utils import is_tf_available, is_torch_available
  33. from modelscope.utils.logger import get_logger
  34. logger = get_logger()
  35. def format_list(para) -> List:
  36. if para is None:
  37. para = []
  38. elif isinstance(para, str):
  39. para = [para]
  40. elif len(set(para)) < len(para):
  41. raise ValueError(f'List columns contains duplicates: {para}')
  42. return para
  43. class MsDataset:
  44. """
  45. ModelScope Dataset (aka, MsDataset) is backed by a huggingface Dataset to
  46. provide efficient data access and local storage managements. On top of
  47. that, MsDataset supports the data integration and interactions with multiple
  48. remote hubs, particularly, ModelScope's own Dataset-hub. MsDataset also
  49. abstracts away data-access details with other remote storage, including both
  50. general external web-hosted data and cloud storage such as OSS.
  51. """
  52. # the underlying huggingface Dataset
  53. _hf_ds = None
  54. _dataset_context_config: DatasetContextConfig = None
  55. def __init__(self,
  56. ds_instance: Union[Dataset, IterableDataset, ExternalDataset,
  57. NativeIterableDataset],
  58. target: Optional[str] = None):
  59. self._hf_ds = ds_instance
  60. if target is not None and target not in self._hf_ds.features:
  61. raise TypeError(
  62. f'"target" must be a column of the dataset({list(self._hf_ds.features.keys())}, but got {target}'
  63. )
  64. self.target = target
  65. self.is_custom = False
  66. def __iter__(self):
  67. for item in self._hf_ds:
  68. if self.target is not None:
  69. yield item[self.target]
  70. else:
  71. yield item
  72. def __getitem__(self, key):
  73. return self._hf_ds[key]
  74. def __len__(self):
  75. return len(self._hf_ds)
  76. @property
  77. def ds_instance(self):
  78. return self._hf_ds
  79. @property
  80. def config_kwargs(self):
  81. if isinstance(self._hf_ds, ExternalDataset):
  82. return self._hf_ds.config_kwargs
  83. else:
  84. return None
  85. @classmethod
  86. def from_hf_dataset(cls,
  87. hf_ds: Union[Dataset, DatasetDict, ExternalDataset],
  88. target: str = None) -> Union[dict, 'MsDataset']:
  89. r"""
  90. @deprecated
  91. This method is deprecated and may be removed in future releases, please use `to_ms_dataset()` instead.
  92. """
  93. warnings.warn(
  94. 'from_hf_dataset is deprecated, please use to_ms_dataset instead.',
  95. DeprecationWarning)
  96. if isinstance(hf_ds, Dataset):
  97. return cls(hf_ds, target)
  98. elif isinstance(hf_ds, DatasetDict):
  99. if len(hf_ds.keys()) == 1:
  100. return cls(next(iter(hf_ds.values())), target)
  101. return {k: cls(v, target) for k, v in hf_ds.items()}
  102. elif isinstance(hf_ds, ExternalDataset):
  103. return cls(hf_ds)
  104. else:
  105. raise TypeError(
  106. f'"hf_ds" must be a Dataset or DatasetDict, but got {type(hf_ds)}'
  107. )
  108. @classmethod
  109. def to_ms_dataset(cls,
  110. ds_instance: Union[Dataset, DatasetDict, ExternalDataset,
  111. NativeIterableDataset,
  112. IterableDataset, IterableDatasetDict],
  113. target: str = None) -> Union[dict, 'MsDataset']:
  114. """Convert input to `MsDataset` instance."""
  115. if isinstance(ds_instance, Dataset):
  116. return cls(ds_instance, target)
  117. elif isinstance(ds_instance, DatasetDict):
  118. if len(ds_instance.keys()) == 1:
  119. return cls(next(iter(ds_instance.values())), target)
  120. return {k: cls(v, target) for k, v in ds_instance.items()}
  121. elif isinstance(ds_instance, ExternalDataset):
  122. return cls(ds_instance)
  123. elif isinstance(ds_instance, NativeIterableDataset):
  124. return cls(ds_instance)
  125. elif isinstance(ds_instance, IterableDataset):
  126. return cls(ds_instance)
  127. elif isinstance(ds_instance, IterableDatasetDict):
  128. if len(ds_instance.keys()) == 1:
  129. return cls(next(iter(ds_instance.values())), target)
  130. return {k: cls(v, target) for k, v in ds_instance.items()}
  131. else:
  132. raise TypeError(
  133. f'"ds_instance" must be a Dataset or DatasetDict, but got {type(ds_instance)}'
  134. )
  135. @staticmethod
  136. def load(
  137. dataset_name: Union[str, list],
  138. namespace: Optional[str] = DEFAULT_DATASET_NAMESPACE,
  139. target: Optional[str] = None,
  140. version: Optional[str] = DEFAULT_DATASET_REVISION,
  141. hub: Optional[Hubs] = Hubs.modelscope,
  142. subset_name: Optional[str] = None,
  143. split: Optional[str] = None,
  144. data_dir: Optional[str] = None,
  145. data_files: Optional[Union[str, Sequence[str],
  146. Mapping[str, Union[str,
  147. Sequence[str]]]]] = None,
  148. download_mode: Optional[DownloadMode] = DownloadMode.
  149. REUSE_DATASET_IF_EXISTS,
  150. cache_dir: Optional[str] = MS_DATASETS_CACHE,
  151. features: Optional[Features] = None,
  152. use_streaming: Optional[bool] = False,
  153. stream_batch_size: Optional[int] = 1,
  154. custom_cfg: Optional[Config] = Config(),
  155. token: Optional[str] = None,
  156. dataset_info_only: Optional[bool] = False,
  157. trust_remote_code: Optional[bool] = False,
  158. **config_kwargs,
  159. ) -> Union[dict, 'MsDataset', NativeIterableDataset]:
  160. """Load a MsDataset from the ModelScope Hub, Hugging Face Hub, urls, or a local dataset.
  161. Args:
  162. dataset_name (str): Path or name of the dataset.
  163. The form of `namespace/dataset_name` is also supported.
  164. namespace(str, optional): Namespace of the dataset. It should not be None if you load a remote dataset
  165. from Hubs.modelscope,
  166. namespace (str, optional):
  167. Namespace of the dataset. It should not be None if you load a remote dataset
  168. from Hubs.modelscope,
  169. target (str, optional): Name of the column to output.
  170. version (str, optional): Version of the dataset script to load:
  171. subset_name (str, optional): Defining the subset_name of the dataset.
  172. data_dir (str, optional): Defining the data_dir of the dataset configuration. I
  173. data_files (str or Sequence or Mapping, optional): Path(s) to source data file(s).
  174. split (str, optional): Which split of the data to load.
  175. hub (Hubs or str, optional): When loading from a remote hub, where it is from. default Hubs.modelscope
  176. download_mode (DownloadMode or str, optional): How to treat existing datasets. default
  177. DownloadMode.REUSE_DATASET_IF_EXISTS
  178. cache_dir (str, Optional): User-define local cache directory.
  179. use_streaming (bool, Optional): If set to True, no need to download all data files.
  180. Instead, it streams the data progressively, and returns
  181. NativeIterableDataset or a dict of NativeIterableDataset.
  182. stream_batch_size (int, Optional): The batch size of the streaming data.
  183. custom_cfg (str, Optional): Model configuration, this can be used for custom datasets.
  184. see https://modelscope.cn/docs/Configuration%E8%AF%A6%E8%A7%A3
  185. token (str, Optional): SDK token of ModelScope.
  186. dataset_info_only (bool, Optional): If set to True, only return the dataset config and info (dict).
  187. trust_remote_code (bool, Optional): If set to True, trust the remote code. Default to `False`.
  188. **config_kwargs (additional keyword arguments): Keyword arguments to be passed
  189. Returns:
  190. MsDataset (MsDataset): MsDataset object for a certain dataset.
  191. """
  192. if token:
  193. from modelscope.hub.api import HubApi
  194. api = HubApi()
  195. api.login(token)
  196. download_mode = DownloadMode(download_mode
  197. or DownloadMode.REUSE_DATASET_IF_EXISTS)
  198. hub = Hubs(hub or Hubs.modelscope)
  199. is_huggingface_hub = (hub == Hubs.huggingface)
  200. if not isinstance(dataset_name, str) and not isinstance(
  201. dataset_name, list):
  202. raise TypeError(
  203. f'dataset_name must be `str` or `list`, but got {type(dataset_name)}'
  204. )
  205. if isinstance(dataset_name, list):
  206. if target is None:
  207. target = 'target'
  208. dataset_inst = Dataset.from_dict({target: dataset_name})
  209. return MsDataset.to_ms_dataset(dataset_inst, target=target)
  210. dataset_name = os.path.expanduser(dataset_name)
  211. is_local_path = os.path.exists(dataset_name)
  212. if is_relative_path(dataset_name) and dataset_name.count(
  213. '/') == 1 and not is_local_path and not is_huggingface_hub:
  214. dataset_name_split = dataset_name.split('/')
  215. namespace = dataset_name_split[0].strip()
  216. dataset_name = dataset_name_split[1].strip()
  217. if not namespace or not dataset_name:
  218. raise 'The dataset_name should be in the form of `namespace/dataset_name` or `dataset_name`.'
  219. if trust_remote_code:
  220. logger.warning(
  221. f'Use trust_remote_code=True. Will invoke codes from {dataset_name}. Please make sure that '
  222. 'you can trust the external codes.')
  223. # Init context config
  224. dataset_context_config = DatasetContextConfig(
  225. dataset_name=dataset_name,
  226. namespace=namespace,
  227. version=version,
  228. subset_name=subset_name,
  229. split=split,
  230. target=target,
  231. hub=hub,
  232. data_dir=data_dir,
  233. data_files=data_files,
  234. download_mode=download_mode,
  235. cache_root_dir=cache_dir,
  236. use_streaming=use_streaming,
  237. stream_batch_size=stream_batch_size,
  238. trust_remote_code=trust_remote_code,
  239. **config_kwargs)
  240. # Load from local disk
  241. if dataset_name in _PACKAGED_DATASETS_MODULES or os.path.isdir(
  242. dataset_name) or os.path.isfile(dataset_name):
  243. dataset_inst = LocalDataLoaderManager(
  244. dataset_context_config).load_dataset(
  245. LocalDataLoaderType.HF_DATA_LOADER)
  246. dataset_inst = MsDataset.to_ms_dataset(dataset_inst, target=target)
  247. if isinstance(dataset_inst, MsDataset):
  248. dataset_inst._dataset_context_config = dataset_context_config
  249. if custom_cfg:
  250. dataset_inst.to_custom_dataset(
  251. custom_cfg=custom_cfg, **config_kwargs)
  252. dataset_inst.is_custom = True
  253. return dataset_inst
  254. # Load from the huggingface hub
  255. elif hub == Hubs.huggingface:
  256. from datasets import load_dataset
  257. return load_dataset(
  258. dataset_name,
  259. name=subset_name,
  260. split=split,
  261. streaming=use_streaming,
  262. download_mode=download_mode.value,
  263. trust_remote_code=trust_remote_code,
  264. **config_kwargs)
  265. # Load from the modelscope hub
  266. elif hub == Hubs.modelscope:
  267. # Get dataset type from ModelScope Hub; dataset_type->4: General Dataset
  268. from modelscope.hub.api import HubApi
  269. _api = HubApi()
  270. endpoint = _api.get_endpoint_for_read(
  271. repo_id=namespace + '/' + dataset_name,
  272. repo_type=REPO_TYPE_DATASET)
  273. dataset_id_on_hub, dataset_type = _api.get_dataset_id_and_type(
  274. dataset_name=dataset_name,
  275. namespace=namespace,
  276. endpoint=endpoint)
  277. # Load from the ModelScope Hub for type=4 (general)
  278. if str(dataset_type) == str(DatasetFormations.general.value):
  279. with load_dataset_with_ctx(
  280. path=namespace + '/' + dataset_name,
  281. name=subset_name,
  282. data_dir=data_dir,
  283. data_files=data_files,
  284. split=split,
  285. cache_dir=cache_dir,
  286. features=features,
  287. download_config=None,
  288. download_mode=download_mode.value,
  289. revision=version,
  290. token=token,
  291. streaming=use_streaming,
  292. dataset_info_only=dataset_info_only,
  293. trust_remote_code=trust_remote_code,
  294. **config_kwargs) as dataset_res:
  295. return dataset_res
  296. else:
  297. remote_dataloader_manager = RemoteDataLoaderManager(
  298. dataset_context_config)
  299. dataset_inst = remote_dataloader_manager.load_dataset(
  300. RemoteDataLoaderType.MS_DATA_LOADER)
  301. dataset_inst = MsDataset.to_ms_dataset(
  302. dataset_inst, target=target)
  303. if isinstance(dataset_inst, MsDataset):
  304. dataset_inst._dataset_context_config = remote_dataloader_manager.dataset_context_config
  305. if custom_cfg:
  306. dataset_inst.to_custom_dataset(
  307. custom_cfg=custom_cfg, **config_kwargs)
  308. dataset_inst.is_custom = True
  309. return dataset_inst
  310. elif hub == Hubs.virgo:
  311. warnings.warn(
  312. 'The option `Hubs.virgo` is deprecated, '
  313. 'will be removed in the future version.', DeprecationWarning)
  314. from modelscope.msdatasets.data_loader.data_loader import VirgoDownloader
  315. from modelscope.utils.constant import VirgoDatasetConfig
  316. # Rewrite the namespace, version and cache_dir for virgo dataset.
  317. if namespace == DEFAULT_DATASET_NAMESPACE:
  318. dataset_context_config.namespace = VirgoDatasetConfig.default_virgo_namespace
  319. if version == DEFAULT_DATASET_REVISION:
  320. dataset_context_config.version = VirgoDatasetConfig.default_dataset_version
  321. if cache_dir == MS_DATASETS_CACHE:
  322. from modelscope.utils.config_ds import CACHE_HOME
  323. cache_dir = os.path.join(CACHE_HOME, 'virgo', 'hub',
  324. 'datasets')
  325. dataset_context_config.cache_root_dir = cache_dir
  326. virgo_downloader = VirgoDownloader(dataset_context_config)
  327. virgo_downloader.process()
  328. return virgo_downloader.dataset
  329. else:
  330. raise 'Please adjust input args to specify a loading mode, we support following scenes: ' \
  331. 'loading from local disk, huggingface hub and modelscope hub.'
  332. @staticmethod
  333. def upload(
  334. object_name: str,
  335. local_file_path: str,
  336. dataset_name: str,
  337. namespace: Optional[str] = DEFAULT_DATASET_NAMESPACE,
  338. version: Optional[str] = DEFAULT_DATASET_REVISION,
  339. num_processes: Optional[int] = None,
  340. chunksize: Optional[int] = 1,
  341. filter_hidden_files: Optional[bool] = True,
  342. upload_mode: Optional[UploadMode] = UploadMode.OVERWRITE) -> None:
  343. r"""
  344. @deprecated
  345. This method is deprecated and may be removed in future releases, please use git command line instead.
  346. """
  347. """Upload dataset file or directory to the ModelScope Hub. Please log in to the ModelScope Hub first.
  348. Args:
  349. object_name (str): The object name on ModelScope, in the form of your-dataset-name.zip or your-dataset-name
  350. local_file_path (str): Local file or directory to upload
  351. dataset_name (str): Name of the dataset
  352. namespace(str, optional): Namespace of the dataset
  353. version: Optional[str]: Version of the dataset
  354. num_processes: Optional[int]: The number of processes used for multiprocess uploading.
  355. This is only applicable when local_file_path is a directory, and we are uploading mutliple-files
  356. inside the directory. When None provided, the number returned by os.cpu_count() is used as default.
  357. chunksize: Optional[int]: The chunksize of objects to upload.
  358. For very long iterables using a large value for chunksize can make the job complete much faster than
  359. using the default value of 1. Available if local_file_path is a directory.
  360. filter_hidden_files: Optional[bool]: Whether to filter hidden files.
  361. Available if local_file_path is a directory.
  362. upload_mode: Optional[UploadMode]: How to upload objects from local. Default: UploadMode.OVERWRITE, upload
  363. all objects from local, existing remote objects may be overwritten.
  364. Returns:
  365. None
  366. """
  367. warnings.warn(
  368. 'The function `upload` is deprecated, '
  369. 'please use git command '
  370. 'or modelscope.hub.api.HubApi.upload_folder '
  371. 'or modelscope.hub.api.HubApi.upload_file.', DeprecationWarning)
  372. if not object_name:
  373. raise ValueError('object_name cannot be empty!')
  374. _upload_manager = DatasetUploadManager(
  375. dataset_name=dataset_name, namespace=namespace, version=version)
  376. upload_mode = UploadMode(upload_mode or UploadMode.OVERWRITE)
  377. if os.path.isfile(local_file_path):
  378. _upload_manager.upload(
  379. object_name=object_name,
  380. local_file_path=local_file_path,
  381. upload_mode=upload_mode)
  382. elif os.path.isdir(local_file_path):
  383. _upload_manager.upload_dir(
  384. object_dir_name=object_name,
  385. local_dir_path=local_file_path,
  386. num_processes=num_processes,
  387. chunksize=chunksize,
  388. filter_hidden_files=filter_hidden_files,
  389. upload_mode=upload_mode)
  390. else:
  391. raise ValueError(
  392. f'{local_file_path} is not a valid file path or directory')
  393. @staticmethod
  394. def clone_meta(dataset_work_dir: str,
  395. dataset_id: str,
  396. revision: Optional[str] = DEFAULT_DATASET_REVISION,
  397. auth_token: Optional[str] = None,
  398. git_path: Optional[str] = None) -> None:
  399. """Clone meta-file of dataset from the ModelScope Hub.
  400. Args:
  401. dataset_work_dir (str): Current git working directory.
  402. dataset_id (str): Dataset id, in the form of your-namespace/your-dataset-name .
  403. revision (str, optional):
  404. revision of the model you want to clone from. Can be any of a branch, tag or commit hash
  405. auth_token (str, optional):
  406. token obtained when calling `HubApi.login()`. Usually you can safely ignore the parameter
  407. as the token is already saved when you login the first time, if None, we will use saved token.
  408. git_path (str, optional):
  409. The git command line path, if None, we use 'git'
  410. Returns:
  411. None
  412. """
  413. warnings.warn(
  414. 'The function `clone_meta` is deprecated, please use git command line to clone the repo.',
  415. DeprecationWarning)
  416. _repo = DatasetRepository(
  417. repo_work_dir=dataset_work_dir,
  418. dataset_id=dataset_id,
  419. revision=revision,
  420. auth_token=auth_token,
  421. git_path=git_path)
  422. clone_work_dir = _repo.clone()
  423. if clone_work_dir:
  424. logger.info('Already cloned repo to: {}'.format(clone_work_dir))
  425. else:
  426. logger.warning(
  427. 'Repo dir already exists: {}'.format(clone_work_dir))
  428. @staticmethod
  429. def upload_meta(dataset_work_dir: str,
  430. commit_message: str,
  431. revision: Optional[str] = DEFAULT_DATASET_REVISION,
  432. auth_token: Optional[str] = None,
  433. git_path: Optional[str] = None,
  434. force: bool = False) -> None:
  435. """Upload meta-file of dataset to the ModelScope Hub. Please clone the meta-data from the ModelScope Hub first.
  436. Args:
  437. dataset_work_dir (str): Current working directory.
  438. commit_message (str): Commit message.
  439. revision(`Optional[str]`):
  440. revision of the model you want to clone from. Can be any of a branch, tag or commit hash
  441. auth_token(`Optional[str]`):
  442. token obtained when calling `HubApi.login()`. Usually you can safely ignore the parameter
  443. as the token is already saved when you log in the first time, if None, we will use saved token.
  444. git_path:(`Optional[str]`):
  445. The git command line path, if None, we use 'git'
  446. force (Optional[bool]): whether to use forced-push.
  447. Returns:
  448. None
  449. """
  450. warnings.warn(
  451. 'The function `upload_meta` is deprecated, '
  452. 'please use git command '
  453. 'or CLI `modelscope upload owner_name/repo_name ...`.',
  454. DeprecationWarning)
  455. _repo = DatasetRepository(
  456. repo_work_dir=dataset_work_dir,
  457. dataset_id='',
  458. revision=revision,
  459. auth_token=auth_token,
  460. git_path=git_path)
  461. _repo.push(commit_message=commit_message, branch=revision, force=force)
  462. @staticmethod
  463. def delete(object_name: str,
  464. dataset_name: str,
  465. namespace: Optional[str] = DEFAULT_DATASET_NAMESPACE,
  466. version: Optional[str] = DEFAULT_DATASET_REVISION) -> str:
  467. """ Delete object of dataset. Please log in first and make sure you have permission to manage the dataset.
  468. Args:
  469. object_name (str): The object name of dataset to be deleted. Could be a name of file or directory. If it's
  470. directory, then ends with `/`.
  471. For example: your-data-name.zip, train/001/img_001.png, train/, ...
  472. dataset_name (str): Path or name of the dataset.
  473. namespace(str, optional): Namespace of the dataset.
  474. version (str, optional): Version of the dataset.
  475. Returns:
  476. res_msg (str): Response message.
  477. """
  478. _delete_manager = DatasetDeleteManager(
  479. dataset_name=dataset_name, namespace=namespace, version=version)
  480. resp_msg = _delete_manager.delete(object_name=object_name)
  481. logger.info(f'Object {object_name} successfully removed!')
  482. return resp_msg
  483. def to_torch_dataset(
  484. self,
  485. columns: Union[str, List[str]] = None,
  486. preprocessors: Union[Callable, List[Callable]] = None,
  487. task_name: str = None,
  488. data_config: ConfigDict = None,
  489. to_tensor: bool = True,
  490. **format_kwargs,
  491. ):
  492. """Create a torch.utils.data.Dataset from the MS Dataset. The torch.utils.data.Dataset can be passed to
  493. torch.utils.data.DataLoader.
  494. Args:
  495. preprocessors (Callable or List[Callable], default None): (list of) Preprocessor object used to process
  496. every sample of the dataset. The output type of processors is dict, and each (numeric) field of the dict
  497. will be used as a field of torch.utils.data.Dataset.
  498. columns (str or List[str], default None): Dataset column(s) to be loaded (numeric data only if
  499. `to_tensor` is True). If the preprocessor is None, the arg columns must have at least one column.
  500. If the `preprocessors` is not None, the output fields of processors will also be added.
  501. task_name (str, default None): task name, refer to :obj:`Tasks` for more details
  502. data_config (ConfigDict, default None): config dict for model object.
  503. Attributes of ConfigDict:
  504. `preprocessor` (Callable, List[Callable], optional): preprocessors to deal with dataset
  505. `type` (str): the type of task
  506. `split_config` (dict, optional): get the split config for ExternalDataset
  507. `test_mode` (bool, optional): is test mode or not
  508. to_tensor (bool, default None): whether convert the data types of dataset column(s) to torch.tensor or not.
  509. format_kwargs: A `dict` of arguments to be passed to the `torch.tensor`.
  510. Returns:
  511. :class:`torch.utils.data.Dataset`
  512. """
  513. if not is_torch_available():
  514. raise ImportError(
  515. 'The function to_torch_dataset requires pytorch to be installed'
  516. )
  517. if isinstance(self._hf_ds, ExternalDataset):
  518. data_config.update({'preprocessor': preprocessors})
  519. data_config.update(self._hf_ds.config_kwargs)
  520. return build_custom_dataset(data_config, task_name)
  521. if preprocessors is not None:
  522. return self._to_torch_dataset_with_processors(
  523. preprocessors, columns=columns, to_tensor=to_tensor)
  524. else:
  525. self._hf_ds.reset_format()
  526. self._hf_ds.set_format(
  527. type='torch', columns=columns, format_kwargs=format_kwargs)
  528. return self._hf_ds
  529. def to_tf_dataset(
  530. self,
  531. batch_size: int,
  532. shuffle: bool,
  533. preprocessors: Union[Callable, List[Callable]] = None,
  534. columns: Union[str, List[str]] = None,
  535. collate_fn: Callable = None,
  536. drop_remainder: bool = None,
  537. collate_fn_args: Dict[str, Any] = None,
  538. label_cols: Union[str, List[str]] = None,
  539. prefetch: bool = True,
  540. ):
  541. """Create a tf.data.Dataset from the MS Dataset. This tf.data.Dataset can be passed to tf methods like
  542. model.fit() or model.predict().
  543. Args:
  544. batch_size (int): Number of samples in a single batch.
  545. shuffle(bool): Shuffle the dataset order.
  546. preprocessors (Callable or List[Callable], default None): (list of) Preprocessor object used to process
  547. every sample of the dataset. The output type of processors is dict, and each field of the dict will be
  548. used as a field of the tf.data. Dataset. If the `preprocessors` is None, the `collate_fn`
  549. shouldn't be None.
  550. columns (str or List[str], default None): Dataset column(s) to be loaded. If the preprocessor is None,
  551. the arg columns must have at least one column. If the `preprocessors` is not None, the output fields of
  552. processors will also be added.
  553. collate_fn(Callable, default None): A callable object used to collect lists of samples into a batch. If
  554. the `preprocessors` is None, the `collate_fn` shouldn't be None.
  555. drop_remainder(bool, default None): Drop the last incomplete batch when loading.
  556. collate_fn_args (Dict, optional): A `dict` of arguments to be passed to the`collate_fn`.
  557. label_cols (str or List[str], default None): Dataset column(s) to load as labels.
  558. prefetch (bool, default True): Prefetch data.
  559. Returns:
  560. :class:`tf.data.Dataset`
  561. """
  562. if not is_tf_available():
  563. raise ImportError(
  564. 'The function to_tf_dataset requires Tensorflow to be installed.'
  565. )
  566. if preprocessors is not None:
  567. return self._to_tf_dataset_with_processors(
  568. batch_size,
  569. shuffle,
  570. preprocessors,
  571. drop_remainder=drop_remainder,
  572. prefetch=prefetch,
  573. label_cols=label_cols,
  574. columns=columns)
  575. if collate_fn is None:
  576. logger.error(
  577. 'The `preprocessors` and the `collate_fn` should`t be both None.'
  578. )
  579. return None
  580. self._hf_ds.reset_format()
  581. return self._hf_ds.to_tf_dataset(
  582. columns,
  583. batch_size,
  584. shuffle,
  585. collate_fn,
  586. drop_remainder=drop_remainder,
  587. collate_fn_args=collate_fn_args,
  588. label_cols=label_cols,
  589. prefetch=prefetch)
  590. def to_hf_dataset(self) -> Dataset:
  591. self._hf_ds.reset_format()
  592. return self._hf_ds
  593. def remap_columns(self, column_mapping: Dict[str, str]) -> Dataset:
  594. """
  595. Rename columns and return the underlying hf dataset directly
  596. TODO: support native MsDataset column rename.
  597. Args:
  598. column_mapping: the mapping of the original and new column names
  599. Returns:
  600. underlying hf dataset
  601. """
  602. self._hf_ds.reset_format()
  603. return self._hf_ds.rename_columns(column_mapping)
  604. def _to_torch_dataset_with_processors(
  605. self,
  606. preprocessors: Union[Callable, List[Callable]],
  607. columns: Union[str, List[str]] = None,
  608. to_tensor: bool = True,
  609. ):
  610. preprocessor_list = preprocessors if isinstance(
  611. preprocessors, list) else [preprocessors]
  612. columns = format_list(columns)
  613. columns = [
  614. key for key in self._hf_ds.features.keys() if key in columns
  615. ]
  616. retained_numeric_columns = []
  617. retained_unumeric_columns = []
  618. if to_tensor:
  619. sample = next(iter(self._hf_ds))
  620. sample_res = {k: np.array(sample[k]) for k in columns}
  621. for processor in preprocessor_list:
  622. sample_res.update(
  623. {k: np.array(v)
  624. for k, v in processor(sample).items()})
  625. def is_numpy_number(value):
  626. return np.issubdtype(value.dtype, np.integer) or np.issubdtype(
  627. value.dtype, np.floating)
  628. for k in sample_res.keys():
  629. if not is_numpy_number(sample_res[k]):
  630. logger.warning(
  631. f'Data of column {k} is non-numeric, will be removed')
  632. retained_unumeric_columns.append(k)
  633. continue
  634. retained_numeric_columns.append(k)
  635. import torch
  636. class MsMapDataset(torch.utils.data.Dataset):
  637. def __init__(self, dataset: Iterable, preprocessor_list,
  638. retained_numeric_columns, retained_unumeric_columns,
  639. columns, to_tensor):
  640. super(MsDataset).__init__()
  641. self.dataset = dataset
  642. self.preprocessor_list = preprocessor_list
  643. self.to_tensor = to_tensor
  644. self.retained_numeric_columns = retained_numeric_columns
  645. self.retained_unumeric_columns = retained_unumeric_columns
  646. self.columns = columns
  647. def __len__(self):
  648. return len(self.dataset)
  649. def type_converter(self, x):
  650. if self.to_tensor:
  651. return torch.as_tensor(x)
  652. else:
  653. return x
  654. def __getitem__(self, index):
  655. item_dict = self.dataset[index]
  656. res = {
  657. k: self.type_converter(item_dict[k])
  658. for k in self.columns if (not self.to_tensor)
  659. or k in self.retained_numeric_columns
  660. }
  661. for preprocessor in self.preprocessor_list:
  662. for k, v in preprocessor(item_dict).items():
  663. if (not self.to_tensor) or \
  664. k in self.retained_numeric_columns:
  665. res[k] = self.type_converter(v)
  666. elif k in self.retained_unumeric_columns:
  667. res[k] = v
  668. return res
  669. return MsMapDataset(self._hf_ds, preprocessor_list,
  670. retained_numeric_columns,
  671. retained_unumeric_columns, columns, to_tensor)
  672. def _to_tf_dataset_with_processors(
  673. self,
  674. batch_size: int,
  675. shuffle: bool,
  676. preprocessors: Union[Callable, List[Callable]],
  677. drop_remainder: bool = None,
  678. prefetch: bool = True,
  679. label_cols: Union[str, List[str]] = None,
  680. columns: Union[str, List[str]] = None,
  681. ):
  682. preprocessor_list = preprocessors if isinstance(
  683. preprocessors, list) else [preprocessors]
  684. label_cols = format_list(label_cols)
  685. columns = format_list(columns)
  686. cols_to_retain = list(set(label_cols + columns))
  687. retained_columns = [
  688. key for key in self._hf_ds.features.keys() if key in cols_to_retain
  689. ]
  690. import tensorflow as tf
  691. tf_dataset = tf.data.Dataset.from_tensor_slices(
  692. np.arange(len(self._hf_ds), dtype=np.int64))
  693. if shuffle:
  694. tf_dataset = tf_dataset.shuffle(buffer_size=len(self._hf_ds))
  695. def func(i, return_dict=False):
  696. i = int(i)
  697. res = {k: np.array(self._hf_ds[i][k]) for k in retained_columns}
  698. for preprocessor in preprocessor_list:
  699. # TODO preprocessor output may have the same key
  700. res.update({
  701. k: np.array(v)
  702. for k, v in preprocessor(self._hf_ds[i]).items()
  703. })
  704. if return_dict:
  705. return res
  706. return tuple(list(res.values()))
  707. sample_res = func(0, True)
  708. @tf.function(input_signature=[tf.TensorSpec(None, tf.int64)])
  709. def fetch_function(i):
  710. output = tf.numpy_function(
  711. func,
  712. inp=[i],
  713. Tout=[
  714. tf.dtypes.as_dtype(val.dtype)
  715. for val in sample_res.values()
  716. ],
  717. )
  718. return {key: output[i] for i, key in enumerate(sample_res)}
  719. from tensorflow.data.experimental import AUTOTUNE
  720. tf_dataset = tf_dataset.map(
  721. fetch_function, num_parallel_calls=AUTOTUNE)
  722. if label_cols:
  723. def split_features_and_labels(input_batch):
  724. labels = {
  725. key: tensor
  726. for key, tensor in input_batch.items() if key in label_cols
  727. }
  728. if len(input_batch) == 1:
  729. input_batch = next(iter(input_batch.values()))
  730. if len(labels) == 1:
  731. labels = next(iter(labels.values()))
  732. return input_batch, labels
  733. tf_dataset = tf_dataset.map(split_features_and_labels)
  734. elif len(columns) == 1:
  735. tf_dataset = tf_dataset.map(lambda x: next(iter(x.values())))
  736. if batch_size > 1:
  737. tf_dataset = tf_dataset.batch(
  738. batch_size, drop_remainder=drop_remainder)
  739. if prefetch:
  740. tf_dataset = tf_dataset.prefetch(AUTOTUNE)
  741. return tf_dataset
  742. def to_custom_dataset(self,
  743. custom_cfg: Config,
  744. preprocessor=None,
  745. mode=None,
  746. **kwargs):
  747. """Convert the input datasets to specific custom datasets by given model configuration and preprocessor.
  748. Args:
  749. custom_cfg (Config): The model configuration for custom datasets.
  750. preprocessor (Preprocessor, Optional): Preprocessor for data samples.
  751. mode (str, Optional): See modelscope.utils.constant.ModeKeys
  752. Returns:
  753. `MsDataset`
  754. """
  755. if not is_torch_available():
  756. raise ImportError(
  757. 'The function to_custom_dataset requires pytorch to be installed'
  758. )
  759. if not custom_cfg:
  760. return
  761. # Set the flag that it has been converted to custom dataset
  762. self.is_custom = True
  763. # Check mode
  764. if mode is None:
  765. if 'mode' in kwargs:
  766. mode = kwargs.get('mode')
  767. # Parse cfg
  768. ds_cfg_key = 'train' if mode == ModeKeys.TRAIN else 'val'
  769. data_cfg = custom_cfg.safe_get(f'dataset.{ds_cfg_key}')
  770. if data_cfg is None:
  771. data_cfg = ConfigDict(type=custom_cfg.model.type) if hasattr(
  772. custom_cfg, ConfigFields.model) else ConfigDict(type=None)
  773. data_cfg.update(dict(mode=mode))
  774. # Get preprocessors from custom_cfg
  775. task_name = custom_cfg.task
  776. if 'task' in kwargs:
  777. task_name = kwargs.pop('task')
  778. field_name = Tasks.find_field_by_task(task_name)
  779. if 'field' in kwargs:
  780. field_name = kwargs.pop('field')
  781. if preprocessor is None and hasattr(custom_cfg, 'preprocessor'):
  782. preprocessor_cfg = custom_cfg.preprocessor
  783. if preprocessor_cfg:
  784. preprocessor = build_preprocessor(preprocessor_cfg, field_name)
  785. # Build custom dataset
  786. if isinstance(self._hf_ds, ExternalDataset):
  787. data_cfg.update(dict(preprocessor=preprocessor))
  788. data_cfg.update(self._hf_ds.config_kwargs)
  789. self._hf_ds = build_custom_dataset(
  790. cfg=data_cfg, task_name=custom_cfg.task)
  791. return
  792. if preprocessor is not None:
  793. to_tensor = kwargs.get('to_tensor', True)
  794. self._hf_ds = self._to_torch_dataset_with_processors(
  795. preprocessors=preprocessor, to_tensor=to_tensor)
  796. else:
  797. self._hf_ds.reset_format()
  798. self._hf_ds.set_format(type='torch')
  799. return