Skip to content

Commit

Permalink
Merge pull request #46 from goeckslab/svm_explainer
Browse files Browse the repository at this point in the history
Fix missing predict_proba support for some models
  • Loading branch information
qchiujunhao authored Dec 31, 2024
2 parents 97f33c3 + d0a8c2f commit 6998721
Show file tree
Hide file tree
Showing 5 changed files with 51 additions and 13 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
.env
env
.venv
*.log
*.png
*.txt
Expand Down
22 changes: 16 additions & 6 deletions tools/base_model_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -263,9 +263,13 @@ def save_html_report(self):
Best Model Plots</div>
<div class="tab" onclick="openTab(event, 'feature')">
Feature Importance</div>
<div class="tab" onclick="openTab(event, 'explainer')">
Explainer
</div>
"""
if self.plots_explainer_html:
html_content += """
"<div class="tab" onclick="openTab(event, 'explainer')">"
Explainer Plots</div>
"""
html_content += f"""
</div>
<div id="summary" class="tab-content">
<h2>Setup Parameters</h2>
Expand Down Expand Up @@ -299,13 +303,19 @@ def save_html_report(self):
<div id="feature" class="tab-content">
{feature_importance_html}
</div>
"""
if self.plots_explainer_html:
html_content += f"""
<div id="explainer" class="tab-content">
{self.plots_explainer_html}
{tree_plots}
</div>
{get_html_closing()}
"""

{get_html_closing()}
"""
else:
html_content += f"""
{get_html_closing()}
"""
with open(os.path.join(
self.output_dir, "comparison_result.html"), "w") as file:
file.write(html_content)
Expand Down
23 changes: 19 additions & 4 deletions tools/pycaret_classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

from pycaret.classification import ClassificationExperiment

from utils import add_hr_to_html, add_plot_to_html
from utils import add_hr_to_html, add_plot_to_html, predict_proba

LOG = logging.getLogger(__name__)

Expand Down Expand Up @@ -39,6 +39,16 @@ def save_dashboard(self):

def generate_plots(self):
LOG.info("Generating and saving plots")

if not hasattr(self.best_model, "predict_proba"):
import types
self.best_model.predict_proba = types.MethodType(
predict_proba, self.best_model)
LOG.warning(
f"The model {type(self.best_model).__name__}\
does not support `predict_proba`. \
Applying monkey patch.")

plots = ['confusion_matrix', 'auc', 'threshold', 'pr',
'error', 'class_report', 'learning', 'calibration',
'vc', 'dimension', 'manifold', 'rfe', 'feature',
Expand Down Expand Up @@ -74,9 +84,14 @@ def generate_plots_explainer(self):
X_test = self.exp.X_test_transformed.copy()
y_test = self.exp.y_test_transformed

explainer = ClassifierExplainer(self.best_model, X_test, y_test)
self.expaliner = explainer
plots_explainer_html = ""
try:
explainer = ClassifierExplainer(self.best_model, X_test, y_test)
self.expaliner = explainer
plots_explainer_html = ""
except Exception as e:
LOG.error(f"Error creating explainer: {e}")
self.plots_explainer_html = None
return

try:
fig_importance = explainer.plot_importances()
Expand Down
11 changes: 8 additions & 3 deletions tools/pycaret_regression.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,9 +59,14 @@ def generate_plots_explainer(self):
X_test = self.exp.X_test_transformed.copy()
y_test = self.exp.y_test_transformed

explainer = RegressionExplainer(self.best_model, X_test, y_test)
self.expaliner = explainer
plots_explainer_html = ""
try:
explainer = RegressionExplainer(self.best_model, X_test, y_test)
self.expaliner = explainer
plots_explainer_html = ""
except Exception as e:
LOG.error(f"Error creating explainer: {e}")
self.plots_explainer_html = None
return

try:
fig_importance = explainer.plot_importances()
Expand Down
7 changes: 7 additions & 0 deletions tools/utils.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
import base64
import logging

import numpy as np

logging.basicConfig(level=logging.DEBUG)
LOG = logging.getLogger(__name__)

Expand Down Expand Up @@ -155,3 +157,8 @@ def encode_image_to_base64(image_path):
"""Convert an image file to a base64 encoded string."""
with open(image_path, "rb") as img_file:
return base64.b64encode(img_file.read()).decode("utf-8")


def predict_proba(self, X):
pred = self.predict(X)
return np.array([1-pred, pred]).T

0 comments on commit 6998721

Please sign in to comment.