Skip to content

Commit

Permalink
mob next [ci-skip] [ci skip] [skip ci]
Browse files Browse the repository at this point in the history
lastFile:src/skore/project.py
  • Loading branch information
rouk1 committed Sep 6, 2024
1 parent 316150a commit 8a139f2
Show file tree
Hide file tree
Showing 2 changed files with 72 additions and 75 deletions.
144 changes: 70 additions & 74 deletions src/skore/project.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import json
from abc import abstractmethod
from pathlib import Path
from typing import Any, List
from typing import Any, List, Protocol

import pandas

Expand All @@ -18,25 +18,17 @@ class Item(abc.ABC):

@property
@abstractmethod
def raw(): ...

@property
@abstractmethod
def transformed(): ...

@abstractmethod
def to_json() -> str: ...

@abstractmethod
def from_json(string: str): ...
def value(self) -> Any:
"""Get the value of the item."""
pass


class JSONItem(Item):
"""An item that is a JSON serializable Python object."""

def __init__(self, value):
self.__raw = None
self.__transformed = json.dumps(value)
self.__raw = value
self.transformed = json.dumps(value)

@property
def value(self):
Expand All @@ -48,87 +40,91 @@ class DataFrameItem(Item):

def __init__(self, value):
self.__raw = None
self.__transformed = value.to_json(orient="split")
self.transformed = value.to_json(orient="split")

@property
def value(self):
return pandas.read_json(self.__transformed, orient="split")

class Transformer:
@staticmethod
def transform(value) -> Item:
"""Transform a value into an Item."""
# Get the class of the value as a string
value_class = type(value).__name__

# Map the class name to the corresponding Item subclass
item_class_map = {
"dict": JSONItem,
"list": JSONItem,
"tuple": JSONItem,
"int": JSONItem,
"float": JSONItem,
"str": JSONItem,
"bool": JSONItem,
# TODO: Add
# pandas.DataFrame DataFrameItem
# polars.DataFrame DataFrameItem
# numpy.ndarray NumpyArrayItem
# matplotlib.figure.Figure MediaItem (as SVG)
# PIL.Image.Image MediaItem (as PNG)
# altair.Chart MediaItem (as Vega-Lite JSON)
# bytes + media type MediaItem
# str + media type MediaItem
# Scikit-learn model ScikitLearnModelItem
}

# Check if the value's class is in our mapping
if value_class in item_class_map:
return item_class_map[value_class](value)

# For scikit-learn models, we need to check if it's a BaseEstimator
# if hasattr(value, "fit") and hasattr(value, "predict"):
# return ScikitLearnModelItem(value)

# If we don't have a specific Item subclass for this type,
# we'll use the base Item class
return Item(value)
return pandas.read_json(self.transformed, orient="split")


class Storable(Protocol):
item: Item
class_name: str


def to_storable(value) -> Item:
"""Transform a value into an Item."""
# Get the class of the value as a string
value_class = type(value).__name__

# Map the class name to the corresponding Item subclass
item_class_map = {
"dict": JSONItem,
"list": JSONItem,
"tuple": JSONItem,
"int": JSONItem,
"float": JSONItem,
"str": JSONItem,
"bool": JSONItem,
"DataFrame": DataFrameItem,
# TODO: Add
# polars.DataFrame DataFrameItem
# numpy.ndarray NumpyArrayItem
# matplotlib.figure.Figure MediaItem (as SVG)
# PIL.Image.Image MediaItem (as PNG)
# altair.Chart MediaItem (as Vega-Lite JSON)
# bytes + media type MediaItem
# str + media type MediaItem
# Scikit-learn model ScikitLearnModelItem
}

# Check if the value's class is in our mapping
if value_class in item_class_map:
return Storable(
item=item_class_map[value_class](value),
class_name=type(value_class).__name__,
)
else:
return Storable(item=Item(value), class_name="Item")


def from_storable(storable: Storable) -> Item:
return storable.class_name #FIXME eval ?? : )


class Project:
"""A project is a collection of items that are stored in a storage."""

def __init__(self, transformer, storage):
self.transformer = transformer
self.storage = storage
def __init__(self):
self.storage = {}

def put(self, key, value):
def put(self, key: str, value: Any):
"""Put a value into the project."""
i = self.transformer.transform(value)
self.storage.save(
key,
i.value,
i.metadata,
)
i = to_storable(value)
self.put_item(key, i)

def put_item(self, key, item):
def put_item(self, key: str, item: Item):
"""Put an item into the project."""
self.storage.save(
key,
item,
)
self.storage[key] = (
item.transformed
) # FIXME store item class and implement a from_storable function

def get(self, key) -> Any:
def get(self, key: str) -> Any:
"""Get an item from the project."""
i = self.get_item(key)
return i.value

def get_item(self, key) -> Item:
def get_item(self, key: str) -> Item:
"""Get an item from the project."""
return from_storable(self.storage[key])

def list_keys(self) -> List[str]:
"""List all keys in the project."""
return list(self.storage.keys())

def delete_item(key: str):
def delete_item(self, key: str):
"""Delete an item from the project."""
del self.storage[key]


def load(project_name: str) -> Project:
Expand Down
3 changes: 2 additions & 1 deletion tests/unit/test_project.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from pandas import DataFrame
from pandas.testing import assert_frame_equal
from skore.project import Project


Expand All @@ -13,4 +14,4 @@ def test_dataframe_item():
project = Project()
df = DataFrame({"a": [1, 2, 3], "b": [4, 5, 6]})
project.put("pandas", df)
assert project.get("pandas") == df
assert_frame_equal(project.get("pandas"), df)

0 comments on commit 8a139f2

Please sign in to comment.