Skip to content

Commit

Permalink
mob next [ci-skip] [ci skip] [skip ci]
Browse files Browse the repository at this point in the history
lastFile:tests/unit/test_project.py
  • Loading branch information
thomass-dev committed Sep 11, 2024
1 parent 4fb5052 commit 33f3ef9
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 11 deletions.
9 changes: 2 additions & 7 deletions src/skore/project.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@ class ItemType(StrEnum):
PANDAS_DATAFRAME = auto()
NUMPY_ARRAY = auto()
SKLEARN_BASE_ESTIMATOR = auto()
ALTAIR_CHART = auto()
MEDIA = auto()


Expand Down Expand Up @@ -74,9 +73,9 @@ def serialize(o: Any) -> Item:
if isinstance(o, altair.vegalite.v5.schema.core.TopLevelSpec):
return Item(
raw=o,
item_type=ItemType.ALTAIR_CHART,
item_type=ItemType.MEDIA,
serialized=json.dumps(o.to_dict()),
media_type=None,
media_type="application/vnd.vega.v5+json",
)
if isinstance(o, matplotlib.figure.Figure):
with StringIO() as output:
Expand Down Expand Up @@ -125,10 +124,6 @@ def deserialize(
o = json.loads(serialized)
unserialized = base64.b64decode(o["skops"])
raw = skops.io.loads(unserialized)
case ItemType.ALTAIR_CHART:
import altair

raw = altair.Chart.from_dict(json.loads(serialized))
case _:
raw = None

Expand Down
12 changes: 8 additions & 4 deletions tests/unit/test_project.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import pandas
import pandas.testing
import sklearn.svm
import skops
from matplotlib import pyplot as plt
from PIL import Image
from sklearn.ensemble import RandomForestClassifier
Expand Down Expand Up @@ -152,7 +153,6 @@ def savefig(*args, **kwargs):

monkeypatch.setattr("matplotlib.figure.Figure.savefig", savefig)
monkeypatch.setattr("sklearn.utils.estimator_html_repr", lambda _: "")
# monkeypatch.setattr("skops.io.dumps", lambda _: b"")

project = Project(NonPersistentStorage())
project.put("string_item", "Hello, World!") # JSONItem
Expand Down Expand Up @@ -205,8 +205,12 @@ def savefig(*args, **kwargs):
# Add a scikit-learn model
model = RandomForestClassifier()
model.fit(numpy.array([[1, 2], [3, 4]]), [0, 1])
model_dump = skops.io.dumps(model)
serialized_model = base64.b64encode(model_dump).decode("ascii")
project.put("rf_model", model) # ScikitLearnModelItem

breakpoint()

assert project.storage.content == {
"string_item": {
"item_type": str(ItemType.JSON),
Expand Down Expand Up @@ -255,13 +259,13 @@ def savefig(*args, **kwargs):
},
"rf_model": {
"item_type": str(ItemType.SKLEARN_BASE_ESTIMATOR),
"serialized": '{"skops": "", "html": ""}',
"serialized": f"{'skops': '{serialized_model}', 'html': ''}",
"media_type": "text/html",
},
"vega_chart": {
"item_type": str(ItemType.ALTAIR_CHART),
"item_type": str(ItemType.MEDIA),
"serialized": json.dumps(altair_chart.to_dict()),
"media_type": None,
"media_type": "application/vnd.vega.v5+json",
},
"pil_image": {
"item_type": str(ItemType.MEDIA),
Expand Down

0 comments on commit 33f3ef9

Please sign in to comment.