Skip to content

Commit

Permalink
mob next [ci-skip] [ci skip] [skip ci]
Browse files Browse the repository at this point in the history
lastFile:tests/unit/test_project.py
  • Loading branch information
rouk1 committed Sep 9, 2024
1 parent 5741248 commit 55c761b
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 10 deletions.
14 changes: 9 additions & 5 deletions src/skore/project.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,16 +61,20 @@ def untransform(serialized: str, raw_class_name: str) -> Item:
raw = json.loads(serialized)
case "pandas.DataFrame":
import pandas

raw = pandas.read_json(serialized, orient="split")
case "numpy.ndarray":
import numpy

raw = numpy.array(json.loads(serialized))
case "sklearn.base.BaseEstimator":
import skops.io

o = json.loads(serialized)
unserialized = base64.b64decode(o["skops"])
raw = skops.io.loads(unserialized)

return Item(
raw=raw,
raw_class_name=raw_class_name,
serialized=serialized
)
return Item(raw=raw, raw_class_name=raw_class_name, serialized=serialized)


class Project:
Expand Down
16 changes: 11 additions & 5 deletions tests/unit/test_project.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import json

import numpy
import numpy.testing
import pandas
import pandas.testing
import sklearn.svm
Expand Down Expand Up @@ -76,10 +77,7 @@ def test_untransform_pandas_dataframe():
raw_class_name="pandas.DataFrame",
)

pandas.testing.assert_frame_equal(
item.raw,
o
)
pandas.testing.assert_frame_equal(item.raw, o)


def test_untransform_numpy_ndarray():
Expand All @@ -89,4 +87,12 @@ def test_untransform_numpy_ndarray():
raw_class_name="numpy.ndarray",
)

assert o == item.raw
numpy.testing.assert_array_equal(o, item.raw)


def test_untransform_sklearn_model():
o = sklearn.svm.SVC()
t = transform(o)
u = untransform(t.serialized, t.raw_class_name)

assert isinstance(u.raw, sklearn.svm.SVC)

0 comments on commit 55c761b

Please sign in to comment.