dataset_utils.py 8.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242
  1. # Copyright (c) Alibaba, Inc. and its affiliates.
  2. import os
  3. from collections import defaultdict
  4. from typing import Optional, Union
  5. import pandas as pd
  6. from modelscope.hub.api import HubApi
  7. from modelscope.msdatasets.context.dataset_context_config import \
  8. DatasetContextConfig
  9. from modelscope.utils.constant import DEFAULT_DATASET_REVISION, MetaDataFields
  10. from modelscope.utils.logger import get_logger
  11. logger = get_logger()
  12. def format_dataset_structure(dataset_structure):
  13. return {
  14. k: v
  15. for k, v in dataset_structure.items()
  16. if (v.get('meta') or v.get('file'))
  17. }
  18. def get_target_dataset_structure(dataset_structure: dict,
  19. subset_name: Optional[str] = None,
  20. split: Optional[str] = None):
  21. """
  22. Args:
  23. dataset_structure (dict): Dataset Structure, like
  24. {
  25. "default":{
  26. "train":{
  27. "meta":"my_train.csv",
  28. "file":"pictures.zip"
  29. }
  30. },
  31. "subsetA":{
  32. "test":{
  33. "meta":"mytest.csv",
  34. "file":"pictures.zip"
  35. }
  36. }
  37. }
  38. subset_name (str, optional): Defining the subset_name of the dataset.
  39. split (str, optional): Which split of the data to load.
  40. Returns:
  41. target_subset_name (str): Name of the chosen subset.
  42. target_dataset_structure (dict): Structure of the chosen split(s), like
  43. {
  44. "test":{
  45. "meta":"mytest.csv",
  46. "file":"pictures.zip"
  47. }
  48. }
  49. """
  50. # verify dataset subset
  51. if (subset_name and subset_name not in dataset_structure) or (
  52. not subset_name and len(dataset_structure.keys()) > 1):
  53. raise ValueError(
  54. f'subset_name {subset_name} not found. Available: {dataset_structure.keys()}'
  55. )
  56. target_subset_name = subset_name
  57. if not subset_name:
  58. target_subset_name = next(iter(dataset_structure.keys()))
  59. logger.info(
  60. f'No subset_name specified, defaulting to the {target_subset_name}'
  61. )
  62. # verify dataset split
  63. target_dataset_structure = format_dataset_structure(
  64. dataset_structure[target_subset_name])
  65. if split and split not in target_dataset_structure:
  66. raise ValueError(
  67. f'split {split} not found. Available: {target_dataset_structure.keys()}'
  68. )
  69. if split:
  70. target_dataset_structure = {split: target_dataset_structure[split]}
  71. return target_subset_name, target_dataset_structure
  72. def list_dataset_objects(hub_api: HubApi, max_limit: int, is_recursive: bool,
  73. dataset_name: str, namespace: str,
  74. version: str) -> list:
  75. """
  76. List all objects for specific dataset.
  77. Args:
  78. hub_api (class HubApi): HubApi instance.
  79. max_limit (int): Max number of objects.
  80. is_recursive (bool): Whether to list objects recursively.
  81. dataset_name (str): Dataset name.
  82. namespace (str): Namespace.
  83. version (str): Dataset version.
  84. Returns:
  85. res (list): List of objects, i.e., ['train/images/001.png', 'train/images/002.png', 'val/images/001.png', ...]
  86. """
  87. res = []
  88. objects = hub_api.list_oss_dataset_objects(
  89. dataset_name=dataset_name,
  90. namespace=namespace,
  91. max_limit=max_limit,
  92. is_recursive=is_recursive,
  93. is_filter_dir=True,
  94. revision=version)
  95. for item in objects:
  96. object_key = item.get('Key')
  97. if not object_key:
  98. continue
  99. res.append(object_key)
  100. return res
  101. def contains_dir(file_map) -> bool:
  102. """
  103. To check whether input contains at least one directory.
  104. Args:
  105. file_map (dict): Structure of data files. e.g., {'train': 'train.zip', 'validation': 'val.zip'}
  106. Returns:
  107. True if input contains at least one directory, False otherwise.
  108. """
  109. res = False
  110. for k, v in file_map.items():
  111. if isinstance(v, str) and not v.endswith('.zip'):
  112. res = True
  113. break
  114. return res
  115. def get_subdir_hash_from_split(split: Union[str, list], version: str) -> str:
  116. if isinstance(split, str):
  117. split = [split]
  118. return os.path.join(version, '_'.join(split))
  119. def get_split_list(split: Union[str, list]) -> list:
  120. """ Unify the split to list-format. """
  121. if isinstance(split, str):
  122. return [split]
  123. elif isinstance(split, list):
  124. return split
  125. else:
  126. raise f'Expected format of split: str or list, but got {type(split)}.'
  127. def get_split_objects_map(file_map, objects):
  128. """
  129. Get the map between dataset split and oss objects.
  130. Args:
  131. file_map (dict): Structure of data files. e.g., {'train': 'train', 'validation': 'val'}, both of train and val
  132. are dirs.
  133. objects (list): List of oss objects. e.g., ['train/001/1_123.png', 'train/001/1_124.png', 'val/003/3_38.png']
  134. Returns:
  135. A map of split-objects. e.g., {'train': ['train/001/1_123.png', 'train/001/1_124.png'],
  136. 'validation':['val/003/3_38.png']}
  137. """
  138. res = {}
  139. for k, v in file_map.items():
  140. res[k] = []
  141. for obj_key in objects:
  142. for k, v in file_map.items():
  143. if obj_key.startswith(v.rstrip('/') + '/'):
  144. res[k].append(obj_key)
  145. return res
  146. def get_dataset_files(subset_split_into: dict,
  147. dataset_name: str,
  148. namespace: str,
  149. context_config: DatasetContextConfig,
  150. revision: Optional[str] = DEFAULT_DATASET_REVISION):
  151. """
  152. Return:
  153. meta_map: Structure of meta files (.csv), the meta file name will be replaced by url, like
  154. {
  155. "test": "https://xxx/mytest.csv"
  156. }
  157. file_map: Structure of data files (.zip), like
  158. {
  159. "test": "pictures.zip"
  160. }
  161. """
  162. meta_map = defaultdict(dict)
  163. file_map = defaultdict(dict)
  164. args_map = defaultdict(dict)
  165. custom_type_map = defaultdict(dict)
  166. modelscope_api = HubApi()
  167. meta_cache_dir = context_config.data_meta_config.meta_cache_dir
  168. for split, info in subset_split_into.items():
  169. custom_type_map[split] = info.get('custom', '')
  170. meta_map[split] = modelscope_api.get_dataset_file_url_origin(
  171. info.get('meta', ''), dataset_name, namespace, revision)
  172. if info.get('file'):
  173. file_map[split] = info['file']
  174. args_map[split] = info.get('args')
  175. objects = []
  176. # If `big_data` is true, then fetch objects from meta-csv file directly.
  177. for split, args_dict in args_map.items():
  178. if args_dict and args_dict.get(MetaDataFields.ARGS_BIG_DATA):
  179. meta_csv_file_url = meta_map[split]
  180. meta_csv_file_path = HubApi.fetch_meta_files_from_url(
  181. meta_csv_file_url, meta_cache_dir)
  182. csv_delimiter = context_config.config_kwargs.get('delimiter', ',')
  183. csv_df = pd.read_csv(
  184. meta_csv_file_path,
  185. iterator=False,
  186. delimiter=csv_delimiter,
  187. escapechar='\\')
  188. target_col = csv_df.columns[csv_df.columns.str.contains(
  189. ':FILE')].to_list()
  190. if len(target_col) == 0:
  191. logger.error(
  192. f'No column contains ":FILE" in {meta_csv_file_path}.')
  193. target_col = csv_df.columns[0]
  194. else:
  195. target_col = target_col[0]
  196. objects = csv_df[target_col].to_list()
  197. file_map[split] = objects
  198. # More general but low-efficiency.
  199. if not objects:
  200. objects = list_dataset_objects(
  201. hub_api=modelscope_api,
  202. max_limit=-1,
  203. is_recursive=True,
  204. dataset_name=dataset_name,
  205. namespace=namespace,
  206. version=revision)
  207. if contains_dir(file_map):
  208. file_map = get_split_objects_map(file_map, objects)
  209. return meta_map, file_map, args_map, custom_type_map