-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathutils.py
121 lines (95 loc) · 4.74 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
"""
Utilities such as model resolver
"""
from chemicalx.models import (DeepSynergy,
EPGCNDS, DeepDDS, MatchMaker, DeepDrug, DROnly,
DeepDRSynergy, EPGCNDSDR, DeepDDSDR, MatchMakerDR, DeepDrugDR
)
def model_resolver(dataset, model_name, drug_dr_channels=None):
"""
Utility function which given the dataset, and the string name of model
returns an instance of the model with settings appropriate for that model
as well as the boolean values for the data_loaders working on the dataset
Args:
dataset (chemicalx.dataset): an instance of a chemicalx dataset
model_name (str): string literal to resolve model
"""
if model_name == "DeepSynergy":
model = DeepSynergy(context_channels=dataset.context_channels, drug_channels=dataset.drug_channels)
bool_context_features = True
bool_drug_features = True
bool_drug_molecules = False
bool_drug_dr_features = False
return model, bool_context_features, bool_drug_features, bool_drug_molecules, bool_drug_dr_features
elif model_name == "EPGCNDS":
model = EPGCNDS()
bool_context_features = True
bool_drug_features = True
bool_drug_molecules = True
bool_drug_dr_features = False
return model, bool_context_features, bool_drug_features, bool_drug_molecules, bool_drug_dr_features
elif model_name == "DeepDDS":
model = DeepDDS(context_channels=dataset.context_channels)
bool_context_features = True
bool_drug_features = True
bool_drug_molecules = True
bool_drug_dr_features = False
return model, bool_context_features, bool_drug_features, bool_drug_molecules, bool_drug_dr_features
elif model_name == "MatchMaker":
model = MatchMaker(context_channels=dataset.context_channels, drug_channels=dataset.drug_channels)
bool_context_features = True
bool_drug_features = True
bool_drug_molecules = False
bool_drug_dr_features = False
return model, bool_context_features, bool_drug_features, bool_drug_molecules, bool_drug_dr_features
elif model_name == "DeepDrug":
model = DeepDrug()
bool_context_features = False
bool_drug_features = True
bool_drug_molecules = True
bool_drug_dr_features = False
return model, bool_context_features, bool_drug_features, bool_drug_molecules, bool_drug_dr_features
elif model_name == "DROnly":
model = DROnly(drug_dr_channels=drug_dr_channels)
bool_context_features = False
bool_drug_features = False
bool_drug_molecules = False
bool_drug_dr_features = True
return model, bool_context_features, bool_drug_features, bool_drug_molecules, bool_drug_dr_features
elif model_name == "DeepDRSynergy":
model = DeepDRSynergy(context_channels=dataset.context_channels, drug_channels=dataset.drug_channels, drug_dr_channels=drug_dr_channels)
bool_context_features = True
bool_drug_features = True
bool_drug_molecules = False
bool_drug_dr_features = True
return model, bool_context_features, bool_drug_features, bool_drug_molecules, bool_drug_dr_features
elif model_name == "EPGCNDSDR":
model = EPGCNDSDR(drug_dr_channels=drug_dr_channels)
bool_context_features = True
bool_drug_features = True
bool_drug_molecules = True
bool_drug_dr_features = True
return model, bool_context_features, bool_drug_features, bool_drug_molecules, bool_drug_dr_features
elif model_name == "DeepDDSDR":
model = DeepDDSDR(context_channels=dataset.context_channels, drug_dr_channels=drug_dr_channels)
bool_context_features = True
bool_drug_features = True
bool_drug_molecules = True
bool_drug_dr_features = True
return model, bool_context_features, bool_drug_features, bool_drug_molecules, bool_drug_dr_features
elif model_name == "MatchMakerDR":
model = MatchMakerDR(context_channels=dataset.context_channels, drug_channels=dataset.drug_channels, drug_dr_channels=drug_dr_channels)
bool_context_features = True
bool_drug_features = True
bool_drug_molecules = False
bool_drug_dr_features = True
return model, bool_context_features, bool_drug_features, bool_drug_molecules, bool_drug_dr_features
elif model_name == "DeepDrugDR":
model = DeepDrugDR(drug_dr_channels=drug_dr_channels)
bool_context_features = False
bool_drug_features = True
bool_drug_molecules = True
bool_drug_dr_features = True
return model, bool_context_features, bool_drug_features, bool_drug_molecules, bool_drug_dr_features
else:
raise ValueError(f"model_name is not recognized you put {model_name}")