Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Triplet sampling in NeighborLoader #6004

Merged
merged 14 commits into from
Nov 23, 2022
Merged
Show file tree
Hide file tree
Changes from 12 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

## [2.2.0] - 2022-MM-DD
### Added
- Added triplet sampling in `LinkNeighborLoader` ([#6004](https://github.com/pyg-team/pytorch_geometric/pull/6004))
- Added `FusedAggregation` of simple scatter reductions ([#6036](https://github.com/pyg-team/pytorch_geometric/pull/6036))
- Added `HeteroData` support for `to_captum_model` and added `to_captum_input` ([#5934](https://github.com/pyg-team/pytorch_geometric/pull/5934))
- Added `HeteroData` support in `RandomNodeLoader` ([#6007](https://github.com/pyg-team/pytorch_geometric/pull/6007))
Expand Down
2 changes: 1 addition & 1 deletion examples/graph_sage_unsup_ppi.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
# Group all training graphs into a single graph to perform sampling:
train_data = Batch.from_data_list(train_dataset)
loader = LinkNeighborLoader(train_data, batch_size=2048, shuffle=True,
neg_sampling_ratio=0.5, num_neighbors=[10, 10],
neg_sampling_ratio=1.0, num_neighbors=[10, 10],
num_workers=6, persistent_workers=True)

# Evaluation loaders (one datapoint corresponds to a graph)
Expand Down
5 changes: 0 additions & 5 deletions test/data/test_lightning_datamodule.py
Original file line number Diff line number Diff line change
Expand Up @@ -329,9 +329,7 @@ def test_lightning_hetero_link_data():
)
assert isinstance(datamodule.neighbor_sampler, NeighborSampler)
for batch in datamodule.train_dataloader():
assert 'edge_label' in batch['author', 'paper']
assert 'edge_label_index' in batch['author', 'paper']
break

data['author'].time = torch.arange(data['author'].num_nodes)
data['paper'].time = torch.arange(data['paper'].num_nodes)
Expand All @@ -349,10 +347,8 @@ def test_lightning_hetero_link_data():
)

for batch in datamodule.train_dataloader():
assert 'edge_label' in batch['author', 'paper']
assert 'edge_label_index' in batch['author', 'paper']
assert 'edge_label_time' in batch['author', 'paper']
break


@withPackage('pytorch_lightning')
Expand Down Expand Up @@ -388,5 +384,4 @@ def test_lightning_hetero_link_data_custom_store():
)

batch = next(iter(datamodule.train_dataloader()))
assert 'edge_label' in batch['author', 'paper']
assert 'edge_label_index' in batch['author', 'paper']
204 changes: 188 additions & 16 deletions test/loader/test_link_neighbor_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ def unique_edge_pairs(edge_index):


@pytest.mark.parametrize('directed', [True]) # TODO re-enable undirected mode
@pytest.mark.parametrize('neg_sampling_ratio', [0.0, 1.0])
@pytest.mark.parametrize('neg_sampling_ratio', [None, 1.0])
def test_homogeneous_link_neighbor_loader(directed, neg_sampling_ratio):
pos_edge_index = get_edge_index(100, 50, 500)
neg_edge_index = get_edge_index(100, 50, 500)
Expand All @@ -39,7 +39,7 @@ def test_homogeneous_link_neighbor_loader(directed, neg_sampling_ratio):
num_neighbors=[-1] * 2,
batch_size=20,
edge_label_index=edge_label_index,
edge_label=edge_label if neg_sampling_ratio == 0.0 else None,
edge_label=edge_label if neg_sampling_ratio is None else None,
directed=directed,
neg_sampling_ratio=neg_sampling_ratio,
shuffle=True,
Expand All @@ -60,7 +60,7 @@ def test_homogeneous_link_neighbor_loader(directed, neg_sampling_ratio):
assert batch.edge_attr.min() >= 0
assert batch.edge_attr.max() < 500

if neg_sampling_ratio == 0.0:
if neg_sampling_ratio is None:
assert batch.edge_label_index.size(1) == 20

# Assert positive samples are present in the original graph:
Expand All @@ -82,7 +82,7 @@ def test_homogeneous_link_neighbor_loader(directed, neg_sampling_ratio):


@pytest.mark.parametrize('directed', [True]) # TODO re-enable undirected mode
@pytest.mark.parametrize('neg_sampling_ratio', [0.0, 1.0])
@pytest.mark.parametrize('neg_sampling_ratio', [None, 1.0])
def test_heterogeneous_link_neighbor_loader(directed, neg_sampling_ratio):
data = HeteroData()

Expand Down Expand Up @@ -111,10 +111,9 @@ def test_heterogeneous_link_neighbor_loader(directed, neg_sampling_ratio):

for batch in loader:
assert isinstance(batch, HeteroData)
assert len(batch) == 6
if neg_sampling_ratio == 0.0:
assert len(batch) == 5 + (1 if neg_sampling_ratio is not None else 0)
if neg_sampling_ratio is None:
# Assert only positive samples are present in the original graph:
assert batch['paper', 'author'].edge_label.sum() == 0
edge_index = unique_edge_pairs(batch['paper', 'author'].edge_index)
edge_label_index = batch['paper', 'author'].edge_label_index
edge_label_index = unique_edge_pairs(edge_label_index)
Expand Down Expand Up @@ -218,9 +217,7 @@ def test_temporal_heterogeneous_link_neighbor_loader():
)
for batch in loader:
# Check if each seed edge has a different batch:
assert int(batch['paper'].batch.max()) + 1 == 32 + 16
# Check if each seed edge has a different source and dstination node:
assert batch['paper'].num_nodes >= 2 * (32 + 16)
assert int(batch['paper'].batch.max()) + 1 == 32
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why isn't this 64? Won't the negative samples also have a disjoint batch?

Copy link
Member Author

@rusty1s rusty1s Nov 23, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I personally think it makes more sense to group the negatives and positives to its own batch, in particular because they now share time information.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm, I see. But it seems weird that we now have more than 1 seed edge with the same batch number. Even though their computation graphs are in fact disjoint. But we can debate this later.


author_max = batch['author'].time.max()
edge_max = batch['paper', 'paper'].edge_label_time.max()
Expand Down Expand Up @@ -272,7 +269,6 @@ def test_custom_heterogeneous_link_neighbor_loader(FeatureStore, GraphStore):
edge_label_index=('paper', 'to', 'author'),
batch_size=20,
directed=True,
neg_sampling_ratio=0,
)

loader2 = LinkNeighborLoader(
Expand All @@ -281,7 +277,6 @@ def test_custom_heterogeneous_link_neighbor_loader(FeatureStore, GraphStore):
edge_label_index=('paper', 'to', 'author'),
batch_size=20,
directed=True,
neg_sampling_ratio=0,
)

assert str(loader1) == str(loader2)
Expand Down Expand Up @@ -312,9 +307,8 @@ def test_homogeneous_link_neighbor_loader_no_edges():

for batch in loader:
assert isinstance(batch, Data)
assert len(batch) == 4
assert len(batch) == 3
assert batch.input_id.numel() == 20
assert batch.num_nodes <= 40
assert batch.edge_label_index.size(1) == 20
assert batch.num_nodes == batch.edge_label_index.unique().numel()

Expand All @@ -329,9 +323,187 @@ def test_heterogeneous_link_neighbor_loader_no_edges():

for batch in loader:
assert isinstance(batch, HeteroData)
assert len(batch) == 4
assert batch['paper'].num_nodes <= 40
assert len(batch) == 3
assert batch['paper', 'paper'].input_id.numel() == 20
assert batch['paper', 'paper'].edge_label_index.size(1) == 20
assert batch['paper'].num_nodes == batch[
'paper', 'paper'].edge_label_index.unique().numel()


@withPackage('pyg_lib')
@pytest.mark.parametrize('disjoint', [False, True])
@pytest.mark.parametrize('temporal', [False, True])
@pytest.mark.parametrize('amount', [1, 2])
def test_homo_link_neighbor_loader_triplet(disjoint, temporal, amount):
if not disjoint and temporal:
return

data = Data()
data.x = torch.arange(100)
data.edge_index = get_edge_index(100, 100, 500)
data.edge_attr = torch.arange(500)

time_attr = edge_label_time = None
if temporal:
time_attr = 'time'
data.time = torch.arange(data.num_nodes)

edge_label_time = torch.max(data.time[data.edge_index[0]],
data.time[data.edge_index[1]])
edge_label_time = edge_label_time + 50

batch_size = 20
loader = LinkNeighborLoader(
data,
num_neighbors=[-1] * 2,
batch_size=batch_size,
edge_label_index=data.edge_index,
edge_label_time=edge_label_time,
time_attr=time_attr,
directed=True,
disjoint=disjoint,
neg_sampling=dict(strategy='triplet', amount=amount),
shuffle=True,
)

assert str(loader) == 'LinkNeighborLoader()'
assert len(loader) == 500 / batch_size

for batch in loader:
assert isinstance(batch, Data)
num_elems = 7 + (1 if disjoint else 0) + (2 if temporal else 0)
assert len(batch) == num_elems

# Check that `src_index` and `dst_pos_index` point to valid edges:
assert torch.equal(batch.x[batch.src_index],
data.edge_index[0, batch.input_id])
assert torch.equal(batch.x[batch.dst_pos_index],
data.edge_index[1, batch.input_id])

# Check that `dst_neg_index` points to valid nodes in the batch:
if amount == 1:
assert batch.dst_neg_index.size() == (batch_size, )
else:
assert batch.dst_neg_index.size() == (batch_size, amount)
assert batch.dst_neg_index.min() >= 0
assert batch.dst_neg_index.max() < batch.num_nodes

if disjoint:
# In disjoint mode, seed nodes should always be placed first:
assert batch.src_index.min() == 0
assert batch.src_index.max() == batch_size - 1

assert batch.dst_pos_index.min() == batch_size
assert batch.dst_pos_index.max() == 2 * batch_size - 1

assert batch.dst_neg_index.min() == 2 * batch_size
max_seed_nodes = 2 * batch_size + batch_size * amount
assert batch.dst_neg_index.max() == max_seed_nodes - 1

assert batch.batch.min() == 0
assert batch.batch.max() == batch_size - 1

# Check that `batch` is always increasing:
for i in range(0, max_seed_nodes, batch_size):
batch_vector = batch.batch[i:i + batch_size]
assert torch.equal(batch_vector, torch.arange(batch_size))

if temporal:
for i in range(batch_size):
assert batch.time[batch.batch == i].max() <= batch.seed_time[i]


@withPackage('pyg_lib')
@pytest.mark.parametrize('disjoint', [False, True])
@pytest.mark.parametrize('temporal', [False, True])
@pytest.mark.parametrize('amount', [1, 2])
def test_hetero_link_neighbor_loader_triplet(disjoint, temporal, amount):
if not disjoint and temporal:
return

data = HeteroData()

data['paper'].x = torch.arange(100)
data['author'].x = torch.arange(100, 300)

data['paper', 'paper'].edge_index = get_edge_index(100, 100, 500)
data['paper', 'author'].edge_index = get_edge_index(100, 200, 1000)
data['author', 'paper'].edge_index = get_edge_index(200, 100, 1000)

time_attr = edge_label_time = None
if temporal:
time_attr = 'time'
data['paper'].time = torch.arange(data['paper'].num_nodes)
data['author'].time = torch.arange(data['author'].num_nodes)

edge_label_time = torch.max(
data['paper'].time[data['paper', 'paper'].edge_index[0]],
data['paper'].time[data['paper', 'paper'].edge_index[1]],
)
edge_label_time = edge_label_time + 50

batch_size = 20
loader = LinkNeighborLoader(
data,
num_neighbors=[-1] * 2,
batch_size=batch_size,
edge_label_index=('paper', 'paper'),
edge_label_time=edge_label_time,
time_attr=time_attr,
directed=True,
disjoint=disjoint,
neg_sampling=dict(strategy='triplet', amount=amount),
shuffle=True,
)

assert str(loader) == 'LinkNeighborLoader()'
assert len(loader) == 500 / batch_size

for batch in loader:
assert isinstance(batch, HeteroData)
num_elems = 6 + (1 if disjoint else 0) + (2 if temporal else 0)
assert len(batch) == num_elems

node_store = batch['paper']
edge_store = batch['paper', 'paper']

# Check that `src_index` and `dst_pos_index` point to valid edges:
assert torch.equal(
node_store.x[node_store.src_index],
data['paper', 'paper'].edge_index[0, edge_store.input_id])
assert torch.equal(
node_store.x[node_store.dst_pos_index],
data['paper', 'paper'].edge_index[1, edge_store.input_id])

# Check that `dst_neg_index` points to valid nodes in the batch:
if amount == 1:
assert node_store.dst_neg_index.size() == (batch_size, )
else:
assert node_store.dst_neg_index.size() == (batch_size, amount)
assert node_store.dst_neg_index.min() >= 0
assert node_store.dst_neg_index.max() < node_store.num_nodes

if disjoint:
# In disjoint mode, seed nodes should always be placed first:
assert node_store.src_index.min() == 0
assert node_store.src_index.max() == batch_size - 1

assert node_store.dst_pos_index.min() == batch_size
assert node_store.dst_pos_index.max() == 2 * batch_size - 1

assert node_store.dst_neg_index.min() == 2 * batch_size
max_seed_nodes = 2 * batch_size + batch_size * amount
assert node_store.dst_neg_index.max() == max_seed_nodes - 1

assert node_store.batch.min() == 0
assert node_store.batch.max() == batch_size - 1

# Check that `batch` is always increasing:
for i in range(0, max_seed_nodes, batch_size):
batch_vector = node_store.batch[i:i + batch_size]
assert torch.equal(batch_vector, torch.arange(batch_size))

if temporal:
for i in range(batch_size):
assert (node_store.time[node_store.batch == i].max() <=
node_store.seed_time[i])
Loading