Skip to content

bayesiains/nsf

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

21 Commits
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

Neural Spline Flows

A record of the code and experiments for the paper:

C. Durkan, A. Bekasov, I. Murray, G. Papamakarios, Neural Spline Flows, NeurIPS 2019. [arXiv] [bibtex]

Work in this repository has now stopped. Please go to nflows for an updated and pip-installable normalizing flows framework for PyTorch.

Dependencies

See environment.yml for required Conda/pip packages, or use this to create a Conda environment with all dependencies:

conda env create -f environment.yml

Tested with Python 3.5 and PyTorch 1.1.

Data

Data for density-estimation experiments is available at https://zenodo.org/record/1161203#.Wmtf_XVl8eN.

Data for VAE and image-modeling experiments is downloaded automatically using either torchvision or custom data providers.

Usage

DATAROOT environment variable needs to be set before running experiments.

2D toy density experiments

Use experiments/face.py or experiments/plane.py.

Density-estimation experiments

Use experiments/uci.py.

VAE experiments

Use experiments/vae_.py.

Image-modeling experiments

Use experiments/images.py.

Sacred is used to organize image experiments. See the documentation for more information.

experiments/image_configs contains .json configurations used for RQ-NSF (C) experiments. For baseline experiments use coupling_layer_type='affine'.

For example, to run RQ-NSF (C) on CIFAR-10 8-bit:

python experiments/images.py with experiments/image_configs/cifar-10-8bit.json

Corresponding affine baseline run:

python experiments/images.py with experiments/image_configs/cifar-10-8bit.json coupling_layer_type='affine'

To evaluate on the test set:

python experiments/images.py eval_on_test with experiments/image_configs/cifar-10-8bit.json flow_checkpoint='<saved_checkpoint>'

To sample:

python experiments/images.py sample with experiments/image_configs/cifar-10-8bit.json flow_checkpoint='<saved_checkpoint>'