-
Notifications
You must be signed in to change notification settings - Fork 54
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
load model from path/str to generate Card #205
Changes from 10 commits
842a1e1
e644e5c
0f447e3
88c81c1
e87f00e
626efcc
169c317
4181bcc
8e8ff22
ca506e8
eaf1989
5cb59a3
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change | ||||||
---|---|---|---|---|---|---|---|---|
|
@@ -5,16 +5,19 @@ | |||||||
import re | ||||||||
import shutil | ||||||||
import tempfile | ||||||||
import zipfile | ||||||||
from dataclasses import dataclass | ||||||||
from pathlib import Path | ||||||||
from reprlib import Repr | ||||||||
from typing import Any, Optional, Union | ||||||||
|
||||||||
import joblib | ||||||||
from huggingface_hub import ModelCard, ModelCardData | ||||||||
from sklearn.utils import estimator_html_repr | ||||||||
from tabulate import tabulate # type: ignore | ||||||||
|
||||||||
import skops | ||||||||
from skops.io import load | ||||||||
|
||||||||
# Repr attributes can be used to control the behavior of repr | ||||||||
aRepr = Repr() | ||||||||
|
@@ -162,6 +165,41 @@ def metadata_from_config(config_path: Union[str, Path]) -> ModelCardData: | |||||||
return card_data | ||||||||
|
||||||||
|
||||||||
def _load_model(model: Any) -> Any: | ||||||||
"""Loads the mddel if provided a file path, if already a model instance return it | ||||||||
unmodified. | ||||||||
|
||||||||
Parameters | ||||||||
---------- | ||||||||
model : pathlib.path, str, or sklearn estimator | ||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||||
Path/str or the actual model instance. if a Path or str, loads the model. | ||||||||
|
||||||||
Returns | ||||||||
------- | ||||||||
model : object | ||||||||
Model instance. | ||||||||
|
||||||||
""" | ||||||||
|
||||||||
if not isinstance(model, (Path, str)): | ||||||||
return model | ||||||||
|
||||||||
model_path = Path(model) | ||||||||
if not model_path.exists(): | ||||||||
raise FileNotFoundError(f"File is not present: {model_path}") | ||||||||
|
||||||||
try: | ||||||||
if zipfile.is_zipfile(model_path): | ||||||||
model = load(model_path) | ||||||||
else: | ||||||||
model = joblib.load(model_path) | ||||||||
except Exception as ex: | ||||||||
msg = f'An "{type(ex).__name__}" occured during model loading.' | ||||||||
raise RuntimeError(msg) from ex | ||||||||
|
||||||||
return model | ||||||||
|
||||||||
|
||||||||
class Card: | ||||||||
"""Model card class that will be used to generate model card. | ||||||||
|
||||||||
|
@@ -172,8 +210,11 @@ class Card: | |||||||
|
||||||||
Parameters | ||||||||
---------- | ||||||||
model: estimator object | ||||||||
Model that will be documented. | ||||||||
model: pathlib.path, str, or sklearn estimator object | ||||||||
``Path``/``str`` of the model or the actual model instance that will be | ||||||||
documented. If a ``Path`` or ``str`` is provided, model will be loaded | ||||||||
on first use. | ||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
After the change, it will be loaded each time, so "on first use" is not correct anymore. |
||||||||
|
||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||||
|
||||||||
model_diagram: bool, default=True | ||||||||
Set to True if model diagram should be plotted in the card. | ||||||||
|
@@ -263,6 +304,18 @@ def __init__( | |||||||
self._extra_sections: list[tuple[str, Any]] = [] | ||||||||
self.metadata = metadata or ModelCardData() | ||||||||
|
||||||||
def get_model(self) -> Any: | ||||||||
"""Returns sklearn estimator object if ``Path``/``str`` | ||||||||
is provided. | ||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||||
|
||||||||
Returns | ||||||||
------- | ||||||||
model : Object | ||||||||
Model instance. | ||||||||
""" | ||||||||
model = _load_model(self.model) | ||||||||
return model | ||||||||
|
||||||||
def add(self, **kwargs: str) -> "Card": | ||||||||
"""Takes values to fill model card template. | ||||||||
|
||||||||
|
@@ -412,7 +465,9 @@ def _generate_card(self) -> ModelCard: | |||||||
'clf.predict(pd.DataFrame.from_dict(config["sklearn"]["example_input"]))' | ||||||||
) | ||||||||
if self.model_diagram is True: | ||||||||
model_plot_div = re.sub(r"\n\s+", "", str(estimator_html_repr(self.model))) | ||||||||
model_plot_div = re.sub( | ||||||||
r"\n\s+", "", str(estimator_html_repr(self.get_model())) | ||||||||
) | ||||||||
if model_plot_div.count("sk-top-container") == 1: | ||||||||
model_plot_div = model_plot_div.replace( | ||||||||
"sk-top-container", 'sk-top-container" style="overflow: auto;' | ||||||||
|
@@ -497,7 +552,7 @@ def _extract_estimator_config(self) -> str: | |||||||
str: | ||||||||
Markdown table of hyperparameters. | ||||||||
""" | ||||||||
hyperparameter_dict = self.model.get_params(deep=True) | ||||||||
hyperparameter_dict = self.get_model().get_params(deep=True) | ||||||||
return _clean_table( | ||||||||
tabulate( | ||||||||
list(hyperparameter_dict.items()), | ||||||||
|
@@ -520,7 +575,7 @@ def __repr__(self) -> str: | |||||||
# create repr for model | ||||||||
model = getattr(self, "model", None) | ||||||||
if model: | ||||||||
model_str = self._strip_blank(repr(model)) | ||||||||
model_str = self._strip_blank(repr(self.get_model())) | ||||||||
model_repr = aRepr.repr(f" model={model_str},").strip('"').strip("'") | ||||||||
else: | ||||||||
model_repr = None | ||||||||
|
Original file line number | Diff line number | Diff line change | ||||||||
---|---|---|---|---|---|---|---|---|---|---|
|
@@ -15,7 +15,7 @@ | |||||||||
import skops | ||||||||||
from skops import hub_utils | ||||||||||
from skops.card import Card, metadata_from_config | ||||||||||
from skops.card._model_card import PlotSection, TableSection | ||||||||||
from skops.card._model_card import PlotSection, TableSection, _load_model | ||||||||||
from skops.io import dump | ||||||||||
|
||||||||||
|
||||||||||
|
@@ -26,6 +26,29 @@ def fit_model(): | |||||||||
return reg | ||||||||||
|
||||||||||
|
||||||||||
def save_model_to_file(model_instance, suffix): | ||||||||||
p-mishra1 marked this conversation as resolved.
Show resolved
Hide resolved
|
||||||||||
save_file_handle, save_file = tempfile.mkstemp(suffix=suffix, prefix="skops-test") | ||||||||||
if suffix in (".pkl", ".pickle"): | ||||||||||
with open(save_file, "wb") as f: | ||||||||||
pickle.dump(model_instance, f) | ||||||||||
elif suffix == ".skops": | ||||||||||
dump(model_instance, save_file) | ||||||||||
return save_file_handle, save_file | ||||||||||
|
||||||||||
|
||||||||||
@pytest.mark.parametrize("suffix", [".pkl", ".pickle", ".skops"]) | ||||||||||
def test_load_model(suffix): | ||||||||||
model0 = LinearRegression(n_jobs=123) | ||||||||||
_, save_file = save_model_to_file(model0, suffix) | ||||||||||
loaded_model_str = _load_model(save_file) | ||||||||||
save_file_path = Path(save_file) | ||||||||||
loaded_model_path = _load_model(save_file_path) | ||||||||||
BenjaminBossan marked this conversation as resolved.
Show resolved
Hide resolved
|
||||||||||
|
||||||||||
assert loaded_model_str.n_jobs == model0.n_jobs | ||||||||||
assert loaded_model_path.n_jobs == model0.n_jobs | ||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
This test is a tiny bit more robust this way (in case that |
||||||||||
assert loaded_model_path.n_jobs == loaded_model_str.n_jobs | ||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This test is not needed, right? If the two asserts above succeed, this cannot possibly fail. |
||||||||||
|
||||||||||
|
||||||||||
@pytest.fixture | ||||||||||
def model_card(model_diagram=True): | ||||||||||
model = fit_model() | ||||||||||
|
@@ -405,6 +428,85 @@ def test_with_metadata(self, card: Card, meth): | |||||||||
assert result == expected | ||||||||||
|
||||||||||
|
||||||||||
class TestCardModelAttribute: | ||||||||||
@pytest.fixture | ||||||||||
def card(self): | ||||||||||
model = LinearRegression(fit_intercept=False) | ||||||||||
card = Card(model=model) | ||||||||||
card.add( | ||||||||||
model_description="A description", | ||||||||||
model_card_authors="Jane Doe", | ||||||||||
) | ||||||||||
card.add_plot( | ||||||||||
roc_curve="ROC_curve.png", | ||||||||||
confusion_matrix="confusion_matrix.jpg", | ||||||||||
) | ||||||||||
card.add_table(search_results={"split": [1, 2, 3], "score": [4, 5, 6]}) | ||||||||||
return card | ||||||||||
|
||||||||||
def path_to_card(self, path): | ||||||||||
card = Card(model=path) | ||||||||||
card.add( | ||||||||||
model_description="A description", | ||||||||||
model_card_authors="Jane Doe", | ||||||||||
) | ||||||||||
card.add_plot( | ||||||||||
roc_curve="ROC_curve.png", | ||||||||||
confusion_matrix="confusion_matrix.jpg", | ||||||||||
) | ||||||||||
card.add_table(search_results={"split": [1, 2, 3], "score": [4, 5, 6]}) | ||||||||||
return card | ||||||||||
|
||||||||||
@pytest.mark.parametrize("meth", [repr, str]) | ||||||||||
@pytest.mark.parametrize("suffix", [".pkl", ".skops"]) | ||||||||||
def test_model_card_repr(self, card: Card, meth, suffix): | ||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'm not sure if this test is strictly necessary. Maybe it is sufficient to show that calling There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I changed it to only repr from card_from _path. |
||||||||||
result_from_model = meth(card) | ||||||||||
|
||||||||||
file_handle, file_name = save_model_to_file(card.model, suffix) | ||||||||||
|
||||||||||
os.close(file_handle) | ||||||||||
|
||||||||||
card_from_path = self.path_to_card(file_name) | ||||||||||
result_from_path = meth(card_from_path) | ||||||||||
print(result_from_path) | ||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Remove |
||||||||||
expected = ( | ||||||||||
"Card(\n" | ||||||||||
" model=LinearRegression(fit_intercept=False),\n" | ||||||||||
" model_description='A description',\n" | ||||||||||
" model_card_authors='Jane Doe',\n" | ||||||||||
" roc_curve='ROC_curve.png',\n" | ||||||||||
" confusion_matrix='confusion_matrix.jpg',\n" | ||||||||||
" search_results=Table(3x2),\n" | ||||||||||
")" | ||||||||||
) | ||||||||||
assert result_from_model == expected | ||||||||||
assert result_from_path == expected | ||||||||||
|
||||||||||
@pytest.mark.parametrize("suffix", [".pkl", ".skops"]) | ||||||||||
@pytest.mark.parametrize("meth", [repr, str]) | ||||||||||
def test_load_model_exception(self, meth, suffix): | ||||||||||
file_handle, file_name = tempfile.mkstemp(suffix=suffix, prefix="skops-test") | ||||||||||
|
||||||||||
os.close(file_handle) | ||||||||||
|
||||||||||
with pytest.raises(Exception, match="occured during model loading."): | ||||||||||
card = Card(file_name) | ||||||||||
meth(card) | ||||||||||
|
||||||||||
@pytest.mark.parametrize("meth", [repr, str]) | ||||||||||
def test_load_model_file_not_found(self, meth): | ||||||||||
file_handle, file_name = tempfile.mkstemp(suffix=".pkl", prefix="skops-test") | ||||||||||
|
||||||||||
os.close(file_handle) | ||||||||||
os.remove(file_name) | ||||||||||
|
||||||||||
with pytest.raises(FileNotFoundError) as excinfo: | ||||||||||
card = Card(file_name) | ||||||||||
meth(card) | ||||||||||
|
||||||||||
assert file_name in str(excinfo.value) | ||||||||||
|
||||||||||
|
||||||||||
class TestPlotSection: | ||||||||||
def test_format_path_is_str(self): | ||||||||||
section = PlotSection(alt_text="some title", path="path/plot.png") | ||||||||||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.