-
Notifications
You must be signed in to change notification settings - Fork 88
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* CASTER layer implementation - only supervised training stage - input dimensionality assumed to be correct * Apply black and reorganize * Move loss into its own module * Update caster.py * Reduce diff on citation * Implement DeepDrug model (#68) * WIP: model forward pass works, not tested * added dropout and batch norm * WIP: made DeepDrug example, not tested * moved to using layers only, not GCN torchdrug model * docstring * added dropout and made context feats optional * added DeepDrug unit test * deepdrug self attribute fix * docstring update * unpack method update (when no context feats used) * isort * fixed test setting (context_channels) * fixed testing without context * black * RST fix * RST fix * more pythonic loop + swap i to _ * removed context feat support in DeepDrug * removed context handling from testing DeepDrug * fixed examples DeepDrug, no context handling, decreased epochs 100->20 * removed unused import * used a wrapper for calling the same layers on pairs of batches * used a wrapper for calling the same layers on pairs of batches * docstring fix * Abstract process applied to left and right sides * Apply black * Cleanup Co-authored-by: Charles Tapley Hoyt <cthoyt@gmail.com> * Add GCN-BMP (#71) * linting * GCNBMP Scatter Reduction fix * Using Rel Conv Layers instead of RGCN Model (avoid unecessary sum readouts) * Added docstrings and fixed highway update implementation * Make number of relationship configurable * little help of black for linting * Cleaning upuseless imports * Sharing attention between right and left side * Adding reference to GCNBMP docstring * Type hinting everything * Fixing docstring in example * - Removing type hints in docstrings as they were added to signatures - Chunked iteration of the BMP backbone for better readability * Ading more-itertools as a dependecy * Using pairwise for encoder construction * Adding missing docstrings * Fixing linting and precommit hook * Fixing the citation back to what is in main * Tests,formatting,example * Tests,formatting,example * GCNBMP * Cleanup Co-authored-by: kcvc236 <kcvc236@seskscpg057.prim.scp> Co-authored-by: Rozemberczki <kmdb028@astrazeneca.net> Co-authored-by: kcvc236 <kcvc236@seskscpg059.prim.scp> Co-authored-by: Charles Tapley Hoyt <cthoyt@gmail.com> * Implement DeepDDI model (#63) * update: Add deepddi model * update: Add deepddi examples * update: Add deepddi test case * Style: deepddi model * Style: deepddi model * Style: deepddi_examples.py * Update deepddi.py * Update deepddi.py Co-authored-by: Charles Tapley Hoyt <cthoyt@gmail.com> * CASTER review fixes * flake8 fixes * CASTER: typing fix Co-authored-by: Andriy Nikolov <kgsq682@astrazeneca.net> Co-authored-by: Charles Tapley Hoyt <cthoyt@gmail.com> Co-authored-by: Piotr Grabowski <3966940+kajocina@users.noreply.github.com> Co-authored-by: Michaël Ughetto <michael.ughetto@astrazeneca.com> Co-authored-by: kcvc236 <kcvc236@seskscpg057.prim.scp> Co-authored-by: Rozemberczki <kmdb028@astrazeneca.net> Co-authored-by: kcvc236 <kcvc236@seskscpg059.prim.scp> Co-authored-by: walter <32014404+hzcheney@users.noreply.github.com>
- Loading branch information
1 parent
424b440
commit 6463147
Showing
5 changed files
with
226 additions
and
5 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,59 @@ | ||
"""Custom loss modules for chemicalx.""" | ||
|
||
from typing import Tuple | ||
|
||
import torch | ||
from torch.nn.modules.loss import _Loss | ||
|
||
__all__ = [ | ||
"CASTERSupervisedLoss", | ||
] | ||
|
||
|
||
class CASTERSupervisedLoss(_Loss): | ||
"""An implementation of the custom loss function for the supervised learning stage of the CASTER algorithm. | ||
The algorithm is described in [huang2020]_. The loss function combines three separate loss functions on | ||
different model outputs: class prediction loss, input reconstruction loss, and dictionary projection loss. | ||
.. [huang2020] Huang, K., *et al.* (2020). `CASTER: Predicting drug interactions | ||
with chemical substructure representation <https://doi.org/10.1609/aaai.v34i01.5412>`_. | ||
*AAAI 2020 - 34th AAAI Conference on Artificial Intelligence*, 702–709. | ||
""" | ||
|
||
def __init__( | ||
self, recon_loss_coeff: float = 1e-1, proj_coeff: float = 1e-1, lambda1: float = 1e-2, lambda2: float = 1e-1 | ||
): | ||
""" | ||
Initialize the custom loss function for the supervised learning stage of the CASTER algorithm. | ||
:param recon_loss_coeff: coefficient for the reconstruction loss | ||
:param proj_coeff: coefficient for the projection loss | ||
:param lambda1: regularization coefficient for the projection loss | ||
:param lambda2: regularization coefficient for the augmented projection loss | ||
""" | ||
super().__init__(reduction="none") | ||
self.recon_loss_coeff = recon_loss_coeff | ||
self.proj_coeff = proj_coeff | ||
self.lambda1 = lambda1 | ||
self.lambda2 = lambda2 | ||
self.loss = torch.nn.BCELoss() | ||
|
||
def forward(self, x: Tuple[torch.FloatTensor, ...], target: torch.Tensor) -> torch.FloatTensor: | ||
"""Perform a forward pass of the loss calculation for the supervised learning stage of the CASTER algorithm. | ||
:param x: a tuple of tensors returned by the model forward pass (see CASTER.forward() method) | ||
:param target: target labels | ||
:return: combined loss value | ||
""" | ||
score, recon, code, dictionary_features_latent, drug_pair_features_latent, drug_pair_features = x | ||
batch_size, _ = drug_pair_features.shape | ||
loss_prediction = self.loss(score, target.float()) | ||
loss_reconstruction = self.recon_loss_coeff * self.loss(recon, drug_pair_features) | ||
loss_projection = self.proj_coeff * ( | ||
torch.norm(drug_pair_features_latent - torch.matmul(code, dictionary_features_latent)) | ||
+ self.lambda1 * torch.sum(torch.abs(code)) / batch_size | ||
+ self.lambda2 * torch.norm(dictionary_features_latent, p="fro") / batch_size | ||
) | ||
loss = loss_prediction + loss_reconstruction + loss_projection | ||
return loss |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,18 +1,137 @@ | ||
"""An implementation of the CASTER model.""" | ||
|
||
from .base import UnimplementedModel | ||
from typing import Tuple | ||
|
||
import torch | ||
|
||
from chemicalx.data import DrugPairBatch | ||
from chemicalx.models import Model | ||
|
||
__all__ = [ | ||
"CASTER", | ||
] | ||
|
||
|
||
class CASTER(UnimplementedModel): | ||
class CASTER(Model): | ||
"""An implementation of the CASTER model from [huang2020]_. | ||
.. seealso:: This model was suggested in https://github.com/AstraZeneca/chemicalx/issues/15 | ||
.. seealso:: This model was suggested in https://github.com/AstraZeneca/chemicalx/issues/17 | ||
.. [huang2020] Huang, K., *et al.* (2020). `CASTER: Predicting drug interactions | ||
with chemical substructure representation <https://doi.org/10.1609/aaai.v34i01.5412>`_. | ||
*AAAI 2020 - 34th AAAI Conference on Artificial Intelligence*, 702–709. | ||
""" | ||
|
||
def __init__( | ||
self, | ||
*, | ||
drug_channels: int, | ||
encoder_hidden_channels: int = 500, | ||
encoder_output_channels: int = 50, | ||
decoder_hidden_channels: int = 500, | ||
hidden_channels: int = 1024, | ||
out_channels: int = 1, | ||
lambda3: float = 1e-5, | ||
magnifying_factor: int = 100, | ||
): | ||
"""Instantiate the CASTER model. | ||
:param drug_channels: The number of drug features (recognised frequent substructures). | ||
The original implementation recognised 1722 basis substructures in the BIOSNAP experiment. | ||
:param encoder_hidden_channels: The number of hidden layer neurons in the encoder module. | ||
:param encoder_output_channels: The number of output layer neurons in the encoder module. | ||
:param decoder_hidden_channels: The number of hidden layer neurons in the decoder module. | ||
:param hidden_channels: The number of hidden layer neurons in the predictor module. | ||
:param out_channels: The number of output channels. | ||
:param lambda3: regularisation coefficient in the dictionary encoder module. | ||
:param magnifying_factor: The magnifying factor coefficient applied to the predictor module input. | ||
""" | ||
super().__init__() | ||
self.lambda3 = lambda3 | ||
self.magnifying_factor = magnifying_factor | ||
self.drug_channels = drug_channels | ||
|
||
# encoder | ||
self.encoder = torch.nn.Sequential( | ||
torch.nn.Linear(self.drug_channels, encoder_hidden_channels), | ||
torch.nn.ReLU(True), | ||
torch.nn.Linear(encoder_hidden_channels, encoder_output_channels), | ||
) | ||
|
||
# decoder | ||
self.decoder = torch.nn.Sequential( | ||
torch.nn.Linear(encoder_output_channels, decoder_hidden_channels), | ||
torch.nn.ReLU(True), | ||
torch.nn.Linear(decoder_hidden_channels, drug_channels), | ||
) | ||
|
||
# predictor: eight layer NN | ||
predictor_layers = [] | ||
predictor_layers.append(torch.nn.Linear(self.drug_channels, hidden_channels)) | ||
predictor_layers.append(torch.nn.ReLU(True)) | ||
for i in range(1, 6): | ||
predictor_layers.append(torch.nn.BatchNorm1d(hidden_channels)) | ||
if i < 5: | ||
predictor_layers.append(torch.nn.Linear(hidden_channels, hidden_channels)) | ||
else: | ||
# in the original paper, the output of the last hidden layer before the output was fixed at 64 channels | ||
predictor_layers.append(torch.nn.Linear(hidden_channels, 64)) | ||
predictor_layers.append(torch.nn.ReLU(True)) | ||
predictor_layers.append(torch.nn.Linear(64, out_channels)) | ||
predictor_layers.append(torch.nn.Sigmoid()) | ||
self.predictor = torch.nn.Sequential(*predictor_layers) | ||
|
||
def unpack(self, batch: DrugPairBatch) -> Tuple[torch.FloatTensor]: | ||
"""Return the "functional representation" of drug pairs, as defined in the original implementation. | ||
:param batch: batch of drug pairs | ||
:return: each pair is represented as a single vector with x^i = 1 if either x_1^i >= 1 or x_2^i >= 1 | ||
""" | ||
pair_representation = (torch.maximum(batch.drug_features_left, batch.drug_features_right) >= 1.0).float() | ||
return (pair_representation,) | ||
|
||
def dictionary_encoder( | ||
self, drug_pair_features_latent: torch.FloatTensor, dictionary_features_latent: torch.FloatTensor | ||
) -> torch.FloatTensor: | ||
"""Perform a forward pass of the dictionary encoder submodule. | ||
:param drug_pair_features_latent: encoder output for the input drug_pair_features | ||
(batch_size x encoder_output_channels) | ||
:param dictionary_features_latent: projection of the drug_pair_features using the dictionary basis | ||
(encoder_output_channels x drug_channels) | ||
:return: sparse code X_o: (batch_size x drug_channels) | ||
""" | ||
dict_feat_squared = torch.matmul(dictionary_features_latent, dictionary_features_latent.transpose(2, 1)) | ||
dict_feat_squared_inv = torch.inverse(dict_feat_squared + self.lambda3 * (torch.eye(self.drug_channels))) | ||
dict_feat_closed_form = torch.matmul(dict_feat_squared_inv, dictionary_features_latent) | ||
r = drug_pair_features_latent[:, None, :].matmul(dict_feat_closed_form.transpose(2, 1)).squeeze(1) | ||
return r | ||
|
||
def forward(self, drug_pair_features: torch.FloatTensor) -> Tuple[torch.FloatTensor, ...]: | ||
"""Run a forward pass of the CASTER model. | ||
:param drug_pair_features: functional representation of each drug pair (see unpack method) | ||
:return: (Tuple[torch.FloatTensor): a tuple of tensors including: | ||
prediction_scores: predicted target scores for each drug pair | ||
reconstructed: input drug pair vectors reconstructed by the encoder-decoder chain | ||
dictionary_encoded: drug pair features encoded by the dictionary encoder submodule | ||
dictionary_features_latent: projection of the encoded drug pair features using the dictionary basis | ||
drug_pair_features_latent: encoder output for the input drug_pair_features | ||
drug_pair_features: a copy of the input unpacked drug_pair_features (needed for loss calculation) | ||
""" | ||
drug_pair_features_latent = self.encoder(drug_pair_features) | ||
dictionary_features_latent = self.encoder(torch.eye(self.drug_channels)) | ||
dictionary_features_latent = dictionary_features_latent.mul(drug_pair_features[:, :, None]) | ||
drug_pair_features_reconstructed = self.decoder(drug_pair_features_latent) | ||
reconstructed = torch.sigmoid(drug_pair_features_reconstructed) | ||
dictionary_encoded = self.dictionary_encoder(drug_pair_features_latent, dictionary_features_latent) | ||
prediction_scores = self.predictor(self.magnifying_factor * dictionary_encoded) | ||
|
||
return ( | ||
prediction_scores, | ||
reconstructed, | ||
dictionary_encoded, | ||
dictionary_features_latent, | ||
drug_pair_features_latent, | ||
drug_pair_features, | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,30 @@ | ||
"""Example with CASTER.""" | ||
|
||
from chemicalx import pipeline | ||
from chemicalx.data import DrugCombDB | ||
from chemicalx.loss import CASTERSupervisedLoss | ||
from chemicalx.models import CASTER | ||
|
||
|
||
def main(): | ||
"""Train and evaluate the CASTER model.""" | ||
dataset = DrugCombDB() | ||
model = CASTER(drug_channels=dataset.drug_channels) | ||
results = pipeline( | ||
dataset=dataset, | ||
model=model, | ||
loss_cls=CASTERSupervisedLoss, | ||
batch_size=5120, | ||
epochs=1, | ||
context_features=False, | ||
drug_features=True, | ||
drug_molecules=False, | ||
metrics=[ | ||
"roc_auc", | ||
], | ||
) | ||
results.summarize() | ||
|
||
|
||
if __name__ == "__main__": | ||
main() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters