Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Torchscript benchmark measure #6907

Conversation

patrickvonplaten
Copy link
Contributor

@patrickvonplaten patrickvonplaten commented Sep 2, 2020

This PR is just there to show some benchmarking results of BertScriptableModel vs. BertModel. It shows the results of running the script: benchmark_pytorch_scripting.py.

In a nutshell, the script does the following:

  1. Create a list of 500 and 2500 input_tensors of batch_size 1 with a sequence length varying between 1 and 128 or 1 and 512.
    Then take a scripted model model = torch.jit.script(BertScriptableModel(...)) and loop over all 500 / 2500 input_tensors in a standard for loop. The script model is warmed up by running the loop 5 times before measuring the time. The loop is run 10 times and the fastest run is taken as a measurement.

  2. Create a list of 64 and 512 input_tensors of batch_size 8 with a sequence length varying between 1 and 128 or 1 and 512.
    Then take a scripted model model = torch.jit.script(BertScriptableModel(...)) and loop over all 64 / 512 input_tensors in a standard for loop. The script model is warmed up by running the loop 5 times before measuring the time. The loop is run 10 times and the fastest run is taken as a measurement.

All this was done on the following environment information:

====================        ENVIRONMENT INFORMATION         ====================                                                                                                                                   
- transformers_version: 3.0.0                                                                            
- framework: PyTorch                                                                                     
- use_torchscript: True                                                                                  
- framework_version: 1.6.0                                                                               
- python_version: 3.6.10                                                                                 
- system: Linux                                                                                          
- cpu: x86_64                                                                                     
- architecture: 64bit                                                                                    
- date: 2020-09-02                                                                                       
- time: 16:26:10.562635                                                                                  
- fp16: False                                                                                            
- use_multiprocessing: False                                                                             
- only_pretrain_model: False                                                                             
- cpu_ram_mb: 32088                                                                                      
- use_gpu: True                                                                                          
- num_gpus: 1                                                                                            
- gpu: TITAN RTX                                                                                         
- gpu_ram_mb: 24217                                                                                      
- gpu_power_watts: 280.0                                                                                 
- gpu_performance_state: 2                                                                               
- use_tpu: False 

=> So only on GPU.

To run this script, one can simply run:

./benchmark_pytorch_scripting.py

Important:

The "for" loop corresponds to the function defined in lines 32 - 37 of the file benchmark_pytorch_scripting.py.
This function then overwrites the function that is usually measured in benchmarks, by setting benchmark._prepare_inference_func = _prepare_inference_func in line 49.

It would be awesome if @sbrody18 could take a look at the benchmark_pytorch_scripting.py f file to check if torchscript was used correctly.

@patrickvonplaten patrickvonplaten marked this pull request as draft September 2, 2020 15:31
@patrickvonplaten
Copy link
Contributor Author

patrickvonplaten commented Sep 2, 2020

Results for 1):

1 / 1

====================       INFERENCE - SPEED - RESULT       ====================
--------------------------------------------------------------------------------
          Model Name             Batch Size     Seq Length     Time in s   
--------------------------------------------------------------------------------
Type: multiple - Script: True       500             128            2.575     
Type: multiple - Script: True       500             512            3.898     
Type: multiple - Script: True       2500            128            13.173    
Type: multiple - Script: True       2500            512            18.263    
--------------------------------------------------------------------------------
1 / 1

====================       INFERENCE - SPEED - RESULT       ====================
--------------------------------------------------------------------------------
          Model Name             Batch Size     Seq Length     Time in s   
--------------------------------------------------------------------------------
Type: multiple - Script: False      500             128            3.733     
Type: multiple - Script: False      500             512            3.857     
Type: multiple - Script: False      2500            128            19.101    
Type: multiple - Script: False      2500            512            19.356    
--------------------------------------------------------------------------------

For the smaller sequence length 128 we can see a significant speed-up (~30%) - for the longer sequence length 512, the speed-up is much smaller (and only for the bigger list of inputs).

@patrickvonplaten
Copy link
Contributor Author

Results for 2)

====================       INFERENCE - SPEED - RESULT       ====================
--------------------------------------------------------------------------------
          Model Name             Batch Size     Seq Length     Time in s   
--------------------------------------------------------------------------------
 Type: batched - Script: True       512             128            0.819     
 Type: batched - Script: True       512             512            3.769     
 Type: batched - Script: True       4096            128            6.705     
 Type: batched - Script: True       4096            512            26.549    
--------------------------------------------------------------------------------
1 / 1

====================       INFERENCE - SPEED - RESULT       ====================
--------------------------------------------------------------------------------
          Model Name             Batch Size     Seq Length     Time in s   
--------------------------------------------------------------------------------
Type: batched - Script: False       512             128            0.837     
Type: batched - Script: False       512             512             3.88     
Type: batched - Script: False       4096            128             6.75     
Type: batched - Script: False       4096            512            27.162    
--------------------------------------------------------------------------------

Here no clear speed gains can be seen.

@sbrody18
Copy link

sbrody18 commented Sep 2, 2020

I'm not sure I understand all the interactions in the benchmarking framework, but I think in line 9 (non-script model) we should be returning torch.jit.trace(model, sample_input), not the untraced model. And the sample input would have be max_length for it to work. That's were most of the gain comes from.
Then the comparison is between using torch.jit.trace() and torch.jit.script(). Or maybe I'm missing some code that does that elsewhere?

@patrickvonplaten
Copy link
Contributor Author

patrickvonplaten commented Sep 3, 2020

Okey, yeah that makes sense! I changed the benchmarking script accordingly and have the following results now:

====================       INFERENCE - SPEED - RESULT       ====================
--------------------------------------------------------------------------------
          Model Name             Batch Size     Seq Length     Time in s   
--------------------------------------------------------------------------------
Type: multiple - Script: True       500             128            1.793     
Type: multiple - Script: True       500             512            3.628     
Type: multiple - Script: True       2500            128            8.774     
Type: multiple - Script: True       2500            512            19.471    
--------------------------------------------------------------------------------
1 / 1

====================       INFERENCE - SPEED - RESULT       ====================
--------------------------------------------------------------------------------
          Model Name             Batch Size     Seq Length     Time in s   
--------------------------------------------------------------------------------
Type: multiple - Trace: True      500             128             1.83     
Type: multiple - Trace: True      500             512            3.783     
Type: multiple - Trace: True      2500            128            9.083     
Type: multiple - Trace: True      2500            512            20.569    
--------------------------------------------------------------------------------

and

====================       INFERENCE - SPEED - RESULT       ====================
--------------------------------------------------------------------------------
          Model Name             Batch Size     Seq Length     Time in s   
--------------------------------------------------------------------------------
 Type: batched - Script: True       512             128            1.043     
 Type: batched - Script: True       512             512            4.913     
 Type: batched - Script: True       4096            128            8.499     
 Type: batched - Script: True       4096            512            34.187    
--------------------------------------------------------------------------------
1 / 1

====================       INFERENCE - SPEED - RESULT       ====================
--------------------------------------------------------------------------------
          Model Name             Batch Size     Seq Length     Time in s   
--------------------------------------------------------------------------------
Type: batched - Trace: True       512             128            1.046     
Type: batched - Trace: True       512             512            4.916     
Type: batched - Trace: True       4096            128            8.042     
Type: batched - Trace: True       4096            512            30.874    
--------------------------------------------------------------------------------

=> So my understanding is now that torch.trace(...) is much more efficient for dynamic input shapes than not using torch.jit at all, but I also don't see how torch.script(...) is better than torch.trace(...). If our models are compatible with torch.trace(...), why do we need to have a model that is compatible with torch.script(...)? It is definitely more convenient to just call torch.trace(model) without having to provide any input_ids, but I'm not 100% sure whether it's worth a huge refactoring.

