Skip to content

Commit

Permalink
fix: LR logging in CustomTensorBoard with PCGrad (jpata#137)
Browse files Browse the repository at this point in the history
The learning rate was no longer being logged since switching
to using PCGrad as default optimizer. This commit fixes that.
  • Loading branch information
erwulff authored Sep 2, 2022
1 parent 3b972dc commit ec3d0e3
Showing 1 changed file with 19 additions and 8 deletions.
27 changes: 19 additions & 8 deletions mlpf/tfmodel/callbacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,26 +24,37 @@ def __init__(self, *args, **kwargs):

def _collect_learning_rate(self, logs):
logs = logs or {}

# Get the optimizer. This is necessary since when using e.g. PCGrad,
# the base optimizer is in self.model.optimizer.optimizer.optimizer
# instead of self.model.optimizer
opt = None
if hasattr(self.model.optimizer, "lr"):
opt = self.model.optimizer
elif hasattr(self.model.optimizer.optimizer, "optimizer"):
opt = self.model.optimizer.optimizer.optimizer
assert opt is not None

if hasattr(opt, "lr"):

lr_schedule = getattr(self.model.optimizer, "lr", None)
lr_schedule = getattr(opt, "lr", None)
if isinstance(lr_schedule, tf.keras.optimizers.schedules.LearningRateSchedule):
logs["learning_rate"] = np.float64(tf.keras.backend.get_value(lr_schedule(self.model.optimizer.iterations)))
logs["learning_rate"] = np.float64(tf.keras.backend.get_value(lr_schedule(opt.iterations)))
else:
logs.update({"learning_rate": np.float64(tf.keras.backend.eval(self.model.optimizer.lr))})
logs.update({"learning_rate": np.float64(tf.keras.backend.eval(opt.lr))})

# Log momentum if the optimizer has it
try:
logs.update({"momentum": np.float64(tf.keras.backend.eval(self.model.optimizer.momentum))})
logs.update({"momentum": np.float64(tf.keras.backend.eval(opt.momentum))})
except AttributeError:
pass

# In Adam, the momentum parameter is called beta_1
if isinstance(self.model.optimizer, tf.keras.optimizers.Adam):
logs.update({"adam_beta_1": np.float64(tf.keras.backend.eval(self.model.optimizer.beta_1))})
if isinstance(opt, tf.keras.optimizers.Adam):
logs.update({"adam_beta_1": np.float64(tf.keras.backend.eval(opt.beta_1))})

if hasattr(self.model.optimizer, "loss_scale"):
logs.update({"loss_scale": np.float64(self.model.optimizer.loss_scale.numpy())})
if hasattr(opt, "loss_scale"):
logs.update({"loss_scale": np.float64(opt.loss_scale.numpy())})

return logs

Expand Down

0 comments on commit ec3d0e3

Please sign in to comment.