eval.py 36 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953
  1. # Copyright 2021-2022 The Alibaba DAMO NLP Team Authors.
  2. # Copyright from https://github.com/thu-spmi/LABES
  3. # Copyright from https://github.com/TonyNemo/UBAR-MultiWOZ
  4. #
  5. # Licensed under the Apache License, Version 2.0 (the "License");
  6. # you may not use this file except in compliance with the License.
  7. # You may obtain a copy of the License at
  8. #
  9. # http://www.apache.org/licenses/LICENSE-2.0
  10. #
  11. # Unless required by applicable law or agreed to in writing, software
  12. # distributed under the License is distributed on an "AS IS" BASIS,
  13. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  14. # See the License for the specific language governing permissions and
  15. # limitations under the License.
  16. import math
  17. from collections import Counter
  18. import json
  19. import numpy as np
  20. from nltk.util import ngrams
  21. from sklearn.metrics import f1_score
  22. from modelscope.utils.nlp.space import ontology, utils
  23. from modelscope.utils.nlp.space.clean_dataset import clean_slot_values
  24. def similar(a, b):
  25. return a == b or a in b or b in a or a.split()[0] == b.split(
  26. )[0] or a.split()[-1] == b.split()[-1]
  27. def setsub(a, b):
  28. junks_a = []
  29. useless_constraint = [
  30. 'temperature', 'week', 'est ', 'quick', 'reminder', 'near'
  31. ]
  32. for i in a:
  33. flg = False
  34. for j in b:
  35. if similar(i, j):
  36. flg = True
  37. if not flg:
  38. junks_a.append(i)
  39. for junk in junks_a:
  40. flg = False
  41. for item in useless_constraint:
  42. if item in junk:
  43. flg = True
  44. if not flg:
  45. return False
  46. return True
  47. def setsim(a, b):
  48. a, b = set(a), set(b)
  49. return setsub(a, b) and setsub(b, a)
  50. def DA_evaluate(preds, labels):
  51. preds = np.array(preds)
  52. labels = np.array(labels)
  53. results = {}
  54. for avg_name in ['micro']:
  55. my_f1_score = f1_score(y_true=labels, y_pred=preds, average=avg_name)
  56. results['f1_{}'.format(avg_name)] = my_f1_score
  57. return results
  58. class BLEUScorer(object):
  59. # BLEU score calculator via GentScorer interface
  60. # it calculates the BLEU-4 by taking the entire corpus in
  61. # Calculate based multiple candidates against multiple references
  62. def __init__(self):
  63. pass
  64. def score(self, parallel_corpus):
  65. # containers
  66. count = [0, 0, 0, 0]
  67. clip_count = [0, 0, 0, 0]
  68. r = 0
  69. c = 0
  70. weights = [0.25, 0.25, 0.25, 0.25]
  71. # accumulate ngram statistics
  72. for hyps, refs in parallel_corpus:
  73. hyps = [hyp.split() for hyp in hyps]
  74. refs = [ref.split() for ref in refs]
  75. for hyp in hyps:
  76. for i in range(4):
  77. # accumulate ngram counts
  78. hypcnts = Counter(ngrams(hyp, i + 1))
  79. cnt = sum(hypcnts.values())
  80. count[i] += cnt
  81. # compute clipped counts
  82. max_counts = {}
  83. for ref in refs:
  84. refcnts = Counter(ngrams(ref, i + 1))
  85. for ng in hypcnts:
  86. max_counts[ng] = max(
  87. max_counts.get(ng, 0), refcnts[ng])
  88. clipcnt = \
  89. dict((ng, min(count, max_counts[ng])) for ng, count in hypcnts.items())
  90. clip_count[i] += sum(clipcnt.values())
  91. # accumulate r & c
  92. bestmatch = [1000, 1000]
  93. for ref in refs:
  94. if bestmatch[0] == 0:
  95. break
  96. diff = abs(len(ref) - len(hyp))
  97. if diff < bestmatch[0]:
  98. bestmatch[0] = diff
  99. bestmatch[1] = len(ref)
  100. r += bestmatch[1]
  101. c += len(hyp)
  102. # computing bleu score
  103. p0 = 1e-7
  104. bp = \
  105. 1 if c > r else math.exp(1 - float(r) / float(c))
  106. p_ns = \
  107. [float(clip_count[i]) / float(count[i] + p0) + p0 for i in range(4)]
  108. s = \
  109. math.fsum(w * math.log(p_n) for w, p_n in zip(weights, p_ns) if p_n)
  110. bleu = bp * math.exp(s)
  111. return bleu * 100
  112. """"
  113. For the data preparation and evaluation on MultiWOZ2.0/2.1,
  114. we refer to the code of UBAR (https://github.com/TonyNemo/UBAR-MultiWOZ)
  115. """
  116. class MultiWOZEvaluator(object):
  117. def __init__(self, reader, **kwargs):
  118. self.reader = reader
  119. self.domains = ontology.all_domains
  120. self.all_data = self.reader.data
  121. self.test_data = self.reader.test
  122. self.bleu_scorer = BLEUScorer()
  123. self.all_info_slot = []
  124. for d, s_list in ontology.informable_slots.items():
  125. for s in s_list:
  126. self.all_info_slot.append(d + '-' + s)
  127. # only evaluate these slots for dialog success
  128. self.requestables = ['phone', 'address', 'postcode', 'reference', 'id']
  129. self.db_dir = kwargs['data_dir']
  130. def pack_dial(self, data):
  131. dials = {}
  132. for turn in data:
  133. dial_id = turn['dial_id']
  134. if dial_id not in dials:
  135. dials[dial_id] = []
  136. dials[dial_id].append(turn)
  137. return dials
  138. def validation_metric(self, data, fout=None):
  139. bleu = self.bleu_metric(data)
  140. # accu_single_dom, accu_multi_dom, multi_dom_num = self.domain_eval(data)
  141. success, match, req_offer_counts, dial_num = \
  142. self.context_to_response_eval(data, same_eval_as_cambridge=True, fout=fout)
  143. return bleu, success, match
  144. def bleu_metric(self, data, eval_dial_list=None):
  145. gen, truth = [], []
  146. for row in data:
  147. if eval_dial_list and row[
  148. 'dial_id'] + '.json' not in eval_dial_list:
  149. continue
  150. gen.append(row['resp_gen'])
  151. truth.append(row['resp'])
  152. wrap_generated = [[_] for _ in gen]
  153. wrap_truth = [[_] for _ in truth]
  154. if gen and truth:
  155. try:
  156. sc = self.bleu_scorer.score(zip(wrap_generated, wrap_truth))
  157. except Exception:
  158. sc = 0.0
  159. else:
  160. sc = 0.0
  161. return sc
  162. def context_to_response_eval(self,
  163. data,
  164. eval_dial_list=None,
  165. same_eval_as_cambridge=False,
  166. fout=None):
  167. dials = self.pack_dial(data)
  168. counts = {}
  169. for req in self.requestables:
  170. counts[req + '_total'] = 0
  171. counts[req + '_offer'] = 0
  172. dial_num, successes, matches = 0, 0, 0
  173. for dial_id in dials:
  174. if eval_dial_list and dial_id + '.json' not in eval_dial_list:
  175. continue
  176. dial = dials[dial_id]
  177. reqs = {}
  178. goal = {}
  179. if '.json' not in dial_id and '.json' in list(
  180. self.all_data.keys())[0]:
  181. dial_id = dial_id + '.json'
  182. for domain in ontology.all_domains:
  183. if self.all_data[dial_id]['goal'].get(domain):
  184. true_goal = self.all_data[dial_id]['goal']
  185. goal = self._parseGoal(goal, true_goal, domain)
  186. for domain in goal.keys():
  187. reqs[domain] = goal[domain]['requestable']
  188. success, match, stats, counts = \
  189. self._evaluateGeneratedDialogue(dial, goal, reqs, counts,
  190. same_eval_as_cambridge=same_eval_as_cambridge, fout=fout)
  191. successes += success
  192. matches += match
  193. dial_num += 1
  194. succ_rate = successes / (float(dial_num) + 1e-10) * 100
  195. match_rate = matches / (float(dial_num) + 1e-10) * 100
  196. return succ_rate, match_rate, counts, dial_num
  197. def _evaluateGeneratedDialogue(self,
  198. dialog,
  199. goal,
  200. real_requestables,
  201. counts,
  202. soft_acc=False,
  203. same_eval_as_cambridge=False,
  204. fout=None):
  205. """Evaluates the dialogue created by the model.
  206. First we load the user goal of the dialogue, then for each turn
  207. generated by the system we look for key-words.
  208. For the Inform rate we look whether the entity was proposed.
  209. For the Success rate we look for requestables slots"""
  210. # for computing corpus success
  211. requestables = self.requestables
  212. # CHECK IF MATCH HAPPENED
  213. provided_requestables = {}
  214. venue_offered = {}
  215. domains_in_goal = []
  216. log = []
  217. bspans = {}
  218. for domain in goal.keys():
  219. venue_offered[domain] = []
  220. provided_requestables[domain] = []
  221. domains_in_goal.append(domain)
  222. for t, turn in enumerate(dialog):
  223. if t == 0:
  224. continue
  225. if fout is not None:
  226. log.append({
  227. 'turn_num': turn['turn_num'],
  228. 'turn_domain': turn['dspn'],
  229. 'user': turn['user'],
  230. 'aspn': turn['aspn'],
  231. 'aspn_gen': turn['aspn_gen'],
  232. 'resp': turn['resp'],
  233. 'resp_gen': turn['resp_gen'],
  234. 'pointer': turn['pointer'],
  235. })
  236. sent_t = turn['resp_gen']
  237. for domain in goal.keys():
  238. # for computing success
  239. if same_eval_as_cambridge:
  240. # [restaurant_name], [hotel_name] instead of [value_name]
  241. if self.reader.use_true_domain_for_ctr_eval:
  242. dom_pred = [d[1:-1] for d in turn['dspn'].split()]
  243. else:
  244. dom_pred = [d[1:-1] for d in turn['dspn_gen'].split()]
  245. if domain not in dom_pred: # fail
  246. continue
  247. if '[value_name]' in sent_t or '[value_id]' in sent_t:
  248. if domain in [
  249. 'restaurant', 'hotel', 'attraction', 'train'
  250. ]:
  251. # HERE YOU CAN PUT YOUR BELIEF STATE ESTIMATION
  252. if not self.reader.use_true_curr_bspn and not self.reader.use_true_bspn_for_ctr_eval:
  253. bspn = turn['bspn_gen']
  254. else:
  255. bspn = turn['bspn']
  256. constraint_dict = self.reader.bspan_to_constraint_dict(
  257. bspn)
  258. if constraint_dict.get(domain):
  259. venues = self.reader.db.queryJsons(
  260. domain,
  261. constraint_dict[domain],
  262. return_name=True)
  263. else:
  264. venues = []
  265. if len(venue_offered[domain]) == 0 and venues:
  266. venue_offered[domain] = venues
  267. bspans[domain] = constraint_dict[domain]
  268. else:
  269. flag = False
  270. for ven in venues:
  271. if ven not in venue_offered[domain]:
  272. flag = True
  273. break
  274. if flag and venues: # sometimes there are no results so sample won't work
  275. venue_offered[domain] = venues
  276. bspans[domain] = constraint_dict[domain]
  277. else: # not limited so we can provide one
  278. venue_offered[domain] = '[value_name]'
  279. # ATTENTION: assumption here - we didn't provide phone or address twice! etc
  280. for requestable in requestables:
  281. if requestable == 'reference':
  282. if '[value_reference]' in sent_t:
  283. if domain in ['restaurant', 'hotel', 'train']:
  284. if 'booked' in turn['pointer'] or 'ok' in turn[
  285. 'pointer'] or '[value_reference]' in turn[
  286. 'resp']:
  287. # if pointer was allowing for that?
  288. provided_requestables[domain].append(
  289. 'reference')
  290. else:
  291. provided_requestables[domain].append(
  292. 'reference')
  293. else:
  294. if '[value_' + requestable + ']' in sent_t:
  295. provided_requestables[domain].append(requestable)
  296. # if name was given in the task
  297. for domain in goal.keys():
  298. # if name was provided for the user, the match is being done automatically
  299. if 'name' in goal[domain]['informable']:
  300. venue_offered[domain] = '[value_name]'
  301. # special domains - entity does not need to be provided
  302. if domain in ['taxi', 'police', 'hospital']:
  303. venue_offered[domain] = '[value_name]'
  304. if domain == 'train':
  305. if not venue_offered[domain] and 'id' not in goal[domain][
  306. 'requestable']:
  307. venue_offered[domain] = '[value_name]'
  308. """
  309. Given all inform and requestable slots
  310. we go through each domain from the user goal
  311. and check whether right entity was provided and
  312. all requestable slots were given to the user.
  313. The dialogue is successful if that's the case for all domains.
  314. """
  315. # HARD EVAL
  316. stats = {
  317. 'restaurant': [0, 0, 0],
  318. 'hotel': [0, 0, 0],
  319. 'attraction': [0, 0, 0],
  320. 'train': [0, 0, 0],
  321. 'taxi': [0, 0, 0],
  322. 'hospital': [0, 0, 0],
  323. 'police': [0, 0, 0]
  324. }
  325. match = 0
  326. success = 0
  327. # MATCH
  328. for domain in goal.keys():
  329. match_stat = 0
  330. if domain in ['restaurant', 'hotel', 'attraction', 'train']:
  331. goal_venues = self.reader.db.queryJsons(
  332. domain, goal[domain]['informable'], return_name=True)
  333. if type(venue_offered[domain]
  334. ) is str and '_name' in venue_offered[domain]:
  335. match += 1
  336. match_stat = 1
  337. elif len(venue_offered[domain]) > 0 and len(
  338. set(venue_offered[domain]) & set(goal_venues)) > 0:
  339. match += 1
  340. match_stat = 1
  341. else:
  342. if '_name]' in venue_offered[domain]:
  343. match += 1
  344. match_stat = 1
  345. stats[domain][0] = match_stat
  346. stats[domain][2] = 1
  347. if soft_acc:
  348. match = float(match) / len(goal.keys())
  349. else:
  350. if match == len(goal.keys()):
  351. match = 1.0
  352. else:
  353. match = 0.0
  354. for domain in domains_in_goal:
  355. for request in real_requestables[domain]:
  356. counts[request + '_total'] += 1
  357. if request in provided_requestables[domain]:
  358. counts[request + '_offer'] += 1
  359. # SUCCESS
  360. if fout is not None:
  361. for domain in domains_in_goal:
  362. success_stat = 0
  363. domain_success = 0
  364. if len(real_requestables[domain]) == 0:
  365. success += 1
  366. success_stat = 1
  367. stats[domain][1] = success_stat
  368. continue
  369. # if values in sentences are super set of requestables
  370. for request in real_requestables[domain]:
  371. if request in provided_requestables[domain]:
  372. domain_success += 1
  373. if domain_success == len(real_requestables[domain]):
  374. success += 1
  375. success_stat = 1
  376. stats[domain][1] = success_stat
  377. # final eval
  378. if soft_acc:
  379. success = float(success) / len(real_requestables)
  380. else:
  381. if success >= len(real_requestables):
  382. success = 1
  383. else:
  384. success = 0
  385. else:
  386. if match == 1.0:
  387. for domain in domains_in_goal:
  388. success_stat = 0
  389. domain_success = 0
  390. if len(real_requestables[domain]) == 0:
  391. success += 1
  392. success_stat = 1
  393. stats[domain][1] = success_stat
  394. continue
  395. # if values in sentences are super set of requestables
  396. for request in real_requestables[domain]:
  397. if request in provided_requestables[domain]:
  398. domain_success += 1
  399. if domain_success == len(real_requestables[domain]):
  400. success += 1
  401. success_stat = 1
  402. stats[domain][1] = success_stat
  403. # final eval
  404. if soft_acc:
  405. success = float(success) / len(real_requestables)
  406. else:
  407. if success >= len(real_requestables):
  408. success = 1
  409. else:
  410. success = 0
  411. if fout is not None and success == 0:
  412. sample = {
  413. dialog[0]['dial_id']: {
  414. 'log': log,
  415. 'real_requestables': real_requestables,
  416. 'provided_requestables': provided_requestables
  417. }
  418. }
  419. line = json.dumps(sample)
  420. fout.write(line)
  421. fout.write('\n')
  422. return success, match, stats, counts
  423. def _parseGoal(self, goal, true_goal, domain):
  424. """Parses user goal into dictionary format."""
  425. goal[domain] = {}
  426. goal[domain] = {'informable': {}, 'requestable': [], 'booking': []}
  427. if 'info' in true_goal[domain]:
  428. if domain == 'train':
  429. # we consider dialogues only where train had to be booked!
  430. if 'book' in true_goal[domain]:
  431. goal[domain]['requestable'].append('reference')
  432. if 'reqt' in true_goal[domain]:
  433. if 'id' in true_goal[domain]['reqt']:
  434. goal[domain]['requestable'].append('id')
  435. else:
  436. if 'reqt' in true_goal[domain]:
  437. for s in true_goal[domain]['reqt']: # additional requests:
  438. if s in [
  439. 'phone', 'address', 'postcode', 'reference',
  440. 'id'
  441. ]:
  442. # ones that can be easily delexicalized
  443. goal[domain]['requestable'].append(s)
  444. if 'book' in true_goal[domain]:
  445. goal[domain]['requestable'].append('reference')
  446. for s, v in true_goal[domain]['info'].items():
  447. s_, v_ = clean_slot_values(self.db_dir, domain, s, v)
  448. if len(v_.split()) > 1:
  449. v_ = ' '.join(
  450. [token.text for token in self.reader.nlp(v_)]).strip()
  451. goal[domain]['informable'][s_] = v_
  452. if 'book' in true_goal[domain]:
  453. goal[domain]['booking'] = true_goal[domain]['book']
  454. return goal
  455. class GenericEvaluator:
  456. def __init__(self, reader):
  457. self.reader = reader
  458. self.metric_dict = {}
  459. def pack_dial(self, data):
  460. dials = {}
  461. for turn in data:
  462. dial_id = turn['dial_id']
  463. if dial_id not in dials:
  464. dials[dial_id] = []
  465. dials[dial_id].append(turn)
  466. return dials
  467. def run_metrics(self, results):
  468. raise ValueError('Please specify the evaluator first')
  469. def bleu_metric(self, data, type='bleu'):
  470. gen, truth = [], []
  471. for row in data:
  472. gen.append(self.clean(row['resp_gen']))
  473. # gen.append(self.clean(row['resp']))
  474. truth.append(self.clean(row['resp']))
  475. wrap_generated = [[_] for _ in gen]
  476. wrap_truth = [[_] for _ in truth]
  477. sc = BLEUScorer().score(zip(wrap_generated, wrap_truth))
  478. return sc
  479. def _normalize_constraint(self,
  480. constraint,
  481. ignore_dontcare=False,
  482. intersection=True):
  483. """
  484. Normalize belief span, e.g. delete repeated words
  485. :param constraint - {'food': 'asian oritental', 'pricerange': 'cheap'}
  486. :param intersection: if true, only keeps the words that appear in th ontology
  487. we set intersection=True as in previous works
  488. :returns: normalized constraint dict
  489. e.g. - {'food': 'asian oritental', 'pricerange': 'cheap', 'area': ''}
  490. """
  491. normalized = {}
  492. for s in self.informable_slots:
  493. normalized[s] = ''
  494. for s, v in constraint.items():
  495. if ignore_dontcare and v == 'dontcare':
  496. continue
  497. if intersection and v != 'dontcare' and v not in self.entities_flat:
  498. continue
  499. normalized[s] = v
  500. return normalized
  501. def _normalize_act(self, aspn, intersection=False):
  502. aspn_list = aspn.split('|')
  503. normalized = {}
  504. for i, v in enumerate(aspn_list):
  505. seq = v.strip()
  506. word_set = set()
  507. for w in seq.split():
  508. if intersection:
  509. if self.reader.act_order[i] == 'av':
  510. if '[value' in w:
  511. word_set.add(w)
  512. else:
  513. if w in self.requestable_slots:
  514. word_set.add(w)
  515. else:
  516. word_set.add(w)
  517. normalized[self.reader.act_order[i]] = word_set
  518. return normalized
  519. def tracker_metric(self, data, normalize=True):
  520. # turn level metric
  521. tp, fp, fn, db_correct = 0, 0, 0, 0
  522. goal_accr, slot_accr, total = 0, {}, 1e-8
  523. for s in self.informable_slots:
  524. slot_accr[s] = 0
  525. for row in data:
  526. if normalize:
  527. gen = self._normalize_constraint(row['bspn_gen'])
  528. truth = self._normalize_constraint(row['bspn'])
  529. else:
  530. gen = self._normalize_constraint(
  531. row['bspn_gen'], intersection=False)
  532. truth = self._normalize_constraint(
  533. row['bspn'], intersection=False)
  534. valid = 'thank' not in row['user'] and 'bye' not in row['user']
  535. if valid:
  536. for slot, value in gen.items():
  537. if value in truth[slot]:
  538. tp += 1
  539. else:
  540. fp += 1
  541. for slot, value in truth.items():
  542. if value not in gen[slot]:
  543. fn += 1
  544. if truth and valid:
  545. total += 1
  546. for s in self.informable_slots:
  547. if gen[s] == truth[s]:
  548. slot_accr[s] += 1
  549. if gen == truth:
  550. goal_accr += 1
  551. if row.get('db_gen') and row.get('db_match'):
  552. if row['db_gen'] == row['db_match']:
  553. db_correct += 1
  554. precision, recall = tp / (tp + fp + 1e-8), tp / (tp + fn + 1e-8)
  555. f1 = 2 * precision * recall / (precision + recall + 1e-8)
  556. goal_accr /= total
  557. db_correct /= total
  558. for s in slot_accr:
  559. slot_accr[s] /= total
  560. return precision, recall, f1, goal_accr, slot_accr, db_correct
  561. def request_metric(self, data):
  562. # dialog level metric
  563. dials = self.pack_dial(data)
  564. tp, fp, fn = 0, 0, 0
  565. for dial_id in dials:
  566. truth_req, gen_req = set(), set()
  567. dial = dials[dial_id]
  568. for turn_num, turn in enumerate(dial):
  569. resp_gen_token = self.clean(turn['resp_gen']).split()
  570. resp_token = self.clean(turn['resp']).split()
  571. for w in resp_gen_token:
  572. if '[value_' in w and w.endswith(
  573. ']') and w != '[value_name]':
  574. gen_req.add(w[1:-1].split('_')[1])
  575. for w in resp_token:
  576. if '[value_' in w and w.endswith(
  577. ']') and w != '[value_name]':
  578. truth_req.add(w[1:-1].split('_')[1])
  579. for req in gen_req:
  580. if req in truth_req:
  581. tp += 1
  582. else:
  583. fp += 1
  584. for req in truth_req:
  585. if req not in gen_req:
  586. fn += 1
  587. precision, recall = tp / (tp + fp + 1e-8), tp / (tp + fn + 1e-8)
  588. f1 = 2 * precision * recall / (precision + recall + 1e-8)
  589. return f1, precision, recall
  590. def act_metric(self, data):
  591. # turn level metric
  592. tp, fp, fn = {
  593. 'all_s': 0,
  594. 'all_v': 0
  595. }, {
  596. 'all_s': 0,
  597. 'all_v': 0
  598. }, {
  599. 'all_s': 0,
  600. 'all_v': 0
  601. }
  602. for s in self.requestable_slots:
  603. tp[s], fp[s], fn[s] = 0, 0, 0
  604. tp['[value_%s]' % s], fp['[value_%s]' % s], fn['[value_%s]'
  605. % s] = 0, 0, 0
  606. for row in data:
  607. gen = self._normalize_act(row['aspn_gen'])
  608. truth = self._normalize_act(row['aspn'])
  609. valid = 'thank' not in row['user'] and 'bye' not in row['user']
  610. if valid:
  611. # how well the act decoder captures user's requests
  612. for value in gen['av']:
  613. if value in truth['av']:
  614. tp['all_v'] += 1
  615. if tp.get(value):
  616. tp[value] += 1
  617. else:
  618. fp['all_v'] += 1
  619. if fp.get(value):
  620. fp[value] += 1
  621. for value in truth['av']:
  622. if value not in gen['av']:
  623. fn['all_v'] += 1
  624. if fn.get(value):
  625. fn[value] += 1
  626. # how accurately the act decoder predicts system's question
  627. if 'as' not in gen:
  628. continue
  629. for slot in gen['as']:
  630. if slot in truth['as']:
  631. tp['all_s'] += 1
  632. if tp.get(slot):
  633. tp[slot] += 1
  634. else:
  635. fp['all_s'] += 1
  636. if fp.get(slot):
  637. fp[slot] += 1
  638. for slot in truth['as']:
  639. if slot not in gen['as']:
  640. fn['all_s'] += 1
  641. if fn.get(slot):
  642. fn[slot] += 1
  643. result = {}
  644. for k, v in tp.items():
  645. precision, recall = tp[k] / (tp[k] + fp[k] + 1e-8), tp[k] / (
  646. tp[k] + fn[k] + 1e-8)
  647. f1 = 2 * precision * recall / (precision + recall + 1e-8)
  648. result[k] = [f1, precision, recall]
  649. return result
  650. """
  651. For the data preparation and evaluation on In-Car Assistant/CamRest,
  652. we refer to the code of LABES (https://github.com/thu-spmi/LABES)
  653. """
  654. class CamRestEvaluator(GenericEvaluator):
  655. def __init__(self, reader):
  656. super().__init__(reader)
  657. self.entities_flat, self.entitiy_to_slot_dict = self.get_entities(
  658. self.reader.ontology_path)
  659. self.informable_slots = self.reader.otlg.informable_slots
  660. self.requestable_slots = self.reader.otlg.requestable_slots
  661. def run_metrics(self, results):
  662. metrics = {}
  663. bleu = self.bleu_metric(results)
  664. p, r, f1, goal_acc, slot_acc, db_acc = self.tracker_metric(results)
  665. match = self.match_metric(results)
  666. req_f1, req_p, req_r = self.request_metric(results)
  667. metrics['bleu'] = bleu
  668. metrics['match'] = match
  669. metrics['req_f1'] = req_f1
  670. metrics['joint_goal'] = goal_acc
  671. metrics['slot_accu'] = slot_acc
  672. metrics['slot-p/r/f1'] = (p, r, f1)
  673. metrics['db_acc'] = db_acc
  674. return metrics
  675. def get_entities(self, entity_path):
  676. entities_flat = []
  677. entitiy_to_slot_dict = {}
  678. raw_entities = json.loads(
  679. open(entity_path, encoding='utf-8').read().lower())
  680. for s in raw_entities['informable']:
  681. entities_flat.extend(raw_entities['informable'][s])
  682. for v in raw_entities['informable'][s]:
  683. entitiy_to_slot_dict[v] = s
  684. return entities_flat, entitiy_to_slot_dict
  685. def constraint_same(self, truth_cons, gen_cons):
  686. if not truth_cons and not gen_cons:
  687. return True
  688. if not truth_cons or not gen_cons:
  689. return False
  690. return setsim(gen_cons, truth_cons)
  691. def match_metric(self, data):
  692. dials = self.pack_dial(data)
  693. match, total = 0, 1e-8
  694. for dial_id in dials:
  695. dial = dials[dial_id]
  696. truth_cons, gen_cons = {'1': '', '2': '', '3': ''}, None
  697. for turn_num, turn in enumerate(dial):
  698. # find the last turn which the system provide an entity
  699. if '[value' in turn['resp_gen']:
  700. gen_cons = self._normalize_constraint(
  701. turn['bspn_gen'], ignore_dontcare=True)
  702. if '[value' in turn['resp']:
  703. truth_cons = self._normalize_constraint(
  704. turn['bspn'], ignore_dontcare=True)
  705. if not gen_cons:
  706. # if no entity is provided, choose the state of the last dialog turn
  707. gen_cons = self._normalize_constraint(
  708. dial[-1]['bspn_gen'], ignore_dontcare=True)
  709. if list(truth_cons.values()) != ['', '', '']:
  710. if gen_cons == truth_cons:
  711. match += 1
  712. total += 1
  713. return match / total
  714. def clean(self, resp):
  715. # we use the same clean process as in Sequicity, SEDST, FSDM
  716. # to ensure comparable results
  717. resp = resp.replace(f'{self.reader.sos_r_token} ', '')
  718. resp = resp.replace(f' {self.reader.eos_r_token}', '')
  719. resp = f'{self.reader.sos_r_token} {resp} {self.reader.eos_r_token}'
  720. for value, slot in self.entitiy_to_slot_dict.items():
  721. resp = utils.clean_replace(resp, value, '[value_%s]' % slot)
  722. return resp
  723. class KvretEvaluator(GenericEvaluator):
  724. def __init__(self, reader):
  725. super().__init__(reader)
  726. self.entities_flat, self.entitiy_to_slot_dict = self.get_entities(
  727. self.reader.ontology_path)
  728. self.informable_slots = self.reader.otlg.informable_slots
  729. self.requestable_slots = self.reader.otlg.requestable_slots
  730. def run_metrics(self, results):
  731. metrics = {}
  732. bleu = self.bleu_metric(results)
  733. p, r, f1, goal_acc, slot_acc, db_acc = self.tracker_metric(
  734. results, normalize=True)
  735. match = self.match_metric(results)
  736. req_f1, req_p, req_r = self.request_metric(results)
  737. metrics['bleu'] = bleu
  738. metrics['match'] = match
  739. metrics['req_f1'] = req_f1
  740. metrics['joint_goal'] = goal_acc
  741. metrics['slot_accu'] = slot_acc
  742. metrics['slot-p/r/f1'] = (p, r, f1)
  743. metrics['db_acc'] = db_acc
  744. return metrics
  745. def _normalize_constraint(self,
  746. constraint,
  747. ignore_dontcare=False,
  748. intersection=True):
  749. """
  750. Normalize belief span, e.g. delete repeated words
  751. :param constraint - {'food': 'asian oritental', 'pricerange': 'cheap'}
  752. :param intersection: if true, only keeps the words that appear in th ontology
  753. we set intersection=True as in previous works
  754. :returns: normalized constraint dict
  755. e.g. - {'food': 'asian oritental', 'pricerange': 'cheap', 'area': ''}
  756. """
  757. junk = [
  758. 'good', 'great', 'quickest', 'shortest', 'route', 'week',
  759. 'fastest', 'nearest', 'next', 'closest', 'way', 'mile', 'activity',
  760. 'restaurant', 'appointment'
  761. ]
  762. normalized = {}
  763. for s in self.informable_slots:
  764. normalized[s] = ''
  765. for s, v in constraint.items():
  766. for j in junk:
  767. v = ' '.join(v.replace(j, '').split())
  768. if intersection and v not in self.entities_flat:
  769. continue
  770. if s in self.informable_slots:
  771. normalized[s] = v
  772. else:
  773. # TODO only use slot (not domain) in s for matching !!!
  774. pass
  775. return normalized
  776. def get_entities(self, entity_path):
  777. entities_flat = []
  778. entitiy_to_slot_dict = {}
  779. entitiy_to_slot_dict = self.reader.entity_dict
  780. for s in entitiy_to_slot_dict:
  781. if s not in entities_flat:
  782. entities_flat.append(s)
  783. return entities_flat, entitiy_to_slot_dict
  784. def constraint_same(self, truth_cons, gen_cons):
  785. if not truth_cons and not gen_cons:
  786. return True
  787. if not truth_cons or not gen_cons:
  788. return False
  789. return setsim(gen_cons, truth_cons)
  790. def match_metric(self, data):
  791. dials = self.pack_dial(data)
  792. match, total = 0, 1e-8
  793. for dial_id in dials:
  794. dial = dials[dial_id]
  795. truth_cons, gen_cons = {
  796. '1': '',
  797. '2': '',
  798. '3': '',
  799. '4': '',
  800. '5': '',
  801. '6': '',
  802. '7': '',
  803. '8': '',
  804. '9': '',
  805. '10': '',
  806. '11': ''
  807. }, None
  808. for turn_num, turn in enumerate(dial):
  809. # find the last turn which the system provide an entity
  810. if '[value' in turn['resp_gen']:
  811. gen_cons = self._normalize_constraint(
  812. turn['bspn_gen'], ignore_dontcare=True)
  813. if '[value' in turn['resp']:
  814. truth_cons = self._normalize_constraint(
  815. turn['bspn'], ignore_dontcare=True)
  816. if not gen_cons:
  817. # if no entity is provided, choose the state of the last dialog turn
  818. gen_cons = self._normalize_constraint(
  819. dial[-1]['bspn_gen'], ignore_dontcare=True)
  820. if list(truth_cons.values()) != [''] * 11:
  821. gen_cons = [x for x in gen_cons.values() if x]
  822. truth_cons = [x for x in truth_cons.values() if x]
  823. if self.constraint_same(gen_cons, truth_cons):
  824. match += 1
  825. total += 1
  826. return match / total
  827. def clean(self, resp):
  828. # we use the same clean process as in Sequicity, SEDST, FSDM
  829. # to ensure comparable results
  830. resp = resp.replace(f'{self.reader.sos_r_token} ', '')
  831. resp = resp.replace(f' {self.reader.eos_r_token}', '')
  832. resp = f'{self.reader.sos_r_token} {resp} {self.reader.eos_r_token}'
  833. for value, slot in self.entitiy_to_slot_dict.items():
  834. resp = utils.clean_replace(resp, value, '[value_%s]' % slot)
  835. return resp