-
Notifications
You must be signed in to change notification settings - Fork 88
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Implement DeepDDI model #63
Changes from 3 commits
8e7466d
e485d06
a1c661e
7e86de4
20f0a50
da0bb06
3c74d7c
7b631a5
c16f927
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,14 +1,90 @@ | ||
"""An implementation of the DeepDDI model.""" | ||
|
||
from .base import UnimplementedModel | ||
import torch | ||
|
||
from chemicalx.data import DrugPairBatch | ||
from chemicalx.models import Model | ||
__all__ = [ | ||
"DeepDDI", | ||
] | ||
|
||
|
||
class DeepDDI(UnimplementedModel): | ||
class DeepDDI(Model): | ||
"""An implementation of the DeepDDI model. | ||
|
||
.. seealso:: https://github.com/AstraZeneca/chemicalx/issues/2 | ||
""" | ||
|
||
def __init__( | ||
self, | ||
*, | ||
drug_channels: int, | ||
hidden_channels: int = 2048, | ||
hidden_layers_num: int = 9, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can we setup a smaller number of layers and lower hidden channel number? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @cthoyt can we have 4 layer with 32 hidden channels? |
||
out_channels: int = 1, | ||
): | ||
"""Instantiate the DeepDDI model. | ||
|
||
:param drug_channels: The number of drug features. | ||
:param hidden_channels: The number of hidden layer neurons. | ||
:param hidden_layers_num: The number of hidden layers. | ||
:param out_channels: The number of output channels. | ||
""" | ||
|
||
super(DeepDDI, self).__init__() | ||
assert hidden_layers_num > 1 | ||
dnn = [] | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I would strongly suggest using a fixed number of layers. E.g. 3 or 4. |
||
dnn.extend([ | ||
cthoyt marked this conversation as resolved.
Show resolved
Hide resolved
|
||
torch.nn.Linear(drug_channels * 2, hidden_channels), | ||
torch.nn.ReLU(), | ||
torch.nn.BatchNorm1d( | ||
num_features=hidden_channels, | ||
affine=True, | ||
momentum=None | ||
), | ||
torch.nn.ReLU() | ||
]) | ||
for _ in range(hidden_layers_num - 1): | ||
dnn.extend([ | ||
cthoyt marked this conversation as resolved.
Show resolved
Hide resolved
|
||
torch.nn.Linear(hidden_channels, hidden_channels), | ||
torch.nn.ReLU(), | ||
torch.nn.BatchNorm1d( | ||
num_features=hidden_channels, | ||
affine=True, | ||
momentum=None | ||
), | ||
torch.nn.ReLU() | ||
]) | ||
dnn.extend([ | ||
torch.nn.Linear(hidden_channels, out_channels), | ||
torch.nn.Sigmoid() | ||
]) | ||
self.dnn = torch.nn.Sequential(*dnn) | ||
|
||
def unpack(self, batch: DrugPairBatch): | ||
"""Return the context features, left drug features and right drug features.""" | ||
return ( | ||
batch.drug_features_left, | ||
batch.drug_features_right, | ||
) | ||
|
||
def forward( | ||
self, | ||
drug_features_left: torch.FloatTensor, | ||
drug_features_right: torch.FloatTensor, | ||
) -> torch.FloatTensor: | ||
""" | ||
Run a forward pass of the DeepDDI model. | ||
|
||
Args: | ||
drug_features_left (torch.FloatTensor): A matrix of head drug features. | ||
drug_features_right (torch.FloatTensor): A matrix of tail drug features. | ||
Returns: | ||
hidden (torch.FloatTensor): A column vector of predicted interaction scores. | ||
""" | ||
input_feature = torch.cat([drug_features_left, drug_features_right], 1) | ||
cthoyt marked this conversation as resolved.
Show resolved
Hide resolved
|
||
hidden = self.dnn(input_feature) | ||
return hidden | ||
|
||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Remove spaces. The PR does not follow linting. Please use black to format your code. When you push again you can test it on the repo. |
||
|
||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,28 @@ | ||
"""Example with DeepDDI.""" | ||
|
||
from chemicalx import pipeline | ||
from chemicalx.data import DrugbankDDI | ||
from chemicalx.models import DeepDDI | ||
|
||
|
||
def main(): | ||
"""Train and evaluate the DeepSynergy model.""" | ||
dataset = DrugbankDDI() | ||
model = DeepDDI(drug_channels=dataset.drug_channels, hidden_layers_num=2) | ||
cthoyt marked this conversation as resolved.
Show resolved
Hide resolved
|
||
results = pipeline( | ||
dataset=dataset, | ||
model=model, | ||
batch_size=5120, | ||
epochs=100, | ||
context_features=False, | ||
drug_features=True, | ||
drug_molecules=False, | ||
metrics=[ | ||
"roc_auc", | ||
] | ||
) | ||
results.summarize() | ||
|
||
|
||
if __name__ == "__main__": | ||
main() |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -165,8 +165,22 @@ def test_ssiddi(self): | |
|
||
def test_deepddi(self): | ||
"""Test DeepDDI.""" | ||
model = DeepDDI(x=2) | ||
assert model.x == 2 | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Create a test also for |
||
model = DeepDDI( | ||
drug_channels=self.loader.drug_channels, | ||
hidden_channels=16, | ||
hidden_layers_num=2, | ||
) | ||
|
||
optimizer = torch.optim.Adam(model.parameters(), lr=0.01, weight_decay=0.0001) | ||
model.train() | ||
loss = torch.nn.BCELoss() | ||
for batch in self.generator: | ||
optimizer.zero_grad() | ||
prediction = model(batch.drug_features_left, batch.drug_features_right) | ||
output = loss(prediction, batch.labels) | ||
output.backward() | ||
optimizer.step() | ||
assert prediction.shape[0] == batch.labels.shape[0] | ||
|
||
def test_deepdrug(self): | ||
"""Test DeepDrug.""" | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please change the docstring to include paper reference.