Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Benchmark] GPT2LMHeadModel (gpt2-medium) forward pass inference became 9% slower compared to 2.8.0 release #11310

Closed
LSinev opened this issue Apr 19, 2021 · 3 comments
Assignees

Comments

@LSinev
Copy link
Contributor

LSinev commented Apr 19, 2021

🖥 Benchmarking GPT2LMHeadModel

Benchmark

GPT2LMHeadModel model call (and model.generate() too)

Set-up

gpu: gtx 1080
pytorch 1.4.0
transformers 2.8.0, 3.5.1, 4.5.1 releases and latest master branch

Code to reproduce

import timeit
import numpy as np
import torch
from transformers import __version__ as trans_version
from transformers import (
    GPT2LMHeadModel,
)

print("transformers:", trans_version)
model = GPT2LMHeadModel.from_pretrained("gpt2-medium")
print(model.__class__)
model.to("cuda")
model.eval()
rounding = 3

timed_result = timeit.repeat(stmt="""model.generate(input_ids=inp_t,
               max_length=1024,
               min_length=1024,
               do_sample=False,
               early_stopping=False, pad_token_id=50256, eos_token_id=50256)""",
                             setup="""inp = np.random.randint(low=1, high=50255, size=1014);inp_t = torch.LongTensor(inp).unsqueeze(0).to("cuda")""",
                             repeat=30, number=1, globals=globals())
timed_model_result = timeit.repeat(stmt="""with torch.no_grad():
    model(input_ids=inp_t)""",
                             setup="""inp = np.random.randint(low=1, high=50255, size=1024);inp_t = torch.LongTensor(inp).unsqueeze(0).to("cuda")""",
                             repeat=30, number=10, globals=globals())
print('GPT2LMmedium model.generate (using caching) 1014 input, generate to 1024 (mean ± 3std):',
      str(np.round(np.mean(timed_result), rounding)) + '±' + str(np.round(3 * np.std(timed_result), rounding)))
print('GPT2LMmedium model call, 1024 input 10 times (mean ± 3std):',
      str(np.round(np.mean(timed_model_result), rounding)) + '±' + str(np.round(3 * np.std(timed_model_result), rounding)))

Results

While model.generate() code improved and works faster now, model forward pass used in model direct call, became 9% slower

transformers: 2.8.0
<class 'transformers.modeling_gpt2.GPT2LMHeadModel'>
GPT2LMmedium model.generate (using caching) 1014 input, generate to 1024 (mean ± 3std): 0.557±0.037
GPT2LMmedium model call, 1024 input 10 times (mean ± 3std): 1.821±0.017

transformers: 3.5.1
<class 'transformers.modeling_gpt2.GPT2LMHeadModel'>
GPT2LMmedium model.generate (using caching) 1014 input, generate to 1024 (mean ± 3std): 0.37±0.003
GPT2LMmedium model call, 1024 input 10 times (mean ± 3std): 1.849±0.012

transformers: 4.5.1
<class 'transformers.models.gpt2.modeling_gpt2.GPT2LMHeadModel'>
GPT2LMmedium model.generate (using caching) 1014 input, generate to 1024 (mean ± 3std): 0.36±0.003
GPT2LMmedium model call, 1024 input 10 times (mean ± 3std): 1.823±0.013

transformers: 4.6.0.dev0
<class 'transformers.models.gpt2.modeling_gpt2.GPT2LMHeadModel'>
GPT2LMmedium model.generate (using caching) 1014 input, generate to 1024 (mean ± 3std): 0.367±0.004
GPT2LMmedium model call, 1024 input 10 times (mean ± 3std): 1.991±0.013

@LSinev
Copy link
Contributor Author

LSinev commented Apr 19, 2021

@patil-suraj Can you please check if this speed decrease of GPT2LMHeadModel model call is not caused by your PR #11225?

@patil-suraj
Copy link
Contributor

Hi @LSinev

Thank you for posting the detailed issue. I will take a look.

@patil-suraj patil-suraj self-assigned this Apr 19, 2021
@github-actions
Copy link

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

Please note that issues that do not follow the contributing guidelines are likely to be ignored.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants