This repository contains the official implementation of TAPE as described in the paper: Rethinking Addressing in Language Models via Contextualized Equivariant Positional Encoding by Jiajun Zhu, Peihao Wang, Ruisi Cai, Jason D. Lee, Pan Li, Zhangyang Wang.
conda create -n adape python=3.10
conda activate adape
pip install -r requirements.txt
pip install flash-attn --no-build-isolation
It is implemented independently under the directory arithmetic/
, another github repo linked as submodule. It should has independent environment as well. Please refer to its README for detailed instructions about training and evaluation.
The scripts under script/ covers the commands for training. For example, you can start training TAPE (adape
in code) model with the following command:
export TYPE=adape
bash script/train.sh
You can change CONFIG_NAME to choose different positional encoding variants. (choose from those under config/
)
There are three steps to get evaluation results:
- finetune pre-trained models on SCROLLS
- generate answers in validation set
- evaluate the answers with corresponding metric
export TYPE=adape DATASET_NAME=quality
export METRIC_DIR=scrolls/metrics
export SAVE_DIR=scrolls/quality
bash script/ft_scrolls.sh # assume the pretrained checkpoint is under output/${TYPE}_c4, if not, need to set 'output_name=<your_output_name>'
bash script/gen_scrolls.sh
python eval_scrolls.py --split validation --dataset_name $DATASET_NAME --predictions ${SAVE_DIR}/${TYPE}.json --metrics_output_dir $METRIC_DIR
You can change DATASET_NAME to choose different dataset. (choose from ['narrative_qa', 'quality', "qasper", 'contract_nli']
)
Similiar to training from scratch, you can use the following command ans select different methods:
export TYPE=adape
bash script/train_llama.sh
For finetuning perplexity evaluation, you need to manually download data hosted by LongLoRA
Dataset | Split | Link |
---|---|---|
PG19 | test | pg19/test.bin |
Proof-pile | test | proof-pile/test_sampled_data.bin |
Then you can use the following command:
data=proof_pile
model_path=output/llama_adape
bash script/eval_llama.sh
We also have eval_retrieval.py
for evaluation on passkey retrieval task.
python3 eval_retrieval.py --context_size 8192 --base_model output/llama_adape --max_tokens 8192 --interval 1000
The codebase are inherited from BiPE and LongLoRA. Thanks to their excellent work!