-
Notifications
You must be signed in to change notification settings - Fork 15
/
Copy pathdataloader.py
73 lines (60 loc) · 2.57 KB
/
dataloader.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
import pandas as pd
import torch.utils.data as data
import torch
import numpy as np
from functools import partial
from dgllife.utils import smiles_to_bigraph, CanonicalAtomFeaturizer, CanonicalBondFeaturizer
from utils import integer_label_protein
class DTIDataset(data.Dataset):
def __init__(self, list_IDs, df, max_drug_nodes=290):
self.list_IDs = list_IDs
self.df = df
self.max_drug_nodes = max_drug_nodes
self.atom_featurizer = CanonicalAtomFeaturizer()
self.bond_featurizer = CanonicalBondFeaturizer(self_loop=True)
self.fc = partial(smiles_to_bigraph, add_self_loop=True)
def __len__(self):
return len(self.list_IDs)
def __getitem__(self, index):
index = self.list_IDs[index]
v_d = self.df.iloc[index]['SMILES']
v_d = self.fc(smiles=v_d, node_featurizer=self.atom_featurizer, edge_featurizer=self.bond_featurizer)
actual_node_feats = v_d.ndata.pop('h')
num_actual_nodes = actual_node_feats.shape[0]
num_virtual_nodes = self.max_drug_nodes - num_actual_nodes
virtual_node_bit = torch.zeros([num_actual_nodes, 1])
actual_node_feats = torch.cat((actual_node_feats, virtual_node_bit), 1)
v_d.ndata['h'] = actual_node_feats
virtual_node_feat = torch.cat((torch.zeros(num_virtual_nodes, 74), torch.ones(num_virtual_nodes, 1)), 1)
v_d.add_nodes(num_virtual_nodes, {"h": virtual_node_feat})
v_d = v_d.add_self_loop()
v_p = self.df.iloc[index]['Protein']
v_p = integer_label_protein(v_p)
y = self.df.iloc[index]["Y"]
# y = torch.Tensor([y])
return v_d, v_p, y
class MultiDataLoader(object):
def __init__(self, dataloaders, n_batches):
if n_batches <= 0:
raise ValueError("n_batches should be > 0")
self._dataloaders = dataloaders
self._n_batches = np.maximum(1, n_batches)
self._init_iterators()
def _init_iterators(self):
self._iterators = [iter(dl) for dl in self._dataloaders]
def _get_nexts(self):
def _get_next_dl_batch(di, dl):
try:
batch = next(dl)
except StopIteration:
new_dl = iter(self._dataloaders[di])
self._iterators[di] = new_dl
batch = next(new_dl)
return batch
return [_get_next_dl_batch(di, dl) for di, dl in enumerate(self._iterators)]
def __iter__(self):
for _ in range(self._n_batches):
yield self._get_nexts()
self._init_iterators()
def __len__(self):
return self._n_batches