Skip to content

Commit

Permalink
mob next [ci-skip] [ci skip] [skip ci]
Browse files Browse the repository at this point in the history
lastFile:tests/unit/test_project.py
  • Loading branch information
rouk1 committed Sep 9, 2024
1 parent de4fdad commit 4f2b7f0
Show file tree
Hide file tree
Showing 2 changed files with 32 additions and 40 deletions.
59 changes: 22 additions & 37 deletions src/skore/project.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,61 +3,40 @@
import base64
import json
from dataclasses import dataclass
from io import StringIO
from typing import Any, List


@dataclass(frozen=True)
class JSONItem:
class Item:
"""An item is a value that is stored in the project."""

serialized: str
raw: Any
raw_class_name: str
media_type: str | None = None


@dataclass(frozen=True)
class DataFrameItem:
serialized: str
raw: Any
raw_class_name: str


@dataclass(frozen=True)
class NumpyArrayItem:
serialized: str
raw: Any
raw_class_name: str


@dataclass(frozen=True)
class SklearnModelItem:
serialized: str
raw: Any
raw_class_name: str


Item = JSONItem | DataFrameItem | NumpyArrayItem | SklearnModelItem


def serialize(o: Any) -> Item | None:
def serialize(o: Any) -> Item:
"""Transform an object into an item."""
try:
serialized = json.dumps(o)
return JSONItem(raw=o, raw_class_name="jsonable", serialized=serialized)
return Item(raw=o, raw_class_name="jsonable", serialized=serialized)
except TypeError:
import matplotlib.figure
import numpy
import pandas
import sklearn
import skops.io

if isinstance(o, pandas.DataFrame):
return DataFrameItem(
return Item(
raw=o,
raw_class_name="pandas.DataFrame",
serialized=o.to_json(orient="split"),
)
if isinstance(o, numpy.ndarray):
return NumpyArrayItem(
return Item(
raw=o,
raw_class_name="numpy.ndarray",
serialized=json.dumps(o.tolist()),
Expand All @@ -66,14 +45,23 @@ def serialize(o: Any) -> Item | None:
sk_dump = skops.io.dumps(o)
serialized_model = base64.b64encode(sk_dump).decode("ascii")
html_representation = sklearn.utils.estimator_html_repr(o)
return SklearnModelItem(
return Item(
raw=o,
raw_class_name="sklearn.base.BaseEstimator",
serialized=json.dumps(
{"skops": serialized_model, "html": html_representation}
),
)
return None
if isinstance(o, matplotlib.figure.Figure):
output = StringIO()
o.savefig(output, format="svg")
return Item(
raw=o,
raw_class_name="matplotlib.figure.Figure",
serialized=output.getvalue(),
)

raise NotImplementedError(f"Type {o.__class__.__name__} is not supported yet.")


def deserialize(serialized: str, raw_class_name: str) -> Item:
Expand All @@ -83,26 +71,22 @@ def deserialize(serialized: str, raw_class_name: str) -> Item:
match raw_class_name:
case "jsonable":
raw = json.loads(serialized)
cls = JSONItem
case "pandas.DataFrame":
import pandas

raw = pandas.read_json(serialized, orient="split")
cls = DataFrameItem
case "numpy.ndarray":
import numpy

raw = numpy.array(json.loads(serialized))
cls = NumpyArrayItem
case "sklearn.base.BaseEstimator":
import skops.io

o = json.loads(serialized)
unserialized = base64.b64decode(o["skops"])
raw = skops.io.loads(unserialized)
cls = SklearnModelItem

return cls(raw=raw, raw_class_name=raw_class_name, serialized=serialized)
return Item(raw=raw, raw_class_name=raw_class_name, serialized=serialized)


class Project:
Expand All @@ -112,7 +96,8 @@ def __init__(self):
self.storage = {}

def put(self, key: str, value: Any):
self.put_item(key, serialize(value))
i = serialize(value)
self.put_item(key, i)

def put_item(self, key: str, item: Item):
self.storage[key] = (item.raw_class_name, item.serialized)
Expand Down
13 changes: 10 additions & 3 deletions tests/unit/test_project.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,14 @@
import pandas
import pandas.testing
import sklearn.svm
from skore.project import Item, serialize, deserialize, SklearnModelItem
from matplotlib import pyplot as plt
from skore.project import Item, deserialize, serialize


def test_transform_primitive():
o = 3
actual = serialize(o)
expected = Item(raw=3, raw_class_name="primitive", serialized="3")
expected = Item(raw=3, raw_class_name="jsonable", serialized="3")
assert actual == expected


Expand Down Expand Up @@ -45,7 +46,7 @@ def test_transform_sklearn_base_baseestimator(monkeypatch):

o = sklearn.svm.SVC()
actual = serialize(o)
expected = SklearnModelItem(
expected = Item(
raw=o,
raw_class_name="sklearn.base.BaseEstimator",
serialized=json.dumps(
Expand All @@ -59,6 +60,12 @@ def test_transform_sklearn_base_baseestimator(monkeypatch):
assert actual == expected


def test_transform_matplotlib():
fig, ax = plt.subplots()
ax.plot([1, 2, 3, 4])
s = serialize(fig)


def test_untransform_primitive():
o = 3
transformed = serialize(o)
Expand Down

0 comments on commit 4f2b7f0

Please sign in to comment.