From e684f8e901aeeeba7e93aeb91d37c15074eacff2 Mon Sep 17 00:00:00 2001 From: Thomas S Date: Tue, 1 Oct 2024 15:59:05 +0200 Subject: [PATCH 1/2] fix: Temporarily disable security feature of skops used in `SklearnBaseEstimatorItem` --- src/skore/item/sklearn_base_estimator_item.py | 29 ++++++++++++++----- .../item/test_sklearn_base_estimator_item.py | 28 ++++++++++++------ 2 files changed, 40 insertions(+), 17 deletions(-) diff --git a/src/skore/item/sklearn_base_estimator_item.py b/src/skore/item/sklearn_base_estimator_item.py index b81959b0..6100fb45 100644 --- a/src/skore/item/sklearn_base_estimator_item.py +++ b/src/skore/item/sklearn_base_estimator_item.py @@ -25,8 +25,9 @@ class SklearnBaseEstimatorItem(Item): def __init__( self, - estimator_skops, - estimator_html_repr, + estimator_html_repr: str, + estimator_skops: bytes, + estimator_skops_untrusted_types: list[str], created_at: str | None = None, updated_at: str | None = None, ): @@ -35,10 +36,12 @@ def __init__( Parameters ---------- - estimator_skops : Any - The skops representation of the scikit-learn estimator. estimator_html_repr : str The HTML representation of the scikit-learn estimator. + estimator_skops : bytes + The skops representation of the scikit-learn estimator. + estimator_skops_untrusted_types : list[str] + The list of untrusted types in the skops representation. created_at : str, optional The creation timestamp in ISO format. updated_at : str, optional @@ -46,8 +49,9 @@ def __init__( """ super().__init__(created_at, updated_at) - self.estimator_skops = estimator_skops self.estimator_html_repr = estimator_html_repr + self.estimator_skops = estimator_skops + self.estimator_skops_untrusted_types = estimator_skops_untrusted_types @cached_property def estimator(self) -> sklearn.base.BaseEstimator: @@ -61,7 +65,9 @@ def estimator(self) -> sklearn.base.BaseEstimator: """ import skops.io - return skops.io.loads(self.estimator_skops) + return skops.io.loads( + self.estimator_skops, trusted=self.estimator_skops_untrusted_types + ) @classmethod def factory(cls, estimator: sklearn.base.BaseEstimator) -> SklearnBaseEstimatorItem: @@ -85,9 +91,16 @@ def factory(cls, estimator: sklearn.base.BaseEstimator) -> SklearnBaseEstimatorI if not isinstance(estimator, sklearn.base.BaseEstimator): raise TypeError(f"Type '{estimator.__class__}' is not supported.") + estimator_html_repr = sklearn.utils.estimator_html_repr(estimator) + estimator_skops = skops.io.dumps(estimator) + estimator_skops_untrusted_types = skops.io.get_untrusted_types( + data=estimator_skops + ) + instance = cls( - estimator_skops=skops.io.dumps(estimator), - estimator_html_repr=sklearn.utils.estimator_html_repr(estimator), + estimator_html_repr=estimator_html_repr, + estimator_skops=estimator_skops, + estimator_skops_untrusted_types=estimator_skops_untrusted_types, ) # add estimator as cached property diff --git a/tests/unit/item/test_sklearn_base_estimator_item.py b/tests/unit/item/test_sklearn_base_estimator_item.py index 8003430b..6e79f1d3 100644 --- a/tests/unit/item/test_sklearn_base_estimator_item.py +++ b/tests/unit/item/test_sklearn_base_estimator_item.py @@ -11,19 +11,26 @@ def monkeypatch_datetime(self, monkeypatch, MockDatetime): @pytest.mark.order(0) def test_factory(self, monkeypatch, mock_nowstr): - monkeypatch.setattr("skops.io.dumps", lambda _: "") - monkeypatch.setattr( - "sklearn.utils.estimator_html_repr", lambda _: "" - ) - estimator = sklearn.svm.SVC() - estimator_skops = "" estimator_html_repr = "" + estimator_skops = "" + estimator_skops_untrusted_types = "" + + monkeypatch.setattr( + "sklearn.utils.estimator_html_repr", + lambda *args, **kwargs: estimator_html_repr, + ) + monkeypatch.setattr("skops.io.dumps", lambda *args, **kwargs: estimator_skops) + monkeypatch.setattr( + "skops.io.get_untrusted_types", + lambda *args, **kwargs: estimator_skops_untrusted_types, + ) item = SklearnBaseEstimatorItem.factory(estimator) - assert item.estimator_skops == estimator_skops assert item.estimator_html_repr == estimator_html_repr + assert item.estimator_skops == estimator_skops + assert item.estimator_skops_untrusted_types == estimator_skops_untrusted_types assert item.created_at == mock_nowstr assert item.updated_at == mock_nowstr @@ -31,12 +38,15 @@ def test_factory(self, monkeypatch, mock_nowstr): def test_estimator(self, mock_nowstr): estimator = sklearn.svm.SVC() estimator_skops = skops.io.dumps(estimator) - estimator_html_repr = "" + estimator_skops_untrusted_types = skops.io.get_untrusted_types( + data=estimator_skops + ) item1 = SklearnBaseEstimatorItem.factory(estimator) item2 = SklearnBaseEstimatorItem( + estimator_html_repr=None, estimator_skops=estimator_skops, - estimator_html_repr=estimator_html_repr, + estimator_skops_untrusted_types=estimator_skops_untrusted_types, created_at=mock_nowstr, updated_at=mock_nowstr, ) From cd0a8ced31edf474c941d9811cc63b18a3efa3cd Mon Sep 17 00:00:00 2001 From: Thomas S Date: Wed, 2 Oct 2024 10:03:05 +0200 Subject: [PATCH 2/2] Add untrusted estimator test --- .../item/test_sklearn_base_estimator_item.py | 33 +++++++++++++++++++ 1 file changed, 33 insertions(+) diff --git a/tests/unit/item/test_sklearn_base_estimator_item.py b/tests/unit/item/test_sklearn_base_estimator_item.py index 6e79f1d3..1eca7623 100644 --- a/tests/unit/item/test_sklearn_base_estimator_item.py +++ b/tests/unit/item/test_sklearn_base_estimator_item.py @@ -4,6 +4,10 @@ from skore.item import SklearnBaseEstimatorItem +class Estimator(sklearn.svm.SVC): + pass + + class TestSklearnBaseEstimatorItem: @pytest.fixture(autouse=True) def monkeypatch_datetime(self, monkeypatch, MockDatetime): @@ -53,3 +57,32 @@ def test_estimator(self, mock_nowstr): assert isinstance(item1.estimator, sklearn.svm.SVC) assert isinstance(item2.estimator, sklearn.svm.SVC) + + @pytest.mark.order(1) + def test_estimator_untrusted(self, mock_nowstr): + estimator = Estimator() + estimator_skops = skops.io.dumps(estimator) + estimator_skops_untrusted_types = skops.io.get_untrusted_types( + data=estimator_skops + ) + + if not estimator_skops_untrusted_types: + pytest.skip( + """ + This test is only intended to exhaustively test an untrusted estimator. + The untrusted Estimator class seems to be trusted by default. + Something changed in `skops`. + """ + ) + + item1 = SklearnBaseEstimatorItem.factory(estimator) + item2 = SklearnBaseEstimatorItem( + estimator_html_repr=None, + estimator_skops=estimator_skops, + estimator_skops_untrusted_types=estimator_skops_untrusted_types, + created_at=mock_nowstr, + updated_at=mock_nowstr, + ) + + assert isinstance(item1.estimator, Estimator) + assert isinstance(item2.estimator, Estimator)