Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Let ImbalancedSampler accept torch.Tensor as input #5138

Merged
merged 12 commits into from
Aug 8, 2022
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

## [2.0.5] - 2022-MM-DD
### Added
- Let `ImbalancedSampler` accept `torch.Tensor` as input ([#5138](https://github.com/pyg-team/pytorch_geometric/pull/5138))
- `NeighborSampler` supports graphs without edges ([#5072](https://github.com/pyg-team/pytorch_geometric/pull/5072))
- Added the `MeanSubtractionNorm` layer ([#5068](https://github.com/pyg-team/pytorch_geometric/pull/5068))
- Added `pyg_lib.segment_matmul` integration within `RGCNConv` ([#5052](https://github.com/pyg-team/pytorch_geometric/pull/5052), [#5096](https://github.com/pyg-team/pytorch_geometric/pull/5096))
Expand Down
28 changes: 20 additions & 8 deletions test/loader/test_imbalanced_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,16 +22,21 @@ def test_dataloader_with_imbalanced_sampler():
sampler = ImbalancedSampler(data_list)
loader = DataLoader(data_list, batch_size=10, sampler=sampler)

ys: List[Tensor] = []
for batch in loader:
ys.append(batch.y)
y = torch.cat([batch.y for batch in loader])

histogram = torch.cat(ys).bincount()
histogram = y.bincount()
prob = histogram / histogram.sum()

assert histogram.sum() == len(data_list)
assert prob.min() > 0.4 and prob.max() < 0.6

# Test with label tensor as input:
torch.manual_seed(12345)
sampler = ImbalancedSampler(torch.tensor([data.y for data in data_list]))
loader = DataLoader(data_list, batch_size=10, sampler=sampler)

assert torch.allclose(y, torch.cat([batch.y for batch in loader]))


def test_neighbor_loader_with_imbalanced_sampler():
zeros = torch.zeros(10, dtype=torch.long)
Expand All @@ -41,17 +46,24 @@ def test_neighbor_loader_with_imbalanced_sampler():
edge_index = torch.empty((2, 0), dtype=torch.long)
data = Data(edge_index=edge_index, y=y, num_nodes=y.size(0))

# Test with data instance as input:
torch.manual_seed(12345)
sampler = ImbalancedSampler(data)
loader = NeighborLoader(data, batch_size=10, sampler=sampler,
num_neighbors=[-1])

ys: List[Tensor] = []
for batch in loader:
ys.append(batch.y)
y = torch.cat([batch.y for batch in loader])

histogram = torch.cat(ys).bincount()
histogram = y.bincount()
prob = histogram / histogram.sum()

assert histogram.sum() == data.num_nodes
assert prob.min() > 0.4 and prob.max() < 0.6

# Test with label tensor as input:
torch.manual_seed(12345)
sampler = ImbalancedSampler(data.y)
loader = NeighborLoader(data, batch_size=10, sampler=sampler,
num_neighbors=[-1])

assert torch.allclose(y, torch.cat([batch.y for batch in loader]))
25 changes: 21 additions & 4 deletions torch_geometric/loader/imbalanced_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,10 +33,23 @@ class distribution.
batch_size=64, num_neighbors=[-1, -1],
sampler=sampler, ...)

You can also pass in the class labels directly as a :class:`torch.Tensor`:

.. code-block:: python

from torch_geometric.loader import NeighborLoader, ImbalancedSampler

sampler = ImbalancedSampler(data.y)
loader = NeighborLoader(data, input_nodes=data.train_mask,
batch_size=64, num_neighbors=[-1, -1],
sampler=sampler, ...)

Args:
dataset (Dataset or Data): The dataset from which to sample the data,
either given as a :class:`~torch_geometric.data.Dataset` or
:class:`~torch_geometric.data.Data` object.
dataset (Dataset or Data or Tensor): The dataset or class distribution
from which to sample the data, given either as a
:class:`~torch_geometric.data.Dataset`,
:class:`~torch_geometric.data.Data`, or :class:`torch.Tensor`
object.
input_nodes (Tensor, optional): The indices of nodes that are used by
the corresponding loader, *e.g.*, by
:class:`~torch_geometric.loader.NeighborLoader`.
Expand All @@ -50,7 +63,7 @@ class distribution.
"""
def __init__(
self,
dataset: Union[Data, Dataset, List[Data]],
dataset: Union[Dataset, Data, List[Data], Tensor],
input_nodes: Optional[Tensor] = None,
num_samples: Optional[int] = None,
):
Expand All @@ -60,6 +73,10 @@ def __init__(
assert dataset.num_nodes == y.numel()
y = y[input_nodes] if input_nodes is not None else y

elif isinstance(dataset, Tensor):
y = dataset.view(-1)
y = y[input_nodes] if input_nodes is not None else y
rusty1s marked this conversation as resolved.
Show resolved Hide resolved

elif isinstance(dataset, InMemoryDataset):
y = dataset.data.y.view(-1)
assert len(dataset) == y.numel()
Expand Down