pipeline.py 5.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241
  1. # !/usr/bin/env python3
  2. # -*- coding: UTF-8 -*-
  3. ################################################################################
  4. #
  5. # Copyright (c) 2024 Baidu.com, Inc. All Rights Reserved
  6. #
  7. ################################################################################
  8. """
  9. 本文件实现了请求产线任务
  10. Authors: xiangyiqing(xiangyiqing@baidu.com)
  11. Date: 2024/3/2
  12. """
  13. import json
  14. import requests
  15. from aistudio_sdk import config, log
  16. from baidubce.bce_client_configuration import BceClientConfiguration
  17. from baidubce.auth.bce_credentials import BceCredentials
  18. from baidubce.services.bos.bos_client import BosClient
  19. class RequestPipelineException(Exception):
  20. """
  21. exception for requesting pipeline server
  22. """
  23. pass
  24. def _request(
  25. method: str,
  26. url: str,
  27. headers: dict,
  28. params: dict,
  29. data
  30. ):
  31. """request api
  32. :param url: http url
  33. :param headers: dictionary of HTTP Headers to send
  34. :param json_data: json data to send in the body
  35. :param data: dictionary, list of tuples, bytes, or file-like object to send in the body
  36. :return: response data in json format
  37. """
  38. log.debug(f"\n- method: {method}\n- url: {url}\n- headers: {headers}\n- params: {params}\n- data: {data}")
  39. err_msg = ''
  40. for _ in range(config.CONNECTION_RETRY_TIMES):
  41. try:
  42. response = requests.request(
  43. method,
  44. url,
  45. headers=headers,
  46. params=params,
  47. data=data,
  48. timeout=config.CONNECTION_TIMEOUT
  49. )
  50. log.debug(f"\n- response: {response.json()}")
  51. return response.json()
  52. except requests.exceptions.JSONDecodeError:
  53. err_msg = "Response body does not contain valid json: {}".format(response)
  54. except Exception as e:
  55. err_msg = 'Error occurred when request for "{}": {}.'.format(url, str(e))
  56. log.debug(f"\n- err_msg: {err_msg}")
  57. raise RequestPipelineException(err_msg)
  58. def _request_pipepline(
  59. token: str,
  60. method: str,
  61. url: str,
  62. params: dict,
  63. data
  64. ):
  65. """
  66. 请求pp-pipeline API
  67. """
  68. headers = {
  69. 'Content-Type': 'application/json',
  70. 'Authorization': f'token {token}'
  71. }
  72. access_url = f"{config.STUDIO_MODEL_API_URL_PREFIX_DEFAULT}{url}"
  73. return _request(method, access_url, headers, params, data)
  74. def create(
  75. token: str,
  76. name: str,
  77. cmd: str,
  78. env: str,
  79. device: str,
  80. gpus: str,
  81. payment: str,
  82. dataset: dict
  83. ):
  84. """
  85. 请求创建产线
  86. """
  87. body = {
  88. "name": name,
  89. "cmd": cmd,
  90. "env": env,
  91. "device": device,
  92. "gpus": gpus,
  93. "payment": payment,
  94. "dataset": dataset,
  95. }
  96. return _request_pipepline(
  97. token,
  98. "POST",
  99. config.PIPELINE_CREATE_URL,
  100. None,
  101. json.dumps(body)
  102. )
  103. def bosacl(
  104. token: str,
  105. pipeline_id: str
  106. ):
  107. """
  108. 申请ak/sk
  109. """
  110. body = {
  111. 'source': 'SDK',
  112. 'pipelineId': pipeline_id,
  113. }
  114. return _request_pipepline(
  115. token,
  116. "GET",
  117. config.PIPELINE_BOSACL_URL,
  118. body,
  119. None
  120. )
  121. def bosacl_ls_cp(
  122. token: str,
  123. pipeline_id: str
  124. ):
  125. """
  126. 申请ak/sk
  127. """
  128. body = {
  129. 'source': 'customCodeOutput',
  130. 'pipelineId': pipeline_id,
  131. }
  132. return _request_pipepline(
  133. token,
  134. "GET",
  135. config.PIPELINE_BOSACL_URL,
  136. body,
  137. None
  138. )
  139. def bos_upload(
  140. local_file: str,
  141. endpoint: str,
  142. bucket_name: str,
  143. file_key: str,
  144. access_key_id: str,
  145. secret_access_key: str,
  146. session_token: str
  147. ):
  148. """
  149. 本地文件 上传至bos指定位置
  150. """
  151. # sts配置
  152. bos_conf = BceClientConfiguration(
  153. credentials=BceCredentials(access_key_id, secret_access_key),
  154. endpoint=endpoint, # "bj.bcebos.com"
  155. security_token=session_token
  156. )
  157. bos_client = BosClient(bos_conf)
  158. # 从文件中上传的Object
  159. bos_client.put_object_from_file(bucket_name, file_key.lstrip("/"), local_file)
  160. def create_callback(
  161. token: str,
  162. pipeline_id: str,
  163. is_succuss: bool,
  164. file_key: str = None,
  165. file_name: str = None
  166. ):
  167. """
  168. 创建产线回调, 成功or失败
  169. """
  170. body = {
  171. "pipelineId": pipeline_id,
  172. "success": is_succuss,
  173. "fileKey": file_key,
  174. "fileName": file_name, # 真实文件名
  175. }
  176. return _request_pipepline(
  177. token,
  178. "POST",
  179. config.PIPELINE_CREATE_CALLBACK_URL,
  180. None,
  181. json.dumps(body)
  182. )
  183. def query(
  184. token: str,
  185. pipeline_id: str,
  186. name: str,
  187. status: str
  188. ):
  189. """
  190. 查询产线
  191. """
  192. body = {
  193. "pipelineId": pipeline_id,
  194. "pipelineName": name,
  195. "stage": status,
  196. }
  197. return _request_pipepline(
  198. token,
  199. "POST",
  200. config.PIPELINE_QUERY_URL,
  201. None,
  202. json.dumps(body)
  203. )
  204. def stop(
  205. token: str,
  206. pipeline_id: str
  207. ):
  208. """
  209. 停止产线
  210. """
  211. body = {
  212. "pipelineId": pipeline_id,
  213. }
  214. return _request_pipepline(
  215. token,
  216. "POST",
  217. config.PIPELINE_STOP_URL,
  218. None,
  219. json.dumps(body)
  220. )