From beaa234e9f72527aaad238332bd4b477095565b6 Mon Sep 17 00:00:00 2001 From: Benjamin Bossan Date: Wed, 5 Apr 2023 15:03:13 +0200 Subject: [PATCH 1/2] A little hardening of the auditing of Nodes Two measures to harden the auditing (a little bit): - Type annotate the Node's children to prevent setting invalid types. - Change all the tests that use loads to only load trusted types instead of using trusted=True The latter is importent because when setting trusted=True, the whole machinery of checking types is not executed, so any bugs that may be contained there will not be revealed. In particular, this shows that for persisting methods, we had a child with a str type and that would raise an error, i.e. loading method types was not possible for users who passed trusted!=True. Additional changes As a consequence of the last point, the auditing code has been changed to accept str as type. Alternatively, we can make the change explained here: https://github.com/skops-dev/skops/pull/338#discussion_r1156151738 i.e. not storing the method name in children. Another "victim" of this change is that the so far dead code of checking for primitive types inside of get_unsafe_set has been removed. This code was supposed to check if the type is a primitive type but it was defective. get_module(child) would raise an error if an instance of the type would be passed. We could theoretically fix that code, but it would still be dead code because primitive types are stored as json. Another small change is to exclude the code in skops/io/old from mypy checks. Otherwise, we would have to update its type signatures if signatures in the persistence code change. --- pyproject.toml | 2 +- skops/io/_audit.py | 20 ++++++------- skops/io/_sklearn.py | 4 +-- skops/io/_visualize.py | 4 +-- skops/io/tests/test_persist.py | 52 +++++++++++++++++++++++++--------- 5 files changed, 52 insertions(+), 30 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index c957920e..1ac0b580 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -34,6 +34,6 @@ omit = [ ] [tool.mypy] -exclude = "(\\w+/)*test_\\w+\\.py$" +exclude = "(\\w+/)*test_\\w+\\.py$|old" ignore_missing_imports = true no_implicit_optional = true diff --git a/skops/io/_audit.py b/skops/io/_audit.py index 067b13c5..ea7253e5 100644 --- a/skops/io/_audit.py +++ b/skops/io/_audit.py @@ -2,14 +2,16 @@ import io from contextlib import contextmanager -from typing import Any, Generator, Literal, Sequence, Type, Union +from typing import Any, Generator, Literal, Optional, Sequence, Type, Union from ._protocol import PROTOCOL -from ._trusted_types import PRIMITIVE_TYPE_NAMES from ._utils import LoadContext, get_module, get_type_paths from .exceptions import UntrustedTypesFoundException NODE_TYPE_MAPPING: dict[tuple[str, int], Node] = {} +VALID_NODE_CHILD_TYPES = Optional[ + Union["Node", list["Node"], dict[str, "Node"], Type, str, io.BytesIO] +] def check_type( @@ -168,7 +170,7 @@ def __init__( # 3. set self.children, where children are states of child nodes; do not # construct the children objects yet self.trusted = self._get_trusted(trusted, []) - self.children: dict[str, Any] = {} + self.children: dict[str, VALID_NODE_CHILD_TYPES] = {} def construct(self): """Construct the object. @@ -269,15 +271,11 @@ def get_unsafe_set(self) -> set[str]: if not check_type(get_module(child), child.__name__, self.trusted): # if the child is a type, we check its safety res.add(get_module(child) + "." + child.__name__) - elif isinstance(child, io.BytesIO): + elif isinstance(child, (io.BytesIO, str)): # We trust BytesIO objects, which are read by other - # libraries such as numpy, scipy. - continue - elif check_type( - get_module(child), child.__class__.__name__, PRIMITIVE_TYPE_NAMES - ): - # if the child is a primitive type, we don't need to check its - # safety. + # libraries such as numpy, scipy. We trust str but have to + # be careful that anything with str is dealt with + # appropriately. continue else: raise ValueError( diff --git a/skops/io/_sklearn.py b/skops/io/_sklearn.py index 4d302267..ce9c3969 100644 --- a/skops/io/_sklearn.py +++ b/skops/io/_sklearn.py @@ -1,6 +1,6 @@ from __future__ import annotations -from typing import Any, Callable, Sequence, Type +from typing import Any, Sequence, Type from sklearn.cluster import Birch @@ -96,7 +96,7 @@ def __init__( self, state: dict[str, Any], load_context: LoadContext, - constructor: Type[Any] | Callable[..., Any], + constructor: Type[Any], trusted: bool | Sequence[str] = False, ) -> None: super().__init__(state, load_context, trusted) diff --git a/skops/io/_visualize.py b/skops/io/_visualize.py index 75269bf0..8e31c2fc 100644 --- a/skops/io/_visualize.py +++ b/skops/io/_visualize.py @@ -7,7 +7,7 @@ from typing import Any, Callable, Iterator, Literal from zipfile import ZipFile -from ._audit import Node, get_tree +from ._audit import VALID_NODE_CHILD_TYPES, Node, get_tree from ._general import FunctionNode, JsonNode, ListNode from ._numpy import NdArrayNode from ._scipy import SparseMatrixNode @@ -168,7 +168,7 @@ def pretty_print_tree( def walk_tree( - node: Node | dict[str, Node] | list[Node], + node: VALID_NODE_CHILD_TYPES | dict[str, VALID_NODE_CHILD_TYPES], node_name: str = "root", level: int = 0, is_last: bool = False, diff --git a/skops/io/tests/test_persist.py b/skops/io/tests/test_persist.py index 1724a4c1..eb7c5107 100644 --- a/skops/io/tests/test_persist.py +++ b/skops/io/tests/test_persist.py @@ -263,7 +263,9 @@ def _unsupported_estimators(type_filter=None): ) def test_can_persist_non_fitted(estimator): """Check that non-fitted estimators can be persisted.""" - loaded = loads(dumps(estimator), trusted=True) + dumped = dumps(estimator) + untrusted_types = get_untrusted_types(data=dumped) + loaded = loads(dumped, trusted=untrusted_types) assert_params_equal(estimator.get_params(), loaded.get_params()) @@ -458,7 +460,9 @@ def split(self, X, **kwargs): ) def test_cross_validator(cv): est = CVEstimator(cv=cv).fit(None, None) - loaded = loads(dumps(est), trusted=True) + dumped = dumps(est) + untrusted_types = get_untrusted_types(data=dumped) + loaded = loads(dumped, trusted=untrusted_types) X, y = make_classification( n_samples=N_SAMPLES, n_features=N_FEATURES, random_state=0 ) @@ -500,7 +504,9 @@ def test_numpy_object_dtype_2d_array(transpose): if transpose: est.obj_array_ = est.obj_array_.T - loaded = loads(dumps(est), trusted=True) + dumped = dumps(est) + untrusted_types = get_untrusted_types(data=dumped) + loaded = loads(dumped, trusted=untrusted_types) assert_params_equal(est.__dict__, loaded.__dict__) @@ -615,7 +621,8 @@ def test_identical_numpy_arrays_not_duplicated(): X = np.random.random((10, 5)) estimator = EstimatorIdenticalArrays().fit(X) dumped = dumps(estimator) - loaded = loads(dumped, trusted=True) + untrusted_types = get_untrusted_types(data=dumped) + loaded = loads(dumped, trusted=untrusted_types) assert_params_equal(estimator.__dict__, loaded.__dict__) # check number of numpy arrays stored on disk @@ -719,7 +726,9 @@ def test_for_base_case_returns_as_expected(self): bound_function = obj.bound_method transformer = FunctionTransformer(func=bound_function) - loaded_transformer = loads(dumps(transformer), trusted=True) + dumped = dumps(transformer) + untrusted_types = get_untrusted_types(data=dumped) + loaded_transformer = loads(dumped, trusted=untrusted_types) loaded_obj = loaded_transformer.func.__self__ self.assert_transformer_persisted_correctly(loaded_transformer, transformer) @@ -736,7 +745,9 @@ def test_when_object_is_changed_after_init_works_as_expected(self): transformer = FunctionTransformer(func=bound_function) - loaded_transformer = loads(dumps(transformer), trusted=True) + dumped = dumps(transformer) + untrusted_types = get_untrusted_types(data=dumped) + loaded_transformer = loads(dumped, trusted=untrusted_types) loaded_obj = loaded_transformer.func.__self__ self.assert_transformer_persisted_correctly(loaded_transformer, transformer) @@ -749,19 +760,23 @@ def test_works_when_given_multiple_bound_methods_attached_to_single_instance(sel func=obj.bound_method, inverse_func=obj.other_bound_method ) - loaded_transformer = loads(dumps(transformer), trusted=True) + dumped = dumps(transformer) + untrusted_types = get_untrusted_types(data=dumped) + loaded_transformer = loads(dumped, trusted=untrusted_types) # check that both func and inverse_func are from the same object instance loaded_0 = loaded_transformer.func.__self__ loaded_1 = loaded_transformer.inverse_func.__self__ assert loaded_0 is loaded_1 - @pytest.mark.xfail(reason="Failing due to circular self reference") + @pytest.mark.xfail(reason="Failing due to circular self reference", strict=True) def test_scipy_stats(self, tmp_path): from scipy import stats estimator = FunctionTransformer(func=stats.zipf) - loads(dumps(estimator), trusted=True) + dumped = dumps(estimator) + untrusted_types = get_untrusted_types(data=dumped) + loads(dumped, trusted=untrusted_types) class CustomEstimator(BaseEstimator): @@ -862,7 +877,9 @@ def test_dump_and_load_with_file_wrapper(tmp_path): ) def test_when_given_object_referenced_twice_loads_as_one_object(obj): an_object = {"obj_1": obj, "obj_2": obj} - persisted_object = loads(dumps(an_object), trusted=True) + dumped = dumps(an_object) + untrusted_types = get_untrusted_types(data=dumped) + persisted_object = loads(dumped, trusted=untrusted_types) assert persisted_object["obj_1"] is persisted_object["obj_2"] @@ -876,7 +893,9 @@ def fit(self, X, y, **fit_params): def test_estimator_with_bytes(): est = EstimatorWithBytes().fit(None, None) - loaded = loads(dumps(est), trusted=True) + dumped = dumps(est) + untrusted_types = get_untrusted_types(data=dumped) + loaded = loads(dumped, trusted=untrusted_types) assert_params_equal(est.__dict__, loaded.__dict__) @@ -934,13 +953,17 @@ def test_persist_operator(op): _, func = op # unfitted est = FunctionTransformer(func) - loaded = loads(dumps(est), trusted=True) + dumped = dumps(est) + untrusted_types = get_untrusted_types(data=dumped) + loaded = loads(dumped, trusted=untrusted_types) assert_params_equal(est.__dict__, loaded.__dict__) # fitted X, y = get_input(est) est.fit(X, y) - loaded = loads(dumps(est), trusted=True) + dumped = dumps(est) + untrusted_types = get_untrusted_types(data=dumped) + loaded = loads(dumped, trusted=untrusted_types) assert_params_equal(est.__dict__, loaded.__dict__) # Technically, we don't need to call transform. However, if this is skipped, @@ -973,7 +996,8 @@ def test_persist_function(func): estimator.fit(X, y) dumped = dumps(estimator) - loaded = loads(dumped, trusted=True) + untrusted_types = get_untrusted_types(data=dumped) + loaded = loads(dumped, trusted=untrusted_types) # check that loaded estimator is identical assert_params_equal(estimator.__dict__, loaded.__dict__) From e4ff1a8f19489382b1ca869cf9f8c4b42089cb59 Mon Sep 17 00:00:00 2001 From: Benjamin Bossan Date: Wed, 5 Apr 2023 15:30:30 +0200 Subject: [PATCH 2/2] Make type annotations work with Python 3.8 --- skops/io/_audit.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/skops/io/_audit.py b/skops/io/_audit.py index ea7253e5..d2426473 100644 --- a/skops/io/_audit.py +++ b/skops/io/_audit.py @@ -2,7 +2,7 @@ import io from contextlib import contextmanager -from typing import Any, Generator, Literal, Optional, Sequence, Type, Union +from typing import Any, Dict, Generator, List, Literal, Optional, Sequence, Type, Union from ._protocol import PROTOCOL from ._utils import LoadContext, get_module, get_type_paths @@ -10,7 +10,7 @@ NODE_TYPE_MAPPING: dict[tuple[str, int], Node] = {} VALID_NODE_CHILD_TYPES = Optional[ - Union["Node", list["Node"], dict[str, "Node"], Type, str, io.BytesIO] + Union["Node", List["Node"], Dict[str, "Node"], Type, str, io.BytesIO] ]