Skip to content

PyTorch implementation of DeltaLSTM and Column-Balanced Targeted Dropout

License

Notifications You must be signed in to change notification settings

gaochangw/DeltaLSTM-CBTD

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

2 Commits
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

DeltaLSTM-CBTD

PyTorch Implementation of DeltaLSTM (./nnmodels/deltalstm.py) and CBTD (./utils/util.py)

Prerequisite

Install Miniconda

wget https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh
chmod +x Miniconda3-latest-Linux-x86_64.sh
./Miniconda3-latest-Linux-x86_64.sh

Create an environment using the following command:

conda create -n pt python=3.8 numpy matplotlib pandas tqdm h5py \
    scipy jupyter scikit-learn tensorboard editdistance \
    pytorch torchvision torchaudio cudatoolkit=11.3 -c pytorch

Activate the environment before running the script.

conda activate pt

Install Pip Packages.

pip install sphfile

Run

Navigate to the project folder and run experiments with the main.py file by specifying the target dataset and step. Please run the following commands to pretrain an LSTM network and retrain on DeltaLSTM:

  1. Data Preparation:
python main.py --dataset_name timit --step data_prep
  1. Feature Extraction
python main.py --dataset_name timit --step feat_ext
  1. Pretrain on LSTM
python main.py --dataset_name timit --step train --phase pretrain --hid_type LSTM
  1. Test LSTM
python main.py --dataset_name timit --step test --phase pretrain --hid_type LSTM
  1. Retrain on DeltaLSTM
python main.py --dataset_name timit --step train --phase retrain --hid_type DeltaLSTM --hid_type_pretrain LSTM
  1. Test DeltaLSTM
python main.py --dataset_name timit --step test --phase retrain --hid_type DeltaLSTM

About

PyTorch implementation of DeltaLSTM and Column-Balanced Targeted Dropout

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages