Skip to content

Commit

Permalink
updated variable names and moved test to better file
Browse files Browse the repository at this point in the history
  • Loading branch information
aGuyLearning committed Dec 13, 2024
1 parent 66815c4 commit 76f0f0c
Show file tree
Hide file tree
Showing 3 changed files with 40 additions and 34 deletions.
46 changes: 23 additions & 23 deletions ehrapy/plot/_survival_analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -340,76 +340,76 @@ def cox_ph_forestplot(
.. image:: /_static/docstring_previews/coxph_forestplot.png
"""
data = cox_ph.summary
coxph_summary = cox_ph.summary
auc_col = "coef"

if labels is None:
labels = data.index
labels = coxph_summary.index
tval = []
ytick = []
for i in range(len(data)):
if not np.isnan(data[auc_col][i]):
for i in range(len(coxph_summary)):
if not np.isnan(coxph_summary[auc_col][i]):
if (
(isinstance(data[auc_col][i], float))
& (isinstance(data["coef lower 95%"][i], float))
& (isinstance(data["coef upper 95%"][i], float))
(isinstance(coxph_summary[auc_col][i], float))
& (isinstance(coxph_summary["coef lower 95%"][i], float))
& (isinstance(coxph_summary["coef upper 95%"][i], float))
):
tval.append(
[
round(data[auc_col][i], decimal),
round(coxph_summary[auc_col][i], decimal),
(
"("
+ str(round(data["coef lower 95%"][i], decimal))
+ str(round(coxph_summary["coef lower 95%"][i], decimal))
+ ", "
+ str(round(data["coef upper 95%"][i], decimal))
+ str(round(coxph_summary["coef upper 95%"][i], decimal))
+ ")"
),
]
)
else:
tval.append(
[
data[auc_col][i],
("(" + str(data["coef lower 95%"][i]) + ", " + str(data["coef upper 95%"][i]) + ")"),
coxph_summary[auc_col][i],
("(" + str(coxph_summary["coef lower 95%"][i]) + ", " + str(coxph_summary["coef upper 95%"][i]) + ")"),
]
)
ytick.append(i)
else:
tval.append([" ", " "])
ytick.append(i)

maxi = round(((pd.to_numeric(data["coef upper 95%"])).max() + 0.1), 2) # setting x-axis maximum
x_axis_upper_bound = round(((pd.to_numeric(coxph_summary["coef upper 95%"])).max() + 0.1), 2)

mini = round(((pd.to_numeric(data["coef lower 95%"])).min() - 0.1), 1) # setting x-axis minimum
x_axis_lower_bound = round(((pd.to_numeric(coxph_summary["coef lower 95%"])).min() - 0.1), 1)

fig = plt.figure(figsize=fig_size)
gspec = gridspec.GridSpec(1, 6)
plot = plt.subplot(gspec[0, 0:4]) # plot of data
tabl = plt.subplot(gspec[0, 4:]) # table
plot.set_ylim(-1, (len(data))) # spacing out y-axis properly
plot.set_ylim(-1, (len(coxph_summary))) # spacing out y-axis properly

plot.axvline(1, color="gray", zorder=1)
lower_diff = data[auc_col] - data["coef lower 95%"]
upper_diff = data["coef upper 95%"] - data[auc_col]
lower_diff = coxph_summary[auc_col] - coxph_summary["coef lower 95%"]
upper_diff = coxph_summary["coef upper 95%"] - coxph_summary[auc_col]
plot.errorbar(
data[auc_col],
data.index,
coxph_summary[auc_col],
coxph_summary.index,
xerr=[lower_diff, upper_diff],
marker="None",
zorder=2,
ecolor=ecolor,
linewidth=0,
elinewidth=1,
)
plot.scatter(data[auc_col], data.index, c=color, s=(size * 25), marker=marker, zorder=3, edgecolors="None")
plot.scatter(coxph_summary[auc_col], coxph_summary.index, c=color, s=(size * 25), marker=marker, zorder=3, edgecolors="None")
plot.xaxis.set_ticks_position("bottom")
plot.yaxis.set_ticks_position("left")
plot.get_xaxis().set_major_formatter(ticker.ScalarFormatter())
plot.get_xaxis().set_minor_formatter(ticker.NullFormatter())
plot.set_yticks(ytick)
plot.set_xlim([mini, maxi])
plot.set_xticks([mini, 1, maxi])
plot.set_xticklabels([mini, 1, maxi])
plot.set_xlim([x_axis_lower_bound, x_axis_upper_bound])
plot.set_xticks([x_axis_lower_bound, 1, x_axis_upper_bound])
plot.set_xticklabels([x_axis_lower_bound, 1, x_axis_upper_bound])
plot.set_yticklabels(labels)
plot.tick_params(axis="y", labelsize=text_size)
plot.yaxis.set_ticks_position("none")
Expand Down
11 changes: 0 additions & 11 deletions tests/plot/test_catplot.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,14 +15,3 @@ def test_catplot_vanilla(adata_mini, check_same_image):
tol=2e-1,
)


def test_coxph_forestplot(mimic_2, check_same_image):
adata_subset = mimic_2[:, ["mort_day_censored", "censor_flg", "gender_num", "afib_flg", "day_icu_intime_num"]]
coxph = ep.tl.cox_ph(adata_subset, duration_col="mort_day_censored", event_col="censor_flg")
fig, ax = ep.pl.cox_ph_forestplot(coxph, fig_size=(12, 3), t_adjuster=0.15, marker="o", size=2, text_size=14)

check_same_image(
fig=fig,
base_path=f"{_TEST_IMAGE_PATH}/coxph_forestplot",
tol=2e-1,
)
17 changes: 17 additions & 0 deletions tests/plot/test_survival_analysis.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
from pathlib import Path

import ehrapy as ep

CURRENT_DIR = Path(__file__).parent
_TEST_IMAGE_PATH = f"{CURRENT_DIR}/_images"

def test_coxph_forestplot(mimic_2, check_same_image):
adata_subset = mimic_2[:, ["mort_day_censored", "censor_flg", "gender_num", "afib_flg", "day_icu_intime_num"]]
coxph = ep.tl.cox_ph(adata_subset, duration_col="mort_day_censored", event_col="censor_flg")
fig, ax = ep.pl.cox_ph_forestplot(coxph, fig_size=(12, 3), t_adjuster=0.15, marker="o", size=2, text_size=14)

check_same_image(
fig=fig,
base_path=f"{_TEST_IMAGE_PATH}/coxph_forestplot",
tol=2e-1,
)

0 comments on commit 76f0f0c

Please sign in to comment.