Skip to content

Commit

Permalink
Refactor: get_instance method saved in state
Browse files Browse the repository at this point in the history
Resolves skops-dev#197

Description

Currently, during the dispatch of the get_instance functions, the class
stored in the state is being loaded to determine which function to
dispatch to. This is bad because module loading can be dangerous. We
will add auditing but it is intended to be on the level of
get_instance itself, not for the dispatch mechanism.

In this PR, the state returned by get_state functions is augmented with
the name of the get_instance method required to load the object. This
way, we can look up the correct method based on the state and don't need
to use the modified singledispatch mechanism, thus avoiding loading
modules during dispatching.

Implementation

Whereas for get_state, we still rely in singledispatch, for get_instance
we now use a simple dictionary that looks up the function based on its
name (which in turn is stored in the state). The dictionary, going by
the name of GET_INSTANCE_MAPPING, is populated similarly to how the
get_instance functions were registered previously with singledispatch.

There was an issue with circular imports (e.g. get_instance >
GET_INSTANCE_MAPPING > ndarray_get_instance > get_instance), hence the
get_instance function was moved to its own module, _dispatch.py (other
name suggestions welcome).

For some types, we now need extra get_state functions because they
have specific get_instance methods. So e.g. sgd_loss_get_state just
wraps reduce_get_state and adds sgd_loss_get_instance as its loader.

Coincidental changes

Since we no longer have to inspect the contents of state to determine
the function to dispatch to for get_instance, we can fall back to the
Python implementation of singledispatch instead of rolling our own. This
side effect is a big win.

The function Tree_get_instance was renamed to tree_get_instance for
consistency.

In the debug_dispatch_functions, there was some code from a time when
the state was allowed not to be a dict (json-serializable objects). Now
we always have a dict, so this dead code was removed.

Also, this fixture was elevated to module-level scope, since it only
needs to run once.
  • Loading branch information
BenjaminBossan committed Oct 21, 2022
1 parent bf8c2c1 commit abe490d
Show file tree
Hide file tree
Showing 9 changed files with 113 additions and 212 deletions.
21 changes: 21 additions & 0 deletions skops/io/_dispatch.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
#!/usr/bin/env python3

import json
from typing import Any, Callable
from zipfile import ZipFile

GET_INSTANCE_MAPPING: dict[str, Callable[[dict[str, Any], ZipFile], Any]] = {}


def get_instance(state, src):
"""Create instance based on the state, using json if possible"""
if state.get("is_json"):
return json.loads(state["content"])

try:
get_instance_func = GET_INSTANCE_MAPPING[state["__loader__"]]
except KeyError:
raise TypeError(
f"Creating an instance of type {type(state)} is not supported yet"
)
return get_instance_func(state, src)
34 changes: 22 additions & 12 deletions skops/io/_general.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,14 +7,16 @@

import numpy as np

from ._utils import SaveState, _import_obj, get_instance, get_module, get_state, gettype
from ._dispatch import get_instance
from ._utils import SaveState, _import_obj, get_module, get_state, gettype
from .exceptions import UnsupportedTypeException


def dict_get_state(obj: Any, save_state: SaveState) -> dict[str, Any]:
res = {
"__class__": obj.__class__.__name__,
"__module__": get_module(type(obj)),
"__loader__": "dict_get_instance",
}

