Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Shift Computation of PaddingFree Variable CuSeqLen from Flash Attention Forward to DataCollatorWithFlattening #65

Conversation

achew010
Copy link
Contributor

@achew010 achew010 commented Aug 8, 2024

Description

To save on repeated computation of the same input across all the flash attention forward layers, this PR shifts the computation of arguments cu_seq_len and max_length to the data collator.

Running the same benchmark on ORCA Math from Transformers Padding Free, there is approx 2% improvement in train runtime from this implementation change.

Implementation Train Runtime (secs) Throughput (Toks/sec)
Transformers Native Padding Free 788 1253
Plugin PaddingFree (Transformers < 4.43) - Computation of Arguments at FA Forward 567 1741
Plugin PaddingFree (Transformers < 4.43) - Computation of Arguments at Data Collator 555 1779

Reproduce

NOTE: The top level model's forward method had to be patched with a wrapper function to accept the additional kwargs from the data collator. Due to the change in method signature, HFTrainer will remove unused columns not detected from the model's forward method signature, causing it To avoid the removal, additional arg, remove_unused_columns=False needs to be passed to the Trainer.

accelerate launch --config_file scripts/benchmarks/accelerate.yaml --num_processes=8 --main_process_port=29500 -m tuning.sft_trainer --padding_free huggingface --model_name_or_path mistralai/Mistral-7B-v0.1 --packing False --max_seq_len 4096 --training_data_path /workspace/data/datasets/orca-math-rd-bench.json --learning_rate 2e-5 --torch_dtype float16 --gradient_accumulation_steps 2 --use_flash_attn True --include_tokens_per_second True --num_train_epochs 1 --gradient_checkpointing True --evaluation_strategy no --save_strategy no --weight_decay 0.01 --warmup_steps 10 --adam_epsilon 1e-4 --lr_scheduler_type linear --logging_strategy steps --logging_steps 10 --per_device_train_batch_size 4 --output_dir benchmark_outputs/ilab --skip_memory_metrics False --remove_unused_columns False

TODO:

  • Consider making this implementation the default for all Transformer versions in the AADP plugin

@achew010 achew010 changed the title Shift Computation Cumulative Seq Lens from Flash Attention Forward to DataCollatorWithFlattening Shift Computation of PaddingFree Variable CuSeqLen from Flash Attention Forward to DataCollatorWithFlattening Aug 8, 2024
Signed-off-by: 1000850000 user <aaron.chew1@ibm.com>
Signed-off-by: 1000850000 user <aaron.chew1@ibm.com>
@achew010 achew010 force-pushed the extract-cumsumlens-to-collator branch from 4e42420 to e6f4856 Compare August 8, 2024 07:19
@fabianlim
Copy link
Contributor

closing this PR for now because there are issues identified with this impl.

@fabianlim fabianlim closed this Aug 19, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants