diff --git a/mlpf/tfmodel/callbacks.py b/mlpf/tfmodel/callbacks.py index d7cc8d09b..4b0e0b44c 100644 --- a/mlpf/tfmodel/callbacks.py +++ b/mlpf/tfmodel/callbacks.py @@ -24,26 +24,37 @@ def __init__(self, *args, **kwargs): def _collect_learning_rate(self, logs): logs = logs or {} + + # Get the optimizer. This is necessary since when using e.g. PCGrad, + # the base optimizer is in self.model.optimizer.optimizer.optimizer + # instead of self.model.optimizer + opt = None if hasattr(self.model.optimizer, "lr"): + opt = self.model.optimizer + elif hasattr(self.model.optimizer.optimizer, "optimizer"): + opt = self.model.optimizer.optimizer.optimizer + assert opt is not None + + if hasattr(opt, "lr"): - lr_schedule = getattr(self.model.optimizer, "lr", None) + lr_schedule = getattr(opt, "lr", None) if isinstance(lr_schedule, tf.keras.optimizers.schedules.LearningRateSchedule): - logs["learning_rate"] = np.float64(tf.keras.backend.get_value(lr_schedule(self.model.optimizer.iterations))) + logs["learning_rate"] = np.float64(tf.keras.backend.get_value(lr_schedule(opt.iterations))) else: - logs.update({"learning_rate": np.float64(tf.keras.backend.eval(self.model.optimizer.lr))}) + logs.update({"learning_rate": np.float64(tf.keras.backend.eval(opt.lr))}) # Log momentum if the optimizer has it try: - logs.update({"momentum": np.float64(tf.keras.backend.eval(self.model.optimizer.momentum))}) + logs.update({"momentum": np.float64(tf.keras.backend.eval(opt.momentum))}) except AttributeError: pass # In Adam, the momentum parameter is called beta_1 - if isinstance(self.model.optimizer, tf.keras.optimizers.Adam): - logs.update({"adam_beta_1": np.float64(tf.keras.backend.eval(self.model.optimizer.beta_1))}) + if isinstance(opt, tf.keras.optimizers.Adam): + logs.update({"adam_beta_1": np.float64(tf.keras.backend.eval(opt.beta_1))}) - if hasattr(self.model.optimizer, "loss_scale"): - logs.update({"loss_scale": np.float64(self.model.optimizer.loss_scale.numpy())}) + if hasattr(opt, "loss_scale"): + logs.update({"loss_scale": np.float64(opt.loss_scale.numpy())}) return logs