ontology.py 6.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206
  1. # Copyright (c) Alibaba, Inc. and its affiliates.
  2. all_domains = [
  3. 'restaurant', 'hotel', 'attraction', 'train', 'taxi', 'police', 'hospital'
  4. ]
  5. all_domains_with_bracket = ['[{}]'.format(item) for item in all_domains]
  6. db_domains = ['restaurant', 'hotel', 'attraction', 'train']
  7. placeholder_tokens = [
  8. '<go_r>', '<go_b>', '<go_a>', '<go_d>', '<eos_u>', '<eos_r>', '<eos_b>',
  9. '<eos_a>', '<eos_d>', '<eos_q>', '<sos_u>', '<sos_r>', '<sos_b>',
  10. '<sos_a>', '<sos_d>', '<sos_q>'
  11. ]
  12. normlize_slot_names = {
  13. 'car type': 'car',
  14. 'entrance fee': 'price',
  15. 'duration': 'time',
  16. 'leaveat': 'leave',
  17. 'arriveby': 'arrive',
  18. 'trainid': 'id'
  19. }
  20. requestable_slots = {
  21. 'taxi': ['car', 'phone'],
  22. 'police': ['postcode', 'address', 'phone'],
  23. 'hospital': ['address', 'phone', 'postcode'],
  24. 'hotel': [
  25. 'address', 'postcode', 'internet', 'phone', 'parking', 'type',
  26. 'pricerange', 'stars', 'area', 'reference'
  27. ],
  28. 'attraction':
  29. ['price', 'type', 'address', 'postcode', 'phone', 'area', 'reference'],
  30. 'train': ['time', 'leave', 'price', 'arrive', 'id', 'reference'],
  31. 'restaurant': [
  32. 'phone', 'postcode', 'address', 'pricerange', 'food', 'area',
  33. 'reference'
  34. ]
  35. }
  36. all_reqslot = [
  37. 'car', 'address', 'postcode', 'phone', 'internet', 'parking', 'type',
  38. 'pricerange', 'food', 'stars', 'area', 'reference', 'time', 'leave',
  39. 'price', 'arrive', 'id'
  40. ]
  41. informable_slots = {
  42. 'taxi': ['leave', 'destination', 'departure', 'arrive'],
  43. 'police': [],
  44. 'hospital': ['department'],
  45. 'hotel': [
  46. 'type', 'parking', 'pricerange', 'internet', 'stay', 'day', 'people',
  47. 'area', 'stars', 'name'
  48. ],
  49. 'attraction': ['area', 'type', 'name'],
  50. 'train': ['destination', 'day', 'arrive', 'departure', 'people', 'leave'],
  51. 'restaurant':
  52. ['food', 'pricerange', 'area', 'name', 'time', 'day', 'people']
  53. }
  54. all_infslot = [
  55. 'type', 'parking', 'pricerange', 'internet', 'stay', 'day', 'people',
  56. 'area', 'stars', 'name', 'leave', 'destination', 'departure', 'arrive',
  57. 'department', 'food', 'time'
  58. ]
  59. all_slots = all_reqslot + [
  60. 'stay', 'day', 'people', 'name', 'destination', 'departure', 'department'
  61. ]
  62. get_slot = {}
  63. for s in all_slots:
  64. get_slot[s] = 1
  65. # mapping slots in dialogue act to original goal slot names
  66. da_abbr_to_slot_name = {
  67. 'addr': 'address',
  68. 'fee': 'price',
  69. 'post': 'postcode',
  70. 'ref': 'reference',
  71. 'ticket': 'price',
  72. 'depart': 'departure',
  73. 'dest': 'destination',
  74. }
  75. dialog_acts = {
  76. 'restaurant': [
  77. 'inform', 'request', 'nooffer', 'recommend', 'select', 'offerbook',
  78. 'offerbooked', 'nobook'
  79. ],
  80. 'hotel': [
  81. 'inform', 'request', 'nooffer', 'recommend', 'select', 'offerbook',
  82. 'offerbooked', 'nobook'
  83. ],
  84. 'attraction': ['inform', 'request', 'nooffer', 'recommend', 'select'],
  85. 'train':
  86. ['inform', 'request', 'nooffer', 'offerbook', 'offerbooked', 'select'],
  87. 'taxi': ['inform', 'request'],
  88. 'police': ['inform', 'request'],
  89. 'hospital': ['inform', 'request'],
  90. # 'booking': ['book', 'inform', 'nobook', 'request'],
  91. 'general': ['bye', 'greet', 'reqmore', 'welcome'],
  92. }
  93. all_acts = []
  94. for acts in dialog_acts.values():
  95. for act in acts:
  96. if act not in all_acts:
  97. all_acts.append(act)
  98. dialog_act_params = {
  99. 'inform': all_slots + ['choice', 'open'],
  100. 'request': all_infslot + ['choice', 'price'],
  101. 'nooffer': all_slots + ['choice'],
  102. 'recommend': all_reqslot + ['choice', 'open'],
  103. 'select': all_slots + ['choice'],
  104. # 'book': ['time', 'people', 'stay', 'reference', 'day', 'name', 'choice'],
  105. 'nobook': ['time', 'people', 'stay', 'reference', 'day', 'name', 'choice'],
  106. 'offerbook': all_slots + ['choice'],
  107. 'offerbooked': all_slots + ['choice'],
  108. 'reqmore': [],
  109. 'welcome': [],
  110. 'bye': [],
  111. 'greet': [],
  112. }
  113. dialog_act_all_slots = all_slots + ['choice', 'open']
  114. # special slot tokens in belief span
  115. # no need of this, just covert slot to [slot] e.g. pricerange -> [pricerange]
  116. slot_name_to_slot_token = {}
  117. # eos tokens definition
  118. eos_tokens = {
  119. 'user': '<eos_u>',
  120. 'user_delex': '<eos_u>',
  121. 'resp': '<eos_r>',
  122. 'resp_gen': '<eos_r>',
  123. 'pv_resp': '<eos_r>',
  124. 'bspn': '<eos_b>',
  125. 'bspn_gen': '<eos_b>',
  126. 'pv_bspn': '<eos_b>',
  127. 'bsdx': '<eos_b>',
  128. 'bsdx_gen': '<eos_b>',
  129. 'pv_bsdx': '<eos_b>',
  130. 'qspn': '<eos_q>',
  131. 'qspn_gen': '<eos_q>',
  132. 'pv_qspn': '<eos_q>',
  133. 'aspn': '<eos_a>',
  134. 'aspn_gen': '<eos_a>',
  135. 'pv_aspn': '<eos_a>',
  136. 'dspn': '<eos_d>',
  137. 'dspn_gen': '<eos_d>',
  138. 'pv_dspn': '<eos_d>'
  139. }
  140. # sos tokens definition
  141. sos_tokens = {
  142. 'user': '<sos_u>',
  143. 'user_delex': '<sos_u>',
  144. 'resp': '<sos_r>',
  145. 'resp_gen': '<sos_r>',
  146. 'pv_resp': '<sos_r>',
  147. 'bspn': '<sos_b>',
  148. 'bspn_gen': '<sos_b>',
  149. 'pv_bspn': '<sos_b>',
  150. 'bsdx': '<sos_b>',
  151. 'bsdx_gen': '<sos_b>',
  152. 'pv_bsdx': '<sos_b>',
  153. 'qspn': '<sos_q>',
  154. 'qspn_gen': '<sos_q>',
  155. 'pv_qspn': '<sos_q>',
  156. 'aspn': '<sos_a>',
  157. 'aspn_gen': '<sos_a>',
  158. 'pv_aspn': '<sos_a>',
  159. 'dspn': '<sos_d>',
  160. 'dspn_gen': '<sos_d>',
  161. 'pv_dspn': '<sos_d>'
  162. }
  163. # db tokens definition
  164. db_tokens = [
  165. '<sos_db>', '<eos_db>', '[book_nores]', '[book_fail]', '[book_success]',
  166. '[db_nores]', '[db_0]', '[db_1]', '[db_2]', '[db_3]'
  167. ]
  168. # understand tokens definition
  169. def get_understand_tokens(prompt_num_for_understand):
  170. understand_tokens = []
  171. for i in range(prompt_num_for_understand):
  172. understand_tokens.append(f'<understand_{i}>')
  173. return understand_tokens
  174. # policy tokens definition
  175. def get_policy_tokens(prompt_num_for_policy):
  176. policy_tokens = []
  177. for i in range(prompt_num_for_policy):
  178. policy_tokens.append(f'<policy_{i}>')
  179. return policy_tokens
  180. # all special tokens definition
  181. def get_special_tokens(other_tokens):
  182. special_tokens = [
  183. '<go_r>', '<go_b>', '<go_a>', '<go_d>', '<eos_u>', '<eos_r>',
  184. '<eos_b>', '<eos_a>', '<eos_d>', '<eos_q>', '<sos_u>', '<sos_r>',
  185. '<sos_b>', '<sos_a>', '<sos_d>', '<sos_q>'
  186. ] + db_tokens + other_tokens
  187. return special_tokens