db_ops.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317
  1. # Copyright (c) Alibaba, Inc. and its affiliates.
  2. import os
  3. import random
  4. import sqlite3
  5. import json
  6. from .ontology import all_domains, db_domains
  7. class MultiWozDB(object):
  8. def __init__(self, db_dir, db_paths):
  9. self.dbs = {}
  10. self.sql_dbs = {}
  11. for domain in all_domains:
  12. with open(
  13. os.path.join(db_dir, db_paths[domain]), 'r',
  14. encoding='utf-8') as f:
  15. self.dbs[domain] = json.loads(f.read().lower())
  16. def oneHotVector(self, domain, num):
  17. """Return number of available entities for particular domain."""
  18. vector = [0, 0, 0, 0]
  19. if num == '':
  20. return vector
  21. if domain != 'train':
  22. if num == 0:
  23. vector = [1, 0, 0, 0]
  24. elif num == 1:
  25. vector = [0, 1, 0, 0]
  26. elif num <= 3:
  27. vector = [0, 0, 1, 0]
  28. else:
  29. vector = [0, 0, 0, 1]
  30. else:
  31. if num == 0:
  32. vector = [1, 0, 0, 0]
  33. elif num <= 5:
  34. vector = [0, 1, 0, 0]
  35. elif num <= 10:
  36. vector = [0, 0, 1, 0]
  37. else:
  38. vector = [0, 0, 0, 1]
  39. return vector
  40. def addBookingPointer(self, turn_da):
  41. """Add information about availability of the booking option."""
  42. # Booking pointer
  43. # Do not consider booking two things in a single turn.
  44. vector = [0, 0]
  45. if turn_da.get('booking-nobook'):
  46. vector = [1, 0]
  47. if turn_da.get('booking-book') or turn_da.get('train-offerbooked'):
  48. vector = [0, 1]
  49. return vector
  50. def addDBPointer(self, domain, match_num, return_num=False):
  51. """Create database pointer for all related domains."""
  52. # if turn_domains is None:
  53. # turn_domains = db_domains
  54. if domain in db_domains:
  55. vector = self.oneHotVector(domain, match_num)
  56. else:
  57. vector = [0, 0, 0, 0]
  58. return vector
  59. def addDBIndicator(self, domain, match_num, return_num=False):
  60. """Create database indicator for all related domains."""
  61. # if turn_domains is None:
  62. # turn_domains = db_domains
  63. if domain in db_domains:
  64. vector = self.oneHotVector(domain, match_num)
  65. else:
  66. vector = [0, 0, 0, 0]
  67. # '[db_nores]', '[db_0]', '[db_1]', '[db_2]', '[db_3]'
  68. if vector == [0, 0, 0, 0]:
  69. indicator = '[db_nores]'
  70. else:
  71. indicator = '[db_%s]' % vector.index(1)
  72. return indicator
  73. def get_match_num(self, constraints, return_entry=False):
  74. """Create database pointer for all related domains."""
  75. match = {'general': ''}
  76. entry = {}
  77. # if turn_domains is None:
  78. # turn_domains = db_domains
  79. for domain in all_domains:
  80. match[domain] = ''
  81. if domain in db_domains and constraints.get(domain):
  82. matched_ents = self.queryJsons(domain, constraints[domain])
  83. match[domain] = len(matched_ents)
  84. if return_entry:
  85. entry[domain] = matched_ents
  86. if return_entry:
  87. return entry
  88. return match
  89. def pointerBack(self, vector, domain):
  90. # multi domain implementation
  91. # domnum = cfg.domain_num
  92. if domain.endswith(']'):
  93. domain = domain[1:-1]
  94. if domain != 'train':
  95. nummap = {0: '0', 1: '1', 2: '2-3', 3: '>3'}
  96. else:
  97. nummap = {0: '0', 1: '1-5', 2: '6-10', 3: '>10'}
  98. if vector[:4] == [0, 0, 0, 0]:
  99. report = ''
  100. else:
  101. num = vector.index(1)
  102. report = domain + ': ' + nummap[num] + '; '
  103. if vector[-2] == 0 and vector[-1] == 1:
  104. report += 'booking: ok'
  105. if vector[-2] == 1 and vector[-1] == 0:
  106. report += 'booking: unable'
  107. return report
  108. def queryJsons(self,
  109. domain,
  110. constraints,
  111. exactly_match=True,
  112. return_name=False):
  113. """Returns the list of entities for a given domain
  114. based on the annotation of the belief state
  115. constraints: dict e.g. {'pricerange': 'cheap', 'area': 'west'}
  116. """
  117. # query the db
  118. if domain == 'taxi':
  119. return [{
  120. 'taxi_colors':
  121. random.choice(self.dbs[domain]['taxi_colors']),
  122. 'taxi_types':
  123. random.choice(self.dbs[domain]['taxi_types']),
  124. 'taxi_phone': [random.randint(1, 9) for _ in range(10)]
  125. }]
  126. if domain == 'police':
  127. return self.dbs['police']
  128. if domain == 'hospital':
  129. if constraints.get('department'):
  130. for entry in self.dbs['hospital']:
  131. if entry.get('department') == constraints.get(
  132. 'department'):
  133. return [entry]
  134. else:
  135. return []
  136. valid_cons = False
  137. for v in constraints.values():
  138. if v not in ['not mentioned', '']:
  139. valid_cons = True
  140. if not valid_cons:
  141. return []
  142. match_result = []
  143. if 'name' in constraints:
  144. for db_ent in self.dbs[domain]:
  145. if 'name' in db_ent:
  146. cons = constraints['name']
  147. dbn = db_ent['name']
  148. if cons == dbn:
  149. db_ent = db_ent if not return_name else db_ent['name']
  150. match_result.append(db_ent)
  151. return match_result
  152. for db_ent in self.dbs[domain]:
  153. match = True
  154. for s, v in constraints.items():
  155. if s == 'name':
  156. continue
  157. if s in ['people', 'stay'] or (domain == 'hotel' and s == 'day') or \
  158. (domain == 'restaurant' and s in ['day', 'time']):
  159. # These inform slots belong to "book info",which do not exist in DB
  160. # "book" is according to the user goal,not DB
  161. continue
  162. skip_case = {
  163. "don't care": 1,
  164. "do n't care": 1,
  165. 'dont care': 1,
  166. 'not mentioned': 1,
  167. 'dontcare': 1,
  168. '': 1
  169. }
  170. if skip_case.get(v):
  171. continue
  172. if s not in db_ent:
  173. # logging.warning('Searching warning: slot %s not in %s db'%(s, domain))
  174. match = False
  175. break
  176. # v = 'guesthouse' if v == 'guest house' else v
  177. # v = 'swimmingpool' if v == 'swimming pool' else v
  178. v = 'yes' if v == 'free' else v
  179. if s in ['arrive', 'leave']:
  180. try:
  181. h, m = v.split(
  182. ':'
  183. ) # raise error if time value is not xx:xx format
  184. v = int(h) * 60 + int(m)
  185. except Exception:
  186. match = False
  187. break
  188. time = int(db_ent[s].split(':')[0]) * 60 + int(
  189. db_ent[s].split(':')[1])
  190. if s == 'arrive' and v > time:
  191. match = False
  192. if s == 'leave' and v < time:
  193. match = False
  194. else:
  195. if exactly_match and v != db_ent[s]:
  196. match = False
  197. break
  198. elif v not in db_ent[s]:
  199. match = False
  200. break
  201. if match:
  202. match_result.append(db_ent)
  203. if not return_name:
  204. return match_result
  205. else:
  206. if domain == 'train':
  207. match_result = [e['id'] for e in match_result]
  208. else:
  209. match_result = [e['name'] for e in match_result]
  210. return match_result
  211. def querySQL(self, domain, constraints):
  212. if not self.sql_dbs:
  213. for dom in db_domains:
  214. db = 'db/{}-dbase.db'.format(dom)
  215. conn = sqlite3.connect(db)
  216. c = conn.cursor()
  217. self.sql_dbs[dom] = c
  218. sql_query = 'select * from {}'.format(domain)
  219. flag = True
  220. for key, val in constraints.items():
  221. if val == '' \
  222. or val == 'dontcare' \
  223. or val == 'not mentioned' \
  224. or val == "don't care" \
  225. or val == 'dont care' \
  226. or val == "do n't care":
  227. pass
  228. else:
  229. if flag:
  230. sql_query += ' where '
  231. val2 = val.replace("'", "''")
  232. # val2 = normalize(val2)
  233. if key == 'leaveAt':
  234. sql_query += r' ' + key + ' > ' + r"'" + val2 + r"'"
  235. elif key == 'arriveBy':
  236. sql_query += r' ' + key + ' < ' + r"'" + val2 + r"'"
  237. else:
  238. sql_query += r' ' + key + '=' + r"'" + val2 + r"'"
  239. flag = False
  240. else:
  241. val2 = val.replace("'", "''")
  242. # val2 = normalize(val2)
  243. if key == 'leaveAt':
  244. sql_query += r' and ' + key + ' > ' + r"'" + val2 + r"'"
  245. elif key == 'arriveBy':
  246. sql_query += r' and ' + key + ' < ' + r"'" + val2 + r"'"
  247. else:
  248. sql_query += r' and ' + key + '=' + r"'" + val2 + r"'"
  249. try: # "select * from attraction where name = 'queens college'"
  250. print(sql_query)
  251. return self.sql_dbs[domain].execute(sql_query).fetchall()
  252. except Exception:
  253. return [] # TODO test it
  254. if __name__ == '__main__':
  255. dbPATHs = {
  256. 'attraction': 'db/attraction_db_processed.json',
  257. 'hospital': 'db/hospital_db_processed.json',
  258. 'hotel': 'db/hotel_db_processed.json',
  259. 'police': 'db/police_db_processed.json',
  260. 'restaurant': 'db/restaurant_db_processed.json',
  261. 'taxi': 'db/taxi_db_processed.json',
  262. 'train': 'db/train_db_processed.json',
  263. }
  264. db = MultiWozDB(dbPATHs)
  265. while True:
  266. constraints = {}
  267. inp = input(
  268. 'input belief state in fomat: domain-slot1=value1;slot2=value2...\n'
  269. )
  270. domain, cons = inp.split('-')
  271. for sv in cons.split(';'):
  272. s, v = sv.split('=')
  273. constraints[s] = v
  274. # res = db.querySQL(domain, constraints)
  275. res = db.queryJsons(domain, constraints, return_name=True)
  276. report = []
  277. reidx = {
  278. 'hotel': 8,
  279. 'restaurant': 6,
  280. 'attraction': 5,
  281. 'train': 1,
  282. }
  283. print(constraints)
  284. print(res)
  285. print('count:', len(res), '\nnames:', report)