diff --git a/sonnet/src/moving_averages.py b/sonnet/src/moving_averages.py index 4605e68..c0b7158 100644 --- a/sonnet/src/moving_averages.py +++ b/sonnet/src/moving_averages.py @@ -14,7 +14,7 @@ # ============================================================================ """Exponential moving average for Sonnet.""" -from typing import Optional +from typing import Optional, cast from sonnet.src import metrics from sonnet.src import once @@ -61,8 +61,8 @@ def __init__(self, decay: types.FloatLike, name: Optional[str] = None): self._counter = tf.Variable( 0, trainable=False, dtype=tf.int64, name="counter") - self._hidden = None - self.average = None + self._hidden: tf.Variable = cast(tf.Variable, None) + self.average: tf.Variable = cast(tf.Variable, None) def update(self, value: tf.Tensor): """Applies EMA to the value given.""" @@ -82,8 +82,10 @@ def value(self) -> tf.Tensor: def reset(self): """Resets the EMA.""" self._counter.assign(tf.zeros_like(self._counter)) - self._hidden.assign(tf.zeros_like(self._hidden)) - self.average.assign(tf.zeros_like(self.average)) + if self._hidden is not None: + self._hidden.assign(tf.zeros_like(self._hidden)) + if self.average is not None: + self.average.assign(tf.zeros_like(self.average)) @once.once def initialize(self, value: tf.Tensor):