# !/usr/bin/env python3 # -*- coding: UTF-8 -*- ################################################################################ # # Copyright (c) 2023 Baidu.com, Inc. All Rights Reserved # ################################################################################ """ 本文件实现了常用的工具函数 Authors: xiangyiqing(xiangyiqing@baidu.com) Date: 2023/07/24 """ import tempfile import sys import os import io import re import base64 import hashlib from datetime import datetime, timezone, timedelta import zipfile from aistudio_sdk import log from aistudio_sdk.errors import FileIntegrityError from aistudio_sdk.config import DEFAULT_MAX_WORKERS from functools import wraps from tqdm.auto import tqdm from concurrent.futures import ThreadPoolExecutor, as_completed from typing import List, Union, BinaryIO, Optional from aistudio_sdk.constant.version import VERSION from pathlib import Path class Dict(dict): """dict class""" def __getattr__(self, key): value = self.get(key, None) return Dict(value) if isinstance(value, dict) else value def __setattr__(self, key, value): self[key] = value def convert_to_dict_object(resp): """ Params :resp: dict, response from AIStudio Rerurns AIStudio object """ if isinstance(resp, dict): return Dict(resp) return resp def err_resp(sdk_code, msg, biz_code=None, log_id=None): """ 构造错误响应信息。 Params: sdk_code (str): SDK错误码,标识错误类型。 msg (str): 错误描述信息。 biz_code (str, optional): 业务层面的错误码,透传自上游接口。 log_id (str, optional): 与错误相关的日志ID,透传自上游接口。 Returns: dict: 格式化好的错误信息。 """ return { "error_code": sdk_code, # 错误码 "error_msg": msg, # 错误消息 "biz_code": biz_code, # 业务错误码 "log_id": log_id # 日志ID } def is_valid_host(host): """检测host合法性""" # 去除可能的协议前缀 如http://、https:// host = re.sub(r'^https?://', '', host, flags=re.IGNORECASE) result = is_valid_domain(host) # if not result: # host = re.sub(r'^http?://', '', host, flags=re.IGNORECASE) # result = is_valid_domain(host) return result def is_valid_domain(domain): """检测域名合法性""" return True # pattern = r"^(?!-)[A-Za-z0-9-]{1,63}(? int: """ get size """ if isinstance(file_path_or_obj, (str, Path)): file_path = Path(file_path_or_obj) return file_path.stat().st_size elif isinstance(file_path_or_obj, bytes): return len(file_path_or_obj) elif isinstance(file_path_or_obj, io.BufferedIOBase): current_position = file_path_or_obj.tell() file_path_or_obj.seek(0, os.SEEK_END) size = file_path_or_obj.tell() file_path_or_obj.seek(current_position) return size else: raise TypeError( 'Unsupported type: must be string, Path, bytes, or io.BufferedIOBase' ) def get_file_hash( file_path_or_obj: Union[str, Path, bytes, BinaryIO], buffer_size_mb: Optional[int] = 1, tqdm_desc: Optional[str] = '[Calculating]', disable_tqdm: Optional[bool] = True, ) -> dict: """ calculate hash """ from tqdm.auto import tqdm file_size = get_file_size(file_path_or_obj) if file_size > 1024 * 1024 * 1024: # 1GB disable_tqdm = False name = 'Large File' if isinstance(file_path_or_obj, (str, Path)): path = file_path_or_obj if isinstance( file_path_or_obj, Path) else Path(file_path_or_obj) name = path.name tqdm_desc = f'[Validating Hash for {name}]' buffer_size = buffer_size_mb * 1024 * 1024 file_hash = hashlib.sha256() chunk_hash_list = [] progress = tqdm( total=file_size, initial=0, unit_scale=True, dynamic_ncols=True, unit='B', desc=tqdm_desc, disable=disable_tqdm, ) if isinstance(file_path_or_obj, (str, Path)): with open(file_path_or_obj, 'rb') as f: while byte_chunk := f.read(buffer_size): chunk_hash_list.append(hashlib.sha256(byte_chunk).hexdigest()) file_hash.update(byte_chunk) progress.update(len(byte_chunk)) file_hash = file_hash.hexdigest() final_chunk_size = buffer_size elif isinstance(file_path_or_obj, bytes): file_hash.update(file_path_or_obj) file_hash = file_hash.hexdigest() chunk_hash_list.append(file_hash) final_chunk_size = len(file_path_or_obj) progress.update(final_chunk_size) elif isinstance(file_path_or_obj, io.BufferedIOBase): while byte_chunk := file_path_or_obj.read(buffer_size): chunk_hash_list.append(hashlib.sha256(byte_chunk).hexdigest()) file_hash.update(byte_chunk) progress.update(len(byte_chunk)) file_hash = file_hash.hexdigest() final_chunk_size = buffer_size else: progress.close() raise ValueError( 'Input must be str, Path, bytes or a io.BufferedIOBase') progress.close() return { 'file_path_or_obj': file_path_or_obj, 'file_hash': file_hash, 'file_size': file_size, 'chunk_size': final_chunk_size, 'chunk_nums': len(chunk_hash_list), 'chunk_hash_list': chunk_hash_list, }