Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[trainer] figuring out why eval with --fp16_full_eval is 25% slower #10816

Open
stas00 opened this issue Mar 20, 2021 · 11 comments
Open

[trainer] figuring out why eval with --fp16_full_eval is 25% slower #10816

stas00 opened this issue Mar 20, 2021 · 11 comments
Labels
Good First Issue Good Second Issue Issues that are more difficult to do than "Good First" issues - give it a try if you want!

Comments

@stas00
Copy link
Contributor

stas00 commented Mar 20, 2021

Recently HF trainer was extended to support full fp16 eval via --fp16_full_eval. I'd have expected it to be either equal or faster than eval with fp32 model, but surprisingly I have noticed a 25% slowdown when using it.

This may or may not impact deepspeed as well, which also runs eval in fp16, but we can't compare it to a baseline, since it only runs fp16.

I wonder if someone would like to research where the slowdown comes from.

I'd probably isolate the model.half() call which should be a constant and focus on the rest of the eval. I'm thinking that some component doesn't take well to fp16 variables. e.g. label smoothing was problematic and now should be fixed in #10815, but I tested w/ and w/o label smoothing and it's not adding to the slowdown.

Here are the script and the corresponding metrics.

First w/o --fp16_full_eval,

export BS=16; rm -r output_dir; PYTHONPATH=src USE_TF=0 CUDA_VISIBLE_DEVICES=0 \
./examples/seq2seq/run_translation.py --model_name_or_path t5-small --output_dir /tmp/zero3 \
--overwrite_output_dir --max_train_samples 10 --max_val_samples 100 --max_source_length 12 \
--max_target_length 128 --val_max_target_length 128 --do_train --num_train_epochs 1 \
--per_device_train_batch_size 2 --learning_rate 3e-3 --warmup_steps 8 --predict_with_generate \
--logging_steps 0 --save_steps 2 --eval_steps 1 --group_by_length --adafactor --dataset_name wmt16 \
--dataset_config ro-en --source_lang en --target_lang ro \
--source_prefix "translate English to Romanian: " --do_eval 

***** train metrics *****
  epoch                      =    1.0
  init_mem_cpu_alloc_delta   =    2MB
  init_mem_cpu_peaked_delta  =    0MB
  init_mem_gpu_alloc_delta   =  230MB
  init_mem_gpu_peaked_delta  =    0MB
  train_mem_cpu_alloc_delta  =   60MB
  train_mem_cpu_peaked_delta =   63MB
  train_mem_gpu_alloc_delta  =  231MB
  train_mem_gpu_peaked_delta =  194MB
  train_runtime              = 7.7162
  train_samples              =     10
  train_samples_per_second   =  0.648
  
***** eval metrics *****
  epoch                     =    1.0
  eval_bleu                 = 2.4612
  eval_gen_len              =  18.53
  eval_loss                 =  5.017
  eval_mem_cpu_alloc_delta  =    0MB
  eval_mem_cpu_peaked_delta =    0MB
  eval_mem_gpu_alloc_delta  =    0MB
  eval_mem_gpu_peaked_delta =  244MB
  eval_runtime              = 4.6481
  eval_samples              =    100
  eval_samples_per_second   = 21.514

now let's add --fp16_full_eval:

export BS=16; rm -r output_dir; PYTHONPATH=src USE_TF=0 CUDA_VISIBLE_DEVICES=0 \
./examples/seq2seq/run_translation.py --model_name_or_path t5-small --output_dir /tmp/zero3 \
--overwrite_output_dir --max_train_samples 10 --max_val_samples 100 --max_source_length 12 \
--max_target_length 128 --val_max_target_length 128 --do_train --num_train_epochs 1 \
--per_device_train_batch_size 2 --learning_rate 3e-3 --warmup_steps 8 --predict_with_generate \
--logging_steps 0 --save_steps 2 --eval_steps 1 --group_by_length --adafactor --dataset_name wmt16 \
--dataset_config ro-en --source_lang en --target_lang ro \
--source_prefix "translate English to Romanian: " --do_eval  \
--fp16_full_eval

***** train metrics *****
  epoch                      =    1.0
  init_mem_cpu_alloc_delta   =    2MB
  init_mem_cpu_peaked_delta  =    0MB
  init_mem_gpu_alloc_delta   =  230MB
  init_mem_gpu_peaked_delta  =    0MB
  train_mem_cpu_alloc_delta  =   60MB
  train_mem_cpu_peaked_delta =   63MB
  train_mem_gpu_alloc_delta  =  231MB
  train_mem_gpu_peaked_delta =  194MB
  train_runtime              = 7.1477
  train_samples              =     10
  train_samples_per_second   =    0.7

