Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[TF Longformer] Improve Speed for TF Longformer #6447

Conversation

patrickvonplaten
Copy link
Contributor

@patrickvonplaten patrickvonplaten commented Aug 12, 2020

This PR:

  • adds a simple test for all tf models to verify that the forward function can be used in graph mode
  • optimizes TF Longformer, by removing unnecessary calculation, such as tf.transpose() (In contrast to PyTorch, tf.transpose() allocates a new tensor and thus should be avoided). This also cleans up the code IMO.

=> These changes lead to a speed-up of 1.03 which is actually not that much...more details in benchmark below.

After a lot of digging TF XLA will not be very easy to implement as a lot of kernels that are highly used in this model tf.where are not implemented for XLA (yet). So TF Longformer TPU will not work sadly for the moment @ibeltagy

Conclusion

For me the PR was also a good exercise to see whether TF can significantly sped up by removing unnecessary tensor allocations. It seems like it's not really worth it go through all the tf models if the improvement in speed is only around 2,3%.

@codecov
Copy link

codecov bot commented Aug 13, 2020

Codecov Report

Merging #6447 into master will increase coverage by 0.84%.
The diff coverage is 100.00%.

Impacted file tree graph

@@            Coverage Diff             @@
##           master    #6447      +/-   ##
==========================================
+ Coverage   78.96%   79.81%   +0.84%     
==========================================
  Files         157      157              
  Lines       28486    28479       -7     
==========================================
+ Hits        22495    22730     +235     
+ Misses       5991     5749     -242     
Impacted Files Coverage Δ
src/transformers/modeling_tf_bert.py 98.38% <100.00%> (ø)
src/transformers/modeling_tf_electra.py 98.95% <100.00%> (+73.82%) ⬆️
src/transformers/modeling_tf_longformer.py 98.67% <100.00%> (-0.03%) ⬇️
src/transformers/modeling_tf_xlm.py 18.94% <0.00%> (-74.32%) ⬇️
src/transformers/modeling_tf_flaubert.py 24.53% <0.00%> (-63.81%) ⬇️
src/transformers/modeling_roberta.py 77.37% <0.00%> (-19.71%) ⬇️
src/transformers/file_utils.py 82.41% <0.00%> (-0.26%) ⬇️
src/transformers/modeling_utils.py 88.05% <0.00%> (+0.55%) ⬆️
src/transformers/modeling_tf_utils.py 87.29% <0.00%> (+2.60%) ⬆️
src/transformers/generation_tf_utils.py 85.71% <0.00%> (+2.75%) ⬆️
... and 4 more

Continue to review full report at Codecov.

Legend - Click here to learn more
Δ = absolute <relative> (impact), ø = not affected, ? = missing data
Powered by Codecov. Last update a75c64d...ae3bbe2. Read the comment docs.

@patrickvonplaten patrickvonplaten force-pushed the add_tf_compile_function_to_test branch from 8c8c5a0 to 85bed01 Compare August 14, 2020 07:53
# is index masked or global attention

is_index_masked = tf.math.less(attention_mask, 0)
is_index_global_attn = tf.math.greater(attention_mask, 0)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These operations should be done once at not be repeated for every layer -> unnecessary slow-down

@@ -170,24 +168,20 @@ def call(
# normalize query
query_vectors /= tf.math.sqrt(tf.constant(self.head_dim, dtype=tf.dtypes.float32))

query_vectors = tf.transpose(
Copy link
Contributor Author

@patrickvonplaten patrickvonplaten Aug 14, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

get rid of tf.transpose as it allocates a new tensor upon calling. Should be avoided.

@@ -134,6 +125,19 @@ def test_save_load(self):

self.assert_outputs_same(after_outputs, outputs)

def test_graph_mode(self):
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add a function that tests every model in graph mode

@patrickvonplaten patrickvonplaten changed the title [TF Optim] Add tf compile function to test [TF Longformer] Improve Speed and Memory for TF Longformer Aug 14, 2020
@patrickvonplaten
Copy link
Contributor Author

Speed Benchmarking:

Running this command on the master branch:

python examples/benchmarking/run_benchmark.py --models allenai/longformer-base-4096 --no_memory --sequence_length 512 1024

on this env:

- transformers_version: 3.0.2
- framework: TensorFlow
- eager_mode: False
- use_xla: False
- framework_version: 2.2.0
- python_version: 3.8.5
- system: Linux
- cpu: x86_64
- architecture: 64bit
- date: 2020-08-14
- time: 10:32:09.525696
- fp16: False
- use_multiprocessing: True
- only_pretrain_model: False
- cpu_ram_mb: N/A
- use_gpu: True
- num_gpus: 1
- gpu: TITAN RTX
- gpu_ram_mb: 24217
- gpu_power_watts: 280.0
- gpu_performance_state: 0
- use_tpu: False

gives:

====================       INFERENCE - SPEED - RESULT       ====================                         
--------------------------------------------------------------------------------                         
          Model Name             Batch Size     Seq Length     Time in s                                 
--------------------------------------------------------------------------------                         
 allenai/longformer-base-4096        8              512            0.229                                 
 allenai/longformer-base-4096        8              1024           0.463                                 
--------------------------------------------------------------------------------

On this branch the speed is improved to:

====================       INFERENCE - SPEED - RESULT       ====================                         
--------------------------------------------------------------------------------                         
          Model Name             Batch Size     Seq Length     Time in s                                 
--------------------------------------------------------------------------------                         
 allenai/longformer-base-4096        8              512            0.223                                 
 allenai/longformer-base-4096        8              1024           0.447                                 
--------------------------------------------------------------------------------

So we can see an improvement of ca. 3%, which is not that much actually... I guess it's interesting to see what effect removing some unnecessary tf.transpose() has in TF, but it might not be worth to go through all modeling_tf_... files trying to remove tf.transpose() and similar functions.

@patrickvonplaten patrickvonplaten requested review from jplu, sgugger and LysandreJik and removed request for jplu August 14, 2020 10:55
@patrickvonplaten patrickvonplaten changed the title [TF Longformer] Improve Speed and Memory for TF Longformer [TF Longformer] Improve Speed for TF Longformer Aug 14, 2020
Copy link
Collaborator

@sgugger sgugger left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, though I'm no expert in TF :-)

src/transformers/modeling_tf_longformer.py Show resolved Hide resolved
@patrickvonplaten patrickvonplaten force-pushed the add_tf_compile_function_to_test branch from 316206c to ae3bbe2 Compare August 26, 2020 16:26
Copy link
Member

@LysandreJik LysandreJik left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great, LGTM! Nice test, should be pretty similar to the Keras/compile tests!

@LysandreJik LysandreJik merged commit 858b7d5 into huggingface:master Aug 26, 2020
Zigur pushed a commit to Zigur/transformers that referenced this pull request Oct 26, 2020
* add tf graph compile tests

* fix conflict

* remove more tf transpose statements

* fix conflicts

* fix comment typos

* move function to class function

* fix black

* fix black

* make style
fabiocapsouza pushed a commit to fabiocapsouza/transformers that referenced this pull request Nov 15, 2020
* add tf graph compile tests

* fix conflict

* remove more tf transpose statements

* fix conflicts

* fix comment typos

* move function to class function

* fix black

* fix black

* make style
fabiocapsouza added a commit to fabiocapsouza/transformers that referenced this pull request Nov 15, 2020
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants