resolver.py 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347
  1. import logging
  2. from datetime import datetime
  3. from typing import Any, Dict, List, Optional, Sequence, Tuple
  4. import wandb
  5. from wandb.sdk.integration_utils.auto_logging import Response
  6. from wandb.sdk.lib.runid import generate_id
  7. logger = logging.getLogger(__name__)
  8. def subset_dict(
  9. original_dict: Dict[str, Any], keys_subset: Sequence[str]
  10. ) -> Dict[str, Any]:
  11. """Create a subset of a dictionary using a subset of keys.
  12. :param original_dict: The original dictionary.
  13. :param keys_subset: The subset of keys to extract.
  14. :return: A dictionary containing only the specified keys.
  15. """
  16. return {key: original_dict[key] for key in keys_subset if key in original_dict}
  17. def reorder_and_convert_dict_list_to_table(
  18. data: List[Dict[str, Any]], order: List[str]
  19. ) -> Tuple[List[str], List[List[Any]]]:
  20. """Convert a list of dictionaries to a pair of column names and corresponding values, with the option to order specific dictionaries.
  21. :param data: A list of dictionaries.
  22. :param order: A list of keys specifying the desired order for specific dictionaries. The remaining dictionaries will be ordered based on their original order.
  23. :return: A pair of column names and corresponding values.
  24. """
  25. final_columns = []
  26. keys_present = set()
  27. # First, add all ordered keys to the final columns
  28. for key in order:
  29. if key not in keys_present:
  30. final_columns.append(key)
  31. keys_present.add(key)
  32. # Then, add any keys present in the dictionaries but not in the order
  33. for d in data:
  34. for key in d:
  35. if key not in keys_present:
  36. final_columns.append(key)
  37. keys_present.add(key)
  38. # Then, construct the table of values
  39. values = []
  40. for d in data:
  41. row = []
  42. for key in final_columns:
  43. row.append(d.get(key, None))
  44. values.append(row)
  45. return final_columns, values
  46. def flatten_dict(
  47. dictionary: Dict[str, Any], parent_key: str = "", sep: str = "-"
  48. ) -> Dict[str, Any]:
  49. """Flatten a nested dictionary, joining keys using a specified separator.
  50. :param dictionary: The dictionary to flatten.
  51. :param parent_key: The base key to prepend to each key.
  52. :param sep: The separator to use when joining keys.
  53. :return: A flattened dictionary.
  54. """
  55. flattened_dict = {}
  56. for key, value in dictionary.items():
  57. new_key = f"{parent_key}{sep}{key}" if parent_key else key
  58. if isinstance(value, dict):
  59. flattened_dict.update(flatten_dict(value, new_key, sep=sep))
  60. else:
  61. flattened_dict[new_key] = value
  62. return flattened_dict
  63. def collect_common_keys(list_of_dicts: List[Dict[str, Any]]) -> Dict[str, List[Any]]:
  64. """Collect the common keys of a list of dictionaries. For each common key, put its values into a list in the order they appear in the original dictionaries.
  65. :param list_of_dicts: The list of dictionaries to inspect.
  66. :return: A dictionary with each common key and its corresponding list of values.
  67. """
  68. common_keys = set.intersection(*map(set, list_of_dicts))
  69. common_dict = {key: [] for key in common_keys}
  70. for d in list_of_dicts:
  71. for key in common_keys:
  72. common_dict[key].append(d[key])
  73. return common_dict
  74. class CohereRequestResponseResolver:
  75. """Class to resolve the request/response from the Cohere API and convert it to a dictionary that can be logged."""
  76. def __call__(
  77. self,
  78. args: Sequence[Any],
  79. kwargs: Dict[str, Any],
  80. response: Response,
  81. start_time: float,
  82. time_elapsed: float,
  83. ) -> Optional[Dict[str, Any]]:
  84. """Process the response from the Cohere API and convert it to a dictionary that can be logged.
  85. :param args: The arguments of the original function.
  86. :param kwargs: The keyword arguments of the original function.
  87. :param response: The response from the Cohere API.
  88. :param start_time: The start time of the request.
  89. :param time_elapsed: The time elapsed for the request.
  90. :return: A dictionary containing the parsed response and timing information.
  91. """
  92. try:
  93. # Each of the different endpoints map to one specific response type
  94. # We want to 'type check' the response without directly importing the packages type
  95. # It may make more sense to pass the invoked symbol from the AutologAPI instead
  96. response_type = str(type(response)).split("'")[1].split(".")[-1]
  97. # Initialize parsed_response to None to handle the case where the response type is unsupported
  98. parsed_response = None
  99. if response_type == "Generations":
  100. parsed_response = self._resolve_generate_response(response)
  101. # TODO: Remove hard-coded default model name
  102. table_column_order = [
  103. "start_time",
  104. "query_id",
  105. "model",
  106. "prompt",
  107. "text",
  108. "token_likelihoods",
  109. "likelihood",
  110. "time_elapsed_(seconds)",
  111. "end_time",
  112. ]
  113. default_model = "command"
  114. elif response_type == "Chat":
  115. parsed_response = self._resolve_chat_response(response)
  116. table_column_order = [
  117. "start_time",
  118. "query_id",
  119. "model",
  120. "conversation_id",
  121. "response_id",
  122. "query",
  123. "text",
  124. "prompt",
  125. "preamble",
  126. "chat_history",
  127. "chatlog",
  128. "time_elapsed_(seconds)",
  129. "end_time",
  130. ]
  131. default_model = "command"
  132. elif response_type == "Classifications":
  133. parsed_response = self._resolve_classify_response(response)
  134. kwargs = self._resolve_classify_kwargs(kwargs)
  135. table_column_order = [
  136. "start_time",
  137. "query_id",
  138. "model",
  139. "id",
  140. "input",
  141. "prediction",
  142. "confidence",
  143. "time_elapsed_(seconds)",
  144. "end_time",
  145. ]
  146. default_model = "embed-english-v2.0"
  147. elif response_type == "SummarizeResponse":
  148. parsed_response = self._resolve_summarize_response(response)
  149. table_column_order = [
  150. "start_time",
  151. "query_id",
  152. "model",
  153. "response_id",
  154. "text",
  155. "additional_command",
  156. "summary",
  157. "time_elapsed_(seconds)",
  158. "end_time",
  159. "length",
  160. "format",
  161. ]
  162. default_model = "summarize-xlarge"
  163. elif response_type == "Reranking":
  164. parsed_response = self._resolve_rerank_response(response)
  165. table_column_order = [
  166. "start_time",
  167. "query_id",
  168. "model",
  169. "id",
  170. "query",
  171. "top_n",
  172. # This is a nested dict key that got flattened
  173. "document-text",
  174. "relevance_score",
  175. "index",
  176. "time_elapsed_(seconds)",
  177. "end_time",
  178. ]
  179. default_model = "rerank-english-v2.0"
  180. else:
  181. logger.info(f"Unsupported Cohere response object: {response}")
  182. return self._resolve(
  183. args,
  184. kwargs,
  185. parsed_response,
  186. start_time,
  187. time_elapsed,
  188. response_type,
  189. table_column_order,
  190. default_model,
  191. )
  192. except Exception as e:
  193. logger.warning(f"Failed to resolve request/response: {e}")
  194. return None
  195. # These helper functions process the response from different endpoints of the Cohere API.
  196. # Since the response objects for different endpoints have different structures,
  197. # we need different logic to process them.
  198. def _resolve_generate_response(self, response: Response) -> List[Dict[str, Any]]:
  199. return_list = []
  200. for _response in response:
  201. # Built in Cohere.*.Generations function to color token_likelihoods and return a dict of response data
  202. _response_dict = _response._visualize_helper()
  203. try:
  204. _response_dict["token_likelihoods"] = wandb.Html(
  205. _response_dict["token_likelihoods"]
  206. )
  207. except (KeyError, ValueError):
  208. pass
  209. return_list.append(_response_dict)
  210. return return_list
  211. def _resolve_chat_response(self, response: Response) -> List[Dict[str, Any]]:
  212. return [
  213. subset_dict(
  214. response.__dict__,
  215. [
  216. "response_id",
  217. "generation_id",
  218. "query",
  219. "text",
  220. "conversation_id",
  221. "prompt",
  222. "chatlog",
  223. "preamble",
  224. ],
  225. )
  226. ]
  227. def _resolve_classify_response(self, response: Response) -> List[Dict[str, Any]]:
  228. # The labels key is a dict returning the scores for the classification probability for each label provided
  229. # We flatten this nested dict for ease of consumption in the wandb UI
  230. return [flatten_dict(_response.__dict__) for _response in response]
  231. def _resolve_classify_kwargs(self, kwargs: Dict[str, Any]) -> Dict[str, Any]:
  232. # Example texts look strange when rendered in Wandb UI as it is a list of text and label
  233. # We extract each value into its own column
  234. example_texts = []
  235. example_labels = []
  236. for example in kwargs["examples"]:
  237. example_texts.append(example.text)
  238. example_labels.append(example.label)
  239. kwargs.pop("examples")
  240. kwargs["example_texts"] = example_texts
  241. kwargs["example_labels"] = example_labels
  242. return kwargs
  243. def _resolve_summarize_response(self, response: Response) -> List[Dict[str, Any]]:
  244. return [{"response_id": response.id, "summary": response.summary}]
  245. def _resolve_rerank_response(self, response: Response) -> List[Dict[str, Any]]:
  246. # The documents key contains a dict containing the content of the document which is at least "text"
  247. # We flatten this nested dict for ease of consumption in the wandb UI
  248. flattened_response_dicts = [
  249. flatten_dict(_response.__dict__) for _response in response
  250. ]
  251. # ReRank returns each document provided a top_n value so we aggregate into one view so users can paginate a row
  252. # As opposed to each row being one of the top_n responses
  253. return_dict = collect_common_keys(flattened_response_dicts)
  254. return_dict["id"] = response.id
  255. return [return_dict]
  256. def _resolve(
  257. self,
  258. args: Sequence[Any],
  259. kwargs: Dict[str, Any],
  260. parsed_response: List[Dict[str, Any]],
  261. start_time: float,
  262. time_elapsed: float,
  263. response_type: str,
  264. table_column_order: List[str],
  265. default_model: str,
  266. ) -> Dict[str, Any]:
  267. """Convert a list of dictionaries to a pair of column names and corresponding values, with the option to order specific dictionaries.
  268. :param args: The arguments passed to the API client.
  269. :param kwargs: The keyword arguments passed to the API client.
  270. :param parsed_response: The parsed response from the API.
  271. :param start_time: The start time of the API request.
  272. :param time_elapsed: The time elapsed during the API request.
  273. :param response_type: The type of the API response.
  274. :param table_column_order: The desired order of columns in the resulting table.
  275. :param default_model: The default model to use if not specified in the response.
  276. :return: A dictionary containing the formatted response.
  277. """
  278. # Args[0] is the client object where we can grab specific metadata about the underlying API status
  279. query_id = generate_id(length=16)
  280. parsed_args = subset_dict(
  281. args[0].__dict__,
  282. ["api_version", "batch_size", "max_retries", "num_workers", "timeout"],
  283. )
  284. start_time_dt = datetime.fromtimestamp(start_time)
  285. end_time_dt = datetime.fromtimestamp(start_time + time_elapsed)
  286. timings = {
  287. "start_time": start_time_dt,
  288. "end_time": end_time_dt,
  289. "time_elapsed_(seconds)": time_elapsed,
  290. }
  291. packed_data = []
  292. for _parsed_response in parsed_response:
  293. _packed_dict = {
  294. "query_id": query_id,
  295. **kwargs,
  296. **_parsed_response,
  297. **timings,
  298. **parsed_args,
  299. }
  300. if "model" not in _packed_dict:
  301. _packed_dict["model"] = default_model
  302. packed_data.append(_packed_dict)
  303. columns, data = reorder_and_convert_dict_list_to_table(
  304. packed_data, table_column_order
  305. )
  306. request_response_table = wandb.Table(data=data, columns=columns)
  307. return {f"{response_type}": request_response_table}