Skip to content

Commit

Permalink
mob next [ci-skip] [ci skip] [skip ci]
Browse files Browse the repository at this point in the history
lastFile:tests/unit/test_project.py
  • Loading branch information
rouk1 committed Sep 6, 2024
1 parent 6fcb488 commit 316150a
Show file tree
Hide file tree
Showing 2 changed files with 30 additions and 47 deletions.
61 changes: 14 additions & 47 deletions src/skore/project.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@
from pathlib import Path
from typing import Any, List

import pandas


def get_filepath_from_project_name(project_name) -> Path:
return None
Expand Down Expand Up @@ -33,61 +35,24 @@ class JSONItem(Item):
"""An item that is a JSON serializable Python object."""

def __init__(self, value):
self.raw = value

# to do cache
@property
def raw(self):
return self.raw
self.__raw = None
self.__transformed = json.dumps(value)

# to do cache
@property
def transformed(self):
return json.dumps(self.raw)

def value(self):
return self.__raw


class DataFrameItem(Item):
"""An item that is a pandas or polars data frame."""

def __init__(self, value):
self.__raw = None
self.__transformed = value.to_json(orient="split")

@property
def raw(self):...
# return self.__raw

@property
def transformed(self):
return self.__transformed

def from_json(string): ...

def to_json(string): ...


class NumpyArrayItem(Item):
@property
def raw(self): ...

@property
def transformed(self): ...

def to_json(self): ...

def from_json(self): ...


class MediaItem(Item):
"""An item that is a media object."""


class ScikitLearnModelItem(Item):
"""An item that is a Scikit-Learn model."""

def __init__(self, value):
self.value = value

def value(self):
return pandas.read_json(self.__transformed, orient="split")

class Transformer:
@staticmethod
Expand Down Expand Up @@ -122,8 +87,8 @@ def transform(value) -> Item:
return item_class_map[value_class](value)

# For scikit-learn models, we need to check if it's a BaseEstimator
if hasattr(value, "fit") and hasattr(value, "predict"):
return ScikitLearnModelItem(value)
# if hasattr(value, "fit") and hasattr(value, "predict"):
# return ScikitLearnModelItem(value)

# If we don't have a specific Item subclass for this type,
# we'll use the base Item class
Expand All @@ -139,9 +104,11 @@ def __init__(self, transformer, storage):

def put(self, key, value):
"""Put a value into the project."""
i = self.transformer.transform(value)
self.storage.save(
key,
self.transformer.transform(value),
i.value,
i.metadata,
)

def put_item(self, key, item):
Expand Down
16 changes: 16 additions & 0 deletions tests/unit/test_project.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
from pandas import DataFrame
from skore.project import Project


def test_json_item():
project = Project()
d = {"a": 1, "b": 2}
project.put("test", d)
assert project.get("test") == d


def test_dataframe_item():
project = Project()
df = DataFrame({"a": [1, 2, 3], "b": [4, 5, 6]})
project.put("pandas", df)
assert project.get("pandas") == df

0 comments on commit 316150a

Please sign in to comment.