Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Move TF building to an actual build() method #23760

Merged
merged 17 commits into from
Jun 6, 2023
Merged

Conversation

Rocketknight1
Copy link
Member

@Rocketknight1 Rocketknight1 commented May 25, 2023

This has been a longstanding dream of mine: To move all TF model building into a proper build() method, using symbolic tensors instead of actual dummies. This would allow us to, among other things, stop our very hacky overriding of save_spec, as well as allowing us to build our TF models with zero device flops (although the speedup may be system-dependent, as we do have some compile time with this approach). It would make our models much closer to the Keras standard, which would stop Chollet casting curses upon me from afar.

In the past, we've run into serious problems with tensor names moving around when we tried this - I think I've figured out why, though, and I have a couple of ideas to resolve that without lots of hacky edge-case code.

This is an extremely draft PR that will break everything until I finish testing it properly!

Update: Using symbolic tensors is much slower - it works in most cases, but increases the time it takes for our tests to run by a factor of ~4, which is probably not acceptable. Instead, I'm going to rework this PR to move to a standard build() method using actual dummies. With some optimizations, I believe we can make this work, while still preserving most of the benefits of this PR, including not repeating the build unnecessarily and adding the ability to override build() to speed up our slowest models

@HuggingFaceDocBuilderDev
Copy link

HuggingFaceDocBuilderDev commented May 25, 2023

The documentation is not available anymore as the PR was closed or merged.

@Rocketknight1 Rocketknight1 changed the title Try moving all TF building to symbolic tensors Move TF building to an actual build() method May 26, 2023
@Rocketknight1 Rocketknight1 force-pushed the tf_functional_builds branch from e3068b1 to aa599e4 Compare May 30, 2023 13:19
@Rocketknight1 Rocketknight1 requested review from gante and sgugger May 30, 2023 13:28
@Rocketknight1
Copy link
Member Author

This should be ready to review now! Some tests failing, but that looks like Hub connection issues

Copy link
Collaborator

@sgugger sgugger left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for your PR. I think it needs more TF-expert eyes so tagging @amyeroberts

The changes in Longformer and LED are very big so should go in their own PR to make it easier for future blame.

)
out = tf.matmul(x, w, transpose_b=True)
if b is not None:
out += b
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is the transpose of b unnecessary here?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed!

@@ -22,6 +23,8 @@

logger = logging.get_logger(__name__)

build_context = threading.local()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is not used anywhere here?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah, you're right, sorry! This is leftover from an earlier approach I was trying.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed!

Copy link
Collaborator

@amyeroberts amyeroberts left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice :) Generally looks like a good clean-up, but I'm not a fan of the current API - it overrides default keras API behaviour in a non-obvious way. I left a more detailed comment in modeling_tf_utils.py.

One question I have is about the change in conditional logic checks in TF modeling code i.e. removing tf.cond(...) - is this necessary with this new build logic or just an update based on the new logic?

lambda: tf.transpose(pixel_values, perm=(0, 2, 3, 1)),
lambda: pixel_values,
)
if shape_list(pixel_values)[1] == num_channels:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think the updated check is compatible with graph mode / compilation of models for XLA

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

shape_list is actually quite smart (because I wasn't the one who wrote it) - it returns a list of the static shape for each dimension when this is known at compile time, and the dynamic shape when it isn't. In an XLA compilation pass shapes are fully static, and so the static shape will always be fully known. As a result, all shape_list calls will just return the static shape in an XLA context and not introduce data-dependent conditionals.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice! But also don't put yourself down - I thought you'd written it :) I fear we shall never rid ourselves of shape_list

Comment on lines 1159 to 1165
def build_with_dummies(self, dummies=None):
if self.built_with_dummies and dummies is None:
return
if dummies is None:
dummies = self.dummy_inputs
self(dummies, training=False)
self.built_with_dummies = True
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the API here is a bit confusing:

  • Allowing dummies to essentially be any input
  • Rebuilding if dummies is None.
  • Having input_shape as an argument for build and then not using it

