Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Base optimizer tracking #2126

Merged
merged 4 commits into from
Sep 1, 2020
Merged

Base optimizer tracking #2126

merged 4 commits into from
Sep 1, 2020

Conversation

bhack
Copy link
Contributor

@bhack bhack commented Aug 27, 2020

Description

Brief Description of the PR:
Inner optimizer is not tracked for checkpoint

Fixes # (issue)
#2094

Type of change

Checklist:

  • [X ] I've properly formatted my code according to the guidelines
    • [X ] By running Black + Flake8
    • [ X] By running pre-commit hooks
  • This PR addresses an already submitted issue for TensorFlow Addons
  • I have made corresponding changes to the documentation
  • I have added tests that prove my fix is effective or that my feature works
  • This PR contains modifications to C++ custom-ops

How Has This Been Tested?

Tested in #2102

If you're adding a bugfix or new feature please describe the tests that you ran to verify your changes:
It still need to pass save and load_model case.

@bhack bhack linked an issue Aug 27, 2020 that may be closed by this pull request
@bot-of-gabrieldemarmiesse

@CyberZHG

You are owner of some files modified in this pull request.
Would you kindly review the changes whenever you have the time to?
Thank you very much.

@bhack bhack changed the title Lookahead base optmizer tracking Lookahead base optimizer tracking Aug 27, 2020
@bhack
Copy link
Contributor Author

bhack commented Aug 28, 2020

/cc @reedwm Can you give us a feedback here. What we need to correctly support Keras save and load_model?
Need we to introduce something like your _DelegatingTrackableMixin in https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/keras/mixed_precision/experimental/loss_scale_optimizer.py#L58?

I see that in that case in Tensorflow your are just testing checkpoint but not save and load_model.

See user proposed tests in #2102

@reedwm
Copy link
Member

reedwm commented Aug 28, 2020

The purpose of _DelegatingTrackableMixin is so the checkpoint format of the LossScaleOptimizer is the same as the checkpoint format of the inner optimizer. This allows you to, e.g., save a checkpoint of float32 model with an Adam optimizer, then load the checkpoint into the equivalent mixed precision model with a LossScaleOptimizer wrapping an Adam optimizer.

I'm not very familiar with the Lookahead optimizer, but it seems this use case doesn't apply: I don't think people would want to train partially with Lookahead and partially without, and this would be hard anyway since Lookahead has an extra set of "slow weights". If I'm correct, the current approach in this PR is fine.

Note we test a SavedModel with LossScaleOptimizer here.

@bhack
Copy link
Contributor Author

bhack commented Aug 28, 2020

@reedwm Thank you for the feedback. I am going to check your SaveModel test that I missed.

@bhack
Copy link
Contributor Author

bhack commented Aug 28, 2020

@reedwm I've tried to expand your test more in the style of the user proposed one and it is passing on Tensorflow: tensorflow/tensorflow#42749

In our Lookahead case at #2102 instead is passing only with Checkpoint, with save_weight/load_weight cases or model.save and model.load_model but in h5.
So I think that we are still missing something for the model.save in tf format.

Any hint?

@reedwm
Copy link
Member

reedwm commented Aug 28, 2020

Hmm I'm not sure. What is the error?

Just a guess, but maybe the issue is because you return additional weights from the weights property:

@property
def weights(self):
return self._weights + self._optimizer.weights

AFAIK, no other optimizer does this. Maybe it's causing SavedModel to be confused.

Also last time I checked, SavedModel does not save slot variables, but checkpoints do. Make sure your SavedModel test is not relying on slot variables being restored.

Also note LossScaleOptimizer used to use an approach similar to this PR. tensorflow/tensorflow@f9e99f4 changed it, but before that commit SavedModel still worked, so this approach should be possible.

/CC @k-w-w any ideas what the issue could be?

@bhack
Copy link
Contributor Author

bhack commented Aug 28, 2020

As you can see we have slots in

def _create_slots(self, var_list):

@bhack
Copy link
Contributor Author

bhack commented Aug 28, 2020

@allenlavoie
Copy link
Member

I don't think that's related. IIRC it was because Checkpoint(model=tf.keras.Model()).save is not compatible with Model().save_weights (since Model is the root of the save in the second case).

What's the error/issue if you try to save with SavedModel?

@bhack
Copy link
Contributor Author

bhack commented Aug 29, 2020

The optimizer or the base/inner optimizer (or both) don't seems to be in the correct state after load_model in tf format.
Instead It seems ok in h5, checkpoint or save_weights and load_weights.

@allenlavoie
Copy link
Member

Hrm. Does tf.saved_model.save / tf.saved_model.load have the same issue? There's a bit of extra logic in Model.save/load_model that could be causing issues here.

@bhack
Copy link
Contributor Author

bhack commented Aug 29, 2020

With tf.saved_model.save is the same.

@bhack
Copy link
Contributor Author

bhack commented Aug 29, 2020

@allenlavoie What is exactly the RestoredOptimizer object?

@bhack
Copy link
Contributor Author

bhack commented Aug 29, 2020

@allenlavoie
Copy link
Member

RestoredOptimizer is just a container for holding the restored slot variables. If nothing owned the tf.Variable objects they would just go out of scope immediately and be deleted. It doesn't implement any of the other Optimizer methods.

Maybe this line is the issue: https://github.com/tensorflow/tensorflow/blob/be2870d0db035d12a207039922d76ae908b1101c/tensorflow/python/keras/optimizer_v2/optimizer_v2.py#L1391

When we tell SavedModel to load something back as a Python object, it takes a Python function to set attributes with. It's calling _set_hyper and passing the inner optimizer, and maybe that's not doing anything reasonable. You could try using a more complicated setter that e.g. typechecks for OptimizerV2s and uses setattr+_track_trackable or something.

@bhack
Copy link
Contributor Author

bhack commented Aug 29, 2020

I don't find so much entries of ResotredOptimizer object in the code. Where is it created?

What is the logic of:

      if isinstance(model.optimizer, optimizer_v2.RestoredOptimizer):
        raise NotImplementedError(
            'As of now, Optimizers loaded from SavedModel cannot be saved. '
            'If you\'re calling `model.save` or `tf.keras.models.save_model`,'
            ' please set the `include_optimizer` option to `False`. For '
            '`tf.saved_model.save`, delete the optimizer from the model.')

@bhack
Copy link
Contributor Author

bhack commented Aug 29, 2020

/cc @omalleyt12 What is the use case of this condition that you have introduced?

@bhack
Copy link
Contributor Author

bhack commented Aug 29, 2020

Ok nevermind I've seen the logic of RestoredOptimizer.

@bhack
Copy link
Contributor Author

bhack commented Aug 29, 2020

I've just extracted a very small code snippet removing Lookahead wrapping.
If you see it is working with the plain optimizer="SGD" but not with optimizer="adam" so I suppose it is something related to slots:

import os
import numpy as np
import tensorflow as tf
import tempfile


from tensorflow_addons.optimizers import Lookahead

def _init_model(optimizer, init_w):
    model = tf.keras.models.Sequential()
    dense = tf.keras.layers.Dense(input_shape=(3,), units=1)
    model.add(dense)
    dense.set_weights([init_w, np.zeros(1,)])
    model.compile(optimizer, loss="mse")
    return model

optimizer="adam"
np.random.seed(0)
x = np.random.standard_normal((10000, 3))
w = np.random.standard_normal((3, 1))
y = np.dot(x, w) + np.random.standard_normal((10000, 1)) * 1e-4

init_w = np.random.standard_normal((3, 1))

model = _init_model(optimizer, init_w)
model.fit(x, y, epochs=2, shuffle=False)
with tempfile.TemporaryDirectory() as ckpt_dir:
    new_model = _init_model(optimizer, init_w)
    new_model.fit(x, y, epochs=1, shuffle=False)
    new_model.save(ckpt_dir)
    new_model = _init_model(optimizer, init_w)
    new_model = tf.keras.models.load_model(
        ckpt_dir
    )
    new_model.fit(x, y, epochs=1, shuffle=False)
    assert np.allclose(model.predict(x), new_model.predict(x))

@bhack
Copy link
Contributor Author

bhack commented Aug 29, 2020

@reedwm With Adam Is also failing the tensorflow test on tour test if the change is correct: tensorflow/tensorflow#42749.

@bhack bhack changed the title Lookahead base optimizer tracking Base optimizer tracking Aug 30, 2020
@bhack bhack marked this pull request as ready for review August 30, 2020 10:53
@bhack bhack requested a review from Squadrick as a code owner August 30, 2020 10:53
@reedwm
Copy link
Member

reedwm commented Aug 31, 2020

I commented in tensorflow/tensorflow#42749 (comment). Since slot variables are not restored and Adam has slot variables, the sample code you gave is expected to fail.

@bhack
Copy link
Contributor Author

bhack commented Aug 31, 2020

Ok I commented also there.

@bhack
Copy link
Contributor Author

bhack commented Sep 1, 2020

I think that this could be merged as is.
Instead we are drafting a general warning on Tensorflow at tensorflow/tensorflow#42846 for the SavedModel case.

@bhack bhack requested review from seanpmorgan and WindQAQ September 1, 2020 13:53
Copy link
Member

@WindQAQ WindQAQ left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks 😃

@WindQAQ WindQAQ merged commit 2bf57f8 into master Sep 1, 2020
@seanpmorgan seanpmorgan deleted the bhack-patch-1 branch September 1, 2020 17:53
jrruijli pushed a commit to jrruijli/addons that referenced this pull request Dec 23, 2020
* Update lookahead.py

Inital fix of 
tensorflow#2094
tensorflow#2102

* Fix linting

* Resolve name conflict with mixed prexision

* Track baseline optimizer in avg
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Weights of Inner Optimizers Not Saved
6 participants