diff --git a/src/skore/project.py b/src/skore/project.py index 5ee9afd31..2fa51bdb8 100644 --- a/src/skore/project.py +++ b/src/skore/project.py @@ -9,45 +9,46 @@ class Item: """An item is a value that is stored in the project.""" + serialized: str raw: Any = None raw_class_name: str | None = None - transformed: Any = None -def get_fully_qualified_class_name(obj): - """ - Get the fully qualified class name of an object. - - This function returns the fully qualified class name of the given object, - including the module name if available. - - Parameters - ---------- - obj : object - Any Python object. - - Returns - ------- - str - The fully qualified class name as a string. - """ - module = obj.__class__.__module__ - if module is None or module == str.__class__.__module__: - return obj.__class__.__name__ - return module + "." + obj.__class__.__name__ - def transform(o: Any) -> Item: """Transform an object into an item.""" try: serialized = json.dumps(o) - except AttributeError: - value_class_name = None - match value_class_name: - case "DataFrame": - return Item(raw=o, raw_class_name=value_class_name, transformed=o.to_dict()) - case _: - return Item(raw=o, raw_class_name=value_class_name) + return Item(raw=o, raw_class_name="primitive", serialized=serialized) + except TypeError: + import numpy + import pandas + import sklearn + import skops.io + + if isinstance(o, pandas.DataFrame): + return Item( + raw=o, + raw_class_name="pandas.DataFrame", + serialized=o.to_json(orient="split"), + ) + if isinstance(o, numpy.ndarray): + return Item( + raw=o, + raw_class_name="numpy.ndarray", + serialized=json.dumps(o.tolist()), + ) + if isinstance(o, sklearn.base.BaseEstimator): + return Item( + raw=o, + raw_class_name="sklearn.base.BaseEstimator", + serialized=json.dumps( + { + "skops": skops.io.dumps(o), + "html": sklearn.utils.estimator_html_repr(o), + } + ), + ) class Project: @@ -64,7 +65,7 @@ def put(self, key: str, value: Any): def put_item(self, key: str, item: Item): """Put an item into the project.""" self.storage[key] = ( - item.transformed + item.serialized ) # FIXME store item class and implement a from_storable function def get(self, key: str) -> Any: diff --git a/tests/unit/test_project.py b/tests/unit/test_project.py index 6b89dbcbf..69d673fe1 100644 --- a/tests/unit/test_project.py +++ b/tests/unit/test_project.py @@ -1,17 +1,21 @@ -from pandas import DataFrame -from pandas.testing import assert_frame_equal -from skore.project import Project +import pandas +from skore.project import Item, transform -def test_json_item(): - project = Project() - d = {"a": 1, "b": 2} - project.put("test", d) - assert project.get("test") == d +def test_transform_primitive(): + o = 3 + actual = transform(o) + expected = Item(raw=3, raw_class_name="primitive", serialized="3") + assert actual == expected -def test_dataframe_item(): - project = Project() - df = DataFrame({"a": [1, 2, 3], "b": [4, 5, 6]}) - project.put("pandas", df) - assert_frame_equal(project.get("pandas"), df) +def test_transform_pandas_dataframe(): + o = pandas.DataFrame() + actual = transform(o) + expected = Item( + raw=pandas.DataFrame(), raw_class_name="pandas.DataFrame", serialized="" + ) + assert actual == expected + + # o = 3.3 + # transform(o)