Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Generative Replay #931

Merged
merged 58 commits into from
Apr 8, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
58 commits
Select commit Hold shift + click to select a range
0dc03fa
Add templates, plugins and models for GR using VAE.
travela Feb 26, 2022
945777f
Update __init__.py and imports.
travela Feb 26, 2022
4d6f498
PEP8 formatting.
travela Mar 2, 2022
82d5dc8
Incorporate plugins and generator strategy
travela Mar 2, 2022
2d63e89
Add documentation to VAE model.
travela Mar 3, 2022
2cef28a
Introduce latent dimension variable for VAE encoder.
travela Mar 3, 2022
8fb9575
Fix from last commit.
travela Mar 3, 2022
6ccff43
Documentation;
travela Mar 3, 2022
bf2c20e
Fix introduced bug.
travela Mar 3, 2022
bc8d7ee
Add boolean to VAETraining call.
travela Mar 3, 2022
baf80e4
Fix 2.0
travela Mar 3, 2022
7ba6f96
Try to move the GenerativeReplayPlugin call outside of the VAETrainin…
travela Mar 3, 2022
5c9b523
Remove redundant code.
travela Mar 3, 2022
5d4e91b
Merge pull request #1 from travela/refactor
travela Mar 3, 2022
8473084
Document all GR plugins.
travela Mar 5, 2022
2d677bd
Module header.
travela Mar 5, 2022
660d6ba
Merge branch 'refactor' into generative_replay
travela Mar 5, 2022
0b96f38
Make VAE more general: any input shape is allowed.
travela Mar 6, 2022
1af464e
Removing reliance on VAEPlugin.
travela Mar 6, 2022
5edd13c
Set default evaluator for VAE to None.
travela Mar 6, 2022
12bcc4e
Add interactive logger to VAETraining.
travela Mar 6, 2022
9de2963
Bug fix: Generator doesn't have to label its replay data; Remove VAEP…
travela Mar 6, 2022
5bec8a9
Doc.
travela Mar 6, 2022
ff87dbb
Change CI workflow to run unittest for generative_replay branch.
travela Mar 8, 2022
171bd39
Create splitMNIST example.
travela Mar 8, 2022
e2eddf3
Create VAE on MNIST example for GenerativeReplayPlugin.
travela Mar 8, 2022
b383954
Change VAE loss function
travela Mar 8, 2022
387e711
detach() samples.
travela Mar 8, 2022
e73d3be
save plot of generated samples and try to open window.
travela Mar 9, 2022
7265e23
Lower number of exp for testing
travela Mar 9, 2022
fe892f0
train 3 exp.
travela Mar 9, 2022
bb1179d
Fix bug
travela Mar 9, 2022
450cc8c
Save all plots in a single file.
travela Mar 9, 2022
279d56e
Change TrainGeneratorAfterExpPlugin name; use current_experience to d…
travela Mar 9, 2022
3d874ad
[General] VAE model exports; try to remove device
travela Mar 9, 2022
7aa3565
Update generator.py
travela Mar 9, 2022
4331d82
Pass device to VAE.
travela Mar 9, 2022
bc8c217
Merge branch 'remove_device' of github.com:travela/avalanche into rem…
travela Mar 9, 2022
ca418ea
Pass device in strategy to VAE
travela Mar 9, 2022
6d81799
[General] Docstring
travela Mar 9, 2022
3982c69
Merge pull request #2 from travela/remove_device
travela Mar 9, 2022
3d66711
Reverse unit-test.yml changes.
travela Mar 9, 2022
e910cf1
Remove commented line
travela Mar 9, 2022
bba78b9
Resolve Requested Changes (#3)
travela Mar 23, 2022
c9f5a23
Move general modules from VAE to utils.
travela Mar 27, 2022
d9c1f1c
Add condition as an input to the abstract generator class.
travela Mar 27, 2022
3f096b1
Clarify confusing generator generator_strategy naming.
travela Mar 27, 2022
03150d1
Update documentation of the GenerativeReplayPlugin
travela Mar 27, 2022
4f5246f
before_training doc string
travela Mar 27, 2022
746e4e3
Renaming of VAE models.
travela Mar 27, 2022
d36e541
Remove TrainGeneratorAfterExpPlugin plugins indexing.
travela Mar 27, 2022
fddf3cc
Merge pull request #5 from travela/requested_changes
travela Mar 27, 2022
99427ac
Make increasing replay batch size optional.
travela Apr 1, 2022
1a8becf
Pass new arguments from strategy to plugin.
travela Apr 1, 2022
a8fb34f
Merge pull request #6 from travela/optional_enhancements
travela Apr 1, 2022
e0a4029
Merge branch 'master' into generative_replay
travela Apr 1, 2022
91504b0
Merge branch 'generative_replay' of github.com:travela/avalanche into…
travela Apr 1, 2022
dfa1d69
update multihead test
AntonioCarta Apr 8, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions avalanche/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,3 +19,4 @@
from .base_model import BaseModel
from .helper_method import as_multitask
from .pnn import PNN
from .generator import *
193 changes: 193 additions & 0 deletions avalanche/models/generator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,193 @@
################################################################################
# Copyright (c) 2021 ContinualAI. #
# Copyrights licensed under the MIT License. #
# See the accompanying LICENSE file for terms. #
# #
# Date: 03-03-2022 #
# Author: Florian Mies #
# Website: https://github.com/travela #
################################################################################

"""

File to place any kind of generative models
and their respective helper functions.

"""

from abc import abstractmethod
from matplotlib import transforms
import torch
import torch.nn as nn
from torchvision import transforms
from avalanche.models.utils import MLP, Flatten
from avalanche.models.base_model import BaseModel


class Generator(BaseModel):
"""
A base abstract class for generators
"""

@abstractmethod
def generate(self, batch_size=None, condition=None):
"""
Lets the generator sample random samples.
Output is either a single sample or, if provided,
a batch of samples of size "batch_size"

:param batch_size: Number of samples to generate
:param condition: Possible condition for a condotional generator
(e.g. a class label)
"""


###########################
# VARIATIONAL AUTOENCODER #
###########################


class VAEMLPEncoder(nn.Module):
'''
Encoder part of the VAE, computer the latent represenations of the input.

:param shape: Shape of the input to the network: (channels, height, width)
:param latent_dim: Dimension of last hidden layer
'''

def __init__(self, shape, latent_dim=128):
super(VAEMLPEncoder, self).__init__()
flattened_size = torch.Size(shape).numel()
self.encode = nn.Sequential(
Flatten(),
nn.Linear(in_features=flattened_size, out_features=400),
nn.BatchNorm1d(400),
nn.LeakyReLU(),
MLP([400, latent_dim])
)

def forward(self, x, y=None):
x = self.encode(x)
return x


class VAEMLPDecoder(nn.Module):
'''
Decoder part of the VAE. Reverses Encoder.

:param shape: Shape of output: (channels, height, width).
:param nhid: Dimension of input.
'''

def __init__(self, shape, nhid=16):
super(VAEMLPDecoder, self).__init__()
flattened_size = torch.Size(shape).numel()
self.shape = shape
self.decode = nn.Sequential(
MLP([nhid, 64, 128, 256, flattened_size], last_activation=False),
nn.Sigmoid())
self.invTrans = transforms.Compose([
transforms.Normalize((0.1307,), (0.3081,))
])

def forward(self, z, y=None):
if (y is None):
return self.invTrans(self.decode(z).view(-1, *self.shape))
else:
return self.invTrans(self.decode(torch.cat((z, y), dim=1))
.view(-1, *self.shape))


class MlpVAE(Generator, nn.Module):
'''
Variational autoencoder module:
fully-connected and suited for any input shape and type.

The encoder only computes the latent represenations
and we have then two possible output heads:
One for the usual output distribution and one for classification.
The latter is an extension the conventional VAE and incorporates
a classifier into the network.
More details can be found in: https://arxiv.org/abs/1809.10635
'''

def __init__(self, shape, nhid=16, n_classes=10, device="cpu"):
"""
:param shape: Shape of each input sample
:param nhid: Dimension of latent space of Encoder.
:param n_classes: Number of classes -
defines classification head's dimension
"""
super(MlpVAE, self).__init__()
self.dim = nhid
self.device = device
self.encoder = VAEMLPEncoder(shape, latent_dim=128)
self.calc_mean = MLP([128, nhid], last_activation=False)
self.calc_logvar = MLP([128, nhid], last_activation=False)
self.classification = MLP([128, n_classes], last_activation=False)
self.decoder = VAEMLPDecoder(shape, nhid)

def get_features(self, x):
"""
Get features for encoder part given input x
"""
return self.encoder(x)

def generate(self, batch_size=None):
"""
Generate random samples.
Output is either a single sample if batch_size=None,
else it is a batch of samples of size "batch_size".
"""
z = torch.randn((batch_size, self.dim)).to(
self.device) if batch_size else torch.randn((1, self.dim)).to(
self.device)
res = self.decoder(z)
if not batch_size:
res = res.squeeze(0)
return res

def sampling(self, mean, logvar):
"""
VAE 'reparametrization trick'
"""
eps = torch.randn(mean.shape).to(self.device)
sigma = 0.5 * torch.exp(logvar)
return mean + eps * sigma

def forward(self, x):
"""
Forward.
"""
represntations = self.encoder(x)
mean, logvar = self.calc_mean(
represntations), self.calc_logvar(represntations)
z = self.sampling(mean, logvar)
return self.decoder(z), mean, logvar


# Loss functions
BCE_loss = nn.BCELoss(reduction="sum")
MSE_loss = nn.MSELoss(reduction="sum")
CE_loss = nn.CrossEntropyLoss()


def VAE_loss(X, forward_output):
'''
Loss function of a VAE using mean squared error for reconstruction loss.
This is the criterion for VAE training loop.

:param X: Original input batch.
:param forward_output: Return value of a VAE.forward() call.
Triplet consisting of (X_hat, mean. logvar), ie.
(Reconstructed input after subsequent Encoder and Decoder,
mean of the VAE output distribution,
logvar of the VAE output distribution)
'''
X_hat, mean, logvar = forward_output
reconstruction_loss = MSE_loss(X_hat, X)
KL_divergence = 0.5 * torch.sum(-1 - logvar + torch.exp(logvar) + mean**2)
return reconstruction_loss + KL_divergence


__all__ = ["MlpVAE", "VAE_loss"]
45 changes: 44 additions & 1 deletion avalanche/models/utils.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from avalanche.benchmarks.utils import AvalancheDataset
from avalanche.models.dynamic_modules import MultiTaskModule, DynamicModule
import torch.nn as nn
from collections import OrderedDict


def avalanche_forward(model, x, task_labels):
Expand Down Expand Up @@ -59,4 +60,46 @@ def add_hooks(self, model):
)


__all__ = ["avalanche_forward", "FeatureExtractorBackbone"]
class Flatten(nn.Module):
'''
Simple nn.Module to flatten each tensor of a batch of tensors.
'''

def __init__(self):
super(Flatten, self).__init__()

def forward(self, x):
batch_size = x.shape[0]
return x.view(batch_size, -1)


class MLP(nn.Module):
'''
Simple nn.Module to create a multi-layer perceptron
with BatchNorm and ReLU activations.

:param hidden_size: An array indicating the number of neurons in each layer.
:type hidden_size: int[]
:param last_activation: Indicates whether to add BatchNorm and ReLU
after the last layer.
:type last_activation: Boolean
'''

def __init__(self, hidden_size, last_activation=True):
super(MLP, self).__init__()
q = []
for i in range(len(hidden_size)-1):
in_dim = hidden_size[i]
out_dim = hidden_size[i+1]
q.append(("Linear_%d" % i, nn.Linear(in_dim, out_dim)))
if (i < len(hidden_size)-2) or ((i == len(hidden_size) - 2)
and (last_activation)):
q.append(("BatchNorm_%d" % i, nn.BatchNorm1d(out_dim)))
q.append(("ReLU_%d" % i, nn.ReLU(inplace=True)))
self.mlp = nn.Sequential(OrderedDict(q))

def forward(self, x):
return self.mlp(x)


__all__ = ["avalanche_forward", "FeatureExtractorBackbone", "MLP", "Flatten"]
2 changes: 2 additions & 0 deletions avalanche/training/plugins/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,5 +13,7 @@
from .lfl import LFLPlugin
from .early_stopping import EarlyStoppingPlugin
from .lr_scheduling import LRSchedulerPlugin
from .generative_replay import GenerativeReplayPlugin, \
TrainGeneratorAfterExpPlugin
from .rwalk import RWalkPlugin
from .mas import MASPlugin
Loading