Skip to content

Commit

Permalink
fix(SklearnBaseEstimatorItem): Temporarily disable skops's security f…
Browse files Browse the repository at this point in the history
  • Loading branch information
thomass-dev authored Oct 2, 2024
1 parent 188b502 commit 4ca60a6
Show file tree
Hide file tree
Showing 2 changed files with 73 additions and 17 deletions.
29 changes: 21 additions & 8 deletions src/skore/item/sklearn_base_estimator_item.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,9 @@ class SklearnBaseEstimatorItem(Item):

def __init__(
self,
estimator_skops,
estimator_html_repr,
estimator_html_repr: str,
estimator_skops: bytes,
estimator_skops_untrusted_types: list[str],
created_at: str | None = None,
updated_at: str | None = None,
):
Expand All @@ -35,19 +36,22 @@ def __init__(
Parameters
----------
estimator_skops : Any
The skops representation of the scikit-learn estimator.
estimator_html_repr : str
The HTML representation of the scikit-learn estimator.
estimator_skops : bytes
The skops representation of the scikit-learn estimator.
estimator_skops_untrusted_types : list[str]
The list of untrusted types in the skops representation.
created_at : str, optional
The creation timestamp in ISO format.
updated_at : str, optional
The last update timestamp in ISO format.
"""
super().__init__(created_at, updated_at)

self.estimator_skops = estimator_skops
self.estimator_html_repr = estimator_html_repr
self.estimator_skops = estimator_skops
self.estimator_skops_untrusted_types = estimator_skops_untrusted_types

@cached_property
def estimator(self) -> sklearn.base.BaseEstimator:
Expand All @@ -61,7 +65,9 @@ def estimator(self) -> sklearn.base.BaseEstimator:
"""
import skops.io

return skops.io.loads(self.estimator_skops)
return skops.io.loads(
self.estimator_skops, trusted=self.estimator_skops_untrusted_types
)

@classmethod
def factory(cls, estimator: sklearn.base.BaseEstimator) -> SklearnBaseEstimatorItem:
Expand All @@ -85,9 +91,16 @@ def factory(cls, estimator: sklearn.base.BaseEstimator) -> SklearnBaseEstimatorI
if not isinstance(estimator, sklearn.base.BaseEstimator):
raise TypeError(f"Type '{estimator.__class__}' is not supported.")

estimator_html_repr = sklearn.utils.estimator_html_repr(estimator)
estimator_skops = skops.io.dumps(estimator)
estimator_skops_untrusted_types = skops.io.get_untrusted_types(
data=estimator_skops
)

instance = cls(
estimator_skops=skops.io.dumps(estimator),
estimator_html_repr=sklearn.utils.estimator_html_repr(estimator),
estimator_html_repr=estimator_html_repr,
estimator_skops=estimator_skops,
estimator_skops_untrusted_types=estimator_skops_untrusted_types,
)

# add estimator as cached property
Expand Down
61 changes: 52 additions & 9 deletions tests/unit/item/test_sklearn_base_estimator_item.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,42 +4,85 @@
from skore.item import SklearnBaseEstimatorItem


class Estimator(sklearn.svm.SVC):
pass


class TestSklearnBaseEstimatorItem:
@pytest.fixture(autouse=True)
def monkeypatch_datetime(self, monkeypatch, MockDatetime):
monkeypatch.setattr("skore.item.item.datetime", MockDatetime)

@pytest.mark.order(0)
def test_factory(self, monkeypatch, mock_nowstr):
monkeypatch.setattr("skops.io.dumps", lambda _: "<estimator_skops>")
monkeypatch.setattr(
"sklearn.utils.estimator_html_repr", lambda _: "<estimator_html_repr>"
)

estimator = sklearn.svm.SVC()
estimator_skops = "<estimator_skops>"
estimator_html_repr = "<estimator_html_repr>"
estimator_skops = "<estimator_skops>"
estimator_skops_untrusted_types = "<estimator_skops_untrusted_types>"

monkeypatch.setattr(
"sklearn.utils.estimator_html_repr",
lambda *args, **kwargs: estimator_html_repr,
)
monkeypatch.setattr("skops.io.dumps", lambda *args, **kwargs: estimator_skops)
monkeypatch.setattr(
"skops.io.get_untrusted_types",
lambda *args, **kwargs: estimator_skops_untrusted_types,
)

item = SklearnBaseEstimatorItem.factory(estimator)

assert item.estimator_skops == estimator_skops
assert item.estimator_html_repr == estimator_html_repr
assert item.estimator_skops == estimator_skops
assert item.estimator_skops_untrusted_types == estimator_skops_untrusted_types
assert item.created_at == mock_nowstr
assert item.updated_at == mock_nowstr

@pytest.mark.order(1)
def test_estimator(self, mock_nowstr):
estimator = sklearn.svm.SVC()
estimator_skops = skops.io.dumps(estimator)
estimator_html_repr = "<estimator_html_repr>"
estimator_skops_untrusted_types = skops.io.get_untrusted_types(
data=estimator_skops
)

item1 = SklearnBaseEstimatorItem.factory(estimator)
item2 = SklearnBaseEstimatorItem(
estimator_html_repr=None,
estimator_skops=estimator_skops,
estimator_html_repr=estimator_html_repr,
estimator_skops_untrusted_types=estimator_skops_untrusted_types,
created_at=mock_nowstr,
updated_at=mock_nowstr,
)

assert isinstance(item1.estimator, sklearn.svm.SVC)
assert isinstance(item2.estimator, sklearn.svm.SVC)

@pytest.mark.order(1)
def test_estimator_untrusted(self, mock_nowstr):
estimator = Estimator()
estimator_skops = skops.io.dumps(estimator)
estimator_skops_untrusted_types = skops.io.get_untrusted_types(
data=estimator_skops
)

if not estimator_skops_untrusted_types:
pytest.skip(
"""
This test is only intended to exhaustively test an untrusted estimator.
The untrusted Estimator class seems to be trusted by default.
Something changed in `skops`.
"""
)

item1 = SklearnBaseEstimatorItem.factory(estimator)
item2 = SklearnBaseEstimatorItem(
estimator_html_repr=None,
estimator_skops=estimator_skops,
estimator_skops_untrusted_types=estimator_skops_untrusted_types,
created_at=mock_nowstr,
updated_at=mock_nowstr,
)

assert isinstance(item1.estimator, Estimator)
assert isinstance(item2.estimator, Estimator)

0 comments on commit 4ca60a6

Please sign in to comment.