Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Let ImbalancedSampler accept torch.Tensor as input #5138

Merged
merged 12 commits into from
Aug 8, 2022

Conversation

EdisonLeeeee
Copy link
Contributor

@EdisonLeeeee EdisonLeeeee commented Aug 4, 2022

Hi, I modified torch_geometric.loader.ImbalancedSampler to accept torch.Tensor object, i.e., the class distribution as input. Now it is supported to use the class labels to initialize the sampler, as like:

from torch_geometric.loader import NeighborLoader, ImbalancedSampler

sampler = ImbalancedSampler(data.y, input_nodes=data.train_mask)
loader = NeighborLoader(data.y, input_nodes=data.train_mask,
                        batch_size=64, num_neighbors=[-1, -1],
                        sampler=sampler, ...)    
                                sampler=sampler, ...)  

And the issue in #5108 should be fixed with minimal adaptation:

from torch_geometric.datasets import DBLP
from torch_geometric.loader import ImbalancedSampler, NeighborLoader

dataset = DBLP(root='~/data/pygdata/dblp')
data = dataset[0]

# NOTE: input the node labels instead of the dataset `data['author']`
sampler = ImbalancedSampler(data['author'].y, input_nodes=data['author'].train_mask) 

train_loader = NeighborLoader(data, input_nodes=('author', data['author'].train_mask),
                              num_neighbors=[10, 10], 
                              sampler=sampler, batch_size=20)

@codecov
Copy link

codecov bot commented Aug 4, 2022

Codecov Report

Merging #5138 (f8de132) into master (7692969) will increase coverage by 0.00%.
The diff coverage is 100.00%.

❗ Current head f8de132 differs from pull request most recent head 59099ee. Consider uploading reports for the commit 59099ee to get more accurate results

@@           Coverage Diff           @@
##           master    #5138   +/-   ##
=======================================
  Coverage   82.98%   82.98%           
=======================================
  Files         333      333           
  Lines       18368    18371    +3     
=======================================
+ Hits        15243    15246    +3     
  Misses       3125     3125           
Impacted Files Coverage Δ
torch_geometric/loader/imbalanced_sampler.py 88.46% <100.00%> (+1.50%) ⬆️

📣 Codecov can now indicate which changes are the most critical in Pull Requests. Learn more

Copy link
Contributor Author

@EdisonLeeeee EdisonLeeeee left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fix several typos

@EdisonLeeeee EdisonLeeeee changed the title Fix ImbalancedSampled to accept torch.Tensor as input (#5108) Fix ImbalancedSampled to accept torch.Tensor as input Aug 4, 2022
@rusty1s rusty1s changed the title Fix ImbalancedSampled to accept torch.Tensor as input Let ImbalancedSampler accept torch.Tensor as input Aug 6, 2022
Copy link
Member

@rusty1s rusty1s left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks!

@rusty1s rusty1s enabled auto-merge (squash) August 8, 2022 14:39
@rusty1s rusty1s merged commit 6e88368 into pyg-team:master Aug 8, 2022
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

AttributeError: 'str' object has no attribute 'y' when using torch_geometric.loader.ImbalancedSampler
3 participants