-
Notifications
You must be signed in to change notification settings - Fork 2
/
common.py
33 lines (22 loc) · 1.23 KB
/
common.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
import torch
from models.ginconv import GINConvNet
from utils import *
dataset = 'kiba'
def create_model(normalisation, device="cpu"):
return GINConvNet(normalisation).to(device)
def partition(dataset, num_partitions, seed):
length = int(len(dataset) / num_partitions)
partitions = [length for _ in range(num_partitions - 1)]
partitions.append(len(dataset) - sum(partitions))
return torch.utils.data.random_split(dataset, partitions, generator=torch.Generator().manual_seed(seed))
def load(num_partitions, seed, path=None):
if path is None:
xy_train, xy_test = TestbedDataset(root='data', dataset=dataset + '_train'), TestbedDataset(root='data',
dataset=dataset + '_test')
xy_train = partition(xy_train, num_partitions, seed)
xy_test = partition(xy_test, num_partitions, seed)
return list(zip(xy_train, xy_test))
else:
xy_train, xy_test = TestbedDataset(root=path, dataset=dataset + '_train'), TestbedDataset(root=path,
dataset=dataset + '_test')
return xy_train, xy_test