I would rework to have build match the logic of the parent class, and build_with_dummies just use the model's dummy values. This way build remains generic and build_with_x specifically builds with x e.g. something like:

    def build(self, input_shape):
        if self.built or call_context().in_call:
            self.built = True
            return

        self(input_shape, training=False)
        self.built = True
        self.built_with_dummies = True

    def build_with_dummies(self):
        if self.built_with_dummies:
            return

        self(self.dummy_inputs, training=False)
        self.built_with_dummies = True
        self.built = True

This would:

  • Change all the calls from model.build() to model.build_with_dummies(). We can then remove all the comments next to the build calls explaining we're using dummy inputs.
  • Remove the need to call super().build(input_shape) when we want the old build logic.
  • Removes the need to set the input_shape to None in all the current build methods

also - why have the distinction between built and built_with_dummies?

@Rocketknight1
Copy link
Member Author

Rocketknight1 commented May 30, 2023

Actually, I should explain my reasoning for some of the changes here - you're probably right that I can improve the API, though!

Firstly, the removal of tf.cond is actually not a necessary part of this PR anymore, but it is good practice (Longformer and LED are the only two models in all of Transformers that use it in their modelling code). The reason is because of the Keras call stack. In the __call__ method for any TF module, Keras appends that layer to the call stack, and enters that layer's namespace. This means that if you have self.bert and that calls self.encoder and that calls self.attn, Keras will be in the bert/encoder/attn namespace.

Incredibly, though, tf.cond counts as a layer with its own namespace, but only when the tf.cond is not being eagerly evaluated. In my initial PR, I was trying to replace our dummies with symbolic TF tensors, which meant the tf.cond was not evaluated at compile time, but instead had to be compiled as a conditional in the model graph. The result is that all layer weights inside the conditional got encapsulated in a /cond.1/ namespace. This broke compatibility with existing checkpoints.

Removing tf.cond helped, but to be safe I added a manual build to those layers to directly control the weight naming regardless of what the call stack thought it should be. As a result, I could probably revert the tf.cond calls, but I think it's preferable if we don't, and just try to keep it out of modelling code and just use if statements instead (which TF can compile into graph conditionals if it can't resolve the branch to be chosen at compile time). tf.cond is fine in generation code where no weight names are created.

Secondly, the distinction between build() and build_with_dummies() is a bit of an ugly hack - I think I could probably remove build_with_dummies() entirely, but there was a piece of the TF-PT crossloading code that only worked if it could build the model with specific inputs of its choice. I added build_with_dummies() to support that, with a separate built_with_dummies flag to make sure that any repeated calls wouldn't waste more time. However, it would probably make more sense to just manually pass the inputs through the model in those particular crossloading functions and delete the method and the flag. WDYT?

@amyeroberts
Copy link
Collaborator

tf.cond counts as a layer with its own namespace, but only when the tf.cond is not being eagerly evaluated.

😑

In this case, let's rid ourselves of this pseudolayer! I'm pro the if/else changes :)

it would probably make more sense to just manually pass the inputs through the model in those particular crossloading functions and delete the method and the flag. WDYT?

Yep, that's what I would go for. Would it be possible to still have some of the logic to exit early if already built? Or would this be to tricky to handle to be worth it?

@Rocketknight1
Copy link
Member Author

I think we could, but it's probably not necessary - the only cases where we build the model with specific inputs are in weird PT-TF crossloading functions, which should always be called during or near model init anyway, so I think it's fine if there's a risk of a little bit of duplicated work there to save on overall code complexity.

@Rocketknight1 Rocketknight1 force-pushed the tf_functional_builds branch from 4386387 to c76c308 Compare June 2, 2023 14:37
@Rocketknight1
Copy link
Member Author

@amyeroberts Done! build_with_dummies is no more

Copy link
Collaborator

