Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

RFC: Replace custom TF embeddings by Keras embeddings #18939

Merged
merged 5 commits into from
Sep 10, 2022

Conversation

gante
Copy link
Member

@gante gante commented Sep 8, 2022

What does this PR do?

This is an RFC with a code example in the PR -- my primary goal is not to get the PR approved, but rather to discuss an improvement to our TF codebase, with an example that passes all tests.

Context

In our TF implementation of models with embedding layers, we rely on two custom-made classes:

  1. TFSharedEmbeddings -- a custom embedding layer whose added benefit is the ability to also use it as a dense layer;
  2. TFWrappedEmbeddings -- used to manipulate the scope of the weights, which would normally depend on the layer where the weights are first used in an operation. Used with tied weight embeddings.

Problems with this setup include:

  1. Users can't use the expected Keras tools to handle embeddings;
  2. Relies on TF1 compatibility to set the right name to the weights (tf.compat.v1.variable_scope);
  3. Resizing the embeddings, a major source of bugs atm, uses complex logic that consists in manipulating tf.Variable.

Proposed change

The proposal is straightforward: replace TFSharedEmbeddings by tf.keras.layer.Embedding, remove TFWrappedEmbeddings, and make the necessary adaptations. A few details to keep in mind (and that you can browse in the code):

  1. There is a whole new code path for resizing the embeddings. Instead of if/else in the original functions, changed functions were rewritten with _v2 prepended to their name (which should also facilitate the transition). You can see that the new functions are simpler than the originals;
  2. Giving the right name to the embeddings (so we can load existing weights) was the hardest part. TF had limited maneuverability here. To pull it off, I relied on UNDOCUMENTED behavior of tf.name_scope. Normally, tf.name_scope appends to the existing scope -- if the scope for the current layer is foo, weights are in the form of foo/weights:0; if we add a context manager tf.name_scope("bar"), weights will be in the form of foo/bar/weights:0. However, if the argument of tf.name_scope ends with /, then it will be a stand-alone name scope. Taking the previous example, with tf.name_scope("bar/"), weights will be in the form of bar/weights:0. This behavior has been in the TF codebase since its first commit (>7 yrs), and replacing TFWrappedEmbeddings relies on this behavior;
  3. The existing TF Bart assumes the input/output embeddings are tied, which PT Bart does not assume. I've not changed this part, so the example you can see in this PR is for models with tied weights;
  4. If you open PT Bart and compare side by side, you'll see that the implementations on the two frameworks are now more similar :)

I estimate about a 1-2 weeks worth of work to propagate the change, which includes:

  1. Replace all TFSharedEmbeddings and TFWrappedEmbeddings;
  2. Handle edge cases -- resize_token_embeddings is not implemented/is broken (and untested) in several recent TF models;
  3. Remove/deprecate old code after 1 and 2 are done.

Pros/cons

(+) Simpler and smaller codebase, especially for the models;
(+) TF model code closer to PT's;
(+) Keras-native embeddings ( = users and contributors can be more productive);
(+) resize_token_embeddings usable in all models;
(-) Time spent refactoring is time not spent building new things;
(-) The solution still relies on named scopes for cross-framework weight matching, which is hacky.

@HuggingFaceDocBuilderDev
Copy link

HuggingFaceDocBuilderDev commented Sep 8, 2022

The documentation is not available anymore as the PR was closed or merged.

@@ -137,7 +137,7 @@ def call(
position_ids = tf.range(seq_len, delta=1, name="range")
position_ids += past_key_values_length

return super().call(position_ids + self.offset)
return super().call(position_ids + tf.constant(self.offset, dtype=tf.int32))
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(make fix-copies)

@@ -230,69 +230,6 @@ def test_model_common_attributes(self):
name = model.get_bias()
assert name is None

def test_resize_token_embeddings(self):
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The exact same test exists in test_modeling_tf_common, and is touched in this PR

@Rocketknight1
Copy link
Member

My thoughts:

  • I think the behaviour of tf.name_scope is intended and stable, even if it's not documented (TF documentation isn't always great). I think we can rely on that safely, and it's a lot better than using compatibility methods from v1.
  • I agree that how we're doing this right now isn't great, and this code is a big improvement.
  • I think how we use name_scope is still a little problematic. However, I don't want to make any big breaking changes there right now because the PT codebase will probably also change soon to use whatever new pickle-free state dict save format the PT devs come up with!

So overall, I think this is a good addition that cleans up a longstanding source of issues in the code, and shouldn't take too long to implement across the codebase.

@gante gante requested review from sgugger and LysandreJik September 8, 2022 16:43
Copy link
Collaborator

@sgugger sgugger left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It does look cleaner this way, and handling the resize token embeddings for all new models is something that should be done either way. So I'm in favor of this switch.

@gante gante marked this pull request as ready for review September 9, 2022 15:25
@gante gante merged commit 00cbadb into huggingface:main Sep 10, 2022
@gante gante deleted the keras_embeddings branch September 10, 2022 10:34
oneraghavan pushed a commit to oneraghavan/transformers that referenced this pull request Sep 26, 2022
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants