Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

No perf advantage for torch.compile on examples from pytorch tutorial #1721

Open
dvrogozh opened this issue Jul 29, 2024 · 4 comments
Open
Assignees
Labels
bug Something isn't working performance

Comments

@dvrogozh
Copy link
Contributor

dvrogozh commented Jul 29, 2024

I am trying pytorch tutorial for torch.compile(): https://pytorch.org/tutorials/intermediate/torch_compile_tutorial.html#demonstrating-speedups adopting it for xpu backend by s/cuda/xpu. Using pytorch/pytorch@f063027. Tutorial has performance examples demonstrating torch.compile advantage over eager mode for Nvidia. Unfortunately I don't observe similar benefits for xpu - torch.compile runs with similar speed as eager mode. Are there any optimization currently missing for XPU affecting these tutorials? This occurs for both examples in tutorial: for inference and for training.

Results (inference):

eager eval time 0: 1.468490231
eager eval time 1: 0.016250838
eager eval time 2: 0.015404673
eager eval time 3: 0.01476964
eager eval time 4: 0.014657789
eager eval time 5: 0.014552059
eager eval time 6: 0.014473312
eager eval time 7: 0.014476375
eager eval time 8: 0.014540959
eager eval time 9: 0.014519486
~~~~~~~~~~
compile eval time 0: 30.085278137
compile eval time 1: 0.016572904
compile eval time 2: 0.015478853
compile eval time 3: 0.015368476
compile eval time 4: 0.015215709
compile eval time 5: 0.015356365
compile eval time 6: 0.015324649
compile eval time 7: 0.015410529
compile eval time 8: 0.015309956
compile eval time 9: 0.015434349
~~~~~~~~~~
Traceback (most recent call last):
  File "/home/dvrogozh/examples/torch/tutorials/ex5.py", line 63, in <module>
    assert(speedup > 1)
AssertionError

Script (inference):

import time
import torch

def timed(fn):
    start =  time.time_ns()
    result = fn()
    torch.xpu.synchronize()
    return result, (time.time_ns() - start) / 1000000000

# Generates random input and targets data for the model, where `b` is
# batch size.
def generate_data(b):
    return (
        torch.randn(b, 3, 128, 128).to(torch.float32).xpu(),
        torch.randint(1000, (b,)).xpu(),
    )

N_ITERS = 10

from torchvision.models import densenet121
def init_model():
    return densenet121().to(torch.float32).xpu()

model = init_model()

# Reset since we are using a different mode.
import torch._dynamo
torch._dynamo.reset()

model_opt = torch.compile(model, mode="reduce-overhead")

eager_times = []
for i in range(N_ITERS):
    inp = generate_data(16)[0]
    with torch.no_grad():
        _, eager_time = timed(lambda: model(inp))
    eager_times.append(eager_time)
    print(f"eager eval time {i}: {eager_time}")

print("~" * 10)

compile_times = []
for i in range(N_ITERS):
    inp = generate_data(16)[0]
    with torch.no_grad():
        _, compile_time = timed(lambda: model_opt(inp))
    compile_times.append(compile_time)
    print(f"compile eval time {i}: {compile_time}")
print("~" * 10)

import numpy as np
eager_med = np.median(eager_times)
compile_med = np.median(compile_times)
speedup = eager_med / compile_med
assert(speedup > 1)
print(f"(eval) eager median: {eager_med}, compile median: {compile_med}, speedup: {speedup}x")
print("~" * 10)

Note that I did def timed implementation in tutorial to measure e2e time due to pytorch/pytorch#131840. Also note that I did try apply pytorch/pytorch#126456 - this did not change performance results for XPU backend.

@alexbaden
Copy link
Contributor

I am not getting the same results (latest llvm-target branch, LTS driver, and pytorch/pytorch@75f64e1):

