Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

transformers incompatible with master (head of trunk) tensorflow & keras 3 #28296

Closed
ekuznetsov139 opened this issue Jan 1, 2024 · 9 comments · Fixed by #28588
Closed

transformers incompatible with master (head of trunk) tensorflow & keras 3 #28296

ekuznetsov139 opened this issue Jan 1, 2024 · 9 comments · Fixed by #28588

Comments

@ekuznetsov139
Copy link

ekuznetsov139 commented Jan 1, 2024

I am trying to get transformers working with head-of-trunk tensorflow, which requires keras 3 (I'm using keras-nightly (3.0.3.dev2023123103)), and I'm running into issues that seem to be caused by changes in internal behavior of keras. Neither 4.36.2 nor head-of-trunk transformers work.

My test script is simply:

from transformers import  GPT2TokenizerFast,  TFGPT2LMHeadModel
import tensorflow as tf

tokenizer = GPT2TokenizerFast.from_pretrained("gpt2", mask_token='#')
model = TFGPT2LMHeadModel.from_pretrained("gpt2")
optimizer = tf.keras.optimizers.Adam(learning_rate=1e-3)
model.compile(optimizer=optimizer, loss="passthrough", metrics=[])

This works with transformers 4.36.2, tensorflow 2.14, keras 2.14.

With head of trunk TF and 4.36.2, I get:

    model = TFGPT2LMHeadModel.from_pretrained("gpt2")
  File "/usr/local/lib/python3.9/dist-packages/transformers/modeling_tf_utils.py", line 2919, in from_pretrained
    model.build()  # build the network with dummy inputs
  File "/usr/local/lib/python3.9/dist-packages/keras/src/layers/layer.py", line 223, in build_wrapper
    original_build_method(*args, **kwargs)
  File "/usr/local/lib/python3.9/dist-packages/transformers/modeling_tf_utils.py", line 1134, in build
    if self.built or call_context().in_call:
TypeError: 'NoneType' object is not callable

This is evidently because https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/util/keras_deps.py#L40 is no longer being called from keras 3.0.x and so https://github.com/huggingface/transformers/blob/v4.36.2/src/transformers/modeling_tf_utils.py#L1133 returns None.

I can bypass this, but then I run into a new problem:

Some weights of the PyTorch model were not used when initializing the TF 2.0 model TFGPT2LMHeadModel: ['h.4.mlp.c_proj.bias', 'h.10.attn.c_attn.weight', <.....>,  'h.9.attn.c_attn.bias', 'h.0.attn.c_attn.bias']

I did some tracing, and the cause is that, when the code hits https://github.com/huggingface/transformers/blob/v4.36.2/src/transformers/modeling_tf_utils.py#L2905, tf_model.trainable_weights is empty, so transformers can't load any weights into it. I tried moving the block at lines 2915-2919 above the load call, but it has no effect.

Then I tried head of trunk transformers. It fails too, but it fails with different symptoms. First, there is:

  File "/usr/local/lib/python3.9/dist-packages/transformers/modeling_tf_utils.py", line 2889, in from_pretrained
    model = cls(config, *model_args, **model_kwargs)
  File "/usr/local/lib/python3.9/dist-packages/transformers/models/gpt2/modeling_tf_gpt2.py", line 847, in __init__
    super().__init__(config, *inputs, **kwargs)
  File "/usr/local/lib/python3.9/dist-packages/transformers/modeling_tf_utils.py", line 1150, in __init__
    self._set_save_spec(self.input_signature)
  File "/usr/local/lib/python3.9/dist-packages/tensorflow/python/trackable/base.py", line 205, in _method_wrapper
    result = method(self, *args, **kwargs)
  File "/usr/local/lib/python3.9/dist-packages/keras/src/backend/tensorflow/layer.py", line 34, in _set_save_spec
    for key, kwarg in kwargs.items():
AttributeError: 'NoneType' object has no attribute 'items'

The problem is that, at
https://github.com/huggingface/transformers/blob/main/src/transformers/modeling_tf_utils.py#L1150, you're calling
self._set_save_spec(self.input_signature)

and hitting https://github.com/keras-team/keras/blob/v3.0.2/keras/backend/tensorflow/layer.py#L16

def _set_save_spec(self, inputs, args=None, kwargs=None)

which is declared with the default parameter 'kwargs=None', but really expects kwargs to be a dict. The logical workaround is
self._set_save_spec(self.input_signature, kwargs={})

This gets me to problem number 2:

    original_build_method(*args, **kwargs)
  File "/usr/local/lib/python3.9/dist-packages/transformers/modeling_tf_utils.py", line 3217, in build
    self.weight = self.add_weight(
TypeError: add_weight() got multiple values for argument 'shape'

This happens because keras has reordered arguments of Layer.add_weight():
https://github.com/keras-team/keras/blob/v2.15.0/keras/engine/base_layer.py#L553
https://github.com/keras-team/keras/blob/v3.0.2/keras/layers/layer.py#L448

so you need to add explicit name= in https://github.com/huggingface/transformers/blob/main/src/transformers/modeling_tf_utils.py#L3217 and again in https://github.com/huggingface/transformers/blob/main/src/transformers/modeling_tf_utils.py#L3220.

Unfortunately, even that does not let me load the model, because there's some kind of a glitch that prevents the TF model from correctly setting its weight names, so I get this error:

Some weights of the PyTorch model were not used when initializing the TF 2.0 model TFGPT2LMHeadModel: ['h.1.attn.c_attn.weight', 'h.2.attn.c_proj.weight', 'h.7.attn.c_proj.bias', 'h.0.attn.c_proj.weight', 'h.8.attn.c_attn.weight', 'h.0.mlp.c_proj.weight', 'h.1.ln_2.bias', 'h.10.attn.c_attn.weight', 'h.2.mlp.c_proj.weight', 'h.8.ln_1.weight', 'h.1.ln_2.weight', 'h.6.mlp.c_fc.bias', 'h.10.ln_1.bias', 'h.10.mlp.c_proj.weight', 'h.3.ln_2.bias', 'h.4.ln_1.weight', 'h.5.mlp.c_proj.weight', 'h.3.attn.c_proj.bias', 'h.2.ln_2.weight', 'h.3.mlp.c_proj.bias', 'h.4.attn.c_proj.bias', 'h.11.attn.c_attn.weight', 'h.9.ln_2.bias', 'h.0.ln_2.bias', 'h.0.attn.c_attn.bias', 'h.4.attn.c_attn.bias', 'h.6.mlp.c_proj.bias', 'h.3.attn.c_attn.weight', 'h.11.ln_2.weight', 'h.11.ln_2.bias', 'h.0.ln_1.weight', 'h.4.mlp.c_proj.weight', 'h.8.attn.c_attn.bias', 'h.4.attn.c_attn.weight', 'h.5.mlp.c_proj.bias', 'h.11.mlp.c_proj.weight', 'h.11.attn.c_proj.weight', 'h.8.attn.c_proj.weight', 'h.3.ln_1.bias', 'h.8.ln_1.bias', 'h.5.ln_2.weight', 'h.3.attn.c_attn.bias', 'h.8.mlp.c_fc.bias', 'h.11.mlp.c_fc.bias', 'h.6.ln_2.bias', 'h.9.mlp.c_fc.weight', 'h.1.ln_1.bias', 'h.3.attn.c_proj.weight', 'h.1.mlp.c_fc.bias', 'h.0.mlp.c_fc.bias', 'h.8.mlp.c_proj.weight', 'h.7.mlp.c_fc.bias', 'h.1.mlp.c_fc.weight', 'h.10.mlp.c_fc.bias', 'h.0.attn.c_attn.weight', 'h.11.attn.c_attn.bias', 'h.5.attn.c_attn.weight', 'h.6.mlp.c_proj.weight', 'h.4.ln_2.bias', 'h.5.mlp.c_fc.weight', 'h.8.mlp.c_fc.weight', 'h.11.attn.c_proj.bias', 'h.3.mlp.c_fc.bias', 'h.2.ln_1.weight', 'h.0.attn.c_proj.bias', 'h.0.mlp.c_fc.weight', 'h.6.attn.c_attn.bias', 'h.2.ln_2.bias', 'h.8.ln_2.weight', 'h.1.mlp.c_proj.weight', 'h.7.ln_1.bias', 'h.6.mlp.c_fc.weight', 'h.7.attn.c_attn.weight', 'h.6.attn.c_attn.weight', 'h.4.ln_1.bias', 'h.2.mlp.c_proj.bias', 'h.7.attn.c_proj.weight', 'h.9.ln_1.bias', 'h.4.mlp.c_fc.weight', 'h.6.ln_1.bias', 'h.9.mlp.c_proj.bias', 'h.10.mlp.c_proj.bias', 'h.11.mlp.c_proj.bias', 'h.4.ln_2.weight', 'h.6.attn.c_proj.weight', 'h.9.attn.c_attn.weight', 'h.9.attn.c_proj.weight', 'h.11.ln_1.bias', 'wpe.weight', 'h.8.attn.c_proj.bias', 'h.7.ln_1.weight', 'h.10.ln_2.bias', 'h.0.mlp.c_proj.bias', 'h.0.ln_2.weight', 'h.4.mlp.c_proj.bias', 'h.6.ln_1.weight', 'h.7.mlp.c_proj.bias', 'h.8.ln_2.bias', 'h.8.mlp.c_proj.bias', 'h.5.ln_1.weight', 'h.9.mlp.c_proj.weight', 'h.5.attn.c_attn.bias', 'h.2.ln_1.bias', 'h.1.attn.c_proj.weight', 'h.9.ln_1.weight', 'h.11.ln_1.weight', 'h.5.attn.c_proj.weight', 'h.4.mlp.c_fc.bias', 'h.5.ln_2.bias', 'h.2.attn.c_attn.weight', 'h.7.attn.c_attn.bias', 'h.7.mlp.c_proj.weight', 'h.1.mlp.c_proj.bias', 'h.5.attn.c_proj.bias', 'h.11.mlp.c_fc.weight', 'h.10.attn.c_proj.weight', 'h.3.ln_1.weight', 'h.10.attn.c_proj.bias', 'h.3.mlp.c_fc.weight', 'h.4.attn.c_proj.weight', 'h.2.attn.c_attn.bias', 'h.3.ln_2.weight', 'h.10.attn.c_attn.bias', 'h.3.mlp.c_proj.weight', 'h.1.attn.c_proj.bias', 'h.2.mlp.c_fc.bias', 'h.9.ln_2.weight', 'h.5.ln_1.bias', 'h.10.ln_1.weight', 'h.7.mlp.c_fc.weight', 'ln_f.bias', 'h.2.attn.c_proj.bias', 'h.0.ln_1.bias', 'h.7.ln_2.bias', 'h.7.ln_2.weight', 'h.6.attn.c_proj.bias', 'h.10.mlp.c_fc.weight', 'wte.weight', 'h.9.mlp.c_fc.bias', 'h.1.ln_1.weight', 'h.6.ln_2.weight', 'h.1.attn.c_attn.bias', 'h.9.attn.c_attn.bias', 'h.2.mlp.c_fc.weight', 'h.5.mlp.c_fc.bias', 'ln_f.weight', 'h.9.attn.c_proj.bias', 'h.10.ln_2.weight']
- This IS expected if you are initializing TFGPT2LMHeadModel from a PyTorch model trained on another task or with another architecture (e.g. initializing a TFBertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing TFGPT2LMHeadModel from a PyTorch model that you expect to be exactly identical (e.g. initializing a TFBertForSequenceClassification model from a BertForSequenceClassification model).
Some weights or buffers of the TF 2.0 model TFGPT2LMHeadModel were not initialized from the PyTorch model and are newly initialized: ['weight', 'weight', 'weight', 'bias', 'weight', 'bias', 'weight', 'bias', 'weight', 'bias', 'weight', 'bias', 'weight', 'bias', 'weight', 'bias', 'weight', 'bias', 'weight', 'bias', 'weight', 'bias', 'weight', 'bias', 'weight', 'bias', 'weight', 'bias', 'weight', 'bias', 'weight', 'bias', 'weight', 'bias', 'weight', 'bias', 'weight', 'bias', 'weight', 'bias', 'weight', 'bias', 'weight', 'bias', 'weight', 'bias', 'weight', 'bias', 'weight', 'bias', 'weight', 'bias', 'weight', 'bias', 'weight', 'bias', 'weight', 'bias', 'weight', 'bias', 'weight', 'bias', 'weight', 'bias', 'weight', 'bias', 'weight', 'bias', 'weight', 'bias', 'weight', 'bias', 'weight', 'bias', 'weight', 'bias', 'weight', 'bias', 'weight', 'bias', 'weight', 'bias', 'weight', 'bias', 'weight', 'bias', 'weight', 'bias', 'weight', 'bias', 'weight', 'bias', 'weight', 'bias', 'weight', 'bias', 'weight', 'bias', 'weight', 'bias', 'weight', 'bias', 'weight', 'bias', 'weight', 'bias', 'weight', 'bias', 'weight', 'bias', 'weight', 'bias', 'weight', 'bias', 'weight', 'bias', 'weight', 'bias', 'weight', 'bias', 'weight', 'bias', 'weight', 'bias', 'weight', 'bias', 'weight', 'bias', 'weight', 'bias', 'weight', 'bias', 'weight', 'bias', 'weight', 'bias', 'weight', 'bias', 'weight', 'bias', 'weight', 'bias', 'weight', 'bias', 'weight', 'bias', 'weight', 'bias']
@ekuznetsov139
Copy link
Author

ekuznetsov139 commented Jan 1, 2024

Here's what needs to be done:

ROCm@0b3d80d

Some of this may break keras2 operation, I began adding version checks but have not had time to do it properly. I had to disable jit_compile, because I was getting XLA-related errors and this was an easy way out; I need to investigate and fix that problem as well.

This gets me as far as being able to train for at least one epoch. Loss values seem to be off but the model loads and trains and loss goes down with time.

@ekuznetsov139
Copy link
Author

I'll just keep talking to myself here, nevermind me.

ROCm@a488708

It trains, apparently correctly, on multiple GPUs (using tf.distribute.MirroredStrategy) and with XLA enabled.

Reported loss is multiplied by the number of GPUs, and I can't quite work out why.

The bigger issue, however, is that mixed precision is broken:

    File "/usr/local/lib/python3.9/dist-packages/keras/src/backend/tensorflow/trainer.py", line 105, in one_step_on_data  **
        return self.train_step(data)
    File "/usr/local/lib/python3.9/dist-packages/transformers/modeling_tf_utils.py", line 1703, in train_step
        self.optimizer.apply_gradients(zip(grads, self.trainable_variables))
    File "/usr/local/lib/python3.9/dist-packages/keras/src/optimizers/base_optimizer.py", line 206, in apply_gradients
        self.apply(grads, trainable_variables)
    File "/usr/local/lib/python3.9/dist-packages/keras/src/optimizers/loss_scale_optimizer.py", line 183, in apply
        ops.cond(finite, handle_finite_grads, handle_non_finite_grads)
    File "/usr/local/lib/python3.9/dist-packages/keras/src/ops/core.py", line 594, in cond
        return Cond()(pred, true_fn, false_fn)
    File "/usr/local/lib/python3.9/dist-packages/keras/src/utils/traceback_utils.py", line 123, in error_handler
        raise e.with_traceback(filtered_tb) from None
    File "/usr/local/lib/python3.9/dist-packages/keras/src/backend/tensorflow/optimizer.py", line 82, in _internal_apply_gradients
        tf.__internal__.distribute.interim.maybe_merge_call(

    RuntimeError: Exception encountered when calling Cond.call().
    
    `merge_call` called while defining a new graph or a tf.function. This can often happen if the function `fn` passed to `strategy.run()` contains a nested `@tf.function`, and the nested `@tf.function` contains a synchronization point, such as aggregating gradients (e.g, optimizer.apply_gradients), or if the function `fn` uses a control flow statement which contains a synchronization point in the body. Such behaviors are not yet supported. Instead, please avoid nested `tf.function`s or control flow statements that may potentially cross a synchronization boundary, for example, wrap the `fn` passed to `strategy.run` or the entire `strategy.run` inside a `tf.function` or move the control flow out of `fn`. If you are subclassing a `tf.keras.Model`, please avoid decorating overridden methods `test_step` and `train_step` in `tf.function`.

Going to tackle this one today.

@ArthurZucker
Copy link
Collaborator

thanks! cc @Rocketknight1

@ekuznetsov139
Copy link
Author

Ok, the mixed precision issue from my last post was actually fixed in Keras, and I was only seeing it because I had a somewhat outdated version of Keras3 in the system (2023-10-17 instead of 2023-12-31.)

There was another issue with mixed precision which only affected testing (I had it fixed in the GPT-2 pathway, it may affect other models.) Saving was also broken. Here's a patch that fixes everything I found, except loss being multiplied by #GPUs:

4fa1260

@ArthurZucker
Copy link
Collaborator

You should open a PR with the patch! 🤗 (linking this issue)

@ekuznetsov139
Copy link
Author

Will do once I'm satisfied that I've resolved all the issues.

@Rocketknight1
Copy link
Member

Hi @ekuznetsov139, thanks for the investigation here - this looks really good! Just to give you some context, the reason the errors change in the latest main version of transformers is that I've been working on Keras 3 PRs behind the scenes as well. The biggest one is that we now use proper build() methods for all our TF models instead of building them with dummy inputs - this avoids lots of issues related to name hierarchies that changed in Keras 3. You can see some of the PRs here:

I think the plan from here is that in our TensorFlow code, we're going to completely remove all direct imports of keras, and only use from tensorflow import keras. This ties the Keras version to the TF version, although we will still need to support Keras 3 as we understand that the built-in version of Keras is going to be Keras 3 starting from TF 2.16.

Our primary goal is to ensure that Keras 3 doesn't break backward compatibility for TF code, even if we don't fully support other frameworks with Keras 3. Once backward compatibility is secure, we have plans to fully support Keras 3, which will probably require a community push to make full Keras ports of all of our models that don't use any TensorFlow ops - there's a partial PR at #26224 but it's on hold because of the number of other backward compatibility issues that need to be resolved first.

@lingluodlut
Copy link

Hi @ekuznetsov139 I also meet the same problems when I used tensorflow & keras 3 to load transformers models. Do you fix it?

@Rocketknight1
Copy link
Member

Hi @lingluodlut @ekuznetsov139, I believe this is the last PR we need #28588

Note that we still won't have full Keras 3 support, but at least Transformers will continue working when Keras 3 is installed after this PR is merged.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging a pull request may close this issue.

4 participants