Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

MNT: A little hardening of the auditing of Nodes #340

Merged
merged 2 commits into from
Apr 6, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,6 @@ omit = [
]

[tool.mypy]
exclude = "(\\w+/)*test_\\w+\\.py$"
exclude = "(\\w+/)*test_\\w+\\.py$|old"
ignore_missing_imports = true
no_implicit_optional = true
20 changes: 9 additions & 11 deletions skops/io/_audit.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,16 @@

import io
from contextlib import contextmanager
from typing import Any, Generator, Literal, Sequence, Type, Union
from typing import Any, Dict, Generator, List, Literal, Optional, Sequence, Type, Union

from ._protocol import PROTOCOL
from ._trusted_types import PRIMITIVE_TYPE_NAMES
from ._utils import LoadContext, get_module, get_type_paths
from .exceptions import UntrustedTypesFoundException

NODE_TYPE_MAPPING: dict[tuple[str, int], Node] = {}
VALID_NODE_CHILD_TYPES = Optional[
Union["Node", List["Node"], Dict[str, "Node"], Type, str, io.BytesIO]
]


def check_type(
Expand Down Expand Up @@ -168,7 +170,7 @@ def __init__(
# 3. set self.children, where children are states of child nodes; do not
# construct the children objects yet
self.trusted = self._get_trusted(trusted, [])
self.children: dict[str, Any] = {}
self.children: dict[str, VALID_NODE_CHILD_TYPES] = {}

def construct(self):
"""Construct the object.
Expand Down Expand Up @@ -269,15 +271,11 @@ def get_unsafe_set(self) -> set[str]:
if not check_type(get_module(child), child.__name__, self.trusted):
# if the child is a type, we check its safety
res.add(get_module(child) + "." + child.__name__)
elif isinstance(child, io.BytesIO):
elif isinstance(child, (io.BytesIO, str)):
# We trust BytesIO objects, which are read by other
# libraries such as numpy, scipy.
continue
elif check_type(
get_module(child), child.__class__.__name__, PRIMITIVE_TYPE_NAMES
):
Comment on lines -276 to -278
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

are we removing them cause now primitives are trusted by default?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is what I was referring to here:

Another "victim" of this PR is that the so far dead code of checking for primitive types inside of get_unsafe_set has been removed. This code was supposed to check if the type is a primitive type but it was defective. get_module(child) would raise an error if an instance of the type would be passed. We could theoretically fix that code, but it would still be dead code because primitive types are stored as json.

# if the child is a primitive type, we don't need to check its
# safety.
# libraries such as numpy, scipy. We trust str but have to
# be careful that anything with str is dealt with
# appropriately.
continue
else:
raise ValueError(
Expand Down
4 changes: 2 additions & 2 deletions skops/io/_sklearn.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from __future__ import annotations

from typing import Any, Callable, Sequence, Type
from typing import Any, Sequence, Type

from sklearn.cluster import Birch

Expand Down Expand Up @@ -96,7 +96,7 @@ def __init__(
self,
state: dict[str, Any],
load_context: LoadContext,
constructor: Type[Any] | Callable[..., Any],
constructor: Type[Any],
trusted: bool | Sequence[str] = False,
) -> None:
super().__init__(state, load_context, trusted)
Expand Down
4 changes: 2 additions & 2 deletions skops/io/_visualize.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from typing import Any, Callable, Iterator, Literal
from zipfile import ZipFile

from ._audit import Node, get_tree
from ._audit import VALID_NODE_CHILD_TYPES, Node, get_tree
from ._general import FunctionNode, JsonNode, ListNode
from ._numpy import NdArrayNode
from ._scipy import SparseMatrixNode
Expand Down Expand Up @@ -168,7 +168,7 @@ def pretty_print_tree(


def walk_tree(
node: Node | dict[str, Node] | list[Node],
node: VALID_NODE_CHILD_TYPES | dict[str, VALID_NODE_CHILD_TYPES],
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

isn't dict(str, Node) included in VALID_NODE_CHILD_TYPES?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, but this is dict[str, VALID_NODE_CHILD_TYPES], so it could be something where the key is not a Node, like {"foo": None}

node_name: str = "root",
level: int = 0,
is_last: bool = False,
Expand Down
52 changes: 38 additions & 14 deletions skops/io/tests/test_persist.py
Original file line number Diff line number Diff line change
Expand Up @@ -263,7 +263,9 @@ def _unsupported_estimators(type_filter=None):
)
def test_can_persist_non_fitted(estimator):
"""Check that non-fitted estimators can be persisted."""
loaded = loads(dumps(estimator), trusted=True)
dumped = dumps(estimator)
untrusted_types = get_untrusted_types(data=dumped)
loaded = loads(dumped, trusted=untrusted_types)
assert_params_equal(estimator.get_params(), loaded.get_params())


Expand Down Expand Up @@ -458,7 +460,9 @@ def split(self, X, **kwargs):
)
def test_cross_validator(cv):
est = CVEstimator(cv=cv).fit(None, None)
loaded = loads(dumps(est), trusted=True)
dumped = dumps(est)
untrusted_types = get_untrusted_types(data=dumped)
loaded = loads(dumped, trusted=untrusted_types)
X, y = make_classification(
n_samples=N_SAMPLES, n_features=N_FEATURES, random_state=0
)
Expand Down Expand Up @@ -500,7 +504,9 @@ def test_numpy_object_dtype_2d_array(transpose):
if transpose:
est.obj_array_ = est.obj_array_.T

loaded = loads(dumps(est), trusted=True)
dumped = dumps(est)
untrusted_types = get_untrusted_types(data=dumped)
loaded = loads(dumped, trusted=untrusted_types)
assert_params_equal(est.__dict__, loaded.__dict__)


Expand Down Expand Up @@ -615,7 +621,8 @@ def test_identical_numpy_arrays_not_duplicated():
X = np.random.random((10, 5))
estimator = EstimatorIdenticalArrays().fit(X)
dumped = dumps(estimator)
loaded = loads(dumped, trusted=True)
untrusted_types = get_untrusted_types(data=dumped)
loaded = loads(dumped, trusted=untrusted_types)
assert_params_equal(estimator.__dict__, loaded.__dict__)

# check number of numpy arrays stored on disk
Expand Down Expand Up @@ -719,7 +726,9 @@ def test_for_base_case_returns_as_expected(self):
bound_function = obj.bound_method
transformer = FunctionTransformer(func=bound_function)

loaded_transformer = loads(dumps(transformer), trusted=True)
dumped = dumps(transformer)
untrusted_types = get_untrusted_types(data=dumped)
loaded_transformer = loads(dumped, trusted=untrusted_types)
loaded_obj = loaded_transformer.func.__self__

self.assert_transformer_persisted_correctly(loaded_transformer, transformer)
Expand All @@ -736,7 +745,9 @@ def test_when_object_is_changed_after_init_works_as_expected(self):

transformer = FunctionTransformer(func=bound_function)

loaded_transformer = loads(dumps(transformer), trusted=True)
dumped = dumps(transformer)
untrusted_types = get_untrusted_types(data=dumped)
loaded_transformer = loads(dumped, trusted=untrusted_types)
loaded_obj = loaded_transformer.func.__self__

self.assert_transformer_persisted_correctly(loaded_transformer, transformer)
Expand All @@ -749,19 +760,23 @@ def test_works_when_given_multiple_bound_methods_attached_to_single_instance(sel
func=obj.bound_method, inverse_func=obj.other_bound_method
)

loaded_transformer = loads(dumps(transformer), trusted=True)
dumped = dumps(transformer)
untrusted_types = get_untrusted_types(data=dumped)
loaded_transformer = loads(dumped, trusted=untrusted_types)

# check that both func and inverse_func are from the same object instance
loaded_0 = loaded_transformer.func.__self__
loaded_1 = loaded_transformer.inverse_func.__self__
assert loaded_0 is loaded_1

@pytest.mark.xfail(reason="Failing due to circular self reference")
@pytest.mark.xfail(reason="Failing due to circular self reference", strict=True)
def test_scipy_stats(self, tmp_path):
from scipy import stats

estimator = FunctionTransformer(func=stats.zipf)
loads(dumps(estimator), trusted=True)
dumped = dumps(estimator)
untrusted_types = get_untrusted_types(data=dumped)
loads(dumped, trusted=untrusted_types)


class CustomEstimator(BaseEstimator):
Expand Down Expand Up @@ -862,7 +877,9 @@ def test_dump_and_load_with_file_wrapper(tmp_path):
)
def test_when_given_object_referenced_twice_loads_as_one_object(obj):
an_object = {"obj_1": obj, "obj_2": obj}
persisted_object = loads(dumps(an_object), trusted=True)
dumped = dumps(an_object)
untrusted_types = get_untrusted_types(data=dumped)
persisted_object = loads(dumped, trusted=untrusted_types)

assert persisted_object["obj_1"] is persisted_object["obj_2"]

Expand All @@ -876,7 +893,9 @@ def fit(self, X, y, **fit_params):

def test_estimator_with_bytes():
est = EstimatorWithBytes().fit(None, None)
loaded = loads(dumps(est), trusted=True)
dumped = dumps(est)
untrusted_types = get_untrusted_types(data=dumped)
loaded = loads(dumped, trusted=untrusted_types)
assert_params_equal(est.__dict__, loaded.__dict__)


Expand Down Expand Up @@ -934,13 +953,17 @@ def test_persist_operator(op):
_, func = op
# unfitted
est = FunctionTransformer(func)
loaded = loads(dumps(est), trusted=True)
dumped = dumps(est)
untrusted_types = get_untrusted_types(data=dumped)
loaded = loads(dumped, trusted=untrusted_types)
assert_params_equal(est.__dict__, loaded.__dict__)

# fitted
X, y = get_input(est)
est.fit(X, y)
loaded = loads(dumps(est), trusted=True)
dumped = dumps(est)
untrusted_types = get_untrusted_types(data=dumped)
loaded = loads(dumped, trusted=untrusted_types)
assert_params_equal(est.__dict__, loaded.__dict__)

# Technically, we don't need to call transform. However, if this is skipped,
Expand Down Expand Up @@ -973,7 +996,8 @@ def test_persist_function(func):
estimator.fit(X, y)

dumped = dumps(estimator)
loaded = loads(dumped, trusted=True)
untrusted_types = get_untrusted_types(data=dumped)
loaded = loads(dumped, trusted=untrusted_types)

# check that loaded estimator is identical
assert_params_equal(estimator.__dict__, loaded.__dict__)
Expand Down