» python ex5.py                                                                                              
eager eval time 0: 1.645833594
eager eval time 1: 0.133093097
eager eval time 2: 0.133950921
eager eval time 3: 0.144729233
eager eval time 4: 0.129245809
eager eval time 5: 0.129086332
eager eval time 6: 0.123269756
eager eval time 7: 0.134237029
eager eval time 8: 0.12613883
eager eval time 9: 0.132319863
~~~~~~~~~~
(I): Detected 2048 spills, recompiling the kernel using large GRF mode
(I): Kernel has now 0 spills
(I): Detected 2048 spills, recompiling the kernel using large GRF mode
(I): Kernel has now 0 spills
(I): Detected 4096 spills, recompiling the kernel using large GRF mode
(I): Kernel has now 0 spills
(I): Detected 4096 spills, recompiling the kernel using large GRF mode
(I): Kernel has now 0 spills
(I): Detected 4096 spills, recompiling the kernel using large GRF mode
(I): Kernel has now 0 spills
(I): Detected 2048 spills, recompiling the kernel using large GRF mode
(I): Kernel has now 0 spills
(I): Detected 1024 spills, recompiling the kernel using large GRF mode
(I): Kernel has now 0 spills
(I): Detected 1024 spills, recompiling the kernel using large GRF mode
(I): Kernel has now 0 spills
(I): Detected 2048 spills, recompiling the kernel using large GRF mode
(I): Kernel has now 0 spills
(I): Detected 2048 spills, recompiling the kernel using large GRF mode
(I): Kernel has now 0 spills
(I): Detected 4096 spills, recompiling the kernel using large GRF mode
(I): Kernel has now 0 spills
(I): Detected 8192 spills, recompiling the kernel using large GRF mode
(I): Kernel has now 0 spills
(I): Detected 8192 spills, recompiling the kernel using large GRF mode
(I): Kernel has now 0 spills
(I): Detected 1024 spills, recompiling the kernel using large GRF mode
(I): Kernel has now 0 spills
(I): Detected 2048 spills, recompiling the kernel using large GRF mode
(I): Kernel has now 0 spills
(I): Detected 2048 spills, recompiling the kernel using large GRF mode
(I): Kernel has now 0 spills
(I): Detected 4096 spills, recompiling the kernel using large GRF mode
(I): Kernel has now 0 spills
(I): Detected 8192 spills, recompiling the kernel using large GRF mode
(I): Kernel has now 0 spills
(I): Detected 8192 spills, recompiling the kernel using large GRF mode
(I): Kernel has now 0 spills
(I): Detected 1024 spills, recompiling the kernel using large GRF mode
(I): Kernel has now 0 spills
(I): Detected 2048 spills, recompiling the kernel using large GRF mode
(I): Kernel has now 0 spills
(I): Detected 2048 spills, recompiling the kernel using large GRF mode
(I): Kernel has now 0 spills
(I): Detected 4096 spills, recompiling the kernel using large GRF mode
(I): Kernel has now 0 spills
(I): Detected 8192 spills, recompiling the kernel using large GRF mode
(I): Kernel has now 0 spills
(I): Detected 8192 spills, recompiling the kernel using large GRF mode
(I): Kernel has now 0 spills
compile eval time 0: 100.98287199
compile eval time 1: 0.0182211
compile eval time 2: 0.015616999
compile eval time 3: 0.01512506
compile eval time 4: 0.015049746
compile eval time 5: 0.015074743
compile eval time 6: 0.015032466
compile eval time 7: 0.015071969
compile eval time 8: 0.015028161
compile eval time 9: 0.014950526
~~~~~~~~~~
(eval) eager median: 0.13270648000000002, compile median: 0.015073356, speedup: 8.804043372955567x

Perhaps there is some logging we can enable to find the difference? Can you try running with TORCH_LOGS="+dynamo" TORCHDYNAMO_VERBOSE=1?

I did not have time to rebuild pytorch now but I can also try that pytorch commit you used, though at first glance mine is much older.

@dvrogozh
Copy link
Contributor Author

dvrogozh commented Jul 29, 2024

@alexbaden : you did not reproduce my eager mode results, but your torch.compile results are similar to what I have. Your pytorch version is very old and I think eager mode simply falls back to CPU on some aten ops (silently because you are also missing intel/torch-xpu-ops#318). You are missing at least the following torch-xpu-ops updates which implemented a lot of aten ops:

$ git log --oneline 75f64e12030dfa6f621f1ec2b207892cf8660cdd..remotes/origin/main -- third_party/xpu.txt
dfba85c26bf Update torch-xpu-ops pin (ATen XPU implementation) (#131643)
b556d315868 Update torch-xpu-ops pin (ATen XPU implementation) (#131015)
cf090e222ea Update torch-xpu-ops pin (ATen XPU implementation) (#130333)
e98587c58d3 Update torch-xpu-ops pin (ATen XPU implementation) (#129353)

Update fyi: I tried pytorch/pytorch@75f64e1 + PR318. The following eager aten ops fall to cpu: aten::native_batch_norm, aten::max_pool2d_with_indices.out, aten::avg_pool2d.out, aten::_adaptive_avg_pool2d

@alexbaden
Copy link
Contributor

Got it, that makes sense. Let me update PyTorch to latest main and try again.

@dvrogozh
Copy link
Contributor Author

dvrogozh commented Aug 9, 2024

See #1770 for potential fix.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working performance
Projects
None yet
Development

Successfully merging a pull request may close this issue.

3 participants