bos_sdk.py 5.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128
  1. # !/usr/bin/env python3
  2. # -*- coding: UTF-8 -*-
  3. ################################################################################
  4. #
  5. # Copyright (c) 2023 Baidu.com, Inc. All Rights Reserved
  6. #
  7. ################################################################################
  8. """
  9. 本文件实现了对bos的封装, 首先安装 bce-python-sdk
  10. Authors: suoyi@baidu.com
  11. Date: 2024/01/03
  12. """
  13. from aistudio_sdk import log
  14. import os
  15. from baidubce.bce_client_configuration import BceClientConfiguration
  16. from baidubce.auth.bce_credentials import BceCredentials
  17. from baidubce.services.bos.bos_client import BosClient, BceClientError
  18. from baidubce import utils
  19. from typing import List
  20. RETRY_TIMES = int(os.environ.get("AISTUDIO_BOS_RETRY_TIMES", 10))
  21. class MyBosClient(BosClient):
  22. """
  23. 重写BosClient的_upload方法,增加重试功能
  24. """
  25. def _upload_task(self, bucket_name, object_key, upload_id,
  26. part_number, part_size, file_name, offset, part_list, uploadTaskHandle,
  27. progress_callback=None, traffic_limit=None):
  28. if uploadTaskHandle.is_cancel():
  29. log.debug(f"upload task canceled with partNumber={part_number}!")
  30. return
  31. success = False
  32. for i in range(RETRY_TIMES):
  33. try:
  34. response = self.upload_part_from_file(bucket_name, object_key, upload_id,
  35. part_number, part_size, file_name, offset,
  36. progress_callback=progress_callback,
  37. traffic_limit=traffic_limit)
  38. part_list.append({
  39. "partNumber": part_number,
  40. "eTag": response.metadata.etag
  41. })
  42. log.debug(f"upload task success with partNumber={part_number}!")
  43. success = True
  44. break
  45. except Exception as e:
  46. log.error(f"upload task failed with partNumber={part_number}!")
  47. log.debug(e)
  48. log.error(f"重试第{i + 1}次")
  49. if not success:
  50. uploadTaskHandle.cancel()
  51. log.error(f"upload task failed with partNumber={part_number}!已取消上传")
  52. raise BceClientError(f"upload task failed with partNumber={part_number}!")
  53. def put_super_obejct_from_file(self, bucket_name, key, file_name, chunk_size=5,
  54. thread_num=None,
  55. uploadTaskHandle=None,
  56. content_type=None,
  57. storage_class=None,
  58. user_headers=None,
  59. progress_callback=None,
  60. traffic_limit=None,
  61. config=None):
  62. """调用原始的 put_super_obejct_from_file,但这里会使用上面定义的 _upload_task"""
  63. return super().put_super_obejct_from_file(bucket_name, key, file_name, chunk_size,
  64. thread_num, uploadTaskHandle,
  65. content_type, storage_class,
  66. user_headers, progress_callback,
  67. traffic_limit, config)
  68. def _compute_service_id(self):
  69. """需要覆盖父类的方法,否则会报错"""
  70. return "bos"
  71. def sts_client(bos_host, sts_ak, sts_sk, session_token) -> MyBosClient:
  72. """
  73. 获取sts client
  74. """
  75. bos_client = MyBosClient(BceClientConfiguration(
  76. credentials=BceCredentials(sts_ak, sts_sk),
  77. endpoint=bos_host,
  78. security_token=session_token))
  79. return bos_client
  80. def upload_files(bos_client: MyBosClient, bucket: str, files: List[str], key_prefix=""):
  81. """
  82. 上传文件
  83. key_prefix: 上传文件的前缀
  84. """
  85. for file in files:
  86. bos_client.put_super_obejct_from_file(bucket, key_prefix + file, file, chunk_size=5, thread_num=None)
  87. def upload_file(bos_client: MyBosClient, bucket: str, file, key):
  88. """
  89. 上传文件
  90. key: 存储路径
  91. """
  92. return bos_client.put_object_from_file(bucket, key, str(file))
  93. def upload_super_file(bos_client: MyBosClient, bucket: str, file, key):
  94. """
  95. 上传文件
  96. key: 存储路径
  97. """
  98. chunk_size = int(os.environ.get("AISTUDIO_UPLOAD_CHUNK_SIZE_MB", 5))
  99. thread_num = os.environ.get("AISTUDIO_UPLOAD_THREAD_NUM", None)
  100. if thread_num:
  101. thread_num = int(thread_num)
  102. res = bos_client.put_super_obejct_from_file(bucket, key, str(file),
  103. chunk_size=chunk_size,
  104. thread_num=thread_num,
  105. progress_callback=None)
  106. if not res:
  107. log.error("upload file failed: 已经取消或者上传失败,如果上传失败,"
  108. "请配置环境变量 AISTUDIO_UPLOAD_CHUNK_SIZE_MB (int类型,默认为5,单位MB),减小分块大小后重试,"
  109. "例如:export AISTUDIO_UPLOAD_CHUNK_SIZE_MB=3 后重新执行"
  110. "如果带宽过小,需要配置环境变量 AISTUDIO_UPLOAD_THREAD_NUM 减少线程数,防止部分分块上传超时,"
  111. "例如:export AISTUDIO_UPLOAD_THREAD_NUM=1 后重新执行")
  112. return res