forked from hengruizhang98/CCA-SSG
-
Notifications
You must be signed in to change notification settings - Fork 0
/
dataset.py
61 lines (49 loc) · 1.89 KB
/
dataset.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
import numpy as np
import torch as th
from dgl.data import CoraGraphDataset, CiteseerGraphDataset, PubmedGraphDataset
from dgl.data import AmazonCoBuyPhotoDataset, AmazonCoBuyComputerDataset
from dgl.data import CoauthorCSDataset, CoauthorPhysicsDataset
def load(name):
if name == 'cora':
dataset = CoraGraphDataset()
elif name == 'citeseer':
dataset = CiteseerGraphDataset()
elif name == 'pubmed':
dataset = PubmedGraphDataset()
elif name == 'photo':
dataset = AmazonCoBuyPhotoDataset()
elif name == 'comp':
dataset = AmazonCoBuyComputerDataset()
elif name == 'cs':
dataset = CoauthorCSDataset()
elif name == 'physics':
dataset = CoauthorPhysicsDataset()
graph = dataset[0]
citegraph = ['cora', 'citeseer', 'pubmed']
cograph = ['photo', 'comp', 'cs', 'physics']
if name in citegraph:
train_mask = graph.ndata.pop('train_mask')
val_mask = graph.ndata.pop('val_mask')
test_mask = graph.ndata.pop('test_mask')
train_idx = th.nonzero(train_mask, as_tuple=False).squeeze()
val_idx = th.nonzero(val_mask, as_tuple=False).squeeze()
test_idx = th.nonzero(test_mask, as_tuple=False).squeeze()
if name in cograph:
train_ratio = 0.1
val_ratio = 0.1
test_ratio = 0.8
N = graph.number_of_nodes()
train_num = int(N * train_ratio)
val_num = int(N * (train_ratio + val_ratio))
idx = np.arange(N)
np.random.shuffle(idx)
train_idx = idx[:train_num]
val_idx = idx[train_num:val_num]
test_idx = idx[val_num:]
train_idx = th.tensor(train_idx)
val_idx = th.tensor(val_idx)
test_idx = th.tensor(test_idx)
num_class = dataset.num_classes
feat = graph.ndata.pop('feat')
labels = graph.ndata.pop('label')
return graph, feat, labels, num_class, train_idx, val_idx, test_idx