loss_scale.py 4.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101
  1. # Copyright (c) Alibaba, Inc. and its affiliates.
  2. import json
  3. import os
  4. from typing import Dict, List, Optional, Tuple
  5. from .utils import split_str_parts_by, split_parts_by_regex
  6. def calculate_loss_scale(query: str,
  7. response: str,
  8. response_loss_scale_map: Optional[Dict[str, list]] = None,
  9. query_loss_scale_map: Optional[Dict[str, list]] = None) -> Tuple[List[str], List[float]]:
  10. """Calculate the loss scale by splitting the agent response.
  11. This algorithm comes from paper: https://arxiv.org/pdf/2309.00986.pdf
  12. Agent response format:
  13. ```text
  14. Thought: you should always think about what to do
  15. Action: the action to take, should be one of the above tools[fire_recognition,
  16. fire_alert, call_police, call_fireman]
  17. Action Input: the input to the action
  18. Observation: the result of the action
  19. ... (this Thought/Action/Action Input/Observation can be repeated zero or more times)
  20. Thought: I now know the final answer
  21. Final Answer: the final answer to the original input question
  22. ```
  23. Returns:
  24. A tuple of agent response parts and their weights.
  25. """
  26. # query loss scale map
  27. if query_loss_scale_map is not None:
  28. for key in query_loss_scale_map.keys():
  29. if key in query:
  30. if isinstance(query_loss_scale_map[key], (float, int)):
  31. query_loss_scale_map[key] = [query_loss_scale_map[key]]
  32. loss_scale_value = query_loss_scale_map[key][0]
  33. return [response], [float(loss_scale_value)]
  34. delimiters = list(k for k in response_loss_scale_map.keys() if len(response_loss_scale_map[k]) == 2)
  35. agent_parts = split_str_parts_by(response, delimiters)
  36. regex_delimiters = {k: v for k, v in response_loss_scale_map.items() if len(v) == 1}
  37. if len(regex_delimiters):
  38. split_parts_by_regex(agent_parts, regex_delimiters)
  39. weights = []
  40. agent_content = []
  41. for c in agent_parts:
  42. if isinstance(c['key'], (float, int)):
  43. weights += [c['key']]
  44. agent_content.append(c['content'])
  45. else:
  46. if c['key'] in response_loss_scale_map:
  47. weights += [response_loss_scale_map[c['key']][0]]
  48. weights += [response_loss_scale_map[c['key']][1]]
  49. agent_content.append(c['key'])
  50. agent_content.append(c['content'])
  51. else:
  52. weights += [1.0]
  53. agent_content.append(c['content'])
  54. return agent_content, weights
  55. def alpha_umi_loss_scale(query: str, response: str):
  56. cwd = os.getcwd()
  57. loss_scale_config_path = 'alpha_umi_loss_scale_config.json'
  58. config_path = os.path.join(cwd, loss_scale_config_path)
  59. with open(config_path, 'r') as json_file:
  60. loss_scale_map = json.load(json_file)
  61. return calculate_loss_scale(query, response, loss_scale_map)
  62. def agentflan_loss_scale(query: str, response: str):
  63. cwd = os.getcwd()
  64. loss_scale_config_path = 'agentflan.json'
  65. config_path = os.path.join(cwd, loss_scale_config_path)
  66. with open(config_path, 'r') as json_file:
  67. loss_scale_map = json.load(json_file)
  68. query_loss_scale_map = loss_scale_map['query']
  69. response_loss_scale_map = loss_scale_map['response']
  70. return calculate_loss_scale(query, response, response_loss_scale_map, query_loss_scale_map)
  71. def react_loss_scale(query: str, response: str):
  72. cwd = os.getcwd()
  73. loss_scale_config_path = 'default_loss_scale_config.json'
  74. config_path = os.path.join(cwd, loss_scale_config_path)
  75. with open(config_path, 'r') as json_file:
  76. loss_scale_map = json.load(json_file)
  77. return calculate_loss_scale(query, response, loss_scale_map)
  78. def default_loss_scale(query: str, response: str):
  79. return [response], [1.0]
  80. loss_scale_map = {
  81. 'agentflan': agentflan_loss_scale,
  82. 'react': react_loss_scale,
  83. 'alpha_umi': alpha_umi_loss_scale,
  84. 'default': default_loss_scale,
  85. }