Skip to content

Commit

Permalink
suppress unnecessary warnings (#2333)
Browse files Browse the repository at this point in the history
  • Loading branch information
fsx950223 authored Oct 9, 2021
1 parent 8e73b9d commit 20eb284
Showing 1 changed file with 2 additions and 2 deletions.
4 changes: 2 additions & 2 deletions tensorflow_addons/callbacks/average_model_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,13 +82,13 @@ def _save_model(self, *args, **kwargs):
assert isinstance(optimizer, AveragedOptimizerWrapper)

if self.update_weights:
optimizer.assign_average_vars(self.model.variables)
optimizer.assign_average_vars(self.model.trainable_weights)
return super()._save_model(*args, **kwargs)
else:
# Note: `model.get_weights()` gives us the weights (non-ref)
# whereas `model.variables` returns references to the variables.
non_avg_weights = self.model.get_weights()
optimizer.assign_average_vars(self.model.variables)
optimizer.assign_average_vars(self.model.trainable_weights)
# result is currently None, since `super._save_model` doesn't
# return anything, but this may change in the future.
result = super()._save_model(*args, **kwargs)
Expand Down

0 comments on commit 20eb284

Please sign in to comment.