resolver.py 7.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213
  1. import logging
  2. import os
  3. from datetime import datetime
  4. from typing import Any, Dict, List, Optional, Sequence, Tuple, Union
  5. import pytz
  6. import wandb
  7. from wandb.sdk.integration_utils.auto_logging import Response
  8. from wandb.sdk.lib.runid import generate_id
  9. logger = logging.getLogger(__name__)
  10. SUPPORTED_PIPELINE_TASKS = [
  11. "text-classification",
  12. "sentiment-analysis",
  13. "question-answering",
  14. "summarization",
  15. "translation",
  16. "text2text-generation",
  17. "text-generation",
  18. # "conversational",
  19. ]
  20. PIPELINES_WITH_TOP_K = [
  21. "text-classification",
  22. "sentiment-analysis",
  23. "question-answering",
  24. ]
  25. class HuggingFacePipelineRequestResponseResolver:
  26. """Resolver for HuggingFace's pipeline request and responses, providing necessary data transformations and formatting.
  27. This is based off (from wandb.sdk.integration_utils.auto_logging import RequestResponseResolver)
  28. """
  29. autolog_id = None
  30. def __call__(
  31. self,
  32. args: Sequence[Any],
  33. kwargs: Dict[str, Any],
  34. response: Response,
  35. start_time: float,
  36. time_elapsed: float,
  37. ) -> Optional[Dict[str, Any]]:
  38. """Main call method for this class.
  39. :param args: list of arguments
  40. :param kwargs: dictionary of keyword arguments
  41. :param response: the response from the request
  42. :param start_time: time when request started
  43. :param time_elapsed: time elapsed for the request
  44. :returns: packed data as a dictionary for logging to wandb, None if an exception occurred
  45. """
  46. try:
  47. pipe, input_data = args[:2]
  48. task = pipe.task
  49. # Translation tasks are in the form of `translation_x_to_y`
  50. if task in SUPPORTED_PIPELINE_TASKS or task.startswith("translation"):
  51. model = self._get_model(pipe)
  52. if model is None:
  53. return None
  54. model_alias = model.name_or_path
  55. timestamp = datetime.now(pytz.utc)
  56. input_data, response = self._transform_task_specific_data(
  57. task, input_data, response
  58. )
  59. formatted_data = self._format_data(task, input_data, response, kwargs)
  60. packed_data = self._create_table(
  61. formatted_data, model_alias, timestamp, time_elapsed
  62. )
  63. table_name = os.environ.get("WANDB_AUTOLOG_TABLE_NAME", f"{task}")
  64. # TODO: Let users decide the name in a way that does not use an environment variable
  65. return {
  66. table_name: wandb.Table(
  67. columns=packed_data[0], data=packed_data[1:]
  68. )
  69. }
  70. logger.warning(
  71. f"The task: `{task}` is not yet supported.\nPlease contact `wandb` to notify us if you would like support for this task"
  72. )
  73. except Exception as e:
  74. logger.warning(e)
  75. return None
  76. # TODO: This should have a dependency on PreTrainedModel. i.e. isinstance(PreTrainedModel)
  77. # from transformers.modeling_utils import PreTrainedModel
  78. # We do not want this dependency explicitly in our codebase so we make a very general
  79. # assumption about the structure of the pipeline which may have unintended consequences
  80. def _get_model(self, pipe) -> Optional[Any]:
  81. """Extracts model from the pipeline.
  82. :param pipe: the HuggingFace pipeline
  83. :returns: Model if available, None otherwise
  84. """
  85. model = pipe.model
  86. try:
  87. return model.model
  88. except AttributeError:
  89. logger.info(
  90. "Model does not have a `.model` attribute. Assuming `pipe.model` is the correct model."
  91. )
  92. return model
  93. @staticmethod
  94. def _transform_task_specific_data(
  95. task: str, input_data: Union[List[Any], Any], response: Union[List[Any], Any]
  96. ) -> Tuple[Union[List[Any], Any], Union[List[Any], Any]]:
  97. """Transform input and response data based on specific tasks.
  98. :param task: the task name
  99. :param input_data: the input data
  100. :param response: the response data
  101. :returns: tuple of transformed input_data and response
  102. """
  103. if task == "question-answering":
  104. input_data = input_data if isinstance(input_data, list) else [input_data]
  105. input_data = [data.__dict__ for data in input_data]
  106. elif task == "conversational":
  107. # We only grab the latest input/output pair from the conversation
  108. # Logging the whole conversation renders strangely.
  109. input_data = input_data if isinstance(input_data, list) else [input_data]
  110. input_data = [data.__dict__["past_user_inputs"][-1] for data in input_data]
  111. response = response if isinstance(response, list) else [response]
  112. response = [data.__dict__["generated_responses"][-1] for data in response]
  113. return input_data, response
  114. def _format_data(
  115. self,
  116. task: str,
  117. input_data: Union[List[Any], Any],
  118. response: Union[List[Any], Any],
  119. kwargs: Dict[str, Any],
  120. ) -> List[Dict[str, Any]]:
  121. """Formats input data, response, and kwargs into a list of dictionaries.
  122. :param task: the task name
  123. :param input_data: the input data
  124. :param response: the response data
  125. :param kwargs: dictionary of keyword arguments
  126. :returns: list of dictionaries containing formatted data
  127. """
  128. input_data = input_data if isinstance(input_data, list) else [input_data]
  129. response = response if isinstance(response, list) else [response]
  130. formatted_data = []
  131. for i_text, r_text in zip(input_data, response):
  132. # Unpack single element responses for better rendering in wandb UI when it is a task without top_k
  133. # top_k = 1 would unpack the response into a single element while top_k > 1 would be a list
  134. # this would cause the UI to not properly concatenate the tables of the same task by omitting the elements past the first
  135. if (
  136. (isinstance(r_text, list))
  137. and (len(r_text) == 1)
  138. and task not in PIPELINES_WITH_TOP_K
  139. ):
  140. r_text = r_text[0]
  141. formatted_data.append(
  142. {"input": i_text, "response": r_text, "kwargs": kwargs}
  143. )
  144. return formatted_data
  145. def _create_table(
  146. self,
  147. formatted_data: List[Dict[str, Any]],
  148. model_alias: str,
  149. timestamp: float,
  150. time_elapsed: float,
  151. ) -> List[List[Any]]:
  152. """Creates a table from formatted data, model alias, timestamp, and elapsed time.
  153. :param formatted_data: list of dictionaries containing formatted data
  154. :param model_alias: alias of the model
  155. :param timestamp: timestamp of the data
  156. :param time_elapsed: time elapsed from the beginning
  157. :returns: list of lists, representing a table of data. [0]th element = columns. [1]st element = data
  158. """
  159. header = [
  160. "ID",
  161. "Model Alias",
  162. "Timestamp",
  163. "Elapsed Time",
  164. "Input",
  165. "Response",
  166. "Kwargs",
  167. ]
  168. table = [header]
  169. autolog_id = generate_id(length=16)
  170. for data in formatted_data:
  171. row = [
  172. autolog_id,
  173. model_alias,
  174. timestamp,
  175. time_elapsed,
  176. data["input"],
  177. data["response"],
  178. data["kwargs"],
  179. ]
  180. table.append(row)
  181. self.autolog_id = autolog_id
  182. return table
  183. def get_latest_id(self):
  184. return self.autolog_id