switch_downoad.py 2.0 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455
  1. # !/usr/bin/env python3
  2. # -*- coding: UTF-8 -*-
  3. ################################################################################
  4. #
  5. # Copyright (c) 2025 Baidu.com, Inc. All Rights Reserved
  6. #
  7. ################################################################################
  8. """
  9. 本文件实现了sdk cdn下载的功能
  10. Authors: zhaoqingtao(zhaoqingtaog@baidu.com)
  11. Date: 2025/05/23
  12. """
  13. import re
  14. import os
  15. import copy
  16. import requests
  17. from urllib.parse import urlparse, urlunparse
  18. from aistudio_sdk import config
  19. def switch_cdn(url, headers, get_headers):
  20. """
  21. switch to cdn host
  22. """
  23. headers_range = {} if headers is None else copy.deepcopy(headers)
  24. headers_range['Range'] = f'bytes=0-1'
  25. response = requests.get(url, headers=headers_range, stream=True,
  26. timeout=config.CONNECTION_TIMEOUT, allow_redirects=False)
  27. if response.status_code == 307 and response.headers.get("Location").startswith('/'):
  28. url_parsed = urlparse(url)
  29. new_parts = url_parsed._replace(path=response.headers.get("Location"), params='', query='', fragment='')
  30. response = requests.get(urlunparse(new_parts), headers=headers_range, stream=True,
  31. timeout=config.CONNECTION_TIMEOUT, allow_redirects=False)
  32. match = re.search(r"/repos/([^/]+)/", url)
  33. paddle_repo = False
  34. if match:
  35. repo_name = match.group(1)
  36. if "paddlepaddle" == repo_name.lower() or "baidu" == repo_name.lower():
  37. paddle_repo = True
  38. if response.is_redirect:
  39. redirect_url = response.headers.get("Location")
  40. parsed = urlparse(redirect_url)
  41. cdn_host = os.getenv("STUDIO_CDN_HOST")
  42. if cdn_host:
  43. new_host = cdn_host
  44. elif paddle_repo:
  45. new_host = config.UNLIMITED_HOST
  46. else:
  47. new_host = config.LIMITED_HOST
  48. parsed = parsed._replace(netloc=new_host)
  49. new_url = urlunparse(parsed)
  50. get_headers.pop("Authorization", None)
  51. return new_url
  52. return url