From 0c8f7f55e840df30a4036e52256e493bd1e374ce Mon Sep 17 00:00:00 2001 From: Auguste Baum Date: Mon, 9 Sep 2024 14:59:43 +0200 Subject: [PATCH] mob next [ci-skip] [ci skip] [skip ci] lastFile:tests/unit/test_project.py --- pyproject.toml | 1 + requirements-doc.txt | 46 ++++++++++++++++++++++++++++++++++---- requirements-test.txt | 35 +++++++++++++++++++++++++---- requirements-tools.txt | 35 ++++++++++++++++++++++++++--- requirements.txt | 35 ++++++++++++++++++++++++++--- src/skore/project.py | 12 ++++++++-- tests/unit/test_project.py | 13 ++++++++++- 7 files changed, 160 insertions(+), 17 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 004105597..894022985 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -30,6 +30,7 @@ dependencies = [ "pydantic_numpy", "uvicorn", "rich", + "skops", ] [project.optional-dependencies] diff --git a/requirements-doc.txt b/requirements-doc.txt index b749fa165..ec09e2b26 100644 --- a/requirements-doc.txt +++ b/requirements-doc.txt @@ -1,5 +1,5 @@ # -# This file is autogenerated by pip-compile with Python 3.12 +# This file is autogenerated by pip-compile with Python 3.11 # by the following command: # # pip-compile --extra=doc --output-file=requirements-doc.txt pyproject.toml @@ -49,6 +49,7 @@ filelock==3.15.4 # huggingface-hub # torch # transformers + # triton fonttools==4.53.1 # via matplotlib fsspec==2024.6.1 @@ -60,6 +61,7 @@ h11==0.14.0 huggingface-hub==0.24.6 # via # sentence-transformers + # skops # tokenizers # transformers idna==3.7 @@ -122,12 +124,44 @@ numpy==2.1.0 # skrub # statsmodels # transformers +nvidia-cublas-cu12==12.1.3.1 + # via + # nvidia-cudnn-cu12 + # nvidia-cusolver-cu12 + # torch +nvidia-cuda-cupti-cu12==12.1.105 + # via torch +nvidia-cuda-nvrtc-cu12==12.1.105 + # via torch +nvidia-cuda-runtime-cu12==12.1.105 + # via torch +nvidia-cudnn-cu12==9.1.0.70 + # via torch +nvidia-cufft-cu12==11.0.2.54 + # via torch +nvidia-curand-cu12==10.3.2.106 + # via torch +nvidia-cusolver-cu12==11.4.5.107 + # via torch +nvidia-cusparse-cu12==12.1.0.106 + # via + # nvidia-cusolver-cu12 + # torch +nvidia-nccl-cu12==2.20.5 + # via torch +nvidia-nvjitlink-cu12==12.6.68 + # via + # nvidia-cusolver-cu12 + # nvidia-cusparse-cu12 +nvidia-nvtx-cu12==12.1.105 + # via torch packaging==24.1 # via # altair # huggingface-hub # matplotlib # pydata-sphinx-theme + # skops # skrub # sphinx # statsmodels @@ -204,6 +238,7 @@ safetensors==0.4.4 scikit-learn==1.5.1 # via # sentence-transformers + # skops # skore (pyproject.toml) # skrub scipy==1.14.1 @@ -222,6 +257,8 @@ six==1.16.0 # via # patsy # python-dateutil +skops==0.10.0 + # via skore (pyproject.toml) skrub==0.3.0 # via skore (pyproject.toml) sniffio==1.3.1 @@ -255,6 +292,8 @@ statsmodels==0.14.2 # via skore (pyproject.toml) sympy==1.13.2 # via torch +tabulate==0.9.0 + # via skops threadpoolctl==3.5.0 # via scikit-learn tokenizers==0.19.1 @@ -269,6 +308,8 @@ tqdm==4.66.5 # transformers transformers==4.44.1 # via sentence-transformers +triton==3.0.0 + # via torch typeguard==4.3.0 # via inflect typing-extensions==4.12.2 @@ -287,6 +328,3 @@ urllib3==2.2.2 # via requests uvicorn==0.30.6 # via skore (pyproject.toml) - -# The following packages are considered to be unsafe in a requirements file: -# setuptools diff --git a/requirements-test.txt b/requirements-test.txt index 0afeb7082..01d9d850d 100644 --- a/requirements-test.txt +++ b/requirements-test.txt @@ -1,5 +1,5 @@ # -# This file is autogenerated by pip-compile with Python 3.12 +# This file is autogenerated by pip-compile with Python 3.11 # by the following command: # # pip-compile --extra=test --output-file=requirements-test.txt pyproject.toml @@ -22,8 +22,11 @@ certifi==2024.7.4 # via # httpcore # httpx + # requests cfgv==3.4.0 # via pre-commit +charset-normalizer==3.3.2 + # via requests click==8.1.7 # via uvicorn compress-pickle[lz4]==2.1.0 @@ -41,11 +44,15 @@ distlib==0.3.8 fastapi==0.112.1 # via skore (pyproject.toml) filelock==3.15.4 - # via virtualenv + # via + # huggingface-hub + # virtualenv fonttools==4.53.1 # via matplotlib fqdn==1.5.1 # via jsonschema +fsspec==2024.9.0 + # via huggingface-hub h11==0.14.0 # via # httpcore @@ -54,6 +61,8 @@ httpcore==1.0.5 # via httpx httpx==0.27.0 # via skore (pyproject.toml) +huggingface-hub==0.24.6 + # via skops identify==2.6.0 # via pre-commit idna==3.7 @@ -61,6 +70,7 @@ idna==3.7 # anyio # httpx # jsonschema + # requests iniconfig==2.0.0 # via pytest isoduration==20.11.0 @@ -104,8 +114,10 @@ numpy==2.1.0 packaging==24.1 # via # altair + # huggingface-hub # matplotlib # pytest + # skops pandas==2.2.2 # via skore (pyproject.toml) pillow==10.4.0 @@ -148,11 +160,15 @@ python-dateutil==2.9.0.post0 pytz==2024.1 # via pandas pyyaml==6.0.2 - # via pre-commit + # via + # huggingface-hub + # pre-commit referencing==0.35.1 # via # jsonschema # jsonschema-specifications +requests==2.32.3 + # via huggingface-hub rfc3339-validator==0.1.4 # via jsonschema rfc3987==1.3.8 @@ -170,7 +186,9 @@ ruamel-yaml-clib==0.2.8 ruff==0.6.1 # via skore (pyproject.toml) scikit-learn==1.5.1 - # via skore (pyproject.toml) + # via + # skops + # skore (pyproject.toml) scipy==1.14.1 # via scikit-learn semver==3.0.2 @@ -179,26 +197,35 @@ six==1.16.0 # via # python-dateutil # rfc3339-validator +skops==0.10.0 + # via skore (pyproject.toml) sniffio==1.3.1 # via # anyio # httpx starlette==0.38.2 # via fastapi +tabulate==0.9.0 + # via skops threadpoolctl==3.5.0 # via scikit-learn +tqdm==4.66.5 + # via huggingface-hub types-python-dateutil==2.9.0.20240821 # via arrow typing-extensions==4.12.2 # via # altair # fastapi + # huggingface-hub # pydantic # pydantic-core tzdata==2024.1 # via pandas uri-template==1.3.0 # via jsonschema +urllib3==2.2.2 + # via requests uvicorn==0.30.6 # via skore (pyproject.toml) virtualenv==20.26.3 diff --git a/requirements-tools.txt b/requirements-tools.txt index 234610864..d12484aff 100644 --- a/requirements-tools.txt +++ b/requirements-tools.txt @@ -1,5 +1,5 @@ # -# This file is autogenerated by pip-compile with Python 3.12 +# This file is autogenerated by pip-compile with Python 3.11 # by the following command: # # pip-compile --extra=tools --output-file=requirements-tools.txt pyproject.toml @@ -16,6 +16,10 @@ attrs==24.2.0 # referencing build==1.2.1 # via pip-tools +certifi==2024.8.30 + # via requests +charset-normalizer==3.3.2 + # via requests click==8.1.7 # via # pip-tools @@ -30,12 +34,20 @@ diskcache==5.6.3 # via skore (pyproject.toml) fastapi==0.112.1 # via skore (pyproject.toml) +filelock==3.16.0 + # via huggingface-hub fonttools==4.53.1 # via matplotlib +fsspec==2024.9.0 + # via huggingface-hub h11==0.14.0 # via uvicorn +huggingface-hub==0.24.6 + # via skops idna==3.7 - # via anyio + # via + # anyio + # requests jinja2==3.1.4 # via altair joblib==1.4.2 @@ -70,7 +82,9 @@ packaging==24.1 # via # altair # build + # huggingface-hub # matplotlib + # skops pandas==2.2.2 # via skore (pyproject.toml) pillow==10.4.0 @@ -102,10 +116,14 @@ python-dateutil==2.9.0.post0 # pandas pytz==2024.1 # via pandas +pyyaml==6.0.2 + # via huggingface-hub referencing==0.35.1 # via # jsonschema # jsonschema-specifications +requests==2.32.3 + # via huggingface-hub rich==13.7.1 # via skore (pyproject.toml) rpds-py==0.20.0 @@ -117,27 +135,38 @@ ruamel-yaml==0.18.6 ruamel-yaml-clib==0.2.8 # via ruamel-yaml scikit-learn==1.5.1 - # via skore (pyproject.toml) + # via + # skops + # skore (pyproject.toml) scipy==1.14.1 # via scikit-learn semver==3.0.2 # via pydantic-numpy six==1.16.0 # via python-dateutil +skops==0.10.0 + # via skore (pyproject.toml) sniffio==1.3.1 # via anyio starlette==0.38.2 # via fastapi +tabulate==0.9.0 + # via skops threadpoolctl==3.5.0 # via scikit-learn +tqdm==4.66.5 + # via huggingface-hub typing-extensions==4.12.2 # via # altair # fastapi + # huggingface-hub # pydantic # pydantic-core tzdata==2024.1 # via pandas +urllib3==2.2.2 + # via requests uvicorn==0.30.6 # via skore (pyproject.toml) wheel==0.44.0 diff --git a/requirements.txt b/requirements.txt index a468c0cb2..0af2cae0f 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,5 +1,5 @@ # -# This file is autogenerated by pip-compile with Python 3.12 +# This file is autogenerated by pip-compile with Python 3.11 # by the following command: # # pip-compile --output-file=requirements.txt pyproject.toml @@ -14,6 +14,10 @@ attrs==24.2.0 # via # jsonschema # referencing +certifi==2024.8.30 + # via requests +charset-normalizer==3.3.2 + # via requests click==8.1.7 # via uvicorn compress-pickle[lz4]==2.1.0 @@ -26,12 +30,20 @@ diskcache==5.6.3 # via skore (pyproject.toml) fastapi==0.112.1 # via skore (pyproject.toml) +filelock==3.16.0 + # via huggingface-hub fonttools==4.53.1 # via matplotlib +fsspec==2024.9.0 + # via huggingface-hub h11==0.14.0 # via uvicorn +huggingface-hub==0.24.6 + # via skops idna==3.7 - # via anyio + # via + # anyio + # requests jinja2==3.1.4 # via altair joblib==1.4.2 @@ -65,7 +77,9 @@ numpy==2.1.0 packaging==24.1 # via # altair + # huggingface-hub # matplotlib + # skops pandas==2.2.2 # via skore (pyproject.toml) pillow==10.4.0 @@ -91,10 +105,14 @@ python-dateutil==2.9.0.post0 # pandas pytz==2024.1 # via pandas +pyyaml==6.0.2 + # via huggingface-hub referencing==0.35.1 # via # jsonschema # jsonschema-specifications +requests==2.32.3 + # via huggingface-hub rich==13.7.1 # via skore (pyproject.toml) rpds-py==0.20.0 @@ -106,26 +124,37 @@ ruamel-yaml==0.18.6 ruamel-yaml-clib==0.2.8 # via ruamel-yaml scikit-learn==1.5.1 - # via skore (pyproject.toml) + # via + # skops + # skore (pyproject.toml) scipy==1.14.1 # via scikit-learn semver==3.0.2 # via pydantic-numpy six==1.16.0 # via python-dateutil +skops==0.10.0 + # via skore (pyproject.toml) sniffio==1.3.1 # via anyio starlette==0.38.2 # via fastapi +tabulate==0.9.0 + # via skops threadpoolctl==3.5.0 # via scikit-learn +tqdm==4.66.5 + # via huggingface-hub typing-extensions==4.12.2 # via # altair # fastapi + # huggingface-hub # pydantic # pydantic-core tzdata==2024.1 # via pandas +urllib3==2.2.2 + # via requests uvicorn==0.30.6 # via skore (pyproject.toml) diff --git a/src/skore/project.py b/src/skore/project.py index 32a50f307..1906f3351 100644 --- a/src/skore/project.py +++ b/src/skore/project.py @@ -52,6 +52,14 @@ def transform(o: Any) -> Item | None: return None +def untransform(serialized: str, raw_class_name: str) -> Any: + """Transform a serialized Item back to an object based on the given class name.""" + match raw_class_name: + case "primitive": + return json.loads(serialized) + # casec + + class Project: """A project is a collection of items that are stored in a storage.""" @@ -60,7 +68,7 @@ def __init__(self): def put(self, key: str, value: Any): """Put a value into the project.""" - i = to_storable(value) + i = value self.put_item(key, i) def put_item(self, key: str, item: Item): @@ -76,7 +84,7 @@ def get(self, key: str) -> Any: def get_item(self, key: str) -> Item: """Get an item from the project.""" - return from_storable(self.storage[key]) + return self.storage[key] def list_keys(self) -> List[str]: """List all keys in the project.""" diff --git a/tests/unit/test_project.py b/tests/unit/test_project.py index 84b815462..894b2e44f 100644 --- a/tests/unit/test_project.py +++ b/tests/unit/test_project.py @@ -3,7 +3,7 @@ import numpy import pandas import sklearn.svm -from skore.project import Item, transform +from skore.project import Item, transform, untransform def test_transform_primitive(): @@ -55,3 +55,14 @@ def test_transform_sklearn_base_baseestimator(monkeypatch): ) assert actual == expected + + +def test_untransform(): + o = 3 + transformed = transform(o) + assert ( + untransform( + serialized=transformed.serialized, raw_class_name=transformed.raw_class_name + ) + == o + )