-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathutils.py
29 lines (23 loc) · 987 Bytes
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
import torch
def index_to_mask(index, size):
mask = torch.zeros((size, ), dtype=torch.bool)
mask[index] = 1
return mask
def random_splits(data, num_classes):
# Set new random planetoid splits:
# * 20 * num_classes labels for training
# * 30 * num_classes labels for validation
# * the rest for testing
indices = []
for i in range(num_classes):
index = (data.y == i).nonzero().view(-1)
index = index[torch.randperm(index.size(0))]
indices.append(index)
train_index = torch.cat([i[:20] for i in indices], dim=0)
val_index = torch.cat([i[20:50] for i in indices], dim=0)
rest_index = torch.cat([i[50:] for i in indices], dim=0)
rest_index = rest_index[torch.randperm(rest_index.size(0))]
data.train_mask = index_to_mask(train_index, size=data.num_nodes)
data.val_mask = index_to_mask(val_index, size=data.num_nodes)
data.test_mask = index_to_mask(rest_index, size=data.num_nodes)
return data