Code for the experiments of the paper V. Kouni, G. Paraskevopoulos, H. Rauhut and G. C. Alexandropoulos, "ADMM-DAD Net: A Deep Unfolding Network for Analysis Compressed Sensing," ICASSP 2022 - 2022 IEEE International Conference on Acoustics, Speech and Signal Processing (ICASSP), Singapore, Singapore, 2022, pp. 1506-1510, doi: 10.1109/ICASSP43922.2022.9747096
.
The repository contains three main scripts for ADMM-DAD: admm_mnist.py
, admm_cifar.py
, admm_speech.py
.
The first two scripts train the model on the MNIST and CIFAR10 datasets respectively and create a folder named results_admm_dataset
(dataset=MNIST or CIFAR10
) containing examples of an original and reconstructed image from the corresponding dataset.
The admm_speech.py
a) trains the model on the SpeechCommands and TIMIT datasets, depending on the choice of the user b) creates a folder named results_speech_admm
containing examples of original and reconstructed raw speech samples from the corresponding dataset c) produces a checkpoint of the trained model. With this checkpoint, the user can run admm_extract_test_spectrograms.py
to extract spectrograms of an example test raw audio file of TIMIT.
Some additional scripts in the repository are measure_speech.py
, measure_clean.py
and evaluate_robustness.py
.
measure_speech.py
implements the preprocessing on the raw speech samples of the whole SpeechCommands and TIMIT datasets, depending on the choice of the user.
measure_clean.py
creates noiseless measurements of the raw speech samples of the whole SpeechCommands and TIMIT datasets, depending on the choice of the user.
evaluate_robustness.py
takes checkpoints of ADMM-DAD net and ISTA-net and plots the desired robustness graph.
Run
python admm_mnist.py --measurement-factor s --lamda 1e-4 --rho 1 --layers 5 --redundancy 5 --normalization NORMALIZE --learning-rate 1e-4
to train the model with MNIST (similarly with CIFAR10). s
is a CS ratio in {0.25, 0.40, 0.50} and NORMALIZE
stands for the type of desired normalization to be applied on the measurement matrix A (None for A, sqrt_m for A/sqrt(num_measurements), orth for AA^T=I). For the corresponding paper, the chosen normalization is sqrt_m.
Set DOWNLOAD_DATA=True
in measure_speech.py
or create a folder data
in your working directory, then extract speech_commands_v0.02.tar.gz
into data/SpeechCommands/speech_commands_v0.02
.
Run
python measure_speech.py --dataset speechcommands --measurement-factor s --ambient-dim 800 --sample-rate 8000 --normalization NORMALIZE
with s and NORMALIZE defined as previously.
This will create a folder in data/speechcommands_s_800_8000_NORMALIZE
.
The contents of this folder are:
├── measurement_matrix.p # Pickle containing Measurement matrix A
├── test.lmdb # Test data lmdb file
├── test.lmdb-lock
├── train.lmdb # Train data lmdb file
├── train.lmdb-lock
├── val.lmdb # Validation data lmdb file
└── val.lmdb-lock
Run
python admm_speech.py --input-folder data/speechcommands_s_800_8000_NORMALIZE --ambient-dim 800 --measurement-factor s --lamda 1e-4 --rho 1 --layers 5 --redundancy 5 --lr 1e-5
to train the model.
Get the tarball from
https://drive.google.com/file/d/1Co7I_sWqQFVl0t39fXnBnAmZhV4E1tcd/view?usp=sharing
Then move timit.tgz into the data
folder and run
tar xvf timit.tgz
Run
python measure_speech.py --dataset timit --ambient-dim ...
to segment and measure data. It will create a folder data/timit_{NUM_MEASUREMENTS}_{AMBIENT_DIM}_{SAMPLE_RATE}
.
python admm_speech.py --input_folder data/timit_200_800_8000 ...
to train the model.
Train a model with admm_speech.py
to save a checkpoint and then run
python admm_extract_test_spectrograms.py --dataset timit --input-folder data/timit_400_800_8000_sqrt_m/ --measurement-factor 0.5 --sample-rate 8000 --ambient-dim 800 --ckpt chechpoint_name.pt --output-folder my_output_folder
First, run the measure_clean.py
script as
python measure_clean.py --dataset timit --input-folder data/timit_s_800_8000_NORMALIZE
where NORMALIZE is the sqrt_m factor (see "How to run MNIST/CIFAR10" section of README) and s the number of measurements (200 or 320, for 25% or 40% CS ratio, respectively).
For 40% CS ratio (respectively for 25%), run
python evaluate_robustness.py --dataset timit --ista-input-folder data/timit_320_800_8000_ISTA/ --admm-input-folder data/timit_320_800_8000_sqrt_m/ --ista-ckpt ${HOME}/checkpoints/timit_dista/deepISTA-timit-L10-0.001-320.pt --admm-ckpt ${HOME}/checkpoints/timit_admm/admm-timit-l10-mfactor0.4-lr1e-05-rho1.0.pt --output-folder timit_robustness_specs_320_10-10L --measurement-factor 0.4 --sample-rate 8000 --ambient-dim 800 --admm-layers 10 --ista-layers 10
where deepISTA
is the script implementing ISTA-net.
The scripts implementing the baseline ISTA-net were provided to us by original authors of "Compressive Sensing and Neural Networks from a Statistical Learning Perspective", A. Behboodi, H. Rauhut, E. Schnoor, arXiv preprint, arXiv:2010.15658 (2020). For reproducibility purposes, the interested reader may contact them.
If you use the code of this repository, please cite our paper. G. Paraskevopoulos and V. Kouni contributed to write the code of this repository.