Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Hetero R-GCN not matching OGB MAG Leaderboard performance #3511

Closed
zjost opened this issue Nov 16, 2021 · 12 comments
Closed

Hetero R-GCN not matching OGB MAG Leaderboard performance #3511

zjost opened this issue Nov 16, 2021 · 12 comments
Assignees
Labels
bug:unconfirmed May be a bug. Need further investigation.

Comments

@zjost
Copy link
Contributor

zjost commented Nov 16, 2021

🐛 Bug

In an effort to directly reproduce the "NeighborSampling (R-GCN aggr)" results using DGL, I get test performance of 40.17 +/- 0.62, compared to the PyG code linked on the leaderboard that gets 46.85 ± 0.48

To Reproduce

Steps to reproduce the behavior:

  1. I adapted the example in the dgl examples directory: examples/pytorch/rgcn-hetero/entity_classify_mb.py. My implementation can be found in this repo. Run with python rgcn_hetero_dgl.py

Expected behavior

I made an effort to exactly duplicate the PyG model architecture and parameter choices and I expect to get comparable performance.

Environment

Training on single GPU.

  • DGL 0.7.0 (pip)
  • PyTorch 1.9.0+cu111
  • Python 3.6.9
  • Linux OS
@jermainewang jermainewang added the bug:unconfirmed May be a bug. Need further investigation. label Nov 17, 2021
@jermainewang
Copy link
Member

Probably related to the effort in #3371

@zjost
Copy link
Contributor Author

zjost commented Nov 17, 2021

I have some additional information. One thing the PyG implementation was doing differently is that instead of reversing (paper, cites, paper) into a new relation like (paper, rev-cites, paper), it instead just made them into undirected edges (code). Making this change increased test accuracy from 40 to ~44.

Another significant difference is that in the dgl rgcn_hetero example, there is a single projection matrix for the self.loop_weight (code) that is shared for all self-connections, whereas in PyG there's a separate linear layer for each node-type (code). To be explicit, there's a separate relational linear layer for each edge-type, and a separate "self" linear layer for each node-type. With this setup, I'm able to match the same number of model parameters between PyG implementation and DGL. However, with a trial of 10 runs, performance is as follows:

      train_acc  valid_acc  test_acc
mean   0.932635   0.457381  0.438895
std    0.001636   0.003455  0.003325

One peculiar thing I've noticed is that the training accuracy in the case of PyG seems to reach a maximum of around 80%, but DGL is getting >90%. This suggests a larger degree of over-fitting, but it doesn't appear that PyG has any other forms of regularization.

@zjost
Copy link
Contributor Author

zjost commented Nov 18, 2021

Another update: I'm now able to slightly beat PyG. The only change was in capturing performance statistics in the same way. Previously I was reporting performance at the end of the 3 epochs, but they actually report train/test performance of the epoch in which the validation accuracy was maximized (code).

However, there still remains a discrepancy: the training accuracy is still significantly higher for DGL and it appears to overfit with fewer epochs. This may be an artifact of differences in initialization methods. I'll investigate this next.

Highest Train: 93.37 ± 0.20
Highest Valid: 49.03 ± 0.29
  Final Train: 76.15 ± 11.98
   Final Test: 47.66 ± 0.47

@zjost
Copy link
Contributor Author

zjost commented Nov 18, 2021

There were discrepancies between the methods used for initialization (Kaiming for PyG linear layers, Xavier for DGL linear layers). I have corrected this, but it did not have a significant impact on the results, particularly the training overfitting.

Highest Train: 93.68 ± 0.15
Highest Valid: 49.09 ± 0.48
  Final Train: 72.34 ± 11.93
   Final Test: 47.99 ± 0.68

Next I'll investigate the size of the gradients.

@zjost
Copy link
Contributor Author

zjost commented Nov 18, 2021

Interestingly, PyG seems to have larger gradients later in the training process than DGL. Below I compare the average L2 norm of the gradient for each parameter Tensor across an epoch. At the end of the first epoch, gradients are about the same magnitude:

image

However, by the end of the 3rd epoch, gradients are significantly larger for PyG.

image

This is likely just a reflection of the training loss differences. For both DGL and PyG, the training loss average is about 2.3 at the end of epoch 1. However, at the end of epoch 3, the training loss is 0.5 for DGL and 1.4 for PyG. I assume PyG's larger loss explains the larger gradients, but I'm still not clear why the training loss is so much lower for DGL.

