| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953 |
- # Copyright 2021-2022 The Alibaba DAMO NLP Team Authors.
- # Copyright from https://github.com/thu-spmi/LABES
- # Copyright from https://github.com/TonyNemo/UBAR-MultiWOZ
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- import math
- from collections import Counter
- import json
- import numpy as np
- from nltk.util import ngrams
- from sklearn.metrics import f1_score
- from modelscope.utils.nlp.space import ontology, utils
- from modelscope.utils.nlp.space.clean_dataset import clean_slot_values
- def similar(a, b):
- return a == b or a in b or b in a or a.split()[0] == b.split(
- )[0] or a.split()[-1] == b.split()[-1]
- def setsub(a, b):
- junks_a = []
- useless_constraint = [
- 'temperature', 'week', 'est ', 'quick', 'reminder', 'near'
- ]
- for i in a:
- flg = False
- for j in b:
- if similar(i, j):
- flg = True
- if not flg:
- junks_a.append(i)
- for junk in junks_a:
- flg = False
- for item in useless_constraint:
- if item in junk:
- flg = True
- if not flg:
- return False
- return True
- def setsim(a, b):
- a, b = set(a), set(b)
- return setsub(a, b) and setsub(b, a)
- def DA_evaluate(preds, labels):
- preds = np.array(preds)
- labels = np.array(labels)
- results = {}
- for avg_name in ['micro']:
- my_f1_score = f1_score(y_true=labels, y_pred=preds, average=avg_name)
- results['f1_{}'.format(avg_name)] = my_f1_score
- return results
- class BLEUScorer(object):
- # BLEU score calculator via GentScorer interface
- # it calculates the BLEU-4 by taking the entire corpus in
- # Calculate based multiple candidates against multiple references
- def __init__(self):
- pass
- def score(self, parallel_corpus):
- # containers
- count = [0, 0, 0, 0]
- clip_count = [0, 0, 0, 0]
- r = 0
- c = 0
- weights = [0.25, 0.25, 0.25, 0.25]
- # accumulate ngram statistics
- for hyps, refs in parallel_corpus:
- hyps = [hyp.split() for hyp in hyps]
- refs = [ref.split() for ref in refs]
- for hyp in hyps:
- for i in range(4):
- # accumulate ngram counts
- hypcnts = Counter(ngrams(hyp, i + 1))
- cnt = sum(hypcnts.values())
- count[i] += cnt
- # compute clipped counts
- max_counts = {}
- for ref in refs:
- refcnts = Counter(ngrams(ref, i + 1))
- for ng in hypcnts:
- max_counts[ng] = max(
- max_counts.get(ng, 0), refcnts[ng])
- clipcnt = \
- dict((ng, min(count, max_counts[ng])) for ng, count in hypcnts.items())
- clip_count[i] += sum(clipcnt.values())
- # accumulate r & c
- bestmatch = [1000, 1000]
- for ref in refs:
- if bestmatch[0] == 0:
- break
- diff = abs(len(ref) - len(hyp))
- if diff < bestmatch[0]:
- bestmatch[0] = diff
- bestmatch[1] = len(ref)
- r += bestmatch[1]
- c += len(hyp)
- # computing bleu score
- p0 = 1e-7
- bp = \
- 1 if c > r else math.exp(1 - float(r) / float(c))
- p_ns = \
- [float(clip_count[i]) / float(count[i] + p0) + p0 for i in range(4)]
- s = \
- math.fsum(w * math.log(p_n) for w, p_n in zip(weights, p_ns) if p_n)
- bleu = bp * math.exp(s)
- return bleu * 100
- """"
- For the data preparation and evaluation on MultiWOZ2.0/2.1,
- we refer to the code of UBAR (https://github.com/TonyNemo/UBAR-MultiWOZ)
- """
- class MultiWOZEvaluator(object):
- def __init__(self, reader, **kwargs):
- self.reader = reader
- self.domains = ontology.all_domains
- self.all_data = self.reader.data
- self.test_data = self.reader.test
- self.bleu_scorer = BLEUScorer()
- self.all_info_slot = []
- for d, s_list in ontology.informable_slots.items():
- for s in s_list:
- self.all_info_slot.append(d + '-' + s)
- # only evaluate these slots for dialog success
- self.requestables = ['phone', 'address', 'postcode', 'reference', 'id']
- self.db_dir = kwargs['data_dir']
- def pack_dial(self, data):
- dials = {}
- for turn in data:
- dial_id = turn['dial_id']
- if dial_id not in dials:
- dials[dial_id] = []
- dials[dial_id].append(turn)
- return dials
- def validation_metric(self, data, fout=None):
- bleu = self.bleu_metric(data)
- # accu_single_dom, accu_multi_dom, multi_dom_num = self.domain_eval(data)
- success, match, req_offer_counts, dial_num = \
- self.context_to_response_eval(data, same_eval_as_cambridge=True, fout=fout)
- return bleu, success, match
- def bleu_metric(self, data, eval_dial_list=None):
- gen, truth = [], []
- for row in data:
- if eval_dial_list and row[
- 'dial_id'] + '.json' not in eval_dial_list:
- continue
- gen.append(row['resp_gen'])
- truth.append(row['resp'])
- wrap_generated = [[_] for _ in gen]
- wrap_truth = [[_] for _ in truth]
- if gen and truth:
- try:
- sc = self.bleu_scorer.score(zip(wrap_generated, wrap_truth))
- except Exception:
- sc = 0.0
- else:
- sc = 0.0
- return sc
- def context_to_response_eval(self,
- data,
- eval_dial_list=None,
- same_eval_as_cambridge=False,
- fout=None):
- dials = self.pack_dial(data)
- counts = {}
- for req in self.requestables:
- counts[req + '_total'] = 0
- counts[req + '_offer'] = 0
- dial_num, successes, matches = 0, 0, 0
- for dial_id in dials:
- if eval_dial_list and dial_id + '.json' not in eval_dial_list:
- continue
- dial = dials[dial_id]
- reqs = {}
- goal = {}
- if '.json' not in dial_id and '.json' in list(
- self.all_data.keys())[0]:
- dial_id = dial_id + '.json'
- for domain in ontology.all_domains:
- if self.all_data[dial_id]['goal'].get(domain):
- true_goal = self.all_data[dial_id]['goal']
- goal = self._parseGoal(goal, true_goal, domain)
- for domain in goal.keys():
- reqs[domain] = goal[domain]['requestable']
- success, match, stats, counts = \
- self._evaluateGeneratedDialogue(dial, goal, reqs, counts,
- same_eval_as_cambridge=same_eval_as_cambridge, fout=fout)
- successes += success
- matches += match
- dial_num += 1
- succ_rate = successes / (float(dial_num) + 1e-10) * 100
- match_rate = matches / (float(dial_num) + 1e-10) * 100
- return succ_rate, match_rate, counts, dial_num
- def _evaluateGeneratedDialogue(self,
- dialog,
- goal,
- real_requestables,
- counts,
- soft_acc=False,
- same_eval_as_cambridge=False,
- fout=None):
- """Evaluates the dialogue created by the model.
- First we load the user goal of the dialogue, then for each turn
- generated by the system we look for key-words.
- For the Inform rate we look whether the entity was proposed.
- For the Success rate we look for requestables slots"""
- # for computing corpus success
- requestables = self.requestables
- # CHECK IF MATCH HAPPENED
- provided_requestables = {}
- venue_offered = {}
- domains_in_goal = []
- log = []
- bspans = {}
- for domain in goal.keys():
- venue_offered[domain] = []
- provided_requestables[domain] = []
- domains_in_goal.append(domain)
- for t, turn in enumerate(dialog):
- if t == 0:
- continue
- if fout is not None:
- log.append({
- 'turn_num': turn['turn_num'],
- 'turn_domain': turn['dspn'],
- 'user': turn['user'],
- 'aspn': turn['aspn'],
- 'aspn_gen': turn['aspn_gen'],
- 'resp': turn['resp'],
- 'resp_gen': turn['resp_gen'],
- 'pointer': turn['pointer'],
- })
- sent_t = turn['resp_gen']
- for domain in goal.keys():
- # for computing success
- if same_eval_as_cambridge:
- # [restaurant_name], [hotel_name] instead of [value_name]
- if self.reader.use_true_domain_for_ctr_eval:
- dom_pred = [d[1:-1] for d in turn['dspn'].split()]
- else:
- dom_pred = [d[1:-1] for d in turn['dspn_gen'].split()]
- if domain not in dom_pred: # fail
- continue
- if '[value_name]' in sent_t or '[value_id]' in sent_t:
- if domain in [
- 'restaurant', 'hotel', 'attraction', 'train'
- ]:
- # HERE YOU CAN PUT YOUR BELIEF STATE ESTIMATION
- if not self.reader.use_true_curr_bspn and not self.reader.use_true_bspn_for_ctr_eval:
- bspn = turn['bspn_gen']
- else:
- bspn = turn['bspn']
- constraint_dict = self.reader.bspan_to_constraint_dict(
- bspn)
- if constraint_dict.get(domain):
- venues = self.reader.db.queryJsons(
- domain,
- constraint_dict[domain],
- return_name=True)
- else:
- venues = []
- if len(venue_offered[domain]) == 0 and venues:
- venue_offered[domain] = venues
- bspans[domain] = constraint_dict[domain]
- else:
- flag = False
- for ven in venues:
- if ven not in venue_offered[domain]:
- flag = True
- break
- if flag and venues: # sometimes there are no results so sample won't work
- venue_offered[domain] = venues
- bspans[domain] = constraint_dict[domain]
- else: # not limited so we can provide one
- venue_offered[domain] = '[value_name]'
- # ATTENTION: assumption here - we didn't provide phone or address twice! etc
- for requestable in requestables:
- if requestable == 'reference':
- if '[value_reference]' in sent_t:
- if domain in ['restaurant', 'hotel', 'train']:
- if 'booked' in turn['pointer'] or 'ok' in turn[
- 'pointer'] or '[value_reference]' in turn[
- 'resp']:
- # if pointer was allowing for that?
- provided_requestables[domain].append(
- 'reference')
- else:
- provided_requestables[domain].append(
- 'reference')
- else:
- if '[value_' + requestable + ']' in sent_t:
- provided_requestables[domain].append(requestable)
- # if name was given in the task
- for domain in goal.keys():
- # if name was provided for the user, the match is being done automatically
- if 'name' in goal[domain]['informable']:
- venue_offered[domain] = '[value_name]'
- # special domains - entity does not need to be provided
- if domain in ['taxi', 'police', 'hospital']:
- venue_offered[domain] = '[value_name]'
- if domain == 'train':
- if not venue_offered[domain] and 'id' not in goal[domain][
- 'requestable']:
- venue_offered[domain] = '[value_name]'
- """
- Given all inform and requestable slots
- we go through each domain from the user goal
- and check whether right entity was provided and
- all requestable slots were given to the user.
- The dialogue is successful if that's the case for all domains.
- """
- # HARD EVAL
- stats = {
- 'restaurant': [0, 0, 0],
- 'hotel': [0, 0, 0],
- 'attraction': [0, 0, 0],
- 'train': [0, 0, 0],
- 'taxi': [0, 0, 0],
- 'hospital': [0, 0, 0],
- 'police': [0, 0, 0]
- }
- match = 0
- success = 0
- # MATCH
- for domain in goal.keys():
- match_stat = 0
- if domain in ['restaurant', 'hotel', 'attraction', 'train']:
- goal_venues = self.reader.db.queryJsons(
- domain, goal[domain]['informable'], return_name=True)
- if type(venue_offered[domain]
- ) is str and '_name' in venue_offered[domain]:
- match += 1
- match_stat = 1
- elif len(venue_offered[domain]) > 0 and len(
- set(venue_offered[domain]) & set(goal_venues)) > 0:
- match += 1
- match_stat = 1
- else:
- if '_name]' in venue_offered[domain]:
- match += 1
- match_stat = 1
- stats[domain][0] = match_stat
- stats[domain][2] = 1
- if soft_acc:
- match = float(match) / len(goal.keys())
- else:
- if match == len(goal.keys()):
- match = 1.0
- else:
- match = 0.0
- for domain in domains_in_goal:
- for request in real_requestables[domain]:
- counts[request + '_total'] += 1
- if request in provided_requestables[domain]:
- counts[request + '_offer'] += 1
- # SUCCESS
- if fout is not None:
- for domain in domains_in_goal:
- success_stat = 0
- domain_success = 0
- if len(real_requestables[domain]) == 0:
- success += 1
- success_stat = 1
- stats[domain][1] = success_stat
- continue
- # if values in sentences are super set of requestables
- for request in real_requestables[domain]:
- if request in provided_requestables[domain]:
- domain_success += 1
- if domain_success == len(real_requestables[domain]):
- success += 1
- success_stat = 1
- stats[domain][1] = success_stat
- # final eval
- if soft_acc:
- success = float(success) / len(real_requestables)
- else:
- if success >= len(real_requestables):
- success = 1
- else:
- success = 0
- else:
- if match == 1.0:
- for domain in domains_in_goal:
- success_stat = 0
- domain_success = 0
- if len(real_requestables[domain]) == 0:
- success += 1
- success_stat = 1
- stats[domain][1] = success_stat
- continue
- # if values in sentences are super set of requestables
- for request in real_requestables[domain]:
- if request in provided_requestables[domain]:
- domain_success += 1
- if domain_success == len(real_requestables[domain]):
- success += 1
- success_stat = 1
- stats[domain][1] = success_stat
- # final eval
- if soft_acc:
- success = float(success) / len(real_requestables)
- else:
- if success >= len(real_requestables):
- success = 1
- else:
- success = 0
- if fout is not None and success == 0:
- sample = {
- dialog[0]['dial_id']: {
- 'log': log,
- 'real_requestables': real_requestables,
- 'provided_requestables': provided_requestables
- }
- }
- line = json.dumps(sample)
- fout.write(line)
- fout.write('\n')
- return success, match, stats, counts
- def _parseGoal(self, goal, true_goal, domain):
- """Parses user goal into dictionary format."""
- goal[domain] = {}
- goal[domain] = {'informable': {}, 'requestable': [], 'booking': []}
- if 'info' in true_goal[domain]:
- if domain == 'train':
- # we consider dialogues only where train had to be booked!
- if 'book' in true_goal[domain]:
- goal[domain]['requestable'].append('reference')
- if 'reqt' in true_goal[domain]:
- if 'id' in true_goal[domain]['reqt']:
- goal[domain]['requestable'].append('id')
- else:
- if 'reqt' in true_goal[domain]:
- for s in true_goal[domain]['reqt']: # additional requests:
- if s in [
- 'phone', 'address', 'postcode', 'reference',
- 'id'
- ]:
- # ones that can be easily delexicalized
- goal[domain]['requestable'].append(s)
- if 'book' in true_goal[domain]:
- goal[domain]['requestable'].append('reference')
- for s, v in true_goal[domain]['info'].items():
- s_, v_ = clean_slot_values(self.db_dir, domain, s, v)
- if len(v_.split()) > 1:
- v_ = ' '.join(
- [token.text for token in self.reader.nlp(v_)]).strip()
- goal[domain]['informable'][s_] = v_
- if 'book' in true_goal[domain]:
- goal[domain]['booking'] = true_goal[domain]['book']
- return goal
- class GenericEvaluator:
- def __init__(self, reader):
- self.reader = reader
- self.metric_dict = {}
- def pack_dial(self, data):
- dials = {}
- for turn in data:
- dial_id = turn['dial_id']
- if dial_id not in dials:
- dials[dial_id] = []
- dials[dial_id].append(turn)
- return dials
- def run_metrics(self, results):
- raise ValueError('Please specify the evaluator first')
- def bleu_metric(self, data, type='bleu'):
- gen, truth = [], []
- for row in data:
- gen.append(self.clean(row['resp_gen']))
- # gen.append(self.clean(row['resp']))
- truth.append(self.clean(row['resp']))
- wrap_generated = [[_] for _ in gen]
- wrap_truth = [[_] for _ in truth]
- sc = BLEUScorer().score(zip(wrap_generated, wrap_truth))
- return sc
- def _normalize_constraint(self,
- constraint,
- ignore_dontcare=False,
- intersection=True):
- """
- Normalize belief span, e.g. delete repeated words
- :param constraint - {'food': 'asian oritental', 'pricerange': 'cheap'}
- :param intersection: if true, only keeps the words that appear in th ontology
- we set intersection=True as in previous works
- :returns: normalized constraint dict
- e.g. - {'food': 'asian oritental', 'pricerange': 'cheap', 'area': ''}
- """
- normalized = {}
- for s in self.informable_slots:
- normalized[s] = ''
- for s, v in constraint.items():
- if ignore_dontcare and v == 'dontcare':
- continue
- if intersection and v != 'dontcare' and v not in self.entities_flat:
- continue
- normalized[s] = v
- return normalized
- def _normalize_act(self, aspn, intersection=False):
- aspn_list = aspn.split('|')
- normalized = {}
- for i, v in enumerate(aspn_list):
- seq = v.strip()
- word_set = set()
- for w in seq.split():
- if intersection:
- if self.reader.act_order[i] == 'av':
- if '[value' in w:
- word_set.add(w)
- else:
- if w in self.requestable_slots:
- word_set.add(w)
- else:
- word_set.add(w)
- normalized[self.reader.act_order[i]] = word_set
- return normalized
- def tracker_metric(self, data, normalize=True):
- # turn level metric
- tp, fp, fn, db_correct = 0, 0, 0, 0
- goal_accr, slot_accr, total = 0, {}, 1e-8
- for s in self.informable_slots:
- slot_accr[s] = 0
- for row in data:
- if normalize:
- gen = self._normalize_constraint(row['bspn_gen'])
- truth = self._normalize_constraint(row['bspn'])
- else:
- gen = self._normalize_constraint(
- row['bspn_gen'], intersection=False)
- truth = self._normalize_constraint(
- row['bspn'], intersection=False)
- valid = 'thank' not in row['user'] and 'bye' not in row['user']
- if valid:
- for slot, value in gen.items():
- if value in truth[slot]:
- tp += 1
- else:
- fp += 1
- for slot, value in truth.items():
- if value not in gen[slot]:
- fn += 1
- if truth and valid:
- total += 1
- for s in self.informable_slots:
- if gen[s] == truth[s]:
- slot_accr[s] += 1
- if gen == truth:
- goal_accr += 1
- if row.get('db_gen') and row.get('db_match'):
- if row['db_gen'] == row['db_match']:
- db_correct += 1
- precision, recall = tp / (tp + fp + 1e-8), tp / (tp + fn + 1e-8)
- f1 = 2 * precision * recall / (precision + recall + 1e-8)
- goal_accr /= total
- db_correct /= total
- for s in slot_accr:
- slot_accr[s] /= total
- return precision, recall, f1, goal_accr, slot_accr, db_correct
- def request_metric(self, data):
- # dialog level metric
- dials = self.pack_dial(data)
- tp, fp, fn = 0, 0, 0
- for dial_id in dials:
- truth_req, gen_req = set(), set()
- dial = dials[dial_id]
- for turn_num, turn in enumerate(dial):
- resp_gen_token = self.clean(turn['resp_gen']).split()
- resp_token = self.clean(turn['resp']).split()
- for w in resp_gen_token:
- if '[value_' in w and w.endswith(
- ']') and w != '[value_name]':
- gen_req.add(w[1:-1].split('_')[1])
- for w in resp_token:
- if '[value_' in w and w.endswith(
- ']') and w != '[value_name]':
- truth_req.add(w[1:-1].split('_')[1])
- for req in gen_req:
- if req in truth_req:
- tp += 1
- else:
- fp += 1
- for req in truth_req:
- if req not in gen_req:
- fn += 1
- precision, recall = tp / (tp + fp + 1e-8), tp / (tp + fn + 1e-8)
- f1 = 2 * precision * recall / (precision + recall + 1e-8)
- return f1, precision, recall
- def act_metric(self, data):
- # turn level metric
- tp, fp, fn = {
- 'all_s': 0,
- 'all_v': 0
- }, {
- 'all_s': 0,
- 'all_v': 0
- }, {
- 'all_s': 0,
- 'all_v': 0
- }
- for s in self.requestable_slots:
- tp[s], fp[s], fn[s] = 0, 0, 0
- tp['[value_%s]' % s], fp['[value_%s]' % s], fn['[value_%s]'
- % s] = 0, 0, 0
- for row in data:
- gen = self._normalize_act(row['aspn_gen'])
- truth = self._normalize_act(row['aspn'])
- valid = 'thank' not in row['user'] and 'bye' not in row['user']
- if valid:
- # how well the act decoder captures user's requests
- for value in gen['av']:
- if value in truth['av']:
- tp['all_v'] += 1
- if tp.get(value):
- tp[value] += 1
- else:
- fp['all_v'] += 1
- if fp.get(value):
- fp[value] += 1
- for value in truth['av']:
- if value not in gen['av']:
- fn['all_v'] += 1
- if fn.get(value):
- fn[value] += 1
- # how accurately the act decoder predicts system's question
- if 'as' not in gen:
- continue
- for slot in gen['as']:
- if slot in truth['as']:
- tp['all_s'] += 1
- if tp.get(slot):
- tp[slot] += 1
- else:
- fp['all_s'] += 1
- if fp.get(slot):
- fp[slot] += 1
- for slot in truth['as']:
- if slot not in gen['as']:
- fn['all_s'] += 1
- if fn.get(slot):
- fn[slot] += 1
- result = {}
- for k, v in tp.items():
- precision, recall = tp[k] / (tp[k] + fp[k] + 1e-8), tp[k] / (
- tp[k] + fn[k] + 1e-8)
- f1 = 2 * precision * recall / (precision + recall + 1e-8)
- result[k] = [f1, precision, recall]
- return result
- """
- For the data preparation and evaluation on In-Car Assistant/CamRest,
- we refer to the code of LABES (https://github.com/thu-spmi/LABES)
- """
- class CamRestEvaluator(GenericEvaluator):
- def __init__(self, reader):
- super().__init__(reader)
- self.entities_flat, self.entitiy_to_slot_dict = self.get_entities(
- self.reader.ontology_path)
- self.informable_slots = self.reader.otlg.informable_slots
- self.requestable_slots = self.reader.otlg.requestable_slots
- def run_metrics(self, results):
- metrics = {}
- bleu = self.bleu_metric(results)
- p, r, f1, goal_acc, slot_acc, db_acc = self.tracker_metric(results)
- match = self.match_metric(results)
- req_f1, req_p, req_r = self.request_metric(results)
- metrics['bleu'] = bleu
- metrics['match'] = match
- metrics['req_f1'] = req_f1
- metrics['joint_goal'] = goal_acc
- metrics['slot_accu'] = slot_acc
- metrics['slot-p/r/f1'] = (p, r, f1)
- metrics['db_acc'] = db_acc
- return metrics
- def get_entities(self, entity_path):
- entities_flat = []
- entitiy_to_slot_dict = {}
- raw_entities = json.loads(
- open(entity_path, encoding='utf-8').read().lower())
- for s in raw_entities['informable']:
- entities_flat.extend(raw_entities['informable'][s])
- for v in raw_entities['informable'][s]:
- entitiy_to_slot_dict[v] = s
- return entities_flat, entitiy_to_slot_dict
- def constraint_same(self, truth_cons, gen_cons):
- if not truth_cons and not gen_cons:
- return True
- if not truth_cons or not gen_cons:
- return False
- return setsim(gen_cons, truth_cons)
- def match_metric(self, data):
- dials = self.pack_dial(data)
- match, total = 0, 1e-8
- for dial_id in dials:
- dial = dials[dial_id]
- truth_cons, gen_cons = {'1': '', '2': '', '3': ''}, None
- for turn_num, turn in enumerate(dial):
- # find the last turn which the system provide an entity
- if '[value' in turn['resp_gen']:
- gen_cons = self._normalize_constraint(
- turn['bspn_gen'], ignore_dontcare=True)
- if '[value' in turn['resp']:
- truth_cons = self._normalize_constraint(
- turn['bspn'], ignore_dontcare=True)
- if not gen_cons:
- # if no entity is provided, choose the state of the last dialog turn
- gen_cons = self._normalize_constraint(
- dial[-1]['bspn_gen'], ignore_dontcare=True)
- if list(truth_cons.values()) != ['', '', '']:
- if gen_cons == truth_cons:
- match += 1
- total += 1
- return match / total
- def clean(self, resp):
- # we use the same clean process as in Sequicity, SEDST, FSDM
- # to ensure comparable results
- resp = resp.replace(f'{self.reader.sos_r_token} ', '')
- resp = resp.replace(f' {self.reader.eos_r_token}', '')
- resp = f'{self.reader.sos_r_token} {resp} {self.reader.eos_r_token}'
- for value, slot in self.entitiy_to_slot_dict.items():
- resp = utils.clean_replace(resp, value, '[value_%s]' % slot)
- return resp
- class KvretEvaluator(GenericEvaluator):
- def __init__(self, reader):
- super().__init__(reader)
- self.entities_flat, self.entitiy_to_slot_dict = self.get_entities(
- self.reader.ontology_path)
- self.informable_slots = self.reader.otlg.informable_slots
- self.requestable_slots = self.reader.otlg.requestable_slots
- def run_metrics(self, results):
- metrics = {}
- bleu = self.bleu_metric(results)
- p, r, f1, goal_acc, slot_acc, db_acc = self.tracker_metric(
- results, normalize=True)
- match = self.match_metric(results)
- req_f1, req_p, req_r = self.request_metric(results)
- metrics['bleu'] = bleu
- metrics['match'] = match
- metrics['req_f1'] = req_f1
- metrics['joint_goal'] = goal_acc
- metrics['slot_accu'] = slot_acc
- metrics['slot-p/r/f1'] = (p, r, f1)
- metrics['db_acc'] = db_acc
- return metrics
- def _normalize_constraint(self,
- constraint,
- ignore_dontcare=False,
- intersection=True):
- """
- Normalize belief span, e.g. delete repeated words
- :param constraint - {'food': 'asian oritental', 'pricerange': 'cheap'}
- :param intersection: if true, only keeps the words that appear in th ontology
- we set intersection=True as in previous works
- :returns: normalized constraint dict
- e.g. - {'food': 'asian oritental', 'pricerange': 'cheap', 'area': ''}
- """
- junk = [
- 'good', 'great', 'quickest', 'shortest', 'route', 'week',
- 'fastest', 'nearest', 'next', 'closest', 'way', 'mile', 'activity',
- 'restaurant', 'appointment'
- ]
- normalized = {}
- for s in self.informable_slots:
- normalized[s] = ''
- for s, v in constraint.items():
- for j in junk:
- v = ' '.join(v.replace(j, '').split())
- if intersection and v not in self.entities_flat:
- continue
- if s in self.informable_slots:
- normalized[s] = v
- else:
- # TODO only use slot (not domain) in s for matching !!!
- pass
- return normalized
- def get_entities(self, entity_path):
- entities_flat = []
- entitiy_to_slot_dict = {}
- entitiy_to_slot_dict = self.reader.entity_dict
- for s in entitiy_to_slot_dict:
- if s not in entities_flat:
- entities_flat.append(s)
- return entities_flat, entitiy_to_slot_dict
- def constraint_same(self, truth_cons, gen_cons):
- if not truth_cons and not gen_cons:
- return True
- if not truth_cons or not gen_cons:
- return False
- return setsim(gen_cons, truth_cons)
- def match_metric(self, data):
- dials = self.pack_dial(data)
- match, total = 0, 1e-8
- for dial_id in dials:
- dial = dials[dial_id]
- truth_cons, gen_cons = {
- '1': '',
- '2': '',
- '3': '',
- '4': '',
- '5': '',
- '6': '',
- '7': '',
- '8': '',
- '9': '',
- '10': '',
- '11': ''
- }, None
- for turn_num, turn in enumerate(dial):
- # find the last turn which the system provide an entity
- if '[value' in turn['resp_gen']:
- gen_cons = self._normalize_constraint(
- turn['bspn_gen'], ignore_dontcare=True)
- if '[value' in turn['resp']:
- truth_cons = self._normalize_constraint(
- turn['bspn'], ignore_dontcare=True)
- if not gen_cons:
- # if no entity is provided, choose the state of the last dialog turn
- gen_cons = self._normalize_constraint(
- dial[-1]['bspn_gen'], ignore_dontcare=True)
- if list(truth_cons.values()) != [''] * 11:
- gen_cons = [x for x in gen_cons.values() if x]
- truth_cons = [x for x in truth_cons.values() if x]
- if self.constraint_same(gen_cons, truth_cons):
- match += 1
- total += 1
- return match / total
- def clean(self, resp):
- # we use the same clean process as in Sequicity, SEDST, FSDM
- # to ensure comparable results
- resp = resp.replace(f'{self.reader.sos_r_token} ', '')
- resp = resp.replace(f' {self.reader.eos_r_token}', '')
- resp = f'{self.reader.sos_r_token} {resp} {self.reader.eos_r_token}'
- for value, slot in self.entitiy_to_slot_dict.items():
- resp = utils.clean_replace(resp, value, '[value_%s]' % slot)
- return resp
|