This is the project containing source code for the paper ASTormer: An AST Structure-aware Transformer Decoder for Text-to-SQL. If you find it useful, please cite our work.
@misc{cao2023astormer,
title={ASTormer: An AST Structure-aware Transformer Decoder for Text-to-SQL},
author={Ruisheng Cao and Hanchong Zhang and Hongshen Xu and Jieyu Li and Da Ma and Lu Chen and Kai Yu},
year={2023},
eprint={2310.18662},
archivePrefix={arXiv},
primaryClass={cs.CL}
}
Note that: This work focuses on leveraging small-sized pre-trained models and labeled training data to train a specialized, interpretable and efficient local text-to-SQL parser in low-resource scenarios, instead of chasing SOTA performances. For better results, please try LLM with in-context learning (such as DINSQL and ACTSQL), or resort to larger encoder-decoder architectures containing billion parameters (such as Picard-3B and RESDSQL-3B). Due to a shift in the author's research focus in the LLM era, this project will no longer be maintained.
The following commands are also provided in setup.sh
.
- Firstly, create conda environment
astormer
:
$ conda create -n astormer python=3.8
$ conda activate astormer
$ pip install torch==1.8.1+cu111 -f https://download.pytorch.org/whl/torch_stable.html
$ pip install -r requirements.txt
- Next, download thrird-party dependencies:
$ python -c "import stanza; stanza.download('en')"
$ python -c "import nltk; nltk.download('stopwords'); nltk.download('punkt');"
- Download the required pre-trained language models from
Hugging Face Model Hub
, such aselectra-small-discriminator
andchinese-electra-180g-small-discriminator
, into thepretrained_models
directory: (please ensure thatGit LFS
is installed)
$ mkdir -p pretrained_models && cd pretrained_models
$ git lfs install
$ git clone https://huggingface.co/google/electra-small-discriminator
- Create a new directory
data
to store all prevalent cross-domain multi-table text-to-SQL data, including Spider, SParC, CoSQL, DuSQL and Chase. Next, download, unzip and rename the spider.zip, sparc.zip, cosql_dataset.zip, DuSQL.zip, Chase.zip as well as their databases (Spider-testsuite-database and Chase-database) into the directorydata
.
- For variants of dev dataset on Spider, e.g., SpiderSyn, SpiderDK, SpiderRealistic, they can also be downloaded and included at the evaluation stage.
- These default paths can be changed by modifying the dict
CONFIG_PATHS
innsts/transition_system.py
. - The directory
data
should be organized as follows:
- data/
- spider/
- database/ # all databases, one directory for each db_id
- database-testsuite/ # test-stuite databases
- *.json # datasets or tables, variants of dev set such as dev_syn.json are also downloaded and placed here
- sparc/
- database/
- *.json
- cosql/
- database/
- sql_state_tracking/
- *.json # train and dev datasets
- [other directories]/
- tables.json
- dusql/
- *.json
- chase/
- database/
- *.json
- Datasets preprocessing, including:
- Merge
data/spider/train_spider.json
anddata/spider/train_others.json
into one single datasetdata/spider/train.json
- Dataset and database format transformation for Chinese benchmarks DuSQL and Chase
- Fix some annotation errors in SQLs and type errors in database schema
- Re-parse the SQL query into a unified JSON format for all benchmarks. We modify and unify the format of
sql
field, including: (seensts/parse_sql_to_json.py
for details)- For a single condition, the parsed tuple is changed from
(not_op, op_id, val_unit, val1, val2)
into(agg_id, op_id, val_unit, val1, val2)
. Thenot_op
is removed and integrated intoop_id
, such asnot in
andnot like
- For FROM conditions where the value is a column id, the target
val1
must be a column list(agg_id, col_id, isDistinct(bool))
to distinguish from integer values - For ORDER BY clause, the parsed tuple is changed from
('asc'/'desc', [val_unit1, val_unit2, ...])
to('asc'/'desc', [(agg_id, val_unit1), (agg_id, val_unit2), ...])
- For a single condition, the parsed tuple is changed from
- It takes less than 10 minutes to preprocess each dataset (tokenization, schema linking and value linking). We use the PLM tokenizer to tokenize questions and schema items; Schema linking is performed at the word level instead of BPE/Subword token-level.
$ ./run/run_preprocessing.sh
To train ASTormer with small
/base
/large
series pre-trained language models respectively:
dataset
can be chosen from['spider', 'sparc', 'cosql', 'dusql', 'chase']
plm
is the name of pre-trained language models under the directorypretrained_models
. Please conform to the choice of PLMs in preprocessing script (run/run_preprocessing.sh
).
# swv means utilizing static word embeddings, extracted from small-series models such as electra-small-discriminator
$ ./run/run_train_and_eval_swv.sh [dataset] [plm]
# DDP is not needed, a single 2080Ti GPU is enough
$ ./run/run_train_and_eval_small.sh [dataset] [plm]
# if DDP used, please specify the environment variables below, e.g., one machine with two GPUs
$ GPU_PER_NODE=2 NUM_NODES=1 NODE_RANK=0 MASTER_ADDR="127.0.0.1" MASTER_PORT=23456 ./run/run_train_and_eval_base.sh [dataset] [plm]
# if DDP used, please specify the environment variables below, e.g., two machines each with two GPUs
$ GPU_PER_NODE=2 NUM_NODES=2 NODE_RANK=0 MASTER_ADDR=[node0_ip] MASTER_PORT=23456 ./run/run_train_and_eval_large.sh [dataset] [plm]
$ GPU_PER_NODE=2 NUM_NODES=2 NODE_RANK=1 MASTER_ADDR=[node0_ip] MASTER_PORT=23456 ./run/run_train_and_eval_large.sh [dataset] [plm]
For inference, see run/run_eval.sh
(evaluation on the preprocessed dev dataset) and run/run_eval_from_scratch.sh
(only SQL prediction, for testset submission):
saved_model_dir
is the directory to saved arguments (params.json
) and model parameters (model.bin
)
$ ./run/run_eval.sh [saved_model_dir]
$ ./run/run_eval_from_scratch.sh [saved_model_dir]
For both training and inference, you can also use the prepared Docker environment from rhythmcao/astormer:v0.3:
$ docker pull rhythmcao/astormer:v0.3
$ docker run -it -v $PWD:/workspace rhythmcao/astormer:v0.3 /bin/bash
We are grateful to the flexible semantic parser TranX that inspires our works.