-
Notifications
You must be signed in to change notification settings - Fork 3.7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Support multiple node type sampling in NeighborLoader
#5013
Conversation
2d173e5
to
43b78e8
Compare
43b78e8
to
63b70b3
Compare
def __call__(self, index: Union[List[int], Tensor]): | ||
if not isinstance(index, torch.LongTensor): | ||
index = torch.LongTensor(index) | ||
def __call__(self, index: Union[List[int], Tensor, HeteroNodeList]): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@wsad1 Can I get your opinion on this (note the PR is overall still somewhat WIP).
I was finding the logic here getting quite complicated. Do you have any suggestions for simplification? One way would be to split the sampler into one class for hetero and one for non-hetero, and handling any conversions in the collate_fn.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the update. I'll check it out once I am back from a short vacation on Tuesday morning.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No rush, thank you.
NeighborLoader
NeighborLoader
af2eec9
to
043ec22
Compare
test/loader/test_neighbor_loader.py
Outdated
], batch_size=batch_size, directed=directed, shuffle=False) | ||
|
||
for batch1, batch2, batch3 in zip(loader, loader2, loader3): | ||
assert torch.allclose(batch1['paper'].x, batch2['paper'].x) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This test is flakey - works locally but not in the CI. Any suggestions for what might be better?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks at @Padarn for the update. I am still looking at your code to see if we can simplify things. Added some initial comments.
Thanks. FYI I'll also do this for the link loader and add examples in separate PRs, so we can iterate on the complexity later too if it seems okay but not perfect. |
Hey @mananshah99 do you also want to take a look at this one? May need to merge it with #5312 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Added some more comments will take a look again.
def to_hetero_list(input_nodes: List[Tuple[str, Tensor]]) -> HeteroNodeList: | ||
return [(node_type, i) for node_type, index in input_nodes for i in index] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There might be cases where we return nodes of only one type. Should we mention that this can't be used to train a model which predicts on multiple node types?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hmm I didn't fully understand this - do you mean some samples from the neighbor sampler will only contain a single type?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah it might contain samples from only one type.
I am not sure what the use case of this multiple node sampling is. If an envisioned use case is is predicting on multiple node types in a hetero graph, that might not work.
Co-authored-by: Jinu Sunil <jinu.sunil@gmail.com>
67a3a54
to
6ce3cf0
Compare
Codecov Report
@@ Coverage Diff @@
## master #5013 +/- ##
==========================================
- Coverage 85.27% 83.32% -1.95%
==========================================
Files 338 338
Lines 18683 18709 +26
==========================================
- Hits 15931 15590 -341
- Misses 2752 3119 +367
📣 We’re building smart automated test selection to slash your CI/CD build times. Learn more |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the ping @Padarn. Left a few comments, happy to chat further as well.
@@ -147,7 +152,8 @@ def _set_num_neighbors_and_num_hops(self, num_neighbors): | |||
# Add at least one element to the list to ensure `max` is well-defined | |||
self.num_hops = max([0] + [len(v) for v in num_neighbors.values()]) | |||
|
|||
def _sparse_neighbor_sample(self, index: Tensor): | |||
def _sparse_neighbor_sample(self, index: Union[List[int], Tensor]): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why? Don't we convert to tensors beforehand?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ah because that made the logic in __call__
quite confusing. You would have first to check whether or not it was hetero and then if it was handling the conversion separately if it was a mixed input. To me it seemed simpler to handle these cases in the individual functions, as it was actually less code.
batch_sizes = { | ||
node_type: index.numel() | ||
for node_type, index in index_dict.items() | ||
} | ||
|
||
if self.data_cls != 'custom' and issubclass(self.data_cls, Data): | ||
return self._sparse_neighbor_sample(index) + (index.numel(), ) | ||
return self._hetero_sparse_neighbor_sample(index_dict) + ( | ||
batch_sizes, ) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This makes sense, and you are right that it somewhat conflicts with the interface in #5312. I think this is totally okay; that interface will likely change significantly over the coming week or two (we also need to support link-level neighbor sampling, etc.).
If it is alright with you, I would propose first merging 5312, and then adapting that interface as part of this PR to support returning a dict of batch sizes for each node type. Wdyt?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Agreed! I see its merged now so let me rethink this a little based on what you've added.
def get_mixed_sampling_nodes(data: HeteroData, | ||
input_nodes: List[InputNodes]) -> SamplingNodes: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Document or perhaps rename? mixed
isn't super clear to me (this is just getting sampling nodes for different node types, right?)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes good point, and I think you're right we could merge these two functions. I originally spit them because it was being used in the sampler init, but I've removed this. Let me clean it up.
def get_node_types(input_nodes: List[Tuple[str, Tensor]]) -> List[str]: | ||
return [node_type for node_type, index in input_nodes] | ||
|
||
|
||
def get_node_list(input_nodes: List[Tuple[str, Tensor]]) -> HeteroNodeList: | ||
return [(node_type, i) for node_type, index in input_nodes for i in index] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do these need to be separate functions?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
👍 good point
Thanks for the comments @mananshah99. Sorry I've been busy with my day job, will review and address over the weekend. |
Co-authored-by: Manan Shah <manan.shah.777@gmail.com>
On second look I think I'll wait for you to finish your current refactoring PRs, the code has changed a lot and I'll have to fit into the new interface. Will focus on helping review you PRs first and then rework this one. |
Thank you for accommodating :) The refactoring PRs are complete now, and the interface is mostly stable. Happy to help move this implementation over behind the new interface, it's pretty cool. |
Great! I can refactor later this week. I'll probably start a new PR as I think most of the code need to move, will definitely ask you for a review. |
This PR adds functionality to allow for multiple node types to be sampled in
NeighbourLoader
.The interface looks as was discussed in the roadmap (#4765):
Internally, it converts this to a list of tuples.
This is not very efficient, but benchmarks #4765 (comment) showed it to be acceptable.
TODO:
None
instead of providing specific nodes for some node typesAddresses #4765