***** eval metrics *****
  epoch                     =    1.0
  eval_bleu                 = 2.4612
  eval_gen_len              =  18.53
  eval_loss                 = 5.0168
  eval_mem_cpu_alloc_delta  =    0MB
  eval_mem_cpu_peaked_delta =    0MB
  eval_mem_gpu_alloc_delta  = -231MB
  eval_mem_gpu_peaked_delta =  262MB
  eval_runtime              = 6.0125
  eval_samples              =    100
  eval_samples_per_second   = 16.632

As you can see w/o --fp16_full_eval: we get ~22 samples per sec and w/ it only ~17/ - that's a huge difference.

I also tested with a larger sample and the gap remains constant.

The halving happens here:

model = model.half().to(self.args.device)

Thank you!

@stas00 stas00 added Good First Issue Good Second Issue Issues that are more difficult to do than "Good First" issues - give it a try if you want! labels Mar 20, 2021
@bhadreshpsavani
Copy link
Contributor

Hi @stas00,
Please let me know if this is still open and I can contribute.

@stas00
Copy link
Contributor Author

stas00 commented Mar 27, 2021

Yes, please.

@bhadreshpsavani
Copy link
Contributor

bhadreshpsavani commented Apr 2, 2021

I reproduced this in colab and got 28% slowness but still figuring out the cause,
Earlier my assumption was this bit reduction/quantization was a device-specific thing.

@stas00
Copy link
Contributor Author

stas00 commented Apr 2, 2021

Usually in such situations I try to either go from the bottom up or in reverse. That is just take the model(**inputs) and measure the speed w/ model vs model.half() - if it's the same go one level up into generate, etc. Or starting from the top (generate) and then removing big chunks of code until you find the part that contributes to the slow down.

You can use this tracker to bracket the operation you measure.

class TrainerMemoryTracker:

But a totally different approach which might get to the core of the issue much faster is to use a python profiler, .e.g. cProfile - that way you get the full analytics on each function call and if you compare these side by side w/ and w/o half() you might get an instant answer. Actually now that I wrote this I'd say start with this approach.

@pcuenca
Copy link
Member

pcuenca commented Apr 22, 2021

I have done a few measures on 2 different cards (a 3090 and a 2080 Ti) using various evaluation batch sizes, and I haven't observed a single culprit for this problem. Instead, I'm seeing that all the operations in the forward pass are somewhat slower with fp16, and consistently so.

Setup

  • Evaluation batch size in {4, 8, 16, 32, 64, 128}
  • 128 evaluation samples. Since I'm using powers of 2 for the batch sizes, this allows us to test from 1 batch to many batches of the same size.
  • max_length = min_length = 128. Setting min_length to 128 increases processing time.

These are the results for the main operations inside the forward method of T5Block (total seconds spent in the corresponding areas; figures from the 3090 and the 3 first batch sizes for brevity):

image

The time difference depends on the batch size, but fp16 is always between 15% (for bs=64) and 26% (bs=16) slower.


Today I discovered this thread in the PyTorch forums, and repeated the test using a version of PyTorch compiled from source. Amazingly, processing is now almost twice as fast, but the difference is still there:

image

In this case, using a batch size of 128 (1 batch) is about 13% slower, while a batch size of 16 is 27% slower.

I'm not sure how to proceed. Does this ring a bell for anyone?

@stas00
Copy link
Contributor Author

stas00 commented Apr 23, 2021

Thank you for researching and profiling, @pcuenca!

I think the next step is the new pytorch profiler:
https://pytorch.org/blog/introducing-pytorch-profiler-the-new-and-improved-performance-tool/

Unfortunately, at the moment I have no time to dig into it, so I hope someone will beat me to it.


re: building from source:

Indeed, I recently built pytorch from source and I don't know if it's that or something else since 1 month passed since OP was made, but I'm getting 2x speed improvement (rtx-3090) on training this task. eval is only slightly faster, but is still 25% slower @ fp16.

Also adapted the cmd line to the recently changed examples:


export BS=16; rm -r output_dir; PYTHONPATH=src USE_TF=0 CUDA_VISIBLE_DEVICES=0 \
./examples/pytorch/translation/run_translation.py --model_name_or_path t5-small --output_dir /tmp/zero3 \
--overwrite_output_dir --max_train_samples 10 --max_eval_samples 100 --max_source_length 12 \
--max_target_length 128  --do_train --num_train_epochs 1 \
--per_device_train_batch_size 2 --learning_rate 3e-3 --warmup_steps 8 --predict_with_generate \
--logging_steps 0 --save_steps 2 --eval_steps 1 --group_by_length --adafactor --dataset_name wmt16 \
--dataset_config ro-en --source_lang en --target_lang ro \
--source_prefix "translate English to Romanian: " --do_eval 

***** train metrics *****
  epoch                      =        1.0
  init_mem_cpu_alloc_delta   =     1254MB
  init_mem_cpu_peaked_delta  =      155MB
  init_mem_gpu_alloc_delta   =      230MB
  init_mem_gpu_peaked_delta  =        0MB
  train_mem_cpu_alloc_delta  =     1382MB
  train_mem_cpu_peaked_delta =      125MB
  train_mem_gpu_alloc_delta  =      231MB
  train_mem_gpu_peaked_delta =      194MB
  train_runtime              = 0:00:04.19
  train_samples              =         10
  train_samples_per_second   =      1.191

***** eval metrics *****
  epoch                     =        1.0
  eval_bleu                 =     2.2434
  eval_gen_len              =      15.69
  eval_loss                 =     3.7374
  eval_mem_cpu_alloc_delta  =        1MB
  eval_mem_cpu_peaked_delta =        0MB
  eval_mem_gpu_alloc_delta  =        0MB
  eval_mem_gpu_peaked_delta =      171MB
  eval_runtime              = 0:00:04.33
  eval_samples              =        100
  eval_samples_per_second   =     23.051

add --fp16_full_eval

export BS=16; rm -r output_dir; PYTHONPATH=src USE_TF=0 CUDA_VISIBLE_DEVICES=0 \
./examples/pytorch/translation/run_translation.py --model_name_or_path t5-small --output_dir /tmp/zero3 \
--overwrite_output_dir --max_train_samples 10 --max_eval_samples 100 --max_source_length 12 \
--max_target_length 128  --do_train --num_train_epochs 1 \
--per_device_train_batch_size 2 --learning_rate 3e-3 --warmup_steps 8 --predict_with_generate \
--logging_steps 0 --save_steps 2 --eval_steps 1 --group_by_length --adafactor --dataset_name wmt16 \
--dataset_config ro-en --source_lang en --target_lang ro \
--source_prefix "translate English to Romanian: " --do_eval --fp16_full_eval

***** train metrics *****
  epoch                      =        1.0
  init_mem_cpu_alloc_delta   =     1259MB
  init_mem_cpu_peaked_delta  =      155MB
  init_mem_gpu_alloc_delta   =      230MB
  init_mem_gpu_peaked_delta  =        0MB
  train_mem_cpu_alloc_delta  =     1380MB
  train_mem_cpu_peaked_delta =      125MB
  train_mem_gpu_alloc_delta  =      231MB
  train_mem_gpu_peaked_delta =      194MB
  train_runtime              = 0:00:03.76
  train_samples              =         10
  train_samples_per_second   =      1.326

***** eval metrics *****
  epoch                     =        1.0
  eval_bleu                 =     2.2434
  eval_gen_len              =      15.69
  eval_loss                 =     3.7383
  eval_mem_cpu_alloc_delta  =        4MB
  eval_mem_cpu_peaked_delta =        0MB
  eval_mem_gpu_alloc_delta  =     -231MB
  eval_mem_gpu_peaked_delta =      262MB
  eval_runtime              = 0:00:05.32
  eval_samples              =        100
  eval_samples_per_second   =     18.778

@dsuess
Copy link
Contributor

dsuess commented Jun 27, 2021

By running everything with CUDA_LAUNCH_BLOCKING=1 under the line profiler, I found that this and this check for infinite values take up more time than I expected.

After removing those checks, this is what I end up with:

$ export BS=16; rm -r output_dir; PYTHONPATH=src USE_TF=0 CUDA_VISIBLE_DEVICES=0 \
python -m cProfile -o profile.prof  ./examples/pytorch/translation/run_translation.py --model_name_or_path t5-small --output_dir /tmp/zero3 \
--overwrite_output_dir --max_train_samples 10 --max_eval_samples 1600 --max_source_length 12 \
--max_target_length 128  --do_train --num_train_epochs 1 \
--per_device_train_batch_size 4 --per_device_eval_batch_size $BS --learning_rate 3e-3 --warmup_steps 8 --predict_with_generate \
--logging_steps 0 --save_steps 2 --eval_steps 1 --group_by_length --adafactor --dataset_name wmt16 \
--dataset_config ro-en --source_lang en --target_lang ro \
--source_prefix "translate English to Romanian: " --do_eval
...
***** eval metrics *****
  epoch                   =        1.0
  eval_bleu               =     0.3251
  eval_gen_len            =    10.2375
  eval_loss               =     3.6796
  eval_runtime            = 0:01:03.89
  eval_samples            =       1600
  eval_samples_per_second =      25.04
  eval_steps_per_second   =      1.565

