Skip to content

Commit

Permalink
Base optimizer tracking (#2126)
Browse files Browse the repository at this point in the history
* Update lookahead.py

Inital fix of 
#2094
#2102

* Fix linting

* Resolve name conflict with mixed prexision

* Track baseline optimizer in avg
  • Loading branch information
bhack authored Sep 1, 2020
1 parent ae05276 commit 2bf57f8
Show file tree
Hide file tree
Showing 2 changed files with 2 additions and 0 deletions.
1 change: 1 addition & 0 deletions tensorflow_addons/optimizers/average_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ def __init__(
raise TypeError("sequential_update must be of bool type")

self._optimizer = optimizer
self._track_trackable(self._optimizer, "awg_optimizer")

if sequential_update is not None:
warnings.warn(
Expand Down
1 change: 1 addition & 0 deletions tensorflow_addons/optimizers/lookahead.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,7 @@ def __init__(
self._set_hyper("sync_period", sync_period)
self._set_hyper("slow_step_size", slow_step_size)
self._initialized = False
self._track_trackable(self._optimizer, "lh_base_optimizer")

def _create_slots(self, var_list):
self._optimizer._create_slots(
Expand Down

0 comments on commit 2bf57f8

Please sign in to comment.