Skip to content

Commit

Permalink
CASTER layer implementation (#73)
Browse files Browse the repository at this point in the history
* CASTER layer implementation

- only supervised training stage
- input dimensionality assumed to be correct

* Apply black and reorganize

* Move loss into its own module

* Update caster.py

* Reduce diff on citation

* Implement DeepDrug model (#68)

* WIP: model forward pass works, not tested

* added dropout and batch norm

* WIP: made DeepDrug example, not tested

* moved to using layers only, not GCN torchdrug model

* docstring

* added dropout and made context feats optional

* added DeepDrug unit test

* deepdrug self attribute fix

* docstring update

* unpack method update (when no context feats used)

* isort

* fixed test setting (context_channels)

* fixed testing without context

* black

* RST fix

* RST fix

* more pythonic loop + swap i to _

* removed context feat support in DeepDrug

* removed context handling from testing DeepDrug

* fixed examples DeepDrug, no context handling, decreased epochs 100->20

* removed unused import

* used a wrapper for calling the same layers on pairs of batches

* used a wrapper for calling the same layers on pairs of batches

* docstring fix

* Abstract process applied to left and right sides

* Apply black

* Cleanup

Co-authored-by: Charles Tapley Hoyt <cthoyt@gmail.com>

* Add GCN-BMP (#71)

* linting

* GCNBMP Scatter Reduction fix

* Using Rel Conv Layers instead of RGCN Model (avoid unecessary sum readouts)

* Added docstrings and fixed highway update implementation

* Make number of relationship configurable

* little help of black for linting

* Cleaning upuseless imports

* Sharing attention between right and left side

* Adding reference to GCNBMP docstring

* Type hinting everything

* Fixing docstring in example

* - Removing type hints in docstrings as they were added to signatures
- Chunked iteration of the BMP backbone for better readability

* Ading more-itertools as a dependecy

* Using pairwise for encoder construction

* Adding missing docstrings

* Fixing linting and precommit hook

* Fixing the citation back to what is in main

* Tests,formatting,example

* Tests,formatting,example

* GCNBMP

* Cleanup

Co-authored-by: kcvc236 <kcvc236@seskscpg057.prim.scp>
Co-authored-by: Rozemberczki <kmdb028@astrazeneca.net>
Co-authored-by: kcvc236 <kcvc236@seskscpg059.prim.scp>
Co-authored-by: Charles Tapley Hoyt <cthoyt@gmail.com>

* Implement DeepDDI model (#63)

* update: Add deepddi model

* update: Add deepddi examples

* update: Add deepddi test case

* Style: deepddi model

* Style: deepddi model

* Style: deepddi_examples.py

* Update deepddi.py

* Update deepddi.py

Co-authored-by: Charles Tapley Hoyt <cthoyt@gmail.com>

* CASTER review fixes

* flake8 fixes

* CASTER: typing fix

Co-authored-by: Andriy Nikolov <kgsq682@astrazeneca.net>
Co-authored-by: Charles Tapley Hoyt <cthoyt@gmail.com>
Co-authored-by: Piotr Grabowski <3966940+kajocina@users.noreply.github.com>
Co-authored-by: Michaël Ughetto <michael.ughetto@astrazeneca.com>
Co-authored-by: kcvc236 <kcvc236@seskscpg057.prim.scp>
Co-authored-by: Rozemberczki <kmdb028@astrazeneca.net>
Co-authored-by: kcvc236 <kcvc236@seskscpg059.prim.scp>
Co-authored-by: walter <32014404+hzcheney@users.noreply.github.com>
  • Loading branch information
9 people authored Feb 3, 2022
1 parent 424b440 commit 6463147
Show file tree
Hide file tree
Showing 5 changed files with 226 additions and 5 deletions.
59 changes: 59 additions & 0 deletions chemicalx/loss.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
"""Custom loss modules for chemicalx."""

from typing import Tuple

import torch
from torch.nn.modules.loss import _Loss

__all__ = [
"CASTERSupervisedLoss",
]


class CASTERSupervisedLoss(_Loss):
"""An implementation of the custom loss function for the supervised learning stage of the CASTER algorithm.
The algorithm is described in [huang2020]_. The loss function combines three separate loss functions on
different model outputs: class prediction loss, input reconstruction loss, and dictionary projection loss.
.. [huang2020] Huang, K., *et al.* (2020). `CASTER: Predicting drug interactions
with chemical substructure representation <https://doi.org/10.1609/aaai.v34i01.5412>`_.
*AAAI 2020 - 34th AAAI Conference on Artificial Intelligence*, 702–709.
"""

def __init__(
self, recon_loss_coeff: float = 1e-1, proj_coeff: float = 1e-1, lambda1: float = 1e-2, lambda2: float = 1e-1
):
"""
Initialize the custom loss function for the supervised learning stage of the CASTER algorithm.
:param recon_loss_coeff: coefficient for the reconstruction loss
:param proj_coeff: coefficient for the projection loss
:param lambda1: regularization coefficient for the projection loss
:param lambda2: regularization coefficient for the augmented projection loss
"""
super().__init__(reduction="none")
self.recon_loss_coeff = recon_loss_coeff
self.proj_coeff = proj_coeff
self.lambda1 = lambda1
self.lambda2 = lambda2
self.loss = torch.nn.BCELoss()

def forward(self, x: Tuple[torch.FloatTensor, ...], target: torch.Tensor) -> torch.FloatTensor:
"""Perform a forward pass of the loss calculation for the supervised learning stage of the CASTER algorithm.
:param x: a tuple of tensors returned by the model forward pass (see CASTER.forward() method)
:param target: target labels
:return: combined loss value
"""
score, recon, code, dictionary_features_latent, drug_pair_features_latent, drug_pair_features = x
batch_size, _ = drug_pair_features.shape
loss_prediction = self.loss(score, target.float())
loss_reconstruction = self.recon_loss_coeff * self.loss(recon, drug_pair_features)
loss_projection = self.proj_coeff * (
torch.norm(drug_pair_features_latent - torch.matmul(code, dictionary_features_latent))
+ self.lambda1 * torch.sum(torch.abs(code)) / batch_size
+ self.lambda2 * torch.norm(dictionary_features_latent, p="fro") / batch_size
)
loss = loss_prediction + loss_reconstruction + loss_projection
return loss
125 changes: 122 additions & 3 deletions chemicalx/models/caster.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,137 @@
"""An implementation of the CASTER model."""

from .base import UnimplementedModel
from typing import Tuple

import torch

from chemicalx.data import DrugPairBatch
from chemicalx.models import Model

__all__ = [
"CASTER",
]


class CASTER(UnimplementedModel):
class CASTER(Model):
"""An implementation of the CASTER model from [huang2020]_.
.. seealso:: This model was suggested in https://github.com/AstraZeneca/chemicalx/issues/15
.. seealso:: This model was suggested in https://github.com/AstraZeneca/chemicalx/issues/17
.. [huang2020] Huang, K., *et al.* (2020). `CASTER: Predicting drug interactions
with chemical substructure representation <https://doi.org/10.1609/aaai.v34i01.5412>`_.
*AAAI 2020 - 34th AAAI Conference on Artificial Intelligence*, 702–709.
"""

def __init__(
self,
*,
drug_channels: int,
encoder_hidden_channels: int = 500,
encoder_output_channels: int = 50,
decoder_hidden_channels: int = 500,
hidden_channels: int = 1024,
out_channels: int = 1,
lambda3: float = 1e-5,
magnifying_factor: int = 100,
):
"""Instantiate the CASTER model.
:param drug_channels: The number of drug features (recognised frequent substructures).
The original implementation recognised 1722 basis substructures in the BIOSNAP experiment.
:param encoder_hidden_channels: The number of hidden layer neurons in the encoder module.
:param encoder_output_channels: The number of output layer neurons in the encoder module.
:param decoder_hidden_channels: The number of hidden layer neurons in the decoder module.
:param hidden_channels: The number of hidden layer neurons in the predictor module.
:param out_channels: The number of output channels.
:param lambda3: regularisation coefficient in the dictionary encoder module.
:param magnifying_factor: The magnifying factor coefficient applied to the predictor module input.
"""
super().__init__()
self.lambda3 = lambda3
self.magnifying_factor = magnifying_factor
self.drug_channels = drug_channels

# encoder
self.encoder = torch.nn.Sequential(
torch.nn.Linear(self.drug_channels, encoder_hidden_channels),
torch.nn.ReLU(True),
torch.nn.Linear(encoder_hidden_channels, encoder_output_channels),
)

# decoder
self.decoder = torch.nn.Sequential(
torch.nn.Linear(encoder_output_channels, decoder_hidden_channels),
torch.nn.ReLU(True),
torch.nn.Linear(decoder_hidden_channels, drug_channels),
)

# predictor: eight layer NN
predictor_layers = []
predictor_layers.append(torch.nn.Linear(self.drug_channels, hidden_channels))
predictor_layers.append(torch.nn.ReLU(True))
for i in range(1, 6):
predictor_layers.append(torch.nn.BatchNorm1d(hidden_channels))
if i < 5:
predictor_layers.append(torch.nn.Linear(hidden_channels, hidden_channels))
else:
# in the original paper, the output of the last hidden layer before the output was fixed at 64 channels
predictor_layers.append(torch.nn.Linear(hidden_channels, 64))
predictor_layers.append(torch.nn.ReLU(True))
predictor_layers.append(torch.nn.Linear(64, out_channels))
predictor_layers.append(torch.nn.Sigmoid())
self.predictor = torch.nn.Sequential(*predictor_layers)

def unpack(self, batch: DrugPairBatch) -> Tuple[torch.FloatTensor]:
"""Return the "functional representation" of drug pairs, as defined in the original implementation.
:param batch: batch of drug pairs
:return: each pair is represented as a single vector with x^i = 1 if either x_1^i >= 1 or x_2^i >= 1
"""
pair_representation = (torch.maximum(batch.drug_features_left, batch.drug_features_right) >= 1.0).float()
return (pair_representation,)

def dictionary_encoder(
self, drug_pair_features_latent: torch.FloatTensor, dictionary_features_latent: torch.FloatTensor
) -> torch.FloatTensor:
"""Perform a forward pass of the dictionary encoder submodule.
:param drug_pair_features_latent: encoder output for the input drug_pair_features
(batch_size x encoder_output_channels)
:param dictionary_features_latent: projection of the drug_pair_features using the dictionary basis
(encoder_output_channels x drug_channels)
:return: sparse code X_o: (batch_size x drug_channels)
"""
dict_feat_squared = torch.matmul(dictionary_features_latent, dictionary_features_latent.transpose(2, 1))
dict_feat_squared_inv = torch.inverse(dict_feat_squared + self.lambda3 * (torch.eye(self.drug_channels)))
dict_feat_closed_form = torch.matmul(dict_feat_squared_inv, dictionary_features_latent)
r = drug_pair_features_latent[:, None, :].matmul(dict_feat_closed_form.transpose(2, 1)).squeeze(1)
return r

def forward(self, drug_pair_features: torch.FloatTensor) -> Tuple[torch.FloatTensor, ...]:
"""Run a forward pass of the CASTER model.
:param drug_pair_features: functional representation of each drug pair (see unpack method)
:return: (Tuple[torch.FloatTensor): a tuple of tensors including:
prediction_scores: predicted target scores for each drug pair
reconstructed: input drug pair vectors reconstructed by the encoder-decoder chain
dictionary_encoded: drug pair features encoded by the dictionary encoder submodule
dictionary_features_latent: projection of the encoded drug pair features using the dictionary basis
drug_pair_features_latent: encoder output for the input drug_pair_features
drug_pair_features: a copy of the input unpacked drug_pair_features (needed for loss calculation)
"""
drug_pair_features_latent = self.encoder(drug_pair_features)
dictionary_features_latent = self.encoder(torch.eye(self.drug_channels))
dictionary_features_latent = dictionary_features_latent.mul(drug_pair_features[:, :, None])
drug_pair_features_reconstructed = self.decoder(drug_pair_features_latent)
reconstructed = torch.sigmoid(drug_pair_features_reconstructed)
dictionary_encoded = self.dictionary_encoder(drug_pair_features_latent, dictionary_features_latent)
prediction_scores = self.predictor(self.magnifying_factor * dictionary_encoded)

return (
prediction_scores,
reconstructed,
dictionary_encoded,
dictionary_features_latent,
drug_pair_features_latent,
drug_pair_features,
)
3 changes: 3 additions & 0 deletions chemicalx/pipeline.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
"""A collection of full training and evaluation pipelines."""

import collections.abc
import json
import time
from dataclasses import dataclass
Expand Down Expand Up @@ -164,6 +165,8 @@ def pipeline(
predictions = []
for batch in test_generator:
prediction = model(*model.unpack(batch))
if isinstance(prediction, collections.abc.Sequence):
prediction = prediction[0]
prediction = prediction.detach().cpu().numpy()
identifiers = batch.identifiers
identifiers["prediction"] = prediction
Expand Down
30 changes: 30 additions & 0 deletions examples/caster_examples.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
"""Example with CASTER."""

from chemicalx import pipeline
from chemicalx.data import DrugCombDB
from chemicalx.loss import CASTERSupervisedLoss
from chemicalx.models import CASTER


def main():
"""Train and evaluate the CASTER model."""
dataset = DrugCombDB()
model = CASTER(drug_channels=dataset.drug_channels)
results = pipeline(
dataset=dataset,
model=model,
loss_cls=CASTERSupervisedLoss,
batch_size=5120,
epochs=1,
context_features=False,
drug_features=True,
drug_molecules=False,
metrics=[
"roc_auc",
],
)
results.summarize()


if __name__ == "__main__":
main()
14 changes: 12 additions & 2 deletions tests/unit/test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import chemicalx.models
from chemicalx import pipeline
from chemicalx.data import DatasetLoader, DrugComb, DrugCombDB
from chemicalx.loss import CASTERSupervisedLoss
from chemicalx.models import (
CASTER,
EPGCNDS,
Expand Down Expand Up @@ -153,8 +154,17 @@ def setUp(self):

def test_caster(self):
"""Test CASTER."""
model = CASTER(x=2)
assert model.x == 2
model = CASTER(drug_channels=256)
optimizer = torch.optim.Adam(model.parameters(), lr=0.01, weight_decay=0.0001)
model.train()
loss = CASTERSupervisedLoss()
for batch in self.generator:
optimizer.zero_grad()
prediction = model(*model.unpack(batch))
output = loss(prediction, batch.labels)
output.backward()
optimizer.step()
assert prediction[0].shape[0] == batch.labels.shape[0]

def test_epgcnds(self):
"""Test EPGCNDS."""
Expand Down

0 comments on commit 6463147

Please sign in to comment.