Skip to content

Source code for grammar-based text-to-SQL parser using one variant of Transformer decoder called ASTormer. Implementations for prevalent cross-domain multi-table benchmarks Spider, SParC, CoSQL, DuSQL and Chase are all included.

Notifications You must be signed in to change notification settings

rhythmcao/text-to-sql-astormer

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

70 Commits
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

ASTormer: AST Structure-aware Transformer Decoder for Text-to-SQL

This is the project containing source code for the paper ASTormer: An AST Structure-aware Transformer Decoder for Text-to-SQL. If you find it useful, please cite our work.

@misc{cao2023astormer,
      title={ASTormer: An AST Structure-aware Transformer Decoder for Text-to-SQL}, 
      author={Ruisheng Cao and Hanchong Zhang and Hongshen Xu and Jieyu Li and Da Ma and Lu Chen and Kai Yu},
      year={2023},
      eprint={2310.18662},
      archivePrefix={arXiv},
      primaryClass={cs.CL}
}

Note that: This work focuses on leveraging small-sized pre-trained models and labeled training data to train a specialized, interpretable and efficient local text-to-SQL parser in low-resource scenarios, instead of chasing SOTA performances. For better results, please try LLM with in-context learning (such as DINSQL and ACTSQL), or resort to larger encoder-decoder architectures containing billion parameters (such as Picard-3B and RESDSQL-3B). Due to a shift in the author's research focus in the LLM era, this project will no longer be maintained.

Create environment

The following commands are also provided in setup.sh.

  1. Firstly, create conda environment astormer:
$ conda create -n astormer python=3.8
$ conda activate astormer
$ pip install torch==1.8.1+cu111 -f https://download.pytorch.org/whl/torch_stable.html
$ pip install -r requirements.txt
  1. Next, download thrird-party dependencies:
$ python -c "import stanza; stanza.download('en')"
$ python -c "import nltk; nltk.download('stopwords'); nltk.download('punkt');"
  1. Download the required pre-trained language models from Hugging Face Model Hub, such as electra-small-discriminator and chinese-electra-180g-small-discriminator, into the pretrained_models directory: (please ensure that Git LFS is installed)
$ mkdir -p pretrained_models && cd pretrained_models
$ git lfs install
$ git clone https://huggingface.co/google/electra-small-discriminator

Download and preprocess datasets

  1. Create a new directory data to store all prevalent cross-domain multi-table text-to-SQL data, including Spider, SParC, CoSQL, DuSQL and Chase. Next, download, unzip and rename the spider.zip, sparc.zip, cosql_dataset.zip, DuSQL.zip, Chase.zip as well as their databases (Spider-testsuite-database and Chase-database) into the directory data.
  • For variants of dev dataset on Spider, e.g., SpiderSyn, SpiderDK, SpiderRealistic, they can also be downloaded and included at the evaluation stage.
  • These default paths can be changed by modifying the dict CONFIG_PATHS in nsts/transition_system.py.
  • The directory data should be organized as follows:
- data/
    - spider/
        - database/ # all databases, one directory for each db_id
        - database-testsuite/ # test-stuite databases
        - *.json # datasets or tables, variants of dev set such as dev_syn.json are also downloaded and placed here
    - sparc/
        - database/
        - *.json
    - cosql/
        - database/
        - sql_state_tracking/
            - *.json # train and dev datasets
        - [other directories]/
        - tables.json
    - dusql/
        - *.json
    - chase/
        - database/
        - *.json
  1. Datasets preprocessing, including:
  • Merge data/spider/train_spider.json and data/spider/train_others.json into one single dataset data/spider/train.json
  • Dataset and database format transformation for Chinese benchmarks DuSQL and Chase
  • Fix some annotation errors in SQLs and type errors in database schema
  • Re-parse the SQL query into a unified JSON format for all benchmarks. We modify and unify the format of sql field, including: (see nsts/parse_sql_to_json.py for details)
    • For a single condition, the parsed tuple is changed from (not_op, op_id, val_unit, val1, val2) into (agg_id, op_id, val_unit, val1, val2). The not_op is removed and integrated into op_id, such as not in and not like
    • For FROM conditions where the value is a column id, the target val1 must be a column list (agg_id, col_id, isDistinct(bool)) to distinguish from integer values
    • For ORDER BY clause, the parsed tuple is changed from ('asc'/'desc', [val_unit1, val_unit2, ...]) to ('asc'/'desc', [(agg_id, val_unit1), (agg_id, val_unit2), ...])
  • It takes less than 10 minutes to preprocess each dataset (tokenization, schema linking and value linking). We use the PLM tokenizer to tokenize questions and schema items; Schema linking is performed at the word level instead of BPE/Subword token-level.
$ ./run/run_preprocessing.sh

Training

To train ASTormer with small/base/large series pre-trained language models respectively:

  • dataset can be chosen from ['spider', 'sparc', 'cosql', 'dusql', 'chase']
  • plm is the name of pre-trained language models under the directory pretrained_models. Please conform to the choice of PLMs in preprocessing script (run/run_preprocessing.sh).
# swv means utilizing static word embeddings, extracted from small-series models such as electra-small-discriminator
$ ./run/run_train_and_eval_swv.sh [dataset] [plm]
        
# DDP is not needed, a single 2080Ti GPU is enough
$ ./run/run_train_and_eval_small.sh [dataset] [plm]

# if DDP used, please specify the environment variables below, e.g., one machine with two GPUs
$ GPU_PER_NODE=2 NUM_NODES=1 NODE_RANK=0 MASTER_ADDR="127.0.0.1" MASTER_PORT=23456 ./run/run_train_and_eval_base.sh [dataset] [plm]

# if DDP used, please specify the environment variables below, e.g., two machines each with two GPUs
$ GPU_PER_NODE=2 NUM_NODES=2 NODE_RANK=0 MASTER_ADDR=[node0_ip] MASTER_PORT=23456 ./run/run_train_and_eval_large.sh [dataset] [plm]
$ GPU_PER_NODE=2 NUM_NODES=2 NODE_RANK=1 MASTER_ADDR=[node0_ip] MASTER_PORT=23456 ./run/run_train_and_eval_large.sh [dataset] [plm]

Inference and Submission

For inference, see run/run_eval.sh (evaluation on the preprocessed dev dataset) and run/run_eval_from_scratch.sh (only SQL prediction, for testset submission):

  • saved_model_dir is the directory to saved arguments (params.json) and model parameters (model.bin)
$ ./run/run_eval.sh [saved_model_dir]

$ ./run/run_eval_from_scratch.sh [saved_model_dir]

For both training and inference, you can also use the prepared Docker environment from rhythmcao/astormer:v0.3:

$ docker pull rhythmcao/astormer:v0.3
$ docker run -it -v $PWD:/workspace rhythmcao/astormer:v0.3 /bin/bash

Acknowledgements

We are grateful to the flexible semantic parser TranX that inspires our works.

About

Source code for grammar-based text-to-SQL parser using one variant of Transformer decoder called ASTormer. Implementations for prevalent cross-domain multi-table benchmarks Spider, SParC, CoSQL, DuSQL and Chase are all included.

Topics

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published