From adf882271fde450a599e99afd952c43565ad1421 Mon Sep 17 00:00:00 2001 From: efriman Date: Fri, 22 Sep 2023 11:05:05 +0100 Subject: [PATCH] Added --font and --plot_output. Fixed SHAP --- PeakPredict/overlap_peaks.py | 98 +++++++++++++++++++++------------ PeakPredict/predict_features.py | 93 ++++++++++++++++++++----------- 2 files changed, 125 insertions(+), 66 deletions(-) diff --git a/PeakPredict/overlap_peaks.py b/PeakPredict/overlap_peaks.py index 1a78cbf..70f59fa 100644 --- a/PeakPredict/overlap_peaks.py +++ b/PeakPredict/overlap_peaks.py @@ -169,6 +169,20 @@ def parse_args_overlap_peaks(): required=False, help="""Relative size of plots. Adjust if they don't look right (too many/few features)""", ) + parser.add_argument( + "--plot_output", + type=str, + default="png", + required=False, + help="""Plot file type output (png/pdf/svg)""", + ) + parser.add_argument( + "--font", + type=str, + default="DejaVu Sans", + required=False, + help="""Font""", + ) return parser @@ -347,6 +361,8 @@ def main(): else: random_state = None + sns.set(style="ticks", font=args.font) + corr_matrix, predictions, feature_importance, model = predict_features( input_table, predict_column=args.predict_column, @@ -370,9 +386,9 @@ def main(): mask=nanmask, figsize=(args.plot_size, args.plot_size), ) - g.savefig(f"{args.outdir}/{args.outname}_corr_features.png", dpi=300) + g.savefig(f"{args.outdir}/{args.outname}_corr_features.{args.plot_output}", dpi=300) logging.info( - f"Saved predictor correlations as {args.outdir}/{args.outname}_corr_features.png" + f"Saved predictor correlations as {args.outdir}/{args.outname}_corr_features.{args.plot_output}" ) predictions.to_csv( @@ -392,12 +408,12 @@ def main(): plt.tight_layout() plt.xticks(rotation=90) plt.savefig( - f"{args.outdir}/{args.outname}_confusion_matrix_{args.predict_column}_{args.model}.png", + f"{args.outdir}/{args.outname}_confusion_matrix_{args.predict_column}_{args.model}.{args.plot_output}", dpi=300, bbox_inches="tight", ) logging.info( - f"Saved confusion matrix as {args.outdir}/{args.outname}_confusion_matrix_{args.predict_column}_{args.model}.png" + f"Saved confusion matrix as {args.outdir}/{args.outname}_confusion_matrix_{args.predict_column}_{args.model}.{args.plot_output}" ) except ValueError: warnings.warn( @@ -419,67 +435,79 @@ def main(): ) plt.xticks(rotation=45, ha="right") plt.savefig( - f"{args.outdir}/{args.outname}_feature_importance_{args.model}.png", + f"{args.outdir}/{args.outname}_feature_importance_{args.model}.{args.plot_output}", dpi=300, bbox_inches="tight", ) logging.info( - f"Saved feature importance as {args.outdir}/{args.outname}_feature_importance_{args.model}.tsv and {args.outdir}/{args.outname}_feature_importance_{args.model}.png" + f"Saved feature importance as {args.outdir}/{args.outname}_feature_importance_{args.model}.tsv and {args.outdir}/{args.outname}_feature_importance_{args.model}.{args.plot_output}" ) if args.shap is not None: - logging.info(f"Calculating SHAP values (can be slow for big datasets)") + logging.info(f"Calculating SHAP values") if args.shap == ["approximate"]: approximate=True - logging.info("Using approximate=True for shap_values") + logging.info("using approximate=True") elif args.shap == []: approximate = False else: approximate = False logging.info("--shap can only be empty or 'approximate', ignoring") - plot_size = ( - None if args.plot_size == 1 else (args.plot_size, args.plot_size) - ) - shap_values = shap.Explainer(model).shap_values( - predictions[predictor_columns], approximate=approximate - ) + plot_size = None if args.plot_size == 1 else (args.plot_size, args.plot_size) + + train_index = [idx for idx in input_table.index if idx not in predictions.index] + X_train = input_table.loc[train_index, predictor_columns] + + try: + shap_values = shap.Explainer(model, X_train).shap_values(predictions[predictor_columns], + approximate=approximate) + except TypeError: + try: + shap_values = shap.Explainer(model, X_train).shap_values(predictions[predictor_columns]) + if args.shap == ["approximate"]: + logging.info("No approximate SHAP for this kind of model") + except TypeError: + raise ValueError("Can't calculate SHAP for this kind of model in PeakPredict") + plt.figure() shap.summary_plot( shap_values, predictions[predictor_columns], class_names=model.classes_, show=False, - plot_type="bar", plot_size=plot_size, ) plt.legend(loc=(1.04, 0)) plt.savefig( - f"{args.outdir}/{args.outname}_SHAP_{args.model}.png", + f"{args.outdir}/{args.outname}_SHAP_{args.model}.{args.plot_output}", dpi=300, bbox_inches="tight", ) logging.info( - f"Saved SHAP values as {args.outdir}/{args.outname}_SHAP_{args.model}.png" + f"Saved SHAP values as {args.outdir}/{args.outname}_SHAP_{args.model}.{args.plot_output}" ) - for i in range(len(shap_values)): - name = model.classes_[i] - plt.figure() - shap.summary_plot( - shap_values[i], - predictions[predictor_columns], - show=False, - plot_size=plot_size, - ) - plt.title(label=name) - plt.savefig( - f"{args.outdir}/{args.outname}_SHAP_{name}_{args.model}.png", - dpi=300, - bbox_inches="tight", - ) - logging.info( - f"Saved SHAP values as {args.outdir}/{args.outname}_SHAP_{name}_{args.model}.png" - ) + try: + for i in range(len(shap_values)): + name = model.classes_[i] + plt.figure() + shap.summary_plot( + shap_values[i], + predictions[predictor_columns], + show=False, + plot_size=plot_size, + ) + plt.title(label=name) + plt.savefig( + f"{args.outdir}/{args.outname}_SHAP_{name}_{args.model}.{args.plot_output}", + dpi=300, + bbox_inches="tight", + ) + logging.info( + f"Saved SHAP values as {args.outdir}/{args.outname}_SHAP_{name}_{args.model}.{args.plot_output}" + ) + except: + pass if __name__ == "__main__": diff --git a/PeakPredict/predict_features.py b/PeakPredict/predict_features.py index ea0ff6a..9e348f9 100644 --- a/PeakPredict/predict_features.py +++ b/PeakPredict/predict_features.py @@ -121,6 +121,20 @@ def parse_args_predict_features(): required=False, help="""Relative size of plots. Adjust if they don't look right (too many/few features)""", ) + parser.add_argument( + "--plot_output", + type=str, + default="png", + required=False, + help="""Plot file type output (png/pdf/svg)""", + ) + parser.add_argument( + "--font", + type=str, + default="DejaVu Sans", + required=False, + help="""Font""", + ) return parser @@ -179,6 +193,8 @@ def main(): .sample(args.maximum_per_category, random_state=random_state) .reset_index(drop=True) ) + + sns.set(style="ticks", font=args.font) corr_matrix, predictions, feature_importance, model = predict_features( input_table, @@ -203,9 +219,9 @@ def main(): mask=nanmask, figsize=(args.plot_size, args.plot_size), ) - g.savefig(f"{args.outdir}/{args.outname}_corr_features.png", dpi=300) + g.savefig(f"{args.outdir}/{args.outname}_corr_features.{args.plot_output}", dpi=300) logging.info( - f"Saved predictor correlations as {args.outdir}/{args.outname}_corr_features.png" + f"Saved predictor correlations as {args.outdir}/{args.outname}_corr_features.{args.plot_output}" ) predictions.to_csv( @@ -224,12 +240,12 @@ def main(): plt.tight_layout() plt.xticks(rotation=90) plt.savefig( - f"{args.outdir}/{args.outname}_confusion_matrix_{args.predict_column}_{args.model}.png", + f"{args.outdir}/{args.outname}_confusion_matrix_{args.predict_column}_{args.model}.{args.plot_output}", dpi=300, bbox_inches="tight", ) logging.info( - f"Saved confusion matrix as {args.outdir}/{args.outname}_confusion_matrix_{args.predict_column}_{args.model}.png" + f"Saved confusion matrix as {args.outdir}/{args.outname}_confusion_matrix_{args.predict_column}_{args.model}.{args.plot_output}" ) except ValueError: warnings.warn( @@ -251,19 +267,19 @@ def main(): ) plt.xticks(rotation=45, ha="right") plt.savefig( - f"{args.outdir}/{args.outname}_feature_importance_{args.model}.png", + f"{args.outdir}/{args.outname}_feature_importance_{args.model}.{args.plot_output}", dpi=300, bbox_inches="tight", ) logging.info( - f"Saved feature importance as {args.outdir}/{args.outname}_feature_importance_{args.model}.tsv and {args.outdir}/{args.outname}_feature_importance_{args.model}.png" + f"Saved feature importance as {args.outdir}/{args.outname}_feature_importance_{args.model}.tsv and {args.outdir}/{args.outname}_feature_importance_{args.model}.{args.plot_output}" ) if args.shap is not None: - logging.info(f"Calculating SHAP values (can be slow for big datasets)") + logging.info(f"Calculating SHAP values") if args.shap == ["approximate"]: approximate=True - logging.info("Using approximate=True for shap_values") + logging.info("using approximate=True") elif args.shap == []: approximate = False else: @@ -271,44 +287,59 @@ def main(): logging.info("--shap can only be empty or 'approximate', ignoring") plot_size = None if args.plot_size == 1 else (args.plot_size, args.plot_size) - shap_values = shap.Explainer(model).shap_values(predictions[predictor_columns], - approximate=approximate) + + train_index = [idx for idx in input_table.index if idx not in predictions.index] + X_train = input_table.loc[train_index, predictor_columns] + + try: + shap_values = shap.Explainer(model, X_train).shap_values(predictions[predictor_columns], + approximate=approximate) + except TypeError: + try: + shap_values = shap.Explainer(model, X_train).shap_values(predictions[predictor_columns]) + if args.shap == ["approximate"]: + logging.info("No approximate SHAP for this kind of model") + except TypeError: + raise ValueError("Can't calculate SHAP for this kind of model in PeakPredict") + plt.figure() shap.summary_plot( shap_values, predictions[predictor_columns], class_names=model.classes_, show=False, - plot_type="bar", plot_size=plot_size, ) plt.legend(loc=(1.04, 0)) plt.savefig( - f"{args.outdir}/{args.outname}_SHAP_{args.model}.png", + f"{args.outdir}/{args.outname}_SHAP_{args.model}.{args.plot_output}", dpi=300, bbox_inches="tight", ) logging.info( - f"Saved SHAP values as {args.outdir}/{args.outname}_SHAP_{args.model}.png" + f"Saved SHAP values as {args.outdir}/{args.outname}_SHAP_{args.model}.{args.plot_output}" ) - for i in range(len(shap_values)): - name = model.classes_[i] - plt.figure() - shap.summary_plot( - shap_values[i], - predictions[predictor_columns], - show=False, - plot_size=plot_size, - ) - plt.title(label=name) - plt.savefig( - f"{args.outdir}/{args.outname}_SHAP_{name}_{args.model}.png", - dpi=300, - bbox_inches="tight", - ) - logging.info( - f"Saved SHAP values as {args.outdir}/{args.outname}_SHAP_{name}_{args.model}.png" - ) + try: + for i in range(len(shap_values)): + name = model.classes_[i] + plt.figure() + shap.summary_plot( + shap_values[i], + predictions[predictor_columns], + show=False, + plot_size=plot_size, + ) + plt.title(label=name) + plt.savefig( + f"{args.outdir}/{args.outname}_SHAP_{name}_{args.model}.{args.plot_output}", + dpi=300, + bbox_inches="tight", + ) + logging.info( + f"Saved SHAP values as {args.outdir}/{args.outname}_SHAP_{name}_{args.model}.{args.plot_output}" + ) + except: + pass if __name__ == "__main__":