# Copyright (c) Alibaba, Inc. and its affiliates. import os from abc import ABC, abstractmethod from typing import Optional, Union from datasets import (Dataset, DatasetBuilder, DatasetDict, IterableDataset, IterableDatasetDict) from datasets import load_dataset as hf_load_dataset from modelscope.hub.api import ModelScopeConfig from modelscope.msdatasets.auth.auth_config import OssAuthConfig from modelscope.msdatasets.context.dataset_context_config import \ DatasetContextConfig from modelscope.msdatasets.data_files.data_files_manager import \ DataFilesManager from modelscope.msdatasets.dataset_cls import ExternalDataset from modelscope.msdatasets.meta.data_meta_manager import DataMetaManager from modelscope.utils.constant import (DatasetFormations, DatasetPathName, DownloadMode, VirgoDatasetConfig) from modelscope.utils.logger import get_logger from modelscope.utils.url_utils import valid_url logger = get_logger() class BaseDownloader(ABC): """Base dataset downloader to load data.""" def __init__(self, dataset_context_config: DatasetContextConfig): self.dataset_context_config = dataset_context_config @abstractmethod def process(self): """The entity processing pipeline for fetching the data. """ raise NotImplementedError( f'No default implementation provided for {BaseDownloader.__name__}.process.' ) @abstractmethod def _authorize(self): raise NotImplementedError( f'No default implementation provided for {BaseDownloader.__name__}._authorize.' ) @abstractmethod def _build(self): raise NotImplementedError( f'No default implementation provided for {BaseDownloader.__name__}._build.' ) @abstractmethod def _prepare_and_download(self): raise NotImplementedError( f'No default implementation provided for {BaseDownloader.__name__}._prepare_and_download.' ) @abstractmethod def _post_process(self): raise NotImplementedError( f'No default implementation provided for {BaseDownloader.__name__}._post_process.' ) class OssDownloader(BaseDownloader): def __init__(self, dataset_context_config: DatasetContextConfig): super().__init__(dataset_context_config) self.data_files_builder: Optional[DataFilesManager] = None self.dataset: Optional[Union[Dataset, IterableDataset, DatasetDict, IterableDatasetDict, ExternalDataset]] = None self.builder: Optional[DatasetBuilder] = None self.data_files_manager: Optional[DataFilesManager] = None def process(self) -> None: """ Sequential data fetching process: authorize -> build -> prepare_and_download -> post_process, to keep dataset_context_config updated. """ self._authorize() self._build() self._prepare_and_download() self._post_process() def _authorize(self) -> None: """ Authorization of target dataset. Get credentials from cache and send to the modelscope-hub in the future. """ cookies = ModelScopeConfig.get_cookies() git_token = ModelScopeConfig.get_token() user_info = ModelScopeConfig.get_user_info() if not self.dataset_context_config.auth_config: auth_config = OssAuthConfig( cookies=cookies, git_token=git_token, user_info=user_info) else: auth_config = self.dataset_context_config.auth_config auth_config.cookies = cookies auth_config.git_token = git_token auth_config.user_info = user_info self.dataset_context_config.auth_config = auth_config def _build(self) -> None: """ Sequential data files building process: build_meta -> build_data_files , to keep context_config updated. """ # Build meta data meta_manager = DataMetaManager(self.dataset_context_config) meta_manager.fetch_meta_files() meta_manager.parse_dataset_structure() self.dataset_context_config = meta_manager.dataset_context_config # Build data-files manager self.data_files_manager = DataFilesManager( dataset_context_config=self.dataset_context_config) self.builder = self.data_files_manager.get_data_files_builder() def _prepare_and_download(self) -> None: """ Fetch data-files from modelscope dataset-hub. """ dataset_py_script = self.dataset_context_config.data_meta_config.dataset_py_script dataset_formation = self.dataset_context_config.data_meta_config.dataset_formation dataset_name = self.dataset_context_config.dataset_name subset_name = self.dataset_context_config.subset_name version = self.dataset_context_config.version split = self.dataset_context_config.split data_dir = self.dataset_context_config.data_dir data_files = self.dataset_context_config.data_files cache_dir = self.dataset_context_config.cache_root_dir download_mode = self.dataset_context_config.download_mode input_kwargs = self.dataset_context_config.config_kwargs trust_remote_code = self.dataset_context_config.trust_remote_code if self.builder is None and not dataset_py_script: raise f'meta-file: {dataset_name}.py not found on the modelscope hub.' if dataset_py_script and dataset_formation == DatasetFormations.hf_compatible: if trust_remote_code: logger.warning( f'Use trust_remote_code=True. Will invoke codes from {dataset_name}. Please make ' 'sure that you can trust the external codes.') self.dataset = hf_load_dataset( dataset_py_script, name=subset_name, revision=version, split=split, data_dir=data_dir, data_files=data_files, cache_dir=cache_dir, download_mode=download_mode.value, trust_remote_code=trust_remote_code, **input_kwargs) else: self.dataset = self.data_files_manager.fetch_data_files( self.builder) def _post_process(self) -> None: if isinstance(self.dataset, ExternalDataset): self.dataset.custom_map = self.dataset_context_config.data_meta_config.meta_type_map class VirgoDownloader(BaseDownloader): """Data downloader for Virgo data source.""" def __init__(self, dataset_context_config: DatasetContextConfig): super().__init__(dataset_context_config) self.dataset = None def process(self): """ Sequential data fetching virgo dataset process: authorize -> build -> prepare_and_download -> post_process """ self._authorize() self._build() self._prepare_and_download() self._post_process() def _authorize(self): """Authorization of virgo dataset.""" from modelscope.msdatasets.auth.auth_config import VirgoAuthConfig cookies = ModelScopeConfig.get_cookies() user_info = ModelScopeConfig.get_user_info() if not self.dataset_context_config.auth_config: auth_config = VirgoAuthConfig( cookies=cookies, git_token='', user_info=user_info) else: auth_config = self.dataset_context_config.auth_config auth_config.cookies = cookies auth_config.git_token = '' auth_config.user_info = user_info self.dataset_context_config.auth_config = auth_config def _build(self): """ Fetch virgo meta and build virgo dataset. """ from modelscope.msdatasets.dataset_cls.dataset import VirgoDataset import pandas as pd meta_manager = DataMetaManager(self.dataset_context_config) meta_manager.fetch_virgo_meta() self.dataset_context_config = meta_manager.dataset_context_config self.dataset = VirgoDataset( **self.dataset_context_config.config_kwargs) virgo_cache_dir = os.path.join( self.dataset_context_config.cache_root_dir, self.dataset_context_config.namespace, self.dataset_context_config.dataset_name, self.dataset_context_config.version) os.makedirs( os.path.join(virgo_cache_dir, DatasetPathName.META_NAME), exist_ok=True) meta_content_cache_file = os.path.join(virgo_cache_dir, DatasetPathName.META_NAME, 'meta_content.csv') if isinstance(self.dataset.meta, pd.DataFrame): meta_content_df = self.dataset.meta meta_content_df.to_csv(meta_content_cache_file, index=False) self.dataset.meta_content_cache_file = meta_content_cache_file self.dataset.virgo_cache_dir = virgo_cache_dir logger.info( f'Virgo meta content saved to {meta_content_cache_file}') def _prepare_and_download(self): """ Fetch data-files from oss-urls in the virgo meta content. """ download_virgo_files = self.dataset_context_config.config_kwargs.pop( 'download_virgo_files', '') if self.dataset.data_type == 0 and download_virgo_files: import requests import json import shutil from urllib.parse import urlparse from functools import partial def download_file(meta_info_val, data_dir): file_url_list = [] file_path_list = [] try: meta_info_val = json.loads(meta_info_val) # get url first, if not exist, try to get inner_url file_url = meta_info_val.get('url', '') if file_url: file_url_list.append(file_url) else: tmp_inner_member_list = meta_info_val.get( 'inner_url', '') for item in tmp_inner_member_list: file_url = item.get('url', '') if file_url: file_url_list.append(file_url) for one_file_url in file_url_list: is_url = valid_url(one_file_url) if is_url: url_parse_res = urlparse(file_url) file_name = os.path.basename(url_parse_res.path) else: raise ValueError(f'Unsupported url: {file_url}') file_path = os.path.join(data_dir, file_name) file_path_list.append((one_file_url, file_path)) except Exception as e: logger.error(f'parse virgo meta info error: {e}') file_path_list = [] for file_url_item, file_path_item in file_path_list: if file_path_item and not os.path.exists(file_path_item): logger.info(f'Downloading file to {file_path_item}') os.makedirs(data_dir, exist_ok=True) with open(file_path_item, 'wb') as f: f.write(requests.get(file_url_item).content) return file_path_list self.dataset.download_virgo_files = True download_mode = self.dataset_context_config.download_mode data_files_dir = os.path.join(self.dataset.virgo_cache_dir, DatasetPathName.DATA_FILES_NAME) if download_mode == DownloadMode.FORCE_REDOWNLOAD: shutil.rmtree(data_files_dir, ignore_errors=True) from tqdm.auto import tqdm tqdm.pandas(desc='apply download_file') self.dataset.meta[ VirgoDatasetConfig. col_cache_file] = self.dataset.meta.progress_apply( lambda row: partial( download_file, data_dir=data_files_dir)(row.meta_info), axis=1) def _post_process(self): ... class MaxComputeDownloader(BaseDownloader): """Data downloader for MaxCompute data source.""" # TODO: MaxCompute data source to be supported . def __init__(self, dataset_context_config: DatasetContextConfig): super().__init__(dataset_context_config) self.dataset = None def process(self): ... def _authorize(self): ... def _build(self): ... def _prepare_and_download(self): ... def _post_process(self): ...