Skip to content

Commit

Permalink
mob next [ci-skip] [ci skip] [skip ci]
Browse files Browse the repository at this point in the history
lastFile:tests/unit/test_project.py
  • Loading branch information
rouk1 committed Sep 9, 2024
1 parent 373377b commit 29ea203
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 6 deletions.
8 changes: 4 additions & 4 deletions src/skore/project.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,14 +39,14 @@ def transform(o: Any) -> Item | None:
serialized=json.dumps(o.tolist()),
)
if isinstance(o, sklearn.base.BaseEstimator):
sk_dump = skops.io.dumps(o)
serialized_model = base64.b64encode(sk_dump).decode("ascii")
html_representation = sklearn.utils.estimator_html_repr(o)
return Item(
raw=o,
raw_class_name="sklearn.base.BaseEstimator",
serialized=json.dumps(
{
"skops": base64.b64encode(skops.io.dumps(o)).decode("ascii"),
"html": sklearn.utils.estimator_html_repr(o),
}
{"skops": serialized_model, "html": html_representation}
),
)
return None
Expand Down
4 changes: 2 additions & 2 deletions tests/unit/test_project.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,13 +39,13 @@ def test_transform_numpy_ndarray():

def test_transform_sklearn_base_baseestimator(monkeypatch):
monkeypatch.setattr("sklearn.utils.estimator_html_repr", lambda _: "")
monkeypatch.setattr("skops.io.dumps", lambda _: "")
monkeypatch.setattr("skops.io.dumps", lambda _: b"")

o = sklearn.svm.SVC()
actual = transform(o)
expected = Item(
raw=o,
raw_class_name="numpy.ndarray",
raw_class_name="sklearn.base.BaseEstimator",
serialized=json.dumps(
{
"skops": "",
Expand Down

0 comments on commit 29ea203

Please sign in to comment.