Skip to content

Commit

Permalink
Fix: FA2 with packed training (#32487)
Browse files Browse the repository at this point in the history
* fix check

* add tests

* [run-slow] llama, gemma2

* oops, whisper actually runs but needed some special treatment
  • Loading branch information
zucchini-nlp authored and ArthurZucker committed Aug 16, 2024
1 parent 734dca2 commit 5674d8b
Show file tree
Hide file tree
Showing 4 changed files with 120 additions and 6 deletions.
9 changes: 4 additions & 5 deletions src/transformers/modeling_flash_attention_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -264,11 +264,10 @@ def _flash_attention_forward(
)
attn_output = pad_input(attn_output_unpad, indices_q, batch_size, query_length)

# if position_ids is provided and check not all examples (row) contain only 1 sequence, and is in pre-fill/training stage
# then use `flash_attn_varlen_func` to prevent cross-example attention and also allow padding free approach
elif (
position_ids is not None and not (position_ids[:, -1] == position_ids.size(1) - 1).all() and query_length != 1
):
# If position_ids is provided and check all examples do not contain only 1 sequence, If tensor in increasing
# then we probably have one sequence, otherwise it is packed. Additionally check we are in pre-fill/training stage.
# Use `flash_attn_varlen_func` to prevent cross-example attention and also allow padding free approach
elif position_ids is not None and not (torch.diff(position_ids, dim=-1) >= 0).all() and query_length != 1:
batch_size = query_states.size(0)
query_states, key_states, value_states, indices_q, cu_seq_lens, max_seq_lens = prepare_fa2_from_position_ids(
query_states, key_states, value_states, position_ids
Expand Down
2 changes: 1 addition & 1 deletion src/transformers/models/phi3/modeling_phi3.py
Original file line number Diff line number Diff line change
Expand Up @@ -540,7 +540,7 @@ def forward(
max(kv_seq_len, position_ids[:, -1].max().item() + 1) if position_ids is not None else kv_seq_len
)

cos, sin = self.rotary_emb(value_states, seq_len=rotary_seq_len)
cos, sin = self.rotary_emb(value_states, seq_len=rotary_seq_len, position_ids=position_ids)

query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)

Expand Down
59 changes: 59 additions & 0 deletions tests/models/whisper/test_modeling_whisper.py
Original file line number Diff line number Diff line change
Expand Up @@ -1844,6 +1844,59 @@ def test_generate_output_type(self, return_dict_in_generate):
)
assert isinstance(pred_ids, expected_output_type)

@require_flash_attn
@require_torch_gpu
@pytest.mark.flash_attn_test
@slow
def test_flash_attn_2_generate_reuse_cache(self):
max_new_tokens = 2
for model_class in self.all_generative_model_classes:
if not model_class._supports_flash_attn_2:
self.skipTest(f"{model_class.__name__} does not support Flash Attention 2")

config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()

dummy_input = inputs_dict[model_class.main_input_name][..., :10]
if dummy_input.dtype in [torch.float32, torch.bfloat16]:
dummy_input = dummy_input.to(torch.float16)

# make sure that all models have enough positions for generation
if hasattr(config, "max_position_embeddings"):
config.max_position_embeddings = dummy_input.shape[1] * 2 + max_new_tokens * 2 + 1

model = model_class(config)

with tempfile.TemporaryDirectory() as tmpdirname:
model.save_pretrained(tmpdirname)

model = model_class.from_pretrained(
tmpdirname,
torch_dtype=torch.float16,
attn_implementation="flash_attention_2",
low_cpu_mem_usage=True,
).to(torch_device)

# run generate once to get filled cache
output = model.generate(
dummy_input,
max_new_tokens=max_new_tokens,
do_sample=False,
use_cache=True,
return_dict_in_generate=True,
)
past_key_values = output.past_key_values

# Try to continue generation from where we left, given that we have more than 1 new token to process
# e.g. this can happen in speculative decoding when feeding candidate tokens back to target model
_ = model.generate(
dummy_input,
decoder_input_ids=output.sequences,
max_new_tokens=max_new_tokens,
do_sample=False,
use_cache=True,
past_key_values=past_key_values,
)


@require_torch
@require_torchaudio
Expand Down Expand Up @@ -4071,6 +4124,12 @@ def test_retain_grad_hidden_states_attentions(self):
def test_save_load_fast_init_from_base(self):
pass

@unittest.skip(
reason="FA2 testing suite needs to be refactored to be compatible with WhisperDecoder for that test"
)
def test_flash_attn_2_generate_reuse_cache(self):
pass

@unittest.skip(
"Duplicated test with WhisperModelTest + the FA2 testing suite needs to be refactored to be compatible with WhisperDecoder for that test"
)
Expand Down
56 changes: 56 additions & 0 deletions tests/test_modeling_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -4331,6 +4331,62 @@ def test_flash_attn_2_generate_use_cache(self):
use_cache=True,
)

@require_flash_attn
@require_torch_gpu
@mark.flash_attn_test
@slow
def test_flash_attn_2_generate_reuse_cache(self):
if not self.has_attentions:
self.skipTest(reason="Model architecture does not support attentions")

max_new_tokens = 2
for model_class in self.all_generative_model_classes:
if not model_class._supports_flash_attn_2:
self.skipTest(f"{model_class.__name__} does not support Flash Attention 2")

config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()

dummy_input = inputs_dict[model_class.main_input_name]
if dummy_input.dtype in [torch.float32, torch.bfloat16]:
dummy_input = dummy_input.to(torch.float16)

# make sure that all models have enough positions for generation
if hasattr(config, "max_position_embeddings"):
config.max_position_embeddings = dummy_input.shape[1] * 2 + max_new_tokens * 2 + 1

model = model_class(config)

with tempfile.TemporaryDirectory() as tmpdirname:
model.save_pretrained(tmpdirname)

model = model_class.from_pretrained(
tmpdirname,
torch_dtype=torch.float16,
attn_implementation="flash_attention_2",
low_cpu_mem_usage=True,
).to(torch_device)

# run generate once to get filled cache
output = model.generate(
dummy_input,
max_new_tokens=max_new_tokens,
do_sample=False,
use_cache=True,
return_dict_in_generate=True,
)
past_key_values = output.past_key_values

# Try to continue generation from where we left, given that we have more than 1 new token to process
# e.g. this can happen in speculative decoding when feeding candidate tokens back to target model
dummy_input_updated = torch.cat([dummy_input, output.sequences], dim=-1)
_ = model.generate(
dummy_input_updated,
max_new_tokens=max_new_tokens,
do_sample=False,
use_cache=True,
past_key_values=past_key_values,
)

@require_flash_attn
@require_torch_gpu
@require_bitsandbytes
Expand Down

0 comments on commit 5674d8b

Please sign in to comment.