| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325 |
- # Copyright (c) Alibaba, Inc. and its affiliates.
- import os
- from abc import ABC, abstractmethod
- from typing import Optional, Union
- from datasets import (Dataset, DatasetBuilder, DatasetDict, IterableDataset,
- IterableDatasetDict)
- from datasets import load_dataset as hf_load_dataset
- from modelscope.hub.api import ModelScopeConfig
- from modelscope.msdatasets.auth.auth_config import OssAuthConfig
- from modelscope.msdatasets.context.dataset_context_config import \
- DatasetContextConfig
- from modelscope.msdatasets.data_files.data_files_manager import \
- DataFilesManager
- from modelscope.msdatasets.dataset_cls import ExternalDataset
- from modelscope.msdatasets.meta.data_meta_manager import DataMetaManager
- from modelscope.utils.constant import (DatasetFormations, DatasetPathName,
- DownloadMode, VirgoDatasetConfig)
- from modelscope.utils.logger import get_logger
- from modelscope.utils.url_utils import valid_url
- logger = get_logger()
- class BaseDownloader(ABC):
- """Base dataset downloader to load data."""
- def __init__(self, dataset_context_config: DatasetContextConfig):
- self.dataset_context_config = dataset_context_config
- @abstractmethod
- def process(self):
- """The entity processing pipeline for fetching the data. """
- raise NotImplementedError(
- f'No default implementation provided for {BaseDownloader.__name__}.process.'
- )
- @abstractmethod
- def _authorize(self):
- raise NotImplementedError(
- f'No default implementation provided for {BaseDownloader.__name__}._authorize.'
- )
- @abstractmethod
- def _build(self):
- raise NotImplementedError(
- f'No default implementation provided for {BaseDownloader.__name__}._build.'
- )
- @abstractmethod
- def _prepare_and_download(self):
- raise NotImplementedError(
- f'No default implementation provided for {BaseDownloader.__name__}._prepare_and_download.'
- )
- @abstractmethod
- def _post_process(self):
- raise NotImplementedError(
- f'No default implementation provided for {BaseDownloader.__name__}._post_process.'
- )
- class OssDownloader(BaseDownloader):
- def __init__(self, dataset_context_config: DatasetContextConfig):
- super().__init__(dataset_context_config)
- self.data_files_builder: Optional[DataFilesManager] = None
- self.dataset: Optional[Union[Dataset, IterableDataset, DatasetDict,
- IterableDatasetDict,
- ExternalDataset]] = None
- self.builder: Optional[DatasetBuilder] = None
- self.data_files_manager: Optional[DataFilesManager] = None
- def process(self) -> None:
- """ Sequential data fetching process: authorize -> build -> prepare_and_download -> post_process,
- to keep dataset_context_config updated. """
- self._authorize()
- self._build()
- self._prepare_and_download()
- self._post_process()
- def _authorize(self) -> None:
- """ Authorization of target dataset.
- Get credentials from cache and send to the modelscope-hub in the future. """
- cookies = ModelScopeConfig.get_cookies()
- git_token = ModelScopeConfig.get_token()
- user_info = ModelScopeConfig.get_user_info()
- if not self.dataset_context_config.auth_config:
- auth_config = OssAuthConfig(
- cookies=cookies, git_token=git_token, user_info=user_info)
- else:
- auth_config = self.dataset_context_config.auth_config
- auth_config.cookies = cookies
- auth_config.git_token = git_token
- auth_config.user_info = user_info
- self.dataset_context_config.auth_config = auth_config
- def _build(self) -> None:
- """ Sequential data files building process: build_meta -> build_data_files , to keep context_config updated. """
- # Build meta data
- meta_manager = DataMetaManager(self.dataset_context_config)
- meta_manager.fetch_meta_files()
- meta_manager.parse_dataset_structure()
- self.dataset_context_config = meta_manager.dataset_context_config
- # Build data-files manager
- self.data_files_manager = DataFilesManager(
- dataset_context_config=self.dataset_context_config)
- self.builder = self.data_files_manager.get_data_files_builder()
- def _prepare_and_download(self) -> None:
- """ Fetch data-files from modelscope dataset-hub. """
- dataset_py_script = self.dataset_context_config.data_meta_config.dataset_py_script
- dataset_formation = self.dataset_context_config.data_meta_config.dataset_formation
- dataset_name = self.dataset_context_config.dataset_name
- subset_name = self.dataset_context_config.subset_name
- version = self.dataset_context_config.version
- split = self.dataset_context_config.split
- data_dir = self.dataset_context_config.data_dir
- data_files = self.dataset_context_config.data_files
- cache_dir = self.dataset_context_config.cache_root_dir
- download_mode = self.dataset_context_config.download_mode
- input_kwargs = self.dataset_context_config.config_kwargs
- trust_remote_code = self.dataset_context_config.trust_remote_code
- if self.builder is None and not dataset_py_script:
- raise f'meta-file: {dataset_name}.py not found on the modelscope hub.'
- if dataset_py_script and dataset_formation == DatasetFormations.hf_compatible:
- if trust_remote_code:
- logger.warning(
- f'Use trust_remote_code=True. Will invoke codes from {dataset_name}. Please make '
- 'sure that you can trust the external codes.')
- self.dataset = hf_load_dataset(
- dataset_py_script,
- name=subset_name,
- revision=version,
- split=split,
- data_dir=data_dir,
- data_files=data_files,
- cache_dir=cache_dir,
- download_mode=download_mode.value,
- trust_remote_code=trust_remote_code,
- **input_kwargs)
- else:
- self.dataset = self.data_files_manager.fetch_data_files(
- self.builder)
- def _post_process(self) -> None:
- if isinstance(self.dataset, ExternalDataset):
- self.dataset.custom_map = self.dataset_context_config.data_meta_config.meta_type_map
- class VirgoDownloader(BaseDownloader):
- """Data downloader for Virgo data source."""
- def __init__(self, dataset_context_config: DatasetContextConfig):
- super().__init__(dataset_context_config)
- self.dataset = None
- def process(self):
- """
- Sequential data fetching virgo dataset process: authorize -> build -> prepare_and_download -> post_process
- """
- self._authorize()
- self._build()
- self._prepare_and_download()
- self._post_process()
- def _authorize(self):
- """Authorization of virgo dataset."""
- from modelscope.msdatasets.auth.auth_config import VirgoAuthConfig
- cookies = ModelScopeConfig.get_cookies()
- user_info = ModelScopeConfig.get_user_info()
- if not self.dataset_context_config.auth_config:
- auth_config = VirgoAuthConfig(
- cookies=cookies, git_token='', user_info=user_info)
- else:
- auth_config = self.dataset_context_config.auth_config
- auth_config.cookies = cookies
- auth_config.git_token = ''
- auth_config.user_info = user_info
- self.dataset_context_config.auth_config = auth_config
- def _build(self):
- """
- Fetch virgo meta and build virgo dataset.
- """
- from modelscope.msdatasets.dataset_cls.dataset import VirgoDataset
- import pandas as pd
- meta_manager = DataMetaManager(self.dataset_context_config)
- meta_manager.fetch_virgo_meta()
- self.dataset_context_config = meta_manager.dataset_context_config
- self.dataset = VirgoDataset(
- **self.dataset_context_config.config_kwargs)
- virgo_cache_dir = os.path.join(
- self.dataset_context_config.cache_root_dir,
- self.dataset_context_config.namespace,
- self.dataset_context_config.dataset_name,
- self.dataset_context_config.version)
- os.makedirs(
- os.path.join(virgo_cache_dir, DatasetPathName.META_NAME),
- exist_ok=True)
- meta_content_cache_file = os.path.join(virgo_cache_dir,
- DatasetPathName.META_NAME,
- 'meta_content.csv')
- if isinstance(self.dataset.meta, pd.DataFrame):
- meta_content_df = self.dataset.meta
- meta_content_df.to_csv(meta_content_cache_file, index=False)
- self.dataset.meta_content_cache_file = meta_content_cache_file
- self.dataset.virgo_cache_dir = virgo_cache_dir
- logger.info(
- f'Virgo meta content saved to {meta_content_cache_file}')
- def _prepare_and_download(self):
- """
- Fetch data-files from oss-urls in the virgo meta content.
- """
- download_virgo_files = self.dataset_context_config.config_kwargs.pop(
- 'download_virgo_files', '')
- if self.dataset.data_type == 0 and download_virgo_files:
- import requests
- import json
- import shutil
- from urllib.parse import urlparse
- from functools import partial
- def download_file(meta_info_val, data_dir):
- file_url_list = []
- file_path_list = []
- try:
- meta_info_val = json.loads(meta_info_val)
- # get url first, if not exist, try to get inner_url
- file_url = meta_info_val.get('url', '')
- if file_url:
- file_url_list.append(file_url)
- else:
- tmp_inner_member_list = meta_info_val.get(
- 'inner_url', '')
- for item in tmp_inner_member_list:
- file_url = item.get('url', '')
- if file_url:
- file_url_list.append(file_url)
- for one_file_url in file_url_list:
- is_url = valid_url(one_file_url)
- if is_url:
- url_parse_res = urlparse(file_url)
- file_name = os.path.basename(url_parse_res.path)
- else:
- raise ValueError(f'Unsupported url: {file_url}')
- file_path = os.path.join(data_dir, file_name)
- file_path_list.append((one_file_url, file_path))
- except Exception as e:
- logger.error(f'parse virgo meta info error: {e}')
- file_path_list = []
- for file_url_item, file_path_item in file_path_list:
- if file_path_item and not os.path.exists(file_path_item):
- logger.info(f'Downloading file to {file_path_item}')
- os.makedirs(data_dir, exist_ok=True)
- with open(file_path_item, 'wb') as f:
- f.write(requests.get(file_url_item).content)
- return file_path_list
- self.dataset.download_virgo_files = True
- download_mode = self.dataset_context_config.download_mode
- data_files_dir = os.path.join(self.dataset.virgo_cache_dir,
- DatasetPathName.DATA_FILES_NAME)
- if download_mode == DownloadMode.FORCE_REDOWNLOAD:
- shutil.rmtree(data_files_dir, ignore_errors=True)
- from tqdm.auto import tqdm
- tqdm.pandas(desc='apply download_file')
- self.dataset.meta[
- VirgoDatasetConfig.
- col_cache_file] = self.dataset.meta.progress_apply(
- lambda row: partial(
- download_file, data_dir=data_files_dir)(row.meta_info),
- axis=1)
- def _post_process(self):
- ...
- class MaxComputeDownloader(BaseDownloader):
- """Data downloader for MaxCompute data source."""
- # TODO: MaxCompute data source to be supported .
- def __init__(self, dataset_context_config: DatasetContextConfig):
- super().__init__(dataset_context_config)
- self.dataset = None
- def process(self):
- ...
- def _authorize(self):
- ...
- def _build(self):
- ...
- def _prepare_and_download(self):
- ...
- def _post_process(self):
- ...
|