Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Proper build() methods for TF #27794

Merged
merged 83 commits into from
Dec 14, 2023
Merged

Proper build() methods for TF #27794

merged 83 commits into from
Dec 14, 2023

Conversation

Rocketknight1
Copy link
Member

@Rocketknight1 Rocketknight1 commented Dec 1, 2023

TensorFlow builds weights lazily. This means that layers do not have an input_dim argument and do not create weight tensors in the model __init__(). Instead, the layers wait until their build() method is called, which usually happens implicitly the first time the layer receives an input. Layers use the shape of the first input they see, or the value explicitly passed to their build() method, to infer their input dim and build their weight tensors.

Up until now, almost none of our TF models had explicit build() methods. This meant that weights were built implicitly when the model was called, which required lots of tiny hacks all over the codebase:

  • We had to do an entire forward pass inside from_pretrained() to prepare the model weights so that we could load a checkpoint
  • We had to be careful about call stacks and name scopes to ensure that models did not accidentally build themselves inside an existing call/name context and destroy their weight names. This meant our code had to sniff the existing TF call stack, which (among many other issues) completely breaks Keras 3.
  • Several models had explicit calls to tf.name_scope() inside their forward pass (!) to control their weight names, which only worked because the weights were always built there

This had always been a big chunk of tech debt that I'd wanted to fix, but it was such a large task that I never really had time. However, with Keras 3 approaching, it became quite urgent. I tried getting GPT-4 to figure out the build() shapes automatically, but it generally failed, so I had to resort to using ast and static analysis of the PyTorch and TF modeling files to cross-match layers from TF to PyTorch code, using the input size arguments from PyTorch to automatically create and populate new build() methods, and then did a manual pass afterwards to fix up the remaining issues.

As a result, after this PR:

  • All models now have correct build() methods
  • Weight names are correct even if models are called weirdly, because we can now tightly control the build() hierarchy
  • Probably the single biggest source of TF bugs we had is gone
  • No more forward passes when building models, including with from_pretrained()! Should make model loading significantly faster, especially on CPU, and should help in the CI.
  • A major Keras 3 blocker is removed.

While I was working on this PR, I also encountered some other issues that I fixed in passing:

  • Added a build_in_name_scope() method and refactored some tests/methods to use it instead. Calling this method yields the same name hierarchy as implicitly calling build() when doing a forward pass, whereas directly calling model.build() does not (because TF enters a name_scope in __call__())
  • Updated TFAdaptivePool2D for Data2Vec, should massively improve model performance
  • Fixed some details in the TFSequenceSummary and TFConv1D classes. These are mostly used by older models.

Note to reviewers: Most of this PR was generated automatically, and just consists of walls of new build() methods. You can generally trust that these methods are correct so long as the CI is green, so you hopefully don't have to read them all - there's >11,000 lines of them! The main things to review are the changes in core files like modeling_tf_utils.py, the new build_in_name_scope() method, etc.

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint.

@Rocketknight1 Rocketknight1 force-pushed the proper_tf_weight_building branch 2 times, most recently from 6966265 to 0f16397 Compare December 6, 2023 20:54
This reverts commit b9df7a0.
This reverts commit 3302207.
and len(summarizer.model.trainable_weights) > 0
and "GPU" in summarizer.model.trainable_weights[0].device
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Quick note to any reviewers so they can make sense of this one: The problem is that get_gpu_count() is supposed to be framework-neutral, but actually checks the frameworks in order using is_torch_available() etc. and returns the GPU count for the first framework that matches.

This is very risky for TensorFlow, because TF environments will often have Torch as well, and if Torch is present then the Torch GPU count is returned instead, which may be 0 even if TF is running on GPU.

I just refactored the check to not use that function.


# Copied from:
# https://gist.github.com/Rocketknight1/43abbe6e73f1008e6e459486e01e0ceb
class TFAdaptiveAvgPool1D(tf.keras.layers.Layer):
Copy link
Member Author

@Rocketknight1 Rocketknight1 Dec 7, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Another note for reviewers: TF doesn't have an AdaptivePool layer like Torch does. I realized while updating the class in this PR that we were still using my old TF version of the layer, which is very inefficient. I wrote a much more performant version later, and so I took the opportunity to do the replacement here (it also fixed some naming issues that the old layer had)