key_types = get_state([type(key) for key in obj.keys()], save_state)
Expand Down Expand Up @@ -43,6 +45,7 @@ def list_get_state(obj: Any, save_state: SaveState) -> dict[str, Any]:
res = {
"__class__": obj.__class__.__name__,
"__module__": get_module(type(obj)),
"__loader__": "list_get_instance",
}
content = []
for value in obj:
Expand All @@ -62,6 +65,7 @@ def tuple_get_state(obj: Any, save_state: SaveState) -> dict[str, Any]:
res = {
"__class__": obj.__class__.__name__,
"__module__": get_module(type(obj)),
"__loader__": "tuple_get_instance",
}
content = tuple(get_state(value, save_state) for value in obj)
res["content"] = content
Expand Down Expand Up @@ -93,6 +97,7 @@ def function_get_state(obj: Any, save_state: SaveState) -> dict[str, Any]:
res = {
"__class__": obj.__class__.__name__,
"__module__": get_module(obj),
"__loader__": "function_get_instance",
"content": {
"module_path": get_module(obj),
"function": obj.__name__,
Expand All @@ -111,6 +116,7 @@ def partial_get_state(obj: Any, save_state: SaveState) -> dict[str, Any]:
res = {
"__class__": "partial", # don't allow any subclass
"__module__": get_module(type(obj)),
"__loader__": "partial_get_instance",
"content": {
"func": get_state(func, save_state),
"args": get_state(args, save_state),
Expand Down Expand Up @@ -138,6 +144,7 @@ def type_get_state(obj: Any, save_state: SaveState) -> dict[str, Any]:
res = {
"__class__": obj.__class__.__name__,
"__module__": get_module(type(obj)),
"__loader__": "type_get_instance",
"content": {
"__class__": obj.__name__,
"__module__": get_module(obj),
Expand All @@ -155,6 +162,7 @@ def slice_get_state(obj: Any, save_state: SaveState) -> dict[str, Any]:
res = {
"__class__": obj.__class__.__name__,
"__module__": get_module(type(obj)),
"__loader__": "slice_get_instance",
"content": {
"start": obj.start,
"stop": obj.stop,
Expand All @@ -181,6 +189,7 @@ def object_get_state(obj: Any, save_state: SaveState) -> dict[str, Any]:
return {
"__class__": "str",
"__module__": "builtins",
"__loader__": "none",
"content": obj_str,
"is_json": True,
}
Expand All @@ -190,6 +199,7 @@ def object_get_state(obj: Any, save_state: SaveState) -> dict[str, Any]:
res = {
"__class__": obj.__class__.__name__,
"__module__": get_module(type(obj)),
"__loader__": "object_get_instance",
}

# __getstate__ takes priority over __dict__, and if non exist, we only save
Expand Down Expand Up @@ -247,14 +257,14 @@ def unsupported_get_state(obj: Any, save_state: SaveState) -> dict[str, Any]:
(type, type_get_state),
(object, object_get_state),
]
# tuples of type and function that creates the instance of that type
GET_INSTANCE_DISPATCH_FUNCTIONS = [
(dict, dict_get_instance),
(list, list_get_instance),
(tuple, tuple_get_instance),
(slice, slice_get_instance),
(FunctionType, function_get_instance),
(partial, partial_get_instance),
(type, type_get_instance),
(object, object_get_instance),
]

GET_INSTANCE_DISPATCH_MAPPING = {
"dict_get_instance": dict_get_instance,
"list_get_instance": list_get_instance,
"tuple_get_instance": tuple_get_instance,
"slice_get_instance": slice_get_instance,
"function_get_instance": function_get_instance,
"partial_get_instance": partial_get_instance,
"type_get_instance": type_get_instance,
"object_get_instance": object_get_instance,
}
26 changes: 16 additions & 10 deletions skops/io/_numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,15 +5,17 @@

import numpy as np

from ._dispatch import get_instance
from ._general import function_get_instance
from ._utils import SaveState, _import_obj, get_instance, get_module, get_state
from ._utils import SaveState, _import_obj, get_module, get_state
from .exceptions import UnsupportedTypeException


def ndarray_get_state(obj: Any, save_state: SaveState) -> dict[str, Any]:
res = {
"__class__": obj.__class__.__name__,
"__module__": get_module(type(obj)),
"__loader__": "ndarray_get_instance",
}

try:
Expand Down Expand Up @@ -78,6 +80,7 @@ def maskedarray_get_state(obj: Any, save_state: SaveState) -> dict[str, Any]:
res = {
"__class__": obj.__class__.__name__,
"__module__": get_module(type(obj)),
"__loader__": "maskedarray_get_instance",
"content": {
"data": get_state(obj.data, save_state),
"mask": get_state(obj.mask, save_state),
Expand All @@ -97,6 +100,7 @@ def random_state_get_state(obj: Any, save_state: SaveState) -> dict[str, Any]:
res = {
"__class__": obj.__class__.__name__,
"__module__": get_module(type(obj)),
"__loader__": "random_state_get_instance",
"content": content,
}
return res
Expand All @@ -115,6 +119,7 @@ def random_generator_get_state(obj: Any, save_state: SaveState) -> dict[str, Any
res = {
"__class__": obj.__class__.__name__,
"__module__": get_module(type(obj)),
"__loader__": "random_generator_get_instance",
"content": {"bit_generator": bit_generator_state},
}
return res
Expand All @@ -139,6 +144,7 @@ def ufunc_get_state(obj: Any, save_state: SaveState) -> dict[str, Any]:
res = {
"__class__": obj.__class__.__name__, # ufunc
"__module__": get_module(type(obj)), # numpy
"__loader__": "function_get_instance",
"content": {
"module_path": get_module(obj),
"function": obj.__name__,
Expand All @@ -154,6 +160,7 @@ def dtype_get_state(obj: Any, save_state: SaveState) -> dict[str, Any]:
res = {
"__class__": "dtype",
"__module__": "numpy",
"__loader__": "dtype_get_instance",
"content": ndarray_get_state(tmp, save_state),
}
return res
Expand All @@ -177,12 +184,11 @@ def dtype_get_instance(state, src):
(np.random.Generator, random_generator_get_state),
]
# tuples of type and function that creates the instance of that type
GET_INSTANCE_DISPATCH_FUNCTIONS = [
(np.generic, ndarray_get_instance),
(np.ndarray, ndarray_get_instance),
(np.ma.MaskedArray, maskedarray_get_instance),
(np.ufunc, function_get_instance),
(np.dtype, dtype_get_instance),
(np.random.RandomState, random_state_get_instance),
(np.random.Generator, random_generator_get_instance),
]
GET_INSTANCE_DISPATCH_MAPPING = {
"ndarray_get_instance": ndarray_get_instance,
"maskedarray_get_instance": maskedarray_get_instance,
"function_get_instance": function_get_instance,
"dtype_get_instance": dtype_get_instance,
"random_state_get_instance": random_state_get_instance,
"random_generator_get_instance": random_generator_get_instance,
}
7 changes: 4 additions & 3 deletions skops/io/_persist.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,8 @@

import skops

from ._utils import SaveState, _get_instance, _get_state, get_instance, get_state
from ._dispatch import GET_INSTANCE_MAPPING, get_instance
from ._utils import SaveState, _get_state, get_state

# We load the dispatch functions from the corresponding modules and register
# them.
Expand All @@ -17,8 +18,8 @@
module = importlib.import_module(module_name, package="skops.io")
for cls, method in getattr(module, "GET_STATE_DISPATCH_FUNCTIONS", []):
_get_state.register(cls)(method)
for cls, method in getattr(module, "GET_INSTANCE_DISPATCH_FUNCTIONS", []):
_get_instance.register(cls)(method)
# populate the the dict used for dispatching get_instance functions
GET_INSTANCE_MAPPING.update(module.GET_INSTANCE_DISPATCH_MAPPING)


def _save(obj):
Expand Down
7 changes: 4 additions & 3 deletions skops/io/_scipy.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ def sparse_matrix_get_state(obj: Any, save_state: SaveState) -> dict[str, Any]:
res = {
"__class__": obj.__class__.__name__,
"__module__": get_module(type(obj)),
"__loader__": "sparse_matrix_get_instance",
}

data_buffer = io.BytesIO()
Expand Down Expand Up @@ -49,8 +50,8 @@ def sparse_matrix_get_instance(state, src):
(spmatrix, sparse_matrix_get_state),
]
# tuples of type and function that creates the instance of that type
GET_INSTANCE_DISPATCH_FUNCTIONS = [
GET_INSTANCE_DISPATCH_MAPPING = {
# use 'spmatrix' to check if a matrix is a sparse matrix because that is
# what scipy.sparse.issparse checks
(spmatrix, sparse_matrix_get_instance),
]
"sparse_matrix_get_instance": sparse_matrix_get_instance,
}
44 changes: 32 additions & 12 deletions skops/io/_sklearn.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,9 @@
from sklearn.tree._tree import Tree
from sklearn.utils import Bunch

from ._dispatch import get_instance
from ._general import dict_get_instance, dict_get_state, unsupported_get_state
from ._utils import SaveState, get_instance, get_module, get_state, gettype
from ._utils import SaveState, get_module, get_state, gettype
from .exceptions import UnsupportedTypeException

ALLOWED_SGD_LOSSES = {
Expand Down Expand Up @@ -110,17 +111,35 @@ def reduce_get_instance(state, src, constructor):
return instance


def Tree_get_instance(state, src):
def tree_get_state(obj: Any, save_state: SaveState) -> dict[str, Any]:
state = reduce_get_state(obj, save_state)
state["__loader__"] = "tree_get_instance"
return state


def tree_get_instance(state, src):
return reduce_get_instance(state, src, constructor=Tree)


def sgd_loss_get_state(obj: Any, save_state: SaveState) -> dict[str, Any]:
state = reduce_get_state(obj, save_state)
state["__loader__"] = "sgd_loss_get_instance"
return state


def sgd_loss_get_instance(state, src):
cls = gettype(state)
if cls not in ALLOWED_SGD_LOSSES:
raise UnsupportedTypeException(f"Expected LossFunction, got {cls}")
return reduce_get_instance(state, src, constructor=cls)


def bunch_get_state(obj: Any, save_state: SaveState) -> dict[str, Any]:
state = dict_get_state(obj, save_state)
state["__loader__"] = "bunch_get_instance"
return state


def bunch_get_instance(state, src):
# Bunch is just a wrapper for dict
content = dict_get_instance(state, src)
Expand All @@ -134,6 +153,7 @@ def _DictWithDeprecatedKeys_get_state(
res = {
"__class__": obj.__class__.__name__,
"__module__": get_module(type(obj)),
"__loader__": "_DictWithDeprecatedKeys_get_instance",
}
content = {}
content["main"] = dict_get_state(obj, save_state)
Expand All @@ -158,18 +178,18 @@ def _DictWithDeprecatedKeys_get_instance(state, src):

# tuples of type and function that gets the state of that type
GET_STATE_DISPATCH_FUNCTIONS = [
(LossFunction, reduce_get_state),
(Tree, reduce_get_state),
(LossFunction, sgd_loss_get_state),
(Tree, tree_get_state),
]
for type_ in UNSUPPORTED_TYPES:
GET_STATE_DISPATCH_FUNCTIONS.append((type_, unsupported_get_state))

# tuples of type and function that creates the instance of that type
GET_INSTANCE_DISPATCH_FUNCTIONS = [
(LossFunction, sgd_loss_get_instance),
(Tree, Tree_get_instance),
(Bunch, bunch_get_instance),
]
GET_INSTANCE_DISPATCH_MAPPING = {
"sgd_loss_get_instance": sgd_loss_get_instance,
"tree_get_instance": tree_get_instance,
"bunch_get_instance": bunch_get_instance,
}

# TODO: remove once support for sklearn<1.2 is dropped.
# Starting from sklearn 1.2, _DictWithDeprecatedKeys is removed as it's no
Expand All @@ -178,6 +198,6 @@ def _DictWithDeprecatedKeys_get_instance(state, src):
GET_STATE_DISPATCH_FUNCTIONS.append(
(_DictWithDeprecatedKeys, _DictWithDeprecatedKeys_get_state)
)
GET_INSTANCE_DISPATCH_FUNCTIONS.append(
(_DictWithDeprecatedKeys, _DictWithDeprecatedKeys_get_instance)
)
GET_INSTANCE_DISPATCH_MAPPING[
"_DictWithDeprecatedKeys_get_instance"
] = _DictWithDeprecatedKeys_get_instance
Loading

0 comments on commit abe490d

Please sign in to comment.