From 20eb2849ceaa87e134f398478620e3860937df42 Mon Sep 17 00:00:00 2001 From: who who who Date: Sun, 10 Oct 2021 03:03:24 +0800 Subject: [PATCH] suppress unnecessary warnings (#2333) --- tensorflow_addons/callbacks/average_model_checkpoint.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tensorflow_addons/callbacks/average_model_checkpoint.py b/tensorflow_addons/callbacks/average_model_checkpoint.py index 40870a296e..0ecf70ba2e 100644 --- a/tensorflow_addons/callbacks/average_model_checkpoint.py +++ b/tensorflow_addons/callbacks/average_model_checkpoint.py @@ -82,13 +82,13 @@ def _save_model(self, *args, **kwargs): assert isinstance(optimizer, AveragedOptimizerWrapper) if self.update_weights: - optimizer.assign_average_vars(self.model.variables) + optimizer.assign_average_vars(self.model.trainable_weights) return super()._save_model(*args, **kwargs) else: # Note: `model.get_weights()` gives us the weights (non-ref) # whereas `model.variables` returns references to the variables. non_avg_weights = self.model.get_weights() - optimizer.assign_average_vars(self.model.variables) + optimizer.assign_average_vars(self.model.trainable_weights) # result is currently None, since `super._save_model` doesn't # return anything, but this may change in the future. result = super()._save_model(*args, **kwargs)