dataset.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333
  1. # Copyright (c) Alibaba, Inc. and its affiliates.
  2. import copy
  3. import math
  4. import os
  5. from itertools import islice
  6. import datasets
  7. import pandas as pd
  8. from datasets import IterableDataset
  9. from tqdm.auto import tqdm
  10. from modelscope.msdatasets.utils.maxcompute_utils import MaxComputeUtil
  11. from modelscope.utils.constant import (DEFAULT_MAXCOMPUTE_ENDPOINT,
  12. EXTENSIONS_TO_LOAD, MaxComputeEnvs,
  13. VirgoDatasetConfig)
  14. from modelscope.utils.logger import get_logger
  15. from modelscope.utils.url_utils import fetch_csv_with_url, valid_url
  16. logger = get_logger()
  17. class ExternalDataset(object):
  18. """Dataset class for custom datasets."""
  19. def __init__(self, split_path_dict, config_kwargs):
  20. self.split_path_dict = split_path_dict
  21. self.config_kwargs = copy.deepcopy(config_kwargs)
  22. self.config_kwargs.update({'split_config': self.split_path_dict})
  23. # dataset for specific extensions
  24. self.spec_extension_dataset = None
  25. self.split_data_files = {
  26. k: []
  27. for k, _ in self.split_path_dict.items()
  28. }
  29. self.custom_map = {}
  30. # the extension of file
  31. file_ext = ''
  32. for split_name, split_dir in self.split_path_dict.items():
  33. if isinstance(split_dir, str) and os.path.isdir(split_dir):
  34. split_file_names = os.listdir(split_dir)
  35. set_files_exts = set([
  36. os.path.splitext(file_name)[-1].strip('.')
  37. for file_name in split_file_names
  38. ])
  39. if '' in set_files_exts:
  40. continue
  41. # ensure these files have same extensions
  42. if len(set_files_exts) != 1:
  43. supported_exts = ','.join(EXTENSIONS_TO_LOAD.keys())
  44. logger.error(
  45. f'Split-{split_name} has been ignored, please flatten your folder structure, '
  46. f'and make sure these files have same extensions. '
  47. f'Supported extensions: {supported_exts} .')
  48. continue
  49. file_ext = list(set_files_exts)[0]
  50. if file_ext not in EXTENSIONS_TO_LOAD:
  51. continue
  52. split_file_paths = [
  53. os.path.join(split_dir, file_name)
  54. for file_name in split_file_names
  55. ]
  56. self.split_data_files[split_name] = split_file_paths
  57. if file_ext:
  58. file_ext = EXTENSIONS_TO_LOAD.get(file_ext)
  59. self.spec_extension_dataset = datasets.load_dataset(
  60. file_ext, data_files=self.split_data_files, **config_kwargs)
  61. def __len__(self):
  62. return len(
  63. self.split_path_dict
  64. ) if not self.spec_extension_dataset else self.spec_extension_dataset.__len__(
  65. )
  66. def __getitem__(self, item):
  67. if not self.spec_extension_dataset:
  68. return self.split_path_dict.get(item)
  69. else:
  70. return self.spec_extension_dataset.__getitem__(item)
  71. def __iter__(self):
  72. if not self.spec_extension_dataset:
  73. for k, v in self.split_path_dict.items():
  74. yield k, v
  75. else:
  76. for k, v in self.spec_extension_dataset.items():
  77. yield k, v
  78. class NativeIterableDataset(IterableDataset):
  79. """The modelscope iterable dataset class."""
  80. def __init__(self, ex_iterable, info, split, stream_batch_size=1):
  81. super().__init__(ex_iterable=ex_iterable, info=info, split=split)
  82. self.stream_batch_size = stream_batch_size
  83. def __iter__(self):
  84. for item in tqdm(
  85. self.iter(
  86. batch_size=self.stream_batch_size, drop_last_batch=False),
  87. desc='Overall progress',
  88. total=self.n_shards,
  89. dynamic_ncols=True):
  90. ret = self._download_item(item)
  91. yield ret
  92. def __len__(self):
  93. return self.n_shards
  94. def __getitem__(self, index):
  95. """
  96. Returns the item at index `index` in the dataset. Slice indexing is supported.
  97. """
  98. if isinstance(index, int):
  99. start = index
  100. stop = index + 1
  101. step = None
  102. else:
  103. start = index.start
  104. stop = index.stop
  105. step = index.step
  106. if step is not None and step <= 0:
  107. raise ValueError('step must be positive')
  108. for item in tqdm(
  109. islice(
  110. self.iter(batch_size=1, drop_last_batch=False), start,
  111. stop, step),
  112. desc='Slicing progress',
  113. dynamic_ncols=True):
  114. ret = self._download_item(item)
  115. yield ret
  116. def _download_item(self, item):
  117. ret = {}
  118. if isinstance(item, dict):
  119. try:
  120. for k, v in item.items():
  121. ret[k] = v
  122. if k.endswith(':FILE'):
  123. dl_manager = self._ex_iterable.kwargs.get('dl_manager')
  124. ex_cache_path = dl_manager.download_and_extract(v)
  125. if isinstance(ex_cache_path, str):
  126. ex_cache_path = [ex_cache_path]
  127. ret[k] = ex_cache_path
  128. ret[k.strip(':FILE')] = v
  129. except Exception as e:
  130. logger.error(e)
  131. ret = item
  132. else:
  133. ret = item
  134. return ret
  135. def head(self, n=5):
  136. """
  137. Returns the first n rows of the dataset.
  138. Args:
  139. n (int): Number of rows to return.
  140. Returns:
  141. list: The list of results, e.g. [{'id': 'abc123', 'text': 'hello world'}, ...]
  142. """
  143. # return self._head(n=n)
  144. res = []
  145. if n <= 0:
  146. return res
  147. iter_num = 0
  148. for item in self.__iter__():
  149. if iter_num >= n:
  150. break
  151. res.append(item)
  152. iter_num += 1
  153. return res
  154. class VirgoDataset(object):
  155. """Dataset class for Virgo.
  156. Attributes:
  157. _meta_content (str): Virgo meta data content, could be a url that contains csv file.
  158. _data_type (int): Virgo dataset type, 0-Standard virgo dataset; Others-User define dataset (to be supported)
  159. Examples:
  160. >>> from modelscope.msdatasets.dataset_cls.dataset import VirgoDataset
  161. >>> input_kwargs = {'metaContent': 'http://xxx-xxx/xxx.csv', 'samplingType': 0}
  162. >>> virgo_dataset = VirgoDataset(**input_kwargs)
  163. >>> print(virgo_dataset[1])
  164. >>> print(len(virgo_dataset))
  165. >>> for line in virgo_dataset:
  166. >>> print(line)
  167. Note: If you set `download_virgo_files` to True by using
  168. MsDataset.load(dataset_name='your-virgo-dataset-id', hub=Hubs.virgo, download_virgo_files=True),
  169. you can get the cache file path of the virgo dataset, the column name is `cache_file`.
  170. >>> if virgo_dataset.download_virgo_files:
  171. >>> print(virgo_dataset[1].get('cache_file'))
  172. """
  173. def __init__(self, **kwargs):
  174. self._meta_content: str = ''
  175. self.data_type: int = 0
  176. self.odps_table_name: str = ''
  177. self.odps_table_partition: str = None
  178. self._odps_utils: MaxComputeUtil = None
  179. self.config_kwargs = kwargs
  180. self._meta: pd.DataFrame = pd.DataFrame()
  181. self._meta_content = self.config_kwargs.pop(
  182. VirgoDatasetConfig.meta_content, '')
  183. self.data_type = self.config_kwargs.pop(
  184. VirgoDatasetConfig.sampling_type, 0)
  185. self._check_variables()
  186. self._parse_meta()
  187. self.meta_content_cache_file = ''
  188. self.virgo_cache_dir = ''
  189. self.download_virgo_files: bool = False
  190. self.odps_table_ins = None
  191. self.odps_reader_ins = None
  192. self.odps_batch_size = self.config_kwargs.pop('odps_batch_size', 100)
  193. self.odps_limit = self.config_kwargs.pop('odps_limit', None)
  194. self.odps_drop_last = self.config_kwargs.pop('odps_drop_last', False)
  195. if self._odps_utils:
  196. self.odps_table_ins, self.odps_reader_ins = self._odps_utils.get_table_reader_ins(
  197. self.odps_table_name, self.odps_table_partition)
  198. def __getitem__(self, index):
  199. if self.odps_reader_ins:
  200. return MaxComputeUtil.gen_reader_item(
  201. reader=self.odps_reader_ins,
  202. index=index,
  203. batch_size_in=self.odps_batch_size,
  204. limit_in=self.odps_limit,
  205. drop_last_in=self.odps_drop_last,
  206. partitions=self.odps_table_ins.table_schema.partitions,
  207. columns=self.odps_table_ins.table_schema.names)
  208. return self._meta.iloc[index].to_dict()
  209. def __len__(self):
  210. if isinstance(self._meta, dict):
  211. return self._meta.get('odpsCount', 0)
  212. return len(self._meta)
  213. def __iter__(self):
  214. if self.odps_reader_ins:
  215. odps_batch_data = MaxComputeUtil.gen_reader_batch(
  216. reader=self.odps_reader_ins,
  217. batch_size_in=self.odps_batch_size,
  218. limit_in=self.odps_limit,
  219. drop_last_in=self.odps_drop_last,
  220. partitions=self.odps_table_ins.table_schema.partitions,
  221. columns=self.odps_table_ins.table_schema.names)
  222. for batch in odps_batch_data:
  223. yield batch
  224. else:
  225. for _, row in self._meta.iterrows():
  226. yield row.to_dict()
  227. @property
  228. def meta(self) -> pd.DataFrame:
  229. """
  230. Virgo meta data. Contains columns: id, meta_info, analysis_result, external_info and
  231. cache_file (if download_virgo_files is True).
  232. """
  233. return self._meta
  234. def _parse_meta(self):
  235. # Fetch csv content
  236. if isinstance(self._meta_content, str) and valid_url(
  237. self._meta_content):
  238. meta_content_df = fetch_csv_with_url(self._meta_content)
  239. self._meta = meta_content_df
  240. elif isinstance(self._meta_content, dict):
  241. self._meta = self._meta_content
  242. self.odps_table_name = self._meta.get('odpsTableName', '')
  243. self.odps_table_partition = self._meta.get('odpsTablePartition',
  244. None)
  245. self._odps_utils = self._get_odps_info()
  246. else:
  247. raise 'The meta content must be url or dict.'
  248. @staticmethod
  249. def _get_odps_info() -> MaxComputeUtil:
  250. """
  251. Get MaxComputeUtil instance.
  252. Args:
  253. None
  254. Returns:
  255. MaxComputeUtil instance.
  256. """
  257. access_id = os.environ.get(MaxComputeEnvs.ACCESS_ID, '')
  258. access_key = os.environ.get(MaxComputeEnvs.ACCESS_SECRET_KEY, '')
  259. proj_name = os.environ.get(MaxComputeEnvs.PROJECT_NAME, '')
  260. endpoint = os.environ.get(MaxComputeEnvs.ENDPOINT,
  261. DEFAULT_MAXCOMPUTE_ENDPOINT)
  262. if not access_id or not access_key or not proj_name:
  263. raise ValueError(
  264. f'Please set MaxCompute envs for Virgo: {MaxComputeEnvs.ACCESS_ID}, '
  265. f'{MaxComputeEnvs.ACCESS_SECRET_KEY}, {MaxComputeEnvs.PROJECT_NAME}, '
  266. f'{MaxComputeEnvs.ENDPOINT}(default: http://service-corp.odps.aliyun-inc.com/api)'
  267. )
  268. return MaxComputeUtil(access_id, access_key, proj_name, endpoint)
  269. def _check_variables(self):
  270. """Check member variables in this class.
  271. 1. Condition-1: self._meta_content cannot be empty
  272. 2. Condition-2: self._meta_content must be url when self._data_type is 0
  273. """
  274. if not self._meta_content:
  275. raise 'Them meta content cannot be empty.'
  276. if self.data_type not in [0, 1]:
  277. raise 'Supported samplingType should be 0 or 1, others are not supported yet.'
  278. if self.data_type == 0 and not valid_url(self._meta_content):
  279. raise 'The meta content must be url when data type is 0.'
  280. if self.data_type == 1 and not isinstance(self._meta_content, dict):
  281. raise 'The meta content must be dict when data type is 1.'