Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Comparison different methods for benchmarking #6218

Closed
patrickvonplaten opened this issue Aug 3, 2020 · 16 comments
Closed

Comparison different methods for benchmarking #6218

patrickvonplaten opened this issue Aug 3, 2020 · 16 comments
Labels

Comments

@patrickvonplaten
Copy link
Contributor

Currently, the benchmarking tools make use of a multi-processing to be sure that all memory is released after each measurement and makes use of the py3nvml library to measure "peak GPU usage".

After some internal discussion, it is questionable whether the current code gives peak GPU memory usage. Thus, I ran a couple of experiments to see how torch benchmarking differs from py3nvml. It is known that there are differences in the memory benchmarking as explained here: https://stackoverflow.com/questions/62257967/why-does-a-single-conv2d-with-10x10x3-take-up-850mb-of-gpu#_=_

For a comparison, the following command was run:

python run_benchmark.py --models gpt2 bert-base-cased xlnet-base-cased --no_speed --save_to_csv --batch_sizes 8 64

The environment information is the following:

transformers_version 3.0.2
framework PyTorch
use_torchscript False
framework_version 1.6.0
python_version 3.6.10
system Linux
cpu x86_64
architecture 64bit
date 2020-08-03
time 14:47:20.956286
fp16 False
use_multiprocessing True
only_pretrain_model False
cpu_ram_mb 32088
use_gpu True
num_gpus 1
gpu TITAN RTX
gpu_ram_mb 24217
gpu_power_watts 280.0
gpu_performance_state 0
use_tpu False

a) These are the results when running the command with the current code (py3nvml):

model batch_size sequence_length result
gpt2 8 8 1422
gpt2 8 32 1454
gpt2 8 128 1732
gpt2 8 512 2784
gpt2 64 8 1558
gpt2 64 32 2086
gpt2 64 128 4170
gpt2 64 512 12482
bert-base-cased 8 8 1326
bert-base-cased 8 32 1360
bert-base-cased 8 128 1470
bert-base-cased 8 512 2042
bert-base-cased 64 8 1382
bert-base-cased 64 32 1640
bert-base-cased 64 128 2664
bert-base-cased 64 512 7158
xlnet-base-cased 8 8 1360
xlnet-base-cased 8 32 1422
xlnet-base-cased 8 128 1610
xlnet-base-cased 8 512 2476
xlnet-base-cased 64 8 1436
xlnet-base-cased 64 32 1830
xlnet-base-cased 64 128 3336
xlnet-base-cased 64 512 10344

b) These are the results when using the function torch.cuda.max_memory_resevered(torch.cuda.current_device()) instead:

model batch_size sequence_length result
gpt2 8 8 566
gpt2 8 32 598
gpt2 8 128 888
gpt2 8 512 1928
gpt2 64 8 702
gpt2 64 32 1230
gpt2 64 128 3314
gpt2 64 512 11626
bert-base-cased 8 8 470
bert-base-cased 8 32 504
bert-base-cased 8 128 614
bert-base-cased 8 512 1186
bert-base-cased 64 8 526
bert-base-cased 64 32 784
bert-base-cased 64 128 1808
bert-base-cased 64 512 6302
xlnet-base-cased 8 8 504
xlnet-base-cased 8 32 566
xlnet-base-cased 8 128 754
xlnet-base-cased 8 512 1620
xlnet-base-cased 64 8 580
xlnet-base-cased 64 32 974
xlnet-base-cased 64 128 2480
xlnet-base-cased 64 512 9488

One can see that the difference is always 856 MB (besides one exception where it is 868 MB). I ran the py3nvml benchmark multiple times and the result is very stable.
The same holds true when benchmarking training.

=> I tend to think that the way the code is currently implemented, it actually gives the peak memory usage, even though I could not find proof in the https://github.com/fbcotter/py3nvml library.

@stas00 - what is your opinion on that?

@patrickvonplaten
Copy link
Contributor Author

My main reason for not using PyTorch's max_memory_resevered function is that there is some GPU memory that is used, but not accounted for.

@stas00
Copy link
Contributor

stas00 commented Aug 4, 2020

One can see that the difference is always 856 MB (besides one exception where it is 868 MB)

I won't be surprised if the delta of 856MB that you reported - is the size of cudnn kernel (loaded). Could you please run a test to measure torch.ones((1, 1)).cuda() - if that's what you get then pytorch's tool should work just fine and w/o any complicated polling or potential conflicts - say 2 tests using the benchmarking framework happen to run at the same time on the same GPU - won't it fail in this scenario?

I have and old Titan X card so it's around 600MB, the novel Tesla T4 is ~1GB, so your being 856MB fits into the ballpark of it.

@patrickvonplaten
Copy link
Contributor Author

patrickvonplaten commented Aug 4, 2020

Yeah taking your code:

#!/usr/bin/env python3
from transformers import is_torch_available
import torch

if is_torch_available():
    from transformers import (
        PyTorchBenchmarkArguments,
        PyTorchBenchmark
    )

MODEL_ID = "facebook/bart-base"
ss = 8
bs = 1

benchmark_args = PyTorchBenchmarkArguments(
    models=[MODEL_ID],
    training=False,
    no_inference=False,
    sequence_lengths=[ss],
    batch_sizes=[bs],
    no_multi_process=False,
    no_cuda=False,
    no_speed=False,
)
benchmark = PyTorchBenchmark(benchmark_args)

# measure cudnn kernel memory consumption
# now we have a baseline that can be subtracted from all the other usages


def run_cuda_kernel_load():
    torch.ones((1, 1)).cuda()


mem, _ = benchmark._measure_memory(run_cuda_kernel_load)
mem_load = mem.bytes
print(f"Mem on load: {mem_load >> 20}MB")

gives me a result of 796 MB -> so we still have 60MB which is different from CUDA/CUDNN kernel loading + PyTorch's max_memory_reseverd vs. py3nvml that are not accounted for, but this is negligible IMO .

@patrickvonplaten
Copy link
Contributor Author

patrickvonplaten commented Aug 4, 2020

So we have two options here:

  1. Leave the functionality as it is and add:
def run_cuda_kernel_load():
    torch.ones((1, 1)).cuda()

to measure CUDA/CUDNN kernel loading.

The advantage is that we can leave the same functionality for TF

  1. Change to torch.cuda.max_memory_reserved + py3nvml or another tool to measure CUDA/CUDNN kernel loading.
    This option seems a bit safer and this way multiple processes could be run on the same GPU. Because measuring CUDA/CUDNN kernel loading cannot be done with torch.cuda.max_memory_reserved and relies on py3nvml or similar, I think we would run into the same problem here in that the result will not be correct if other processes run on the GPU. Or do you know how this can be measured without conflicting with another measurement on the same GPU at the same time?

I guess 2) is the better option though -> it seems safer and 2 measurements can be done at once. Maybe we should also only return this result as a default and optionally let the user decide if he / she wants to return the MB required to load CUDA/CUDNN. In the graphs on the model pages we would then include the MB required to load the CUDA/CUDNN kernel.

After thinking a bit more about RAM measurements when running on GPU - I think it makes actually more sense to only do this in combination with the new torch profiler: https://pytorch.org/docs/stable/autograd.html#profiler . I tried out the profiler and it gives very in-detail measurements for both CPU and GPU time and memory. For me this profiler is very useful for analysis of the code, e.g. which layer consumes how much memory / time, how much time / gpu is spent on CPU / GPU

So overall, IMO it would be nice to have 2 use cases:

a) Run peak memory usage (either on GPU or CPU) -> get one number for CPU, get one number for GPU (default to torch.cuda.max_memory_reserved and optionally add GPU CUDA/CUDNN kernel loading mem requirement). Here, I don't think we need to report CPU mem usage when model is run on GPU IMO. This would be very useful for ML engineers that want to use transformers in production.

