diff --git a/keras/optimizers/optimizer_experimental/BUILD b/keras/optimizers/optimizer_experimental/BUILD index 9de98511f2a..1efc5af611a 100644 --- a/keras/optimizers/optimizer_experimental/BUILD +++ b/keras/optimizers/optimizer_experimental/BUILD @@ -15,6 +15,7 @@ py_library( "adadelta.py", "adagrad.py", "adam.py", + "adamw.py", "optimizer.py", "rmsprop.py", "sgd.py", diff --git a/keras/optimizers/optimizer_experimental/adamw.py b/keras/optimizers/optimizer_experimental/adamw.py new file mode 100644 index 00000000000..c6ed7408693 --- /dev/null +++ b/keras/optimizers/optimizer_experimental/adamw.py @@ -0,0 +1,229 @@ +# Copyright 2022 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""AdamW optimizer implementation.""" + +from keras.optimizers.optimizer_experimental import optimizer +from keras.utils import generic_utils +import tensorflow.compat.v2 as tf +# pylint: disable=g-direct-tensorflow-import +from tensorflow.python.util.tf_export import keras_export + + +@generic_utils.register_keras_serializable() +@keras_export('keras.optimizers.experimental.AdamW', v1=[]) +class AdamW(optimizer.Optimizer): + r"""Optimizer that implements the AdamW algorithm. + + AdamW optimization is a stochastic gradient descent method that is based on + adaptive estimation of first-order and second-order moments with an added + method to decay weights per the techniques discussed in the paeper, + 'Decoupled Weight Decay Regularization' by + [Loshchilov, Hutter et al., 2019](https://arxiv.org/abs/1711.05101). + + According to + [Kingma et al., 2014](http://arxiv.org/abs/1412.6980), + the underying Adam method is "*computationally + efficient, has little memory requirement, invariant to diagonal rescaling of + gradients, and is well suited for problems that are large in terms of + data/parameters*". + + Attributes: + learning_rate: A `tf.Tensor`, floating point value, a schedule that is a + `tf.keras.optimizers.schedules.LearningRateSchedule`, or a callable + that takes no arguments and returns the actual value to use. The + learning rate. Defaults to 0.001. + weight_decay: A `tf.Tensor`, floating point value. The weight decay. + Defaults to 0.004. + beta_1: A float value or a constant float tensor, or a callable + that takes no arguments and returns the actual value to use. The + exponential decay rate for the 1st moment estimates. Defaults to 0.9. + beta_2: A float value or a constant float tensor, or a callable + that takes no arguments and returns the actual value to use. The + exponential decay rate for the 2nd moment estimates. Defaults to 0.999. + epsilon: A small constant for numerical stability. This epsilon is + "epsilon hat" in the Kingma and Ba paper (in the formula just before + Section 2.1), not the epsilon in Algorithm 1 of the paper. Defaults to + 1e-7. + amsgrad: Boolean. Whether to apply AMSGrad variant of this algorithm from + the paper "On the Convergence of Adam and beyond". Defaults to `False`. + clipnorm: see the `clipnorm` argument of `optimizer_experimental.Optimizer`. + clipvalue: see the `clipvalue` argument of + `optimizer_experimental.Optimizer`. + global_clipnorm: see the `global_clipnorm` argument of + `optimizer_experimental.Optimizer`. + use_ema: see the `use_ema` argument of `optimizer_experimental.Optimizer`. + ema_momentum: see the `ema_momentum` argument of + `optimizer_experimental.Optimizer`. + ema_overwrite_frequency: see the `ema_overwrite_frequency` argument of + `optimizer_experimental.Optimizer`. + jit_compile: see the `jit_compile` argument of + `optimizer_experimental.Optimizer`. + name: Optional name prefix for the operations created when applying + gradients. Defaults to `"Adam"`. + **kwargs: see the `**kwargs` argument of `optimizer_experimental.Optimizer`. + + Reference: + - [Loshchilov et al., 2019](https://arxiv.org/abs/1711.05101) + - [Kingma et al., 2014](http://arxiv.org/abs/1412.6980) for `adam` + - [Reddi et al., 2018]( + https://openreview.net/pdf?id=ryQu7f-RZ) for `amsgrad`. + + Notes: + + The default value of 1e-7 for epsilon might not be a good default in + general. For example, when training an Inception network on ImageNet a + current good choice is 1.0 or 0.1. Note that since Adam uses the + formulation just before Section 2.1 of the Kingma and Ba paper rather than + the formulation in Algorithm 1, the "epsilon" referred to here is "epsilon + hat" in the paper. + + The sparse implementation of this algorithm (used when the gradient is an + IndexedSlices object, typically because of `tf.gather` or an embedding + lookup in the forward pass) does apply momentum to variable slices even if + they were not used in the forward pass (meaning they have a gradient equal + to zero). Momentum decay (beta1) is also applied to the entire momentum + accumulator. This means that the sparse behavior is equivalent to the dense + behavior (in contrast to some momentum implementations which ignore momentum + unless a variable slice was actually used). + """ + + def __init__(self, + learning_rate=0.001, + weight_decay=0.004, + beta_1=0.9, + beta_2=0.999, + epsilon=1e-7, + amsgrad=False, + clipnorm=None, + clipvalue=None, + global_clipnorm=None, + use_ema=False, + ema_momentum=0.99, + ema_overwrite_frequency=None, + jit_compile=False, + name='AdamW', + **kwargs): + super(AdamW, self).__init__( + name=name, + clipnorm=clipnorm, + clipvalue=clipvalue, + global_clipnorm=global_clipnorm, + use_ema=use_ema, + ema_momentum=ema_momentum, + ema_overwrite_frequency=ema_overwrite_frequency, + jit_compile=jit_compile, + **kwargs) + self._learning_rate = self._build_learning_rate(learning_rate) + self.weight_decay = weight_decay + self.beta_1 = beta_1 + self.beta_2 = beta_2 + self.epsilon = epsilon + self.amsgrad = amsgrad + + if self.weight_decay is None: + raise ValueError('Missing value of `weight_decay` which is required and' + ' must be a float value.') + + def build(self, var_list): + """Initialize optimizer variables. + + AdamW optimizer has 3 types of variables: momentums, velocities and + velocity_hat (only set when amsgrad is applied), + + Args: + var_list: list of model variables to build AdamW variables on. + """ + super().build(var_list) + if hasattr(self, '_built') and self._built: + return + self._built = True + self._momentums = [] + self._velocities = [] + for var in var_list: + self._momentums.append( + self.add_variable_from_reference( + model_variable=var, variable_name='m')) + self._velocities.append( + self.add_variable_from_reference( + model_variable=var, variable_name='v')) + if self.amsgrad: + self._velocity_hats = [] + for var in var_list: + self._velocity_hats.append( + self.add_variable_from_reference( + model_variable=var, variable_name='vhat')) + + def update_step(self, gradient, variable): + """Update step given gradient and the associated model variable.""" + if self._var_key(variable) not in self._index_dict: + raise KeyError(f'Optimizer cannot recognize variable {variable.name}, ' + f'this usually means you are calling an optimizer ' + f'previously used on a different model. Please try ' + f'creating a new optimizer instance.') + beta_1_power = None + beta_2_power = None + lr = tf.cast(self.learning_rate, variable.dtype) + local_step = tf.cast(self.iterations + 1, variable.dtype) + beta_1_power = tf.pow(tf.cast(self.beta_1, variable.dtype), local_step) + beta_2_power = tf.pow(tf.cast(self.beta_2, variable.dtype), local_step) + + var_key = self._var_key(variable) + m = self._momentums[self._index_dict[var_key]] + v = self._velocities[self._index_dict[var_key]] + + alpha = (lr * tf.sqrt(1 - beta_2_power) / (1 - beta_1_power)) + + # Apply step weight decay + if self.weight_decay != 0: + wd = tf.cast(self.weight_decay, variable.dtype) + variable.assign_sub(variable * (1 - lr * wd)) + + if isinstance(gradient, tf.IndexedSlices): + # Sparse gradients. + m.assign_add(-m * (1 - self.beta_1)) + m.scatter_add( + tf.IndexedSlices(gradient.values * (1 - self.beta_1), + gradient.indices)) + v.assign_add(-v * (1 - self.beta_2)) + v.scatter_add( + tf.IndexedSlices( + tf.square(gradient.values) * (1 - self.beta_2), gradient.indices)) + if self.amsgrad: + v_hat = self._velocity_hats[self._index_dict[var_key]] + v_hat.assign(tf.maximum(v_hat, v)) + v = v_hat + variable.assign_sub((m * alpha) / (tf.sqrt(v) + self.epsilon)) + else: + # Dense gradients. + m.assign_add((gradient - m) * (1 - self.beta_1)) + v.assign_add((tf.square(gradient) - v) * (1 - self.beta_2)) + if self.amsgrad: + v_hat = self._velocity_hats[self._index_dict[var_key]] + v_hat.assign(tf.maximum(v_hat, v)) + v = v_hat + variable.assign_sub((m * alpha) / (tf.sqrt(v) + self.epsilon)) + + def get_config(self): + config = super(AdamW, self).get_config() + + config.update({ + 'learning_rate': self._serialize_hyperparameter(self._learning_rate), + 'weight_decay': self.weight_decay, + 'beta_1': self.beta_1, + 'beta_2': self.beta_2, + 'epsilon': self.epsilon, + 'amsgrad': self.amsgrad, + }) + return config diff --git a/keras/optimizers/optimizer_experimental/optimizer_test.py b/keras/optimizers/optimizer_experimental/optimizer_test.py index ee1dd3b5394..4a18e4a6111 100644 --- a/keras/optimizers/optimizer_experimental/optimizer_test.py +++ b/keras/optimizers/optimizer_experimental/optimizer_test.py @@ -13,6 +13,7 @@ from keras.optimizers.optimizer_experimental import adadelta as adadelta_new from keras.optimizers.optimizer_experimental import adagrad as adagrad_new from keras.optimizers.optimizer_experimental import adam as adam_new +from keras.optimizers.optimizer_experimental import adamw as adamw_new from keras.optimizers.optimizer_experimental import rmsprop as rmsprop_new from keras.optimizers.optimizer_experimental import sgd as sgd_new from keras.optimizers.optimizer_v2 import adadelta as adadelta_old @@ -41,22 +42,29 @@ adadelta_new_fn = tf.__internal__.test.combinations.NamedObject( "experimentaladadelta", lambda: adadelta_new.Adadelta( # pylint: disable=g-long-lambda - 0.002, use_ema=True, ema_overwrite_frequency=None)) + 0.002, + use_ema=True, + ema_overwrite_frequency=None)) adagrad_new_fn = tf.__internal__.test.combinations.NamedObject( "experimentaladagrad", lambda: adagrad_new.Adagrad(0.002)) adam_new_fn = tf.__internal__.test.combinations.NamedObject( "experimentaladam", lambda: adam_new.Adam(0.002)) +adamw_new_fn = tf.__internal__.test.combinations.NamedObject( + "experimentaladamw", lambda: adamw_new.AdamW(0.002, weight_decay=0.004)) rmsprop_new_fn = tf.__internal__.test.combinations.NamedObject( "experimentalrmsprop", lambda: rmsprop_new.RMSprop(0.002)) sgd_new_fn = tf.__internal__.test.combinations.NamedObject( "experimentalsgdaverage", lambda: sgd_new.SGD( # pylint: disable=g-long-lambda - 0.002, use_ema=True, ema_overwrite_frequency=1)) + 0.002, + use_ema=True, + ema_overwrite_frequency=1)) OPTIMIZER_FN = [ adadelta_new_fn, adagrad_new_fn, adam_new_fn, + adamw_new_fn, rmsprop_new_fn, sgd_new_fn, ] @@ -163,6 +171,13 @@ def testSetIterations(self): with self.assertRaisesRegex(RuntimeError, "Cannot set*"): optimizer.iterations = 2 + def testPassingMissingWDError(self): + with self.assertRaises(ValueError): + _ = adamw_new.AdamW(0.01, weight_decay=None) + + with self.assertRaisesRegex(ValueError, "Missing value of"): + _ = adamw_new.AdamW(0.01, weight_decay=None) + def testMovingAverageOptimizer(self): optimizer = sgd_new.SGD( learning_rate=1,