-
Notifications
You must be signed in to change notification settings - Fork 27.3k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
RFC: Replace custom TF embeddings by Keras embeddings #18939
Conversation
The documentation is not available anymore as the PR was closed or merged. |
@@ -137,7 +137,7 @@ def call( | |||
position_ids = tf.range(seq_len, delta=1, name="range") | |||
position_ids += past_key_values_length | |||
|
|||
return super().call(position_ids + self.offset) | |||
return super().call(position_ids + tf.constant(self.offset, dtype=tf.int32)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
(make fix-copies
)
@@ -230,69 +230,6 @@ def test_model_common_attributes(self): | |||
name = model.get_bias() | |||
assert name is None | |||
|
|||
def test_resize_token_embeddings(self): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The exact same test exists in test_modeling_tf_common
, and is touched in this PR
4051fa4
to
01b7a9f
Compare
01b7a9f
to
4051fa4
Compare
My thoughts:
So overall, I think this is a good addition that cleans up a longstanding source of issues in the code, and shouldn't take too long to implement across the codebase. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It does look cleaner this way, and handling the resize token embeddings for all new models is something that should be done either way. So I'm in favor of this switch.
4051fa4
to
8b651e0
Compare
What does this PR do?
This is an RFC with a code example in the PR -- my primary goal is not to get the PR approved, but rather to discuss an improvement to our TF codebase, with an example that passes all tests.
Context
In our TF implementation of models with embedding layers, we rely on two custom-made classes:
TFSharedEmbeddings
-- a custom embedding layer whose added benefit is the ability to also use it as a dense layer;TFWrappedEmbeddings
-- used to manipulate the scope of the weights, which would normally depend on the layer where the weights are first used in an operation. Used with tied weight embeddings.Problems with this setup include:
tf.compat.v1.variable_scope
);tf.Variable
.Proposed change
The proposal is straightforward: replace
TFSharedEmbeddings
bytf.keras.layer.Embedding
, removeTFWrappedEmbeddings
, and make the necessary adaptations. A few details to keep in mind (and that you can browse in the code):if/else
in the original functions, changed functions were rewritten with_v2
prepended to their name (which should also facilitate the transition). You can see that the new functions are simpler than the originals;tf.name_scope
. Normally,tf.name_scope
appends to the existing scope -- if the scope for the current layer isfoo
, weights are in the form offoo/weights:0
; if we add a context managertf.name_scope("bar")
, weights will be in the form offoo/bar/weights:0
. However, if the argument oftf.name_scope
ends with/
, then it will be a stand-alone name scope. Taking the previous example, withtf.name_scope("bar/")
, weights will be in the form ofbar/weights:0
. This behavior has been in the TF codebase since its first commit (>7 yrs), and replacingTFWrappedEmbeddings
relies on this behavior;I estimate about a 1-2 weeks worth of work to propagate the change, which includes:
TFSharedEmbeddings
andTFWrappedEmbeddings
;resize_token_embeddings
is not implemented/is broken (and untested) in several recent TF models;Pros/cons
(+) Simpler and smaller codebase, especially for the models;
(+) TF model code closer to PT's;
(+) Keras-native embeddings ( = users and contributors can be more productive);
(+)
resize_token_embeddings
usable in all models;(-) Time spent refactoring is time not spent building new things;
(-) The solution still relies on named scopes for cross-framework weight matching, which is hacky.