From 76f0f0cb487fb2790adc9aea1b7a093f608ad775 Mon Sep 17 00:00:00 2001 From: Carl Buchholz <32228189+aGuyLearning@users.noreply.github.com> Date: Fri, 13 Dec 2024 10:01:03 +0100 Subject: [PATCH] updated variable names and moved test to better file --- ehrapy/plot/_survival_analysis.py | 46 ++++++++++++++-------------- tests/plot/test_catplot.py | 11 ------- tests/plot/test_survival_analysis.py | 17 ++++++++++ 3 files changed, 40 insertions(+), 34 deletions(-) create mode 100644 tests/plot/test_survival_analysis.py diff --git a/ehrapy/plot/_survival_analysis.py b/ehrapy/plot/_survival_analysis.py index 285b57ac..6dacc124 100644 --- a/ehrapy/plot/_survival_analysis.py +++ b/ehrapy/plot/_survival_analysis.py @@ -340,28 +340,28 @@ def cox_ph_forestplot( .. image:: /_static/docstring_previews/coxph_forestplot.png """ - data = cox_ph.summary + coxph_summary = cox_ph.summary auc_col = "coef" if labels is None: - labels = data.index + labels = coxph_summary.index tval = [] ytick = [] - for i in range(len(data)): - if not np.isnan(data[auc_col][i]): + for i in range(len(coxph_summary)): + if not np.isnan(coxph_summary[auc_col][i]): if ( - (isinstance(data[auc_col][i], float)) - & (isinstance(data["coef lower 95%"][i], float)) - & (isinstance(data["coef upper 95%"][i], float)) + (isinstance(coxph_summary[auc_col][i], float)) + & (isinstance(coxph_summary["coef lower 95%"][i], float)) + & (isinstance(coxph_summary["coef upper 95%"][i], float)) ): tval.append( [ - round(data[auc_col][i], decimal), + round(coxph_summary[auc_col][i], decimal), ( "(" - + str(round(data["coef lower 95%"][i], decimal)) + + str(round(coxph_summary["coef lower 95%"][i], decimal)) + ", " - + str(round(data["coef upper 95%"][i], decimal)) + + str(round(coxph_summary["coef upper 95%"][i], decimal)) + ")" ), ] @@ -369,8 +369,8 @@ def cox_ph_forestplot( else: tval.append( [ - data[auc_col][i], - ("(" + str(data["coef lower 95%"][i]) + ", " + str(data["coef upper 95%"][i]) + ")"), + coxph_summary[auc_col][i], + ("(" + str(coxph_summary["coef lower 95%"][i]) + ", " + str(coxph_summary["coef upper 95%"][i]) + ")"), ] ) ytick.append(i) @@ -378,22 +378,22 @@ def cox_ph_forestplot( tval.append([" ", " "]) ytick.append(i) - maxi = round(((pd.to_numeric(data["coef upper 95%"])).max() + 0.1), 2) # setting x-axis maximum + x_axis_upper_bound = round(((pd.to_numeric(coxph_summary["coef upper 95%"])).max() + 0.1), 2) - mini = round(((pd.to_numeric(data["coef lower 95%"])).min() - 0.1), 1) # setting x-axis minimum + x_axis_lower_bound = round(((pd.to_numeric(coxph_summary["coef lower 95%"])).min() - 0.1), 1) fig = plt.figure(figsize=fig_size) gspec = gridspec.GridSpec(1, 6) plot = plt.subplot(gspec[0, 0:4]) # plot of data tabl = plt.subplot(gspec[0, 4:]) # table - plot.set_ylim(-1, (len(data))) # spacing out y-axis properly + plot.set_ylim(-1, (len(coxph_summary))) # spacing out y-axis properly plot.axvline(1, color="gray", zorder=1) - lower_diff = data[auc_col] - data["coef lower 95%"] - upper_diff = data["coef upper 95%"] - data[auc_col] + lower_diff = coxph_summary[auc_col] - coxph_summary["coef lower 95%"] + upper_diff = coxph_summary["coef upper 95%"] - coxph_summary[auc_col] plot.errorbar( - data[auc_col], - data.index, + coxph_summary[auc_col], + coxph_summary.index, xerr=[lower_diff, upper_diff], marker="None", zorder=2, @@ -401,15 +401,15 @@ def cox_ph_forestplot( linewidth=0, elinewidth=1, ) - plot.scatter(data[auc_col], data.index, c=color, s=(size * 25), marker=marker, zorder=3, edgecolors="None") + plot.scatter(coxph_summary[auc_col], coxph_summary.index, c=color, s=(size * 25), marker=marker, zorder=3, edgecolors="None") plot.xaxis.set_ticks_position("bottom") plot.yaxis.set_ticks_position("left") plot.get_xaxis().set_major_formatter(ticker.ScalarFormatter()) plot.get_xaxis().set_minor_formatter(ticker.NullFormatter()) plot.set_yticks(ytick) - plot.set_xlim([mini, maxi]) - plot.set_xticks([mini, 1, maxi]) - plot.set_xticklabels([mini, 1, maxi]) + plot.set_xlim([x_axis_lower_bound, x_axis_upper_bound]) + plot.set_xticks([x_axis_lower_bound, 1, x_axis_upper_bound]) + plot.set_xticklabels([x_axis_lower_bound, 1, x_axis_upper_bound]) plot.set_yticklabels(labels) plot.tick_params(axis="y", labelsize=text_size) plot.yaxis.set_ticks_position("none") diff --git a/tests/plot/test_catplot.py b/tests/plot/test_catplot.py index 90b2b5f3..8e569928 100644 --- a/tests/plot/test_catplot.py +++ b/tests/plot/test_catplot.py @@ -15,14 +15,3 @@ def test_catplot_vanilla(adata_mini, check_same_image): tol=2e-1, ) - -def test_coxph_forestplot(mimic_2, check_same_image): - adata_subset = mimic_2[:, ["mort_day_censored", "censor_flg", "gender_num", "afib_flg", "day_icu_intime_num"]] - coxph = ep.tl.cox_ph(adata_subset, duration_col="mort_day_censored", event_col="censor_flg") - fig, ax = ep.pl.cox_ph_forestplot(coxph, fig_size=(12, 3), t_adjuster=0.15, marker="o", size=2, text_size=14) - - check_same_image( - fig=fig, - base_path=f"{_TEST_IMAGE_PATH}/coxph_forestplot", - tol=2e-1, - ) diff --git a/tests/plot/test_survival_analysis.py b/tests/plot/test_survival_analysis.py new file mode 100644 index 00000000..982bc7a8 --- /dev/null +++ b/tests/plot/test_survival_analysis.py @@ -0,0 +1,17 @@ +from pathlib import Path + +import ehrapy as ep + +CURRENT_DIR = Path(__file__).parent +_TEST_IMAGE_PATH = f"{CURRENT_DIR}/_images" + +def test_coxph_forestplot(mimic_2, check_same_image): + adata_subset = mimic_2[:, ["mort_day_censored", "censor_flg", "gender_num", "afib_flg", "day_icu_intime_num"]] + coxph = ep.tl.cox_ph(adata_subset, duration_col="mort_day_censored", event_col="censor_flg") + fig, ax = ep.pl.cox_ph_forestplot(coxph, fig_size=(12, 3), t_adjuster=0.15, marker="o", size=2, text_size=14) + + check_same_image( + fig=fig, + base_path=f"{_TEST_IMAGE_PATH}/coxph_forestplot", + tol=2e-1, + ) \ No newline at end of file