| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446 |
- # Copyright (c) Alibaba, Inc. and its affiliates.
- import os
- import random
- import re
- from typing import Any, Dict, List
- import torch
- from modelscope.utils.constant import ModeKeys
- from .base import OfaBasePreprocessor
- from .utils.bridge_content_encoder import get_database_matches
- from .utils.get_tables import dump_db_json_schema
- class OfaTextToSqlPreprocessor(OfaBasePreprocessor):
- r"""
- OFA preprocessor for text to sql tasks
- """
- def __init__(self,
- cfg,
- model_dir,
- mode=ModeKeys.INFERENCE,
- *args,
- **kwargs):
- """preprocess the data
- Args:
- cfg(modelscope.utils.config.ConfigDict) : model config
- model_dir (str): model path,
- mode: preprocessor mode (model mode)
- """
- super(OfaTextToSqlPreprocessor, self).__init__(cfg, model_dir, mode,
- *args, **kwargs)
- self.instruction_text = self.cfg.model.get('prompt',
- ' . generating sql code.')
- self.max_struct_length = self.cfg.get('max_struct_length', 256)
- self.separator = '\t'
- self.db_schema_cache = {}
- self.database_path = os.path.join(
- os.path.abspath(model_dir), 'database')
- def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
- if self.mode == ModeKeys.TRAIN:
- return self._build_train_sample(data)
- else:
- return self._build_infer_sample(data)
- def _build_train_sample(self, data: Dict[str, Any]) -> Dict[str, Any]:
- r"""
- build sample for training tasks.
- step 1. Get the input question and database id from text input
- step 2. Get the database structure input
- step 3. Add a pseudo ids for every input.
- step 4. Calculate the target and previous output items.
- """
- assert 'text' in self.column_map and 'text' in data, \
- 'there must be `text` column in task key map and source data'
- text = data[self.column_map['text']] # equal data['text']
- texts = text.split(self.separator)
- assert len(
- texts
- ) == 3, 'invalid input, should contain query, question and database id'
- query, question, db_id = texts
- # construct struct input
- if db_id not in self.db_schema_cache:
- self.db_schema_cache[db_id] = dump_db_json_schema(
- self.database_path + '/' + db_id + '/' + db_id + '.sqlite',
- db_id)
- question = ' '.join(question.strip().split()[:self.max_src_length])
- seq_inputs = seq2seq_input(query, question, db_id, self.database_path,
- self.db_schema_cache[db_id], self.cfg.model,
- True)
- struct_in = seq_inputs['struct_in']
- text = seq_inputs['text_in']
- seq_out = seq_inputs['seq_out']
- db_struct = seq_inputs['db_struct']
- text = '{} ; structured knowledge: {}'.format(
- text, struct_in) + self.instruction_text
- src_item = self.tokenize_text(text + self.instruction_text)
- src_item = src_item[:(self.max_src_length + self.max_struct_length
- + 20)]
- tgt_item = self.tokenize_text(
- ' {}'.format(seq_out), add_bos=False,
- add_eos=False)[:self.max_tgt_length]
- target_item = torch.cat([tgt_item, self.eos_item])
- prev_output_item = torch.cat([self.bos_item, tgt_item])
- sample = {
- 'id': 0.0,
- 'source': src_item,
- 'target': target_item,
- 'prev_output_tokens': prev_output_item,
- 'db_struct': db_struct
- }
- return sample
- def _build_infer_sample(self, data: Dict[str, Any]) -> Dict[str, Any]:
- r"""
- build sample for inference tasks.
- step 1. Get the input question and database id from text input
- step 2. Get the database structure input
- step 3. Add a pseudo ids for every input.
- """
- assert 'text' in self.column_map and 'text' in data, \
- 'there must be `text` column in task key map and source data'
- text = data[self.column_map['text']] # equal data['text']
- db_id = data.get(self.column_map['database'], 'culture_company')
- db_id = db_id.strip()
- # construct struct input
- if db_id not in self.db_schema_cache:
- self.db_schema_cache[db_id] = dump_db_json_schema(
- self.database_path + '/' + db_id + '/' + db_id + '.sqlite',
- db_id)
- text = ' '.join(text.strip().split()[:self.max_src_length])
- seq_inputs = seq2seq_input(None, text, db_id, self.database_path,
- self.db_schema_cache[db_id], self.cfg.model)
- struct_in = seq_inputs['struct_in']
- db_struct = seq_inputs['db_struct']
- text = '{} ; structured knowledge: {}'.format(
- text, struct_in) + self.instruction_text
- src_item = self.tokenize_text(text + self.instruction_text)
- src_item = src_item[:(self.max_src_length + self.max_struct_length
- + 20)]
- sample = {'id': 0.0, 'source': src_item, 'db_struct': db_struct}
- if 'solution' in self.column_map and self.column_map[
- 'solution'] in data:
- sample['label'] = ' {}'.format(data[self.column_map['solution']])
- return sample
- def seq2seq_input(query,
- question,
- db_id,
- db_path,
- schema,
- args,
- is_train=False):
- ex = form_input_for_construction(query, question, db_id, db_path, schema)
- serialized_schema = spider_add_serialized_schema(
- ex, args)['serialized_schema'].strip()
- if not is_train:
- return {
- 'struct_in': serialized_schema,
- 'text_in': question,
- 'db_struct': ex
- }
- question, seq_out = spider_pre_process_one_function(ex, args)
- return {
- 'struct_in': serialized_schema,
- 'text_in': question,
- 'seq_out': seq_out,
- 'db_struct': ex
- }
- def spider_pre_process_one_function(item: dict, args):
- prefix = ''
- seq_out = spider_get_target(
- query=item['query'],
- db_id=item['db_id'],
- normalize_query=True,
- target_with_db_id=args.target_with_db_id,
- )
- return prefix + item['question'].strip(), seq_out
- def spider_get_target(
- query: str,
- db_id: str,
- normalize_query: bool,
- target_with_db_id: bool,
- ) -> str:
- _normalize = normalize if normalize_query else (lambda x: x)
- return f'{db_id} | {_normalize(query)}' if target_with_db_id else _normalize(
- query)
- def normalize(query: str) -> str:
- def comma_fix(s):
- # Remove spaces in front of commas
- return s.replace(' , ', ', ')
- def white_space_fix(s):
- # Remove double and triple spaces
- return ' '.join(s.split())
- def lower(s):
- # Convert everything except text between (single or double) quotation marks to lower case
- return re.sub(r"\b(?<!['\"])(\w+)(?!['\"])\b",
- lambda match: match.group(1).lower(), s)
- return comma_fix(white_space_fix(lower(query)))
- def spider_add_serialized_schema(ex: dict, args) -> dict:
- if getattr(args, 'schema_serialization_with_nl'):
- serialized_schema = serialize_schema_natural_language(
- question=ex['question'],
- db_path=ex['db_path'],
- db_id=ex['db_id'],
- db_column_names=ex['db_column_names'],
- db_table_names=ex['db_table_names'],
- db_primary_keys=ex['db_primary_keys'],
- db_foreign_keys=ex['db_foreign_keys'],
- schema_serialization_with_db_content=args.
- schema_serialization_with_db_content,
- normalize_query=True,
- )
- else:
- serialized_schema = serialize_schema(
- question=ex['question'],
- db_path=ex['db_path'],
- db_id=ex['db_id'],
- db_column_names=ex['db_column_names'],
- db_table_names=ex['db_table_names'],
- schema_serialization_type='peteshaw',
- schema_serialization_randomized=False,
- schema_serialization_with_db_id=True,
- schema_serialization_with_db_content=args.
- schema_serialization_with_db_content,
- normalize_query=True,
- )
- return {'serialized_schema': serialized_schema}
- def serialize_schema_natural_language(
- question: str,
- db_path: str,
- db_id: str,
- db_column_names: Dict[str, str],
- db_table_names: List[str],
- db_primary_keys,
- db_foreign_keys,
- schema_serialization_with_db_content: bool = False,
- normalize_query: bool = True,
- ) -> str:
- overall_description = f'{db_id} contains tables such as ' \
- f'{", ".join([name.lower() if normalize_query else name for name in db_table_names])}.'
- def table_description_primary_key_template(primary_key):
- return f'{primary_key} is the primary key.'
- def table_description(name, column_names):
- return f'Table {name} has columns such as {", ".join(column_names)}.'
- def value_description(cv_pairs):
- return f'{"".join(["The {} contains values such as {}.".format(column, value) for column, value in cv_pairs])}'
- def foreign_key_description(table_1, column_1, table_2, column_2):
- return f'The {column_1} of {table_1} is the foreign key of {column_2} of {table_2}.'
- db_primary_keys = db_primary_keys['column_id']
- db_foreign_keys = list(
- zip(db_foreign_keys['column_id'], db_foreign_keys['other_column_id']))
- descriptions = [overall_description]
- db_table_name_strs = []
- db_column_name_strs = []
- value_sep = ', '
- for table_id, table_name in enumerate(db_table_names):
- table_name_str = table_name.lower() if normalize_query else table_name
- db_table_name_strs.append(table_name_str)
- columns = []
- column_value_pairs = []
- primary_keys = []
- for column_id, (x, y) in enumerate(
- zip(db_column_names['table_id'],
- db_column_names['column_name'])):
- if column_id == 0:
- continue
- column_str = y.lower() if normalize_query else y
- db_column_name_strs.append(column_str)
- if x == table_id:
- columns.append(column_str)
- if column_id in db_primary_keys:
- primary_keys.append(column_str)
- if schema_serialization_with_db_content:
- matches = get_database_matches(
- question=question,
- table_name=table_name,
- column_name=y,
- db_path=(db_path + '/' + db_id + '/' + db_id
- + '.sqlite'),
- )
- if matches:
- column_value_pairs.append(
- (column_str, value_sep.join(matches)))
- table_description_columns_str = table_description(
- table_name_str, columns)
- descriptions.append(table_description_columns_str)
- table_description_primary_key_str = table_description_primary_key_template(
- ', '.join(primary_keys))
- descriptions.append(table_description_primary_key_str)
- if len(column_value_pairs) > 0:
- value_description_str = value_description(column_value_pairs)
- descriptions.append(value_description_str)
- for x, y in db_foreign_keys:
- # get the table and column of x
- x_table_name = db_table_name_strs[db_column_names['table_id'][x]]
- x_column_name = db_column_name_strs[x]
- # get the table and column of y
- y_table_name = db_table_name_strs[db_column_names['table_id'][y]]
- y_column_name = db_column_name_strs[y]
- foreign_key_description_str = foreign_key_description(
- x_table_name, x_column_name, y_table_name, y_column_name)
- descriptions.append(foreign_key_description_str)
- return ' '.join(descriptions)
- def serialize_schema(
- question: str,
- db_path: str,
- db_id: str,
- db_column_names: Dict[str, str],
- db_table_names: List[str],
- schema_serialization_type: str = 'peteshaw',
- schema_serialization_randomized: bool = False,
- schema_serialization_with_db_id: bool = True,
- schema_serialization_with_db_content: bool = False,
- normalize_query: bool = True,
- ) -> str:
- if schema_serialization_type == 'verbose':
- db_id_str = 'Database: {db_id}. '
- table_sep = '. '
- table_str = 'Table: {table}. Columns: {columns}'
- column_sep = ', '
- column_str_with_values = '{column} ({values})'
- column_str_without_values = '{column}'
- value_sep = ', '
- elif schema_serialization_type == 'peteshaw':
- # see https://github.com/google-research/language/blob/master/language/nqg/tasks/spider/append_schema.py#L42
- db_id_str = ' | {db_id}'
- table_sep = ''
- table_str = ' | {table} : {columns}'
- column_sep = ' , '
- column_str_with_values = '{column} ( {values} )'
- column_str_without_values = '{column}'
- value_sep = ' , '
- else:
- raise NotImplementedError
- def get_column_str(table_name: str, column_name: str) -> str:
- column_name_str = column_name.lower(
- ) if normalize_query else column_name
- if schema_serialization_with_db_content:
- # print("testing")
- matches = get_database_matches(
- question=question,
- table_name=table_name,
- column_name=column_name,
- db_path=(db_path + '/' + db_id + '/' + db_id + '.sqlite'),
- )
- if matches:
- return column_str_with_values.format(
- column=column_name_str, values=value_sep.join(matches))
- else:
- return column_str_without_values.format(column=column_name_str)
- else:
- return column_str_without_values.format(column=column_name_str)
- tables = [
- table_str.format(
- table=table_name.lower() if normalize_query else table_name,
- columns=column_sep.join(
- map(
- lambda y: get_column_str(
- table_name=table_name, column_name=y[1]),
- filter(
- lambda y: y[0] == table_id,
- zip(
- db_column_names['table_id'],
- db_column_names['column_name'],
- ),
- ),
- )),
- ) for table_id, table_name in enumerate(db_table_names)
- ]
- if schema_serialization_randomized:
- random.shuffle(tables)
- if schema_serialization_with_db_id:
- serialized_schema = db_id_str.format(
- db_id=db_id) + table_sep.join(tables)
- else:
- serialized_schema = table_sep.join(tables)
- return serialized_schema
- def form_input_for_construction(query, question, db_id, db_path, schema):
- return {
- 'query':
- query,
- 'question':
- question,
- 'db_id':
- db_id,
- 'db_path':
- db_path,
- 'db_table_names':
- schema['table_names_original'],
- 'db_column_names': {
- 'table_id': [
- table_id
- for table_id, column_name in schema['column_names_original']
- ],
- 'column_name': [
- column_name
- for table_id, column_name in schema['column_names_original']
- ]
- },
- 'db_column_types':
- schema['column_types'],
- 'db_primary_keys': [{
- 'column_id': column_id
- } for column_id in schema['primary_keys']],
- 'db_foreign_keys': {
- 'column_id': [
- column_id
- for column_id, other_column_id in schema['foreign_keys']
- ],
- 'other_column_id': [
- other_column_id
- for column_id, other_column_id in schema['foreign_keys']
- ]
- },
- }
|