Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Weights of Inner Optimizers Not Saved #2094

Closed
BinyanHu opened this issue Aug 15, 2020 · 9 comments · Fixed by #2126
Closed

Weights of Inner Optimizers Not Saved #2094

BinyanHu opened this issue Aug 15, 2020 · 9 comments · Fixed by #2126
Labels
bug Something isn't working optimizers

Comments

@BinyanHu
Copy link

BinyanHu commented Aug 15, 2020

System information

  • OS Platform and Distribution (e.g., Linux Ubuntu 16.04): Linux Ubuntu 16.04 & Windows 10
  • TensorFlow version and how it was installed (source or binary): 2.3.0 from source
  • TensorFlow-Addons version and how it was installed (source or binary): 0.11.1 from source
  • Python version: 3.7
  • Is GPU used? (yes/no): yes

Describe the bug

Resume a training process needs the restoration of the optimizer states to continue training RIGHT from the previous state without any loss of accuracy. Currently, the keras interface of saving model keras.Model.save_weights checkpoints both the network parameters and the optimizer weights. However, when an optimizer is wrapped inside another, its weights can not be saved by this mean.

For example, when I was trying to use the Ranger optimizer, which is constructed by wrapping RAdam with Lookahead:

optimizer = tfa.optimizers.Lookahead(
    tfa.optimizers.RectifiedAdam()
)

I noticed a performance drop on resuming training. I found that the weights of the inner RAdam were not saved into the checkpoint. (I checked the .index file in the checkpoint folder and there are no variable names like "m" and "v", only "slow", which is the weights of Lookahead). Therefore, after loading the weights from file and restart fitting, the weights of RAdam are randomly reinitialized. This could because the weights of the inner optimizer are not automatically tracked.

Experiments

I trained the two LeNets on the FashionMNIST dataset. All the configurations are the same except for the optimizers. Both training are interrupted in the middle and then resumed.

image
Fig. TensorBoard. Blue: Ranger (Lookahead+RAdam), orange: RAdam.

Note the "bump" of the Ranger curve caused by the reinitialization of RAdam weights. Apparently, the weights of the inner optimizer are not correctly saved.

@bhack
Copy link
Contributor

bhack commented Aug 15, 2020

Can you prepare a minimal PR with a new test to cover your case?

@bhack
Copy link
Contributor

bhack commented Aug 15, 2020

So that we could check if It is similar to #1911

@BinyanHu
Copy link
Author

So that we could check if It is similar to #1911

Our issues are similar. But the real problem is on the missing weights of the inner RAdam optimizer.

First, I reran my program with status.assert_consumed(), and the errors are as follows:

AssertionError:
Unresolved object in checkpoint (root).optimizer.iter: attributes {
  name: "VARIABLE_VALUE"
  full_name: "iter"
  checkpoint_key: "optimizer/iter/.ATTRIBUTES/VARIABLE_VALUE"
}

Same as #1911. This is because the variable iter is not yet created by the time when we load the weights. If we do assert_consumed after fitting the model, the error goes away. The following warnings of all "slow" of the network parameters are not used in that issue are caused merely by non-training mode does not require loading the optimizer states, which is not a problem. In all, calling assert_consumed right after loading weights does not reveal the problem.

Second, learning rate warmup could help RAdam to re-accumulate the mean and variance statistics with small steps rather than "messing up" the network weights in the first few steps on resuming. This can, to some extent, alleviate the missing of the RAdam weights, but is definitely not the correct solution.

Plus, I just checked the sizes of the checkpoint files: Ranger 3381kb and RAdam 5070kb. With an extra slot "slow", the size of the Ranger checkpoint should not be smaller, indicating that the weights of RAdam are missing.

I think the reason is evident here. If a PR is still needed, how should the test be conducted? Would saving and loading a model with a Lookahead-wrapped optimizer with slots be enough to demonstrate the problem?

@bhack
Copy link
Contributor

bhack commented Aug 15, 2020

Lookahead test has no serializzation test currently.
So I think that you can add a small one and let It to fail in https://github.com/tensorflow/addons/blob/master/tensorflow_addons/optimizers/tests/lookahead_test.py

@bhack
Copy link
Contributor

bhack commented Aug 15, 2020

Check if some of the original author tests could be useful https://github.com/CyberZHG/keras-lookahead/blob/master/tests/test_optimizers.py

@bhack
Copy link
Contributor

bhack commented Aug 15, 2020

/cc @CyberZHG

@bhack
Copy link
Contributor

bhack commented Aug 15, 2020

Also check that you are recovering custom objects on load e.g. custom_objects={ 'RAdam': RAdam, 'Lookahead': Lookahead, })

@WindQAQ
Copy link
Member

WindQAQ commented Aug 16, 2020

Hi @BinyanHu, thanks for investigating this. Can you provide the minimal code snippet to reproduce the issue, e.g. the way you save the model? Thank you!

@WindQAQ WindQAQ added bug Something isn't working optimizers labels Aug 16, 2020
@AakashKumarNain
Copy link
Member

AssertionError:
Unresolved object in checkpoint (root).optimizer.iter: attributes {
  name: "VARIABLE_VALUE"
  full_name: "iter"
  checkpoint_key: "optimizer/iter/.ATTRIBUTES/VARIABLE_VALUE"
}

I think this is because of the fact that the value you pass to your optimizer is float, it gives this warning. You can use tf.Variable() for that.

On the other hand, I feel this is the real issue here

Plus, I just checked the sizes of the checkpoint files: Ranger 3381kb and RAdam 5070kb. With an extra slot "slow", the size of the Ranger checkpoint should not be smaller, indicating that the weights of RAdam are missing.

bhack added a commit that referenced this issue Aug 27, 2020
bhack added a commit that referenced this issue Aug 27, 2020
@bhack bhack mentioned this issue Aug 27, 2020
17 tasks
@bhack bhack linked a pull request Aug 27, 2020 that will close this issue
17 tasks
WindQAQ pushed a commit that referenced this issue Sep 1, 2020
* Update lookahead.py

Inital fix of 
#2094
#2102

* Fix linting

* Resolve name conflict with mixed prexision

* Track baseline optimizer in avg
jrruijli pushed a commit to jrruijli/addons that referenced this issue Dec 23, 2020
* Update lookahead.py

Inital fix of 
tensorflow#2094
tensorflow#2102

* Fix linting

* Resolve name conflict with mixed prexision

* Track baseline optimizer in avg
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working optimizers
Projects
None yet
Development

Successfully merging a pull request may close this issue.

4 participants