@amyeroberts amyeroberts left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for iterating - nice cleanup!

Just a small Q about build logic for LED / Longformer

Comment on lines +714 to +719
with tf.name_scope("query_global"):
self.query_global.build((self.config.hidden_size,))
with tf.name_scope("key_global"):
self.key_global.build((self.config.hidden_size,))
with tf.name_scope("value_global"):
self.value_global.build((self.config.hidden_size,))
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Two silly questions

  • why do these layers need to be built separately here?
  • why no super().build(input_shape) call in the method?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good question! The answer is that it's hard to get dummy inputs that correctly touch all of those layers, so they tend to be left un-built unless we explicitly build them.

As for the super().build(), I just forgot! The base build() method doesn't really do anything, but you're right that I should probably still call it just in case.

@Rocketknight1
Copy link
Member Author

Also, this PR looks ready but I'm going to let it sit for a couple of days to make sure the CI is working again after my last library-breaking PR, then merge it.

@Rocketknight1
Copy link
Member Author

Change of plans: The CI is working except for OOM errors during building for some of the pipelines, and since this cleans up building a bit we're going to merge this one too and see if it helps. If it doesn't, I'll open a new PR to see if I can lower the memory usage in the affected models.

@Rocketknight1 Rocketknight1 merged commit 4a55e47 into main Jun 6, 2023
@Rocketknight1 Rocketknight1 deleted the tf_functional_builds branch June 6, 2023 17:30
@@ -69,11 +74,14 @@
if parse(tf.__version__).minor >= 13:
from keras import backend as K
from keras.__internal__ import KerasTensor
from keras.engine.base_layer_utils import call_context
Copy link

@frostming frostming Jun 7, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This would break since keras 2.13 has moved the import to keras.src.engine

See #23663

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the catch, I'll make the fix ASAP!

lambda: _make_causal_mask(input_shape, past_key_values_length=past_key_values_length),
lambda: _expand_mask(tf.ones((batch_size, seq_len + past_key_values_length)), tgt_len=seq_len),
)
if seq_len > 1:
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @Rocketknight1, I am unfortunately having an issue with this change.
When I build a functional keras model using the Whisper encoder & decoder layers, I cannot serialize the model because of this change as it raises the error:

Using a symbolic `tf.Tensor` as a Python `bool` is not allowed: AutoGraph did convert this function. This might indicate you are trying to use an unsupported feature.

Here is a minimal reproducible example to raise the error:

from transformers import TFWhisperModel
import tensorflow as tf

whisper = TFWhisperModel.from_pretrained("openai/whisper-tiny")
inp = tf.keras.Input((80, 3000))
stack = whisper.get_encoder()(inp)
decoder_input_ids = tf.ones((tf.shape(inp)[0], 1), dtype=tf.int32)* whisper.config.decoder_start_token_id
stack = whisper.get_decoder()(input_ids=decoder_input_ids, encoder_hidden_states=stack.last_hidden_state)
model = tf.keras.Model(inp, stack)
model.summary()
model.save("whisper-tiny-custom")

What do you think?
I will open an issue for this to be referenced!

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I opened a corresponding issue

novice03 pushed a commit to novice03/transformers that referenced this pull request Jun 23, 2023
* A fun new PR where I break the entire codebase again

* A fun new PR where I break the entire codebase again

* Handle cross-attention

* Move calls to model(model.dummy_inputs) to the new build() method

* Seeing what fails with the build context thing

* make fix-copies

* Let's see what fails with new build methods

* Fix the pytorch crossload build calls

* Fix the overridden build methods in vision_text_dual_encoder

* Make sure all our build methods set self.built or call super().build(), which also sets it

* make fix-copies

* Remove finished TODO

* Tentatively remove unneeded (?) line

* Transpose b in deberta correctly and remove unused threading local

* Get rid of build_with_dummies and all it stands for

* Rollback some changes to TF-PT crossloading

* Correctly call super().build()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

6 participants