Skip to content

Commit

Permalink
Added --font and --plot_output. Fixed SHAP
Browse files Browse the repository at this point in the history
  • Loading branch information
efriman committed Sep 22, 2023
1 parent 3084f27 commit adf8822
Show file tree
Hide file tree
Showing 2 changed files with 125 additions and 66 deletions.
98 changes: 63 additions & 35 deletions PeakPredict/overlap_peaks.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,6 +169,20 @@ def parse_args_overlap_peaks():
required=False,
help="""Relative size of plots. Adjust if they don't look right (too many/few features)""",
)
parser.add_argument(
"--plot_output",
type=str,
default="png",
required=False,
help="""Plot file type output (png/pdf/svg)""",
)
parser.add_argument(
"--font",
type=str,
default="DejaVu Sans",
required=False,
help="""Font""",
)

return parser

Expand Down Expand Up @@ -347,6 +361,8 @@ def main():
else:
random_state = None

sns.set(style="ticks", font=args.font)

corr_matrix, predictions, feature_importance, model = predict_features(
input_table,
predict_column=args.predict_column,
Expand All @@ -370,9 +386,9 @@ def main():
mask=nanmask,
figsize=(args.plot_size, args.plot_size),
)
g.savefig(f"{args.outdir}/{args.outname}_corr_features.png", dpi=300)
g.savefig(f"{args.outdir}/{args.outname}_corr_features.{args.plot_output}", dpi=300)
logging.info(
f"Saved predictor correlations as {args.outdir}/{args.outname}_corr_features.png"
f"Saved predictor correlations as {args.outdir}/{args.outname}_corr_features.{args.plot_output}"
)

