Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

TensorFlow training/inference optimization #7605

Closed
wants to merge 18 commits into from

Conversation

jplu
Copy link
Contributor

@jplu jplu commented Oct 6, 2020

What does this PR do?

This PR fixes a performance issue where some operation was done on CPU instead of GPU and would result to put the GPU in idle mode. This optimization is feasible thanks to the recent update we made on the way we load the TF weights.

@patrickvonplaten I have done few changes in the TFLongformer model but I'm sure it can be further optimized the same way (see TFLongformerSelfAttention) but as I don't know much on how works this model, can you take a look if the same optimization can be applied?

Fixes #6771

Copy link
Member

@LysandreJik LysandreJik left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If that's all it takes, that's fantastic! Did you manage to obtain the performance improvements that were initially mentioned thanks to this?

Also I'm realizing now that we don't have integration testing for our TensorFlow models, and this seems like a situation where having some would be needed. Could we work on adding these tests for the models modified here at first, and then add it to the rest of the models?

Something like is done in tests/test_modeling_roberta.py, using tiny models.

I can help you work on it if you're lacking time!

@jplu
Copy link
Contributor Author

jplu commented Oct 6, 2020

If that's all it takes, that's fantastic! Did you manage to obtain the performance improvements that were initially mentioned thanks to this?

On my machine with my GPU yes.

Also I'm realizing now that we don't have integration testing for our TensorFlow models, and this seems like a situation where having some would be needed. Could we work on adding these tests for the models modified here at first, and then add it to the rest of the models?

Sure! It is a good idea!

I can help you work on it if you're lacking time!

I would appreciate if you have time yes 😃

@LysandreJik
Copy link
Member

Okay, will take a look at doing the integrations tests sometimes tonight. Will let you know!

@ydshieh
Copy link
Collaborator

ydshieh commented Oct 6, 2020

@jplu

For learning purpose, I am wondering which operations was done on CPU instead of GPU. I saw you changed Dense to EinsumDense in several places, and remove several operations about shape changing. Is shape changing done on CPU and EinsumDense could avoid this? Could you give me some information about this, so I can read and learn it? Thanks.

@jplu
Copy link
Contributor Author

jplu commented Oct 6, 2020

@chiapas

If you take a look at #6771 is it quite well detailed. The issue was coming from transpose+matmul that was done on CPU. einsumDense allows you to do all these computation directly in the layer but at the cost of changing the shapes of the original layers, that why we have modified the way we load the TF models.

To do this PR I basically took example on the original BERT implementation right here.

@jplu
Copy link
Contributor Author

jplu commented Oct 6, 2020

Thanks a lot @LysandreJik !!

As I'm currently working on from scratch LM training for TF models, I don't have much time to really focus on this.

@ydshieh
Copy link
Collaborator

ydshieh commented Oct 6, 2020

transpose+matmul

@jplu Thanks. I am superised by this transpose+matmul that was done on CPU.

@ydshieh
Copy link
Collaborator

ydshieh commented Oct 6, 2020

Thanks a lot @LysandreJik !!

As I'm currently working on from scratch LM training for TF models, I don't have much time to really focus on this.

@jplu You also works on LM training for TF models? I plan to go back to a pending PR #6955 I created once the test_trainer_tf.py is done. Do PR #6955 and your work on TF models LM training overlap? Currently that PR is still empty though.

@jplu
Copy link
Contributor Author

jplu commented Oct 6, 2020

@chiapas This is exactly what I'm doing, and the models needs some rework that's why I'm mostly focus on BERT to have at least one model working.

I just done yesterday the data pipeline with random masking generation.

@ydshieh
Copy link
Collaborator

ydshieh commented Oct 6, 2020

@chiapas This is exactly what I'm doing, and the models needs some rework that's why I'm mostly focus on BERT to have at least one model working.

I just done yesterday the data pipeline with random masking generation.

Ah, ok. I guess my PR was pending too long and it is my bad not to communicate with you first. I planed to do this while I finished a notebook on Kaggle Masked, My Dear Watson - MLM with TPU, which also works on MLM.