b) Do in detail analysis -> run the new torch profiler: https://pytorch.org/docs/stable/autograd.html#profiler. For PyTorch this can replace the line-by-line tracing completely IMO and cover the case when the user wants to track CPU as well as GPU usage when running on GPU. We would require PyTorch 1.6 for this, but this is ok IMO. This use case would also be more interesting for researcher and "experts" and less more ML "engineers" with less research background.

I think the default should be to run only a), where as the user could optionally turn on b) to in-depth analysis.

Since TF does not (yet) have these tools, we will still have to rely on what we currently have for TF, but for PyTorch I'm happy to switch more to actual "Pytorch" tools to track memory since it seems to give very similar/equal results as py3nvml.

Also looping in @LysandreJik @julien-c and @sshleifer here to hear their opinions on that.

@patrickvonplaten
Copy link
Contributor Author

Same thing goes for speed I guess:

a) Leave functionality as it is for general overall speed (maybe change 30 averaging to 10 + some warmup) -> return one number
b) Use PyTorch profiler for in-detail profiling of CPU / GPU time. User can optionally turn this on.

@patrickvonplaten
Copy link
Contributor Author

BTW, the new profiler can be run like this:

!/usr/bin/env python3
from transformers import is_torch_available
import torch

if is_torch_available():
    from transformers import (
        BertModel
    )

def run_model():
    model = BertModel.from_pretrained("bert-base-cased")
    model.cuda()
    outputs = model(torch.tensor(32 * [128 * [0]]).cuda())
    return outputs


with torch.autograd.profiler.profile(use_cuda=True, profile_memory=True) as prof:
    run_model()

print(prof.table())

@stas00
Copy link
Contributor

stas00 commented Aug 4, 2020

gives me a result of 796 MB -> so we still have 60MB which is different from CUDA/CUDNN kernel loading + PyTorch's max_memory_reseverd vs. py3nvml that are not accounted for, but this is negligible IMO .

As I mentioned earlier, I'm not sure "reserved" is the right function, as it involves caching. Try torch.cuda.memory_allocated (and for peak torch.cuda.max_memory_allocated) instead.

@stas00
Copy link
Contributor

stas00 commented Aug 4, 2020

I'd say option #2, plus the code from option #1, so a user can still know the overhead of the cudnn kernel load.

Thank you for mentioning the profiler and the sample code - let me study and experiment with it and then I will be able to comment on your suggestions.

@stas00
Copy link
Contributor

stas00 commented Aug 4, 2020

Same thing goes for speed I guess:

a) Leave functionality as it is for general overall speed (maybe change 30 averaging to 10 + some warmup) -> return one number

Plus, I'd suggest to make n_repeats configurable, with a sensible default. e.g. when developing code I'd want to run n_repeats=1 - e.g. currently a large model takes a really long time to __init__ when it's run 30 times.

@stas00
Copy link
Contributor

stas00 commented Aug 4, 2020

Wrt returns, as we discussed, my suggestion is to have a full API that returns rich outputs, and then design shortcut wrappers that return just single specific bits, so that it makes test writing much less cluttered. i.e., removing a need for writing code like this as we have to do now:

mem, _ = benchmark._measure_memory(func)
mem = mem.bytes

it should be possible to do:

mem = benchmark._measure_memory_bytes(func)

no unpacking, no retrieving.

_measure_memory_bytes is just a possible name - we could think of something else.

@stas00
Copy link
Contributor

stas00 commented Aug 4, 2020

Ah, one more thing we discussed - we need to return general RAM when the benchmark is run on GPU. Memory leaks mainly happen in general RAM. So the main API should include measuring and returning this data too, and flags/shortcuts to enable/disable the calculation and retrieval of this data.

@stas00
Copy link
Contributor

stas00 commented Aug 4, 2020

2 more things to consider:

  1. Should we control these:
torch.backends.cudnn.benchmark = True
torch.backends.cudnn.enabled = True

as these settings should impact peformance

  1. allow a fixed seed arg?

