Skip to content

Commit 3abf4bc

Browse files
authored
EMA model stepping updated to keep track of current step (open-mmlab#64)
ema model stepping done automatically now
1 parent 94566e6 commit 3abf4bc

File tree

2 files changed

+5
-3
lines changed

2 files changed

+5
-3
lines changed

examples/train_unconditional.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -130,7 +130,7 @@ def transforms(examples):
130130
torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
131131
optimizer.step()
132132
lr_scheduler.step()
133-
ema_model.step(model, global_step)
133+
ema_model.step(model)
134134
optimizer.zero_grad()
135135
progress_bar.update(1)
136136
progress_bar.set_postfix(

src/diffusers/training_utils.py

+4-2
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@ def __init__(
4343
self.averaged_model = self.averaged_model.to(device=device)
4444

4545
self.decay = 0.0
46+
self.optimization_step = 0
4647

4748
def get_decay(self, optimization_step):
4849
"""
@@ -57,11 +58,11 @@ def get_decay(self, optimization_step):
5758
return max(self.min_value, min(value, self.max_value))
5859

5960
@torch.no_grad()
60-
def step(self, new_model, optimization_step):
61+
def step(self, new_model):
6162
ema_state_dict = {}
6263
ema_params = self.averaged_model.state_dict()
6364

64-
self.decay = self.get_decay(optimization_step)
65+
self.decay = self.get_decay(self.optimization_step)
6566

6667
for key, param in new_model.named_parameters():
6768
if isinstance(param, dict):
@@ -85,3 +86,4 @@ def step(self, new_model, optimization_step):
8586
ema_state_dict[key] = param
8687

8788
self.averaged_model.load_state_dict(ema_state_dict, strict=False)
89+
self.optimization_step += 1

0 commit comments

Comments
 (0)