-
Notifications
You must be signed in to change notification settings - Fork 6
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* wip * fix test * update targettype * simplify the rmsd fn * update tests * Update polaris/evaluate/metrics/docking_metrics.py Co-authored-by: Cas Wognum <caswognum@outlook.com> * update docking metrics * update test * relocate metrics * add metrics in API * remove warnings * update tests * update docstring --------- Co-authored-by: Cas Wognum <caswognum@outlook.com>
- Loading branch information
Showing
9 changed files
with
194 additions
and
44 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,7 @@ | ||
from polaris.evaluate.metrics.generic_metrics import ( | ||
cohen_kappa_score, | ||
absolute_average_fold_error, | ||
spearman, | ||
pearsonr, | ||
) | ||
from polaris.evaluate.metrics.docking_metrics import rmsd_coverage |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,60 @@ | ||
# This script includes docking related evaluation metrics. | ||
|
||
from typing import Union, List | ||
|
||
import numpy as np | ||
from rdkit.Chem.rdMolAlign import CalcRMS | ||
|
||
import datamol as dm | ||
|
||
|
||
def _rmsd(mol_probe: dm.Mol, mol_ref: dm.Mol) -> float: | ||
"""Calculate RMSD between predicted molecule and closest ground truth molecule. | ||
The RMSD is calculated with first conformer of predicted molecule and only consider heavy atoms for RMSD calculation. | ||
It is assumed that the predicted binding conformers are extracted from the docking output, where the receptor (protein) coordinates have been aligned with the original crystal structure. | ||
Args: | ||
mol_probe: Predicted molecule (docked ligand) with exactly one conformer. | ||
mol_ref: Ground truth molecule (crystal ligand) with at least one conformer. If multiple conformers are | ||
present, the lowest RMSD will be reported. | ||
Returns: | ||
Returns the RMS between two molecules, taking symmetry into account. | ||
""" | ||
|
||
# copy the molecule for modification. | ||
mol_probe = dm.copy_mol(mol_probe) | ||
mol_ref = dm.copy_mol(mol_ref) | ||
|
||
# remove hydrogen from molecule | ||
mol_probe = dm.remove_hs(mol_probe) | ||
mol_ref = dm.remove_hs(mol_ref) | ||
|
||
# calculate RMSD | ||
return CalcRMS( | ||
prbMol=mol_probe, refMol=mol_ref, symmetrizeConjugatedTerminalGroups=True, prbId=-1, refId=-1 | ||
) | ||
|
||
|
||
def rmsd_coverage(y_pred: Union[str, List[dm.Mol]], y_true: Union[str, list[dm.Mol]], max_rsmd: float = 2): | ||
""" | ||
Calculate the coverage of molecules with an RMSD less than a threshold (2 Å by default) compared to the reference molecule conformer. | ||
It is assumed that the predicted binding conformers are extracted from the docking output, where the receptor (protein) coordinates have been aligned with the original crystal structure. | ||
Attributes: | ||
y_pred: List of predicted binding conformers. | ||
y_true: List of ground truth binding confoermers. | ||
max_rsmd: The threshold for determining acceptable rsmd. | ||
""" | ||
|
||
if len(y_pred) != len(y_true): | ||
assert ValueError( | ||
f"The list of probing molecules and the list of reference molecules are different sizes. {len(y_pred)} != {len(y_true)} " | ||
) | ||
|
||
rmsds = np.array( | ||
[_rmsd(mol_probe=mol_probe, mol_ref=mol_ref) for mol_probe, mol_ref in zip(y_pred, y_true)] | ||
) | ||
print(rmsds) | ||
return np.sum(rmsds <= max_rsmd) / len(rmsds) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,42 @@ | ||
import numpy as np | ||
from scipy import stats | ||
from sklearn.metrics import cohen_kappa_score as sk_cohen_kappa_score | ||
|
||
|
||
def pearsonr(y_true: np.ndarray, y_pred: np.ndarray): | ||
"""Calculate a pearson r correlation""" | ||
return stats.pearsonr(y_true, y_pred).statistic | ||
|
||
|
||
def spearman(y_true: np.ndarray, y_pred: np.ndarray): | ||
"""Calculate a Spearman correlation""" | ||
return stats.spearmanr(y_true, y_pred).statistic | ||
|
||
|
||
def absolute_average_fold_error(y_true: np.ndarray, y_pred: np.ndarray) -> float: | ||
""" | ||
Calculate the Absolute Average Fold Error (AAFE) metric. | ||
It measures the fold change between predicted values and observed values. | ||
The implementation is based on [this paper](https://pubs.acs.org/doi/10.1021/acs.chemrestox.3c00305). | ||
Args: | ||
y_true: The true target values of shape (n_samples,) | ||
y_pred: The predicted target values of shape (n_samples,). | ||
Returns: | ||
aafe: The Absolute Average Fold Error. | ||
""" | ||
if len(y_true) != len(y_pred): | ||
raise ValueError("Length of y_true and y_pred must be the same.") | ||
|
||
if np.any(y_true == 0): | ||
raise ValueError("`y_true` contains zero which will result `Inf` value.") | ||
|
||
aafe = np.mean(np.abs(y_pred) / np.abs(y_true)) | ||
|
||
return aafe | ||
|
||
|
||
def cohen_kappa_score(y_true, y_pred, **kwargs): | ||
"""Scikit learn cohen_kappa_score wraper with renamed arguments""" | ||
return sk_cohen_kappa_score(y1=y_true, y2=y_pred, **kwargs) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters