Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

load model from path/str to generate Card #205

Merged
merged 12 commits into from
Nov 23, 2022
5 changes: 4 additions & 1 deletion docs/changes.rst
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,8 @@ v0.3
- Use ``huggingface_hub`` v0.10.1 for model cards, drop ``modelcards``
dependency. :pr:`162` by `Benjamin Bossan`_.
- Add source links to API documentation. :pr:`172` by `Ayyuce Demirbas`_.
- Add support to load model if given Path/str to ``model`` argument in
:mod:`skops.card` . :pr:`205` by `prajjwal mishra`_.


v0.2
Expand Down Expand Up @@ -59,4 +61,5 @@ Contributors
~~~~~~~~~~~~

:user:`Adrin Jalali <adrinjalali>`, :user:`Merve Noyan <merveenoyan>`,
:user:`Benjamin Bossan <BenjaminBossan>`, :user:`Ayyuce Demirbas <ayyucedemirbas>`
:user:`Benjamin Bossan <BenjaminBossan>`, :user:`Ayyuce Demirbas <ayyucedemirbas>`,
:user:`Prajjwal Mishra <p-mishra1>`
65 changes: 60 additions & 5 deletions skops/card/_model_card.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,16 +5,19 @@
import re
import shutil
import tempfile
import zipfile
from dataclasses import dataclass
from pathlib import Path
from reprlib import Repr
from typing import Any, Optional, Union

import joblib
from huggingface_hub import ModelCard, ModelCardData
from sklearn.utils import estimator_html_repr
from tabulate import tabulate # type: ignore

import skops
from skops.io import load

# Repr attributes can be used to control the behavior of repr
aRepr = Repr()
Expand Down Expand Up @@ -162,6 +165,41 @@ def metadata_from_config(config_path: Union[str, Path]) -> ModelCardData:
return card_data


def _load_model(model: Any) -> Any:
"""Loads the mddel if provided a file path, if already a model instance return it
unmodified.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
unmodified.
unmodified.


Parameters
----------
model : pathlib.path, str, or sklearn estimator
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
model : pathlib.path, str, or sklearn estimator
model : pathlib.Path, str, or sklearn estimator

Path/str or the actual model instance. if a Path or str, loads the model.

Returns
-------
model : object
Model instance.

"""

if not isinstance(model, (Path, str)):
return model

model_path = Path(model)
if not model_path.exists():
raise FileNotFoundError(f"File is not present: {model_path}")

try:
if zipfile.is_zipfile(model_path):
model = load(model_path)
else:
model = joblib.load(model_path)
except Exception as ex:
msg = f'An "{type(ex).__name__}" occured during model loading.'
raise RuntimeError(msg) from ex

return model


class Card:
"""Model card class that will be used to generate model card.

Expand All @@ -172,8 +210,11 @@ class Card:

Parameters
----------
model: estimator object
Model that will be documented.
model: pathlib.path, str, or sklearn estimator object
``Path``/``str`` of the model or the actual model instance that will be
documented. If a ``Path`` or ``str`` is provided, model will be loaded
on first use.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
documented. If a ``Path`` or ``str`` is provided, model will be loaded
on first use.
documented. If a ``Path`` or ``str`` is provided, model will be loaded.

After the change, it will be loaded each time, so "on first use" is not correct anymore.


Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change


model_diagram: bool, default=True
Set to True if model diagram should be plotted in the card.
Expand Down Expand Up @@ -263,6 +304,18 @@ def __init__(
self._extra_sections: list[tuple[str, Any]] = []
self.metadata = metadata or ModelCardData()

def get_model(self) -> Any:
"""Returns sklearn estimator object if ``Path``/``str``
is provided.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
is provided.
is provided.


Returns
-------
model : Object
Model instance.
"""
model = _load_model(self.model)
return model

def add(self, **kwargs: str) -> "Card":
"""Takes values to fill model card template.