@Rocketknight1
Copy link
Member Author

oh my god the tests pass i didn't think this was ever going to happen

Copy link
Member

@gante gante left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice and clean 🔥 I can't imagine the time you took to find this "simple" solution 👀

@Rocketknight1
Copy link
Member Author

Thanks! There's one thing left to add - some layers are buildable with int shapes in Keras 2, but that always fails in Keras 3. I'm going to do a quick replacement so that those become actual shapes (with extra None dimensions) - it should have no impact on Keras 2 either way.

@Rocketknight1
Copy link
Member Author

Quick update: The build shapes all have proper ranks instead of just being ints now, but our old method of controlling names with tf.name_scope() isn't working for Keras 3 - I've asked Chollet what the recommended solution there is

@amyeroberts amyeroberts requested review from amyeroberts and removed request for ArthurZucker December 14, 2023 13:54
@Rocketknight1
Copy link
Member Author

Got a solution, but I think it fits better in another PR! I'm gonna merge this one for now and see what shakes out in the nightly CI, while I work on the next phase of Keras 3 compatibility.

@Rocketknight1 Rocketknight1 merged commit 050e0b4 into main Dec 14, 2023
21 checks passed
@Rocketknight1 Rocketknight1 deleted the proper_tf_weight_building branch December 14, 2023 15:17
@ArthurZucker
Copy link
Collaborator

Not sure I get why this was merged?

Copy link
Collaborator

@amyeroberts amyeroberts left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As promised, a quick follow up review :) Thanks for this massive piece of work! Layer naming / building is definitely easier to follow like this ❤️

Just a few comments, mostly nits. Main ones are about saving the config in all of the layer modules and where the early return happens in some methods.

if self.built:
return
self.built = True
if getattr(self, "summary", None) is not None:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What's the significance of "summary" here?

@@ -146,7 +146,7 @@ def __init__(self, config: AlbertConfig, **kwargs):
self.LayerNorm = tf.keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="LayerNorm")
self.dropout = tf.keras.layers.Dropout(rate=config.hidden_dropout_prob)

def build(self, input_shape: tf.TensorShape):
def build(self, input_shape=None):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit - shouldn't this become an optional?

Suggested change
def build(self, input_shape=None):
def build(self, input_shape: Optional[tf.TensorShape] = None):

@@ -246,6 +251,7 @@ def __init__(self, config: AlbertConfig, **kwargs):
# Two different dropout probabilities; see https://github.com/google-research/albert/blob/master/modeling.py#L971-L993
self.attention_dropout = tf.keras.layers.Dropout(rate=config.attention_probs_dropout_prob)
self.output_dropout = tf.keras.layers.Dropout(rate=config.hidden_dropout_prob)
self.config = config
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Storing the whole config in each layer can start to make things a lot bigger if it's large e.g. id2label with many classes. In the case of TF with safetensors, are we just instantiating the class from the library and loading the weight e.g. we don't store all the class attributes when saving?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a good point, but I think it's okay! Our save_pretrained methods should just save the weights, and not save attributes of the layers like this. Also, setting self.config = config just creates a reference to the same underlying config object, so it doesn't use extra memory when the model is initialized either.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In general, a lot of our models do self.config = config in the __init__ anyway! I just needed to add a lot more of it so that the build() methods could see the config vars they need.

@@ -965,10 +1086,18 @@ def __init__(self, config, input_embeddings, **kwargs):
# an output-only bias for each token.
self.decoder = input_embeddings

def build(self, input_shape):
def build(self, input_shape=None):
self.bias = self.add_weight(shape=(self.config.vocab_size,), initializer="zeros", trainable=True, name="bias")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shouldn't this go after the check of if self.built?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Probably! In general, this happens when the class had an existing build method, in which case my new method is appended to the end of it. Existing build methods in the codebase don't always have an if built: clause. It shouldn't actually cause too many problems, but I could consider a pass to fix it up if it does.

@@ -169,7 +169,12 @@ def build(self, input_shape: tf.TensorShape = None):
name="embeddings",
)

super().build(input_shape)
if self.built:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same here - can we return quickly without calling self.add_weight if self.built is True?

@@ -766,7 +877,21 @@ def build(self, input_shape: tf.TensorShape = None):
name="logit_scale",
)

super().build(input_shape)
if self.built:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same q here. I'm not going to message any time I see it in case it's nothing but if it's not then you'll have to go through the PR to find all the cases!

# If a specific input shape is passed in, we need to modify it to account for padding
# Not necessary if those portions of the shape are None
if input_shape[-2] is not None:
input_shape[-2] += self.explicit_padding * 2
super().build(input_shape)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we still want this super().build() call here?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Probably not, but it should be harmless - I don't even think the super() method really does anything in TF anymore!

iantbutler01 pushed a commit to BismuthCloud/transformers that referenced this pull request Dec 16, 2023
* Add a convenience method for building in your own name scope

* Second attempt at auto layer building

* Revert "Second attempt at auto layer building"

This reverts commit e03a3aa.

* Attempt poedator#3

* Revert "Attempt poedator#3"

This reverts commit b9df7a0.

* Add missing attributes that we're going to need later

* Add some attributes we're going to need later

* A fourth attempt! Feel the power flow through you!

* Revert "A fourth attempt! Feel the power flow through you!"

This reverts commit 6bf4aaf.

* Add more values we'll need later

* TF refactor that we'll need later

* Revert "TF refactor that we'll need later"

This reverts commit ca07202.

* Revert "Revert "TF refactor that we'll need later""

This reverts commit 1beb0f3.

* make fixup

* Attempt five!

* Revert "Attempt five!"

This reverts commit 3302207.

* Attempt six - this time don't add empty methods

* Revert "Attempt six - this time don't add empty methods"

This reverts commit 67d6012.

* Attempt seven - better base model class detection!

* Revert "Attempt seven - better base model class detection!"

This reverts commit 5f14845.

* Another attribute we'll need later

* Try again with the missing attribute!

* Revert "Try again with the missing attribute!"

This reverts commit 760c6f3.

* This is the attempt that will pierce the heavens!

* Revert "This is the attempt that will pierce the heavens!"

This reverts commit c868bb6.

* Attempt seven - snag list is steadily decreasing

* Revert "Attempt seven - snag list is steadily decreasing"

This reverts commit 46fbd97.

* Attempt eight - will an empty snag list do it?

* Revert "Attempt eight - will an empty snag list do it?"

This reverts commit 7c8a3c2.

* Fixes to Hubert issues that cause problems later

* Trying again with Conv1D/SeparableConv fixes

* Revert "Trying again with Conv1D/SeparableConv fixes"

This reverts commit 55092bc.

* Apply the build shape fixes to Wav2Vec2 as well

* One more attempt!

* Revert "One more attempt!"

This reverts commit 5ac3e4c.

* Another attempt!

* Revert "Another attempt!"

This reverts commit ea16d89.

* Let's see how many failures we get without the internal build method

* Fix OpenAI

* Fix MobileBERT

* (Mostly) fix GroupVIT

* Fix BLIP

* One more BLIP fix

* One more BLIP fix!

* Fix Regnet

* Finally fully fix GroupViT

* Fix Data2Vec and add the new AdaptivePool

* Fix Segformer

* Fix Albert

* Fix Deberta/DebertaV2

* Fix XLM

* Actually fix XLM

* Fix Flaubert

* Fix lxmert

* Fix Resnet

* Fix ConvBERT

* Fix ESM

* Fix Convnext / ConvnextV2

* Fix SAM

* Fix Efficientformer

* Fix LayoutLMv3

* Fix speech_to_text

* Fix mpnet and mobilevit

* Fix Swin

* Fix CTRL

* Fix CVT

* Fix DPR

* Fix Wav2Vec2

* Fix T5

* Fix Hubert

* Fix GPT2

* Fix Whisper

* Fix DeiT

* Fix the encoder-decoder / dual-encoder classes

* make fix-copies

* build in name scope

* Fix summarization test

* Fix tied weight names for BART + Blenderbot

* Fix tied weight name building

* Fix to TFESM weight building

* Update TF SAM

* Expand all the shapes out into Big Boy Shapes
staghado pushed a commit to staghado/transformers that referenced this pull request Jan 15, 2024
* Add a convenience method for building in your own name scope

* Second attempt at auto layer building

* Revert "Second attempt at auto layer building"

This reverts commit e03a3aa.

* Attempt huggingface#3

* Revert "Attempt huggingface#3"

This reverts commit b9df7a0.

* Add missing attributes that we're going to need later

* Add some attributes we're going to need later

* A fourth attempt! Feel the power flow through you!

* Revert "A fourth attempt! Feel the power flow through you!"

This reverts commit 6bf4aaf.

* Add more values we'll need later

* TF refactor that we'll need later

* Revert "TF refactor that we'll need later"

This reverts commit ca07202.

* Revert "Revert "TF refactor that we'll need later""

This reverts commit 1beb0f3.

* make fixup

* Attempt five!

* Revert "Attempt five!"

This reverts commit 3302207.

* Attempt six - this time don't add empty methods

* Revert "Attempt six - this time don't add empty methods"

This reverts commit 67d6012.

* Attempt seven - better base model class detection!

* Revert "Attempt seven - better base model class detection!"

This reverts commit 5f14845.

* Another attribute we'll need later

* Try again with the missing attribute!

* Revert "Try again with the missing attribute!"

This reverts commit 760c6f3.

* This is the attempt that will pierce the heavens!

* Revert "This is the attempt that will pierce the heavens!"

This reverts commit c868bb6.

* Attempt seven - snag list is steadily decreasing

* Revert "Attempt seven - snag list is steadily decreasing"

This reverts commit 46fbd97.

* Attempt eight - will an empty snag list do it?

* Revert "Attempt eight - will an empty snag list do it?"

This reverts commit 7c8a3c2.

* Fixes to Hubert issues that cause problems later

* Trying again with Conv1D/SeparableConv fixes

* Revert "Trying again with Conv1D/SeparableConv fixes"

This reverts commit 55092bc.

* Apply the build shape fixes to Wav2Vec2 as well

* One more attempt!

* Revert "One more attempt!"

This reverts commit 5ac3e4c.

* Another attempt!

* Revert "Another attempt!"

This reverts commit ea16d89.

* Let's see how many failures we get without the internal build method

* Fix OpenAI

* Fix MobileBERT

* (Mostly) fix GroupVIT

* Fix BLIP

* One more BLIP fix

* One more BLIP fix!

* Fix Regnet

* Finally fully fix GroupViT

* Fix Data2Vec and add the new AdaptivePool

* Fix Segformer

* Fix Albert

* Fix Deberta/DebertaV2

* Fix XLM

* Actually fix XLM

* Fix Flaubert

* Fix lxmert

* Fix Resnet

* Fix ConvBERT

* Fix ESM

* Fix Convnext / ConvnextV2

* Fix SAM

* Fix Efficientformer

* Fix LayoutLMv3

* Fix speech_to_text

* Fix mpnet and mobilevit

* Fix Swin

* Fix CTRL

* Fix CVT

* Fix DPR

* Fix Wav2Vec2

* Fix T5

* Fix Hubert

* Fix GPT2

* Fix Whisper

* Fix DeiT

* Fix the encoder-decoder / dual-encoder classes

* make fix-copies

* build in name scope

* Fix summarization test

* Fix tied weight names for BART + Blenderbot

* Fix tied weight name building

* Fix to TFESM weight building

* Update TF SAM

* Expand all the shapes out into Big Boy Shapes
This was referenced Jan 15, 2024
ZJaume added a commit to bitextor/bicleaner-ai that referenced this pull request Apr 16, 2024
This fixes security issues #274, #275, #276.

Can't upgrade to a higher version because this change seems to break
model loading and some layers are failing to load:

huggingface/transformers#27794
ZJaume added a commit to bitextor/bicleaner-ai that referenced this pull request Apr 16, 2024
This fixes security issues #274, #275, #276.

Can't upgrade to a higher version because this change seems to break
model loading and some layers are failing to load:

huggingface/transformers#27794
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants