Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for Tensor learning rates and gradients with mixed types. #4876

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions RELEASE.md
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,10 @@ and this project adheres to

### Fixed

* A bug where `tff.learning.optimizers.build_adafactor(...)` would update its
step counter twice upon every invocation of `.next()`.
* A bug where `tff.learning.optimizers.build_adafactor` would update its step
counter twice upon every invocation of `.next()`.
* A bug where tensor learning rates for `tff.learning.optimizers.build_sgdm`
would fail with mixed dtype gradients.

### Removed

Expand Down
27 changes: 14 additions & 13 deletions tensorflow_federated/python/learning/optimizers/adagrad.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
_HPARAMS_KEYS = [optimizer.LEARNING_RATE_KEY, _EPSILON_KEY]

State = TypeVar('State', bound=collections.OrderedDict[str, Any])
Hparams = TypeVar('Hparams', bound=collections.OrderedDict[str, float])
Hparams = TypeVar('Hparams', bound=collections.OrderedDict[str, Any])


class _Adagrad(optimizer.Optimizer[State, optimizer.Weights, Hparams]):
Expand All @@ -40,31 +40,35 @@ def __init__(
epsilon: optimizer.Float = 1e-7,
):
"""Initializes SGD optimizer."""
if learning_rate < 0.0:
if not tf.is_symbolic_tensor(learning_rate) and learning_rate < 0.0:
raise ValueError(
f'Adagrad `learning_rate` must be nonnegative, found {learning_rate}.'
)
if initial_preconditioner_value < 0.0:
if (
not tf.is_symbolic_tensor(initial_preconditioner_value)
and initial_preconditioner_value < 0.0
):
raise ValueError(
'Adagrad `initial_preconditioner_value` must be nonnegative, found '
f'{initial_preconditioner_value}.'
)
if epsilon < 0.0:
if not tf.is_symbolic_tensor(epsilon) and epsilon < 0.0:
raise ValueError(f'Adagrad epsilon must be nonnegative, found {epsilon}.')
self._lr = learning_rate
self._initial_precond = initial_preconditioner_value
self._epsilon = epsilon

def initialize(self, specs: Any) -> State:
initial_preconditioner = tf.nest.map_structure(
lambda s: tf.ones(s.shape, s.dtype) * self._initial_precond, specs
lambda s: tf.ones(s.shape, s.dtype)
* tf.cast(self._initial_precond, s.dtype),
specs,
)
state = collections.OrderedDict([
return collections.OrderedDict([
(optimizer.LEARNING_RATE_KEY, self._lr),
(_EPSILON_KEY, self._epsilon),
(_PRECONDITIONER_KEY, initial_preconditioner),
])
return state

def next(
self, state: State, weights: optimizer.Weights, gradients: Any
Expand All @@ -82,7 +86,9 @@ def _adagrad_update(w, p, g):
if g is None:
return w, p
p = p + tf.math.square(g)
w = w - lr * g / tf.math.sqrt(p + epsilon)
w = w - tf.cast(lr, g.dtype) * g / tf.math.sqrt(
p + tf.cast(epsilon, p.dtype)
)
return w, p

updated_weights, updated_preconditioner = nest_utils.map_at_leaves(
Expand All @@ -99,11 +105,6 @@ def get_hparams(self, state: State) -> Hparams:
return collections.OrderedDict([(k, state[k]) for k in _HPARAMS_KEYS])

def set_hparams(self, state: State, hparams: Hparams) -> State:
# TODO: b/245962555 - Find an alternative to `update_struct` if it
# interferes with typing guarantees.
# We use `tff.structure.update_struct` (rather than something like
# `copy.deepcopy`) to ensure that this can be called within a
# `tff.Computation`.
return structure.update_struct(state, **hparams)


Expand Down
20 changes: 18 additions & 2 deletions tensorflow_federated/python/learning/optimizers/adagrad_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,8 +145,8 @@ def random_vector():
genarator.normal(shape=s.shape, dtype=s.dtype) for s in weight_spec
]

intial_weight = random_vector()
model_variables_fn = lambda: [tf.Variable(v) for v in intial_weight]
initial_weight = random_vector()
model_variables_fn = lambda: [tf.Variable(v) for v in initial_weight]
gradients = [random_vector() for _ in range(steps)]
tff_optimizer_fn = lambda: adagrad.build_adagrad(0.01)
keras_optimizer_fn = lambda: tf.keras.optimizers.Adagrad(0.01)
Expand Down Expand Up @@ -227,6 +227,22 @@ def test_set_get_hparams_is_no_op(self, spec):
updated_state = optimizer.set_hparams(state, hparams)
self.assertEqual(state, updated_state)

def test_lr_with_different_weight_dtypes(self):
weights = (
tf.constant([0.1], dtype=tf.float32),
tf.constant(1.0, dtype=tf.float64),
tf.constant([10.0, 10.0], dtype=tf.bfloat16),
)
adagrad_optimizer = adagrad.build_adagrad(
learning_rate=tf.constant(0.1, dtype=tf.float32),
initial_preconditioner_value=tf.constant(0.1, dtype=tf.float32),
epsilon=tf.constant(0.1, dtype=tf.float64),
)
state = adagrad_optimizer.initialize(weights)
adagrad_optimizer.next(
state, weights, tf.nest.map_structure(tf.zeros_like, weights)
)


if __name__ == '__main__':
tf.test.main()
34 changes: 17 additions & 17 deletions tensorflow_federated/python/learning/optimizers/adam.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@
]

State = TypeVar('State', bound=collections.OrderedDict[str, Any])
Hparams = TypeVar('Hparams', bound=collections.OrderedDict[str, float])
Hparams = TypeVar('Hparams', bound=collections.OrderedDict[str, Any])


class _Adam(optimizer.Optimizer[State, optimizer.Weights, Hparams]):
Expand All @@ -50,19 +50,19 @@ def __init__(
epsilon: optimizer.Float = 1e-7,
):
"""Initializes Adam optimizer."""
if learning_rate < 0.0:
if not tf.is_symbolic_tensor(learning_rate) and learning_rate < 0.0:
raise ValueError(
f'Adam `learning_rate` must be nonnegative, found {learning_rate}.'
)
if beta_1 < 0.0 or beta_1 > 1.0:
if not tf.is_symbolic_tensor(beta_1) and (beta_1 < 0.0 or beta_1 > 1.0):
raise ValueError(
f'Adam `beta_1` must be in the range [0.0, 1.0], found {beta_1}.'
)
if beta_2 < 0.0 or beta_2 > 1.0:
if not tf.is_symbolic_tensor(beta_2) and (beta_2 < 0.0 or beta_2 > 1.0):
raise ValueError(
f'Adam `beta_2` must be in the range [0.0, 1.0], found {beta_2}.'
)
if epsilon < 0.0:
if not tf.is_symbolic_tensor(epsilon) and epsilon < 0.0:
raise ValueError(f'Adam `epsilon` must be nonnegative, found {epsilon}.')
self._lr = learning_rate
self._beta_1 = beta_1
Expand All @@ -76,7 +76,7 @@ def initialize(self, specs: Any) -> State:
initial_preconditioner = tf.nest.map_structure(
lambda s: tf.zeros(s.shape, s.dtype), specs
)
state = collections.OrderedDict([
return collections.OrderedDict([
(optimizer.LEARNING_RATE_KEY, self._lr),
(_BETA_1_KEY, self._beta_1),
(_BETA_2_KEY, self._beta_2),
Expand All @@ -85,7 +85,6 @@ def initialize(self, specs: Any) -> State:
(_ACCUMULATOR_KEY, initial_accumulator),
(_PRECONDITIONER_KEY, initial_preconditioner),
])
return state

def next(
self, state: State, weights: optimizer.Weights, gradients: Any
Expand All @@ -103,18 +102,24 @@ def next(
optimizer.check_weights_state_match(
weights, preconditioner, 'preconditioner'
)
if tf.is_tensor(beta_1):
casted_step = tf.cast(step, beta_1.dtype)
else:
casted_step = step
normalized_lr = (
lr
* tf.math.sqrt((1 - tf.math.pow(beta_2, tf.cast(step, tf.float32))))
/ (1 - tf.math.pow(beta_1, tf.cast(step, tf.float32)))
* tf.math.sqrt((1.0 - tf.math.pow(beta_2, casted_step)))
/ (1.0 - tf.math.pow(beta_1, casted_step))
)

def _adam_update(w, a, p, g):
if g is None:
return w, a, p
a = a + (g - a) * (1 - beta_1)
p = p + (tf.math.square(g) - p) * (1 - beta_2)
w = w - normalized_lr * a / (tf.math.sqrt(p) + epsilon)
a = a + (g - a) * (1 - tf.cast(beta_1, a.dtype))
p = p + (tf.math.square(g) - p) * (1 - tf.cast(beta_2, p.dtype))
w = w - tf.cast(normalized_lr, a.dtype) * a / (
tf.math.sqrt(p) + tf.cast(epsilon, p.dtype)
)
return w, a, p

updated_weights, updated_accumulator, updated_preconditioner = (
Expand Down Expand Up @@ -142,11 +147,6 @@ def get_hparams(self, state: State) -> Hparams:
return collections.OrderedDict([(k, state[k]) for k in _HPARAMS_KEYS])

def set_hparams(self, state: State, hparams: Hparams) -> State:
# TODO: b/245962555 - Find an alternative to `update_struct` if it
# interferes with typing guarantees.
# We use `tff.structure.update_struct` (rather than something like
# `copy.deepcopy`) to ensure that this can be called within a
# `tff.Computation`.
return structure.update_struct(state, **hparams)


Expand Down
25 changes: 20 additions & 5 deletions tensorflow_federated/python/learning/optimizers/adam_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,9 +55,7 @@ def test_math(self):
for _ in range(4):
state, weights = optimizer.next(state, weights, gradients)
history.append(weights)
self.assertAllClose(
[[1.0], [0.9000007], [0.8000017], [0.700002], [0.600003]], history
)
self.assertAllClose([[1.0], [0.9], [0.8], [0.7], [0.6]], history)

@parameterized.named_parameters(
('scalar_spec', _SCALAR_SPEC),
Expand Down Expand Up @@ -142,8 +140,8 @@ def random_vector():
genarator.normal(shape=s.shape, dtype=s.dtype) for s in weight_spec
]

intial_weight = random_vector()
model_variables_fn = lambda: [tf.Variable(v) for v in intial_weight]
initial_weight = random_vector()
model_variables_fn = lambda: [tf.Variable(v) for v in initial_weight]
gradients = [random_vector() for _ in range(steps)]
tff_optimizer_fn = lambda: adam.build_adam(0.01, 0.9, 0.999)
keras_optimizer_fn = lambda: tf.keras.optimizers.Adam(0.01, 0.9, 0.999)
Expand Down Expand Up @@ -225,6 +223,23 @@ def test_set_get_hparams_is_no_op(self, spec):
updated_state = optimizer.set_hparams(state, hparams)
self.assertEqual(state, updated_state)

def test_lr_with_different_weight_dtypes(self):
weights = (
tf.constant([0.1], dtype=tf.float32),
tf.constant(1.0, dtype=tf.float64),
tf.constant([10.0, 10.0], dtype=tf.bfloat16),
)
adam_optimizer = adam.build_adam(
learning_rate=tf.constant(0.1, dtype=tf.float32),
beta_1=tf.constant(0.1, dtype=tf.float32),
beta_2=tf.constant(0.1, dtype=tf.float32),
epsilon=tf.constant(0.1, dtype=tf.float64),
)
state = adam_optimizer.initialize(weights)
adam_optimizer.next(
state, weights, tf.nest.map_structure(tf.zeros_like, weights)
)


if __name__ == '__main__':
tf.test.main()
15 changes: 6 additions & 9 deletions tensorflow_federated/python/learning/optimizers/sgdm.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
_ACCUMULATOR_KEY = 'accumulator'

State = TypeVar('State', bound=collections.OrderedDict[str, Any])
Hparams = TypeVar('Hparams', bound=collections.OrderedDict[str, float])
Hparams = TypeVar('Hparams', bound=collections.OrderedDict[str, Any])


class _SGD(optimizer.Optimizer[State, optimizer.Weights, Hparams]):
Expand All @@ -38,14 +38,16 @@ def __init__(
momentum: Optional[optimizer.Float] = None,
):
"""Initializes SGD optimizer."""
if learning_rate < 0.0:
if not tf.is_symbolic_tensor(learning_rate) and learning_rate < 0.0:
raise ValueError(
f'SGD `learning_rate` must be nonnegative, found {learning_rate}.'
)
if momentum:
# We should only track momentum as a hparam in the case that it is both
# specified and nonzero.
if momentum < 0.0 or momentum > 1.0:
if not tf.is_symbolic_tensor(momentum) and (
momentum < 0.0 or momentum > 1.0
):
raise ValueError(
'SGD `momentum` must be `None` or in the range [0, 1], found '
f'{momentum}.'
Expand Down Expand Up @@ -77,7 +79,7 @@ def next(
def _sgd_update(w, g):
if g is None:
return w
return w - lr * g
return w - tf.cast(lr, dtype=g.dtype) * g

updated_weights = nest_utils.map_at_leaves(
_sgd_update, weights, gradients
Expand Down Expand Up @@ -111,11 +113,6 @@ def get_hparams(self, state: State) -> Hparams:
return collections.OrderedDict([(k, state[k]) for k in self._hparams_keys])

def set_hparams(self, state: State, hparams: Hparams) -> State:
# TODO: b/245962555 - Find an alternative to `update_struct` if it
# interferes with typing guarantees.
# We use `tff.structure.update_struct` (rather than something like
# `copy.deepcopy`) to ensure that this can be called within a
# `tff.Computation`.
return structure.update_struct(state, **hparams)


Expand Down
20 changes: 17 additions & 3 deletions tensorflow_federated/python/learning/optimizers/sgdm_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ def test_get_hparams_momentum(self, momentum_value):
optimizer = sgdm.build_sgdm(0.01, momentum=momentum_value)
state = optimizer.initialize(_SCALAR_SPEC)
hparams = optimizer.get_hparams(state)
# Whether we specify None momentum or momentum 0.0, we shouldnt track the
# Whether we specify None momentum or momentum 0.0, we shouldn't track the
# extra accumulator state. The implementation of next checks for the
# presence or absence of momentum key--it should not be there in either
# case.
Expand Down Expand Up @@ -177,8 +177,8 @@ def random_vector():
genarator.normal(shape=s.shape, dtype=s.dtype) for s in weight_spec
]

intial_weight = random_vector()
model_variables_fn = lambda: [tf.Variable(v) for v in intial_weight]
initial_weight = random_vector()
model_variables_fn = lambda: [tf.Variable(v) for v in initial_weight]
gradients = [random_vector() for _ in range(steps)]
tff_optimizer_fn = lambda: sgdm.build_sgdm(learning_rate, momentum)

Expand Down Expand Up @@ -306,6 +306,20 @@ def test_set_get_hparams_is_no_op_with_momentum(self, spec):
updated_state = optimizer.set_hparams(state, hparams)
self.assertEqual(state, updated_state)

def test_lr_with_different_weight_dtypes(self):
weights = (
tf.constant([0.1], dtype=tf.float32),
tf.constant(1.0, dtype=tf.float64),
tf.constant([10.0, 10.0], dtype=tf.bfloat16),
)
sgdm_optimizer = sgdm.build_sgdm(
learning_rate=tf.constant(0.1, dtype=tf.float32)
)
state = sgdm_optimizer.initialize(weights)
sgdm_optimizer.next(
state, weights, tf.nest.map_structure(tf.zeros_like, weights)
)


if __name__ == '__main__':
tf.test.main()