diff --git a/tensorflow_addons/callbacks/average_model_checkpoint.py b/tensorflow_addons/callbacks/average_model_checkpoint.py index 40870a296e..0ecf70ba2e 100644 --- a/tensorflow_addons/callbacks/average_model_checkpoint.py +++ b/tensorflow_addons/callbacks/average_model_checkpoint.py @@ -82,13 +82,13 @@ def _save_model(self, *args, **kwargs): assert isinstance(optimizer, AveragedOptimizerWrapper) if self.update_weights: - optimizer.assign_average_vars(self.model.variables) + optimizer.assign_average_vars(self.model.trainable_weights) return super()._save_model(*args, **kwargs) else: # Note: `model.get_weights()` gives us the weights (non-ref) # whereas `model.variables` returns references to the variables. non_avg_weights = self.model.get_weights() - optimizer.assign_average_vars(self.model.variables) + optimizer.assign_average_vars(self.model.trainable_weights) # result is currently None, since `super._save_model` doesn't # return anything, but this may change in the future. result = super()._save_model(*args, **kwargs)