-
Notifications
You must be signed in to change notification settings - Fork 570
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
include_optimizer parameter introduced to push_to_hub_keras() #616
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should solve the issue 👍
src/huggingface_hub/keras_mixin.py
Outdated
@@ -50,7 +52,7 @@ def save_pretrained_keras( | |||
with open(path, "w") as f: | |||
json.dump(config, f) | |||
|
|||
tf.keras.models.save_model(model, save_directory) | |||
tf.keras.models.save_model(model, save_directory, include_optimizer) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The third argument of save_model
is overwrite
(doc), so we'll have to specify the keyword argument for include_optimizer
Thanks for this PR! Suggestion for future: before sending PRs for review, it's always good idea to test it at least locally since you might realize it does not work for X reason (wrong argument, weird TF bug, etc). Anyways, overall looks good 👍 feel free to open the PR for official review once you've tested it and added some automated test as well |
Hey folks, My two cents on this would be to use a WDYT? |
@osanseviero @ariG23498's suggestion is neat. I don't know which choice to pick because we'll set it to False by default to handle the custom optimizers (like the ones that tfa offers) user has to set it to True if they want to push a model with one of the default optimizers. Maybe take |
Yes, @merveenoyan. I think that sounds like a right approach and consistent with what we do in transformers and other places (setting |
I tested this in multiple ways, first one was directly pushing it by Second one was adding I also didn't include optimizer and saved the traces, then when loading I was asked to compile manually. I wonder if the UX is good enough at this point, I think it's just fine (so I need someone to criticize this well) @ariG23498 @osanseviero |
Co-authored-by: Omar Sanseviero <osanseviero@users.noreply.github.com>
@merveenoyan I think this works! Thanks for the great work 😄 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
I'm writing automated tests at the moment, I will not merge yet. 🤓 |
@osanseviero last week I've removed the pytest assertions from the two new test cases (they weren't going to be raised so). Should I merge? |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks a lot for this PR! Overall looks good, but I think the tests could improve a bit 😄
src/huggingface_hub/keras_mixin.py
Outdated
include_optimizer (:obj:`bool`, `optional`): | ||
Whether or not to include optimizer during serialization. | ||
model_save_kwargs(:obj:`dict`, `optional`): | ||
Arguments other than default arguments that can be passed to tf.keras.models.save_model(). |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Same here :)
self.assertIn("saved_model.pb", files) | ||
self.assertIn("keras_metadata.pb", files) | ||
self.assertEqual(len(files), 4) | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
A second test that shows that the optimizer is not there by default would be a good idea
Feel free to pin the |
AFTER FIGHTING BLACK FOR HOURS @osanseviero |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks a lot for this PR! I left a couple of improvement ideas but looks almost ready to merge! 🚀 🚀
tests/test_keras_integration.py
Outdated
model = from_pretrained_keras(f"{WORKING_REPO_DIR}/{REPO_NAME}") | ||
self.assertIsNone(model.optimizer) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I would suggest to avoid testing this as well in the same test. Maybe let's have a test which is
test_save_pretrained_does_not_save_optimizer_state_by_default
which has this
Co-authored-by: Omar Sanseviero <osanseviero@gmail.com>
@osanseviero on my local (checked black version) the quality test passes but here it fails. Any guesses? |
This PR addresses the discussion we had in #598.
I did a quick and dirty implementation and haven't tested yet so it's a WIP, I want to hear the review first.
For context: this parameter prevents models with custom optimizers (like the ones built with tfa) to be serialized and is by default set to True so I set it to False and made it an optional parameter.
cc: @osanseviero