Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve survival analysis interface #825

Merged
merged 31 commits into from
Dec 2, 2024
Merged
Show file tree
Hide file tree
Changes from 8 commits
Commits
Show all changes
31 commits
Select commit Hold shift + click to select a range
a970b5b
updated kmf to match method signature
aGuyLearning Nov 13, 2024
7434bde
updated notebook
aGuyLearning Nov 13, 2024
5add2c3
updated ehrapy tutorial commit
aGuyLearning Nov 13, 2024
150a7f7
updated docu for new method signature
aGuyLearning Nov 13, 2024
b66fb44
added outputs to survival analysis
aGuyLearning Nov 13, 2024
0c8e6d6
correctly passing on fitting options
aGuyLearning Nov 13, 2024
95c2b74
pull request fixes.
aGuyLearning Nov 20, 2024
6085e96
added legacy suport
aGuyLearning Nov 20, 2024
16b7d5f
added kmf function legacy support in tests and added new kaplan_meier…
aGuyLearning Nov 27, 2024
579c220
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Nov 27, 2024
82a6e3c
updated notebook
aGuyLearning Nov 27, 2024
c604074
added stacklevel to deprecation warning
aGuyLearning Nov 27, 2024
f6b5a89
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Nov 27, 2024
7322f15
added deprecation warning in comment
aGuyLearning Nov 27, 2024
75f8000
Merge branch 'main' into enhancement/issue-822
Zethson Nov 28, 2024
1b80ff1
Update ehrapy/plot/_survival_analysis.py
eroell Dec 1, 2024
a26f6bc
Update ehrapy/plot/_survival_analysis.py
eroell Dec 1, 2024
1442983
Update ehrapy/plot/_survival_analysis.py
eroell Dec 1, 2024
972b71b
Update ehrapy/plot/_survival_analysis.py
eroell Dec 1, 2024
915df91
Update tests/tools/test_sa.py
eroell Dec 1, 2024
a3502b5
doc adjustments
eroell Dec 1, 2024
969eeb9
Merge branch 'main' into enhancement/issue-822
eroell Dec 1, 2024
1940ace
change name of kmf plot to kaplan_meier, some adjustments
eroell Dec 1, 2024
18f9292
introduce keyword only for univariate sa
eroell Dec 1, 2024
8c14039
correct docstring
eroell Dec 1, 2024
6f291bc
update submodule
eroell Dec 1, 2024
08e5949
add lifelines intersphinx mappings
eroell Dec 1, 2024
315e564
Update ehrapy/tools/_sa.py
Zethson Dec 2, 2024
b9e5bfb
Update ehrapy/tools/_sa.py
Zethson Dec 2, 2024
7bbe627
Update ehrapy/tools/_sa.py
Zethson Dec 2, 2024
5e14096
Merge branch 'main' into enhancement/issue-822
Zethson Dec 2, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion docs/tutorials/notebooks
16 changes: 7 additions & 9 deletions ehrapy/plot/_survival_analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,22 +186,20 @@ def kmf(
# So we need to flip `censor_fl` when pass `censor_fl` to KaplanMeierFitter

>>> adata[:, ["censor_flg"]].X = np.where(adata[:, ["censor_flg"]].X == 0, 1, 0)
>>> kmf = ep.tl.kmf(adata[:, ["mort_day_censored"]].X, adata[:, ["censor_flg"]].X)
>>> kmf = ep.tl.kmf(adata, "mort_day_censored", "censor_flg")
eroell marked this conversation as resolved.
Show resolved Hide resolved
>>> ep.pl.kmf(
... [kmf], color=["r"], xlim=[0, 700], ylim=[0, 1], xlabel="Days", ylabel="Proportion Survived", show=True
... )

.. image:: /_static/docstring_previews/kmf_plot_1.png

>>> T = adata[:, ["mort_day_censored"]].X
>>> E = adata[:, ["censor_flg"]].X
>>> groups = adata[:, ["service_unit"]].X
>>> ix1 = groups == "FICU"
>>> ix2 = groups == "MICU"
>>> ix3 = groups == "SICU"
>>> kmf_1 = ep.tl.kmf(T[ix1], E[ix1], label="FICU")
>>> kmf_2 = ep.tl.kmf(T[ix2], E[ix2], label="MICU")
>>> kmf_3 = ep.tl.kmf(T[ix3], E[ix3], label="SICU")
>>> adata_ficu = adata[groups == "FICU"]
>>> adata_micu = adata[groups == "MICU"]
>>> adata_sicu = adata[groups == "SICU"]
>>> kmf_1 = ep.tl.kmf(adata_ficu, "mort_day_censored", "censor_flg", label="FICU")
eroell marked this conversation as resolved.
Show resolved Hide resolved
>>> kmf_2 = ep.tl.kmf(adata_micu, "mort_day_censored", "censor_flg", label="MICU")
eroell marked this conversation as resolved.
Show resolved Hide resolved
>>> kmf_3 = ep.tl.kmf(adata_sicu, "mort_day_censored", "censor_flg", label="SICU")
eroell marked this conversation as resolved.
Show resolved Hide resolved
>>> ep.pl.kmf([kmf_1, kmf_2, kmf_3], ci_show=[False,False,False], color=['k','r', 'g'],
>>> xlim=[0, 750], ylim=[0, 1], xlabel="Days", ylabel="Proportion Survived")

Expand Down
216 changes: 175 additions & 41 deletions ehrapy/tools/_sa.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from __future__ import annotations

import warnings
from typing import TYPE_CHECKING, Literal

import numpy as np # This package is implicitly used
Expand Down Expand Up @@ -116,15 +117,19 @@ def glm(


def kmf(
aGuyLearning marked this conversation as resolved.
Show resolved Hide resolved
Zethson marked this conversation as resolved.
Show resolved Hide resolved
durations: Iterable,
event_observed: Iterable | None = None,
timeline: Iterable = None,
entry: Iterable | None = None,
aGuyLearning marked this conversation as resolved.
Show resolved Hide resolved
adata: AnnData,
duration_col: str,
event_col: str | None = None,
timeline: list[float] | None = None,
entry: str | None = None,
label: str | None = None,
alpha: float | None = None,
ci_labels: tuple[str, str] = None,
aGuyLearning marked this conversation as resolved.
Show resolved Hide resolved
weights: Iterable | None = None,
censoring: Literal["right", "left"] = None,
aGuyLearning marked this conversation as resolved.
Show resolved Hide resolved
ci_labels: list[str] | None = None,
weights: list[float] | None = None,
fit_options: dict | None = None,
censoring: Literal["right", "left"] = "right",
durations: Iterable | None = None,
event_observed: Iterable | None = None,
) -> KaplanMeierFitter:
"""Fit the Kaplan-Meier estimate for the survival function.

Expand All @@ -135,8 +140,9 @@ def kmf(
https://lifelines.readthedocs.io/en/latest/fitters/univariate/KaplanMeierFitter.html#module-lifelines.fitters.kaplan_meier_fitter

Args:
durations: length n -- duration (relative to subject's birth) the subject was alive for.
event_observed: True if the death was observed, False if the event was lost (right-censored). Defaults to all True if event_observed is equal to `None`.
adata: AnnData object with necessary columns `duration_col` and `event_col`.
duration_col: The name of the column in the AnnData objects that contains the subjects’ lifetimes.
event_col: The name of the column in anndata that contains the subjects’ death observation.
timeline: return the best estimate at the values in timelines (positively increasing)
Zethson marked this conversation as resolved.
Show resolved Hide resolved
entry: Relative time when a subject entered the study. This is useful for left-truncated (not left-censored) observations.
If None, all members of the population entered study when they were "born".
Expand All @@ -145,8 +151,12 @@ def kmf(
ci_labels: Add custom column names to the generated confidence intervals as a length-2 list: [<lower-bound name>, <upper-bound name>] (default: <label>_lower_<1-alpha/2>).
weights: If providing a weighted dataset. For example, instead of providing every subject
as a single element of `durations` and `event_observed`, one could weigh subject differently.
censoring: 'right' for fitting the model to a right-censored dataset.
'left' for fitting the model to a left-censored dataset (default: fit the model to a right-censored dataset).
fit_options: Additional keyword arguments to pass into the estimator.
censoring: 'right' for fitting the model to a right-censored dataset. (default, calls fit).
'left' for fitting the model to a left-censored dataset (calls fit_left_censoring).
durations: length n -- duration (relative to subject's birth) the subject was alive for. (legacy argument, use duration_col instead)
event_observed: True if the death was observed, False if the event was lost (right-censored). Defaults to all True if event_observed is equal to `None`.(this is a legacy argument, use event_col instead)


Returns:
Fitted KaplanMeierFitter.
eroell marked this conversation as resolved.
Show resolved Hide resolved
Expand All @@ -156,34 +166,58 @@ def kmf(
>>> adata = ep.dt.mimic_2(encoded=False)
>>> # Flip 'censor_fl' because 0 = death and 1 = censored
Zethson marked this conversation as resolved.
Show resolved Hide resolved
>>> adata[:, ["censor_flg"]].X = np.where(adata[:, ["censor_flg"]].X == 0, 1, 0)
>>> kmf = ep.tl.kmf(adata[:, ["mort_day_censored"]].X, adata[:, ["censor_flg"]].X)
>>> kmf = ep.tl.kmf(adata, "mort_day_censored", "censor_flg", label="Mortality")
"""
kmf = KaplanMeierFitter()
if censoring == "None" or "right":
kmf.fit(
durations=durations,
event_observed=event_observed,
timeline=timeline,
entry=entry,
label=label,
alpha=alpha,
ci_labels=ci_labels,
weights=weights,
# legacy support
if durations is not None:
# legacy warning
warnings.warn(
"The `durations` and `event_observed` arguments are deprecated, please use `duration_col` and `event_col` instead.",
DeprecationWarning,
stacklevel=2,
)
elif censoring == "left":
kmf.fit_left_censoring(
durations=durations,
event_observed=event_observed,
timeline=timeline,
entry=entry,
label=label,
alpha=alpha,
ci_labels=ci_labels,
weights=weights,
kmf = KaplanMeierFitter()
if censoring == "None" or "right":
aGuyLearning marked this conversation as resolved.
Show resolved Hide resolved
kmf.fit(
durations=durations,
event_observed=event_observed,
timeline=timeline,
entry=entry,
label=label,
alpha=alpha,
ci_labels=ci_labels,
weights=weights,
)
elif censoring == "left":
kmf.fit_left_censoring(
durations=durations,
event_observed=event_observed,
timeline=timeline,
entry=entry,
label=label,
alpha=alpha,
ci_labels=ci_labels,
weights=weights,
)

return kmf
else:
return _univariate_model(
adata,
duration_col,
event_col,
KaplanMeierFitter,
True,
timeline,
entry,
label,
alpha,
ci_labels,
weights,
fit_options,
censoring,
)

return kmf


def test_kmf_logrank(
kmf_A: KaplanMeierFitter,
Expand Down Expand Up @@ -376,7 +410,21 @@ def log_logistic_aft(adata: AnnData, duration_col: str, event_col: str, entry_co
)


def _univariate_model(adata: AnnData, duration_col: str, event_col: str, model_class, accept_zero_duration=True):
def _univariate_model(
adata: AnnData,
duration_col: str,
event_col: str,
model_class,
accept_zero_duration=True,
timeline: list[float] | None = None,
entry: str | None = None,
label: str | None = None,
alpha: float | None = None,
ci_labels: list[str] | None = None,
weights: list[float] | None = None,
fit_options: dict | None = None,
censoring: Literal["right", "left"] = "right",
):
"""Convenience function for univariate models."""
df = anndata_to_df(adata)

Expand All @@ -386,12 +434,38 @@ def _univariate_model(adata: AnnData, duration_col: str, event_col: str, model_c
E = df[event_col]

model = model_class()
model.fit(T, event_observed=E)
function_name = "fit" if censoring == "right" else "fit_left_censoring"
aGuyLearning marked this conversation as resolved.
Show resolved Hide resolved
# get fit function, default to fit if not found
fit_function = getattr(model, function_name, model.fit)

fit_function(
T,
event_observed=E,
timeline=timeline,
entry=entry,
label=label,
alpha=alpha,
ci_labels=ci_labels,
weights=weights,
fit_options=fit_options,
)

return model


def nelson_aalen(adata: AnnData, duration_col: str, event_col: str) -> NelsonAalenFitter:
def nelson_aalen(
adata: AnnData,
duration_col: str,
event_col: str,
aGuyLearning marked this conversation as resolved.
Show resolved Hide resolved
timeline: list[float] | None = None,
aGuyLearning marked this conversation as resolved.
Show resolved Hide resolved
entry: str | None = None,
label: str | None = None,
alpha: float | None = None,
ci_labels: list[str] | None = None,
weights: list[float] | None = None,
fit_options: dict | None = None,
censoring: Literal["right", "left"] = "right",
) -> NelsonAalenFitter:
"""Employ the Nelson-Aalen estimator to estimate the cumulative hazard function from censored survival data

The Nelson-Aalen estimator is a non-parametric method used in survival analysis to estimate the cumulative hazard function.
Expand All @@ -404,6 +478,17 @@ def nelson_aalen(adata: AnnData, duration_col: str, event_col: str) -> NelsonAal
duration_col: The name of the column in the AnnData objects that contains the subjects’ lifetimes.
event_col: The name of the column in anndata that contains the subjects’ death observation.
If left as None, assume all individuals are uncensored.
timeline: return the best estimate at the values in timelines (positively increasing)
Zethson marked this conversation as resolved.
Show resolved Hide resolved
entry: Relative time when a subject entered the study. This is useful for left-truncated (not left-censored) observations.
If None, all members of the population entered study when they were "born".
label: A string to name the column of the estimate.
alpha: The alpha value in the confidence intervals. Overrides the initializing alpha for this call to fit only.
ci_labels: Add custom column names to the generated confidence intervals as a length-2 list: [<lower-bound name>, <upper-bound name>] (default: <label>_lower_<1-alpha/2>).
weights: If providing a weighted dataset. For example, instead of providing every subject
as a single element of `durations` and `event_observed`, one could weigh subject differently.
fit_options: Additional keyword arguments to pass into the estimator.
censoring: 'right' for fitting the model to a right-censored dataset. (default, calls fit).
'left' for fitting the model to a left-censored dataset (calls fit_left_censoring).

Returns:
Fitted NelsonAalenFitter.
eroell marked this conversation as resolved.
Show resolved Hide resolved
Expand All @@ -415,10 +500,36 @@ def nelson_aalen(adata: AnnData, duration_col: str, event_col: str) -> NelsonAal
>>> adata[:, ["censor_flg"]].X = np.where(adata[:, ["censor_flg"]].X == 0, 1, 0)
>>> naf = ep.tl.nelson_aalen(adata, "mort_day_censored", "censor_flg")
"""
return _univariate_model(adata, duration_col, event_col, NelsonAalenFitter)

return _univariate_model(
adata,
duration_col,
event_col,
NelsonAalenFitter,
True,
timeline=timeline,
entry=entry,
label=label,
alpha=alpha,
ci_labels=ci_labels,
weights=weights,
fit_options=fit_options,
censoring=censoring,
)


def weibull(adata: AnnData, duration_col: str, event_col: str) -> WeibullFitter:
def weibull(
adata: AnnData,
duration_col: str,
event_col: str,
timeline: list[float] | None = None,
entry: str | None = None,
label: str | None = None,
alpha: float | None = None,
ci_labels: list[str] | None = None,
weights: list[float] | None = None,
fit_options: dict | None = None,
) -> WeibullFitter:
"""Employ the Weibull model in univariate survival analysis to understand event occurrence dynamics.

In contrast to the non-parametric Nelson-Aalen estimator, the Weibull model employs a parametric approach with shape and scale parameters,
Expand All @@ -434,6 +545,16 @@ def weibull(adata: AnnData, duration_col: str, event_col: str) -> WeibullFitter:
duration_col: Name of the column in the AnnData objects that contains the subjects’ lifetimes.
event_col: Name of the column in the AnnData object that contains the subjects’ death observation.
If left as None, assume all individuals are uncensored.
adata: AnnData object with necessary columns `duration_col` and `event_col`.
timeline: return the best estimate at the values in timelines (positively increasing)
Zethson marked this conversation as resolved.
Show resolved Hide resolved
entry: Relative time when a subject entered the study. This is useful for left-truncated (not left-censored) observations.
If None, all members of the population entered study when they were "born".
label: A string to name the column of the estimate.
alpha: The alpha value in the confidence intervals. Overrides the initializing alpha for this call to fit only.
ci_labels: Add custom column names to the generated confidence intervals as a length-2 list: [<lower-bound name>, <upper-bound name>] (default: <label>_lower_<1-alpha/2>).
weights: If providing a weighted dataset. For example, instead of providing every subject
as a single element of `durations` and `event_observed`, one could weigh subject differently.
fit_options: Additional keyword arguments to pass into the estimator.

Returns:
Fitted WeibullFitter.
eroell marked this conversation as resolved.
Show resolved Hide resolved
Expand All @@ -445,4 +566,17 @@ def weibull(adata: AnnData, duration_col: str, event_col: str) -> WeibullFitter:
>>> adata[:, ["censor_flg"]].X = np.where(adata[:, ["censor_flg"]].X == 0, 1, 0)
>>> wf = ep.tl.weibull(adata, "mort_day_censored", "censor_flg")
"""
return _univariate_model(adata, duration_col, event_col, WeibullFitter, accept_zero_duration=False)
return _univariate_model(
adata,
duration_col,
event_col,
WeibullFitter,
accept_zero_duration=False,
timeline=timeline,
entry=entry,
label=label,
alpha=alpha,
ci_labels=ci_labels,
weights=weights,
fit_options=fit_options,
)
2 changes: 1 addition & 1 deletion tests/tools/test_sa.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ def _sa_func_test(self, sa_function, sa_class, mimic_2_sa):

def test_kmf(self, mimic_2_sa):
adata, _, _ = mimic_2_sa
aGuyLearning marked this conversation as resolved.
Show resolved Hide resolved
kmf = ep.tl.kmf(adata[:, ["mort_day_censored"]].X, adata[:, ["censor_flg"]].X)
kmf = ep.tl.kmf(adata, "mort_day_censored", "censor_flg")
self._sa_function_assert(kmf, KaplanMeierFitter)

def test_cox_ph(self, mimic_2_sa):
Expand Down
Loading