-
Notifications
You must be signed in to change notification settings - Fork 27.8k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Comparison different methods for benchmarking #6218
Comments
My main reason for not using PyTorch's |
I won't be surprised if the delta of 856MB that you reported - is the size of cudnn kernel (loaded). Could you please run a test to measure I have and old Titan X card so it's around 600MB, the novel Tesla T4 is ~1GB, so your being 856MB fits into the ballpark of it. |
Yeah taking your code: #!/usr/bin/env python3
from transformers import is_torch_available
import torch
if is_torch_available():
from transformers import (
PyTorchBenchmarkArguments,
PyTorchBenchmark
)
MODEL_ID = "facebook/bart-base"
ss = 8
bs = 1
benchmark_args = PyTorchBenchmarkArguments(
models=[MODEL_ID],
training=False,
no_inference=False,
sequence_lengths=[ss],
batch_sizes=[bs],
no_multi_process=False,
no_cuda=False,
no_speed=False,
)
benchmark = PyTorchBenchmark(benchmark_args)
# measure cudnn kernel memory consumption
# now we have a baseline that can be subtracted from all the other usages
def run_cuda_kernel_load():
torch.ones((1, 1)).cuda()
mem, _ = benchmark._measure_memory(run_cuda_kernel_load)
mem_load = mem.bytes
print(f"Mem on load: {mem_load >> 20}MB") gives me a result of 796 MB -> so we still have 60MB which is different from CUDA/CUDNN kernel loading + PyTorch's |
So we have two options here:
def run_cuda_kernel_load():
torch.ones((1, 1)).cuda() to measure CUDA/CUDNN kernel loading. The advantage is that we can leave the same functionality for TF
I guess 2) is the better option though -> it seems safer and 2 measurements can be done at once. Maybe we should also only return this result as a default and optionally let the user decide if he / she wants to return the MB required to load CUDA/CUDNN. In the graphs on the model pages we would then include the MB required to load the CUDA/CUDNN kernel. After thinking a bit more about RAM measurements when running on GPU - I think it makes actually more sense to only do this in combination with the new torch profiler: https://pytorch.org/docs/stable/autograd.html#profiler . I tried out the profiler and it gives very in-detail measurements for both CPU and GPU time and memory. For me this profiler is very useful for analysis of the code, e.g. which layer consumes how much memory / time, how much time / gpu is spent on CPU / GPU So overall, IMO it would be nice to have 2 use cases: a) Run peak memory usage (either on GPU or CPU) -> get one number for CPU, get one number for GPU (default to b) Do in detail analysis -> run the new torch profiler: https://pytorch.org/docs/stable/autograd.html#profiler. For PyTorch this can replace the line-by-line tracing completely IMO and cover the case when the user wants to track CPU as well as GPU usage when running on GPU. We would require PyTorch 1.6 for this, but this is ok IMO. This use case would also be more interesting for researcher and "experts" and less more ML "engineers" with less research background. I think the default should be to run only a), where as the user could optionally turn on b) to in-depth analysis. Since TF does not (yet) have these tools, we will still have to rely on what we currently have for TF, but for PyTorch I'm happy to switch more to actual "Pytorch" tools to track memory since it seems to give very similar/equal results as Also looping in @LysandreJik @julien-c and @sshleifer here to hear their opinions on that. |
Same thing goes for a) Leave functionality as it is for general overall speed (maybe change 30 averaging to 10 + some warmup) -> return one number |
BTW, the new profiler can be run like this: !/usr/bin/env python3
from transformers import is_torch_available
import torch
if is_torch_available():
from transformers import (
BertModel
)
def run_model():
model = BertModel.from_pretrained("bert-base-cased")
model.cuda()
outputs = model(torch.tensor(32 * [128 * [0]]).cuda())
return outputs
with torch.autograd.profiler.profile(use_cuda=True, profile_memory=True) as prof:
run_model()
print(prof.table()) |
As I mentioned earlier, I'm not sure "reserved" is the right function, as it involves caching. Try |
Plus, I'd suggest to make |
Wrt returns, as we discussed, my suggestion is to have a full API that returns rich outputs, and then design shortcut wrappers that return just single specific bits, so that it makes test writing much less cluttered. i.e., removing a need for writing code like this as we have to do now:
it should be possible to do:
no unpacking, no retrieving.
|
Ah, one more thing we discussed - we need to return general RAM when the benchmark is run on GPU. Memory leaks mainly happen in general RAM. So the main API should include measuring and returning this data too, and flags/shortcuts to enable/disable the calculation and retrieval of this data. |
2 more things to consider:
as these settings should impact peformance
|
It appears that profiler has been around for quite some time. Of course, its table dump is huge and is difficult to work with, and there are almost no examples of the profiler use out there. Here is what I came up with so far:
gives
I'm not sure yet whether any of this is correct - need to correlate with our profiling functions. Figured out the self vs. total (the profiler results has a set of
So we only care about total then. If I add for comparison the measurements from our benchmark tools, I get mostly very different results - I run these on T4, unlike the ones above that were run on TitanX:
The last row labelled as Ours is benchmark's gpu As you can see only speed measurements for If anybody wants to continue experimenting, this is a WIP colab nb: It has a bunch of other experiments, but if you run all - just watch the results of the last cell. Warning: the code is unpolished. |
One additional potential caveat for speed measurements is async code returning too early? |
Thanks for posting this! Yes, I think we should maybe stick to our code for now regarding the total time and memory. |
This issue has been automatically marked as stale because it has not had recent activity. It will be closed if no further activity occurs. Thank you for your contributions. |
Currently, the benchmarking tools make use of a multi-processing to be sure that all memory is released after each measurement and makes use of the
py3nvml
library to measure "peak GPU usage".After some internal discussion, it is questionable whether the current code gives peak GPU memory usage. Thus, I ran a couple of experiments to see how torch benchmarking differs from
py3nvml
. It is known that there are differences in the memory benchmarking as explained here: https://stackoverflow.com/questions/62257967/why-does-a-single-conv2d-with-10x10x3-take-up-850mb-of-gpu#_=_For a comparison, the following command was run:
The environment information is the following:
a) These are the results when running the command with the current code (
py3nvml
):b) These are the results when using the function
torch.cuda.max_memory_resevered(torch.cuda.current_device())
instead:One can see that the difference is always 856 MB (besides one exception where it is 868 MB). I ran the
py3nvml
benchmark multiple times and the result is very stable.The same holds true when benchmarking training.
=> I tend to think that the way the code is currently implemented, it actually gives the peak memory usage, even though I could not find proof in the https://github.com/fbcotter/py3nvml library.
@stas00 - what is your opinion on that?
The text was updated successfully, but these errors were encountered: