Skip to content

Commit

Permalink
mob next [ci-skip] [ci skip] [skip ci]
Browse files Browse the repository at this point in the history
lastFile:tests/unit/test_project.py
  • Loading branch information
auguste-probabl committed Sep 9, 2024
1 parent 55c761b commit de4fdad
Show file tree
Hide file tree
Showing 2 changed files with 53 additions and 25 deletions.
54 changes: 41 additions & 13 deletions src/skore/project.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,33 +7,57 @@


@dataclass(frozen=True)
class Item:
class JSONItem:
"""An item is a value that is stored in the project."""

serialized: str
raw: Any = None
raw_class_name: str | None = None
raw: Any
raw_class_name: str


def transform(o: Any) -> Item | None:
@dataclass(frozen=True)
class DataFrameItem:
serialized: str
raw: Any
raw_class_name: str


@dataclass(frozen=True)
class NumpyArrayItem:
serialized: str
raw: Any
raw_class_name: str


@dataclass(frozen=True)
class SklearnModelItem:
serialized: str
raw: Any
raw_class_name: str


Item = JSONItem | DataFrameItem | NumpyArrayItem | SklearnModelItem


def serialize(o: Any) -> Item | None:
"""Transform an object into an item."""
try:
serialized = json.dumps(o)
return Item(raw=o, raw_class_name="primitive", serialized=serialized)
return JSONItem(raw=o, raw_class_name="jsonable", serialized=serialized)
except TypeError:
import numpy
import pandas
import sklearn
import skops.io

if isinstance(o, pandas.DataFrame):
return Item(
return DataFrameItem(
raw=o,
raw_class_name="pandas.DataFrame",
serialized=o.to_json(orient="split"),
)
if isinstance(o, numpy.ndarray):
return Item(
return NumpyArrayItem(
raw=o,
raw_class_name="numpy.ndarray",
serialized=json.dumps(o.tolist()),
Expand All @@ -42,7 +66,7 @@ def transform(o: Any) -> Item | None:
sk_dump = skops.io.dumps(o)
serialized_model = base64.b64encode(sk_dump).decode("ascii")
html_representation = sklearn.utils.estimator_html_repr(o)
return Item(
return SklearnModelItem(
raw=o,
raw_class_name="sklearn.base.BaseEstimator",
serialized=json.dumps(
Expand All @@ -52,29 +76,33 @@ def transform(o: Any) -> Item | None:
return None


def untransform(serialized: str, raw_class_name: str) -> Item:
def deserialize(serialized: str, raw_class_name: str) -> Item:
"""Transform a serialized Item back to an object based on the given class name."""
raw = None

match raw_class_name:
case "primitive":
case "jsonable":
raw = json.loads(serialized)
cls = JSONItem
case "pandas.DataFrame":
import pandas

raw = pandas.read_json(serialized, orient="split")
cls = DataFrameItem
case "numpy.ndarray":
import numpy

raw = numpy.array(json.loads(serialized))
cls = NumpyArrayItem
case "sklearn.base.BaseEstimator":
import skops.io

o = json.loads(serialized)
unserialized = base64.b64decode(o["skops"])
raw = skops.io.loads(unserialized)
cls = SklearnModelItem

return Item(raw=raw, raw_class_name=raw_class_name, serialized=serialized)
return cls(raw=raw, raw_class_name=raw_class_name, serialized=serialized)


class Project:
Expand All @@ -84,7 +112,7 @@ def __init__(self):
self.storage = {}

def put(self, key: str, value: Any):
self.put_item(key, transform(value))
self.put_item(key, serialize(value))

def put_item(self, key: str, item: Item):
self.storage[key] = (item.raw_class_name, item.serialized)
Expand All @@ -93,7 +121,7 @@ def get(self, key: str) -> Any:
return self.get_item(key).raw

def get_item(self, key: str) -> Item:
return untransform(*self.storage[key])
return deserialize(*self.storage[key])

def list_keys(self) -> List[str]:
"""List all keys in the project."""
Expand Down
24 changes: 12 additions & 12 deletions tests/unit/test_project.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,19 +5,19 @@
import pandas
import pandas.testing
import sklearn.svm
from skore.project import Item, transform, untransform
from skore.project import Item, serialize, deserialize, SklearnModelItem


def test_transform_primitive():
o = 3
actual = transform(o)
actual = serialize(o)
expected = Item(raw=3, raw_class_name="primitive", serialized="3")
assert actual == expected


def test_transform_pandas_dataframe():
o = pandas.DataFrame([{"key": "value"}])
actual = transform(o)
actual = serialize(o)
expected = Item(
raw=o,
raw_class_name="pandas.DataFrame",
Expand All @@ -29,7 +29,7 @@ def test_transform_pandas_dataframe():

def test_transform_numpy_ndarray():
o = numpy.array([1, 2, 3])
actual = transform(o)
actual = serialize(o)
expected = Item(
raw=o,
raw_class_name="numpy.ndarray",
Expand All @@ -44,8 +44,8 @@ def test_transform_sklearn_base_baseestimator(monkeypatch):
monkeypatch.setattr("skops.io.dumps", lambda _: b"")

o = sklearn.svm.SVC()
actual = transform(o)
expected = Item(
actual = serialize(o)
expected = SklearnModelItem(
raw=o,
raw_class_name="sklearn.base.BaseEstimator",
serialized=json.dumps(
Expand All @@ -61,9 +61,9 @@ def test_transform_sklearn_base_baseestimator(monkeypatch):

def test_untransform_primitive():
o = 3
transformed = transform(o)
transformed = serialize(o)
assert (
untransform(
deserialize(
serialized=transformed.serialized, raw_class_name=transformed.raw_class_name
).raw
== o
Expand All @@ -72,7 +72,7 @@ def test_untransform_primitive():

def test_untransform_pandas_dataframe():
o = pandas.DataFrame([{"key": "value"}])
item = untransform(
item = deserialize(
serialized=o.to_json(orient="split"),
raw_class_name="pandas.DataFrame",
)
Expand All @@ -82,7 +82,7 @@ def test_untransform_pandas_dataframe():

def test_untransform_numpy_ndarray():
o = numpy.array([1, 2, 3])
item = untransform(
item = deserialize(
serialized=json.dumps(o.tolist()),
raw_class_name="numpy.ndarray",
)
Expand All @@ -92,7 +92,7 @@ def test_untransform_numpy_ndarray():

def test_untransform_sklearn_model():
o = sklearn.svm.SVC()
t = transform(o)
u = untransform(t.serialized, t.raw_class_name)
t = serialize(o)
u = deserialize(t.serialized, t.raw_class_name)

assert isinstance(u.raw, sklearn.svm.SVC)

0 comments on commit de4fdad

Please sign in to comment.