Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support multiple node type sampling in NeighborLoader #5013

Closed
wants to merge 17 commits into from

Conversation

Padarn
Copy link
Contributor

@Padarn Padarn commented Jul 20, 2022

This PR adds functionality to allow for multiple node types to be sampled in NeighbourLoader.

The interface looks as was discussed in the roadmap (#4765):

NeighbourLoader(
   input_nodes=[
     ('paper', torch.LongTensor([0,1,2])), 
     ('author', torch.LongTensor([0,1,2])) 
   ]
  ...
)

Internally, it converts this to a list of tuples.

[('paper', 0), ('paper', 1),....]

This is not very efficient, but benchmarks #4765 (comment) showed it to be acceptable.

TODO:

  • Add tests
  • Add support for None instead of providing specific nodes for some node types

Addresses #4765

def __call__(self, index: Union[List[int], Tensor]):
if not isinstance(index, torch.LongTensor):
index = torch.LongTensor(index)
def __call__(self, index: Union[List[int], Tensor, HeteroNodeList]):
Copy link
Contributor Author

@Padarn Padarn Aug 9, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@wsad1 Can I get your opinion on this (note the PR is overall still somewhat WIP).

I was finding the logic here getting quite complicated. Do you have any suggestions for simplification? One way would be to split the sampler into one class for hetero and one for non-hetero, and handling any conversions in the collate_fn.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the update. I'll check it out once I am back from a short vacation on Tuesday morning.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No rush, thank you.

@Padarn Padarn changed the title [WIP] Support multiple node type sampling in NeighborLoader Support multiple node type sampling in NeighborLoader Aug 14, 2022
@Padarn Padarn requested review from rusty1s and wsad1 August 14, 2022 04:15
@Padarn Padarn force-pushed the padarn/neighbour-multi-node branch from af2eec9 to 043ec22 Compare August 14, 2022 06:44
], batch_size=batch_size, directed=directed, shuffle=False)

for batch1, batch2, batch3 in zip(loader, loader2, loader3):
assert torch.allclose(batch1['paper'].x, batch2['paper'].x)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This test is flakey - works locally but not in the CI. Any suggestions for what might be better?

Copy link
Member

@wsad1 wsad1 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks at @Padarn for the update. I am still looking at your code to see if we can simplify things. Added some initial comments.

test/loader/test_neighbor_loader.py Show resolved Hide resolved
@Padarn
Copy link
Contributor Author

Padarn commented Aug 19, 2022

Thanks. FYI I'll also do this for the link loader and add examples in separate PRs, so we can iterate on the complexity later too if it seems okay but not perfect.

@Padarn
Copy link
Contributor Author

Padarn commented Sep 3, 2022

Hey @mananshah99 do you also want to take a look at this one? May need to merge it with #5312

Copy link
Member

@wsad1 wsad1 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added some more comments will take a look again.

Comment on lines 554 to 546
def to_hetero_list(input_nodes: List[Tuple[str, Tensor]]) -> HeteroNodeList:
return [(node_type, i) for node_type, index in input_nodes for i in index]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There might be cases where we return nodes of only one type. Should we mention that this can't be used to train a model which predicts on multiple node types?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm I didn't fully understand this - do you mean some samples from the neighbor sampler will only contain a single type?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah it might contain samples from only one type.
I am not sure what the use case of this multiple node sampling is. If an envisioned use case is is predicting on multiple node types in a hetero graph, that might not work.

torch_geometric/loader/neighbor_loader.py Outdated Show resolved Hide resolved
@Padarn Padarn force-pushed the padarn/neighbour-multi-node branch from 67a3a54 to 6ce3cf0 Compare September 5, 2022 12:24
@codecov
Copy link

codecov bot commented Sep 5, 2022

Codecov Report

Merging #5013 (6ce3cf0) into master (79617e0) will decrease coverage by 1.94%.
The diff coverage is 86.00%.

❗ Current head 6ce3cf0 differs from pull request most recent head 9af624f. Consider uploading reports for the commit 9af624f to get more accurate results

@@            Coverage Diff             @@
##           master    #5013      +/-   ##
==========================================
- Coverage   85.27%   83.32%   -1.95%     
==========================================
  Files         338      338              
  Lines       18683    18709      +26     
==========================================
- Hits        15931    15590     -341     
- Misses       2752     3119     +367     
Impacted Files Coverage Δ
torch_geometric/data/lightning_datamodule.py 48.82% <ø> (ø)
torch_geometric/loader/neighbor_loader.py 92.22% <84.78%> (-2.55%) ⬇️
torch_geometric/typing.py 100.00% <100.00%> (ø)
torch_geometric/nn/models/dimenet_utils.py 0.00% <0.00%> (-75.52%) ⬇️
torch_geometric/nn/models/dimenet.py 14.51% <0.00%> (-53.00%) ⬇️
torch_geometric/profile/profile.py 37.89% <0.00%> (-26.32%) ⬇️
torch_geometric/nn/conv/utils/typing.py 81.25% <0.00%> (-17.50%) ⬇️
torch_geometric/nn/inits.py 67.85% <0.00%> (-7.15%) ⬇️
torch_geometric/transforms/add_self_loops.py 94.44% <0.00%> (-5.56%) ⬇️
torch_geometric/nn/resolver.py 88.88% <0.00%> (-5.56%) ⬇️
... and 12 more

📣 We’re building smart automated test selection to slash your CI/CD build times. Learn more

Copy link
Contributor

@mananshah99 mananshah99 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the ping @Padarn. Left a few comments, happy to chat further as well.

@@ -147,7 +152,8 @@ def _set_num_neighbors_and_num_hops(self, num_neighbors):
# Add at least one element to the list to ensure `max` is well-defined
self.num_hops = max([0] + [len(v) for v in num_neighbors.values()])

def _sparse_neighbor_sample(self, index: Tensor):
def _sparse_neighbor_sample(self, index: Union[List[int], Tensor]):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why? Don't we convert to tensors beforehand?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah because that made the logic in __call__ quite confusing. You would have first to check whether or not it was hetero and then if it was handling the conversion separately if it was a mixed input. To me it seemed simpler to handle these cases in the individual functions, as it was actually less code.

torch_geometric/loader/neighbor_loader.py Outdated Show resolved Hide resolved
Comment on lines +219 to +225
batch_sizes = {
node_type: index.numel()
for node_type, index in index_dict.items()
}

if self.data_cls != 'custom' and issubclass(self.data_cls, Data):
return self._sparse_neighbor_sample(index) + (index.numel(), )
return self._hetero_sparse_neighbor_sample(index_dict) + (
batch_sizes, )
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This makes sense, and you are right that it somewhat conflicts with the interface in #5312. I think this is totally okay; that interface will likely change significantly over the coming week or two (we also need to support link-level neighbor sampling, etc.).

If it is alright with you, I would propose first merging 5312, and then adapting that interface as part of this PR to support returning a dict of batch sizes for each node type. Wdyt?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agreed! I see its merged now so let me rethink this a little based on what you've added.

Comment on lines +534 to +535
def get_mixed_sampling_nodes(data: HeteroData,
input_nodes: List[InputNodes]) -> SamplingNodes:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Document or perhaps rename? mixed isn't super clear to me (this is just getting sampling nodes for different node types, right?)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes good point, and I think you're right we could merge these two functions. I originally spit them because it was being used in the sampler init, but I've removed this. Let me clean it up.

Comment on lines +541 to +546
def get_node_types(input_nodes: List[Tuple[str, Tensor]]) -> List[str]:
return [node_type for node_type, index in input_nodes]


def get_node_list(input_nodes: List[Tuple[str, Tensor]]) -> HeteroNodeList:
return [(node_type, i) for node_type, index in input_nodes for i in index]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do these need to be separate functions?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

👍 good point

@Padarn
Copy link
Contributor Author

Padarn commented Sep 10, 2022

Thanks for the comments @mananshah99. Sorry I've been busy with my day job, will review and address over the weekend.

Co-authored-by: Manan Shah <manan.shah.777@gmail.com>
@Padarn
Copy link
Contributor Author

Padarn commented Sep 10, 2022

On second look I think I'll wait for you to finish your current refactoring PRs, the code has changed a lot and I'll have to fit into the new interface. Will focus on helping review you PRs first and then rework this one.

@mananshah99
Copy link
Contributor

On second look I think I'll wait for you to finish your current refactoring PRs, the code has changed a lot and I'll have to fit into the new interface. Will focus on helping review you PRs first and then rework this one.

Thank you for accommodating :) The refactoring PRs are complete now, and the interface is mostly stable. Happy to help move this implementation over behind the new interface, it's pretty cool.

@Padarn
Copy link
Contributor Author

Padarn commented Sep 21, 2022

Great! I can refactor later this week. I'll probably start a new PR as I think most of the code need to move, will definitely ask you for a review.

@Padarn Padarn closed this Sep 24, 2022
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants