# Copyright 2021-2022 The Alibaba DAMO NLP Team Authors. # Copyright from https://github.com/thu-spmi/LABES # Copyright from https://github.com/TonyNemo/UBAR-MultiWOZ # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import math from collections import Counter import json import numpy as np from nltk.util import ngrams from sklearn.metrics import f1_score from modelscope.utils.nlp.space import ontology, utils from modelscope.utils.nlp.space.clean_dataset import clean_slot_values def similar(a, b): return a == b or a in b or b in a or a.split()[0] == b.split( )[0] or a.split()[-1] == b.split()[-1] def setsub(a, b): junks_a = [] useless_constraint = [ 'temperature', 'week', 'est ', 'quick', 'reminder', 'near' ] for i in a: flg = False for j in b: if similar(i, j): flg = True if not flg: junks_a.append(i) for junk in junks_a: flg = False for item in useless_constraint: if item in junk: flg = True if not flg: return False return True def setsim(a, b): a, b = set(a), set(b) return setsub(a, b) and setsub(b, a) def DA_evaluate(preds, labels): preds = np.array(preds) labels = np.array(labels) results = {} for avg_name in ['micro']: my_f1_score = f1_score(y_true=labels, y_pred=preds, average=avg_name) results['f1_{}'.format(avg_name)] = my_f1_score return results class BLEUScorer(object): # BLEU score calculator via GentScorer interface # it calculates the BLEU-4 by taking the entire corpus in # Calculate based multiple candidates against multiple references def __init__(self): pass def score(self, parallel_corpus): # containers count = [0, 0, 0, 0] clip_count = [0, 0, 0, 0] r = 0 c = 0 weights = [0.25, 0.25, 0.25, 0.25] # accumulate ngram statistics for hyps, refs in parallel_corpus: hyps = [hyp.split() for hyp in hyps] refs = [ref.split() for ref in refs] for hyp in hyps: for i in range(4): # accumulate ngram counts hypcnts = Counter(ngrams(hyp, i + 1)) cnt = sum(hypcnts.values()) count[i] += cnt # compute clipped counts max_counts = {} for ref in refs: refcnts = Counter(ngrams(ref, i + 1)) for ng in hypcnts: max_counts[ng] = max( max_counts.get(ng, 0), refcnts[ng]) clipcnt = \ dict((ng, min(count, max_counts[ng])) for ng, count in hypcnts.items()) clip_count[i] += sum(clipcnt.values()) # accumulate r & c bestmatch = [1000, 1000] for ref in refs: if bestmatch[0] == 0: break diff = abs(len(ref) - len(hyp)) if diff < bestmatch[0]: bestmatch[0] = diff bestmatch[1] = len(ref) r += bestmatch[1] c += len(hyp) # computing bleu score p0 = 1e-7 bp = \ 1 if c > r else math.exp(1 - float(r) / float(c)) p_ns = \ [float(clip_count[i]) / float(count[i] + p0) + p0 for i in range(4)] s = \ math.fsum(w * math.log(p_n) for w, p_n in zip(weights, p_ns) if p_n) bleu = bp * math.exp(s) return bleu * 100 """" For the data preparation and evaluation on MultiWOZ2.0/2.1, we refer to the code of UBAR (https://github.com/TonyNemo/UBAR-MultiWOZ) """ class MultiWOZEvaluator(object): def __init__(self, reader, **kwargs): self.reader = reader self.domains = ontology.all_domains self.all_data = self.reader.data self.test_data = self.reader.test self.bleu_scorer = BLEUScorer() self.all_info_slot = [] for d, s_list in ontology.informable_slots.items(): for s in s_list: self.all_info_slot.append(d + '-' + s) # only evaluate these slots for dialog success self.requestables = ['phone', 'address', 'postcode', 'reference', 'id'] self.db_dir = kwargs['data_dir'] def pack_dial(self, data): dials = {} for turn in data: dial_id = turn['dial_id'] if dial_id not in dials: dials[dial_id] = [] dials[dial_id].append(turn) return dials def validation_metric(self, data, fout=None): bleu = self.bleu_metric(data) # accu_single_dom, accu_multi_dom, multi_dom_num = self.domain_eval(data) success, match, req_offer_counts, dial_num = \ self.context_to_response_eval(data, same_eval_as_cambridge=True, fout=fout) return bleu, success, match def bleu_metric(self, data, eval_dial_list=None): gen, truth = [], [] for row in data: if eval_dial_list and row[ 'dial_id'] + '.json' not in eval_dial_list: continue gen.append(row['resp_gen']) truth.append(row['resp']) wrap_generated = [[_] for _ in gen] wrap_truth = [[_] for _ in truth] if gen and truth: try: sc = self.bleu_scorer.score(zip(wrap_generated, wrap_truth)) except Exception: sc = 0.0 else: sc = 0.0 return sc def context_to_response_eval(self, data, eval_dial_list=None, same_eval_as_cambridge=False, fout=None): dials = self.pack_dial(data) counts = {} for req in self.requestables: counts[req + '_total'] = 0 counts[req + '_offer'] = 0 dial_num, successes, matches = 0, 0, 0 for dial_id in dials: if eval_dial_list and dial_id + '.json' not in eval_dial_list: continue dial = dials[dial_id] reqs = {} goal = {} if '.json' not in dial_id and '.json' in list( self.all_data.keys())[0]: dial_id = dial_id + '.json' for domain in ontology.all_domains: if self.all_data[dial_id]['goal'].get(domain): true_goal = self.all_data[dial_id]['goal'] goal = self._parseGoal(goal, true_goal, domain) for domain in goal.keys(): reqs[domain] = goal[domain]['requestable'] success, match, stats, counts = \ self._evaluateGeneratedDialogue(dial, goal, reqs, counts, same_eval_as_cambridge=same_eval_as_cambridge, fout=fout) successes += success matches += match dial_num += 1 succ_rate = successes / (float(dial_num) + 1e-10) * 100 match_rate = matches / (float(dial_num) + 1e-10) * 100 return succ_rate, match_rate, counts, dial_num def _evaluateGeneratedDialogue(self, dialog, goal, real_requestables, counts, soft_acc=False, same_eval_as_cambridge=False, fout=None): """Evaluates the dialogue created by the model. First we load the user goal of the dialogue, then for each turn generated by the system we look for key-words. For the Inform rate we look whether the entity was proposed. For the Success rate we look for requestables slots""" # for computing corpus success requestables = self.requestables # CHECK IF MATCH HAPPENED provided_requestables = {} venue_offered = {} domains_in_goal = [] log = [] bspans = {} for domain in goal.keys(): venue_offered[domain] = [] provided_requestables[domain] = [] domains_in_goal.append(domain) for t, turn in enumerate(dialog): if t == 0: continue if fout is not None: log.append({ 'turn_num': turn['turn_num'], 'turn_domain': turn['dspn'], 'user': turn['user'], 'aspn': turn['aspn'], 'aspn_gen': turn['aspn_gen'], 'resp': turn['resp'], 'resp_gen': turn['resp_gen'], 'pointer': turn['pointer'], }) sent_t = turn['resp_gen'] for domain in goal.keys(): # for computing success if same_eval_as_cambridge: # [restaurant_name], [hotel_name] instead of [value_name] if self.reader.use_true_domain_for_ctr_eval: dom_pred = [d[1:-1] for d in turn['dspn'].split()] else: dom_pred = [d[1:-1] for d in turn['dspn_gen'].split()] if domain not in dom_pred: # fail continue if '[value_name]' in sent_t or '[value_id]' in sent_t: if domain in [ 'restaurant', 'hotel', 'attraction', 'train' ]: # HERE YOU CAN PUT YOUR BELIEF STATE ESTIMATION if not self.reader.use_true_curr_bspn and not self.reader.use_true_bspn_for_ctr_eval: bspn = turn['bspn_gen'] else: bspn = turn['bspn'] constraint_dict = self.reader.bspan_to_constraint_dict( bspn) if constraint_dict.get(domain): venues = self.reader.db.queryJsons( domain, constraint_dict[domain], return_name=True) else: venues = [] if len(venue_offered[domain]) == 0 and venues: venue_offered[domain] = venues bspans[domain] = constraint_dict[domain] else: flag = False for ven in venues: if ven not in venue_offered[domain]: flag = True break if flag and venues: # sometimes there are no results so sample won't work venue_offered[domain] = venues bspans[domain] = constraint_dict[domain] else: # not limited so we can provide one venue_offered[domain] = '[value_name]' # ATTENTION: assumption here - we didn't provide phone or address twice! etc for requestable in requestables: if requestable == 'reference': if '[value_reference]' in sent_t: if domain in ['restaurant', 'hotel', 'train']: if 'booked' in turn['pointer'] or 'ok' in turn[ 'pointer'] or '[value_reference]' in turn[ 'resp']: # if pointer was allowing for that? provided_requestables[domain].append( 'reference') else: provided_requestables[domain].append( 'reference') else: if '[value_' + requestable + ']' in sent_t: provided_requestables[domain].append(requestable) # if name was given in the task for domain in goal.keys(): # if name was provided for the user, the match is being done automatically if 'name' in goal[domain]['informable']: venue_offered[domain] = '[value_name]' # special domains - entity does not need to be provided if domain in ['taxi', 'police', 'hospital']: venue_offered[domain] = '[value_name]' if domain == 'train': if not venue_offered[domain] and 'id' not in goal[domain][ 'requestable']: venue_offered[domain] = '[value_name]' """ Given all inform and requestable slots we go through each domain from the user goal and check whether right entity was provided and all requestable slots were given to the user. The dialogue is successful if that's the case for all domains. """ # HARD EVAL stats = { 'restaurant': [0, 0, 0], 'hotel': [0, 0, 0], 'attraction': [0, 0, 0], 'train': [0, 0, 0], 'taxi': [0, 0, 0], 'hospital': [0, 0, 0], 'police': [0, 0, 0] } match = 0 success = 0 # MATCH for domain in goal.keys(): match_stat = 0 if domain in ['restaurant', 'hotel', 'attraction', 'train']: goal_venues = self.reader.db.queryJsons( domain, goal[domain]['informable'], return_name=True) if type(venue_offered[domain] ) is str and '_name' in venue_offered[domain]: match += 1 match_stat = 1 elif len(venue_offered[domain]) > 0 and len( set(venue_offered[domain]) & set(goal_venues)) > 0: match += 1 match_stat = 1 else: if '_name]' in venue_offered[domain]: match += 1 match_stat = 1 stats[domain][0] = match_stat stats[domain][2] = 1 if soft_acc: match = float(match) / len(goal.keys()) else: if match == len(goal.keys()): match = 1.0 else: match = 0.0 for domain in domains_in_goal: for request in real_requestables[domain]: counts[request + '_total'] += 1 if request in provided_requestables[domain]: counts[request + '_offer'] += 1 # SUCCESS if fout is not None: for domain in domains_in_goal: success_stat = 0 domain_success = 0 if len(real_requestables[domain]) == 0: success += 1 success_stat = 1 stats[domain][1] = success_stat continue # if values in sentences are super set of requestables for request in real_requestables[domain]: if request in provided_requestables[domain]: domain_success += 1 if domain_success == len(real_requestables[domain]): success += 1 success_stat = 1 stats[domain][1] = success_stat # final eval if soft_acc: success = float(success) / len(real_requestables) else: if success >= len(real_requestables): success = 1 else: success = 0 else: if match == 1.0: for domain in domains_in_goal: success_stat = 0 domain_success = 0 if len(real_requestables[domain]) == 0: success += 1 success_stat = 1 stats[domain][1] = success_stat continue # if values in sentences are super set of requestables for request in real_requestables[domain]: if request in provided_requestables[domain]: domain_success += 1 if domain_success == len(real_requestables[domain]): success += 1 success_stat = 1 stats[domain][1] = success_stat # final eval if soft_acc: success = float(success) / len(real_requestables) else: if success >= len(real_requestables): success = 1 else: success = 0 if fout is not None and success == 0: sample = { dialog[0]['dial_id']: { 'log': log, 'real_requestables': real_requestables, 'provided_requestables': provided_requestables } } line = json.dumps(sample) fout.write(line) fout.write('\n') return success, match, stats, counts def _parseGoal(self, goal, true_goal, domain): """Parses user goal into dictionary format.""" goal[domain] = {} goal[domain] = {'informable': {}, 'requestable': [], 'booking': []} if 'info' in true_goal[domain]: if domain == 'train': # we consider dialogues only where train had to be booked! if 'book' in true_goal[domain]: goal[domain]['requestable'].append('reference') if 'reqt' in true_goal[domain]: if 'id' in true_goal[domain]['reqt']: goal[domain]['requestable'].append('id') else: if 'reqt' in true_goal[domain]: for s in true_goal[domain]['reqt']: # additional requests: if s in [ 'phone', 'address', 'postcode', 'reference', 'id' ]: # ones that can be easily delexicalized goal[domain]['requestable'].append(s) if 'book' in true_goal[domain]: goal[domain]['requestable'].append('reference') for s, v in true_goal[domain]['info'].items(): s_, v_ = clean_slot_values(self.db_dir, domain, s, v) if len(v_.split()) > 1: v_ = ' '.join( [token.text for token in self.reader.nlp(v_)]).strip() goal[domain]['informable'][s_] = v_ if 'book' in true_goal[domain]: goal[domain]['booking'] = true_goal[domain]['book'] return goal class GenericEvaluator: def __init__(self, reader): self.reader = reader self.metric_dict = {} def pack_dial(self, data): dials = {} for turn in data: dial_id = turn['dial_id'] if dial_id not in dials: dials[dial_id] = [] dials[dial_id].append(turn) return dials def run_metrics(self, results): raise ValueError('Please specify the evaluator first') def bleu_metric(self, data, type='bleu'): gen, truth = [], [] for row in data: gen.append(self.clean(row['resp_gen'])) # gen.append(self.clean(row['resp'])) truth.append(self.clean(row['resp'])) wrap_generated = [[_] for _ in gen] wrap_truth = [[_] for _ in truth] sc = BLEUScorer().score(zip(wrap_generated, wrap_truth)) return sc def _normalize_constraint(self, constraint, ignore_dontcare=False, intersection=True): """ Normalize belief span, e.g. delete repeated words :param constraint - {'food': 'asian oritental', 'pricerange': 'cheap'} :param intersection: if true, only keeps the words that appear in th ontology we set intersection=True as in previous works :returns: normalized constraint dict e.g. - {'food': 'asian oritental', 'pricerange': 'cheap', 'area': ''} """ normalized = {} for s in self.informable_slots: normalized[s] = '' for s, v in constraint.items(): if ignore_dontcare and v == 'dontcare': continue if intersection and v != 'dontcare' and v not in self.entities_flat: continue normalized[s] = v return normalized def _normalize_act(self, aspn, intersection=False): aspn_list = aspn.split('|') normalized = {} for i, v in enumerate(aspn_list): seq = v.strip() word_set = set() for w in seq.split(): if intersection: if self.reader.act_order[i] == 'av': if '[value' in w: word_set.add(w) else: if w in self.requestable_slots: word_set.add(w) else: word_set.add(w) normalized[self.reader.act_order[i]] = word_set return normalized def tracker_metric(self, data, normalize=True): # turn level metric tp, fp, fn, db_correct = 0, 0, 0, 0 goal_accr, slot_accr, total = 0, {}, 1e-8 for s in self.informable_slots: slot_accr[s] = 0 for row in data: if normalize: gen = self._normalize_constraint(row['bspn_gen']) truth = self._normalize_constraint(row['bspn']) else: gen = self._normalize_constraint( row['bspn_gen'], intersection=False) truth = self._normalize_constraint( row['bspn'], intersection=False) valid = 'thank' not in row['user'] and 'bye' not in row['user'] if valid: for slot, value in gen.items(): if value in truth[slot]: tp += 1 else: fp += 1 for slot, value in truth.items(): if value not in gen[slot]: fn += 1 if truth and valid: total += 1 for s in self.informable_slots: if gen[s] == truth[s]: slot_accr[s] += 1 if gen == truth: goal_accr += 1 if row.get('db_gen') and row.get('db_match'): if row['db_gen'] == row['db_match']: db_correct += 1 precision, recall = tp / (tp + fp + 1e-8), tp / (tp + fn + 1e-8) f1 = 2 * precision * recall / (precision + recall + 1e-8) goal_accr /= total db_correct /= total for s in slot_accr: slot_accr[s] /= total return precision, recall, f1, goal_accr, slot_accr, db_correct def request_metric(self, data): # dialog level metric dials = self.pack_dial(data) tp, fp, fn = 0, 0, 0 for dial_id in dials: truth_req, gen_req = set(), set() dial = dials[dial_id] for turn_num, turn in enumerate(dial): resp_gen_token = self.clean(turn['resp_gen']).split() resp_token = self.clean(turn['resp']).split() for w in resp_gen_token: if '[value_' in w and w.endswith( ']') and w != '[value_name]': gen_req.add(w[1:-1].split('_')[1]) for w in resp_token: if '[value_' in w and w.endswith( ']') and w != '[value_name]': truth_req.add(w[1:-1].split('_')[1]) for req in gen_req: if req in truth_req: tp += 1 else: fp += 1 for req in truth_req: if req not in gen_req: fn += 1 precision, recall = tp / (tp + fp + 1e-8), tp / (tp + fn + 1e-8) f1 = 2 * precision * recall / (precision + recall + 1e-8) return f1, precision, recall def act_metric(self, data): # turn level metric tp, fp, fn = { 'all_s': 0, 'all_v': 0 }, { 'all_s': 0, 'all_v': 0 }, { 'all_s': 0, 'all_v': 0 } for s in self.requestable_slots: tp[s], fp[s], fn[s] = 0, 0, 0 tp['[value_%s]' % s], fp['[value_%s]' % s], fn['[value_%s]' % s] = 0, 0, 0 for row in data: gen = self._normalize_act(row['aspn_gen']) truth = self._normalize_act(row['aspn']) valid = 'thank' not in row['user'] and 'bye' not in row['user'] if valid: # how well the act decoder captures user's requests for value in gen['av']: if value in truth['av']: tp['all_v'] += 1 if tp.get(value): tp[value] += 1 else: fp['all_v'] += 1 if fp.get(value): fp[value] += 1 for value in truth['av']: if value not in gen['av']: fn['all_v'] += 1 if fn.get(value): fn[value] += 1 # how accurately the act decoder predicts system's question if 'as' not in gen: continue for slot in gen['as']: if slot in truth['as']: tp['all_s'] += 1 if tp.get(slot): tp[slot] += 1 else: fp['all_s'] += 1 if fp.get(slot): fp[slot] += 1 for slot in truth['as']: if slot not in gen['as']: fn['all_s'] += 1 if fn.get(slot): fn[slot] += 1 result = {} for k, v in tp.items(): precision, recall = tp[k] / (tp[k] + fp[k] + 1e-8), tp[k] / ( tp[k] + fn[k] + 1e-8) f1 = 2 * precision * recall / (precision + recall + 1e-8) result[k] = [f1, precision, recall] return result """ For the data preparation and evaluation on In-Car Assistant/CamRest, we refer to the code of LABES (https://github.com/thu-spmi/LABES) """ class CamRestEvaluator(GenericEvaluator): def __init__(self, reader): super().__init__(reader) self.entities_flat, self.entitiy_to_slot_dict = self.get_entities( self.reader.ontology_path) self.informable_slots = self.reader.otlg.informable_slots self.requestable_slots = self.reader.otlg.requestable_slots def run_metrics(self, results): metrics = {} bleu = self.bleu_metric(results) p, r, f1, goal_acc, slot_acc, db_acc = self.tracker_metric(results) match = self.match_metric(results) req_f1, req_p, req_r = self.request_metric(results) metrics['bleu'] = bleu metrics['match'] = match metrics['req_f1'] = req_f1 metrics['joint_goal'] = goal_acc metrics['slot_accu'] = slot_acc metrics['slot-p/r/f1'] = (p, r, f1) metrics['db_acc'] = db_acc return metrics def get_entities(self, entity_path): entities_flat = [] entitiy_to_slot_dict = {} raw_entities = json.loads( open(entity_path, encoding='utf-8').read().lower()) for s in raw_entities['informable']: entities_flat.extend(raw_entities['informable'][s]) for v in raw_entities['informable'][s]: entitiy_to_slot_dict[v] = s return entities_flat, entitiy_to_slot_dict def constraint_same(self, truth_cons, gen_cons): if not truth_cons and not gen_cons: return True if not truth_cons or not gen_cons: return False return setsim(gen_cons, truth_cons) def match_metric(self, data): dials = self.pack_dial(data) match, total = 0, 1e-8 for dial_id in dials: dial = dials[dial_id] truth_cons, gen_cons = {'1': '', '2': '', '3': ''}, None for turn_num, turn in enumerate(dial): # find the last turn which the system provide an entity if '[value' in turn['resp_gen']: gen_cons = self._normalize_constraint( turn['bspn_gen'], ignore_dontcare=True) if '[value' in turn['resp']: truth_cons = self._normalize_constraint( turn['bspn'], ignore_dontcare=True) if not gen_cons: # if no entity is provided, choose the state of the last dialog turn gen_cons = self._normalize_constraint( dial[-1]['bspn_gen'], ignore_dontcare=True) if list(truth_cons.values()) != ['', '', '']: if gen_cons == truth_cons: match += 1 total += 1 return match / total def clean(self, resp): # we use the same clean process as in Sequicity, SEDST, FSDM # to ensure comparable results resp = resp.replace(f'{self.reader.sos_r_token} ', '') resp = resp.replace(f' {self.reader.eos_r_token}', '') resp = f'{self.reader.sos_r_token} {resp} {self.reader.eos_r_token}' for value, slot in self.entitiy_to_slot_dict.items(): resp = utils.clean_replace(resp, value, '[value_%s]' % slot) return resp class KvretEvaluator(GenericEvaluator): def __init__(self, reader): super().__init__(reader) self.entities_flat, self.entitiy_to_slot_dict = self.get_entities( self.reader.ontology_path) self.informable_slots = self.reader.otlg.informable_slots self.requestable_slots = self.reader.otlg.requestable_slots def run_metrics(self, results): metrics = {} bleu = self.bleu_metric(results) p, r, f1, goal_acc, slot_acc, db_acc = self.tracker_metric( results, normalize=True) match = self.match_metric(results) req_f1, req_p, req_r = self.request_metric(results) metrics['bleu'] = bleu metrics['match'] = match metrics['req_f1'] = req_f1 metrics['joint_goal'] = goal_acc metrics['slot_accu'] = slot_acc metrics['slot-p/r/f1'] = (p, r, f1) metrics['db_acc'] = db_acc return metrics def _normalize_constraint(self, constraint, ignore_dontcare=False, intersection=True): """ Normalize belief span, e.g. delete repeated words :param constraint - {'food': 'asian oritental', 'pricerange': 'cheap'} :param intersection: if true, only keeps the words that appear in th ontology we set intersection=True as in previous works :returns: normalized constraint dict e.g. - {'food': 'asian oritental', 'pricerange': 'cheap', 'area': ''} """ junk = [ 'good', 'great', 'quickest', 'shortest', 'route', 'week', 'fastest', 'nearest', 'next', 'closest', 'way', 'mile', 'activity', 'restaurant', 'appointment' ] normalized = {} for s in self.informable_slots: normalized[s] = '' for s, v in constraint.items(): for j in junk: v = ' '.join(v.replace(j, '').split()) if intersection and v not in self.entities_flat: continue if s in self.informable_slots: normalized[s] = v else: # TODO only use slot (not domain) in s for matching !!! pass return normalized def get_entities(self, entity_path): entities_flat = [] entitiy_to_slot_dict = {} entitiy_to_slot_dict = self.reader.entity_dict for s in entitiy_to_slot_dict: if s not in entities_flat: entities_flat.append(s) return entities_flat, entitiy_to_slot_dict def constraint_same(self, truth_cons, gen_cons): if not truth_cons and not gen_cons: return True if not truth_cons or not gen_cons: return False return setsim(gen_cons, truth_cons) def match_metric(self, data): dials = self.pack_dial(data) match, total = 0, 1e-8 for dial_id in dials: dial = dials[dial_id] truth_cons, gen_cons = { '1': '', '2': '', '3': '', '4': '', '5': '', '6': '', '7': '', '8': '', '9': '', '10': '', '11': '' }, None for turn_num, turn in enumerate(dial): # find the last turn which the system provide an entity if '[value' in turn['resp_gen']: gen_cons = self._normalize_constraint( turn['bspn_gen'], ignore_dontcare=True) if '[value' in turn['resp']: truth_cons = self._normalize_constraint( turn['bspn'], ignore_dontcare=True) if not gen_cons: # if no entity is provided, choose the state of the last dialog turn gen_cons = self._normalize_constraint( dial[-1]['bspn_gen'], ignore_dontcare=True) if list(truth_cons.values()) != [''] * 11: gen_cons = [x for x in gen_cons.values() if x] truth_cons = [x for x in truth_cons.values() if x] if self.constraint_same(gen_cons, truth_cons): match += 1 total += 1 return match / total def clean(self, resp): # we use the same clean process as in Sequicity, SEDST, FSDM # to ensure comparable results resp = resp.replace(f'{self.reader.sos_r_token} ', '') resp = resp.replace(f' {self.reader.eos_r_token}', '') resp = f'{self.reader.sos_r_token} {resp} {self.reader.eos_r_token}' for value, slot in self.entitiy_to_slot_dict.items(): resp = utils.clean_replace(resp, value, '[value_%s]' % slot) return resp