@stas00
Copy link
Contributor

stas00 commented Aug 4, 2020

BTW, the new profiler can be run like this:

It appears that profiler has been around for quite some time. Of course, its table dump is huge and is difficult to work with, and there are almost no examples of the profiler use out there.

Here is what I came up with so far:

import torch
from transformers import BertModel

def run_model():
    model = BertModel.from_pretrained("bert-base-cased")
    model.cuda()
    model(torch.tensor(32 * [128 * [0]]).cuda())

with torch.autograd.profiler.profile(use_cuda=True, profile_memory=True) as prof:
    _=run_model()

cpu_time    = sum([e.cpu_time_total    for e in prof.key_averages()]) / 1000
cuda_time   = sum([e.cuda_time_total   for e in prof.key_averages()]) / 1000
cpu_mem     = sum([e.cpu_memory_usage  for e in prof.key_averages()]) >> 20
cuda_mem    = sum([e.cuda_memory_usage for e in prof.key_averages()]) >> 20 

print(f"Device |  Mem MB  | Speed ms")
print(f"CPU    | { cpu_mem:8}  | {cpu_time:8.2f}")
print(f"GPU    | {cuda_mem:8}  | {cuda_time:8.2f}")

gives

Device |  Mem MB  |  Speed ms
CPU    |        1 | 1835.97
GPU    |    13258 | 1846.77

I'm not sure yet whether any of this is correct - need to correlate with our profiling functions.

Figured out the self vs. total (the profiler results has a set of self_ attributes, in addition to total_) :

  • Total CPU: calls to the function, and functions called by the function,
  • Self CPU: calls to the function in the selected time range, excluding functions called by the function.

So we only care about total then.

If I add for comparison the measurements from our benchmark tools, I get mostly very different results - I run these on T4, unlike the ones above that were run on TitanX:

Device: Tesla T4
Model: bert-base-cased

init
Device    | Mem MB | Speed ms
CPU       |    500 |  3339.74
GPU       |      0 |  3339.54
Ours      |    914 |  3289.81

fwd
Device    | Mem MB | Speed ms
CPU       |      0 |    97.47
GPU       |     27 |   105.76
Ours      |    920 |    12.34

fwd-bwd
Device    | Mem MB | Speed ms
CPU       |      0 |   193.97
GPU       |   1723 |   211.86
Ours      |   1540 |    24.95

The last row labelled as Ours is benchmark's gpu measure_memory + measure_speed results. And the first two rows are from torch.autograd.profiler.profile as shown before.

As you can see only speed measurements for init match, the rest is dramatically different...

If anybody wants to continue experimenting, this is a WIP colab nb:
https://colab.research.google.com/drive/1i-_lxUCuuTKn5Nhe4ENMJd5RgHVeMSZP?usp=sharing

It has a bunch of other experiments, but if you run all - just watch the results of the last cell. Warning: the code is unpolished.

@stas00
Copy link
Contributor

stas00 commented Aug 5, 2020

One additional potential caveat for speed measurements is async code returning too early?
i.e. needing to run: torch.cuda.synchronize() before finishing the speed measurements? Most likely since it's a separate process it's of no need.

@patrickvonplaten
Copy link
Contributor Author

Thanks for posting this! Yes, I think we should maybe stick to our code for now regarding the total time and memory.
A while ago, @LysandreJik made some speed expeirements: https://docs.google.com/spreadsheets/d/1sryqufw2D0XlUH4sq3e9Wnxu5EAQkaohzrJbd5HdQ_w/edit which match the results given my PyTorchBenchmark very nicely, so I'm quite positive that the speed measurements are correct.

@stale
Copy link

stale bot commented Oct 10, 2020

This issue has been automatically marked as stale because it has not had recent activity. It will be closed if no further activity occurs. Thank you for your contributions.

@stale stale bot added the wontfix label Oct 10, 2020
@stale stale bot closed this as completed Oct 18, 2020
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

No branches or pull requests

2 participants