Skip to content

Commit

Permalink
Add ligand RMSD metrics (#199)
Browse files Browse the repository at this point in the history
* wip

* fix test

* update targettype

* simplify the rmsd fn

* update tests

* Update polaris/evaluate/metrics/docking_metrics.py

Co-authored-by: Cas Wognum <caswognum@outlook.com>

* update docking metrics

* update test

* relocate metrics

* add metrics in API

* remove warnings

* update tests

* update docstring

---------

Co-authored-by: Cas Wognum <caswognum@outlook.com>
  • Loading branch information
zhu0619 and cwognum authored Sep 21, 2024
1 parent c1df924 commit e5370b7
Show file tree
Hide file tree
Showing 9 changed files with 194 additions and 44 deletions.
7 changes: 5 additions & 2 deletions docs/api/evaluation.md
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,15 @@

::: polaris.evaluate.MetricInfo

::: polaris.evaluate._metric.absolute_average_fold_error

---

::: polaris.evaluate.Metric
options:
filters: ["!^_", "!fn", "!is_multitask", "!y_type"]

---

::: polaris.evaluate.metrics.generic_metrics
::: polaris.evaluate.metrics.docking_metrics

---
53 changes: 12 additions & 41 deletions polaris/evaluate/_metric.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,10 @@

import numpy as np
from pydantic import BaseModel, Field
from scipy import stats

from sklearn.metrics import (
accuracy_score,
average_precision_score,
cohen_kappa_score as sk_cohen_kappa_score,
explained_variance_score,
f1_score,
matthews_corrcoef,
Expand All @@ -18,46 +17,15 @@
balanced_accuracy_score,
)

from polaris.utils.types import DirectionType


def pearsonr(y_true: np.ndarray, y_pred: np.ndarray):
"""Calculate a pearson r correlation"""
return stats.pearsonr(y_true, y_pred).statistic


def spearman(y_true: np.ndarray, y_pred: np.ndarray):
"""Calculate a Spearman correlation"""
return stats.spearmanr(y_true, y_pred).statistic


def absolute_average_fold_error(y_true: np.ndarray, y_pred: np.ndarray) -> float:
"""
Calculate the Absolute Average Fold Error (AAFE) metric.
It measures the fold change between predicted values and observed values.
The implementation is based on [this paper](https://pubs.acs.org/doi/10.1021/acs.chemrestox.3c00305).
Args:
y_true: The true target values of shape (n_samples,)
y_pred: The predicted target values of shape (n_samples,).
Returns:
aafe: The Absolute Average Fold Error.
"""
if len(y_true) != len(y_pred):
raise ValueError("Length of y_true and y_pred must be the same.")

if np.any(y_true == 0):
raise ValueError("`y_true` contains zero which will result `Inf` value.")

aafe = np.mean(np.abs(y_pred) / np.abs(y_true))

return aafe

from polaris.evaluate.metrics import (
cohen_kappa_score,
absolute_average_fold_error,
spearman,
pearsonr,
)
from polaris.evaluate.metrics.docking_metrics import rmsd_coverage

def cohen_kappa_score(y_true, y_pred, **kwargs):
"""Scikit learn cohen_kappa_score wraper with renamed arguments"""
return sk_cohen_kappa_score(y1=y_true, y2=y_pred, **kwargs)
from polaris.utils.types import DirectionType


class MetricInfo(BaseModel):
Expand Down Expand Up @@ -120,6 +88,9 @@ class Metric(Enum):
)
# TODO: add metrics to handle multitask multiclass predictions.

# docking related metrics
rmsd_coverage = MetricInfo(fn=rmsd_coverage, direction="max", y_type="y_pred")

@property
def fn(self) -> Callable:
"""The callable that actually computes the metric"""
Expand Down
7 changes: 7 additions & 0 deletions polaris/evaluate/metrics/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
from polaris.evaluate.metrics.generic_metrics import (
cohen_kappa_score,
absolute_average_fold_error,
spearman,
pearsonr,
)
from polaris.evaluate.metrics.docking_metrics import rmsd_coverage
60 changes: 60 additions & 0 deletions polaris/evaluate/metrics/docking_metrics.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
# This script includes docking related evaluation metrics.

from typing import Union, List

import numpy as np
from rdkit.Chem.rdMolAlign import CalcRMS

import datamol as dm


def _rmsd(mol_probe: dm.Mol, mol_ref: dm.Mol) -> float:
"""Calculate RMSD between predicted molecule and closest ground truth molecule.
The RMSD is calculated with first conformer of predicted molecule and only consider heavy atoms for RMSD calculation.
It is assumed that the predicted binding conformers are extracted from the docking output, where the receptor (protein) coordinates have been aligned with the original crystal structure.
Args:
mol_probe: Predicted molecule (docked ligand) with exactly one conformer.
mol_ref: Ground truth molecule (crystal ligand) with at least one conformer. If multiple conformers are
present, the lowest RMSD will be reported.
Returns:
Returns the RMS between two molecules, taking symmetry into account.
"""

# copy the molecule for modification.
mol_probe = dm.copy_mol(mol_probe)
mol_ref = dm.copy_mol(mol_ref)

# remove hydrogen from molecule
mol_probe = dm.remove_hs(mol_probe)
mol_ref = dm.remove_hs(mol_ref)

# calculate RMSD
return CalcRMS(
prbMol=mol_probe, refMol=mol_ref, symmetrizeConjugatedTerminalGroups=True, prbId=-1, refId=-1
)


def rmsd_coverage(y_pred: Union[str, List[dm.Mol]], y_true: Union[str, list[dm.Mol]], max_rsmd: float = 2):
"""
Calculate the coverage of molecules with an RMSD less than a threshold (2 Å by default) compared to the reference molecule conformer.
It is assumed that the predicted binding conformers are extracted from the docking output, where the receptor (protein) coordinates have been aligned with the original crystal structure.
Attributes:
y_pred: List of predicted binding conformers.
y_true: List of ground truth binding confoermers.
max_rsmd: The threshold for determining acceptable rsmd.
"""

if len(y_pred) != len(y_true):
assert ValueError(
f"The list of probing molecules and the list of reference molecules are different sizes. {len(y_pred)} != {len(y_true)} "
)

rmsds = np.array(
[_rmsd(mol_probe=mol_probe, mol_ref=mol_ref) for mol_probe, mol_ref in zip(y_pred, y_true)]
)
print(rmsds)
return np.sum(rmsds <= max_rsmd) / len(rmsds)
42 changes: 42 additions & 0 deletions polaris/evaluate/metrics/generic_metrics.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
import numpy as np
from scipy import stats
from sklearn.metrics import cohen_kappa_score as sk_cohen_kappa_score


def pearsonr(y_true: np.ndarray, y_pred: np.ndarray):
"""Calculate a pearson r correlation"""
return stats.pearsonr(y_true, y_pred).statistic


def spearman(y_true: np.ndarray, y_pred: np.ndarray):
"""Calculate a Spearman correlation"""
return stats.spearmanr(y_true, y_pred).statistic


def absolute_average_fold_error(y_true: np.ndarray, y_pred: np.ndarray) -> float:
"""
Calculate the Absolute Average Fold Error (AAFE) metric.
It measures the fold change between predicted values and observed values.
The implementation is based on [this paper](https://pubs.acs.org/doi/10.1021/acs.chemrestox.3c00305).
Args:
y_true: The true target values of shape (n_samples,)
y_pred: The predicted target values of shape (n_samples,).
Returns:
aafe: The Absolute Average Fold Error.
"""
if len(y_true) != len(y_pred):
raise ValueError("Length of y_true and y_pred must be the same.")

if np.any(y_true == 0):
raise ValueError("`y_true` contains zero which will result `Inf` value.")

aafe = np.mean(np.abs(y_pred) / np.abs(y_true))

return aafe