@VoVAllen
Copy link
Collaborator

VoVAllen commented Nov 19, 2021

@zjost Hi, Could you try it with CPU?

One possible reason can be PyG's cuda kernel uses atomic operation for reduction, which might be nondeterministic, and might hurt convergence. If testing on cpu, I think it can be less affected, at least not affected by the kernel implementation.

@mufeili
Copy link
Member

mufeili commented Nov 19, 2021

By taking a look at your implementation and the PyG example, there can be a few additional potential implementation discrepancies:

  • The to_undirected function here will remove duplicate edges if any while your implementation does not.
  • You passed split_idx['train'] to NodeDataLoader while the PyG example passed split_idx['train']['paper'] to NeighborSampler. It's likely split_idx['train'] only contains the paper node IDs though.
  • You did not call model.train() after the first time you called model.eval() in test, which matters when you use dropout.

@zjost
Copy link
Contributor Author

zjost commented Nov 19, 2021

Thanks for the tips! I'll address results of each item:

  • Using CPU vs GPU had no observable effect
  • It's correct that I had duplicate edges (about 40k of them). Removing did not have an observable effect, but I will continue deduplication.
  • split_idx['train'] only includes paper nodes, so there is no difference that I can see
  • It's correct that model.train() should have been called at the start of each epoch since the test function switched to model.eval() and effectively turned off dropout. This reduced the training accuracy at the end of 3 epochs from 94 -> 85.
    • This also explains why things would be more or less the same after epoch 1, but then begin to overfit in subsequent epochs since dropout was removed.

The result of all this is that now the distribution of Training Accuracy at end of 3rd epoch is 84.38 ± 0.91 compared to PyG of 80.03 ± 0.40. This is still significantly significant, but much closer than before.

Next I'll try to manually run some forward/backward passes using a handful of nodes to see if there are any other clues. Thanks again for all the great suggestions--I would have never found the "model.train" thing.

@zjost
Copy link
Contributor Author

zjost commented Nov 20, 2021

A few more discoveries. If I manually set identical parameters between DGL and PyG models, and pass the same minibatch through (with all neighbors used, no sampling), I get identical outputs and gradients for all the parameters. This suggests there's no fundamental computation problem in forward/backward pass.

One difference I've noticed is that the number of sampled neighbors do not match. This seems to be because with a fanout of [25, 20], DGL will sample that number of neighbors per edge type, whereas PyG samples that number and ignore edge type. The consequence is that DGL touches many more nodes. This likely results in more parameter updates in the training nodes and is likely resulting in faster over-fitting on the training nodes and convergence in general.

@zjost
Copy link
Contributor Author

zjost commented Nov 23, 2021

Final update. I found that if I scaled the fan-out numbers in DGL by the number of (incoming) edge-types, I get about the same neighborhood sizes. If I re-run 10 training iterations using fanout = [6, 5], I get the following results:

All runs:
Highest Train: 80.09 ± 0.29
Highest Valid: 45.67 ± 0.44
  Final Train: 70.14 ± 6.29
   Final Test: 44.77 ± 0.55

This shows for the first time that the Training accuracy at the end of 3 epochs is matching PyG results, which adds credibility to the idea that the training data was overfitting more quickly in DGL because more nodes were involved in the computation and therefore getting embedding gradients.

However, the Final Test accuracy has dropped significantly from 48% to 45%. I suspect the difference is that sampling N nodes for each edge type is different than sampling N*num_edge_types, since in the first case edge-types will be represented equally whereas in the second the distribution of edge-types will reflect relative frequencies observed in the underlying graph.

Final question: would there be value in me submitting a PR for this example that uses the Leaderboard architecture and achieves the higher final Test accuracy? Beyond answering that, feel free to close this issue.

@mufeili
Copy link
Member

mufeili commented Nov 29, 2021

Hi @zjost , thanks for the great efforts. It will be greatly appreciated if you can submit a PR and add an example in this folder for ogbn-mag.

@zjost
Copy link
Contributor Author

zjost commented Dec 3, 2021

Closing this issue since I created a PR. Thanks again for the help!

@zjost zjost closed this as completed Dec 3, 2021
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug:unconfirmed May be a bug. Need further investigation.
Projects
None yet
Development

No branches or pull requests

4 participants