-
Notifications
You must be signed in to change notification settings - Fork 27.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
TF: XLA model output differs when certain outputs are passed #16838
Comments
How significant are the differences? Would it pass with 1e-1? |
Tried twice, for |
Just tried it here. On CPU:
On GPU (3090, using TensorFloat32):
|
My best guess is that there are two separate issues:
|
Wait so XLA works on GPU, but not on CPU? That's very weird |
@gante Probably the following code and outputs could make you spot the places more easily.
Codeimport numpy as np
import tensorflow as tf
from transformers import TFT5Model, T5Tokenizer
from transformers.utils.generic import ModelOutput
checkpoint = "t5-base"
tokenizer = T5Tokenizer.from_pretrained(checkpoint)
model = TFT5Model.from_pretrained(checkpoint)
# Ugly hack to retrun all outputs
model.config.output_hidden_states = True
model.config.output_attentions = True
model = TFT5Model.from_pretrained(checkpoint, config=model.config)
model_xla = tf.function(model, jit_compile=True)
# tokenizer.pad_token_id = tokenizer.eos_token_id
pad_token_id = tokenizer.pad_token_id
sentence_1 = "Translate English to German: I have a cat, two dogs, three horses, and four birds."
sentence_2 = "Translate English to German: I have a cat, two dogs, and three horses."
ids_single = tokenizer([sentence_1], return_tensors="tf", padding=True).input_ids
decoder_ids_single = tf.zeros((1, 1), dtype=tf.int32)
# attention_single = tf.cast(tf.math.not_equal(ids_single, pad_token_id), dtype=tf.int32) # as computed in generate
attention_single = tf.cast(tf.ones_like(ids_single), dtype=tf.int32) # as computed in generate
ids_pair = tokenizer([sentence_1, sentence_2], return_tensors="tf", padding=True).input_ids
decoder_ids_pair = tf.zeros((2, 1), dtype=tf.int32)
# attention_pair = tf.cast(tf.math.not_equal(ids_pair, pad_token_id), dtype=tf.int32) # as computed in generate
attention_pair = tf.cast(tf.ones_like(ids_pair), dtype=tf.int32)
# case 3 FAILING: with batch size = 1 and attention mask, XLA and non-XLA match
outputs = model(input_ids=ids_single, decoder_input_ids=decoder_ids_single, attention_mask=attention_single, output_hidden_states=True, output_attentions=True)
outputs_xla = model_xla(input_ids=ids_single, decoder_input_ids=decoder_ids_single, attention_mask=attention_single, output_hidden_states=True, output_attentions=True)
# Please ignore the bad naming - this is just a quick copy from the test script
def check_pt_tf_outputs(tf_outputs, pt_outputs, model_class, tol=1e-5, name="outputs", attributes=None):
# Allow `ModelOutput` (e.g. `CLIPOutput` has `text_model_output` and `vision_model_output`).
if isinstance(tf_outputs, ModelOutput):
tf_keys = tuple([k for k, v in tf_outputs.items() if v is not None])
pt_keys = tuple([k for k, v in pt_outputs.items() if v is not None])
# (if without the hack) XLA models don't return full outputs at this moment ... need to ignore them at this moment
# keys = tuple(set(tf_keys).intersection(pt_keys))
# tf_outputs = tuple([tf_outputs[k] for k in keys])
# pt_outputs = tuple([pt_outputs[k] for k in keys])
# convert to the case of `tuple`
# appending each key to the current (string) `names`
attributes = tuple([f"{name}.{k}" for k in tf_keys])
check_pt_tf_outputs(tf_outputs.to_tuple(), pt_outputs.to_tuple(), model_class, tol=tol, name=name, attributes=attributes)
# Allow `list` (e.g. `TransfoXLModelOutput.mems` is a list of tensors.)
elif type(tf_outputs) in [tuple, list]:
if attributes is not None:
# case 1: each output has assigned name (e.g. a tuple form of a `ModelOutput`)
pass
else:
# case 2: each output has no assigned name (e.g. hidden states of each layer) -> add an index to `names`
attributes = tuple([f"{name}_{idx}" for idx in range(len(tf_outputs))])
for tf_output, pt_output, attr in zip(tf_outputs, pt_outputs, attributes):
check_pt_tf_outputs(tf_output, pt_output, model_class, tol=tol, name=attr)
elif isinstance(tf_outputs, tf.Tensor):
tf_outputs = tf_outputs.numpy()
pt_outputs = pt_outputs.numpy()
# deal with NumPy's scalars to make replacing nan values by 0 work.
if np.isscalar(tf_outputs):
tf_outputs = np.array([tf_outputs])
pt_outputs = np.array([pt_outputs])
tf_nans = np.isnan(tf_outputs)
pt_nans = np.isnan(pt_outputs)
pt_outputs[tf_nans] = 0
tf_outputs[tf_nans] = 0
pt_outputs[pt_nans] = 0
tf_outputs[pt_nans] = 0
max_diff = np.amax(np.abs(tf_outputs - pt_outputs))
print(f"{name}: {max_diff}")
else:
raise ValueError(
f"`tf_outputs` should be an instance of `tf.Tensor`, a `tuple`, or an instance of `tf.Tensor`. Got {type(tf_outputs)} instead.")
check_pt_tf_outputs(outputs, outputs_xla, model_class=TFT5Model) Outputsoutputs.last_hidden_state: 2.800762176513672
outputs.past_key_values_0_0: 4.291534423828125e-06
outputs.past_key_values_0_1: 1.0728836059570312e-06
outputs.past_key_values_0_2: 3.4570693969726562e-06
outputs.past_key_values_0_3: 3.337860107421875e-06
outputs.past_key_values_1_0: 0.4949379563331604
outputs.past_key_values_1_1: 0.8448842763900757
outputs.past_key_values_1_2: 4.291534423828125e-06
outputs.past_key_values_1_3: 4.887580871582031e-06
outputs.past_key_values_2_0: 0.4911351203918457
outputs.past_key_values_2_1: 0.5065852403640747
outputs.past_key_values_2_2: 4.76837158203125e-06
outputs.past_key_values_2_3: 5.7220458984375e-06
outputs.past_key_values_3_0: 0.47093653678894043
outputs.past_key_values_3_1: 0.5624567270278931
outputs.past_key_values_3_2: 4.410743713378906e-06
outputs.past_key_values_3_3: 5.9604644775390625e-06
outputs.past_key_values_4_0: 0.775518536567688
outputs.past_key_values_4_1: 0.934751570224762
outputs.past_key_values_4_2: 5.7220458984375e-06
outputs.past_key_values_4_3: 7.152557373046875e-06
outputs.past_key_values_5_0: 1.0620229244232178
outputs.past_key_values_5_1: 1.1955945491790771
outputs.past_key_values_5_2: 5.7220458984375e-06
outputs.past_key_values_5_3: 9.059906005859375e-06
outputs.past_key_values_6_0: 1.5020784139633179
outputs.past_key_values_6_1: 1.768876552581787
outputs.past_key_values_6_2: 6.4373016357421875e-06
outputs.past_key_values_6_3: 8.344650268554688e-06
outputs.past_key_values_7_0: 1.9831377267837524
outputs.past_key_values_7_1: 1.7343039512634277
outputs.past_key_values_7_2: 6.67572021484375e-06
outputs.past_key_values_7_3: 1.0251998901367188e-05
outputs.past_key_values_8_0: 2.3230268955230713
outputs.past_key_values_8_1: 2.937762498855591
outputs.past_key_values_8_2: 5.7220458984375e-06
outputs.past_key_values_8_3: 9.775161743164062e-06
outputs.past_key_values_9_0: 2.8203392028808594
outputs.past_key_values_9_1: 5.384043216705322
outputs.past_key_values_9_2: 5.9604644775390625e-06
outputs.past_key_values_9_3: 1.33514404296875e-05
outputs.past_key_values_10_0: 4.303163528442383
outputs.past_key_values_10_1: 10.02894401550293
outputs.past_key_values_10_2: 6.198883056640625e-06
outputs.past_key_values_10_3: 1.430511474609375e-05
outputs.past_key_values_11_0: 4.163003921508789
outputs.past_key_values_11_1: 7.657519817352295
outputs.past_key_values_11_2: 4.76837158203125e-06
outputs.past_key_values_11_3: 1.9073486328125e-05
outputs.decoder_hidden_states_0: 0.0
outputs.decoder_hidden_states_1: 2151.3359375
outputs.decoder_hidden_states_2: 2724.79736328125
outputs.decoder_hidden_states_3: 4147.70751953125
outputs.decoder_hidden_states_4: 6162.63720703125
outputs.decoder_hidden_states_5: 7066.3046875
outputs.decoder_hidden_states_6: 7329.43603515625
outputs.decoder_hidden_states_7: 7471.92333984375
outputs.decoder_hidden_states_8: 7749.91162109375
outputs.decoder_hidden_states_9: 8324.51953125
outputs.decoder_hidden_states_10: 8609.3359375
outputs.decoder_hidden_states_11: 7732.30224609375
outputs.decoder_hidden_states_12: 2.800762176513672
outputs.decoder_attentions_0: 0.0
outputs.decoder_attentions_1: 0.0
outputs.decoder_attentions_2: 0.0
outputs.decoder_attentions_3: 0.0
outputs.decoder_attentions_4: 0.0
outputs.decoder_attentions_5: 0.0
outputs.decoder_attentions_6: 0.0
outputs.decoder_attentions_7: 0.0
outputs.decoder_attentions_8: 0.0
outputs.decoder_attentions_9: 0.0
outputs.decoder_attentions_10: 0.0
outputs.decoder_attentions_11: 0.0
outputs.cross_attentions_0: 0.9293187856674194
outputs.cross_attentions_1: 0.8967262506484985
outputs.cross_attentions_2: 0.7246492505073547
outputs.cross_attentions_3: 0.9164008498191833
outputs.cross_attentions_4: 0.8164070248603821
outputs.cross_attentions_5: 0.7364302277565002
outputs.cross_attentions_6: 0.6568543314933777
outputs.cross_attentions_7: 0.6275004744529724
outputs.cross_attentions_8: 0.6810514330863953
outputs.cross_attentions_9: 0.631909966468811
outputs.cross_attentions_10: 0.4159456491470337
outputs.cross_attentions_11: 0.39396628737449646
outputs.encoder_last_hidden_state: 5.960464477539062e-07
outputs.encoder_hidden_states_0: 0.0
outputs.encoder_hidden_states_1: 0.000244140625
outputs.encoder_hidden_states_2: 0.0003662109375
outputs.encoder_hidden_states_3: 0.00048828125
outputs.encoder_hidden_states_4: 0.00048828125
outputs.encoder_hidden_states_5: 0.00048828125
outputs.encoder_hidden_states_6: 0.0009765625
outputs.encoder_hidden_states_7: 0.00048828125
outputs.encoder_hidden_states_8: 0.001953125
outputs.encoder_hidden_states_9: 0.001953125
outputs.encoder_hidden_states_10: 0.0078125
outputs.encoder_hidden_states_11: 0.0078125
outputs.encoder_hidden_states_12: 5.960464477539062e-07
outputs.encoder_attentions_0: 5.066394805908203e-07
outputs.encoder_attentions_1: 5.364418029785156e-07
outputs.encoder_attentions_2: 7.152557373046875e-07
outputs.encoder_attentions_3: 5.960464477539062e-07
outputs.encoder_attentions_4: 5.662441253662109e-07
outputs.encoder_attentions_5: 5.960464477539062e-07
outputs.encoder_attentions_6: 5.364418029785156e-07
outputs.encoder_attentions_7: 6.258487701416016e-07
outputs.encoder_attentions_8: 8.642673492431641e-07
outputs.encoder_attentions_9: 5.960464477539062e-07
outputs.encoder_attentions_10: 7.152557373046875e-07
outputs.encoder_attentions_11: 5.960464477539062e-07 |
Thank you for your suggestions, you have solved the puzzle 🙏 The winning suggestion award goes to @Rocketknight1 -- XLA on CPU is indeed buggy. I've spun up an Nvidia T4 ( = no As a result of this thread, I was thinking of:
Greedy search translating correctly with GPU: Greedy search failing with CPU: Sample behaving okay with GPU (sampling 10 outputs for the first sentence input): |
@gante I think they certainly would be interested, but we'd have to localize the bug a little more! If you could fix an input and make a minimal single module that showed the buggy behaviour, you should definitely report that upstream. I totally understand if that's not a priority with everything else on your plate, though! |
Cool, great job guys in locating the error! I don't think it's a good idea to to raise an error / exception if XLA is enabled on CPU. XLA should work on CPU - why wouldn't it? To me this clearly looks like a TF bug and quite a big one actually. IMO, lots of people debug their code on CPU in XLA so I think it is pretty important that it works on CPU. Also we need to test XLA on CPU as well so that it runs on circle ci IMO cc @sanchit-gandhi, who's is working quite a bit with XLA at the moment. |
It shouldn't be too difficult to locate where the difference is coming from since we know that without attention_mask is works no? |
Change this line
to
will solve the problem. This gives the same weights (on CPU + XLA) as the ones computed on GPU machine (both non-XLA & XLA). I tested this trick with @gante code samples. I also looked the expected values for [8.03906238e-04 4.91665269e-04 6.60848498e-01 7.20867813e-02, ...] without this, on CPU + XLA, we get [0.04347826, 0.04347826, 0.04347826, 0.04347826, ...] I guess some trick (about numerical stability of Softmax) is not done for XLA + CPU. The code I useimport numpy as np
import tensorflow as tf
from transformers import TFT5Model, T5Tokenizer
from transformers.utils.generic import ModelOutput
checkpoint = "t5-base"
tokenizer = T5Tokenizer.from_pretrained(checkpoint)
model = TFT5Model.from_pretrained(checkpoint)
# Ugly hack to retrun all outputs
model.config.output_hidden_states = True
model.config.output_attentions = True
model = TFT5Model.from_pretrained(checkpoint, config=model.config)
model_xla = tf.function(model, jit_compile=True)
# tokenizer.pad_token_id = tokenizer.eos_token_id
pad_token_id = tokenizer.pad_token_id
sentence_1 = "I have a cat, two dogs"
sentence_2 = "I have a cat"
sentence_1 = "Translate English to German: I have a cat, two dogs, three horses, and four birds."
sentence_2 = "Translate English to German: I have a cat, two dogs, and three horses."
ids_single = tokenizer([sentence_1], return_tensors="tf", padding=True).input_ids
decoder_ids_single = tf.zeros((1, 1), dtype=tf.int32)
# attention_single = tf.cast(tf.math.not_equal(ids_single, pad_token_id), dtype=tf.int32) # as computed in generate
attention_single = tf.cast(tf.ones_like(ids_single), dtype=tf.int32) # as computed in generate
decoder_attention_single = tf.cast(tf.ones_like(decoder_ids_single), dtype=tf.int32) # as computed in generate
ids_pair = tokenizer([sentence_1, sentence_2], return_tensors="tf", padding=True).input_ids
decoder_ids_pair = tf.zeros((2, 1), dtype=tf.int32)
# attention_pair = tf.cast(tf.math.not_equal(ids_pair, pad_token_id), dtype=tf.int32) # as computed in generate
attention_pair = tf.cast(tf.ones_like(ids_pair), dtype=tf.int32)
decoder_attention_pair = tf.cast(tf.ones_like(decoder_ids_pair), dtype=tf.int32) # as computed in generate
# case 3 FAILING: with batch size = 1 and attention mask, XLA and non-XLA match
outputs = model(input_ids=ids_single, decoder_input_ids=decoder_ids_single, attention_mask=attention_single, decoder_attention_mask=decoder_attention_single, output_hidden_states=True, output_attentions=True)
outputs_xla = model_xla(input_ids=ids_single, decoder_input_ids=decoder_ids_single, attention_mask=attention_single, decoder_attention_mask=decoder_attention_single, output_hidden_states=True, output_attentions=True) |
As @patrickvonplaten mentioned, it's pretty imperative to have XLA working on CPU for any kind of debugging - there are all sorts of debugging methods that pull values back to the host and perform checks on an op-by-op basis (see https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html#nans). These are pretty crucial for understanding the inner-workings of a compiled function that you wouldn't otherwise see if running XLA purely on an accelerator. Also running JAX/Flax on CPU, the floating-point precision of internal computations used in TPU matrix multiplications and convolutions is always highest. When you move to a TPU, the floating-point precision is lowered by default. We need to be able to test our code on CPU to run at highest precision, especially for any sort of PT-Flax equivalence tests (see #15754). I'm not familiar with how TF treats matmul precisions, but it's these sorts of considerations that mean running XLA on CPU is pretty essential! |
Great find @ydshieh! We should talk to the TF guys about this no? |
(sorry, accidentally edited @patrickvonplaten above comment) Yes. Let's extract (or create) some inputs , and reproduce the issue with only the softmax part. |
This is great @ydshieh 🔥 I'm going to build a toy example and open an issue in TF, linking to this thread. |
Pinned the problem: it is due to the softmax with numerically masked (= large negative) inputs, on XLA+CPU. I've opened an issue on TensorFlow (as backlinked above), where it contains a simple reproducible example. Meanwhile, avoid XLA+CPU :D |
If this would require long time for TF team to fix, we might use a wrapped version of |
I'm sure somewhere hidden there is a tf softmax that is stable on XLA. We could then create a custom |
It should work! The toy example below adds said wrapper (with import tensorflow as tf
LARGE_PENALTY = -1e9
def stable_softmax(x):
return tf.nn.softmax(x + 1)
def masked_softmax(x, boolean_mask):
numerical_mask = (1. - tf.cast(boolean_mask, dtype=tf.float32)) * LARGE_PENALTY
masked_x = x + numerical_mask
return stable_softmax(masked_x)
xla_masked_softmax = tf.function(masked_softmax, jit_compile=True)
xla_stable_softmax = tf.function(stable_softmax, jit_compile=True)
x = tf.random.normal((1, 10))
# same outcome regardless of the boolean mask here
boolean_mask = tf.convert_to_tensor([[1] * 9 + [0] * 1], dtype=tf.int32)
# passes
numerical_mask = (1. - tf.cast(boolean_mask, dtype=tf.float32)) * LARGE_PENALTY
masked_x = x + numerical_mask
xla_out = xla_stable_softmax(masked_x)
out = stable_softmax(masked_x)
print(tf.math.reduce_max(tf.math.abs(xla_out - out)).numpy())
assert tf.experimental.numpy.allclose(xla_out, out)
# The stable softmax has the same output as the original fn
unstable_out = tf.nn.softmax(masked_x)
print(tf.math.reduce_max(tf.math.abs(unstable_out - out)).numpy())
assert tf.experimental.numpy.allclose(unstable_out, out)
# passes (with the + 1 in the softmax)
xla_out = xla_masked_softmax(x, boolean_mask)
out = masked_softmax(x, boolean_mask)
print(tf.math.reduce_max(tf.math.abs(xla_out - out)).numpy())
assert tf.experimental.numpy.allclose(xla_out, out) Opening a PR soon with this temporary fix, and will replace ALL softmax calls with this wrapped version. |
|
The problem with |
@ydshieh I agree that it should be more stable numerically, but I'd rather add a fixed constant. Perhaps not |
OK, good point @gante . And my suggestion didn't work well even with your code above! So good for me to use a constant. |
From further experimentation, I think the reason the small constant works has nothing to do with numerical stability - I think inserting an addition just changes the particular compiled program that XLA generates, and so avoids this issue. |
Depending on the passed inputs, the output of an XLA-compiled model may significantly differ from its non-XLA counterpart. This suggests we should add tests for XLA-output equivalence, just like we do with e.g. PT-TF, as it is not guaranteed.
At the moment, this blocks further developments in
generate()
(can't reliably reproduce non-XLA results with XLA). I will assess this problem for T5 (first model where I've noticed this), then check whether it is present for other key models, and finally add equivalence tests.cc @patrickvonplaten @Rocketknight1 (feel free to pitch in with ideas and suggestions)
Example for reproducibility (updated: assert diff < x-> print diff):
The text was updated successfully, but these errors were encountered: