Skip to content

Commit

Permalink
load model from path/str to generate Card (#205)
Browse files Browse the repository at this point in the history
Solves #96

Add possibility to instantiate a model card with a model passed as str or Path.

The model will be loaded from disk if accessed. For now, the model is not cached,
so passing the loaded model is still recommended.
  • Loading branch information
p-mishra1 authored Nov 23, 2022
1 parent 7e9ff2f commit 757b940
Show file tree
Hide file tree
Showing 3 changed files with 146 additions and 7 deletions.
5 changes: 4 additions & 1 deletion docs/changes.rst
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,8 @@ v0.3
- Use ``huggingface_hub`` v0.10.1 for model cards, drop ``modelcards``
dependency. :pr:`162` by `Benjamin Bossan`_.
- Add source links to API documentation. :pr:`172` by `Ayyuce Demirbas`_.
- Add support to load model if given Path/str to ``model`` argument in
:mod:`skops.card` . :pr:`205` by `prajjwal mishra`_.


v0.2
Expand Down Expand Up @@ -61,4 +63,5 @@ Contributors
~~~~~~~~~~~~

:user:`Adrin Jalali <adrinjalali>`, :user:`Merve Noyan <merveenoyan>`,
:user:`Benjamin Bossan <BenjaminBossan>`, :user:`Ayyuce Demirbas <ayyucedemirbas>`
:user:`Benjamin Bossan <BenjaminBossan>`, :user:`Ayyuce Demirbas <ayyucedemirbas>`,
:user:`Prajjwal Mishra <p-mishra1>`
63 changes: 58 additions & 5 deletions skops/card/_model_card.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,16 +5,19 @@
import re
import shutil
import tempfile
import zipfile
from dataclasses import dataclass
from pathlib import Path
from reprlib import Repr
from typing import Any, Optional, Union

import joblib
from huggingface_hub import ModelCard, ModelCardData
from sklearn.utils import estimator_html_repr
from tabulate import tabulate # type: ignore

import skops
from skops.io import load

# Repr attributes can be used to control the behavior of repr
aRepr = Repr()
Expand Down Expand Up @@ -162,6 +165,41 @@ def metadata_from_config(config_path: Union[str, Path]) -> ModelCardData:
return card_data


def _load_model(model: Any) -> Any:
"""Loads the mddel if provided a file path, if already a model instance return it
unmodified.
Parameters
----------
model : pathlib.Path, str, or sklearn estimator
Path/str or the actual model instance. if a Path or str, loads the model.
Returns
-------
model : object
Model instance.
"""

if not isinstance(model, (Path, str)):
return model

model_path = Path(model)
if not model_path.exists():
raise FileNotFoundError(f"File is not present: {model_path}")

try:
if zipfile.is_zipfile(model_path):
model = load(model_path)
else:
model = joblib.load(model_path)
except Exception as ex:
msg = f'An "{type(ex).__name__}" occured during model loading.'
raise RuntimeError(msg) from ex

return model


class Card:
"""Model card class that will be used to generate model card.
Expand All @@ -172,8 +210,9 @@ class Card:
Parameters
----------
model: estimator object
Model that will be documented.
model: pathlib.Path, str, or sklearn estimator object
``Path``/``str`` of the model or the actual model instance that will be
documented. If a ``Path`` or ``str`` is provided, model will be loaded.
model_diagram: bool, default=True
Set to True if model diagram should be plotted in the card.
Expand Down Expand Up @@ -263,6 +302,18 @@ def __init__(
self._extra_sections: list[tuple[str, Any]] = []
self.metadata = metadata or ModelCardData()

def get_model(self) -> Any:
"""Returns sklearn estimator object if ``Path``/``str``
is provided.
Returns
-------
model : Object
Model instance.
"""
model = _load_model(self.model)
return model

def add(self, **kwargs: str) -> "Card":
"""Takes values to fill model card template.
Expand Down Expand Up @@ -412,7 +463,9 @@ def _generate_card(self) -> ModelCard:
'clf.predict(pd.DataFrame.from_dict(config["sklearn"]["example_input"]))'
)
if self.model_diagram is True:
model_plot_div = re.sub(r"\n\s+", "", str(estimator_html_repr(self.model)))
model_plot_div = re.sub(
r"\n\s+", "", str(estimator_html_repr(self.get_model()))
)
if model_plot_div.count("sk-top-container") == 1:
model_plot_div = model_plot_div.replace(
"sk-top-container", 'sk-top-container" style="overflow: auto;'
Expand Down Expand Up @@ -497,7 +550,7 @@ def _extract_estimator_config(self) -> str:
str:
Markdown table of hyperparameters.
"""
hyperparameter_dict = self.model.get_params(deep=True)
hyperparameter_dict = self.get_model().get_params(deep=True)
return _clean_table(
tabulate(
list(hyperparameter_dict.items()),
Expand All @@ -520,7 +573,7 @@ def __repr__(self) -> str:
# create repr for model
model = getattr(self, "model", None)
if model:
model_str = self._strip_blank(repr(model))
model_str = self._strip_blank(repr(self.get_model()))
model_repr = aRepr.repr(f" model={model_str},").strip('"').strip("'")
else:
model_repr = None
Expand Down
85 changes: 84 additions & 1 deletion skops/card/tests/test_card.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
import skops
from skops import hub_utils
from skops.card import Card, metadata_from_config
from skops.card._model_card import PlotSection, TableSection
from skops.card._model_card import PlotSection, TableSection, _load_model
from skops.io import dump


Expand All @@ -26,6 +26,30 @@ def fit_model():
return reg


def save_model_to_file(model_instance, suffix):
save_file_handle, save_file = tempfile.mkstemp(suffix=suffix, prefix="skops-test")
if suffix in (".pkl", ".pickle"):
with open(save_file, "wb") as f:
pickle.dump(model_instance, f)
elif suffix == ".skops":
dump(model_instance, save_file)
return save_file_handle, save_file


@pytest.mark.parametrize("suffix", [".pkl", ".pickle", ".skops"])
def test_load_model(suffix):
model0 = LinearRegression(n_jobs=123)
_, save_file = save_model_to_file(model0, suffix)
loaded_model_str = _load_model(save_file)
save_file_path = Path(save_file)
loaded_model_path = _load_model(save_file_path)
loaded_model_instance = _load_model(model0)

assert loaded_model_str.n_jobs == 123
assert loaded_model_path.n_jobs == 123
assert loaded_model_instance.n_jobs == 123


@pytest.fixture
def model_card(model_diagram=True):
model = fit_model()
Expand Down Expand Up @@ -405,6 +429,65 @@ def test_with_metadata(self, card: Card, meth):
assert result == expected


class TestCardModelAttribute:
def path_to_card(self, path):
card = Card(model=path)
card.add(
model_description="A description",
model_card_authors="Jane Doe",
)
card.add_plot(
roc_curve="ROC_curve.png",
confusion_matrix="confusion_matrix.jpg",
)
card.add_table(search_results={"split": [1, 2, 3], "score": [4, 5, 6]})
return card

@pytest.mark.parametrize("meth", [repr, str])
@pytest.mark.parametrize("suffix", [".pkl", ".skops"])
def test_model_card_repr(self, meth, suffix):
model = LinearRegression(fit_intercept=False)
file_handle, file_name = save_model_to_file(model, suffix)
os.close(file_handle)
card_from_path = self.path_to_card(file_name)
result_from_path = meth(card_from_path)
expected = (
"Card(\n"
" model=LinearRegression(fit_intercept=False),\n"
" model_description='A description',\n"
" model_card_authors='Jane Doe',\n"
" roc_curve='ROC_curve.png',\n"
" confusion_matrix='confusion_matrix.jpg',\n"
" search_results=Table(3x2),\n"
")"
)
assert result_from_path == expected

@pytest.mark.parametrize("suffix", [".pkl", ".skops"])
@pytest.mark.parametrize("meth", [repr, str])
def test_load_model_exception(self, meth, suffix):
file_handle, file_name = tempfile.mkstemp(suffix=suffix, prefix="skops-test")

os.close(file_handle)

with pytest.raises(Exception, match="occured during model loading."):
card = Card(file_name)
meth(card)

@pytest.mark.parametrize("meth", [repr, str])
def test_load_model_file_not_found(self, meth):
file_handle, file_name = tempfile.mkstemp(suffix=".pkl", prefix="skops-test")

os.close(file_handle)
os.remove(file_name)

with pytest.raises(FileNotFoundError) as excinfo:
card = Card(file_name)
meth(card)

assert file_name in str(excinfo.value)


class TestPlotSection:
def test_format_path_is_str(self):
section = PlotSection(alt_text="some title", path="path/plot.png")
Expand Down

0 comments on commit 757b940

Please sign in to comment.