text2sql.py 16 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446
  1. # Copyright (c) Alibaba, Inc. and its affiliates.
  2. import os
  3. import random
  4. import re
  5. from typing import Any, Dict, List
  6. import torch
  7. from modelscope.utils.constant import ModeKeys
  8. from .base import OfaBasePreprocessor
  9. from .utils.bridge_content_encoder import get_database_matches
  10. from .utils.get_tables import dump_db_json_schema
  11. class OfaTextToSqlPreprocessor(OfaBasePreprocessor):
  12. r"""
  13. OFA preprocessor for text to sql tasks
  14. """
  15. def __init__(self,
  16. cfg,
  17. model_dir,
  18. mode=ModeKeys.INFERENCE,
  19. *args,
  20. **kwargs):
  21. """preprocess the data
  22. Args:
  23. cfg(modelscope.utils.config.ConfigDict) : model config
  24. model_dir (str): model path,
  25. mode: preprocessor mode (model mode)
  26. """
  27. super(OfaTextToSqlPreprocessor, self).__init__(cfg, model_dir, mode,
  28. *args, **kwargs)
  29. self.instruction_text = self.cfg.model.get('prompt',
  30. ' . generating sql code.')
  31. self.max_struct_length = self.cfg.get('max_struct_length', 256)
  32. self.separator = '\t'
  33. self.db_schema_cache = {}
  34. self.database_path = os.path.join(
  35. os.path.abspath(model_dir), 'database')
  36. def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
  37. if self.mode == ModeKeys.TRAIN:
  38. return self._build_train_sample(data)
  39. else:
  40. return self._build_infer_sample(data)
  41. def _build_train_sample(self, data: Dict[str, Any]) -> Dict[str, Any]:
  42. r"""
  43. build sample for training tasks.
  44. step 1. Get the input question and database id from text input
  45. step 2. Get the database structure input
  46. step 3. Add a pseudo ids for every input.
  47. step 4. Calculate the target and previous output items.
  48. """
  49. assert 'text' in self.column_map and 'text' in data, \
  50. 'there must be `text` column in task key map and source data'
  51. text = data[self.column_map['text']] # equal data['text']
  52. texts = text.split(self.separator)
  53. assert len(
  54. texts
  55. ) == 3, 'invalid input, should contain query, question and database id'
  56. query, question, db_id = texts
  57. # construct struct input
  58. if db_id not in self.db_schema_cache:
  59. self.db_schema_cache[db_id] = dump_db_json_schema(
  60. self.database_path + '/' + db_id + '/' + db_id + '.sqlite',
  61. db_id)
  62. question = ' '.join(question.strip().split()[:self.max_src_length])
  63. seq_inputs = seq2seq_input(query, question, db_id, self.database_path,
  64. self.db_schema_cache[db_id], self.cfg.model,
  65. True)
  66. struct_in = seq_inputs['struct_in']
  67. text = seq_inputs['text_in']
  68. seq_out = seq_inputs['seq_out']
  69. db_struct = seq_inputs['db_struct']
  70. text = '{} ; structured knowledge: {}'.format(
  71. text, struct_in) + self.instruction_text
  72. src_item = self.tokenize_text(text + self.instruction_text)
  73. src_item = src_item[:(self.max_src_length + self.max_struct_length
  74. + 20)]
  75. tgt_item = self.tokenize_text(
  76. ' {}'.format(seq_out), add_bos=False,
  77. add_eos=False)[:self.max_tgt_length]
  78. target_item = torch.cat([tgt_item, self.eos_item])
  79. prev_output_item = torch.cat([self.bos_item, tgt_item])
  80. sample = {
  81. 'id': 0.0,
  82. 'source': src_item,
  83. 'target': target_item,
  84. 'prev_output_tokens': prev_output_item,
  85. 'db_struct': db_struct
  86. }
  87. return sample
  88. def _build_infer_sample(self, data: Dict[str, Any]) -> Dict[str, Any]:
  89. r"""
  90. build sample for inference tasks.
  91. step 1. Get the input question and database id from text input
  92. step 2. Get the database structure input
  93. step 3. Add a pseudo ids for every input.
  94. """
  95. assert 'text' in self.column_map and 'text' in data, \
  96. 'there must be `text` column in task key map and source data'
  97. text = data[self.column_map['text']] # equal data['text']
  98. db_id = data.get(self.column_map['database'], 'culture_company')
  99. db_id = db_id.strip()
  100. # construct struct input
  101. if db_id not in self.db_schema_cache:
  102. self.db_schema_cache[db_id] = dump_db_json_schema(
  103. self.database_path + '/' + db_id + '/' + db_id + '.sqlite',
  104. db_id)
  105. text = ' '.join(text.strip().split()[:self.max_src_length])
  106. seq_inputs = seq2seq_input(None, text, db_id, self.database_path,
  107. self.db_schema_cache[db_id], self.cfg.model)
  108. struct_in = seq_inputs['struct_in']
  109. db_struct = seq_inputs['db_struct']
  110. text = '{} ; structured knowledge: {}'.format(
  111. text, struct_in) + self.instruction_text
  112. src_item = self.tokenize_text(text + self.instruction_text)
  113. src_item = src_item[:(self.max_src_length + self.max_struct_length
  114. + 20)]
  115. sample = {'id': 0.0, 'source': src_item, 'db_struct': db_struct}
  116. if 'solution' in self.column_map and self.column_map[
  117. 'solution'] in data:
  118. sample['label'] = ' {}'.format(data[self.column_map['solution']])
  119. return sample
  120. def seq2seq_input(query,
  121. question,
  122. db_id,
  123. db_path,
  124. schema,
  125. args,
  126. is_train=False):
  127. ex = form_input_for_construction(query, question, db_id, db_path, schema)
  128. serialized_schema = spider_add_serialized_schema(
  129. ex, args)['serialized_schema'].strip()
  130. if not is_train:
  131. return {
  132. 'struct_in': serialized_schema,
  133. 'text_in': question,
  134. 'db_struct': ex
  135. }
  136. question, seq_out = spider_pre_process_one_function(ex, args)
  137. return {
  138. 'struct_in': serialized_schema,
  139. 'text_in': question,
  140. 'seq_out': seq_out,
  141. 'db_struct': ex
  142. }
  143. def spider_pre_process_one_function(item: dict, args):
  144. prefix = ''
  145. seq_out = spider_get_target(
  146. query=item['query'],
  147. db_id=item['db_id'],
  148. normalize_query=True,
  149. target_with_db_id=args.target_with_db_id,
  150. )
  151. return prefix + item['question'].strip(), seq_out
  152. def spider_get_target(
  153. query: str,
  154. db_id: str,
  155. normalize_query: bool,
  156. target_with_db_id: bool,
  157. ) -> str:
  158. _normalize = normalize if normalize_query else (lambda x: x)
  159. return f'{db_id} | {_normalize(query)}' if target_with_db_id else _normalize(
  160. query)
  161. def normalize(query: str) -> str:
  162. def comma_fix(s):
  163. # Remove spaces in front of commas
  164. return s.replace(' , ', ', ')
  165. def white_space_fix(s):
  166. # Remove double and triple spaces
  167. return ' '.join(s.split())
  168. def lower(s):
  169. # Convert everything except text between (single or double) quotation marks to lower case
  170. return re.sub(r"\b(?<!['\"])(\w+)(?!['\"])\b",
  171. lambda match: match.group(1).lower(), s)
  172. return comma_fix(white_space_fix(lower(query)))
  173. def spider_add_serialized_schema(ex: dict, args) -> dict:
  174. if getattr(args, 'schema_serialization_with_nl'):
  175. serialized_schema = serialize_schema_natural_language(
  176. question=ex['question'],
  177. db_path=ex['db_path'],
  178. db_id=ex['db_id'],
  179. db_column_names=ex['db_column_names'],
  180. db_table_names=ex['db_table_names'],
  181. db_primary_keys=ex['db_primary_keys'],
  182. db_foreign_keys=ex['db_foreign_keys'],
  183. schema_serialization_with_db_content=args.
  184. schema_serialization_with_db_content,
  185. normalize_query=True,
  186. )
  187. else:
  188. serialized_schema = serialize_schema(
  189. question=ex['question'],
  190. db_path=ex['db_path'],
  191. db_id=ex['db_id'],
  192. db_column_names=ex['db_column_names'],
  193. db_table_names=ex['db_table_names'],
  194. schema_serialization_type='peteshaw',
  195. schema_serialization_randomized=False,
  196. schema_serialization_with_db_id=True,
  197. schema_serialization_with_db_content=args.
  198. schema_serialization_with_db_content,
  199. normalize_query=True,
  200. )
  201. return {'serialized_schema': serialized_schema}
  202. def serialize_schema_natural_language(
  203. question: str,
  204. db_path: str,
  205. db_id: str,
  206. db_column_names: Dict[str, str],
  207. db_table_names: List[str],
  208. db_primary_keys,
  209. db_foreign_keys,
  210. schema_serialization_with_db_content: bool = False,
  211. normalize_query: bool = True,
  212. ) -> str:
  213. overall_description = f'{db_id} contains tables such as ' \
  214. f'{", ".join([name.lower() if normalize_query else name for name in db_table_names])}.'
  215. def table_description_primary_key_template(primary_key):
  216. return f'{primary_key} is the primary key.'
  217. def table_description(name, column_names):
  218. return f'Table {name} has columns such as {", ".join(column_names)}.'
  219. def value_description(cv_pairs):
  220. return f'{"".join(["The {} contains values such as {}.".format(column, value) for column, value in cv_pairs])}'
  221. def foreign_key_description(table_1, column_1, table_2, column_2):
  222. return f'The {column_1} of {table_1} is the foreign key of {column_2} of {table_2}.'
  223. db_primary_keys = db_primary_keys['column_id']
  224. db_foreign_keys = list(
  225. zip(db_foreign_keys['column_id'], db_foreign_keys['other_column_id']))
  226. descriptions = [overall_description]
  227. db_table_name_strs = []
  228. db_column_name_strs = []
  229. value_sep = ', '
  230. for table_id, table_name in enumerate(db_table_names):
  231. table_name_str = table_name.lower() if normalize_query else table_name
  232. db_table_name_strs.append(table_name_str)
  233. columns = []
  234. column_value_pairs = []
  235. primary_keys = []
  236. for column_id, (x, y) in enumerate(
  237. zip(db_column_names['table_id'],
  238. db_column_names['column_name'])):
  239. if column_id == 0:
  240. continue
  241. column_str = y.lower() if normalize_query else y
  242. db_column_name_strs.append(column_str)
  243. if x == table_id:
  244. columns.append(column_str)
  245. if column_id in db_primary_keys:
  246. primary_keys.append(column_str)
  247. if schema_serialization_with_db_content:
  248. matches = get_database_matches(
  249. question=question,
  250. table_name=table_name,
  251. column_name=y,
  252. db_path=(db_path + '/' + db_id + '/' + db_id
  253. + '.sqlite'),
  254. )
  255. if matches:
  256. column_value_pairs.append(
  257. (column_str, value_sep.join(matches)))
  258. table_description_columns_str = table_description(
  259. table_name_str, columns)
  260. descriptions.append(table_description_columns_str)
  261. table_description_primary_key_str = table_description_primary_key_template(
  262. ', '.join(primary_keys))
  263. descriptions.append(table_description_primary_key_str)
  264. if len(column_value_pairs) > 0:
  265. value_description_str = value_description(column_value_pairs)
  266. descriptions.append(value_description_str)
  267. for x, y in db_foreign_keys:
  268. # get the table and column of x
  269. x_table_name = db_table_name_strs[db_column_names['table_id'][x]]
  270. x_column_name = db_column_name_strs[x]
  271. # get the table and column of y
  272. y_table_name = db_table_name_strs[db_column_names['table_id'][y]]
  273. y_column_name = db_column_name_strs[y]
  274. foreign_key_description_str = foreign_key_description(
  275. x_table_name, x_column_name, y_table_name, y_column_name)
  276. descriptions.append(foreign_key_description_str)
  277. return ' '.join(descriptions)
  278. def serialize_schema(
  279. question: str,
  280. db_path: str,
  281. db_id: str,
  282. db_column_names: Dict[str, str],
  283. db_table_names: List[str],
  284. schema_serialization_type: str = 'peteshaw',
  285. schema_serialization_randomized: bool = False,
  286. schema_serialization_with_db_id: bool = True,
  287. schema_serialization_with_db_content: bool = False,
  288. normalize_query: bool = True,
  289. ) -> str:
  290. if schema_serialization_type == 'verbose':
  291. db_id_str = 'Database: {db_id}. '
  292. table_sep = '. '
  293. table_str = 'Table: {table}. Columns: {columns}'
  294. column_sep = ', '
  295. column_str_with_values = '{column} ({values})'
  296. column_str_without_values = '{column}'
  297. value_sep = ', '
  298. elif schema_serialization_type == 'peteshaw':
  299. # see https://github.com/google-research/language/blob/master/language/nqg/tasks/spider/append_schema.py#L42
  300. db_id_str = ' | {db_id}'
  301. table_sep = ''
  302. table_str = ' | {table} : {columns}'
  303. column_sep = ' , '
  304. column_str_with_values = '{column} ( {values} )'
  305. column_str_without_values = '{column}'
  306. value_sep = ' , '
  307. else:
  308. raise NotImplementedError
  309. def get_column_str(table_name: str, column_name: str) -> str:
  310. column_name_str = column_name.lower(
  311. ) if normalize_query else column_name
  312. if schema_serialization_with_db_content:
  313. # print("testing")
  314. matches = get_database_matches(
  315. question=question,
  316. table_name=table_name,
  317. column_name=column_name,
  318. db_path=(db_path + '/' + db_id + '/' + db_id + '.sqlite'),
  319. )
  320. if matches:
  321. return column_str_with_values.format(
  322. column=column_name_str, values=value_sep.join(matches))
  323. else:
  324. return column_str_without_values.format(column=column_name_str)
  325. else:
  326. return column_str_without_values.format(column=column_name_str)
  327. tables = [
  328. table_str.format(
  329. table=table_name.lower() if normalize_query else table_name,
  330. columns=column_sep.join(
  331. map(
  332. lambda y: get_column_str(
  333. table_name=table_name, column_name=y[1]),
  334. filter(
  335. lambda y: y[0] == table_id,
  336. zip(
  337. db_column_names['table_id'],
  338. db_column_names['column_name'],
  339. ),
  340. ),
  341. )),
  342. ) for table_id, table_name in enumerate(db_table_names)
  343. ]
  344. if schema_serialization_randomized:
  345. random.shuffle(tables)
  346. if schema_serialization_with_db_id:
  347. serialized_schema = db_id_str.format(
  348. db_id=db_id) + table_sep.join(tables)
  349. else:
  350. serialized_schema = table_sep.join(tables)
  351. return serialized_schema
  352. def form_input_for_construction(query, question, db_id, db_path, schema):
  353. return {
  354. 'query':
  355. query,
  356. 'question':
  357. question,
  358. 'db_id':
  359. db_id,
  360. 'db_path':
  361. db_path,
  362. 'db_table_names':
  363. schema['table_names_original'],
  364. 'db_column_names': {
  365. 'table_id': [
  366. table_id
  367. for table_id, column_name in schema['column_names_original']
  368. ],
  369. 'column_name': [
  370. column_name
  371. for table_id, column_name in schema['column_names_original']
  372. ]
  373. },
  374. 'db_column_types':
  375. schema['column_types'],
  376. 'db_primary_keys': [{
  377. 'column_id': column_id
  378. } for column_id in schema['primary_keys']],
  379. 'db_foreign_keys': {
  380. 'column_id': [
  381. column_id
  382. for column_id, other_column_id in schema['foreign_keys']
  383. ],
  384. 'other_column_id': [
  385. other_column_id
  386. for column_id, other_column_id in schema['foreign_keys']
  387. ]
  388. },
  389. }