resolver.py 8.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240
  1. import datetime
  2. import io
  3. import logging
  4. from dataclasses import asdict, dataclass
  5. from typing import Any, Dict, List, Optional, Sequence
  6. import wandb
  7. from wandb.sdk.data_types import trace_tree
  8. from wandb.sdk.integration_utils.auto_logging import Response
  9. logger = logging.getLogger(__name__)
  10. @dataclass
  11. class UsageMetrics:
  12. elapsed_time: float = None
  13. prompt_tokens: int = None
  14. completion_tokens: int = None
  15. total_tokens: int = None
  16. @dataclass
  17. class Metrics:
  18. usage: UsageMetrics = None
  19. stats: wandb.Table = None
  20. trace: trace_tree.WBTraceTree = None
  21. usage_metric_keys = {f"usage/{k}" for k in asdict(UsageMetrics())}
  22. class OpenAIRequestResponseResolver:
  23. def __init__(self):
  24. self.define_metrics_called = False
  25. def __call__(
  26. self,
  27. args: Sequence[Any],
  28. kwargs: Dict[str, Any],
  29. response: Response,
  30. start_time: float, # pass to comply with the protocol, but use response["created"] instead
  31. time_elapsed: float,
  32. ) -> Optional[Dict[str, Any]]:
  33. request = kwargs
  34. if not self.define_metrics_called:
  35. # define metrics on first call
  36. for key in usage_metric_keys:
  37. wandb.define_metric(key, step_metric="_timestamp")
  38. self.define_metrics_called = True
  39. try:
  40. if response.get("object") == "edit":
  41. return self._resolve_edit(request, response, time_elapsed)
  42. elif response.get("object") == "text_completion":
  43. return self._resolve_completion(request, response, time_elapsed)
  44. elif response.get("object") == "chat.completion":
  45. return self._resolve_chat_completion(request, response, time_elapsed)
  46. else:
  47. # todo: properly treat failed requests
  48. logger.info(
  49. f"Unsupported OpenAI response object: {response.get('object')}"
  50. )
  51. except Exception as e:
  52. logger.warning(f"Failed to resolve request/response: {e}")
  53. return None
  54. @staticmethod
  55. def results_to_trace_tree(
  56. request: Dict[str, Any],
  57. response: Response,
  58. results: List[trace_tree.Result],
  59. time_elapsed: float,
  60. ) -> trace_tree.WBTraceTree:
  61. """Converts the request, response, and results into a trace tree.
  62. params:
  63. request: The request dictionary
  64. response: The response object
  65. results: A list of results object
  66. time_elapsed: The time elapsed in seconds
  67. returns:
  68. A wandb trace tree object.
  69. """
  70. start_time_ms = int(round(response["created"] * 1000))
  71. end_time_ms = start_time_ms + int(round(time_elapsed * 1000))
  72. span = trace_tree.Span(
  73. name=f"{response.get('model', 'openai')}_{response['object']}_{response.get('created')}",
  74. attributes=dict(response), # type: ignore
  75. start_time_ms=start_time_ms,
  76. end_time_ms=end_time_ms,
  77. span_kind=trace_tree.SpanKind.LLM,
  78. results=results,
  79. )
  80. model_obj = {"request": request, "response": response, "_kind": "openai"}
  81. return trace_tree.WBTraceTree(root_span=span, model_dict=model_obj)
  82. def _resolve_edit(
  83. self,
  84. request: Dict[str, Any],
  85. response: Response,
  86. time_elapsed: float,
  87. ) -> Dict[str, Any]:
  88. """Resolves the request and response objects for `openai.Edit`."""
  89. request_str = (
  90. f"\n\n**Instruction**: {request['instruction']}\n\n"
  91. f"**Input**: {request['input']}\n"
  92. )
  93. choices = [
  94. f"\n\n**Edited**: {choice['text']}\n" for choice in response["choices"]
  95. ]
  96. return self._resolve_metrics(
  97. request=request,
  98. response=response,
  99. request_str=request_str,
  100. choices=choices,
  101. time_elapsed=time_elapsed,
  102. )
  103. def _resolve_completion(
  104. self,
  105. request: Dict[str, Any],
  106. response: Response,
  107. time_elapsed: float,
  108. ) -> Dict[str, Any]:
  109. """Resolves the request and response objects for `openai.Completion`."""
  110. request_str = f"\n\n**Prompt**: {request['prompt']}\n"
  111. choices = [
  112. f"\n\n**Completion**: {choice['text']}\n" for choice in response["choices"]
  113. ]
  114. return self._resolve_metrics(
  115. request=request,
  116. response=response,
  117. request_str=request_str,
  118. choices=choices,
  119. time_elapsed=time_elapsed,
  120. )
  121. def _resolve_chat_completion(
  122. self,
  123. request: Dict[str, Any],
  124. response: Response,
  125. time_elapsed: float,
  126. ) -> Dict[str, Any]:
  127. """Resolves the request and response objects for `openai.Completion`."""
  128. prompt = io.StringIO()
  129. for message in request["messages"]:
  130. prompt.write(f"\n\n**{message['role']}**: {message['content']}\n")
  131. request_str = prompt.getvalue()
  132. choices = [
  133. f"\n\n**{choice['message']['role']}**: {choice['message']['content']}\n"
  134. for choice in response["choices"]
  135. ]
  136. return self._resolve_metrics(
  137. request=request,
  138. response=response,
  139. request_str=request_str,
  140. choices=choices,
  141. time_elapsed=time_elapsed,
  142. )
  143. def _resolve_metrics(
  144. self,
  145. request: Dict[str, Any],
  146. response: Response,
  147. request_str: str,
  148. choices: List[str],
  149. time_elapsed: float,
  150. ) -> Dict[str, Any]:
  151. """Resolves the request and response objects for `openai.Completion`."""
  152. results = [
  153. trace_tree.Result(
  154. inputs={"request": request_str},
  155. outputs={"response": choice},
  156. )
  157. for choice in choices
  158. ]
  159. metrics = self._get_metrics_to_log(request, response, results, time_elapsed)
  160. return self._convert_metrics_to_dict(metrics)
  161. @staticmethod
  162. def _get_usage_metrics(response: Response, time_elapsed: float) -> UsageMetrics:
  163. """Gets the usage stats from the response object."""
  164. if response.get("usage"):
  165. usage_stats = UsageMetrics(**response["usage"])
  166. else:
  167. usage_stats = UsageMetrics()
  168. usage_stats.elapsed_time = time_elapsed
  169. return usage_stats
  170. def _get_metrics_to_log(
  171. self,
  172. request: Dict[str, Any],
  173. response: Response,
  174. results: List[Any],
  175. time_elapsed: float,
  176. ) -> Metrics:
  177. model = response.get("model") or request.get("model")
  178. usage_metrics = self._get_usage_metrics(response, time_elapsed)
  179. usage = []
  180. for result in results:
  181. row = {
  182. "request": result.inputs["request"],
  183. "response": result.outputs["response"],
  184. "model": model,
  185. "start_time": datetime.datetime.fromtimestamp(response["created"]),
  186. "end_time": datetime.datetime.fromtimestamp(
  187. response["created"] + time_elapsed
  188. ),
  189. "request_id": response.get("id", None),
  190. "api_type": response.get("api_type", "openai"),
  191. "session_id": wandb.run.id,
  192. }
  193. row.update(asdict(usage_metrics))
  194. usage.append(row)
  195. usage_table = wandb.Table(
  196. columns=list(usage[0].keys()),
  197. data=[(item.values()) for item in usage],
  198. )
  199. trace = self.results_to_trace_tree(request, response, results, time_elapsed)
  200. metrics = Metrics(stats=usage_table, trace=trace, usage=usage_metrics)
  201. return metrics
  202. @staticmethod
  203. def _convert_metrics_to_dict(metrics: Metrics) -> Dict[str, Any]:
  204. """Converts metrics to a dict."""
  205. metrics_dict = {
  206. "stats": metrics.stats,
  207. "trace": metrics.trace,
  208. }
  209. usage_stats = {f"usage/{k}": v for k, v in asdict(metrics.usage).items()}
  210. metrics_dict.update(usage_stats)
  211. return metrics_dict