Expand Down Expand Up @@ -412,7 +465,9 @@ def _generate_card(self) -> ModelCard:
'clf.predict(pd.DataFrame.from_dict(config["sklearn"]["example_input"]))'
)
if self.model_diagram is True:
model_plot_div = re.sub(r"\n\s+", "", str(estimator_html_repr(self.model)))
model_plot_div = re.sub(
r"\n\s+", "", str(estimator_html_repr(self.get_model()))
)
if model_plot_div.count("sk-top-container") == 1:
model_plot_div = model_plot_div.replace(
"sk-top-container", 'sk-top-container" style="overflow: auto;'
Expand Down Expand Up @@ -497,7 +552,7 @@ def _extract_estimator_config(self) -> str:
str:
Markdown table of hyperparameters.
"""
hyperparameter_dict = self.model.get_params(deep=True)
hyperparameter_dict = self.get_model().get_params(deep=True)
return _clean_table(
tabulate(
list(hyperparameter_dict.items()),
Expand All @@ -520,7 +575,7 @@ def __repr__(self) -> str:
# create repr for model
model = getattr(self, "model", None)
if model:
model_str = self._strip_blank(repr(model))
model_str = self._strip_blank(repr(self.get_model()))
model_repr = aRepr.repr(f" model={model_str},").strip('"').strip("'")
else:
model_repr = None
Expand Down
104 changes: 103 additions & 1 deletion skops/card/tests/test_card.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
import skops
from skops import hub_utils
from skops.card import Card, metadata_from_config
from skops.card._model_card import PlotSection, TableSection
from skops.card._model_card import PlotSection, TableSection, _load_model
from skops.io import dump


Expand All @@ -26,6 +26,29 @@ def fit_model():
return reg


def save_model_to_file(model_instance, suffix):
p-mishra1 marked this conversation as resolved.
Show resolved Hide resolved
save_file_handle, save_file = tempfile.mkstemp(suffix=suffix, prefix="skops-test")
if suffix in (".pkl", ".pickle"):
with open(save_file, "wb") as f:
pickle.dump(model_instance, f)
elif suffix == ".skops":
dump(model_instance, save_file)
return save_file_handle, save_file


@pytest.mark.parametrize("suffix", [".pkl", ".pickle", ".skops"])
def test_load_model(suffix):
model0 = LinearRegression(n_jobs=123)
_, save_file = save_model_to_file(model0, suffix)
loaded_model_str = _load_model(save_file)
save_file_path = Path(save_file)
loaded_model_path = _load_model(save_file_path)
BenjaminBossan marked this conversation as resolved.
Show resolved Hide resolved

assert loaded_model_str.n_jobs == model0.n_jobs
assert loaded_model_path.n_jobs == model0.n_jobs
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
assert loaded_model_str.n_jobs == model0.n_jobs
assert loaded_model_path.n_jobs == model0.n_jobs
assert loaded_model_str.n_jobs == 123
assert loaded_model_path.n_jobs == 123

This test is a tiny bit more robust this way (in case that model0 is being mutated by _load_model for some reason).

assert loaded_model_path.n_jobs == loaded_model_str.n_jobs
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This test is not needed, right? If the two asserts above succeed, this cannot possibly fail.



@pytest.fixture
def model_card(model_diagram=True):
model = fit_model()
Expand Down Expand Up @@ -405,6 +428,85 @@ def test_with_metadata(self, card: Card, meth):
assert result == expected


class TestCardModelAttribute:
@pytest.fixture
def card(self):
model = LinearRegression(fit_intercept=False)
card = Card(model=model)
card.add(
model_description="A description",
model_card_authors="Jane Doe",
)
card.add_plot(
roc_curve="ROC_curve.png",
confusion_matrix="confusion_matrix.jpg",
)
card.add_table(search_results={"split": [1, 2, 3], "score": [4, 5, 6]})
return card

def path_to_card(self, path):
card = Card(model=path)
card.add(
model_description="A description",
model_card_authors="Jane Doe",
)
card.add_plot(
roc_curve="ROC_curve.png",
confusion_matrix="confusion_matrix.jpg",
)
card.add_table(search_results={"split": [1, 2, 3], "score": [4, 5, 6]})
return card

@pytest.mark.parametrize("meth", [repr, str])
@pytest.mark.parametrize("suffix", [".pkl", ".skops"])
def test_model_card_repr(self, card: Card, meth, suffix):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure if this test is strictly necessary. Maybe it is sufficient to show that calling repr or str on a card that has a path/str for model works? The fact that it's the same output is implied by the other tests.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I changed it to only repr from card_from _path.

result_from_model = meth(card)

file_handle, file_name = save_model_to_file(card.model, suffix)

os.close(file_handle)

card_from_path = self.path_to_card(file_name)
result_from_path = meth(card_from_path)
print(result_from_path)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Remove

expected = (
"Card(\n"
" model=LinearRegression(fit_intercept=False),\n"
" model_description='A description',\n"
" model_card_authors='Jane Doe',\n"
" roc_curve='ROC_curve.png',\n"
" confusion_matrix='confusion_matrix.jpg',\n"
" search_results=Table(3x2),\n"
")"
)
assert result_from_model == expected
assert result_from_path == expected

@pytest.mark.parametrize("suffix", [".pkl", ".skops"])
@pytest.mark.parametrize("meth", [repr, str])
def test_load_model_exception(self, meth, suffix):
file_handle, file_name = tempfile.mkstemp(suffix=suffix, prefix="skops-test")

os.close(file_handle)

with pytest.raises(Exception, match="occured during model loading."):
card = Card(file_name)
meth(card)

@pytest.mark.parametrize("meth", [repr, str])
def test_load_model_file_not_found(self, meth):
file_handle, file_name = tempfile.mkstemp(suffix=".pkl", prefix="skops-test")

os.close(file_handle)
os.remove(file_name)

with pytest.raises(FileNotFoundError) as excinfo:
card = Card(file_name)
meth(card)

assert file_name in str(excinfo.value)


class TestPlotSection:
def test_format_path_is_str(self):
section = PlotSection(alt_text="some title", path="path/plot.png")
Expand Down