predictions.to_csv(
Expand All @@ -392,12 +408,12 @@ def main():
plt.tight_layout()
plt.xticks(rotation=90)
plt.savefig(
f"{args.outdir}/{args.outname}_confusion_matrix_{args.predict_column}_{args.model}.png",
f"{args.outdir}/{args.outname}_confusion_matrix_{args.predict_column}_{args.model}.{args.plot_output}",
dpi=300,
bbox_inches="tight",
)
logging.info(
f"Saved confusion matrix as {args.outdir}/{args.outname}_confusion_matrix_{args.predict_column}_{args.model}.png"
f"Saved confusion matrix as {args.outdir}/{args.outname}_confusion_matrix_{args.predict_column}_{args.model}.{args.plot_output}"
)
except ValueError:
warnings.warn(
Expand All @@ -419,67 +435,79 @@ def main():
)
plt.xticks(rotation=45, ha="right")
plt.savefig(
f"{args.outdir}/{args.outname}_feature_importance_{args.model}.png",
f"{args.outdir}/{args.outname}_feature_importance_{args.model}.{args.plot_output}",
dpi=300,
bbox_inches="tight",
)
logging.info(
f"Saved feature importance as {args.outdir}/{args.outname}_feature_importance_{args.model}.tsv and {args.outdir}/{args.outname}_feature_importance_{args.model}.png"
f"Saved feature importance as {args.outdir}/{args.outname}_feature_importance_{args.model}.tsv and {args.outdir}/{args.outname}_feature_importance_{args.model}.{args.plot_output}"
)

if args.shap is not None:
logging.info(f"Calculating SHAP values (can be slow for big datasets)")
logging.info(f"Calculating SHAP values")
if args.shap == ["approximate"]:
approximate=True
logging.info("Using approximate=True for shap_values")
logging.info("using approximate=True")
elif args.shap == []:
approximate = False
else:
approximate = False
logging.info("--shap can only be empty or 'approximate', ignoring")

plot_size = (
None if args.plot_size == 1 else (args.plot_size, args.plot_size)
)
shap_values = shap.Explainer(model).shap_values(
predictions[predictor_columns], approximate=approximate
)
plot_size = None if args.plot_size == 1 else (args.plot_size, args.plot_size)

train_index = [idx for idx in input_table.index if idx not in predictions.index]
X_train = input_table.loc[train_index, predictor_columns]

try:
shap_values = shap.Explainer(model, X_train).shap_values(predictions[predictor_columns],
approximate=approximate)
except TypeError:
try:
shap_values = shap.Explainer(model, X_train).shap_values(predictions[predictor_columns])
if args.shap == ["approximate"]:
logging.info("No approximate SHAP for this kind of model")
except TypeError:
raise ValueError("Can't calculate SHAP for this kind of model in PeakPredict")

plt.figure()
shap.summary_plot(
shap_values,
predictions[predictor_columns],
class_names=model.classes_,
show=False,
plot_type="bar",
plot_size=plot_size,
)
plt.legend(loc=(1.04, 0))
plt.savefig(
f"{args.outdir}/{args.outname}_SHAP_{args.model}.png",
f"{args.outdir}/{args.outname}_SHAP_{args.model}.{args.plot_output}",
dpi=300,
bbox_inches="tight",
)
logging.info(
f"Saved SHAP values as {args.outdir}/{args.outname}_SHAP_{args.model}.png"
f"Saved SHAP values as {args.outdir}/{args.outname}_SHAP_{args.model}.{args.plot_output}"
)
for i in range(len(shap_values)):
name = model.classes_[i]
plt.figure()
shap.summary_plot(
shap_values[i],
predictions[predictor_columns],
show=False,
plot_size=plot_size,
)
plt.title(label=name)
plt.savefig(
f"{args.outdir}/{args.outname}_SHAP_{name}_{args.model}.png",
dpi=300,
bbox_inches="tight",
)
logging.info(
f"Saved SHAP values as {args.outdir}/{args.outname}_SHAP_{name}_{args.model}.png"
)
try:
for i in range(len(shap_values)):
name = model.classes_[i]
plt.figure()
shap.summary_plot(
shap_values[i],
predictions[predictor_columns],
show=False,
plot_size=plot_size,
)
plt.title(label=name)
plt.savefig(
f"{args.outdir}/{args.outname}_SHAP_{name}_{args.model}.{args.plot_output}",
dpi=300,
bbox_inches="tight",
)
logging.info(
f"Saved SHAP values as {args.outdir}/{args.outname}_SHAP_{name}_{args.model}.{args.plot_output}"
)
except:
pass


if __name__ == "__main__":
Expand Down
93 changes: 62 additions & 31 deletions PeakPredict/predict_features.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,20 @@ def parse_args_predict_features():
required=False,
help="""Relative size of plots. Adjust if they don't look right (too many/few features)""",
)
parser.add_argument(
"--plot_output",
type=str,
default="png",
required=False,
help="""Plot file type output (png/pdf/svg)""",
)
parser.add_argument(
"--font",
type=str,
default="DejaVu Sans",
required=False,
help="""Font""",
)

return parser

Expand Down Expand Up @@ -179,6 +193,8 @@ def main():
.sample(args.maximum_per_category, random_state=random_state)
.reset_index(drop=True)
)

sns.set(style="ticks", font=args.font)

corr_matrix, predictions, feature_importance, model = predict_features(
input_table,
Expand All @@ -203,9 +219,9 @@ def main():
mask=nanmask,
figsize=(args.plot_size, args.plot_size),
)
g.savefig(f"{args.outdir}/{args.outname}_corr_features.png", dpi=300)
g.savefig(f"{args.outdir}/{args.outname}_corr_features.{args.plot_output}", dpi=300)
logging.info(
f"Saved predictor correlations as {args.outdir}/{args.outname}_corr_features.png"
f"Saved predictor correlations as {args.outdir}/{args.outname}_corr_features.{args.plot_output}"
)

