-
Notifications
You must be signed in to change notification settings - Fork 3k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Hetero R-GCN not matching OGB MAG Leaderboard performance #3511
Comments
Probably related to the effort in #3371 |
I have some additional information. One thing the PyG implementation was doing differently is that instead of reversing Another significant difference is that in the dgl rgcn_hetero example, there is a single projection matrix for the
One peculiar thing I've noticed is that the training accuracy in the case of PyG seems to reach a maximum of around 80%, but DGL is getting >90%. This suggests a larger degree of over-fitting, but it doesn't appear that PyG has any other forms of regularization. |
Another update: I'm now able to slightly beat PyG. The only change was in capturing performance statistics in the same way. Previously I was reporting performance at the end of the 3 epochs, but they actually report train/test performance of the epoch in which the validation accuracy was maximized (code). However, there still remains a discrepancy: the training accuracy is still significantly higher for DGL and it appears to overfit with fewer epochs. This may be an artifact of differences in initialization methods. I'll investigate this next.
|
There were discrepancies between the methods used for initialization (Kaiming for PyG linear layers, Xavier for DGL linear layers). I have corrected this, but it did not have a significant impact on the results, particularly the training overfitting.
Next I'll investigate the size of the gradients. |
Interestingly, PyG seems to have larger gradients later in the training process than DGL. Below I compare the average L2 norm of the gradient for each parameter Tensor across an epoch. At the end of the first epoch, gradients are about the same magnitude: However, by the end of the 3rd epoch, gradients are significantly larger for PyG. This is likely just a reflection of the training loss differences. For both DGL and PyG, the training loss average is about 2.3 at the end of epoch 1. However, at the end of epoch 3, the training loss is 0.5 for DGL and 1.4 for PyG. I assume PyG's larger loss explains the larger gradients, but I'm still not clear why the training loss is so much lower for DGL. |
@zjost Hi, Could you try it with CPU? One possible reason can be PyG's cuda kernel uses atomic operation for reduction, which might be nondeterministic, and might hurt convergence. If testing on cpu, I think it can be less affected, at least not affected by the kernel implementation. |
By taking a look at your implementation and the PyG example, there can be a few additional potential implementation discrepancies:
|
Thanks for the tips! I'll address results of each item:
The result of all this is that now the distribution of Training Accuracy at end of 3rd epoch is Next I'll try to manually run some forward/backward passes using a handful of nodes to see if there are any other clues. Thanks again for all the great suggestions--I would have never found the "model.train" thing. |
A few more discoveries. If I manually set identical parameters between DGL and PyG models, and pass the same minibatch through (with all neighbors used, no sampling), I get identical outputs and gradients for all the parameters. This suggests there's no fundamental computation problem in forward/backward pass. One difference I've noticed is that the number of sampled neighbors do not match. This seems to be because with a fanout of |
Final update. I found that if I scaled the fan-out numbers in DGL by the number of (incoming) edge-types, I get about the same neighborhood sizes. If I re-run 10 training iterations using
This shows for the first time that the Training accuracy at the end of 3 epochs is matching PyG results, which adds credibility to the idea that the training data was overfitting more quickly in DGL because more nodes were involved in the computation and therefore getting embedding gradients. However, the Final Test accuracy has dropped significantly from 48% to 45%. I suspect the difference is that sampling N nodes for each edge type is different than sampling Final question: would there be value in me submitting a PR for this example that uses the Leaderboard architecture and achieves the higher final Test accuracy? Beyond answering that, feel free to close this issue. |
Hi @zjost , thanks for the great efforts. It will be greatly appreciated if you can submit a PR and add an example in this folder for ogbn-mag. |
Closing this issue since I created a PR. Thanks again for the help! |
🐛 Bug
In an effort to directly reproduce the "NeighborSampling (R-GCN aggr)" results using DGL, I get test performance of
40.17 +/- 0.62
, compared to the PyG code linked on the leaderboard that gets46.85 ± 0.48
To Reproduce
Steps to reproduce the behavior:
examples/pytorch/rgcn-hetero/entity_classify_mb.py
. My implementation can be found in this repo. Run withpython rgcn_hetero_dgl.py
Expected behavior
I made an effort to exactly duplicate the PyG model architecture and parameter choices and I expect to get comparable performance.
Environment
Training on single GPU.
The text was updated successfully, but these errors were encountered: