Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Method type in dispatch calls #195

Merged
merged 26 commits into from
Oct 31, 2022
Merged
Show file tree
Hide file tree
Changes from 19 commits
Commits
Show all changes
26 commits
Select commit Hold shift + click to select a range
c2010db
Add Method type in dispatch calls
E-Aho Oct 17, 2022
0534d1b
Change from debug string to actual string
E-Aho Oct 18, 2022
839e2cd
Load object of bound method with attributes
E-Aho Oct 19, 2022
b37a7d4
Rename _BoundMethodHolder
E-Aho Oct 19, 2022
deb17ad
Merge branch 'main' into FIX-bound-method-serialization
E-Aho Oct 20, 2022
1a36c9f
Serialise -> serialize
E-Aho Oct 21, 2022
9e2405d
Address PR comments
E-Aho Oct 22, 2022
1575bd5
Merge branch 'main' into FIX-bound-method-serialization
E-Aho Oct 23, 2022
8127aa6
Reword comment
E-Aho Oct 23, 2022
9fe2e40
Merge branch 'FIX-bound-method-serialization' of github.com:E-Aho/sko…
E-Aho Oct 23, 2022
aec304a
Update skops/io/_utils.py
E-Aho Oct 24, 2022
05f7ce9
Change from inline comments
E-Aho Oct 24, 2022
5f8e102
Implement code review comments
E-Aho Oct 24, 2022
8e03ffc
Merge-conflict-resolve
E-Aho Oct 24, 2022
6684a75
Merge branch 'main' into FIX-bound-method-serialization
E-Aho Oct 24, 2022
87990aa
Merge branch 'main' into FIX-bound-method-serialization
E-Aho Oct 25, 2022
1dda374
Fix gettype to not return FunctionType or MethodType
E-Aho Oct 25, 2022
ad6122a
Merge branch 'main' into FIX-bound-method-serialization
E-Aho Oct 25, 2022
7ccf3fd
whitespace for linter
E-Aho Oct 25, 2022
78db91a
Update skops/io/_general.py
E-Aho Oct 25, 2022
b6aa8fd
Add LoadState to get_instance functions
E-Aho Oct 27, 2022
30b0147
Add minor comments to additions to get_instance
E-Aho Oct 27, 2022
b613c4e
Merge resolve
E-Aho Oct 27, 2022
ddd94be
Revert "Add LoadState to get_instance functions"
E-Aho Oct 28, 2022
eb4b87d
Add xfail tests for outstanding issues
E-Aho Oct 28, 2022
a2056c9
Reword XFail reason
E-Aho Oct 31, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 27 additions & 1 deletion skops/io/_general.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import json
from functools import partial
from types import FunctionType
from types import FunctionType, MethodType
from typing import Any

import numpy as np
Expand Down Expand Up @@ -242,6 +242,30 @@ def object_get_instance(state, src):
return instance


def method_get_state(obj: Any, save_state: SaveState):
E-Aho marked this conversation as resolved.
Show resolved Hide resolved
# This method is used to persist bound methods, which are
# dependent on a specific instance of an object.
# It stores the state of the object the method is bound to,
# and prepares both to be persisted.
res = {
"__class__": obj.__class__.__name__,
"__module__": get_module(obj),
"__loader__": "method_get_instance",
"content": {
"func": obj.__func__.__name__,
"obj": get_state(obj.__self__, save_state),
E-Aho marked this conversation as resolved.
Show resolved Hide resolved
},
}

return res


def method_get_instance(state, src):
loaded_obj = object_get_instance(state["content"]["obj"], src)
E-Aho marked this conversation as resolved.
Show resolved Hide resolved
method = getattr(loaded_obj, state["content"]["func"])
return method


def unsupported_get_state(obj: Any, save_state: SaveState) -> dict[str, Any]:
raise UnsupportedTypeException(obj)

