Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement DeepDDI model #63

Merged
merged 9 commits into from
Feb 2, 2022
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
80 changes: 78 additions & 2 deletions chemicalx/models/deepddi.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,90 @@
"""An implementation of the DeepDDI model."""

from .base import UnimplementedModel
import torch

from chemicalx.data import DrugPairBatch
from chemicalx.models import Model
__all__ = [
"DeepDDI",
]


class DeepDDI(UnimplementedModel):
class DeepDDI(Model):
"""An implementation of the DeepDDI model.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please change the docstring to include paper reference.


.. seealso:: https://github.com/AstraZeneca/chemicalx/issues/2
"""

def __init__(
self,
*,
drug_channels: int,
hidden_channels: int = 2048,
hidden_layers_num: int = 9,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we setup a smaller number of layers and lower hidden channel number?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@cthoyt can we have 4 layer with 32 hidden channels?

out_channels: int = 1,
):
"""Instantiate the DeepDDI model.

:param drug_channels: The number of drug features.
:param hidden_channels: The number of hidden layer neurons.
:param hidden_layers_num: The number of hidden layers.
:param out_channels: The number of output channels.
"""

super(DeepDDI, self).__init__()
assert hidden_layers_num > 1
dnn = []
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would strongly suggest using a fixed number of layers. E.g. 3 or 4.

dnn.extend([
cthoyt marked this conversation as resolved.
Show resolved Hide resolved
torch.nn.Linear(drug_channels * 2, hidden_channels),
torch.nn.ReLU(),
torch.nn.BatchNorm1d(
num_features=hidden_channels,
affine=True,
momentum=None
),
torch.nn.ReLU()
])
for _ in range(hidden_layers_num - 1):
dnn.extend([
cthoyt marked this conversation as resolved.
Show resolved Hide resolved
torch.nn.Linear(hidden_channels, hidden_channels),
torch.nn.ReLU(),
torch.nn.BatchNorm1d(
num_features=hidden_channels,
affine=True,
momentum=None
),
torch.nn.ReLU()
])
dnn.extend([
torch.nn.Linear(hidden_channels, out_channels),
torch.nn.Sigmoid()
])
self.dnn = torch.nn.Sequential(*dnn)

def unpack(self, batch: DrugPairBatch):
"""Return the context features, left drug features and right drug features."""
return (
batch.drug_features_left,
batch.drug_features_right,
)

def forward(
self,
drug_features_left: torch.FloatTensor,
drug_features_right: torch.FloatTensor,
) -> torch.FloatTensor:
"""
Run a forward pass of the DeepDDI model.

Args:
drug_features_left (torch.FloatTensor): A matrix of head drug features.
drug_features_right (torch.FloatTensor): A matrix of tail drug features.
Returns:
hidden (torch.FloatTensor): A column vector of predicted interaction scores.
"""
input_feature = torch.cat([drug_features_left, drug_features_right], 1)
cthoyt marked this conversation as resolved.
Show resolved Hide resolved
hidden = self.dnn(input_feature)
return hidden

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Remove spaces. The PR does not follow linting. Please use black to format your code. When you push again you can test it on the repo.



28 changes: 28 additions & 0 deletions examples/deepddi_examples.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
"""Example with DeepDDI."""

from chemicalx import pipeline
from chemicalx.data import DrugbankDDI
from chemicalx.models import DeepDDI


def main():
"""Train and evaluate the DeepSynergy model."""
dataset = DrugbankDDI()
model = DeepDDI(drug_channels=dataset.drug_channels, hidden_layers_num=2)
cthoyt marked this conversation as resolved.
Show resolved Hide resolved
results = pipeline(
dataset=dataset,
model=model,
batch_size=5120,
epochs=100,
context_features=False,
drug_features=True,
drug_molecules=False,
metrics=[
"roc_auc",
]
)
results.summarize()


if __name__ == "__main__":
main()
18 changes: 16 additions & 2 deletions tests/unit/test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,8 +165,22 @@ def test_ssiddi(self):

def test_deepddi(self):
"""Test DeepDDI."""
model = DeepDDI(x=2)
assert model.x == 2
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Create a test also for hidden_layer_num=1.

model = DeepDDI(
drug_channels=self.loader.drug_channels,
hidden_channels=16,
hidden_layers_num=2,
)

optimizer = torch.optim.Adam(model.parameters(), lr=0.01, weight_decay=0.0001)
model.train()
loss = torch.nn.BCELoss()
for batch in self.generator:
optimizer.zero_grad()
prediction = model(batch.drug_features_left, batch.drug_features_right)
output = loss(prediction, batch.labels)
output.backward()
optimizer.step()
assert prediction.shape[0] == batch.labels.shape[0]

def test_deepdrug(self):
"""Test DeepDrug."""
Expand Down