20240820 Update: HGRN2 has been accepted by COLM 2024.
Official implementation of HGRN2: Gated Linear RNNs with State Expansion
. This repo does not contain specific codes, but only scripts and some instructions on how to reproduce the results of the paper. The overall directory is as follows:
We list the main experimental results in the table below; for the complete experimental results, please refer to the paper.
The overall network architecture is as follows:
The modification to hgrn2 is very simple. Compared to hgrn1, the recursive formula becomes:
where
Key insights:
- Expand memory is quite import.
- Outproduct is a parameter efficient expanding methods.
- Transitioning from linear RNN to linear attention. (The output gate plays the role of Q, (1 - forget gate) plays the role of K, and the input state plays the role of V.)
- No need extra parameters to represent forget gate like GLA/Mamba.
See hgru2-pytorch. In order to reproduce the experimental results, please use the reproduce branch! The other implementations come from fla, thanks for yzhangcs's implementation.
Our experiment uses several conda environments.
First build the conda environment based on the yaml file:
conda env create --file lra.yaml
Then, install hgru-pytorch
:
conda activate lra
git clone https://github.com/Doraemonzzz/hgru2-pytorch
cd hgru2-pytorch
pip install .
Build the conda environment based on the yaml file:
conda env create --file im.yaml
Then, install hgru-pytorch
:
conda activate im
git clone https://github.com/Doraemonzzz/hgru2-pytorch
cd hgru2-pytorch
pip install .
Regarding the wikitext-103 experiment, we provide the main version dependencies:
torch==2.0.1
triton==2.0.0
triton-nightly==2.1.0.dev20230728172942
After setting up the basic environment, you need to use our version of fairseq:
git clone https://github.com/OpenNLPLab/fairseq-evo.git
cd fairseq-eva
pip install -e .
Regarding the mqar experiment, we provide the main version dependencies:
torch==2.0.1
triton==2.1.0
After setting up the basic environment, you also need fla:
git clone https://github.com/sustcsonglin/flash-linear-attention
cd flash-linear-attention
pip install -e .
First download the wikitext-103 dataset:
git clone https://huggingface.co/datasets/OpenNLPLab/wikitext-103
Use the following command to train language model:
bash script_lm.sh arch num_gpus data_dir
where arch
is chosen from
hgrn2_lm_expand2
hgrn2_lm_outproduct_1
hgrn2_lm_outproduct_2
hgrn2_lm_outproduct_4
hgrn2_lm_outproduct_8
hgrn2_lm_outproduct_16
hgrn2_lm_outproduct_32
hgrn2_lm_outproduct_64
hgrn2_lm_outproduct_128
num_gpus
is the number of gpus and data_dir
is wikitext-103's path.
First clone the following codebase:
git clone https://github.com/OpenNLPLab/im.git
Then change the PROG
and DATA
in script_im.sh
, finally run the following script
python run_im.py
Download the raw data:
wget https://storage.googleapis.com/long-range-arena/lra_release.gz
mv lra_release.gz lra_release.tar.gz
tar -xvf lra_release.tar.gz
Or download the preprocessed data:
git clone https://huggingface.co/datasets/OpenNLPLab/lra
Clone the following repo:
git clone https://github.com/OpenNLPLab/lra.git
git checkout release_torch2
Change the DATA_PATH
and program_path
in script_lra_others.sh
and srcipt_lra_image
.
Use the following script to run the experiments:
python run_lra.py
First change the code_dir
and cache_dir
in script_mqar.sh
, then run the following script:
bash script_mqar.sh
If you find our repository or paper valuable, please cite it using the following BibTeX.
@misc{2404.07904,
Author = {Zhen Qin and Songlin Yang and Weixuan Sun and Xuyang Shen and Dong Li and Weigao Sun and Yiran Zhong},
Title = {HGRN2: Gated Linear RNNs with State Expansion},
Year = {2024},
Eprint = {arXiv:2404.07904},
}