Expand All @@ -253,6 +277,7 @@ def unsupported_get_state(obj: Any, save_state: SaveState) -> dict[str, Any]:
(tuple, tuple_get_state),
(slice, slice_get_state),
(FunctionType, function_get_state),
(MethodType, method_get_state),
(partial, partial_get_state),
(type, type_get_state),
(object, object_get_state),
Expand All @@ -264,6 +289,7 @@ def unsupported_get_state(obj: Any, save_state: SaveState) -> dict[str, Any]:
"tuple_get_instance": tuple_get_instance,
"slice_get_instance": slice_get_instance,
"function_get_instance": function_get_instance,
"method_get_instance": method_get_instance,
"partial_get_instance": partial_get_instance,
"type_get_instance": type_get_instance,
"object_get_instance": object_get_instance,
Expand Down
5 changes: 0 additions & 5 deletions skops/io/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
import sys
from dataclasses import dataclass, field
from functools import singledispatch
from types import FunctionType
from typing import Any
from zipfile import ZipFile

Expand Down Expand Up @@ -61,10 +60,6 @@ def _import_obj(module, cls_or_func, package=None):

def gettype(state):
if "__module__" in state and "__class__" in state:
if state["__class__"] == "function":
# This special case is due to how functions are serialized. We
# could try to change it.
return FunctionType
return _import_obj(state["__module__"], state["__class__"])
return None

Expand Down
72 changes: 72 additions & 0 deletions skops/io/tests/test_persist.py
Original file line number Diff line number Diff line change
Expand Up @@ -792,6 +792,78 @@ def test_get_instance_unknown_type_error_msg():
get_instance(state, None)


class _BoundMethodHolder:
"""Used to test the ability to serialize and deserialize bound methods"""

def __init__(self, object_state: str):
# Initialize with some state to make sure state is persisted
self.object_state = object_state
# bind some method to this object, could be any persistable function
self.chosen_function = np.log

def bound_method(self, x):
return self.chosen_function(x)


class TestPersistingBoundMethods:
@staticmethod
def assert_transformer_persisted_correctly(
loaded_transformer: FunctionTransformer,
original_transformer: FunctionTransformer,
):
"""Checks that a persisted and original transformer are equivalent, including
the func passed to it
"""
assert loaded_transformer.func.__name__ == original_transformer.func.__name__

assert_params_equal(
loaded_transformer.func.__self__.__dict__,
original_transformer.func.__self__.__dict__,
)
assert_params_equal(loaded_transformer.__dict__, original_transformer.__dict__)

@staticmethod
def assert_bound_method_holder_persisted_correctly(
original_obj: _BoundMethodHolder, loaded_obj: _BoundMethodHolder
):
"""Checks that the persisted and original instances of _BoundMethodHolder are
equivalent
"""
assert original_obj.bound_method.__name__ == loaded_obj.bound_method.__name__
assert original_obj.chosen_function == loaded_obj.chosen_function

assert_params_equal(original_obj.__dict__, loaded_obj.__dict__)

def test_for_base_case_returns_as_expected(self):
initial_state = "This is an arbitrary state"
obj = _BoundMethodHolder(object_state=initial_state)
bound_function = obj.bound_method
transformer = FunctionTransformer(func=bound_function)

loaded_transformer = loads(dumps(transformer))
loaded_obj = loaded_transformer.func.__self__

self.assert_transformer_persisted_correctly(loaded_transformer, transformer)
self.assert_bound_method_holder_persisted_correctly(obj, loaded_obj)

def test_when_object_is_changed_after_init_works_as_expected(self):
# given change to object with bound method after initialisation,
# make sure still persists correctly

initial_state = "This is an arbitrary state"
obj = _BoundMethodHolder(object_state=initial_state)
obj.chosen_function = np.sqrt
bound_function = obj.bound_method

transformer = FunctionTransformer(func=bound_function)

loaded_transformer = loads(dumps(transformer))
loaded_obj = loaded_transformer.func.__self__

self.assert_transformer_persisted_correctly(loaded_transformer, transformer)
self.assert_bound_method_holder_persisted_correctly(obj, loaded_obj)


class CustomEstimator(BaseEstimator):
"""Estimator with np array, np scalar, and sparse matrix attribute"""

Expand Down