def cohen_kappa_score(y_true, y_pred, **kwargs):
"""Scikit learn cohen_kappa_score wraper with renamed arguments"""
return sk_cohen_kappa_score(y1=y_true, y2=y_pred, **kwargs)
18 changes: 17 additions & 1 deletion polaris/evaluate/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,20 @@ def safe_mask(
return input_values[test_label][target_label][mask]


def mask_index(input_values):
if np.issubdtype(input_values.dtype, np.number):
mask = ~np.isnan(input_values)
else:
# Create a mask to identify NaNs
mask = np.full(input_values.shape, True, dtype=bool)
# Iterate over the array to identify NaNs
for index, value in np.ndenumerate(input_values):
# Convert to float and check if it's NaN
if value is None:
mask[index] = False
return mask


def evaluate_benchmark(
target_cols: list[str],
metrics: list[Metric],
Expand Down Expand Up @@ -84,7 +98,9 @@ def evaluate_benchmark(
for target_label, y_true_target in y_true_subset.items():
# Single-task metrics for a multi-task benchmark
# In such a setting, there can be NaN values, which we thus have to filter out.
mask = ~np.isnan(y_true_target)

mask = mask_index(y_true_target)

score = metric(
y_true=y_true_target[mask],
y_pred=safe_mask(y_pred, test_label, target_label, mask),
Expand Down
1 change: 1 addition & 0 deletions polaris/utils/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,6 +162,7 @@ class TargetType(Enum):

REGRESSION = "regression"
CLASSIFICATION = "classification"
DOCKING = "docking"


class TaskType(Enum):
Expand Down
29 changes: 29 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@
from polaris.dataset import ColumnAnnotation, CompetitionDataset, DatasetV1
from polaris.experimental._dataset_v2 import DatasetV2
from polaris.utils.types import HubOwner
from polaris.dataset.converters import SDFConverter
from polaris.dataset import DatasetFactory


def check_version(artifact):
Expand Down Expand Up @@ -381,3 +383,30 @@ def test_multi_task_benchmark_multiple_test_sets(test_dataset):
)
check_version(benchmark)
return benchmark


@pytest.fixture(scope="function")
def test_docking_dataset(tmpdir, sdf_files, test_org_owner):
# toy docking dataset
factory = DatasetFactory(tmpdir.join("ligands.zarr"))

converter = SDFConverter(mol_prop_as_cols=True)
factory.register_converter("sdf", converter)
factory.add_from_files(sdf_files, axis=0)
dataset = factory.build()
check_version(dataset)
return dataset


@pytest.fixture(scope="function")
def test_docking_benchmark(test_docking_dataset):
benchmark = SingleTaskBenchmarkSpecification(
name="single-task-single-set-benchmark",
dataset=test_docking_dataset,
metrics=["rmsd_coverage"],
split=([], [0, 1]),
target_cols=["molecule"],
input_cols=["smiles"],
)
check_version(benchmark)
return benchmark
21 changes: 21 additions & 0 deletions tests/test_evaluate.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
import pandas as pd
import pytest

import datamol as dm

import polaris as po
from polaris.benchmark import (
MultiTaskBenchmarkSpecification,
Expand Down Expand Up @@ -185,3 +187,22 @@ def test_metric_y_types(
test_single_task_benchmark_clf.metrics = [Metric.f1]
result = test_single_task_benchmark_clf.evaluate(y_pred=predictions, y_prob=probabilities)
assert result.results.Score.values[0] == Metric.f1.fn(y_true=test_y, y_pred=predictions)


def test_metrics_docking(test_docking_benchmark: SingleTaskBenchmarkSpecification, caffeine, ibuprofen):
_, test = test_docking_benchmark.get_train_test_split()

predictions = np.array([caffeine, ibuprofen])
result = test_docking_benchmark.evaluate(y_pred=predictions)

for metric in test_docking_benchmark.metrics:
assert metric in result.results.Metric.tolist()

# sanity check
assert result.results.Score.values[0] == 1

conf_caffeine = dm.conformers.generate(mol=caffeine, n_confs=1, random_seed=333)
conf_ibuprofen = dm.conformers.generate(mol=ibuprofen, n_confs=1, random_seed=333)
predictions = np.array([conf_caffeine, conf_ibuprofen])
result = test_docking_benchmark.evaluate(y_pred=predictions)
assert result.results.Score.values[0] == 0

0 comments on commit e5370b7

Please sign in to comment.