| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101 |
- # Copyright (c) Alibaba, Inc. and its affiliates.
- import json
- import os
- from typing import Dict, List, Optional, Tuple
- from .utils import split_str_parts_by, split_parts_by_regex
- def calculate_loss_scale(query: str,
- response: str,
- response_loss_scale_map: Optional[Dict[str, list]] = None,
- query_loss_scale_map: Optional[Dict[str, list]] = None) -> Tuple[List[str], List[float]]:
- """Calculate the loss scale by splitting the agent response.
- This algorithm comes from paper: https://arxiv.org/pdf/2309.00986.pdf
- Agent response format:
- ```text
- Thought: you should always think about what to do
- Action: the action to take, should be one of the above tools[fire_recognition,
- fire_alert, call_police, call_fireman]
- Action Input: the input to the action
- Observation: the result of the action
- ... (this Thought/Action/Action Input/Observation can be repeated zero or more times)
- Thought: I now know the final answer
- Final Answer: the final answer to the original input question
- ```
- Returns:
- A tuple of agent response parts and their weights.
- """
- # query loss scale map
- if query_loss_scale_map is not None:
- for key in query_loss_scale_map.keys():
- if key in query:
- if isinstance(query_loss_scale_map[key], (float, int)):
- query_loss_scale_map[key] = [query_loss_scale_map[key]]
- loss_scale_value = query_loss_scale_map[key][0]
- return [response], [float(loss_scale_value)]
- delimiters = list(k for k in response_loss_scale_map.keys() if len(response_loss_scale_map[k]) == 2)
- agent_parts = split_str_parts_by(response, delimiters)
- regex_delimiters = {k: v for k, v in response_loss_scale_map.items() if len(v) == 1}
- if len(regex_delimiters):
- split_parts_by_regex(agent_parts, regex_delimiters)
- weights = []
- agent_content = []
- for c in agent_parts:
- if isinstance(c['key'], (float, int)):
- weights += [c['key']]
- agent_content.append(c['content'])
- else:
- if c['key'] in response_loss_scale_map:
- weights += [response_loss_scale_map[c['key']][0]]
- weights += [response_loss_scale_map[c['key']][1]]
- agent_content.append(c['key'])
- agent_content.append(c['content'])
- else:
- weights += [1.0]
- agent_content.append(c['content'])
- return agent_content, weights
- def alpha_umi_loss_scale(query: str, response: str):
- cwd = os.getcwd()
- loss_scale_config_path = 'alpha_umi_loss_scale_config.json'
- config_path = os.path.join(cwd, loss_scale_config_path)
- with open(config_path, 'r') as json_file:
- loss_scale_map = json.load(json_file)
- return calculate_loss_scale(query, response, loss_scale_map)
- def agentflan_loss_scale(query: str, response: str):
- cwd = os.getcwd()
- loss_scale_config_path = 'agentflan.json'
- config_path = os.path.join(cwd, loss_scale_config_path)
- with open(config_path, 'r') as json_file:
- loss_scale_map = json.load(json_file)
- query_loss_scale_map = loss_scale_map['query']
- response_loss_scale_map = loss_scale_map['response']
- return calculate_loss_scale(query, response, response_loss_scale_map, query_loss_scale_map)
- def react_loss_scale(query: str, response: str):
- cwd = os.getcwd()
- loss_scale_config_path = 'default_loss_scale_config.json'
- config_path = os.path.join(cwd, loss_scale_config_path)
- with open(config_path, 'r') as json_file:
- loss_scale_map = json.load(json_file)
- return calculate_loss_scale(query, response, loss_scale_map)
- def default_loss_scale(query: str, response: str):
- return [response], [1.0]
- loss_scale_map = {
- 'agentflan': agentflan_loss_scale,
- 'react': react_loss_scale,
- 'alpha_umi': alpha_umi_loss_scale,
- 'default': default_loss_scale,
- }
|