Skip to content

Commit

Permalink
ENH permutation importance (#142)
Browse files Browse the repository at this point in the history
Add a new method add_permutation_importances to the Card class that adds a plot
of sklearn.inspection.permutation_importance to the model card.

Co-authored-by: Adrin Jalali <adrin.jalali@gmail.com>
Co-authored-by: Benjamin Bossan <BenjaminBossan@users.noreply.github.com>
  • Loading branch information
3 people authored Jan 30, 2023
1 parent 9407ba7 commit 81558aa
Show file tree
Hide file tree
Showing 9 changed files with 240 additions and 8 deletions.
5 changes: 4 additions & 1 deletion docs/model_card.rst
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,10 @@ plots, save them on disk and then add them to the card by passing the path name
to the :meth:`.Card.add_plot` method. For tables, you can pass either
dictionaries with the key being the header and the values being list of row
entries, or a pandas ``DataFrame``; use the :meth:`.Card.add_table` method for
this.
this. If you would like to add permutation importance results, you can pass
your importances to :meth:`.Card.add_permutation_importances`. If you want to
have multiple importance plots, you should pass a file name and a title for the
plot. This will create a boxplot and write it to the model card for you.

To add content to an existing subsection, or create a new subsection, use a
``"/"`` to indicate the subsection. E.g. let's assume you would like to add a
Expand Down
9 changes: 9 additions & 0 deletions examples/plot_model_card.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from sklearn.datasets import load_breast_cancer
from sklearn.ensemble import HistGradientBoostingClassifier
from sklearn.experimental import enable_halving_search_cv # noqa
from sklearn.inspection import permutation_importance
from sklearn.metrics import (
ConfusionMatrixDisplay,
accuracy_score,
Expand Down Expand Up @@ -153,6 +154,14 @@
**{"Model description/Evaluation Results/Confusion Matrix": "confusion_matrix.png"}
)

importances = permutation_importance(model, X_test, y_test, n_repeats=10)
model_card.add_permutation_importances(
importances,
X_test.columns,
plot_file=Path(local_repo) / "importance.png",
plot_name="Permutation Importance",
)

cv_results = model.cv_results_
clf_report = classification_report(
y_test, y_pred, output_dict=True, target_names=["malignant", "benign"]
Expand Down
58 changes: 56 additions & 2 deletions skops/card/_model_card.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

from skops.card._templates import CONTENT_PLACEHOLDER, SKOPS_TEMPLATE, Templates
from skops.io import load
from skops.utils.importutils import import_or_raise

# Repr attributes can be used to control the behavior of repr
aRepr = Repr()
Expand Down Expand Up @@ -206,7 +207,7 @@ def split_subsection_names(key: str) -> list[str]:


def _getting_started_code(
file_name: str, model_format: Literal["pickle", "skops"], indent=" "
file_name: str, model_format: Literal["pickle", "skops"], indent: str = " "
) -> list[str]:
# get lines of code required to load the model
lines = [
Expand Down Expand Up @@ -1085,11 +1086,64 @@ def add_metrics(
"You can find the details about evaluation process and "
"the evaluation results."
)

self._metrics.update(kwargs)
self._add_metrics(section, self._metrics, description=description)
return self

def add_permutation_importances(
self,
permutation_importances,
columns: Sequence[str],
plot_file: str = "permutation_importances.png",
plot_name: str = "Permutation Importances",
overwrite: bool = False,
) -> "Card":
"""Plots permutation importance and saves it to model card.
Parameters
----------
permutation_importances : sklearn.utils.Bunch
Output of :func:`sklearn.inspection.permutation_importance`.
columns : str, list or pandas.Index
Column names of the data used to generate importances.
plot_file : str
Filename for the plot.
plot_name : str
Name of the plot.
overwrite : bool (default=False)
Whether to overwrite the permutation importance plot file, if a plot by that
name already exists.
Returns
-------
self : object
Card object.
"""
plt = import_or_raise("matplotlib.pyplot", "permutation importance")

if Path(plot_file).exists() and overwrite is False:
raise ValueError(
f"{str(plot_file)} already exists. Set `overwrite` to `True` or pass a"
" different filename for the plot."
)
sorted_importances_idx = permutation_importances.importances_mean.argsort()
_, ax = plt.subplots()
ax.boxplot(
x=permutation_importances.importances[sorted_importances_idx].T,
labels=columns[sorted_importances_idx],
vert=False,
)
ax.set_title(plot_name)
ax.set_xlabel("Decrease in Score")
plt.savefig(plot_file)
self.add_plot(**{plot_name: plot_file})

return self

def _add_metrics(
self,
section: str,
Expand Down
97 changes: 96 additions & 1 deletion skops/card/tests/test_card.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,14 @@
import textwrap
from pathlib import Path

import matplotlib.pyplot as plt
import numpy as np
import pytest
import sklearn
from huggingface_hub import CardData, metadata_load
from sklearn.datasets import load_iris
from sklearn.inspection import permutation_importance
from sklearn.linear_model import LinearRegression, LogisticRegression
from sklearn.metrics import f1_score, make_scorer
from sklearn.neighbors import KNeighborsClassifier

from skops import hub_utils
Expand Down Expand Up @@ -403,6 +404,96 @@ def test_add_twice(self, model_card):
assert text1 == text2


def test_permutation_importances(
iris_estimator, iris_data, model_card, destination_path
):
X, y = iris_data
result = permutation_importance(
iris_estimator, X, y, n_repeats=10, random_state=42, n_jobs=2
)

model_card.add_permutation_importances(
result,
X.columns,
Path(destination_path) / "importance.png",
"Permutation Importance",
)
temp_path = Path(destination_path) / "importance.png"
assert f"![Permutation Importance]({temp_path}" in model_card.render()


def test_multiple_permutation_importances(
iris_estimator, iris_data, model_card, destination_path
):
X, y = iris_data
result = permutation_importance(
iris_estimator, X, y, n_repeats=10, random_state=42, n_jobs=2
)
model_card.add_permutation_importances(
result, X.columns, plot_file=Path(destination_path) / "importance.png"
)
f1 = make_scorer(f1_score, average="micro")
result = permutation_importance(
iris_estimator, X, y, scoring=f1, n_repeats=10, random_state=42, n_jobs=2
)
model_card.add_permutation_importances(
result,
X.columns,
plot_file=Path(destination_path) / "f1_importance.png",
plot_name="Permutation Importance on f1",
)
# check for default one
temp_path = Path(destination_path) / "importance.png"
assert f"![Permutation Importances]({temp_path}" in model_card.render()
# check for F1
temp_path_f1 = Path(destination_path) / "f1_importance.png"
assert f"![Permutation Importance on f1]({temp_path_f1}" in model_card.render()


def test_duplicate_permutation_importances(
iris_estimator, iris_data, model_card, destination_path
):
X, y = iris_data
result = permutation_importance(
iris_estimator, X, y, n_repeats=10, random_state=42, n_jobs=2
)
plot_path = os.path.join(destination_path, "importance.png")
model_card.add_permutation_importances(result, X.columns, plot_file=plot_path)
with pytest.raises(
ValueError,
match=(
"already exists. Set `overwrite` to `True` or pass a"
" different filename for the plot."
),
):
model_card.add_permutation_importances(
result,
X.columns,
plot_file=plot_path,
plot_name="Permutation Importance on f1",
)


def test_duplicate_permutation_importances_overwrite(
iris_estimator, iris_data, model_card, destination_path
):
X, y = iris_data
result = permutation_importance(
iris_estimator, X, y, n_repeats=10, random_state=42, n_jobs=2
)
plot_path = os.path.join(destination_path, "importance.png")
model_card.add_permutation_importances(result, X.columns, plot_file=plot_path)

model_card.add_permutation_importances(
result,
X.columns,
plot_file=plot_path,
plot_name="Permutation Importance on f1",
overwrite=True,
)
assert f"![Permutation Importance on f1]({plot_path}" in model_card.render()


class TestAddGetStartedCode:
"""Tests for getting started code"""

Expand Down Expand Up @@ -856,13 +947,17 @@ def test_delete_empty_key_subsection_raises(self, model_card):

class TestAddPlot:
def test_add_plot(self, destination_path, model_card):
import matplotlib.pyplot as plt

plt.plot([4, 5, 6, 7])
plt.savefig(Path(destination_path) / "fig1.png")
model_card = model_card.add_plot(fig1="fig1.png")
plot_content = model_card.select("fig1").content.format()
assert plot_content == "![fig1](fig1.png)"

def test_add_plot_to_existing_section(self, destination_path, model_card):
import matplotlib.pyplot as plt

plt.plot([4, 5, 6, 7])
plt.savefig(Path(destination_path) / "fig1.png")
model_card = model_card.add_plot(**{"Model description/Figure 1": "fig1.png"})
Expand Down
1 change: 0 additions & 1 deletion skops/card/tests/test_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,6 @@ def test_example_model_cards(tmp_path, file_name):
path = Path(os.getcwd()) / "skops" / "card" / "tests" / "examples"
file0 = path / file_name
diff = (path / file_name).with_suffix(".md.diff")

parsed_card = parse_modelcard(file0)
file1 = tmp_path / "readme-parsed.md"
parsed_card.save(file1)
Expand Down
29 changes: 28 additions & 1 deletion skops/conftest.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import builtins
from unittest.mock import patch

import pytest
Expand All @@ -7,7 +8,8 @@
def pandas_not_installed():
# patch import so that it raises an ImportError when trying to import
# pandas. This works because pandas is only imported lazily.
orig_import = __import__

orig_import = builtins.__import__

def mock_import(name, *args, **kwargs):
if name == "pandas":
Expand All @@ -16,3 +18,28 @@ def mock_import(name, *args, **kwargs):

with patch("builtins.__import__", side_effect=mock_import):
yield


@pytest.fixture
def matplotlib_not_installed():
# patch import so that it raises an ImportError when trying to import
# matplotlib. This works because matplotlib is only imported lazily.

# ugly way of removing matplotlib from cached imports
import sys

for key in list(sys.modules.keys()):
if key.startswith("matplotlib"):
del sys.modules[key]

orig_import = builtins.__import__

def mock_import(name, *args, **kwargs):
if name == "matplotlib":
raise ImportError
return orig_import(name, *args, **kwargs)

with patch("builtins.__import__", side_effect=mock_import):
yield

import matplotlib # noqa
4 changes: 2 additions & 2 deletions skops/hub_utils/_hf_hub.py
Original file line number Diff line number Diff line change
Expand Up @@ -206,7 +206,7 @@ def _create_config(
"text-regression",
],
data,
model_format: Literal[ # type: ignore
model_format: Literal[
"skops",
"pickle",
"auto",
Expand Down Expand Up @@ -337,7 +337,7 @@ def init(
"text-regression",
],
data,
model_format: Literal[ # type: ignore
model_format: Literal[
"skops",
"pickle",
"auto",
Expand Down
29 changes: 29 additions & 0 deletions skops/utils/importutils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
from importlib import import_module


def import_or_raise(module, feature_name):
"""Raise error if a given library is not present in the environment.
Parameters
----------
module: str
Name of the module.
feature_name: str
Name of the feature module is required for.
Raises
------
ModuleNotFoundError
Is raised if a given module is not present in the environment
"""
try:
module = import_module(module)
except ImportError as e:
package = module.split(".")[0]
raise ModuleNotFoundError(
f"{feature_name.capitalize()} requires {package} to be installed. In order"
f" to use {feature_name}, you need to install the package in your current"
" python environment."
) from e
return module
16 changes: 16 additions & 0 deletions skops/utils/tests/test_importutils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
import pytest

from skops.utils.importutils import import_or_raise


@pytest.mark.usefixtures("matplotlib_not_installed")
def test_import_or_raise():
with pytest.raises(
ModuleNotFoundError,
match=(
"Permutation importance requires matplotlib to be installed. In order"
" to use permutation importance, you need to install the package in"
" your current python environment."
),
):
import_or_raise("matplotlib", "permutation importance")

0 comments on commit 81558aa

Please sign in to comment.