Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

BYOL weight update callback #867

Merged
merged 5 commits into from
Aug 25, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
50 changes: 23 additions & 27 deletions pl_bolts/callbacks/byol_updates.py
Original file line number Diff line number Diff line change
@@ -1,27 +1,27 @@
import math
from typing import Sequence, Union

import torch.nn as nn
from pytorch_lightning import Callback, LightningModule, Trainer
from torch import Tensor
from torch.nn import Module

from pl_bolts.utils.stability import under_review


@under_review()
class BYOLMAWeightUpdate(Callback):
"""Weight update rule from BYOL.
"""Weight update rule from Bootstrap Your Own Latent (BYOL).

Your model should have:
Updates the target_network params using an exponential moving average update rule weighted by tau.
BYOL claims this keeps the online_network from collapsing.

The PyTorch Lightning module being trained should have:

- ``self.online_network``
- ``self.target_network``

Updates the target_network params using an exponential moving average update rule weighted by tau.
BYOL claims this keeps the online_network from collapsing.

.. note:: Automatically increases tau from ``initial_tau`` to 1.0 with every training step

Args:
initial_tau (float, optional): starting tau. Auto-updates with every training step

Example::

# model must have 2 attributes
Expand All @@ -32,11 +32,10 @@ class BYOLMAWeightUpdate(Callback):
trainer = Trainer(callbacks=[BYOLMAWeightUpdate()])
"""

def __init__(self, initial_tau: float = 0.996):
"""
Args:
initial_tau: starting tau. Auto-updates with every training step
"""
def __init__(self, initial_tau: float = 0.996) -> None:
if not 0.0 <= initial_tau <= 1.0:
raise ValueError(f"initial tau should be between 0 and 1 instead of {initial_tau}.")

super().__init__()
self.initial_tau = initial_tau
self.current_tau = initial_tau
Expand All @@ -53,21 +52,18 @@ def on_train_batch_end(
online_net = pl_module.online_network
target_net = pl_module.target_network

# update weights
# update target network weights
self.update_weights(online_net, target_net)

# update tau after
self.current_tau = self.update_tau(pl_module, trainer)
self.update_tau(pl_module, trainer)

def update_tau(self, pl_module: LightningModule, trainer: Trainer) -> float:
def update_tau(self, pl_module: LightningModule, trainer: Trainer) -> None:
"""Update tau value for next update."""
max_steps = len(trainer.train_dataloader) * trainer.max_epochs
tau = 1 - (1 - self.initial_tau) * (math.cos(math.pi * pl_module.global_step / max_steps) + 1) / 2
return tau

def update_weights(self, online_net: Union[Module, Tensor], target_net: Union[Module, Tensor]) -> None:
# apply MA weight update
for (name, online_p), (_, target_p) in zip(
online_net.named_parameters(),
target_net.named_parameters(),
):
target_p.data = self.current_tau * target_p.data + (1 - self.current_tau) * online_p.data
self.current_tau = 1 - (1 - self.initial_tau) * (math.cos(math.pi * pl_module.global_step / max_steps) + 1) / 2

def update_weights(self, online_net: Union[nn.Module, Tensor], target_net: Union[nn.Module, Tensor]) -> None:
"""Update target network parameters."""
for online_p, target_p in zip(online_net.parameters(), target_net.parameters()):
target_p.data = self.current_tau * target_p.data + (1.0 - self.current_tau) * online_p.data
73 changes: 48 additions & 25 deletions tests/callbacks/test_param_update_callbacks.py
Original file line number Diff line number Diff line change
@@ -1,33 +1,56 @@
from copy import deepcopy

import pytest
import torch
from torch import nn

from pl_bolts.callbacks.byol_updates import BYOLMAWeightUpdate


def test_byol_ma_weight_update_callback():
a = nn.Linear(100, 10)
b = deepcopy(a)
a_original = deepcopy(a)
b_original = deepcopy(b)

# make sure a params and b params are the same
assert torch.equal(next(iter(a.parameters()))[0], next(iter(b.parameters()))[0])

# fake weight update
opt = torch.optim.SGD(a.parameters(), lr=0.1)
y = a(torch.randn(3, 100))
loss = y.sum()
loss.backward()
opt.step()
opt.zero_grad()

# make sure a did in fact update
assert not torch.equal(next(iter(a_original.parameters()))[0], next(iter(a.parameters()))[0])

# do update via callback
cb = BYOLMAWeightUpdate(0.8)
cb.update_weights(a, b)

assert not torch.equal(next(iter(b_original.parameters()))[0], next(iter(b.parameters()))[0])
@pytest.mark.parametrize("initial_tau", [-0.1, 0.0, 0.996, 1.0, 1.1])
def test_byol_ma_weight_single_update_callback(initial_tau, catch_warnings):
"""Check BYOL exponential moving average weight update rule for a single update."""
if 0.0 <= initial_tau <= 1.0:
# Create simple one layer network and their copies
online_network = nn.Linear(100, 10)
target_network = deepcopy(online_network)
online_network_copy = deepcopy(online_network)
target_network_copy = deepcopy(target_network)

# Check parameters are equal
assert torch.equal(next(iter(online_network.parameters()))[0], next(iter(target_network.parameters()))[0])

# Simulate weight update
opt = torch.optim.SGD(online_network.parameters(), lr=0.1)
y = online_network(torch.randn(3, 100))
loss = y.sum()
loss.backward()
opt.step()
opt.zero_grad()

# Check online network update
assert not torch.equal(
next(iter(online_network.parameters()))[0], next(iter(online_network_copy.parameters()))[0]
)

# Update target network weights via callback
cb = BYOLMAWeightUpdate(initial_tau)
cb.update_weights(online_network, target_network)

# Check target network update according to value of tau
if initial_tau == 0.0:
matsumotosan marked this conversation as resolved.
Show resolved Hide resolved
assert torch.equal(next(iter(target_network.parameters()))[0], next(iter(online_network.parameters()))[0])
elif initial_tau == 1.0:
assert torch.equal(
next(iter(target_network.parameters()))[0], next(iter(target_network_copy.parameters()))[0]
)
else:
for online_p, target_p in zip(online_network.parameters(), target_network_copy.parameters()):
target_p.data = initial_tau * target_p.data + (1.0 - initial_tau) * online_p.data

assert torch.equal(
next(iter(target_network.parameters()))[0], next(iter(target_network_copy.parameters()))[0]
)
else:
with pytest.raises(ValueError, match="initial tau should be"):
cb = BYOLMAWeightUpdate(initial_tau)