predictions.to_csv(
Expand All @@ -224,12 +240,12 @@ def main():
plt.tight_layout()
plt.xticks(rotation=90)
plt.savefig(
f"{args.outdir}/{args.outname}_confusion_matrix_{args.predict_column}_{args.model}.png",
f"{args.outdir}/{args.outname}_confusion_matrix_{args.predict_column}_{args.model}.{args.plot_output}",
dpi=300,
bbox_inches="tight",
)
logging.info(
f"Saved confusion matrix as {args.outdir}/{args.outname}_confusion_matrix_{args.predict_column}_{args.model}.png"
f"Saved confusion matrix as {args.outdir}/{args.outname}_confusion_matrix_{args.predict_column}_{args.model}.{args.plot_output}"
)
except ValueError:
warnings.warn(
Expand All @@ -251,64 +267,79 @@ def main():
)
plt.xticks(rotation=45, ha="right")
plt.savefig(
f"{args.outdir}/{args.outname}_feature_importance_{args.model}.png",
f"{args.outdir}/{args.outname}_feature_importance_{args.model}.{args.plot_output}",
dpi=300,
bbox_inches="tight",
)
logging.info(
f"Saved feature importance as {args.outdir}/{args.outname}_feature_importance_{args.model}.tsv and {args.outdir}/{args.outname}_feature_importance_{args.model}.png"
f"Saved feature importance as {args.outdir}/{args.outname}_feature_importance_{args.model}.tsv and {args.outdir}/{args.outname}_feature_importance_{args.model}.{args.plot_output}"
)

if args.shap is not None:
logging.info(f"Calculating SHAP values (can be slow for big datasets)")
logging.info(f"Calculating SHAP values")
if args.shap == ["approximate"]:
approximate=True
logging.info("Using approximate=True for shap_values")
logging.info("using approximate=True")
elif args.shap == []:
approximate = False
else:
approximate = False
logging.info("--shap can only be empty or 'approximate', ignoring")

plot_size = None if args.plot_size == 1 else (args.plot_size, args.plot_size)
shap_values = shap.Explainer(model).shap_values(predictions[predictor_columns],
approximate=approximate)

train_index = [idx for idx in input_table.index if idx not in predictions.index]
X_train = input_table.loc[train_index, predictor_columns]

try:
shap_values = shap.Explainer(model, X_train).shap_values(predictions[predictor_columns],
approximate=approximate)
except TypeError:
try:
shap_values = shap.Explainer(model, X_train).shap_values(predictions[predictor_columns])
if args.shap == ["approximate"]:
logging.info("No approximate SHAP for this kind of model")
except TypeError:
raise ValueError("Can't calculate SHAP for this kind of model in PeakPredict")

plt.figure()
shap.summary_plot(
shap_values,
predictions[predictor_columns],
class_names=model.classes_,
show=False,
plot_type="bar",
plot_size=plot_size,
)
plt.legend(loc=(1.04, 0))
plt.savefig(
f"{args.outdir}/{args.outname}_SHAP_{args.model}.png",
f"{args.outdir}/{args.outname}_SHAP_{args.model}.{args.plot_output}",
dpi=300,
bbox_inches="tight",
)
logging.info(
f"Saved SHAP values as {args.outdir}/{args.outname}_SHAP_{args.model}.png"
f"Saved SHAP values as {args.outdir}/{args.outname}_SHAP_{args.model}.{args.plot_output}"
)
for i in range(len(shap_values)):
name = model.classes_[i]
plt.figure()
shap.summary_plot(
shap_values[i],
predictions[predictor_columns],
show=False,
plot_size=plot_size,
)
plt.title(label=name)
plt.savefig(
f"{args.outdir}/{args.outname}_SHAP_{name}_{args.model}.png",
dpi=300,
bbox_inches="tight",
)
logging.info(
f"Saved SHAP values as {args.outdir}/{args.outname}_SHAP_{name}_{args.model}.png"
)
try:
for i in range(len(shap_values)):
name = model.classes_[i]
plt.figure()
shap.summary_plot(
shap_values[i],
predictions[predictor_columns],
show=False,
plot_size=plot_size,
)
plt.title(label=name)
plt.savefig(
f"{args.outdir}/{args.outname}_SHAP_{name}_{args.model}.{args.plot_output}",
dpi=300,
bbox_inches="tight",
)
logging.info(
f"Saved SHAP values as {args.outdir}/{args.outname}_SHAP_{name}_{args.model}.{args.plot_output}"
)
except:
pass


if __name__ == "__main__":
Expand Down

0 comments on commit adf8822

Please sign in to comment.