diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index 1734ebc..b57ed81 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -24,9 +24,9 @@ jobs: conda-channels: anaconda, conda-forge, bioconda - name: Create environment run: | - conda install -n base conda-libmamba-solver + conda install -n base mamba -c conda-forge conda clean --all -y - conda env create --experimental-solver=libmamba -f environment.yml + mamba env create -f environment.yml - name: Run tests with pytest run: | source activate metamers @@ -48,9 +48,9 @@ jobs: conda-channels: anaconda, conda-forge, bioconda - name: Create environment run: | - conda install -n base conda-libmamba-solver + conda install -n base mamba conda clean --all -y - conda env create --experimental-solver=libmamba -f environment.yml + mamba env create -f environment.yml - name: Setup FFmpeg uses: FedericoCarboni/setup-ffmpeg@v1 - name: Download data @@ -83,11 +83,11 @@ jobs: conda-channels: anaconda, conda-forge, bioconda - name: Create environment run: | - conda install -n base conda-libmamba-solver + conda install -n base mamba conda clean --all -y - conda env create --experimental-solver=libmamba -f environment.yml + mamba env create -f environment.yml source activate metamers - conda install --experimental-solver=libmamba jupyter nbconvert + mamba install jupyter nbconvert - name: modify config.yml run: | sed -i 's|DATA_DIR:.*|DATA_DIR: "data/metamers"|g' config.yml diff --git a/Snakefile b/Snakefile index 70a0943..1ebb189 100644 --- a/Snakefile +++ b/Snakefile @@ -181,7 +181,11 @@ rule test_setup: with open(log[0], 'w', buffering=1) as log_file: with contextlib.redirect_stdout(log_file), contextlib.redirect_stderr(log_file): print("Copying outputs from %s to %s" % (op.dirname(input[0]), output[0])) - shutil.copytree(op.dirname(input[0]), output[0]) + # Copy over only those files that were generated by the call + # with the correct gpu_num wildcard, ignore the others + # (wildcards.gpu_num can only take values 0 or 1) + ignore_gpu = {0: '*_gpu-1*', 1: '*_gpu-0*'}[int(wildcards.gpu_num)] + shutil.copytree(op.dirname(input[0]), output[0], ignore=shutil.ignore_patterns(ignore_gpu)) rule all_refs: @@ -559,8 +563,8 @@ def get_mem_estimate(wildcards, partition=None): if int(wildcards.gpu) == 0: # in this case, we *do not* want to specify memory (we'll get the # whole node allocated but slurm could still kill the job if we go - # over requested memory) - mem = '' + # over requested memory). setting mem=0 requests all memory on node + mem = '0' else: # we'll be plugging this right into the mem request to slurm, so it # needs to be exactly correct diff --git a/environment.yml b/environment.yml index 016a6a1..8540b7a 100644 --- a/environment.yml +++ b/environment.yml @@ -27,7 +27,6 @@ dependencies: - gputil>=1.4.0 - opt_einsum>=3.3 - pip: - - torchcontrib==0.0.2 - pyrtools>=1.0 - git+https://github.com/LabForComputationalVision/plenoptic.git@main - flatten_dict diff --git a/extra_packages/plenoptic_part/synthesize/metamer.py b/extra_packages/plenoptic_part/synthesize/metamer.py index 353da9e..59f7936 100644 --- a/extra_packages/plenoptic_part/synthesize/metamer.py +++ b/extra_packages/plenoptic_part/synthesize/metamer.py @@ -148,7 +148,7 @@ def _init_synthesized_signal(self, initial_image, clamper=RangeClamper((0, 1)), def synthesize(self, initial_image=None, seed=0, max_iter=100, learning_rate=.01, scheduler=True, optimizer='SGD', optimizer_kwargs={}, swa=False, - swa_kwargs={}, clamper=RangeClamper((0, 1)), + swa_start=10, swa_freq=1, swa_lr=.005, clamper=RangeClamper((0, 1)), clamp_each_iter=True, store_progress=False, save_progress=False, save_path='metamer.pt', loss_thresh=1e-4, loss_change_iter=50, fraction_removed=0., loss_change_thresh=1e-2, loss_change_fraction=1., @@ -191,9 +191,12 @@ def synthesize(self, initial_image=None, seed=0, max_iter=100, learning_rate=.01 the specific optimizer you're using swa : bool, optional whether to use stochastic weight averaging or not - swa_kwargs : dict, optional - Dictionary of keyword arguments to pass to the SWA object. See - torchcontrib.optim.SWA docs for more info. + swa_start : int, optional + the iteration to start using stochastic weight iteration + swa_freq : int, optional + how frequently to update the parameters of the averaged model + swa_lr : float, optional + learning rate of the SWA clamper : plenoptic.Clamper or None, optional Clamper makes a change to the image in order to ensure that it stays reasonable. The classic example (and default @@ -276,7 +279,7 @@ def synthesize(self, initial_image=None, seed=0, max_iter=100, learning_rate=.01 # initialize the optimizer self._init_optimizer(optimizer, learning_rate, scheduler, clip_grad_norm, - optimizer_kwargs, swa, swa_kwargs) + optimizer_kwargs, swa, swa_start, swa_freq, swa_lr) # get ready to store progress self._init_store_progress(store_progress, save_progress, save_path) @@ -302,7 +305,7 @@ def synthesize(self, initial_image=None, seed=0, max_iter=100, learning_rate=.01 pbar.close() if self._swa: - self._optimizer.swap_swa_sgd() + self.synthesized_signal = self._swa_model.module.synthesized_signal # finally, stack the saved_* attributes self._finalize_stored_progress() diff --git a/extra_packages/plenoptic_part/synthesize/synthesis.py b/extra_packages/plenoptic_part/synthesize/synthesis.py index edcabd4..f0ebe80 100644 --- a/extra_packages/plenoptic_part/synthesize/synthesis.py +++ b/extra_packages/plenoptic_part/synthesize/synthesis.py @@ -6,10 +6,7 @@ from torch import optim import numpy as np import warnings -try: - import torchcontrib -except ModuleNotFoundError: - warnings.warn("Unable to import torchcontrib, will be unable to use SWA!") +from torch.optim.swa_utils import AveragedModel, SWALR import plenoptic as po from plenoptic.simulate.models.naive import Identity from ..tools.optim import l2_norm @@ -21,6 +18,14 @@ from ..tools.clamps import RangeClamper +class DummyModel(torch.nn.Module): + """Dummy model to wrap synthesized image, for SWA.""" + def __init__(self, synthesized_signal): + super().__init__() + # this is already a parameter + self.synthesized_signal = synthesized_signal + + class Synthesis(metaclass=abc.ABCMeta): r"""Abstract super-class for synthesis methods @@ -494,8 +499,8 @@ def _finalize_stored_progress(self): @abc.abstractmethod def synthesize(self, seed=0, max_iter=100, learning_rate=1, scheduler=True, optimizer='Adam', - optimizer_kwargs={}, swa=False, swa_kwargs={}, clamper=RangeClamper((0, 1)), - clamp_each_iter=True, store_progress=False, + optimizer_kwargs={}, swa=False, swa_start=10, swa_freq=1, swa_lr=.5, + clamper=RangeClamper((0, 1)), clamp_each_iter=True, store_progress=False, save_progress=False, save_path='synthesis.pt', loss_thresh=1e-4, loss_change_iter=50, fraction_removed=0., loss_change_thresh=1e-2, loss_change_fraction=1., coarse_to_fine=False, clip_grad_norm=False): @@ -534,10 +539,14 @@ def synthesize(self, seed=0, max_iter=100, learning_rate=1, scheduler=True, opti addition to learning_rate). What these should be depend on the specific optimizer you're using swa : bool, optional - whether to use stochastic weight averaging or not - swa_kwargs : dict, optional - Dictionary of keyword arguments to pass to the SWA object. See - torchcontrib.optim.SWA docs for more info. + whether to use stochastic weight averaging or not, see + https://pytorch.org/blog/pytorch-1.6-now-includes-stochastic-weight-averaging/ + swa_start : int, optional + the iteration to start using stochastic weight iteration + swa_freq : int, optional + how frequently to update the parameters of the averaged model + swa_lr : float, optional + learning rate of the SWA clamper : plenoptic.Clamper or None, optional Clamper makes a change to the image in order to ensure that it stays reasonable. The classic example (and default @@ -618,7 +627,7 @@ def synthesize(self, seed=0, max_iter=100, learning_rate=1, scheduler=True, opti loss_change_fraction, loss_change_thresh, loss_change_iter) # initialize the optimizer self._init_optimizer(optimizer, learning_rate, scheduler, clip_grad_norm, - optimizer_kwargs, swa, swa_kwargs) + optimizer_kwargs, swa, swa_start, swa_freq, swa_lr) # get ready to store progress self._init_store_progress(store_progress, save_progress) @@ -752,7 +761,8 @@ def representation_error(self, iteration=None, **kwargs): return rep_error def _init_optimizer(self, optimizer, lr, scheduler=True, clip_grad_norm=False, - optimizer_kwargs={}, swa=False, swa_kwargs={}): + optimizer_kwargs={}, swa=False, swa_start=10, swa_freq=1, + swa_lr=.5): """Initialize the optimzer and learning rate scheduler This gets called at the beginning of synthesize() and can also @@ -806,15 +816,18 @@ def _init_optimizer(self, optimizer, lr, scheduler=True, clip_grad_norm=False, passed to the optimizer's initializer swa : bool, optional whether to use stochastic weight averaging or not - swa_kwargs : dict, optional - Dictionary of keyword arguments to pass to the SWA object. + swa_start : int, optional + the iteration to start using stochastic weight iteration + swa_freq : int, optional + how frequently to update the parameters of the averaged model + swa_lr : float, optional + learning rate of the SWA """ # there's a weird scoping issue that happens if we don't copy the # dictionary, where it can accidentally persist across instances of the # object, which messes all sorts of things up optimizer_kwargs = optimizer_kwargs.copy() - swa_kwargs = swa_kwargs.copy() # if lr is None, we're resuming synthesis from earlier, and we # want to start with the last learning rate. however, we also # want to keep track of the initial learning rate, since we use @@ -854,8 +867,12 @@ def _init_optimizer(self, optimizer, lr, scheduler=True, clip_grad_norm=False, else: raise Exception("Don't know how to handle optimizer %s!" % optimizer) self._swa = swa + self._swa_counter = 0 if swa: - self._optimizer = torchcontrib.optim.SWA(self._optimizer, **swa_kwargs) + self._swa_start = swa_start + self._swa_freq = swa_freq + self._swa_model = None + self._scheduler = SWALR(self._optimizer, swa_lr=swa_lr) warnings.warn("When using SWA, can't also use LR scheduler") else: if scheduler: @@ -871,8 +888,8 @@ def _init_optimizer(self, optimizer, lr, scheduler=True, clip_grad_norm=False, # initial_lr here init_optimizer_kwargs = {'optimizer': optimizer, 'lr': initial_lr, 'scheduler': scheduler, 'swa': swa, - 'swa_kwargs': swa_kwargs, - 'optimizer_kwargs': optimizer_kwargs} + 'swa_start': swa_start, 'swa_freq': swa_freq, + 'swa_lr': swa_lr, 'optimizer_kwargs': optimizer_kwargs} self._init_optimizer_kwargs = init_optimizer_kwargs if clip_grad_norm is True: self.clip_grad_norm = 1 @@ -1000,7 +1017,16 @@ def _optimizer_step(self, pbar=None, **kwargs): g = self.synthesized_signal.grad.detach() # optionally step the scheduler if self._scheduler is not None: - self._scheduler.step(loss.item()) + if self._swa: + if self._swa_counter > self._swa_start: + if self._swa_model is None: + self._dummy_model = DummyModel(self.synthesized_signal) + self._swa_model = AveragedModel(self._dummy_model) + elif self._swa_counter % self._swa_freq == 0: + self._swa_model.update_parameters(self._dummy_model) + self._scheduler.step() + else: + self._scheduler.step(loss.item()) if self.coarse_to_fine and self.scales[0] != 'all': with torch.no_grad(): @@ -1017,10 +1043,13 @@ def _optimizer_step(self, pbar=None, **kwargs): postfix_dict.update(dict(loss="%.4e" % abs(loss.item()), gradient_norm="%.4e" % g.norm().item(), learning_rate=self._optimizer.param_groups[0]['lr'], - pixel_change=f"{pixel_change:.04e}", **kwargs)) + pixel_change=f"{pixel_change:.04e}", + swa_counter=self._swa_counter, + **kwargs)) # add extra info here if you want it to show up in progress bar if pbar is not None: pbar.set_postfix(**postfix_dict) + self._swa_counter += 1 return loss, g.norm(), self._optimizer.param_groups[0]['lr'], pixel_change @abc.abstractmethod diff --git a/foveated_metamers/create_metamers.py b/foveated_metamers/create_metamers.py index 3f1baa9..2c41335 100644 --- a/foveated_metamers/create_metamers.py +++ b/foveated_metamers/create_metamers.py @@ -1121,12 +1121,13 @@ def main(model_name, scaling, image, seed=0, min_ecc=.5, max_ecc=15, learning_ra clamp_each_iter=clamp_each_iter, save_progress=save_progress, optimizer=optimizer, - swa=swa, swa_kwargs=swa_kwargs, + swa=swa, fraction_removed=fraction_removed, loss_change_fraction=loss_change_fraction, loss_change_thresh=loss_change_thresh, coarse_to_fine=coarse_to_fine, - save_path=inprogress_path) + save_path=inprogress_path, + **swa_kwargs) duration = time.time() - start_time # make sure everything's on the cpu for saving metamer = metamer.to('cpu')