-
Notifications
You must be signed in to change notification settings - Fork 27.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Fix weight tying in TF-ESM #22839
Fix weight tying in TF-ESM #22839
Conversation
Also cc @gante in case he hates how I handled weight tying here, I don't want to break TF convention too much! |
The documentation is not available anymore as the PR was closed or merged. |
@Rocketknight1 I'm cool with this :D |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM but @amyeroberts might have more insight in the TF subtleties of this code.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
All looks good to me! Thanks for adding this :)
I just have a small question about the tests.
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() | ||
|
||
for model_class in self.all_model_classes: | ||
model = model_class(config) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The modification of the modeling code is controlled by self.config.tie_word_embeddings
, but all the models here are using the same config. Is the toggling of tying or not tying the weights tested elsewhere?
I'm quite likely missing something just judging on the diff here though
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The reason for overriding this test is that the common test expects model.get_output_embedding()
to return a tf.keras.layers.Layer
, but in this case I'm using a simple shared matrix created with add_weight
, so I had to tweak the test a little. I'm actually not testing the effect of the tie_word_embeddings
parameter anywhere, but maybe I should, unless it's already tested somewhere in the common tests!
@@ -1102,6 +1103,11 @@ def __init__(self, config): | |||
|
|||
self.esm = TFEsmMainLayer(config, add_pooling_layer=False, name="esm") | |||
self.lm_head = TFEsmLMHead(config, name="lm_head") | |||
if config.tie_word_embeddings: | |||
# Ensure word embeddings are built so that we actually have something to tie | |||
with tf.name_scope(os.path.join(self._name_scope(), "esm", "embeddings", "word_embeddings")): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Cries in TensorFlow 😭
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I know, but it's the only option when we want to create a weight for a distant sublayer! TF only 'walks' the name scope hierarchy via the call stack when weights are built during call()
; every other time you have to explicitly enter the tf.name_scope()
you want.
Fix weight tying in ESM
TF ESM cloned weights instead of tying, which worked when loading from PT but broke when loading from safetensors. This resolves the issue by correctly tying weights when this is enabled in the config. Fixes an ongoing CI error raised by @ydshieh