Skip to content

Commit

Permalink
mob next [ci-skip] [ci skip] [skip ci]
Browse files Browse the repository at this point in the history
lastFile:tests/unit/test_project.py
  • Loading branch information
augustebaum committed Sep 10, 2024
1 parent 5327ec7 commit 3285335
Show file tree
Hide file tree
Showing 2 changed files with 74 additions and 45 deletions.
20 changes: 8 additions & 12 deletions src/skore/project.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
""""""


import base64
import json
from dataclasses import dataclass
Expand All @@ -24,6 +23,7 @@ class ItemType(StrEnum):
@dataclass(frozen=True)
class Item:
"""An item is a value that is stored in the project."""

serialized: str
raw: Any
item_type: ItemType
Expand Down Expand Up @@ -126,24 +126,20 @@ def put(self, key: str, value: Any):
def put_item(self, key: str, item: Item):
self.storage.setitem(
key,
(
str(item.item_type),
item.media_type,
item.serialized,
),
{
"item_type": str(item.item_type),
"media_type": item.media_type,
"serialized": item.serialized,
},
)

def get(self, key: str) -> Any:
return self.get_item(key).raw

def get_item(self, key: str) -> Item:
item_type, media_type, serialized = self.storage.getitem(key)
item = self.storage.getitem(key)

return deserialize(
item_type,
media_type,
serialized,
)
return deserialize(**item)

def list_keys(self) -> List[str]:
"""List all keys in the project."""
Expand Down
99 changes: 66 additions & 33 deletions tests/unit/test_project.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,14 +8,13 @@
from matplotlib import pyplot as plt
from sklearn.ensemble import RandomForestClassifier
from skore.project import Item, ItemType, Project, deserialize, serialize
from skore.storage.filesystem import FileSystem
from skore.storage.non_persistent_storage import NonPersistentStorage


def test_transform_primitive():
o = 3
actual = serialize(o)
expected = Item(raw=3, raw_class_name="jsonable", serialized="3")
expected = Item(raw=3, item_type=ItemType.JSON, serialized="3")
assert actual == expected


Expand All @@ -24,7 +23,7 @@ def test_transform_pandas_dataframe():
actual = serialize(o)
expected = Item(
raw=o,
raw_class_name="pandas.DataFrame",
item_type=ItemType.PANDAS_DATAFRAME,
serialized=o.to_json(orient="split"),
)

Expand All @@ -36,7 +35,7 @@ def test_transform_numpy_ndarray():
actual = serialize(o)
expected = Item(
raw=o,
raw_class_name="numpy.ndarray",
item_type=ItemType.NUMPY_ARRAY,
serialized=json.dumps(o.tolist()),
)

Expand All @@ -51,7 +50,7 @@ def test_transform_sklearn_base_baseestimator(monkeypatch):
actual = serialize(o)
expected = Item(
raw=o,
raw_class_name="sklearn.base.BaseEstimator",
item_type=ItemType.SKLEARN_BASE_ESTIMATOR,
serialized=json.dumps(
{
"skops": "",
Expand Down Expand Up @@ -84,7 +83,9 @@ def test_untransform_primitive():
transformed = serialize(o)
assert (
deserialize(
serialized=transformed.serialized, raw_class_name=transformed.raw_class_name
serialized=transformed.serialized,
item_type=transformed.item_type,
media_type=None,
).raw
== o
)
Expand All @@ -94,7 +95,8 @@ def test_untransform_pandas_dataframe():
o = pandas.DataFrame([{"key": "value"}])
item = deserialize(
serialized=o.to_json(orient="split"),
raw_class_name="pandas.DataFrame",
item_type=ItemType.PANDAS_DATAFRAME,
media_type=None,
)

pandas.testing.assert_frame_equal(item.raw, o)
Expand All @@ -104,7 +106,8 @@ def test_untransform_numpy_ndarray():
o = numpy.array([1, 2, 3])
item = deserialize(
serialized=json.dumps(o.tolist()),
raw_class_name="numpy.ndarray",
item_type=ItemType.NUMPY_ARRAY,
media_type=None,
)

numpy.testing.assert_array_equal(o, item.raw)
Expand All @@ -113,7 +116,7 @@ def test_untransform_numpy_ndarray():
def test_untransform_sklearn_model():
o = sklearn.svm.SVC()
t = serialize(o)
u = deserialize(t.serialized, t.raw_class_name)
u = deserialize(serialized=t.serialized, item_type=t.item_type, media_type=None)

assert isinstance(u.raw, sklearn.svm.SVC)

Expand All @@ -129,9 +132,11 @@ def savefig(*args, **kwargs):

s = serialize(fig)

assert deserialize(s.serialized, s.raw_class_name, s.media_type) == Item(
assert deserialize(
serialized=s.serialized, item_type=s.item_type, media_type=s.media_type
) == Item(
raw=None,
raw_class_name="matplotlib.figure.Figure",
item_type=ItemType.MEDIA,
serialized="",
media_type="image/svg+xml",
)
Expand Down Expand Up @@ -180,28 +185,56 @@ def savefig(*args, **kwargs):
project.put("rf_model", model) # ScikitLearnModelItem

assert project.storage.content == {
"string_item": ("jsonable", '"Hello, World!"', None),
"int_item": ("jsonable", "42", None),
"float_item": ("jsonable", "3.14", None),
"bool_item": ("jsonable", "true", None),
"list_item": ("jsonable", "[1, 2, 3]", None),
"dict_item": ("jsonable", '{"key": "value"}', None),
"pandas_df": (
"pandas.DataFrame",
'{"columns":["A","B"],"index":[0,1,2],"data":[[1,4],[2,5],[3,6]]}',
None,
),
"numpy_array": ("numpy.ndarray", "[1, 2, 3, 4, 5]", None),
"mpl_figure": (
"matplotlib.figure.Figure",
"",
"image/svg+xml",
),
"rf_model": (
"sklearn.base.BaseEstimator",
'{"skops": "", "html": ""}',
"text/html",
),
"string_item": {
"item_type": ItemType.JSON,
"serialized": '"Hello, World!"',
"media_type": None,
},
"int_item": {
"item_type": ItemType.JSON,
"serialized": "42",
"media_type": None,
},
"float_item": {
"item_type": ItemType.JSON,
"serialized": "3.14",
"media_type": None,
},
"bool_item": {
"item_type": ItemType.JSON,
"serialized": "true",
"media_type": None,
},
"list_item": {
"item_type": ItemType.JSON,
"serialized": "[1, 2, 3]",
"media_type": None,
},
"dict_item": {
"item_type": ItemType.JSON,
"serialized": '{"key": "value"}',
"media_type": None,
},
"pandas_df": {
"item_type": ItemType.JSON,
"serialized": '{"columns":["A","B"],"index":[0,1,2],"data":[[1,4],[2,5],[3,6]]}',
"media_type": None,
},
"numpy_array": {
"item_type": ItemType.NUMPY_ARRAY,
"serialized": "[1, 2, 3, 4, 5]",
"media_type": None,
},
"mpl_figure": {
"item_type": ItemType.MEDIA,
"serialized": "",
"media_type": "image/svg+xml",
},
"rf_model": {
"item_type": ItemType.SKLEARN_BASE_ESTIMATOR,
"serialized": '{"skops": "", "html": ""}',
"media_type": "text/html",
},
}

assert project.get("string_item") == "Hello, World!"
Expand Down

0 comments on commit 3285335

Please sign in to comment.