-
Notifications
You must be signed in to change notification settings - Fork 27.7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Fix T5 and BART for TF #9063
Fix T5 and BART for TF #9063
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Way cleaner! Thanks a lot for doing this!
tests/test_modeling_tf_common.py
Outdated
key = np.array(key, dtype=bool) | ||
pt_inputs_dict[name] = torch.from_numpy(key).to(torch.long) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Any reason to change a bool to a tensor? We can just keep it as a boolean, no?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is because now, the use_cache
parameter returned by the _prepare_for_class
method in the T5 test file is now a primitive boolean and not anymore a tensor. Thanks to the changes the value is not forced to be a tensor anymore :)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes I understand that. PyTorch accepts regular booleans in its models, that's why I'm confused why we convert it to a PyTorch tensor here.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is because of the .numpy()
to convert TF tensors to numpy array. Simple booleans don't have this attribute.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
My question is why not just pass pt_inputs_dict[name] = key
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ah ok!! Simply because I didn't know 😄 Just pushed the update^^
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Overall way cleaner. Happy to see the cast_bool_to_primitive
go!
Thanks for working on this.
output_attentions = output_attentions if output_attentions is not None else self.output_attentions | ||
output_hidden_states = output_hidden_states if output_hidden_states is not None else self.output_hidden_states | ||
use_cache = use_cache if use_cache is not None else self.use_cache | ||
return_dict = return_dict if return_dict is not None else self.config.return_dict | ||
if use_cache: | ||
assert not training, "Training + use cache are incompatible" | ||
# check attention mask and invert | ||
use_cache = cast_bool_to_primitive(use_cache) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Does removing this mean the TFBartDecoder
will not be able to handle the output attentions/hidden states/cache/return dict parameters on its own, but will instead rely on being called by the class TFBart{Model, ForConditionalGeneration}
?
If @patrickvonplaten is doing the same refactor for TF BART as he did for the PT version, then one of the enhancements it offers is being able to use the (TF)BartDecoder
as a standalone model. Will this change prevent this?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If the (TF)BartDecoder
aims to be used as a standalone, the input_processing
must be added. I can add it just in case, it won't arm the current bebavior, just an "over" processing. Should I do it?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think @patrickvonplaten will probably take care of that in his PR, so no problem here. I would still wait for his review before merging!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I haven't looked at having neither TFBart nor TFT5 as "encoder-only" or "decoder-only" model yet because a) there is not TFEncoderDecoder model and b) because of the issues that will be solved in this PR. So I'm 100% fine to delete it for now here
decoder_hidden_states=decoder_outputs[2], | ||
decoder_attentions=decoder_outputs[3], | ||
encoder_last_hidden_state=inputs["encoder_outputs"][0], | ||
encoder_hidden_states=inputs["encoder_outputs"][1], | ||
encoder_attentions=inputs["encoder_outputs"][2], | ||
decoder_hidden_states=decoder_outputs.hidden_states, | ||
decoder_attentions=decoder_outputs.attentions, | ||
encoder_last_hidden_state=inputs["encoder_outputs"].last_hidden_state, | ||
encoder_hidden_states=inputs["encoder_outputs"].hidden_states, | ||
encoder_attentions=inputs["encoder_outputs"].attentions, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Infinitely cleaner
@@ -354,7 +354,8 @@ def input_processing(func, config, input_ids, **kwargs): | |||
if isinstance(v, allowed_types) or v is None: | |||
output[k] = v | |||
else: | |||
raise ValueError(f"Data of type {type(v)} is not allowed only tf.Tensor is accepted for {k}.") | |||
print(k, v) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
think we can delete print(k, v)
no?
@@ -1366,31 +1367,6 @@ def get_initializer(initializer_range: float = 0.02) -> tf.initializers.Truncate | |||
return tf.keras.initializers.TruncatedNormal(stddev=initializer_range) | |||
|
|||
|
|||
def cast_bool_to_primitive(bool_variable: Union[tf.Tensor, bool], default_tensor_to_true=False) -> bool: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yeeees :-)
@@ -1046,7 +1029,7 @@ def __init__(self, config, *inputs, **kwargs): | |||
self.use_cache = config.use_cache | |||
# final_bias_logits is registered as a buffer in pytorch, so not trainable for the the sake of consistency. | |||
self.final_logits_bias = self.add_weight( | |||
name="/final_logits_bias", shape=[1, config.vocab_size], initializer="zeros", trainable=False |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
great catch!
tests/test_modeling_tf_common.py
Outdated
@@ -574,7 +591,10 @@ def check_hidden_states_output(config, inputs_dict, model_class): | |||
self.model_tester, "expected_num_hidden_layers", self.model_tester.num_hidden_layers + 1 | |||
) | |||
|
|||
hidden_states = outputs[-1] | |||
if hasattr(outputs, "hidden_states"): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can we have an if model.config.is_encoder_decoder
case here instead ? Seems more in line with PT tests and we should test both decoder_hidden_states
and encoder_hidden_states
for encoder-decoder
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Better?
I should have addressed everybody's comments :) |
if self.is_encoder_decoder: | ||
hidden_states = outputs.decoder_hidden_states | ||
else: | ||
hidden_states = outputs.hidden_states |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sorry, I meant something a bit more like this:
if model.config.is_encoder_decoder:
encoder_hidden_states = outputs.encoder_hidden_states
decoder_hidden_states = outputs.decoder_hidden_states
self.assertEqual(config.output_attentions, False)
self.assertEqual(len(encoder_hidden_states), expected_num_layers)
self.assertListEqual(
list(encoder_hidden_states[0].shape[-2:]),
[self.model_tester.seq_length, self.model_tester.hidden_size],
)
self.assertEqual(len(decoder_hidden_states), expected_num_layers)
self.assertListEqual(
list(decoder_hidden_states[0].shape[-2:]),
[self.model_tester.seq_length, self.model_tester.hidden_size],
)
else:
hidden_states = outputs.hidden_states
self.assertEqual(config.output_attentions, False)
self.assertEqual(len(hidden_states), expected_num_layers)
self.assertListEqual(
list(hidden_states[0].shape[-2:]),
[self.model_tester.seq_length, self.model_tester.hidden_size],
)
- Always good to check for both encoder & decoder for
encoder_decoder
and 2) I prefer to stop using theself.is_encoder_decoder
flag (the config better defines whether a model is an encoder-decoder - not really the test case IMO)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ok, just pushed the fix!!!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just have one comment left for the test. After this I think we can merge and I'll rebase my TFBart refactor PR on the new changes here :-)
What does this PR do?
This PR fix the TensorFlow implementation of T5 and BART to make them graph compilation+execution compliant and then be able to create a savedmodel for each them.
The slow tests
test_saved_model_with_hidden_states_output
andtest_saved_model_with_attentions_output
are now passing for both models.