Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

SWA fix #3

Merged
merged 9 commits into from
Jun 29, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 7 additions & 7 deletions .github/workflows/tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -24,9 +24,9 @@ jobs:
conda-channels: anaconda, conda-forge, bioconda
- name: Create environment
run: |
conda install -n base conda-libmamba-solver
conda install -n base mamba -c conda-forge
conda clean --all -y
conda env create --experimental-solver=libmamba -f environment.yml
mamba env create -f environment.yml
- name: Run tests with pytest
run: |
source activate metamers
Expand All @@ -48,9 +48,9 @@ jobs:
conda-channels: anaconda, conda-forge, bioconda
- name: Create environment
run: |
conda install -n base conda-libmamba-solver
conda install -n base mamba
conda clean --all -y
conda env create --experimental-solver=libmamba -f environment.yml
mamba env create -f environment.yml
- name: Setup FFmpeg
uses: FedericoCarboni/setup-ffmpeg@v1
- name: Download data
Expand Down Expand Up @@ -83,11 +83,11 @@ jobs:
conda-channels: anaconda, conda-forge, bioconda
- name: Create environment
run: |
conda install -n base conda-libmamba-solver
conda install -n base mamba
conda clean --all -y
conda env create --experimental-solver=libmamba -f environment.yml
mamba env create -f environment.yml
source activate metamers
conda install --experimental-solver=libmamba jupyter nbconvert
mamba install jupyter nbconvert
- name: modify config.yml
run: |
sed -i 's|DATA_DIR:.*|DATA_DIR: "data/metamers"|g' config.yml
Expand Down
10 changes: 7 additions & 3 deletions Snakefile
Original file line number Diff line number Diff line change
Expand Up @@ -181,7 +181,11 @@ rule test_setup:
with open(log[0], 'w', buffering=1) as log_file:
with contextlib.redirect_stdout(log_file), contextlib.redirect_stderr(log_file):
print("Copying outputs from %s to %s" % (op.dirname(input[0]), output[0]))
shutil.copytree(op.dirname(input[0]), output[0])
# Copy over only those files that were generated by the call
# with the correct gpu_num wildcard, ignore the others
# (wildcards.gpu_num can only take values 0 or 1)
ignore_gpu = {0: '*_gpu-1*', 1: '*_gpu-0*'}[int(wildcards.gpu_num)]
shutil.copytree(op.dirname(input[0]), output[0], ignore=shutil.ignore_patterns(ignore_gpu))


rule all_refs:
Expand Down Expand Up @@ -559,8 +563,8 @@ def get_mem_estimate(wildcards, partition=None):
if int(wildcards.gpu) == 0:
# in this case, we *do not* want to specify memory (we'll get the
# whole node allocated but slurm could still kill the job if we go
# over requested memory)
mem = ''
# over requested memory). setting mem=0 requests all memory on node
mem = '0'
else:
# we'll be plugging this right into the mem request to slurm, so it
# needs to be exactly correct
Expand Down
1 change: 0 additions & 1 deletion environment.yml
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,6 @@ dependencies:
- gputil>=1.4.0
- opt_einsum>=3.3
- pip:
- torchcontrib==0.0.2
- pyrtools>=1.0
- git+https://github.com/LabForComputationalVision/plenoptic.git@main
- flatten_dict
Expand Down
15 changes: 9 additions & 6 deletions extra_packages/plenoptic_part/synthesize/metamer.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,7 +148,7 @@ def _init_synthesized_signal(self, initial_image, clamper=RangeClamper((0, 1)),

def synthesize(self, initial_image=None, seed=0, max_iter=100, learning_rate=.01,
scheduler=True, optimizer='SGD', optimizer_kwargs={}, swa=False,
swa_kwargs={}, clamper=RangeClamper((0, 1)),
swa_start=10, swa_freq=1, swa_lr=.005, clamper=RangeClamper((0, 1)),
clamp_each_iter=True, store_progress=False, save_progress=False,
save_path='metamer.pt', loss_thresh=1e-4, loss_change_iter=50,
fraction_removed=0., loss_change_thresh=1e-2, loss_change_fraction=1.,
Expand Down Expand Up @@ -191,9 +191,12 @@ def synthesize(self, initial_image=None, seed=0, max_iter=100, learning_rate=.01
the specific optimizer you're using
swa : bool, optional
whether to use stochastic weight averaging or not
swa_kwargs : dict, optional
Dictionary of keyword arguments to pass to the SWA object. See
torchcontrib.optim.SWA docs for more info.
swa_start : int, optional
the iteration to start using stochastic weight iteration
swa_freq : int, optional
how frequently to update the parameters of the averaged model
swa_lr : float, optional
learning rate of the SWA
clamper : plenoptic.Clamper or None, optional
Clamper makes a change to the image in order to ensure that
it stays reasonable. The classic example (and default
Expand Down Expand Up @@ -276,7 +279,7 @@ def synthesize(self, initial_image=None, seed=0, max_iter=100, learning_rate=.01

# initialize the optimizer
self._init_optimizer(optimizer, learning_rate, scheduler, clip_grad_norm,
optimizer_kwargs, swa, swa_kwargs)
optimizer_kwargs, swa, swa_start, swa_freq, swa_lr)

# get ready to store progress
self._init_store_progress(store_progress, save_progress, save_path)
Expand All @@ -302,7 +305,7 @@ def synthesize(self, initial_image=None, seed=0, max_iter=100, learning_rate=.01
pbar.close()

if self._swa:
self._optimizer.swap_swa_sgd()
self.synthesized_signal = self._swa_model.module.synthesized_signal

# finally, stack the saved_* attributes
self._finalize_stored_progress()
Expand Down
69 changes: 49 additions & 20 deletions extra_packages/plenoptic_part/synthesize/synthesis.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,7 @@
from torch import optim
import numpy as np
import warnings
try:
import torchcontrib
except ModuleNotFoundError:
warnings.warn("Unable to import torchcontrib, will be unable to use SWA!")
from torch.optim.swa_utils import AveragedModel, SWALR
import plenoptic as po
from plenoptic.simulate.models.naive import Identity
from ..tools.optim import l2_norm
Expand All @@ -21,6 +18,14 @@
from ..tools.clamps import RangeClamper


class DummyModel(torch.nn.Module):
"""Dummy model to wrap synthesized image, for SWA."""
def __init__(self, synthesized_signal):
super().__init__()
# this is already a parameter
self.synthesized_signal = synthesized_signal


class Synthesis(metaclass=abc.ABCMeta):
r"""Abstract super-class for synthesis methods

Expand Down Expand Up @@ -494,8 +499,8 @@ def _finalize_stored_progress(self):

@abc.abstractmethod
def synthesize(self, seed=0, max_iter=100, learning_rate=1, scheduler=True, optimizer='Adam',
optimizer_kwargs={}, swa=False, swa_kwargs={}, clamper=RangeClamper((0, 1)),
clamp_each_iter=True, store_progress=False,
optimizer_kwargs={}, swa=False, swa_start=10, swa_freq=1, swa_lr=.5,
clamper=RangeClamper((0, 1)), clamp_each_iter=True, store_progress=False,
save_progress=False, save_path='synthesis.pt', loss_thresh=1e-4,
loss_change_iter=50, fraction_removed=0., loss_change_thresh=1e-2,
loss_change_fraction=1., coarse_to_fine=False, clip_grad_norm=False):
Expand Down Expand Up @@ -534,10 +539,14 @@ def synthesize(self, seed=0, max_iter=100, learning_rate=1, scheduler=True, opti
addition to learning_rate). What these should be depend on
the specific optimizer you're using
swa : bool, optional
whether to use stochastic weight averaging or not
swa_kwargs : dict, optional
Dictionary of keyword arguments to pass to the SWA object. See
torchcontrib.optim.SWA docs for more info.
whether to use stochastic weight averaging or not, see
https://pytorch.org/blog/pytorch-1.6-now-includes-stochastic-weight-averaging/
swa_start : int, optional
the iteration to start using stochastic weight iteration
swa_freq : int, optional
how frequently to update the parameters of the averaged model
swa_lr : float, optional
learning rate of the SWA
clamper : plenoptic.Clamper or None, optional
Clamper makes a change to the image in order to ensure that
it stays reasonable. The classic example (and default
Expand Down Expand Up @@ -618,7 +627,7 @@ def synthesize(self, seed=0, max_iter=100, learning_rate=1, scheduler=True, opti
loss_change_fraction, loss_change_thresh, loss_change_iter)
# initialize the optimizer
self._init_optimizer(optimizer, learning_rate, scheduler, clip_grad_norm,
optimizer_kwargs, swa, swa_kwargs)
optimizer_kwargs, swa, swa_start, swa_freq, swa_lr)
# get ready to store progress
self._init_store_progress(store_progress, save_progress)

Expand Down Expand Up @@ -752,7 +761,8 @@ def representation_error(self, iteration=None, **kwargs):
return rep_error

def _init_optimizer(self, optimizer, lr, scheduler=True, clip_grad_norm=False,
optimizer_kwargs={}, swa=False, swa_kwargs={}):
optimizer_kwargs={}, swa=False, swa_start=10, swa_freq=1,
swa_lr=.5):
"""Initialize the optimzer and learning rate scheduler

This gets called at the beginning of synthesize() and can also
Expand Down Expand Up @@ -806,15 +816,18 @@ def _init_optimizer(self, optimizer, lr, scheduler=True, clip_grad_norm=False,
passed to the optimizer's initializer
swa : bool, optional
whether to use stochastic weight averaging or not
swa_kwargs : dict, optional
Dictionary of keyword arguments to pass to the SWA object.
swa_start : int, optional
the iteration to start using stochastic weight iteration
swa_freq : int, optional
how frequently to update the parameters of the averaged model
swa_lr : float, optional
learning rate of the SWA

"""
# there's a weird scoping issue that happens if we don't copy the
# dictionary, where it can accidentally persist across instances of the
# object, which messes all sorts of things up
optimizer_kwargs = optimizer_kwargs.copy()
swa_kwargs = swa_kwargs.copy()
# if lr is None, we're resuming synthesis from earlier, and we
# want to start with the last learning rate. however, we also
# want to keep track of the initial learning rate, since we use
Expand Down Expand Up @@ -854,8 +867,12 @@ def _init_optimizer(self, optimizer, lr, scheduler=True, clip_grad_norm=False,
else:
raise Exception("Don't know how to handle optimizer %s!" % optimizer)
self._swa = swa
self._swa_counter = 0
if swa:
self._optimizer = torchcontrib.optim.SWA(self._optimizer, **swa_kwargs)
self._swa_start = swa_start
self._swa_freq = swa_freq
self._swa_model = None
self._scheduler = SWALR(self._optimizer, swa_lr=swa_lr)
warnings.warn("When using SWA, can't also use LR scheduler")
else:
if scheduler:
Expand All @@ -871,8 +888,8 @@ def _init_optimizer(self, optimizer, lr, scheduler=True, clip_grad_norm=False,
# initial_lr here
init_optimizer_kwargs = {'optimizer': optimizer, 'lr': initial_lr,
'scheduler': scheduler, 'swa': swa,
'swa_kwargs': swa_kwargs,
'optimizer_kwargs': optimizer_kwargs}
'swa_start': swa_start, 'swa_freq': swa_freq,
'swa_lr': swa_lr, 'optimizer_kwargs': optimizer_kwargs}
self._init_optimizer_kwargs = init_optimizer_kwargs
if clip_grad_norm is True:
self.clip_grad_norm = 1
Expand Down Expand Up @@ -1000,7 +1017,16 @@ def _optimizer_step(self, pbar=None, **kwargs):
g = self.synthesized_signal.grad.detach()
# optionally step the scheduler
if self._scheduler is not None:
self._scheduler.step(loss.item())
if self._swa:
if self._swa_counter > self._swa_start:
if self._swa_model is None:
self._dummy_model = DummyModel(self.synthesized_signal)
self._swa_model = AveragedModel(self._dummy_model)
elif self._swa_counter % self._swa_freq == 0:
self._swa_model.update_parameters(self._dummy_model)
self._scheduler.step()
else:
self._scheduler.step(loss.item())

if self.coarse_to_fine and self.scales[0] != 'all':
with torch.no_grad():
Expand All @@ -1017,10 +1043,13 @@ def _optimizer_step(self, pbar=None, **kwargs):
postfix_dict.update(dict(loss="%.4e" % abs(loss.item()),
gradient_norm="%.4e" % g.norm().item(),
learning_rate=self._optimizer.param_groups[0]['lr'],
pixel_change=f"{pixel_change:.04e}", **kwargs))
pixel_change=f"{pixel_change:.04e}",
swa_counter=self._swa_counter,
**kwargs))
# add extra info here if you want it to show up in progress bar
if pbar is not None:
pbar.set_postfix(**postfix_dict)
self._swa_counter += 1
return loss, g.norm(), self._optimizer.param_groups[0]['lr'], pixel_change

@abc.abstractmethod
Expand Down
5 changes: 3 additions & 2 deletions foveated_metamers/create_metamers.py
Original file line number Diff line number Diff line change
Expand Up @@ -1121,12 +1121,13 @@ def main(model_name, scaling, image, seed=0, min_ecc=.5, max_ecc=15, learning_ra
clamp_each_iter=clamp_each_iter,
save_progress=save_progress,
optimizer=optimizer,
swa=swa, swa_kwargs=swa_kwargs,
swa=swa,
fraction_removed=fraction_removed,
loss_change_fraction=loss_change_fraction,
loss_change_thresh=loss_change_thresh,
coarse_to_fine=coarse_to_fine,
save_path=inprogress_path)
save_path=inprogress_path,
**swa_kwargs)
duration = time.time() - start_time
# make sure everything's on the cpu for saving
metamer = metamer.to('cpu')
Expand Down