-
Notifications
You must be signed in to change notification settings - Fork 27.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Move TF building to an actual build() method #23760
Conversation
The documentation is not available anymore as the PR was closed or merged. |
e3068b1
to
aa599e4
Compare
This should be ready to review now! Some tests failing, but that looks like Hub connection issues |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for your PR. I think it needs more TF-expert eyes so tagging @amyeroberts
The changes in Longformer and LED are very big so should go in their own PR to make it easier for future blame.
) | ||
out = tf.matmul(x, w, transpose_b=True) | ||
if b is not None: | ||
out += b |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is the transpose of b unnecessary here?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fixed!
src/transformers/tf_utils.py
Outdated
@@ -22,6 +23,8 @@ | |||
|
|||
logger = logging.get_logger(__name__) | |||
|
|||
build_context = threading.local() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is not used anywhere here?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ah, you're right, sorry! This is leftover from an earlier approach I was trying.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fixed!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice :) Generally looks like a good clean-up, but I'm not a fan of the current API - it overrides default keras API behaviour in a non-obvious way. I left a more detailed comment in modeling_tf_utils.py.
One question I have is about the change in conditional logic checks in TF modeling code i.e. removing tf.cond(...)
- is this necessary with this new build logic or just an update based on the new logic?
lambda: tf.transpose(pixel_values, perm=(0, 2, 3, 1)), | ||
lambda: pixel_values, | ||
) | ||
if shape_list(pixel_values)[1] == num_channels: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't think the updated check is compatible with graph mode / compilation of models for XLA
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
shape_list
is actually quite smart (because I wasn't the one who wrote it) - it returns a list of the static shape for each dimension when this is known at compile time, and the dynamic shape when it isn't. In an XLA compilation pass shapes are fully static, and so the static shape will always be fully known. As a result, all shape_list
calls will just return the static shape in an XLA context and not introduce data-dependent conditionals.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice! But also don't put yourself down - I thought you'd written it :) I fear we shall never rid ourselves of shape_list
def build_with_dummies(self, dummies=None): | ||
if self.built_with_dummies and dummies is None: | ||
return | ||
if dummies is None: | ||
dummies = self.dummy_inputs | ||
self(dummies, training=False) | ||
self.built_with_dummies = True |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think the API here is a bit confusing:
- Allowing
dummies
to essentially be any input - Rebuilding if
dummies
is None. - Having
input_shape
as an argument forbuild
and then not using it
I would rework to have build
match the logic of the parent class, and build_with_dummies
just use the model's dummy values. This way build
remains generic and build_with_x
specifically builds with x
e.g. something like:
def build(self, input_shape):
if self.built or call_context().in_call:
self.built = True
return
self(input_shape, training=False)
self.built = True
self.built_with_dummies = True
def build_with_dummies(self):
if self.built_with_dummies:
return
self(self.dummy_inputs, training=False)
self.built_with_dummies = True
self.built = True
This would:
- Change all the calls from
model.build()
tomodel.build_with_dummies()
. We can then remove all the comments next to the build calls explaining we're using dummy inputs. - Remove the need to call
super().build(input_shape)
when we want the old build logic. - Removes the need to set the input_shape to
None
in all the current build methods
also - why have the distinction between built
and built_with_dummies
?
Actually, I should explain my reasoning for some of the changes here - you're probably right that I can improve the API, though! Firstly, the removal of Incredibly, though, Removing Secondly, the distinction between |
😑 In this case, let's rid ourselves of this pseudolayer! I'm pro the if/else changes :)
Yep, that's what I would go for. Would it be possible to still have some of the logic to exit early if already built? Or would this be to tricky to handle to be worth it? |
I think we could, but it's probably not necessary - the only cases where we build the model with specific inputs are in weird PT-TF crossloading functions, which should always be called during or near model init anyway, so I think it's fine if there's a risk of a little bit of duplicated work there to save on overall code complexity. |
…), which also sets it
4386387
to
c76c308
Compare
@amyeroberts Done! |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for iterating - nice cleanup!
Just a small Q about build
logic for LED / Longformer
with tf.name_scope("query_global"): | ||
self.query_global.build((self.config.hidden_size,)) | ||
with tf.name_scope("key_global"): | ||
self.key_global.build((self.config.hidden_size,)) | ||
with tf.name_scope("value_global"): | ||
self.value_global.build((self.config.hidden_size,)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Two silly questions
- why do these layers need to be built separately here?
- why no
super().build(input_shape)
call in the method?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good question! The answer is that it's hard to get dummy inputs that correctly touch all of those layers, so they tend to be left un-built unless we explicitly build them.
As for the super().build()
, I just forgot! The base build()
method doesn't really do anything, but you're right that I should probably still call it just in case.
Also, this PR looks ready but I'm going to let it sit for a couple of days to make sure the CI is working again after my last library-breaking PR, then merge it. |
Change of plans: The CI is working except for OOM errors during building for some of the pipelines, and since this cleans up building a bit we're going to merge this one too and see if it helps. If it doesn't, I'll open a new PR to see if I can lower the memory usage in the affected models. |
@@ -69,11 +74,14 @@ | |||
if parse(tf.__version__).minor >= 13: | |||
from keras import backend as K | |||
from keras.__internal__ import KerasTensor | |||
from keras.engine.base_layer_utils import call_context |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This would break since keras
2.13 has moved the import to keras.src.engine
See #23663
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the catch, I'll make the fix ASAP!
lambda: _make_causal_mask(input_shape, past_key_values_length=past_key_values_length), | ||
lambda: _expand_mask(tf.ones((batch_size, seq_len + past_key_values_length)), tgt_len=seq_len), | ||
) | ||
if seq_len > 1: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hi @Rocketknight1, I am unfortunately having an issue with this change.
When I build a functional keras model using the Whisper encoder & decoder layers, I cannot serialize the model because of this change as it raises the error:
Using a symbolic `tf.Tensor` as a Python `bool` is not allowed: AutoGraph did convert this function. This might indicate you are trying to use an unsupported feature.
Here is a minimal reproducible example to raise the error:
from transformers import TFWhisperModel
import tensorflow as tf
whisper = TFWhisperModel.from_pretrained("openai/whisper-tiny")
inp = tf.keras.Input((80, 3000))
stack = whisper.get_encoder()(inp)
decoder_input_ids = tf.ones((tf.shape(inp)[0], 1), dtype=tf.int32)* whisper.config.decoder_start_token_id
stack = whisper.get_decoder()(input_ids=decoder_input_ids, encoder_hidden_states=stack.last_hidden_state)
model = tf.keras.Model(inp, stack)
model.summary()
model.save("whisper-tiny-custom")
What do you think?
I will open an issue for this to be referenced!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I opened a corresponding issue
* A fun new PR where I break the entire codebase again * A fun new PR where I break the entire codebase again * Handle cross-attention * Move calls to model(model.dummy_inputs) to the new build() method * Seeing what fails with the build context thing * make fix-copies * Let's see what fails with new build methods * Fix the pytorch crossload build calls * Fix the overridden build methods in vision_text_dual_encoder * Make sure all our build methods set self.built or call super().build(), which also sets it * make fix-copies * Remove finished TODO * Tentatively remove unneeded (?) line * Transpose b in deberta correctly and remove unused threading local * Get rid of build_with_dummies and all it stands for * Rollback some changes to TF-PT crossloading * Correctly call super().build()
This has been a longstanding dream of mine: To move all TF model building into a proper
build()
method, using symbolic tensors instead of actual dummies. This would allow us to, among other things, stop our very hacky overriding ofsave_spec
, as well as allowing us to build our TF models with zero device flops (although the speedup may be system-dependent, as we do have some compile time with this approach). It would make our models much closer to the Keras standard, which would stop Chollet casting curses upon me from afar.In the past, we've run into serious problems with tensor names moving around when we tried this - I think I've figured out why, though, and I have a couple of ideas to resolve that without lots of hacky edge-case code.
This is an extremely draft PR that will break everything until I finish testing it properly!
Update: Using symbolic tensors is much slower - it works in most cases, but increases the time it takes for our tests to run by a factor of ~4, which is probably not acceptable. Instead, I'm going to rework this PR to move to a standard build() method using actual dummies. With some optimizations, I believe we can make this work, while still preserving most of the benefits of this PR, including not repeating the build unnecessarily and adding the ability to override
build()
to speed up our slowest models