-
Notifications
You must be signed in to change notification settings - Fork 2
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Adds uncertainty plots which allow uncertainty on the prevalence of g…
…roups to be assessed (#8) Includes all code, tests, docs and examples for the uncertainty plot
- Loading branch information
Showing
14 changed files
with
306 additions
and
21 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,70 @@ | ||
from lorepy import uncertainty_plot | ||
|
||
from sklearn.datasets import load_iris | ||
import matplotlib.pyplot as plt | ||
import pandas as pd | ||
|
||
# Load iris dataset and convert to dataframe | ||
iris_obj = load_iris() | ||
iris_df = pd.DataFrame(iris_obj.data, columns=iris_obj.feature_names) | ||
|
||
iris_df["species"] = [iris_obj.target_names[s] for s in iris_obj.target] | ||
|
||
# Default uncertainty plot | ||
uncertainty_plot(data=iris_df, x="sepal width (cm)", y="species", iterations=100) | ||
plt.savefig("./docs/img/uncertainty_default.png", dpi=150) | ||
plt.show() | ||
|
||
# Using jackknife instead of resample to assess uncertainty | ||
uncertainty_plot( | ||
data=iris_df, | ||
x="sepal width (cm)", | ||
y="species", | ||
iterations=100, | ||
jackknife_fraction=0.8, | ||
) | ||
plt.savefig("./docs/img/uncertainty_jackknife.png", dpi=150) | ||
plt.show() | ||
|
||
# Uncertainty plot with custom colors | ||
from matplotlib.colors import ListedColormap | ||
|
||
colormap = ListedColormap(["red", "green", "blue"]) | ||
uncertainty_plot( | ||
data=iris_df, | ||
x="sepal width (cm)", | ||
y="species", | ||
iterations=100, | ||
mode="resample", | ||
colormap=colormap, | ||
) | ||
plt.savefig("./docs/img/uncertainty_custom_color.png", dpi=150) | ||
plt.show() | ||
|
||
# Uncertainty plot with a confounder | ||
uncertainty_plot( | ||
data=iris_df, | ||
x="sepal width (cm)", | ||
y="species", | ||
iterations=100, | ||
mode="resample", | ||
confounders=[("petal width (cm)", 1)], | ||
) | ||
plt.savefig("./docs/img/uncertainty_confounder.png", dpi=150) | ||
plt.show() | ||
|
||
# Uncertainty plot with a custom classifier | ||
from sklearn.svm import SVC | ||
|
||
svc = SVC(probability=True) | ||
|
||
uncertainty_plot( | ||
data=iris_df, | ||
x="sepal width (cm)", | ||
y="species", | ||
iterations=100, | ||
mode="resample", | ||
clf=svc, | ||
) | ||
plt.savefig("./docs/img/uncertainty_custom_classifier.png", dpi=150) | ||
plt.show() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1 +1,2 @@ | ||
from .lorepy import loreplot | ||
from .uncertainty import uncertainty_plot |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,130 @@ | ||
import numpy as np | ||
import pandas as pd | ||
from matplotlib import pyplot as plt | ||
from pandas import DataFrame | ||
from sklearn.linear_model import LogisticRegression | ||
from sklearn.model_selection import train_test_split | ||
from sklearn.utils import resample | ||
|
||
from lorepy.lorepy import _get_area_df, _prepare_data | ||
|
||
|
||
def _get_uncertainty_data( | ||
x: str, | ||
X_reg, | ||
y_reg, | ||
x_range, | ||
mode="resample", | ||
jackknife_fraction: float = 0.8, | ||
iterations: int = 100, | ||
confounders=[], | ||
clf=None, | ||
): | ||
areas = [] | ||
for i in range(iterations): | ||
if mode == "jackknife": | ||
X_keep, _, y_keep, _ = train_test_split( | ||
X_reg, y_reg, train_size=jackknife_fraction | ||
) | ||
elif mode == "resample": | ||
X_keep, y_keep = resample(X_reg, y_reg, replace=True) | ||
else: | ||
raise NotImplementedError( | ||
f"Mode {mode} is unsupported, only jackknife and resample are valid modes" | ||
) | ||
|
||
lg = LogisticRegression(multi_class="multinomial") if clf is None else clf | ||
lg.fit(X_keep, y_keep) | ||
new_area = _get_area_df(lg, x, x_range, confounders=confounders).reset_index() | ||
|
||
areas.append(new_area) | ||
|
||
long_df = pd.concat(areas).melt(id_vars=[x]).sort_values(x) | ||
|
||
output = ( | ||
long_df.groupby([x, "variable"]) | ||
.agg( | ||
min=pd.NamedAgg(column="value", aggfunc="min"), | ||
mean=pd.NamedAgg(column="value", aggfunc="mean"), | ||
max=pd.NamedAgg(column="value", aggfunc="max"), | ||
low_95=pd.NamedAgg(column="value", aggfunc=lambda v: np.percentile(v, 2.5)), | ||
high_95=pd.NamedAgg( | ||
column="value", aggfunc=lambda v: np.percentile(v, 97.5) | ||
), | ||
low_50=pd.NamedAgg(column="value", aggfunc=lambda v: np.percentile(v, 25)), | ||
high_50=pd.NamedAgg(column="value", aggfunc=lambda v: np.percentile(v, 75)), | ||
) | ||
.reset_index() | ||
) | ||
|
||
return output | ||
|
||
|
||
def uncertainty_plot( | ||
data: DataFrame, | ||
x: str, | ||
y: str, | ||
x_range=None, | ||
mode="resample", | ||
jackknife_fraction=0.8, | ||
iterations=100, | ||
confounders=[], | ||
colormap=None, | ||
clf=None, | ||
): | ||
""" | ||
Code to create a multi-panel plot, one panel for each category, with the prevalence of that category across the | ||
range of x-values, along with the uncertainty (intervals containing 50% and 95% of the samples are shown) | ||
:param data: Pandas dataframe with data | ||
:param x: Needs to be a numerical feature | ||
:param y: Categorical feature | ||
:param x_range: Either None (range will be selected automatically) or a tuple with min and max value for the x-axis | ||
:param mode: Sampling method, either "resample" (bootstrap) or "jackknife" (default = "resample") | ||
:param jackknife_fraction: Fraction of data to retain for each jackknife sample (default = 0.8) | ||
:param iterations: Number of iterations for resampling or jackknife (default = 100) | ||
:param confounders: List of tuples with the feature and reference value e.g., [("BMI", 25)] will use a reference of 25 for plots | ||
:param colormap: Colormap to use for the plot, default is None in which case matplotlib's default will be used | ||
:param clf: Provide a different scikit-learn classifier for the function. Should implement the predict_proba() and fit(). If None a LogisticRegression will be used. | ||
:return: A tuple containing the figure and axes objects | ||
""" | ||
X_reg, y_reg, r = _prepare_data(data, x, y, confounders) | ||
|
||
if x_range is None: | ||
x_range = r | ||
|
||
plot_df = _get_uncertainty_data( | ||
x, | ||
X_reg, | ||
y_reg, | ||
x_range, | ||
mode=mode, | ||
jackknife_fraction=jackknife_fraction, | ||
iterations=iterations, | ||
confounders=confounders, | ||
clf=clf, | ||
) | ||
|
||
categories = plot_df.variable.unique() | ||
|
||
fig, axs = plt.subplots(ncols=len(categories), sharex=True, sharey=True) | ||
|
||
cmap = plt.get_cmap("tab10") if colormap is None else colormap | ||
|
||
for idx, category in enumerate(categories): | ||
cat_df = plot_df[plot_df.variable == category] | ||
|
||
axs[idx].fill_between( | ||
cat_df[x], cat_df["low_95"], cat_df["high_95"], alpha=0.1, color=cmap(idx) | ||
) | ||
axs[idx].fill_between( | ||
cat_df[x], cat_df["low_50"], cat_df["high_50"], alpha=0.2, color=cmap(idx) | ||
) | ||
axs[idx].plot(cat_df[x], cat_df["mean"], color=cmap(idx)) | ||
axs[idx].set_title(categories[idx]) | ||
axs[idx].set_xlabel(x) | ||
|
||
axs[idx].set_xlim(*x_range) | ||
axs[idx].set_ylim(0, 1) | ||
|
||
return fig, axs |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.