Skip to content
/ sfr Public

PyTorch implementation of Sparse Function-space Representation of Neural Networks

License

Notifications You must be signed in to change notification settings

AaltoML/sfr

Folders and files

NameName
Last commit message
Last commit date

Latest commit

3b6df95 · Apr 12, 2024
Feb 9, 2024
Feb 12, 2024
Nov 3, 2023
Nov 8, 2023
Nov 8, 2023
Feb 9, 2024
Apr 12, 2024
Nov 9, 2023
Feb 9, 2024
Nov 9, 2023
Feb 9, 2024
Feb 9, 2024
Feb 7, 2024
Feb 9, 2024
Feb 9, 2024
Feb 9, 2024

Repository files navigation

SFR - Sparse Function-space Representation of Neural Networks

This repository contains a clean and minimal PyTorch implementation of Sparse Function-space Representation (SFR) of Neural Networks. If you'd like to use SFR we recommend using this repo. Please see sfr-experiments for reproducing the experiments in the ICLR 2024 paper.

Function-space Parameterization of Neural Networks for Sequential Learning
Aidan Scannell*, Riccardo Mereu*, Paul Chang, Ella Tamir, Joni Pajarinen, Arno Solin
International Conference on Learning Representations (ICLR 2024)
Paper Code Website
Sparse Function-space Representation of Neural Networks
Aidan Scannell*, Riccardo Mereu*, Paul Chang, Ella Tamir, Joni Pajarinen, Arno Solin
ICML 2023 Workshop on Duality Principles for Modern Machine Learning
Paper Code Website

Install

CPU

Create an environment with:

conda env create -f env_cpu.yaml

Activate the environment with:

source activate sfr

NVIDIA GPU

Create an environment with:

conda env create -f env_nvidia.yaml

Activate the environment with:

source activate sfr

Useage

See the notebooks for how to use our code for both regression and classification.

Image Classification

We provide a minimal training script in train.py which can be used to train a CNN and fit SFR on MNIST/Fashion-MNIST/CIFAR-10. It is advised to run this on GPU.

Example

Here's a short example:

import src
import torch

torch.set_default_dtype(torch.float64)

def func(x, noise=True):
    return torch.sin(x * 5) / x + torch.cos(x * 10)

# Toy data set
X_train = torch.rand((100, 1)) * 2
Y_train = func(X_train, noise=True)
data = (X_train, Y_train)

# Training config
width = 64
num_epochs = 1000
batch_size = 16
learning_rate = 1e-3
delta = 0.00005  # prior precision
data_loader = torch.utils.data.DataLoader(
    torch.utils.data.TensorDataset(*data), batch_size=batch_size
)

# Create a neural network
network = torch.nn.Sequential(
    torch.nn.Linear(1, width),
    torch.nn.Tanh(),
    torch.nn.Linear(width, width),
    torch.nn.Tanh(),
    torch.nn.Linear(width, 1),
)

# Instantiate SFR (handles NN training/prediction as they're coupled via the prior/likelihood)
sfr = src.SFR(
    network=network,
    prior=src.priors.Gaussian(params=network.parameters, delta=delta),
    likelihood=src.likelihoods.Gaussian(sigma_noise=2),
    output_dim=1,
    num_inducing=32,
    dual_batch_size=None, # this reduces the memory required for computing dual parameters
    jitter=1e-4,
)

sfr.train()
optimizer = torch.optim.Adam([{"params": sfr.parameters()}], lr=learning_rate)
for epoch_idx in range(num_epochs):
    for batch_idx, batch in enumerate(data_loader):
        x, y = batch
        loss = sfr.loss(x, y)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

sfr.set_data(data) # This builds the dual parameters

# Make predictions in function space
X_test = torch.linspace(-0.7, 3.5, 300, dtype=torch.float64).reshape(-1, 1)
f_mean, f_var = sfr.predict_f(X_test)

# Make predictions in output space
y_mean, y_var = sfr.predict(X_test)

Development

Set up pre-commit by running:

pre-commit install

Now when you commit the formatter/linter etc will automatically be run.

Citation

Please consider citing our conference paper

@inproceedings{scannellFunction2024,
  title           = {Function-space Prameterization of Neural Networks for Sequential Learning},
  booktitle       = {Proceedings of The Twelth International Conference on Learning Representations (ICLR 2024)},
  author          = {Aidan Scannell and Riccardo Mereu and Paul Chang and Ella Tami and Joni Pajarinen and Arno Solin},
  year            = {2024},
  month           = {5},
}

Or our workshop

@inproceedings{scannellSparse2023,
  title           = {Sparse Function-space Representation of Neural Networks},
  maintitle       = {ICML 2023 Workshop on Duality Principles for Modern Machine Learning},
  author          = {Aidan Scannell and Riccardo Mereu and Paul Chang and Ella Tami and Joni Pajarinen and Arno Solin},
  year            = {2023},
  month           = {7},
}