The same with --fp16_full_eval:

***** eval metrics *****
  epoch                   =        1.0
  eval_bleu               =     0.3258
  eval_gen_len            =    10.2406
  eval_loss               =     3.6797
  eval_runtime            = 0:01:01.43
  eval_samples            =       1600
  eval_samples_per_second =     26.043
  eval_steps_per_second   =      1.628

Note that I had to dial up the number of eval examples since this measurement was quite noisy on the shared system I used. However, the FP16 was faster most of the time. If someone could double check these observations under more reliable circumstances, that'll be great.

@stas00
Copy link
Contributor Author

stas00 commented Jun 29, 2021

Thank you for looking into it, @dsuess!

I'm trying to figure out torch.profiler to get a better understanding using native tools.

Great to hear you found those checks to be slowdowns. Need to investigate these closer with torch.profiler.

And I also found

attn_weights = nn.functional.softmax(scores.float(), dim=-1).type_as(
to be another point of slowdown. It's possible that the upcast can be removed completely, which should speed things up. But definitely a slightly faster version is to:

            attn_weights = nn.functional.softmax(scores.float(), dim=-1).type_as(scores)
            attn_weights = nn.functional.softmax(scores.float(), dim=-1, dtype=scores.dtype)

for fp16 (it makes no difference for fp32)

I will look closer into the 2 points you suggested.

but also we should run under a more realistic configuration of at least seqlen 512 and not 12 like I had it originally, with large seqlen things change quite a lot. That is --max_source_length 512 --max_target_length 512 (or even better 1024).

@dsuess
Copy link
Contributor

dsuess commented Jan 19, 2022

Thanks for your feedback @stas00. I finally got the time to have a closer look with the pytorch profiler. I'd summarize what I found with:

  • the speedup we're getting for matmuls in fp16 aren't that great. This might be due to fewer kernels being executed on Tensor cores when using FP16 (31% of kernels) compared to FP32 (74% of kernels).
  • this is made worse by additional copy/conversion operations as can be seen in the device self time for FP16 (left) vs FP32 (right):

image

These conversions happen in the layer norm and before the softmax, which matches with your observation. I also double checked the layer norm with this micro benchmark, which runs ~30% slower in FP16.
There's a tiny improvement, which makes the eval-example run ~1% faster, but it doesn't even register in the micro benchmark.

Judging from the issue you raised, we can't run layer norm in FP16. I'd expect the same to be true for softmax, so I am unsure if we can get rid of those conversions. We may have a chance to get more out of the matmuls, so I'll try to figure out why those kernels don't run on Tensor cores despite being eligible.


I've done all these experiments on a 3080Ti with --max_source_length 512 --max_target_length 512

@stas00
Copy link
Contributor Author

stas00 commented Jan 19, 2022

This is fantastic work, @dsuess!

Here is an additional profiling report of the same issue but under tf32: #14608 (comment)

This appears to be specific to t5 and derived models. And yes the problem is that it uses RMSNorm which pytorch doesn't provide and that's why it's slow.

I made a request to make an RMSNorm fused kernel here: NVIDIA/apex#1271 and once this is done to ask to upstream it into pytorch. I hope this should solve this issue.

I also tried to avoid re-casting using some tricks here by trying to deploy the existing fused functions: #14656 but I couldn't find a faster way using the existing pytorch python API.

Have you by chance tried any other architectures using the same benchmarks? e.g. gpt2 and bert as they are very distinct from t5.

@dsuess
Copy link
Contributor

dsuess commented Jan 22, 2022

Here is an additional profiling report of the same issue but under tf32: #14608 (comment)

Great benchmark of the different data types, thanks for sharing.

Have you by chance tried any other architectures using the same benchmarks? e.g. gpt2 and bert as they are very distinct from t5.

I've just tested the same script with some of the mbart variants and as expected, fp16 is faster for those.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Good First Issue Good Second Issue Issues that are more difficult to do than "Good First" issues - give it a try if you want!
Projects
None yet
Development

No branches or pull requests

4 participants