also cc @sgugger @LysandreJik

@sbrody18
Copy link

sbrody18 commented Sep 4, 2020

We saw different behavior in our experiments a few months ago. Will try to reproduce and update here.

@patrickvonplaten
Copy link
Contributor Author

We saw different behavior in our experiments a few months ago. Will try to reproduce and update here.

Was torch.script() much faster than torch.trace() in your experiments?

@sbrody18
Copy link

sbrody18 commented Sep 4, 2020

In our experiments, using trace(model, example_input) would result in a model that would only accept a sequence of the same length as example_sequence, whereas script(model) had no such restriction. This is the case mentioned in your documentation here: https://huggingface.co/transformers/torchscript.html#dummy-inputs-and-standard-lengths

What that meant in practice is that you needed to trace with an example sequence of length = max_length, and then pad every example of length < max_length with zeros. Since the speed of the model is basically linear in the sequence length, for a set of inputs with varying sequence lengths we got a speed up of avg_len/max_length by using script() instead of trace().

Upon further investigation, it looks like when we ran these experiments, several months ago, we were using Torch 1.2. It looks like in Torch 1.3 the fixed-length problem is no longer an issue for your BERT models (we still encounter it with other models architectures we build). So there's no longer a big speed gain from script() vs trace().

There are still some good reasons for preferring script() to trace() - scripting is guaranteed to capture the model codepath logic, whereas tracing might miss a logic branch if the example input doesn't flow through it. Also, currently tracing your models produces several warnings like the one below. But I'm not sure if those on their own are enough of a motivation to make major changes in your code base.

TracerWarning: Converting a tensor to a Python index might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!

@patrickvonplaten
Copy link
Contributor Author

In our experiments, using trace(model, example_input) would result in a model that would only accept a sequence of the same length as example_sequence, whereas script(model) had no such restriction. This is the case mentioned in your documentation here: https://huggingface.co/transformers/torchscript.html#dummy-inputs-and-standard-lengths

What that meant in practice is that you needed to trace with an example sequence of length = max_length, and then pad every example of length < max_length with zeros. Since the speed of the model is basically linear in the sequence length, for a set of inputs with varying sequence lengths we got a speed up of avg_len/max_length by using script() instead of trace().

Upon further investigation, it looks like when we ran these experiments, several months ago, we were using Torch 1.2. It looks like in Torch 1.3 the fixed-length problem is no longer an issue for your BERT models (we still encounter it with other models architectures we build). So there's no longer a big speed gain from script() vs trace().

There are still some good reasons for preferring script() to trace() - scripting is guaranteed to capture the model codepath logic, whereas tracing might miss a logic branch if the example input doesn't flow through it. Also, currently tracing your models produces several warnings like the one below. But I'm not sure if those on their own are enough of a motivation to make major changes in your code base.

TracerWarning: Converting a tensor to a Python index might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!

@sgugger - what are your thoughts on this?

@sgugger
Copy link
Collaborator

sgugger commented Sep 8, 2020

I think adding the scriptable layers seems cleaner to make sure everything works right with scripting/tracing. Not the approach in this PR but the other linked in a comment (@sbrody18 I don't know if you saw my PR to rebase on master for this branch). It ends up with most changes being helpful to read the code (type annotations and asserts) and a few extra classes for the scriptable layers but not much added code.

@sbrody18
Copy link

sbrody18 commented Sep 8, 2020

@sgugger I agree - I think the extra benefit of the type and None-checking is really helpful to prevent bugs and makes the code better.
I saw your PR late Friday and didn't have time to look into it. Will try to do so by end of day.

@stale
Copy link

stale bot commented Nov 7, 2020

This issue has been automatically marked as stale because it has not had recent activity. It will be closed if no further activity occurs. Thank you for your contributions.

@stale stale bot added the wontfix label Nov 7, 2020
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants