Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Custom propagate is not taken into account by the jit compiler in PyG 2.5.0 and 2.5.1 #9077

Closed
migalkin opened this issue Mar 20, 2024 · 11 comments
Labels

Comments

@migalkin
Copy link
Contributor

migalkin commented Mar 20, 2024

🐛 Describe the bug

Setup: running ULTRA and, in particular, the Generalized Relational Convolution with a custom rspmm cuda/cpu kernel works flawlessly with PyTorch 2.1 and PyG 2.4.

Updating the env to torch 2.2.1 and PyG 2.5.0 / 2.5.1 results in the JIT compilation not taking into account the custom propagate function implemented in the layer. I see the compiled layer file in ~/.cache/pyg/message_passing/ generated from the original propagate function from MessagePassing and it never invokes the custom propagate function.

With that, a lot of other errors arise, such that missing index and dim_size kwargs for the aggregate function that are originally collected by self._collect.

Besides, even after explicitly defining all the necessary kwargs in the self.propagate call, the inference time on a standard fb15k237 dataset increases from 3 sec to 180 sec (on M2 Max laptop) 📉 . I was wondering about a few questions therefore:

  • Is it the expected behavior that JIT compiler ignores the custom propagate function in the layer?
  • Is the JIT compilation enabled by default in all PyG models from now on?
  • Is there a way to disable JIT compilation of PyG models (at least for models with custom kernels) ?

Thanks for looking into that!

P.S. PyG 2.5.0 compiles layers into ~/.cache/pyg/message_passing while 2.5.1 compiles into /var/folder/<some_gibberish> - is it ok?

Versions

PyTorch version: 2.2.1
Is debug build: False
CUDA used to build PyTorch: None
ROCM used to build PyTorch: N/A

OS: macOS 14.2.1 (arm64)
GCC version: Could not collect
Clang version: 15.0.0 (clang-1500.1.0.2.5)
CMake version: version 3.25.3
Libc version: N/A

Python version: 3.9.18 | packaged by conda-forge | (main, Dec 23 2023, 16:35:41)  [Clang 16.0.6 ] (64-bit runtime)
Python platform: macOS-14.2.1-arm64-arm-64bit
Is CUDA available: False
CUDA runtime version: No CUDA
CUDA_MODULE_LOADING set to: N/A
GPU models and configuration: No CUDA
Nvidia driver version: No CUDA
cuDNN version: No CUDA
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True

CPU:
Apple M2 Max

Versions of relevant libraries:
[pip3] numpy==1.26.4
[pip3] torch==2.2.1
[pip3] torch_cluster==1.6.3
[pip3] torch_geometric==2.5.0
[pip3] torch_scatter==2.1.2
[pip3] torch_sparse==0.6.18
[conda] numpy                     1.26.4                   pypi_0    pypi
[conda] torch                     2.2.1                    pypi_0    pypi
[conda] torch-cluster             1.6.3                    pypi_0    pypi
[conda] torch-geometric           2.5.0                    pypi_0    pypi
[conda] torch-scatter             2.1.2                    pypi_0    pypi
[conda] torch-sparse              0.6.18                   pypi_0    pypi
@migalkin migalkin added the bug label Mar 20, 2024
@migalkin migalkin changed the title Custom propagate is not taken into account by the jit compiler in PyG 2.5.1 Custom propagate is not taken into account by the jit compiler in PyG 2.5.0 and 2.5.1 Mar 20, 2024
@rusty1s
Copy link
Member

rusty1s commented Mar 20, 2024

Mh, thanks for bringing this up. Can you do me a favor and check if #9079 resolves your issues?

@migalkin
Copy link
Contributor Author

Yes, that fixes the jit issue!

The downside is that the inference is still 2x slower than in 2.4, with minimal changes to the code (replacing inspector.distribute() from older versions to inspector.collect_param_data()).

2.4 version:
Screenshot 2024-03-20 at 3 54 41 PM

#9079 branch:
Screenshot 2024-03-20 at 3 48 30 PM

Perhaps there are other moving parts involved into that and updated in the newer version, hard to say more without profiling

@rusty1s
Copy link
Member

rusty1s commented Mar 21, 2024

Can you share some information on how I can benchmark this?

@migalkin
Copy link
Contributor Author

migalkin commented Mar 21, 2024

Sure, I created the pyg2.5 branch in the repo: https://github.com/DeepGraphLearning/ULTRA/tree/pyg2.5

From there, I run this:

python script/run.py -c config/transductive/inference.yaml --dataset FB15k237 --epochs 0 --bpe null --gpus null --ckpt /<your path to the repos>/ULTRA/ckpts/ultra_4g.pth --bs 64

and look at the tqdm stats

@rusty1s
Copy link
Member

rusty1s commented Mar 25, 2024

Thanks. Will try to reproduce, and create PyG 2.5.3 afterwards.

@rusty1s
Copy link
Member

rusty1s commented Mar 26, 2024

Thanks. I looked into this. I cannot spot any performance degradation within MessagePassing, but your generalized rspmm kernel is a lot slower, probably because of changed feature dimension.

  • on main: output of generalized rspmm is torch.Size([14541, 512])
  • on pyg2.5: output of generalized rspmm is torch.Size([14541, 4096])

@migalkin
Copy link
Contributor Author

Ah the default batch size in the main branch in config/transductive/inference.yaml is 8 instead of 64 (hence the flattened shape is num_nodes x 64 x batch_size = 512) - could you please try with batch size 64 in main?

@rusty1s
Copy link
Member

rusty1s commented Mar 27, 2024

Mh, for me both models take equally long to run (relation_model around 0.16s and entity_model around 3.51s), and tqdm outputs equal runtime as well.

@rusty1s
Copy link
Member

rusty1s commented Mar 27, 2024

Can you double check on your end? What does

torch.cuda.synchronize()
t = time.perf_counter()
t_pred = model(test_data, t_batch)
h_pred = model(test_data, h_batch)
torch.cuda.synchronize()
print(time.perf_counter() - t)

return for you in both versions?

@migalkin
Copy link
Contributor Author

Confirm, on GPUs both versions take pretty much the same time (running on GPUs is anyways the most standard setup):

  • main
Screen Shot 2024-03-27 at 7 47 30 PM
  • pyg2.5
Screen Shot 2024-03-27 at 8 02 19 PM

The slowdown is then probably due to some problems with newer pytorch versions on M1/M2 - but that's definitely out of scope of the current issue.

TL;DR for future readers: The issue with custom propagate was fixed perfectly fine and PyG 2.5 does not affect performance on GPUs 🎉

Thanks for looking into this!

@rusty1s
Copy link
Member

rusty1s commented Mar 28, 2024

🙏

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

No branches or pull requests

2 participants