-
Notifications
You must be signed in to change notification settings - Fork 3.7k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* fix test * typo * update * update * update * update * update * update * update * update * update * update * skip tests * run cron only in master repo * fix test * update * Add test
- Loading branch information
Showing
23 changed files
with
188 additions
and
203 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,47 @@ | ||
import functools | ||
import os.path as osp | ||
import shutil | ||
|
||
import pytest | ||
|
||
from torch_geometric.data import Dataset | ||
|
||
|
||
def load_dataset(root: str, name: str, *args, **kwargs) -> Dataset: | ||
r"""Returns a variety of datasets according to :obj:`name`.""" | ||
if 'karate' in name.lower(): | ||
from torch_geometric.datasets import KarateClub | ||
return KarateClub(*args, **kwargs) | ||
if name.lower() in ['cora', 'citeseer', 'pubmed']: | ||
from torch_geometric.datasets import Planetoid | ||
path = osp.join(root, 'Planetoid', name) | ||
return Planetoid(path, name, *args, **kwargs) | ||
if name in ['BZR', 'ENZYMES', 'IMDB-BINARY', 'MUTAG']: | ||
from torch_geometric.datasets import TUDataset | ||
path = osp.join(root, 'TUDataset') | ||
return TUDataset(path, name, *args, **kwargs) | ||
if name in ['ego-facebook', 'soc-Slashdot0811', 'wiki-vote']: | ||
from torch_geometric.datasets import SNAPDataset | ||
path = osp.join(root, 'SNAPDataset') | ||
return SNAPDataset(path, name, *args, **kwargs) | ||
if name.lower() in ['bashapes']: | ||
from torch_geometric.datasets import BAShapes | ||
return BAShapes(*args, **kwargs) | ||
if name.lower() in ['dblp']: | ||
from torch_geometric.datasets import DBLP | ||
path = osp.join(root, 'DBLP') | ||
return DBLP(path, *args, **kwargs) | ||
if name in ['citationCiteseer', 'illc1850']: | ||
from torch_geometric.datasets import SuiteSparseMatrixCollection | ||
path = osp.join(root, 'SuiteSparseMatrixCollection') | ||
return SuiteSparseMatrixCollection(path, name=name, *args, **kwargs) | ||
|
||
raise NotImplementedError | ||
|
||
|
||
@pytest.fixture(scope='session') | ||
def get_dataset(): | ||
root = osp.join('/', 'tmp', 'pyg_test_datasets') | ||
yield functools.partial(load_dataset, root) | ||
if osp.exists(root): | ||
shutil.rmtree(root) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,26 +1,21 @@ | ||
import os.path as osp | ||
import random | ||
import shutil | ||
import sys | ||
from torch_geometric.testing import onlyFullTest | ||
|
||
from torch_geometric.datasets import TUDataset | ||
|
||
|
||
def test_bzr(): | ||
root = osp.join('/', 'tmp', str(random.randrange(sys.maxsize))) | ||
dataset = TUDataset(root, 'BZR') | ||
|
||
@onlyFullTest | ||
def test_bzr(get_dataset): | ||
dataset = get_dataset(name='BZR') | ||
assert len(dataset) == 405 | ||
assert dataset.num_features == 53 | ||
assert dataset.num_node_labels == 53 | ||
assert dataset.num_node_attributes == 0 | ||
assert dataset.num_classes == 2 | ||
assert dataset.__repr__() == 'BZR(405)' | ||
assert str(dataset) == 'BZR(405)' | ||
assert len(dataset[0]) == 3 | ||
|
||
dataset = TUDataset(root, 'BZR', use_node_attr=True) | ||
|
||
@onlyFullTest | ||
def test_bzr_with_node_attr(get_dataset): | ||
dataset = get_dataset(name='BZR', use_node_attr=True) | ||
assert dataset.num_features == 56 | ||
assert dataset.num_node_labels == 53 | ||
assert dataset.num_node_attributes == 3 | ||
|
||
shutil.rmtree(root) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,24 +1,14 @@ | ||
import os.path as osp | ||
import random | ||
import shutil | ||
import sys | ||
|
||
from torch_geometric.datasets import TUDataset | ||
|
||
|
||
def test_mutag(): | ||
root = osp.join('/', 'tmp', str(random.randrange(sys.maxsize))) | ||
dataset = TUDataset(root, 'MUTAG') | ||
|
||
def test_mutag(get_dataset): | ||
dataset = get_dataset(name='MUTAG') | ||
assert len(dataset) == 188 | ||
assert dataset.num_features == 7 | ||
assert dataset.num_classes == 2 | ||
assert dataset.__repr__() == 'MUTAG(188)' | ||
assert str(dataset) == 'MUTAG(188)' | ||
|
||
assert len(dataset[0]) == 4 | ||
assert dataset[0].edge_attr.size(1) == 4 | ||
|
||
dataset = TUDataset(root, 'MUTAG', use_node_attr=True) | ||
assert dataset.num_features == 7 | ||
|
||
shutil.rmtree(root) | ||
def test_mutag_with_node_attr(get_dataset): | ||
dataset = get_dataset(name='MUTAG', use_node_attr=True) | ||
assert dataset.num_features == 7 |
Oops, something went wrong.