Since you already have more progresses (and also you are HF member), it is better for you to continue. However, if there is something I can contribute for this TF LM task, I would love to do it.

@jplu
Copy link
Contributor Author

jplu commented Oct 6, 2020

Since you already have more progresses (and also you are HF member), it is better for you to continue. However, if there is something I can contribute for this TF LM task, I would love to do it.

Thanks! I will let you know.

@patrickvonplaten
Copy link
Contributor

patrickvonplaten commented Oct 6, 2020

That's awesome! I will see what results the TF benchmark scripts give before/after this PR.

Strongly agree with @LysandreJik that we should add integration tests before merging this PR.

@patrickvonplaten
Copy link
Contributor

patrickvonplaten commented Oct 6, 2020

I ran the benchmarks: python examples/benchmarking/run_benchmark_tf.py --models bert-base-cased --env_print in the following environment:

- transformers_version: 3.3.1                                                                                                                                                                                      
- framework: TensorFlow                                                                                                                                                                                            
- eager_mode: False                                                                                                                                                                                                
- use_xla: False                                                                                                                                                                                                   
- framework_version: 2.3.0                                                                                                                                                                                         
- python_version: 3.6.10                                                                                                                                                                                           
- system: Linux                                                                                                                                                                                                    
- cpu: x86_64                                                                                                                                                                                                      
- architecture: 64bit                                                                                                                                                                                              
- date: 2020-10-06                   
- time: 19:06:48.378935                                                                                  
- fp16: False                                                                                                                                                                                                      
- use_multiprocessing: True                                                                                                                                                                                        
- only_pretrain_model: False                                                                                                                                                                                       
- cpu_ram_mb: 32088                                 
- use_gpu: True                                                                                          
- num_gpus: 1                                                                                            
- gpu: TITAN RTX                                                                                         
- gpu_ram_mb: 24217                                                                                      
- gpu_power_watts: 280.0                                                                                 
- gpu_performance_state: 8                                                                               
- use_tpu: False 

Currently, on master:

====================       INFERENCE - SPEED - RESULT       ====================
--------------------------------------------------------------------------------
          Model Name             Batch Size     Seq Length     Time in s   
--------------------------------------------------------------------------------
       bert-base-cased               8               8             0.085     
       bert-base-cased               8               32            0.166     
       bert-base-cased               8              128            0.513     
       bert-base-cased               8              512            2.629     
--------------------------------------------------------------------------------                                                                                                                                 

In this tf-optim branch, the results are:

====================       INFERENCE - SPEED - RESULT       ====================
--------------------------------------------------------------------------------
          Model Name             Batch Size     Seq Length     Time in s   
--------------------------------------------------------------------------------
       bert-base-cased               8               8             0.088     
       bert-base-cased               8               32            0.176     
       bert-base-cased               8              128            0.531     
       bert-base-cased               8              512            3.028     
--------------------------------------------------------------------------------

=> So the speed results are more or less identical with the way the benchmarks are used.

I don't compile the model with Keras, but just add the "@tf.function" decorator to the function to transform the function into graph mode. So not sure what to think of that.... => @jplu - colud you maybe check the benchmark script and see if you can get a speed-up there? Or if the benchmark script is wrong?

python examples/benchmarking/run_benchmark_tf.py --models bert-base-cased --env_print

@jplu
Copy link
Contributor Author

jplu commented Oct 6, 2020

The benchmark script is ok, but to see the difference you have to create a saved_model and run the model in TF Serving. Your benchmark don't take into account all the optimization TF serving does for inference.

We should update the benchmark script to include:

  • Saved model creation
  • run a the saved model with the TF Serving tool
  • adapt the the benchmark to include gRPC calls to use the model from TF Serving.

@jplu
Copy link
Contributor Author

jplu commented Nov 9, 2020

Will be integrated into the PR #7753

@jplu jplu closed this Nov 9, 2020
@jplu jplu deleted the tf-optim branch November 9, 2020 13:02
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Exported TF Bert model is much slower than that exported from Google's Bert
4 participants