Skip to content

Commit 3d5f855

Browse files
Padarnrusty1s
andauthored
Rewrite tests to not depend on currently broken dblp dataset (#5250)
* Rewrite tests to not depend on currently broken dblp dataset * update changelog * Update test/data/test_lightning_datamodule.py Co-authored-by: Matthias Fey <matthias.fey@tu-dortmund.de> * Update CHANGELOG.md Co-authored-by: Matthias Fey <matthias.fey@tu-dortmund.de> * add skip for broken test * add mark in skip * add mark in skip Co-authored-by: Matthias Fey <matthias.fey@tu-dortmund.de>
1 parent 39c7e88 commit 3d5f855

File tree

3 files changed

+24
-5
lines changed

3 files changed

+24
-5
lines changed

CHANGELOG.md

+1
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
1111
- Added `BaseStorage.get()` functionality ([#5240](https://github.com/pyg-team/pytorch_geometric/pull/5240))
1212
- Added a test to confirm that `to_hetero` works with `SparseTensor` ([#5222](https://github.com/pyg-team/pytorch_geometric/pull/5222))
1313
### Changed
14+
- Changed tests relying on `dblp` datasets to instead use synthetic data. ([#5250](https://github.com/pyg-team/pytorch_geometric/pull/5250))
1415
### Removed
1516

1617
## [2.1.0] - 2022-08-17

test/data/test_lightning_datamodule.py

+19-4
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
import torch.nn.functional as F
66

77
from torch_geometric.data import (
8+
HeteroData,
89
LightningDataset,
910
LightningLinkData,
1011
LightningNodeData,
@@ -18,6 +19,12 @@
1819
LightningModule = torch.nn.Module
1920

2021

22+
def get_edge_index(num_src_nodes, num_dst_nodes, num_edges):
23+
row = torch.randint(num_src_nodes, (num_edges, ), dtype=torch.long)
24+
col = torch.randint(num_dst_nodes, (num_edges, ), dtype=torch.long)
25+
return torch.stack([row, col], dim=0)
26+
27+
2128
class LinearGraphModule(LightningModule):
2229
def __init__(self, in_channels, hidden_channels, out_channels):
2330
super().__init__()
@@ -273,10 +280,18 @@ def test_lightning_hetero_node_data(get_dataset):
273280
@withCUDA
274281
@onlyFullTest
275282
@withPackage('pytorch_lightning')
276-
def test_lightning_hetero_link_data(get_dataset):
277-
# TODO: Add more datasets.
278-
dataset = get_dataset(name='DBLP')
279-
data = dataset[0]
283+
def test_lightning_hetero_link_data():
284+
torch.manual_seed(12345)
285+
286+
data = HeteroData()
287+
288+
data['paper'].x = torch.arange(10)
289+
data['author'].x = torch.arange(10)
290+
data['term'].x = torch.arange(10)
291+
292+
data['paper', 'author'].edge_index = get_edge_index(10, 10, 10)
293+
data['author', 'paper'].edge_index = get_edge_index(10, 10, 10)
294+
data['paper', 'term'].edge_index = get_edge_index(10, 10, 10)
280295

281296
datamodule = LightningLinkData(
282297
data,

test/loader/test_hgt_loader.py

+4-1
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import numpy as np
2+
import pytest
23
import torch
34
from torch_sparse import SparseTensor
45

@@ -56,8 +57,9 @@ def test_hgt_loader():
5657
for batch in loader:
5758
assert isinstance(batch, HeteroData)
5859

59-
# Test node type selection:
60+
# Test node and types:
6061
assert set(batch.node_types) == {'paper', 'author'}
62+
assert set(batch.edge_types) == set(data.edge_types)
6163

6264
assert len(batch['paper']) == 2
6365
assert batch['paper'].x.size() == (40, ) # 20 + 4 * 5
@@ -177,6 +179,7 @@ def forward(self, x, edge_index, edge_weight):
177179
assert torch.allclose(out1, out2, atol=1e-6)
178180

179181

182+
@pytest.mark.skip("'dblp' dataset is broken")
180183
@withPackage('torch_sparse>=0.6.15')
181184
def test_hgt_loader_on_dblp(get_dataset):
182185
data = get_dataset(name='dblp')[0]

0 commit comments

Comments
 (0)