-
Notifications
You must be signed in to change notification settings - Fork 28k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Tests: upgrade test_eager_matches_sdpa_generate
#34386
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -15,6 +15,7 @@ | |
|
||
|
||
import copy | ||
import gc | ||
import inspect | ||
import tempfile | ||
import unittest | ||
|
@@ -33,6 +34,7 @@ | |
require_torch_gpu, | ||
require_torch_multi_accelerator, | ||
require_torch_multi_gpu, | ||
require_torch_sdpa, | ||
slow, | ||
torch_device, | ||
) | ||
|
@@ -2046,6 +2048,86 @@ def test_inherits_generation_mixin(self): | |
for model_class in self.all_generative_model_classes: | ||
self.assertTrue("GenerationMixin" in str(model_class.__bases__)) | ||
|
||
@require_torch_sdpa | ||
@slow | ||
def test_eager_matches_sdpa_generate(self): | ||
max_new_tokens = 30 | ||
|
||
for model_class in self.all_generative_model_classes: | ||
if not model_class._supports_sdpa: | ||
self.skipTest(f"{model_class.__name__} does not support SDPA") | ||
|
||
config, original_inputs_dict = self.prepare_config_and_inputs_for_generate() | ||
inputs_dict = {} | ||
for input_name, input_data in original_inputs_dict.items(): | ||
if isinstance(input_data, torch.Tensor) and input_data.dtype in [torch.float32, torch.bfloat16]: | ||
inputs_dict[input_name] = input_data.to(torch.float16) | ||
else: | ||
inputs_dict[input_name] = input_data | ||
main_input = inputs_dict[model_class.main_input_name] | ||
Comment on lines
+2060
to
+2067
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Uses |
||
|
||
# make sure that all models have enough positions for generation | ||
if hasattr(config, "max_position_embeddings"): | ||
config.max_position_embeddings = max_new_tokens + main_input.shape[1] + 1 | ||
|
||
model = model_class(config) | ||
|
||
with tempfile.TemporaryDirectory() as tmpdirname: | ||
model.save_pretrained(tmpdirname) | ||
del model | ||
gc.collect() | ||
|
||
generate_kwargs = { | ||
"max_new_tokens": max_new_tokens, | ||
"do_sample": False, | ||
"return_dict_in_generate": True, | ||
"output_scores": True, | ||
} | ||
|
||
model_sdpa = model_class.from_pretrained( | ||
tmpdirname, | ||
torch_dtype=torch.float16, | ||
low_cpu_mem_usage=True, | ||
).to(torch_device) | ||
res_sdpa = model_sdpa.generate(**inputs_dict, **generate_kwargs) | ||
del model_sdpa | ||
gc.collect() | ||
|
||
model_eager = model_class.from_pretrained( | ||
tmpdirname, | ||
torch_dtype=torch.float16, | ||
low_cpu_mem_usage=True, | ||
attn_implementation="eager", | ||
).to(torch_device) | ||
res_eager = model_eager.generate(**inputs_dict, **generate_kwargs) | ||
del model_eager | ||
gc.collect() | ||
|
||
# Eager and SDPA are very similar, but not exactly the same. Because we are using random models, this | ||
# test would be flaky if we only checked the sequences. Two situations in which this test passes: | ||
# 1. The sequences are the same | ||
# 2. The sequences are different, but the scores up until the first mismatch are nearly identical | ||
output_matches = res_eager.sequences == res_sdpa.sequences | ||
has_matching_outputs = output_matches.all() | ||
has_matching_scores = None | ||
if not has_matching_outputs: | ||
input_length = main_input.shape[1] | ||
for batch_idx in range(res_eager.sequences.shape[0]): | ||
batch_matches = output_matches[batch_idx] | ||
if batch_matches.all(): | ||
continue | ||
first_mismatch_idx = batch_matches.int().argmin() # gets the index of the first False | ||
first_mismatch_idx -= input_length # scores doesn't include data regarding input tokens | ||
sdpa_first_mismatch_scores = res_sdpa.scores[first_mismatch_idx][batch_idx] | ||
eager_first_mismatch_scores = res_eager.scores[first_mismatch_idx][batch_idx] | ||
has_matching_scores = torch.allclose( | ||
sdpa_first_mismatch_scores, eager_first_mismatch_scores, rtol=1e-3, atol=1e-3 | ||
) | ||
if not has_matching_scores: | ||
break | ||
Comment on lines
+2106
to
+2127
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. flakiness handling as explained in the PR header |
||
|
||
self.assertTrue(has_matching_outputs or has_matching_scores) | ||
|
||
def _check_outputs(self, output, main_input, config, use_cache=False, num_return_sequences=1): | ||
# we can be sure what is batch size from main input but seq length depends on model type and whether input is text/audio/image | ||
# so we infer actual text seq length from model_tester, same was as it is done in `test_modeling_common.py` tests` | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
(this is mostly copy-paste, going to comment the sections that are changed)