Skip to content

Commit

Permalink
add warmup
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidrains committed Dec 10, 2023
1 parent d015c73 commit f9c6467
Show file tree
Hide file tree
Showing 3 changed files with 44 additions and 3 deletions.
44 changes: 42 additions & 2 deletions magvit2_pytorch/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,11 @@
from torch import nn
from torch.nn import Module
from torch.utils.data import Dataset, random_split
from torch.optim.lr_scheduler import LambdaLR, LRScheduler
import pytorch_warmup as warmup

from beartype import beartype
from beartype.typing import Optional, Literal, Union
from beartype.typing import Optional, Literal, Union, Type

from magvit2_pytorch.optimizer import get_optimizer

Expand All @@ -35,6 +37,8 @@
Literal['images']
]

ConstantLRScheduler = partial(LambdaLR, lr_lambda = lambda step: 1.)

DEFAULT_DDP_KWARGS = DistributedDataParallelKwargs(
find_unused_parameters = True
)
Expand Down Expand Up @@ -75,6 +79,9 @@ def __init__(
num_frames = 17,
use_wandb_tracking = False,
discr_start_after_step = 0.,
warmup_steps = 1000,
scheduler: Optional[Type[LRScheduler]] = None,
scheduler_kwargs: dict = dict(),
accelerate_kwargs: dict = dict(),
ema_kwargs: dict = dict(),
optimizer_kwargs: dict = dict(),
Expand Down Expand Up @@ -149,6 +156,20 @@ def __init__(
self.optimizer = get_optimizer(model.parameters(), lr = learning_rate, **optimizer_kwargs)
self.discr_optimizer = get_optimizer(model.discr_parameters(), lr = learning_rate, **optimizer_kwargs)

# warmup

self.warmup = warmup.LinearWarmup(self.optimizer, warmup_period = warmup_steps)
self.discr_warmup = warmup.LinearWarmup(self.discr_optimizer, warmup_period = warmup_steps)

# schedulers

if exists(scheduler):
self.scheduler = scheduler(self.optimizer, **scheduler_kwargs)
self.discr_scheduler = scheduler(self.discr_optimizer, **scheduler_kwargs)
else:
self.scheduler = ConstantLRScheduler(self.optimizer)
self.discr_scheduler = ConstantLRScheduler(self.discr_optimizer)

# training related params

self.batch_size = batch_size
Expand Down Expand Up @@ -270,7 +291,12 @@ def save(self, path, overwrite = True):
model = self.unwrapped_model.state_dict(),
ema_model = self.ema_model.state_dict(),
optimizer = self.optimizer.state_dict(),
discr_optimizer = self.discr_optimizer.state_dict()
discr_optimizer = self.discr_optimizer.state_dict(),
warmup = self.warmup.state_dict(),
scheduler = self.scheduler.state_dict(),
discr_warmup = self.discr_warmup.state_dict(),
discr_scheduler = self.discr_scheduler.state_dict(),
step = self.step.item()
)

for ind, opt in enumerate(self.multiscale_discr_optimizers):
Expand All @@ -288,10 +314,16 @@ def load(self, path):
self.ema_model.load_state_dict(pkg['ema_model'])
self.optimizer.load_state_dict(pkg['optimizer'])
self.discr_optimizer.load_state_dict(pkg['discr_optimizer'])
self.warmup.load_state_dict(pkg['warmup'])
self.scheduler.load_state_dict(pkg['scheduler'])
self.discr_warmup.load_state_dict(pkg['discr_warmup'])
self.discr_scheduler.load_state_dict(pkg['discr_scheduler'])

for ind, opt in enumerate(self.multiscale_discr_optimizers):
opt.load_state_dict(pkg[f'multiscale_discr_optimizer_{ind}'])

self.step.copy_(pkg['step'])

def train_step(self, dl_iter):
self.model.train()

Expand Down Expand Up @@ -339,6 +371,10 @@ def train_step(self, dl_iter):

self.optimizer.step()

if not self.accelerator.optimizer_step_was_skipped:
with self.warmup.dampening():
self.scheduler.step()

# update ema model

self.wait()
Expand Down Expand Up @@ -396,6 +432,10 @@ def train_step(self, dl_iter):

self.discr_optimizer.step()

if not self.accelerator.optimizer_step_was_skipped:
with self.discr_warmup.dampening():
self.discr_optimizer.step()

if self.has_multiscale_discrs:
for multiscale_discr_optimizer in self.multiscale_discr_optimizers:
multiscale_discr_optimizer.step()
Expand Down
2 changes: 1 addition & 1 deletion magvit2_pytorch/version.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = '0.1.45'
__version__ = '0.2.0'
1 change: 1 addition & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
'beartype',
'einops>=0.7.0',
'ema-pytorch>=0.2.4',
'pytorch-warmup',
'gateloop-transformer>=0.0.25',
'kornia',
'opencv-python',
Expand Down

0 comments on commit f9c6467

Please sign in to comment.