From c2010db5dcfa9304985d76782ec796de660234aa Mon Sep 17 00:00:00 2001 From: = Date: Mon, 17 Oct 2022 22:17:59 +0100 Subject: [PATCH 01/18] Add Method type in dispatch calls --- skops/io/_general.py | 25 ++++++++++++++++++++++++- skops/io/_utils.py | 4 +++- skops/io/tests/test_persist.py | 10 ++++++++++ 3 files changed, 37 insertions(+), 2 deletions(-) diff --git a/skops/io/_general.py b/skops/io/_general.py index 88413472..7eca552e 100644 --- a/skops/io/_general.py +++ b/skops/io/_general.py @@ -2,7 +2,7 @@ import json from functools import partial -from types import FunctionType +from types import FunctionType, MethodType from typing import Any import numpy as np @@ -98,6 +98,7 @@ def function_get_state(obj: Any, save_state: SaveState) -> dict[str, Any]: "function": obj.__name__, }, } + return res @@ -232,6 +233,26 @@ def object_get_instance(state, src): return instance +def method_get_state(obj: Any, save_state: SaveState): + res = { + "__class__": obj.__class__.__name__, + "__module__": get_module(obj), + "content": { + "func": obj.__func__.__name__, + "obj": obj.__self__.__class__.__name__, + "module_path": get_module(obj.__self__.__class__), + }, + } + return res + + +def method_get_instance(state, src): + # TODO: init with attrs + loaded_obj = _import_obj(state["content"]["module_path"], state["content"]["obj"])() + method = getattr(loaded_obj, "_entropy") + return method + + def unsupported_get_state(obj: Any, save_state: SaveState) -> dict[str, Any]: raise UnsupportedTypeException(obj) @@ -243,6 +264,7 @@ def unsupported_get_state(obj: Any, save_state: SaveState) -> dict[str, Any]: (tuple, tuple_get_state), (slice, slice_get_state), (FunctionType, function_get_state), + (MethodType, method_get_state), (partial, partial_get_state), (type, type_get_state), (object, object_get_state), @@ -254,6 +276,7 @@ def unsupported_get_state(obj: Any, save_state: SaveState) -> dict[str, Any]: (tuple, tuple_get_instance), (slice, slice_get_instance), (FunctionType, function_get_instance), + (MethodType, method_get_instance), (partial, partial_get_instance), (type, type_get_instance), (object, object_get_instance), diff --git a/skops/io/_utils.py b/skops/io/_utils.py index 61437894..f9be11a4 100644 --- a/skops/io/_utils.py +++ b/skops/io/_utils.py @@ -6,7 +6,7 @@ from dataclasses import dataclass, field from functools import _find_impl, get_cache_token, update_wrapper # type: ignore from pathlib import Path -from types import FunctionType +from types import FunctionType, MethodType from typing import Any from skops.utils.fixes import GenericAlias @@ -192,6 +192,8 @@ def gettype(state): # This special case is due to how functions are serialized. We # could try to change it. return FunctionType + if state["__class__"] == "method": + return MethodType return _import_obj(state["__module__"], state["__class__"]) return None diff --git a/skops/io/tests/test_persist.py b/skops/io/tests/test_persist.py index 068f6af6..60bf87a5 100644 --- a/skops/io/tests/test_persist.py +++ b/skops/io/tests/test_persist.py @@ -782,3 +782,13 @@ def test_numpy_dtype_object_does_not_store_broken_file(tmp_path): # this estimator should not have any numpy file assert not any(file.endswith(".npy") for file in files) + + +def test_for_serialized_bound_method_works_as_expected(tmp_path): + from scipy import stats + + estimator = FunctionTransformer(func=stats.zipf) + loaded_estimator = save_load_round(estimator, tmp_path / "file.skops") + + assert estimator.func._entropy(2) == loaded_estimator.func._entropy(2) + # This is the bounded method that was causing trouble. As seen, it works now! From 0534d1b9f26fc76fec8f2c0f50816b60910114d7 Mon Sep 17 00:00:00 2001 From: = Date: Tue, 18 Oct 2022 13:28:08 +0100 Subject: [PATCH 02/18] Change from debug string to actual string --- skops/io/_general.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/skops/io/_general.py b/skops/io/_general.py index 7eca552e..1c6676a3 100644 --- a/skops/io/_general.py +++ b/skops/io/_general.py @@ -249,7 +249,7 @@ def method_get_state(obj: Any, save_state: SaveState): def method_get_instance(state, src): # TODO: init with attrs loaded_obj = _import_obj(state["content"]["module_path"], state["content"]["obj"])() - method = getattr(loaded_obj, "_entropy") + method = getattr(loaded_obj, state["content"]["func"]) return method From 839e2cdc4f13193af74e85ce46eb1413563fc7ea Mon Sep 17 00:00:00 2001 From: = Date: Wed, 19 Oct 2022 14:57:00 +0100 Subject: [PATCH 03/18] Load object of bound method with attributes --- skops/io/_general.py | 7 +++---- skops/io/tests/test_persist.py | 27 ++++++++++++++++++++++----- 2 files changed, 25 insertions(+), 9 deletions(-) diff --git a/skops/io/_general.py b/skops/io/_general.py index 1c6676a3..a7261d54 100644 --- a/skops/io/_general.py +++ b/skops/io/_general.py @@ -239,16 +239,15 @@ def method_get_state(obj: Any, save_state: SaveState): "__module__": get_module(obj), "content": { "func": obj.__func__.__name__, - "obj": obj.__self__.__class__.__name__, - "module_path": get_module(obj.__self__.__class__), + "obj": object_get_state(obj.__self__, save_state), }, } + return res def method_get_instance(state, src): - # TODO: init with attrs - loaded_obj = _import_obj(state["content"]["module_path"], state["content"]["obj"])() + loaded_obj = object_get_instance(state["content"]["obj"], src) method = getattr(loaded_obj, state["content"]["func"]) return method diff --git a/skops/io/tests/test_persist.py b/skops/io/tests/test_persist.py index 60bf87a5..662064d5 100644 --- a/skops/io/tests/test_persist.py +++ b/skops/io/tests/test_persist.py @@ -784,11 +784,28 @@ def test_numpy_dtype_object_does_not_store_broken_file(tmp_path): assert not any(file.endswith(".npy") for file in files) +class DummyMethodHolder: + """Used to test the ability to serialise and deserialize""" + + def __init__(self, variant: str): + if variant == "sqrt": + self.chosen_function = np.sqrt + elif variant == "log": + self.chosen_function = np.log + else: + self.chosen_function = np.exp + + def apply(self, x): + return self.chosen_function(x) + + def test_for_serialized_bound_method_works_as_expected(tmp_path): - from scipy import stats + obj = DummyMethodHolder(variant="log") + bound_function = obj.apply + transformer = FunctionTransformer(func=bound_function) - estimator = FunctionTransformer(func=stats.zipf) - loaded_estimator = save_load_round(estimator, tmp_path / "file.skops") + loaded_transformer = save_load_round(transformer, tmp_path / "file.skops") + loaded_bound_function = loaded_transformer.func.__self__.chosen_function - assert estimator.func._entropy(2) == loaded_estimator.func._entropy(2) - # This is the bounded method that was causing trouble. As seen, it works now! + assert loaded_transformer.func.__name__ == bound_function.__name__ + assert loaded_bound_function == obj.chosen_function From b37a7d4f8a8f9dc5dd763da860ebd9ef002aff9e Mon Sep 17 00:00:00 2001 From: = Date: Wed, 19 Oct 2022 15:11:13 +0100 Subject: [PATCH 04/18] Rename _BoundMethodHolder --- skops/io/tests/test_persist.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/skops/io/tests/test_persist.py b/skops/io/tests/test_persist.py index 662064d5..b3a29a90 100644 --- a/skops/io/tests/test_persist.py +++ b/skops/io/tests/test_persist.py @@ -784,8 +784,8 @@ def test_numpy_dtype_object_does_not_store_broken_file(tmp_path): assert not any(file.endswith(".npy") for file in files) -class DummyMethodHolder: - """Used to test the ability to serialise and deserialize""" +class _BoundMethodHolder: + """Used to test the ability to serialise and deserialize bound methods""" def __init__(self, variant: str): if variant == "sqrt": @@ -800,7 +800,7 @@ def apply(self, x): def test_for_serialized_bound_method_works_as_expected(tmp_path): - obj = DummyMethodHolder(variant="log") + obj = _BoundMethodHolder(variant="log") bound_function = obj.apply transformer = FunctionTransformer(func=bound_function) From 1a36c9fc6e5fa27a91594d8ac93cd36d348a4675 Mon Sep 17 00:00:00 2001 From: Erin Date: Fri, 21 Oct 2022 13:01:24 +0100 Subject: [PATCH 05/18] Serialise -> serialize Co-authored-by: Benjamin Bossan --- skops/io/tests/test_persist.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/skops/io/tests/test_persist.py b/skops/io/tests/test_persist.py index d9a54259..7d73c51f 100644 --- a/skops/io/tests/test_persist.py +++ b/skops/io/tests/test_persist.py @@ -807,7 +807,7 @@ def test_loads_from_str(): class _BoundMethodHolder: - """Used to test the ability to serialise and deserialize bound methods""" + """Used to test the ability to serialize and deserialize bound methods""" def __init__(self, variant: str): if variant == "sqrt": From 9e2405dc3a17ec70560f904d4c28e3808dbc88a6 Mon Sep 17 00:00:00 2001 From: = Date: Sat, 22 Oct 2022 16:05:45 +0100 Subject: [PATCH 06/18] Address PR comments --- skops/io/_general.py | 5 ++- skops/io/tests/test_persist.py | 77 +++++++++++++++++++++++++++------- 2 files changed, 65 insertions(+), 17 deletions(-) diff --git a/skops/io/_general.py b/skops/io/_general.py index a7261d54..194adb6c 100644 --- a/skops/io/_general.py +++ b/skops/io/_general.py @@ -234,12 +234,15 @@ def object_get_instance(state, src): def method_get_state(obj: Any, save_state: SaveState): + # This method is used to persist methods bound to a different instance + # It stores the state of the object the method is bound to, + # and prepares both to be persisted. res = { "__class__": obj.__class__.__name__, "__module__": get_module(obj), "content": { "func": obj.__func__.__name__, - "obj": object_get_state(obj.__self__, save_state), + "obj": get_state(obj.__self__, save_state), }, } diff --git a/skops/io/tests/test_persist.py b/skops/io/tests/test_persist.py index 7d73c51f..e659ec3a 100644 --- a/skops/io/tests/test_persist.py +++ b/skops/io/tests/test_persist.py @@ -809,25 +809,70 @@ def test_loads_from_str(): class _BoundMethodHolder: """Used to test the ability to serialize and deserialize bound methods""" - def __init__(self, variant: str): - if variant == "sqrt": - self.chosen_function = np.sqrt - elif variant == "log": - self.chosen_function = np.log - else: - self.chosen_function = np.exp + def __init__(self, object_state: str): + self.object_state = ( + object_state # Initialize with some state to make sure state is persisted + ) + self.chosen_function = ( + np.log + ) # Bound a method to this object (can be any method) - def apply(self, x): + def bound_method(self, x): return self.chosen_function(x) -def test_for_serialized_bound_method_works_as_expected(tmp_path): - obj = _BoundMethodHolder(variant="log") - bound_function = obj.apply - transformer = FunctionTransformer(func=bound_function) +class TestPersistingBoundMethods: + @staticmethod + def assert_transformer_persisted_correctly( + loaded_transformer: FunctionTransformer, + original_transformer: FunctionTransformer, + ): + """Checks that a persisted and original transformer are equivalent, including the func passed to it + """ + assert loaded_transformer.func.__name__ == original_transformer.func.__name__ + + assert_params_equal( + loaded_transformer.func.__self__.__dict__, + original_transformer.func.__self__.__dict__, + ) + assert_params_equal(loaded_transformer.__dict__, original_transformer.__dict__) + + @staticmethod + def assert_bound_method_holder_persisted_correctly( + original_obj: _BoundMethodHolder, loaded_obj: _BoundMethodHolder + ): + """Checks that the persisted and original instances of _BoundMethodHolder are equivalent + """ + assert original_obj.bound_method.__name__ == loaded_obj.bound_method.__name__ + assert original_obj.chosen_function == loaded_obj.chosen_function + + assert_params_equal(original_obj.__dict__, loaded_obj.__dict__) + + def test_for_base_case_returns_as_expected(self, tmp_path): + initial_state = "This is an arbitrary state" + obj = _BoundMethodHolder(object_state=initial_state) + bound_function = obj.bound_method + transformer = FunctionTransformer(func=bound_function) + + loaded_transformer = save_load_round(transformer, tmp_path / "file.skops") + loaded_obj = loaded_transformer.func.__self__ + + self.assert_transformer_persisted_correctly(loaded_transformer, transformer) + self.assert_bound_method_holder_persisted_correctly(obj, loaded_obj) + + def test_when_method_is_not_set_during_init_works_as_expected(self, tmp_path): + # given change to object with bound method after initialisation, + # make sure still persists correctly + + initial_state = "This is an arbitrary state" + obj = _BoundMethodHolder(object_state=initial_state) + obj.chosen_function = np.sqrt + bound_function = obj.bound_method + + transformer = FunctionTransformer(func=bound_function) - loaded_transformer = save_load_round(transformer, tmp_path / "file.skops") - loaded_bound_function = loaded_transformer.func.__self__.chosen_function + loaded_transformer = save_load_round(transformer, tmp_path / "file.skops") + loaded_obj = loaded_transformer.func.__self__ - assert loaded_transformer.func.__name__ == bound_function.__name__ - assert loaded_bound_function == obj.chosen_function + self.assert_transformer_persisted_correctly(loaded_transformer, transformer) + self.assert_bound_method_holder_persisted_correctly(obj, loaded_obj) From 8127aa693c21c4a6c7ad41c6688064ef0fbd652c Mon Sep 17 00:00:00 2001 From: = Date: Sun, 23 Oct 2022 15:05:19 +0100 Subject: [PATCH 07/18] Reword comment --- skops/io/_utils.py | 2 +- skops/io/tests/test_persist.py | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/skops/io/_utils.py b/skops/io/_utils.py index d8e40a81..a17d1d08 100644 --- a/skops/io/_utils.py +++ b/skops/io/_utils.py @@ -5,7 +5,6 @@ import sys from dataclasses import dataclass, field from functools import _find_impl, get_cache_token, update_wrapper # type: ignore -from pathlib import Path from types import FunctionType, MethodType from typing import Any from zipfile import ZipFile @@ -52,6 +51,7 @@ def dispatch(instance): # CHANGED: variable name cls->instance # CHANGED: check if we deal with a state dict, in which case we use it # to resolve the correct class. Otherwise, just use the class of the # instance. + if ( isinstance(instance, dict) and "__module__" in instance diff --git a/skops/io/tests/test_persist.py b/skops/io/tests/test_persist.py index e659ec3a..6d7df449 100644 --- a/skops/io/tests/test_persist.py +++ b/skops/io/tests/test_persist.py @@ -814,8 +814,8 @@ def __init__(self, object_state: str): object_state # Initialize with some state to make sure state is persisted ) self.chosen_function = ( - np.log - ) # Bound a method to this object (can be any method) + np.log # bind some method to this object, could be any persistable function + ) def bound_method(self, x): return self.chosen_function(x) From aec304a974729a38bb8720bbd142f48a197f1262 Mon Sep 17 00:00:00 2001 From: Erin Date: Mon, 24 Oct 2022 13:24:36 +0100 Subject: [PATCH 08/18] Update skops/io/_utils.py Co-authored-by: Benjamin Bossan --- skops/io/_utils.py | 1 - 1 file changed, 1 deletion(-) diff --git a/skops/io/_utils.py b/skops/io/_utils.py index a17d1d08..f61c6b2a 100644 --- a/skops/io/_utils.py +++ b/skops/io/_utils.py @@ -51,7 +51,6 @@ def dispatch(instance): # CHANGED: variable name cls->instance # CHANGED: check if we deal with a state dict, in which case we use it # to resolve the correct class. Otherwise, just use the class of the # instance. - if ( isinstance(instance, dict) and "__module__" in instance From 05f7ce9f91e82fdb60959e8c09c5b683e58579c7 Mon Sep 17 00:00:00 2001 From: Erin Date: Mon, 24 Oct 2022 13:27:24 +0100 Subject: [PATCH 09/18] Change from inline comments Co-authored-by: Benjamin Bossan --- skops/io/tests/test_persist.py | 10 ++++------ 1 file changed, 4 insertions(+), 6 deletions(-) diff --git a/skops/io/tests/test_persist.py b/skops/io/tests/test_persist.py index 6547bd73..e90dd893 100644 --- a/skops/io/tests/test_persist.py +++ b/skops/io/tests/test_persist.py @@ -810,12 +810,10 @@ class _BoundMethodHolder: """Used to test the ability to serialize and deserialize bound methods""" def __init__(self, object_state: str): - self.object_state = ( - object_state # Initialize with some state to make sure state is persisted - ) - self.chosen_function = ( - np.log # bind some method to this object, could be any persistable function - ) + # Initialize with some state to make sure state is persisted + self.object_state = object_state + # bind some method to this object, could be any persistable function + self.chosen_function = np.log def bound_method(self, x): return self.chosen_function(x) From 5f8e102f240f5a6536a71af2da4f15178cdca86f Mon Sep 17 00:00:00 2001 From: = Date: Mon, 24 Oct 2022 13:44:25 +0100 Subject: [PATCH 10/18] Implement code review comments --- skops/io/_general.py | 3 ++- skops/io/tests/test_persist.py | 14 ++++++++------ 2 files changed, 10 insertions(+), 7 deletions(-) diff --git a/skops/io/_general.py b/skops/io/_general.py index 194adb6c..5f384737 100644 --- a/skops/io/_general.py +++ b/skops/io/_general.py @@ -234,7 +234,8 @@ def object_get_instance(state, src): def method_get_state(obj: Any, save_state: SaveState): - # This method is used to persist methods bound to a different instance + # This method is used to persist bound methods, which are + # dependent on a specific instance of an object. # It stores the state of the object the method is bound to, # and prepares both to be persisted. res = { diff --git a/skops/io/tests/test_persist.py b/skops/io/tests/test_persist.py index e90dd893..fafccad9 100644 --- a/skops/io/tests/test_persist.py +++ b/skops/io/tests/test_persist.py @@ -825,7 +825,8 @@ def assert_transformer_persisted_correctly( loaded_transformer: FunctionTransformer, original_transformer: FunctionTransformer, ): - """Checks that a persisted and original transformer are equivalent, including the func passed to it + """Checks that a persisted and original transformer are equivalent, including + the func passed to it """ assert loaded_transformer.func.__name__ == original_transformer.func.__name__ @@ -839,26 +840,27 @@ def assert_transformer_persisted_correctly( def assert_bound_method_holder_persisted_correctly( original_obj: _BoundMethodHolder, loaded_obj: _BoundMethodHolder ): - """Checks that the persisted and original instances of _BoundMethodHolder are equivalent + """Checks that the persisted and original instances of _BoundMethodHolder are + equivalent """ assert original_obj.bound_method.__name__ == loaded_obj.bound_method.__name__ assert original_obj.chosen_function == loaded_obj.chosen_function assert_params_equal(original_obj.__dict__, loaded_obj.__dict__) - def test_for_base_case_returns_as_expected(self, tmp_path): + def test_for_base_case_returns_as_expected(self): initial_state = "This is an arbitrary state" obj = _BoundMethodHolder(object_state=initial_state) bound_function = obj.bound_method transformer = FunctionTransformer(func=bound_function) - loaded_transformer = save_load_round(transformer, tmp_path / "file.skops") + loaded_transformer = loads(dumps(transformer)) loaded_obj = loaded_transformer.func.__self__ self.assert_transformer_persisted_correctly(loaded_transformer, transformer) self.assert_bound_method_holder_persisted_correctly(obj, loaded_obj) - def test_when_method_is_not_set_during_init_works_as_expected(self, tmp_path): + def test_when_object_is_changed_after_init_works_as_expected(self): # given change to object with bound method after initialisation, # make sure still persists correctly @@ -869,7 +871,7 @@ def test_when_method_is_not_set_during_init_works_as_expected(self, tmp_path): transformer = FunctionTransformer(func=bound_function) - loaded_transformer = save_load_round(transformer, tmp_path / "file.skops") + loaded_transformer = loads(dumps(transformer)) loaded_obj = loaded_transformer.func.__self__ self.assert_transformer_persisted_correctly(loaded_transformer, transformer) From 1dda3748c6f733283b62634ae637960e3f6c2bd0 Mon Sep 17 00:00:00 2001 From: = Date: Tue, 25 Oct 2022 09:49:47 +0100 Subject: [PATCH 11/18] Fix gettype to not return FunctionType or MethodType --- skops/io/_utils.py | 7 ------- 1 file changed, 7 deletions(-) diff --git a/skops/io/_utils.py b/skops/io/_utils.py index 95230c9a..2b6f9749 100644 --- a/skops/io/_utils.py +++ b/skops/io/_utils.py @@ -5,7 +5,6 @@ import sys from dataclasses import dataclass, field from functools import singledispatch -from types import FunctionType, MethodType from typing import Any from zipfile import ZipFile @@ -61,12 +60,6 @@ def _import_obj(module, cls_or_func, package=None): def gettype(state): if "__module__" in state and "__class__" in state: - if state["__class__"] == "function": - # This special case is due to how functions are serialized. We - # could try to change it. - return FunctionType - if state["__class__"] == "method": - return MethodType return _import_obj(state["__module__"], state["__class__"]) return None From 7ccf3fd8c6b21eec1293119030ab984d82a34d55 Mon Sep 17 00:00:00 2001 From: = Date: Tue, 25 Oct 2022 14:49:18 +0100 Subject: [PATCH 12/18] whitespace for linter --- skops/io/tests/test_persist.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/skops/io/tests/test_persist.py b/skops/io/tests/test_persist.py index 38f26b2d..f6198cb1 100644 --- a/skops/io/tests/test_persist.py +++ b/skops/io/tests/test_persist.py @@ -863,7 +863,7 @@ def test_when_object_is_changed_after_init_works_as_expected(self): self.assert_transformer_persisted_correctly(loaded_transformer, transformer) self.assert_bound_method_holder_persisted_correctly(obj, loaded_obj) - + class CustomEstimator(BaseEstimator): """Estimator with np array, np scalar, and sparse matrix attribute""" From 78db91abd0f7b7c0f2789d2baa8a4d0ec253002f Mon Sep 17 00:00:00 2001 From: Erin Date: Tue, 25 Oct 2022 18:55:12 +0100 Subject: [PATCH 13/18] Update skops/io/_general.py Co-authored-by: Benjamin Bossan --- skops/io/_general.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/skops/io/_general.py b/skops/io/_general.py index c0594319..a2e712e9 100644 --- a/skops/io/_general.py +++ b/skops/io/_general.py @@ -261,7 +261,7 @@ def method_get_state(obj: Any, save_state: SaveState): def method_get_instance(state, src): - loaded_obj = object_get_instance(state["content"]["obj"], src) + loaded_obj = get_instance(state["content"]["obj"], src) method = getattr(loaded_obj, state["content"]["func"]) return method From b6aa8fd96f7762af44e5382ce539ba14ad616c2a Mon Sep 17 00:00:00 2001 From: = Date: Thu, 27 Oct 2022 20:01:31 +0100 Subject: [PATCH 14/18] Add LoadState to get_instance functions --- skops/io/_dispatch.py | 14 ++++++++-- skops/io/_general.py | 49 +++++++++++++++++++--------------- skops/io/_numpy.py | 22 +++++++-------- skops/io/_persist.py | 6 ++--- skops/io/_scipy.py | 2 +- skops/io/_sklearn.py | 20 +++++++------- skops/io/_utils.py | 31 ++++++++++++++++++++- skops/io/tests/test_persist.py | 42 ++++++++++++++++++++++++++--- 8 files changed, 133 insertions(+), 53 deletions(-) diff --git a/skops/io/_dispatch.py b/skops/io/_dispatch.py index e0ae9a96..b204b26c 100644 --- a/skops/io/_dispatch.py +++ b/skops/io/_dispatch.py @@ -2,14 +2,20 @@ import json +from skops.io._utils import LoadState + GET_INSTANCE_MAPPING = {} # type: ignore -def get_instance(state, src): +def get_instance(state, src, load_state: LoadState): """Create instance based on the state, using json if possible""" if state.get("is_json"): return json.loads(state["content"]) + saved_id = state.get("__id__") + if saved_id and saved_id in load_state.memo: + return load_state.get_instance(saved_id) + try: get_instance_func = GET_INSTANCE_MAPPING[state["__loader__"]] except KeyError: @@ -17,4 +23,8 @@ def get_instance(state, src): raise TypeError( f" Can't find loader {state['__loader__']} for type {type_name}." ) - return get_instance_func(state, src) + + loaded_obj = get_instance_func(state, src, load_state) + if saved_id: + load_state.memoize(loaded_obj, saved_id) + return loaded_obj diff --git a/skops/io/_general.py b/skops/io/_general.py index c0594319..4f74e13e 100644 --- a/skops/io/_general.py +++ b/skops/io/_general.py @@ -33,11 +33,11 @@ def dict_get_state(obj: Any, save_state: SaveState) -> dict[str, Any]: return res -def dict_get_instance(state, src): +def dict_get_instance(state, src, load_state): content = gettype(state)() - key_types = get_instance(state["key_types"], src) + key_types = get_instance(state["key_types"], src, load_state) for k_type, item in zip(key_types, state["content"].items()): - content[k_type(item[0])] = get_instance(item[1], src) + content[k_type(item[0])] = get_instance(item[1], src, load_state) return content @@ -54,10 +54,10 @@ def list_get_state(obj: Any, save_state: SaveState) -> dict[str, Any]: return res -def list_get_instance(state, src): +def list_get_instance(state, src, load_state): content = gettype(state)() for value in state["content"]: - content.append(get_instance(value, src)) + content.append(get_instance(value, src, load_state)) return content @@ -72,7 +72,7 @@ def tuple_get_state(obj: Any, save_state: SaveState) -> dict[str, Any]: return res -def tuple_get_instance(state, src): +def tuple_get_instance(state, src, load_state): # Returns a tuple or a namedtuple instance. def isnamedtuple(t): # This is needed since namedtuples need to have the args when @@ -86,7 +86,7 @@ def isnamedtuple(t): return all(type(n) == str for n in f) cls = gettype(state) - content = tuple(get_instance(value, src) for value in state["content"]) + content = tuple(get_instance(value, src, load_state) for value in state["content"]) if isnamedtuple(cls): return cls(*content) @@ -106,7 +106,7 @@ def function_get_state(obj: Any, save_state: SaveState) -> dict[str, Any]: return res -def function_get_instance(state, src): +def function_get_instance(state, src, load_state): loaded = _import_obj(state["content"]["module_path"], state["content"]["function"]) return loaded @@ -127,12 +127,12 @@ def partial_get_state(obj: Any, save_state: SaveState) -> dict[str, Any]: return res -def partial_get_instance(state, src): +def partial_get_instance(state, src, load_state): content = state["content"] - func = get_instance(content["func"], src) - args = get_instance(content["args"], src) - kwds = get_instance(content["kwds"], src) - namespace = get_instance(content["namespace"], src) + func = get_instance(content["func"], src, load_state) + args = get_instance(content["args"], src, load_state) + kwds = get_instance(content["kwds"], src, load_state) + namespace = get_instance(content["namespace"], src, load_state) instance = partial(func, *args, **kwds) # always use partial, not a subclass instance.__setstate__((func, args, kwds, namespace)) return instance @@ -153,7 +153,7 @@ def type_get_state(obj: Any, save_state: SaveState) -> dict[str, Any]: return res -def type_get_instance(state, src): +def type_get_instance(state, src, load_state): loaded = _import_obj(state["content"]["__module__"], state["content"]["__class__"]) return loaded @@ -172,7 +172,7 @@ def slice_get_state(obj: Any, save_state: SaveState) -> dict[str, Any]: return res -def slice_get_instance(state, src): +def slice_get_instance(state, src, load_state): start = state["content"]["start"] stop = state["content"]["stop"] step = state["content"]["step"] @@ -190,6 +190,7 @@ def object_get_state(obj: Any, save_state: SaveState) -> dict[str, Any]: "__class__": "str", "__module__": "builtins", "__loader__": "none", + "__id__": id(obj), "content": obj_str, "is_json": True, } @@ -200,6 +201,7 @@ def object_get_state(obj: Any, save_state: SaveState) -> dict[str, Any]: "__class__": obj.__class__.__name__, "__module__": get_module(type(obj)), "__loader__": "object_get_instance", + "__id__": id(obj), } # __getstate__ takes priority over __dict__, and if non exist, we only save @@ -218,7 +220,7 @@ def object_get_state(obj: Any, save_state: SaveState) -> dict[str, Any]: return res -def object_get_instance(state, src): +def object_get_instance(state, src, load_state): if state.get("is_json", False): return json.loads(state["content"]) @@ -233,7 +235,7 @@ def object_get_instance(state, src): if not content: # nothing more to do return instance - attrs = get_instance(content, src) + attrs = get_instance(content, src, load_state) if hasattr(instance, "__setstate__"): instance.__setstate__(attrs) else: @@ -253,15 +255,20 @@ def method_get_state(obj: Any, save_state: SaveState): "__loader__": "method_get_instance", "content": { "func": obj.__func__.__name__, - "obj": get_state(obj.__self__, save_state), + "__id__": id(obj.__self__), }, } - + root_obj = obj.__self__ + obj_state = save_state.get_memoized_state(root_obj) + if obj_state is None: + obj_state = get_state(root_obj, save_state) + save_state.store_state(root_obj, obj_state) + res["content"]["obj"] = obj_state return res -def method_get_instance(state, src): - loaded_obj = object_get_instance(state["content"]["obj"], src) +def method_get_instance(state, src, load_state): + loaded_obj = get_instance(state["content"]["obj"], src, load_state) method = getattr(loaded_obj, state["content"]["func"]) return method diff --git a/skops/io/_numpy.py b/skops/io/_numpy.py index 4d1d5b98..bfe4ac47 100644 --- a/skops/io/_numpy.py +++ b/skops/io/_numpy.py @@ -50,7 +50,7 @@ def ndarray_get_state(obj: Any, save_state: SaveState) -> dict[str, Any]: return res -def ndarray_get_instance(state, src): +def ndarray_get_instance(state, src, load_state): # Dealing with a regular numpy array, where dtype != object if state["type"] == "numpy": val = np.load(io.BytesIO(src.read(state["file"])), allow_pickle=False) @@ -63,8 +63,8 @@ def ndarray_get_instance(state, src): # We explicitly set the dtype to "O" since we only save object arrays in # json. - shape = get_instance(state["shape"], src) - tmp = [get_instance(s, src) for s in state["content"]] + shape = get_instance(state["shape"], src, load_state) + tmp = [get_instance(s, src, load_state) for s in state["content"]] # TODO: this is a hack to get the correct shape of the array. We should # find _a better way_ to do this. if len(shape) == 1: @@ -89,9 +89,9 @@ def maskedarray_get_state(obj: Any, save_state: SaveState) -> dict[str, Any]: return res -def maskedarray_get_instance(state, src): - data = get_instance(state["content"]["data"], src) - mask = get_instance(state["content"]["mask"], src) +def maskedarray_get_instance(state, src, load_state): + data = get_instance(state["content"]["data"], src, load_state) + mask = get_instance(state["content"]["mask"], src, load_state) return np.ma.MaskedArray(data, mask) @@ -106,10 +106,10 @@ def random_state_get_state(obj: Any, save_state: SaveState) -> dict[str, Any]: return res -def random_state_get_instance(state, src): +def random_state_get_instance(state, src, load_state): cls = _import_obj(state["__module__"], state["__class__"]) random_state = cls() - content = get_instance(state["content"], src) + content = get_instance(state["content"], src, load_state) random_state.set_state(content) return random_state @@ -125,7 +125,7 @@ def random_generator_get_state(obj: Any, save_state: SaveState) -> dict[str, Any return res -def random_generator_get_instance(state, src): +def random_generator_get_instance(state, src, load_state): # first restore the state of the bit generator bit_generator_state = state["content"]["bit_generator"] bit_generator = _import_obj("numpy.random", bit_generator_state["bit_generator"])() @@ -166,10 +166,10 @@ def dtype_get_state(obj: Any, save_state: SaveState) -> dict[str, Any]: return res -def dtype_get_instance(state, src): +def dtype_get_instance(state, src, load_state): # we use numpy's internal save mechanism to store the dtype by # saving/loading an empty array with that dtype. - tmp = ndarray_get_instance(state["content"], src) + tmp = ndarray_get_instance(state["content"], src, load_state) return tmp.dtype diff --git a/skops/io/_persist.py b/skops/io/_persist.py index 9144da08..81571670 100644 --- a/skops/io/_persist.py +++ b/skops/io/_persist.py @@ -8,7 +8,7 @@ import skops from ._dispatch import GET_INSTANCE_MAPPING, get_instance -from ._utils import SaveState, _get_state, get_state +from ._utils import LoadState, SaveState, _get_state, get_state # We load the dispatch functions from the corresponding modules and register # them. @@ -115,7 +115,7 @@ def load(file): """ with ZipFile(file, "r") as input_zip: schema = input_zip.read("schema.json") - instance = get_instance(json.loads(schema), input_zip) + instance = get_instance(json.loads(schema), input_zip, LoadState()) return instance @@ -141,5 +141,5 @@ def loads(data): with ZipFile(io.BytesIO(data), "r") as zip_file: schema = json.loads(zip_file.read("schema.json")) - instance = get_instance(schema, src=zip_file) + instance = get_instance(schema, src=zip_file, load_state=LoadState()) return instance diff --git a/skops/io/_scipy.py b/skops/io/_scipy.py index 305d30dc..e8679afc 100644 --- a/skops/io/_scipy.py +++ b/skops/io/_scipy.py @@ -31,7 +31,7 @@ def sparse_matrix_get_state(obj: Any, save_state: SaveState) -> dict[str, Any]: return res -def sparse_matrix_get_instance(state, src): +def sparse_matrix_get_instance(state, src, load_state): if state["type"] != "scipy": raise TypeError( f"Cannot load object of type {state['__module__']}.{state['__class__']}" diff --git a/skops/io/_sklearn.py b/skops/io/_sklearn.py index 6cde5463..2961ba5f 100644 --- a/skops/io/_sklearn.py +++ b/skops/io/_sklearn.py @@ -87,12 +87,12 @@ def reduce_get_state(obj: Any, save_state: SaveState) -> dict[str, Any]: return res -def reduce_get_instance(state, src, constructor): +def reduce_get_instance(state, src, load_state, constructor): reduce = state["__reduce__"] - args = get_instance(reduce["args"], src) + args = get_instance(reduce["args"], src, load_state) instance = constructor(*args) - attrs = get_instance(state["content"], src) + attrs = get_instance(state["content"], src, load_state) if not attrs: # nothing more to do return instance @@ -116,8 +116,8 @@ def tree_get_state(obj: Any, save_state: SaveState) -> dict[str, Any]: return state -def tree_get_instance(state, src): - return reduce_get_instance(state, src, constructor=Tree) +def tree_get_instance(state, src, load_state): + return reduce_get_instance(state, src, load_state, constructor=Tree) def sgd_loss_get_state(obj: Any, save_state: SaveState) -> dict[str, Any]: @@ -126,11 +126,11 @@ def sgd_loss_get_state(obj: Any, save_state: SaveState) -> dict[str, Any]: return state -def sgd_loss_get_instance(state, src): +def sgd_loss_get_instance(state, src, load_state): cls = gettype(state) if cls not in ALLOWED_SGD_LOSSES: raise UnsupportedTypeException(f"Expected LossFunction, got {cls}") - return reduce_get_instance(state, src, constructor=cls) + return reduce_get_instance(state, src, load_state, constructor=cls) # TODO: remove once support for sklearn<1.2 is dropped. @@ -152,11 +152,11 @@ def _DictWithDeprecatedKeys_get_state( # TODO: remove once support for sklearn<1.2 is dropped. -def _DictWithDeprecatedKeys_get_instance(state, src): +def _DictWithDeprecatedKeys_get_instance(state, src, load_state): # _DictWithDeprecatedKeys is just a wrapper for dict - content = dict_get_instance(state["content"]["main"], src) + content = dict_get_instance(state["content"]["main"], src, load_state) deprecated_key_to_new_key = dict_get_instance( - state["content"]["_deprecated_key_to_new_key"], src + state["content"]["_deprecated_key_to_new_key"], src, load_state ) res = _DictWithDeprecatedKeys(**content) res._deprecated_key_to_new_key = deprecated_key_to_new_key diff --git a/skops/io/_utils.py b/skops/io/_utils.py index 2b6f9749..4409ad7a 100644 --- a/skops/io/_utils.py +++ b/skops/io/_utils.py @@ -5,7 +5,7 @@ import sys from dataclasses import dataclass, field from functools import singledispatch -from typing import Any +from typing import Any, Optional from zipfile import ZipFile @@ -121,10 +121,39 @@ def memoize(self, obj: Any) -> int: self.memo[obj_id] = obj return obj_id + def get_memoized_state(self, obj: Any) -> Optional[dict]: + # Used in persisting a single object that is referenced in multiple places + obj_id = id(obj) + return self.memo.get(obj_id) + + def store_state(self, obj: Any, state: dict) -> None: + obj_id = id(obj) + if obj_id not in self.memo: + self.memo[obj_id] = state + def clear_memo(self) -> None: self.memo.clear() +@dataclass(frozen=True) +class LoadState: + + """ + State required for loading objects. + This state is passed to each ``get_instance_*`` function. + This is primarily used to hold references to objects which exist in multiple places + in the state tree. + """ + + memo: dict[int, Any] = field(default_factory=dict) + + def memoize(self, obj: Any, id: int) -> None: + self.memo[id] = obj + + def get_instance(self, id: int) -> Any: + return self.memo.get(id) + + @singledispatch def _get_state(obj, save_state): # This function should never be called directly. Instead, it is used to diff --git a/skops/io/tests/test_persist.py b/skops/io/tests/test_persist.py index f6198cb1..47bcdb27 100644 --- a/skops/io/tests/test_persist.py +++ b/skops/io/tests/test_persist.py @@ -96,16 +96,16 @@ def wrapper(obj, save_state): def debug_get_instance(func): # check consistency of argument names and input type signature = inspect.signature(func) - assert list(signature.parameters.keys()) == ["state", "src"] + assert list(signature.parameters.keys()) == ["state", "src", "load_state"] @wraps(func) - def wrapper(state, src): + def wrapper(state, src, load_state): assert "__class__" in state assert "__module__" in state assert "__loader__" in state assert isinstance(src, ZipFile) - result = func(state, src) + result = func(state, src, load_state) return result return wrapper @@ -789,7 +789,7 @@ def test_get_instance_unknown_type_error_msg(): state["__loader__"] = "this_get_instance_does_not_exist" msg = "Can't find loader this_get_instance_does_not_exist for type builtins.tuple." with pytest.raises(TypeError, match=msg): - get_instance(state, None) + get_instance(state, None, None) class _BoundMethodHolder: @@ -804,6 +804,21 @@ def __init__(self, object_state: str): def bound_method(self, x): return self.chosen_function(x) + def other_bound_method(self, x): + return self.chosen_function(x) + + +class _MultipleMethodHolder: + """Used to test the ability to serialize and deserialize""" + + def __init__(self, func1, func2): + self.func1 = func1 + self.func2 = func2 + + def bound_method(self, x): + # arbitrary function that uses both funcs + return self.func1(x) + self.func2(x) + class TestPersistingBoundMethods: @staticmethod @@ -863,6 +878,25 @@ def test_when_object_is_changed_after_init_works_as_expected(self): self.assert_transformer_persisted_correctly(loaded_transformer, transformer) self.assert_bound_method_holder_persisted_correctly(obj, loaded_obj) + def test_works_when_given_multiple_bound_methods_attached_to_single_instance(self): + initial_state = "Any arbitrary state" + original_obj = _BoundMethodHolder(object_state=initial_state) + multiple_bound_method = _MultipleMethodHolder( + original_obj.bound_method, original_obj.other_bound_method + ) + + bound_function = multiple_bound_method.bound_method + + transformer = FunctionTransformer(func=bound_function) + + loaded_transformer = loads(dumps(transformer)) + loaded_obj = loaded_transformer.func.__self__ + + self.assert_transformer_persisted_correctly(loaded_transformer, transformer) + + # check that both func1 and func2 are from the same object + assert loaded_obj.func1.__self__ == loaded_obj.func2.__self__ + class CustomEstimator(BaseEstimator): """Estimator with np array, np scalar, and sparse matrix attribute""" From 30b014738d3729fbfab36062b28d23df35369653 Mon Sep 17 00:00:00 2001 From: = Date: Thu, 27 Oct 2022 20:03:12 +0100 Subject: [PATCH 15/18] Add minor comments to additions to get_instance --- skops/io/_dispatch.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/skops/io/_dispatch.py b/skops/io/_dispatch.py index b204b26c..6c010e3b 100644 --- a/skops/io/_dispatch.py +++ b/skops/io/_dispatch.py @@ -13,7 +13,9 @@ def get_instance(state, src, load_state: LoadState): return json.loads(state["content"]) saved_id = state.get("__id__") + if saved_id and saved_id in load_state.memo: + # same instance already loaded elsewhere in tree return load_state.get_instance(saved_id) try: @@ -25,6 +27,8 @@ def get_instance(state, src, load_state: LoadState): ) loaded_obj = get_instance_func(state, src, load_state) + if saved_id: load_state.memoize(loaded_obj, saved_id) + return loaded_obj From ddd94be77d9357bf9506fe65da49d9efa7637160 Mon Sep 17 00:00:00 2001 From: = Date: Fri, 28 Oct 2022 20:01:26 +0100 Subject: [PATCH 16/18] Revert "Add LoadState to get_instance functions" --- skops/io/_dispatch.py | 18 ++----------- skops/io/_general.py | 49 +++++++++++++++------------------- skops/io/_numpy.py | 22 +++++++-------- skops/io/_persist.py | 6 ++--- skops/io/_scipy.py | 2 +- skops/io/_sklearn.py | 20 +++++++------- skops/io/_utils.py | 31 +-------------------- skops/io/tests/test_persist.py | 42 +++-------------------------- 8 files changed, 53 insertions(+), 137 deletions(-) diff --git a/skops/io/_dispatch.py b/skops/io/_dispatch.py index 6c010e3b..e0ae9a96 100644 --- a/skops/io/_dispatch.py +++ b/skops/io/_dispatch.py @@ -2,22 +2,14 @@ import json -from skops.io._utils import LoadState - GET_INSTANCE_MAPPING = {} # type: ignore -def get_instance(state, src, load_state: LoadState): +def get_instance(state, src): """Create instance based on the state, using json if possible""" if state.get("is_json"): return json.loads(state["content"]) - saved_id = state.get("__id__") - - if saved_id and saved_id in load_state.memo: - # same instance already loaded elsewhere in tree - return load_state.get_instance(saved_id) - try: get_instance_func = GET_INSTANCE_MAPPING[state["__loader__"]] except KeyError: @@ -25,10 +17,4 @@ def get_instance(state, src, load_state: LoadState): raise TypeError( f" Can't find loader {state['__loader__']} for type {type_name}." ) - - loaded_obj = get_instance_func(state, src, load_state) - - if saved_id: - load_state.memoize(loaded_obj, saved_id) - - return loaded_obj + return get_instance_func(state, src) diff --git a/skops/io/_general.py b/skops/io/_general.py index 4f74e13e..c0594319 100644 --- a/skops/io/_general.py +++ b/skops/io/_general.py @@ -33,11 +33,11 @@ def dict_get_state(obj: Any, save_state: SaveState) -> dict[str, Any]: return res -def dict_get_instance(state, src, load_state): +def dict_get_instance(state, src): content = gettype(state)() - key_types = get_instance(state["key_types"], src, load_state) + key_types = get_instance(state["key_types"], src) for k_type, item in zip(key_types, state["content"].items()): - content[k_type(item[0])] = get_instance(item[1], src, load_state) + content[k_type(item[0])] = get_instance(item[1], src) return content @@ -54,10 +54,10 @@ def list_get_state(obj: Any, save_state: SaveState) -> dict[str, Any]: return res -def list_get_instance(state, src, load_state): +def list_get_instance(state, src): content = gettype(state)() for value in state["content"]: - content.append(get_instance(value, src, load_state)) + content.append(get_instance(value, src)) return content @@ -72,7 +72,7 @@ def tuple_get_state(obj: Any, save_state: SaveState) -> dict[str, Any]: return res -def tuple_get_instance(state, src, load_state): +def tuple_get_instance(state, src): # Returns a tuple or a namedtuple instance. def isnamedtuple(t): # This is needed since namedtuples need to have the args when @@ -86,7 +86,7 @@ def isnamedtuple(t): return all(type(n) == str for n in f) cls = gettype(state) - content = tuple(get_instance(value, src, load_state) for value in state["content"]) + content = tuple(get_instance(value, src) for value in state["content"]) if isnamedtuple(cls): return cls(*content) @@ -106,7 +106,7 @@ def function_get_state(obj: Any, save_state: SaveState) -> dict[str, Any]: return res -def function_get_instance(state, src, load_state): +def function_get_instance(state, src): loaded = _import_obj(state["content"]["module_path"], state["content"]["function"]) return loaded @@ -127,12 +127,12 @@ def partial_get_state(obj: Any, save_state: SaveState) -> dict[str, Any]: return res -def partial_get_instance(state, src, load_state): +def partial_get_instance(state, src): content = state["content"] - func = get_instance(content["func"], src, load_state) - args = get_instance(content["args"], src, load_state) - kwds = get_instance(content["kwds"], src, load_state) - namespace = get_instance(content["namespace"], src, load_state) + func = get_instance(content["func"], src) + args = get_instance(content["args"], src) + kwds = get_instance(content["kwds"], src) + namespace = get_instance(content["namespace"], src) instance = partial(func, *args, **kwds) # always use partial, not a subclass instance.__setstate__((func, args, kwds, namespace)) return instance @@ -153,7 +153,7 @@ def type_get_state(obj: Any, save_state: SaveState) -> dict[str, Any]: return res -def type_get_instance(state, src, load_state): +def type_get_instance(state, src): loaded = _import_obj(state["content"]["__module__"], state["content"]["__class__"]) return loaded @@ -172,7 +172,7 @@ def slice_get_state(obj: Any, save_state: SaveState) -> dict[str, Any]: return res -def slice_get_instance(state, src, load_state): +def slice_get_instance(state, src): start = state["content"]["start"] stop = state["content"]["stop"] step = state["content"]["step"] @@ -190,7 +190,6 @@ def object_get_state(obj: Any, save_state: SaveState) -> dict[str, Any]: "__class__": "str", "__module__": "builtins", "__loader__": "none", - "__id__": id(obj), "content": obj_str, "is_json": True, } @@ -201,7 +200,6 @@ def object_get_state(obj: Any, save_state: SaveState) -> dict[str, Any]: "__class__": obj.__class__.__name__, "__module__": get_module(type(obj)), "__loader__": "object_get_instance", - "__id__": id(obj), } # __getstate__ takes priority over __dict__, and if non exist, we only save @@ -220,7 +218,7 @@ def object_get_state(obj: Any, save_state: SaveState) -> dict[str, Any]: return res -def object_get_instance(state, src, load_state): +def object_get_instance(state, src): if state.get("is_json", False): return json.loads(state["content"]) @@ -235,7 +233,7 @@ def object_get_instance(state, src, load_state): if not content: # nothing more to do return instance - attrs = get_instance(content, src, load_state) + attrs = get_instance(content, src) if hasattr(instance, "__setstate__"): instance.__setstate__(attrs) else: @@ -255,20 +253,15 @@ def method_get_state(obj: Any, save_state: SaveState): "__loader__": "method_get_instance", "content": { "func": obj.__func__.__name__, - "__id__": id(obj.__self__), + "obj": get_state(obj.__self__, save_state), }, } - root_obj = obj.__self__ - obj_state = save_state.get_memoized_state(root_obj) - if obj_state is None: - obj_state = get_state(root_obj, save_state) - save_state.store_state(root_obj, obj_state) - res["content"]["obj"] = obj_state + return res -def method_get_instance(state, src, load_state): - loaded_obj = get_instance(state["content"]["obj"], src, load_state) +def method_get_instance(state, src): + loaded_obj = object_get_instance(state["content"]["obj"], src) method = getattr(loaded_obj, state["content"]["func"]) return method diff --git a/skops/io/_numpy.py b/skops/io/_numpy.py index bfe4ac47..4d1d5b98 100644 --- a/skops/io/_numpy.py +++ b/skops/io/_numpy.py @@ -50,7 +50,7 @@ def ndarray_get_state(obj: Any, save_state: SaveState) -> dict[str, Any]: return res -def ndarray_get_instance(state, src, load_state): +def ndarray_get_instance(state, src): # Dealing with a regular numpy array, where dtype != object if state["type"] == "numpy": val = np.load(io.BytesIO(src.read(state["file"])), allow_pickle=False) @@ -63,8 +63,8 @@ def ndarray_get_instance(state, src, load_state): # We explicitly set the dtype to "O" since we only save object arrays in # json. - shape = get_instance(state["shape"], src, load_state) - tmp = [get_instance(s, src, load_state) for s in state["content"]] + shape = get_instance(state["shape"], src) + tmp = [get_instance(s, src) for s in state["content"]] # TODO: this is a hack to get the correct shape of the array. We should # find _a better way_ to do this. if len(shape) == 1: @@ -89,9 +89,9 @@ def maskedarray_get_state(obj: Any, save_state: SaveState) -> dict[str, Any]: return res -def maskedarray_get_instance(state, src, load_state): - data = get_instance(state["content"]["data"], src, load_state) - mask = get_instance(state["content"]["mask"], src, load_state) +def maskedarray_get_instance(state, src): + data = get_instance(state["content"]["data"], src) + mask = get_instance(state["content"]["mask"], src) return np.ma.MaskedArray(data, mask) @@ -106,10 +106,10 @@ def random_state_get_state(obj: Any, save_state: SaveState) -> dict[str, Any]: return res -def random_state_get_instance(state, src, load_state): +def random_state_get_instance(state, src): cls = _import_obj(state["__module__"], state["__class__"]) random_state = cls() - content = get_instance(state["content"], src, load_state) + content = get_instance(state["content"], src) random_state.set_state(content) return random_state @@ -125,7 +125,7 @@ def random_generator_get_state(obj: Any, save_state: SaveState) -> dict[str, Any return res -def random_generator_get_instance(state, src, load_state): +def random_generator_get_instance(state, src): # first restore the state of the bit generator bit_generator_state = state["content"]["bit_generator"] bit_generator = _import_obj("numpy.random", bit_generator_state["bit_generator"])() @@ -166,10 +166,10 @@ def dtype_get_state(obj: Any, save_state: SaveState) -> dict[str, Any]: return res -def dtype_get_instance(state, src, load_state): +def dtype_get_instance(state, src): # we use numpy's internal save mechanism to store the dtype by # saving/loading an empty array with that dtype. - tmp = ndarray_get_instance(state["content"], src, load_state) + tmp = ndarray_get_instance(state["content"], src) return tmp.dtype diff --git a/skops/io/_persist.py b/skops/io/_persist.py index 81571670..9144da08 100644 --- a/skops/io/_persist.py +++ b/skops/io/_persist.py @@ -8,7 +8,7 @@ import skops from ._dispatch import GET_INSTANCE_MAPPING, get_instance -from ._utils import LoadState, SaveState, _get_state, get_state +from ._utils import SaveState, _get_state, get_state # We load the dispatch functions from the corresponding modules and register # them. @@ -115,7 +115,7 @@ def load(file): """ with ZipFile(file, "r") as input_zip: schema = input_zip.read("schema.json") - instance = get_instance(json.loads(schema), input_zip, LoadState()) + instance = get_instance(json.loads(schema), input_zip) return instance @@ -141,5 +141,5 @@ def loads(data): with ZipFile(io.BytesIO(data), "r") as zip_file: schema = json.loads(zip_file.read("schema.json")) - instance = get_instance(schema, src=zip_file, load_state=LoadState()) + instance = get_instance(schema, src=zip_file) return instance diff --git a/skops/io/_scipy.py b/skops/io/_scipy.py index e8679afc..305d30dc 100644 --- a/skops/io/_scipy.py +++ b/skops/io/_scipy.py @@ -31,7 +31,7 @@ def sparse_matrix_get_state(obj: Any, save_state: SaveState) -> dict[str, Any]: return res -def sparse_matrix_get_instance(state, src, load_state): +def sparse_matrix_get_instance(state, src): if state["type"] != "scipy": raise TypeError( f"Cannot load object of type {state['__module__']}.{state['__class__']}" diff --git a/skops/io/_sklearn.py b/skops/io/_sklearn.py index 2961ba5f..6cde5463 100644 --- a/skops/io/_sklearn.py +++ b/skops/io/_sklearn.py @@ -87,12 +87,12 @@ def reduce_get_state(obj: Any, save_state: SaveState) -> dict[str, Any]: return res -def reduce_get_instance(state, src, load_state, constructor): +def reduce_get_instance(state, src, constructor): reduce = state["__reduce__"] - args = get_instance(reduce["args"], src, load_state) + args = get_instance(reduce["args"], src) instance = constructor(*args) - attrs = get_instance(state["content"], src, load_state) + attrs = get_instance(state["content"], src) if not attrs: # nothing more to do return instance @@ -116,8 +116,8 @@ def tree_get_state(obj: Any, save_state: SaveState) -> dict[str, Any]: return state -def tree_get_instance(state, src, load_state): - return reduce_get_instance(state, src, load_state, constructor=Tree) +def tree_get_instance(state, src): + return reduce_get_instance(state, src, constructor=Tree) def sgd_loss_get_state(obj: Any, save_state: SaveState) -> dict[str, Any]: @@ -126,11 +126,11 @@ def sgd_loss_get_state(obj: Any, save_state: SaveState) -> dict[str, Any]: return state -def sgd_loss_get_instance(state, src, load_state): +def sgd_loss_get_instance(state, src): cls = gettype(state) if cls not in ALLOWED_SGD_LOSSES: raise UnsupportedTypeException(f"Expected LossFunction, got {cls}") - return reduce_get_instance(state, src, load_state, constructor=cls) + return reduce_get_instance(state, src, constructor=cls) # TODO: remove once support for sklearn<1.2 is dropped. @@ -152,11 +152,11 @@ def _DictWithDeprecatedKeys_get_state( # TODO: remove once support for sklearn<1.2 is dropped. -def _DictWithDeprecatedKeys_get_instance(state, src, load_state): +def _DictWithDeprecatedKeys_get_instance(state, src): # _DictWithDeprecatedKeys is just a wrapper for dict - content = dict_get_instance(state["content"]["main"], src, load_state) + content = dict_get_instance(state["content"]["main"], src) deprecated_key_to_new_key = dict_get_instance( - state["content"]["_deprecated_key_to_new_key"], src, load_state + state["content"]["_deprecated_key_to_new_key"], src ) res = _DictWithDeprecatedKeys(**content) res._deprecated_key_to_new_key = deprecated_key_to_new_key diff --git a/skops/io/_utils.py b/skops/io/_utils.py index 4409ad7a..2b6f9749 100644 --- a/skops/io/_utils.py +++ b/skops/io/_utils.py @@ -5,7 +5,7 @@ import sys from dataclasses import dataclass, field from functools import singledispatch -from typing import Any, Optional +from typing import Any from zipfile import ZipFile @@ -121,39 +121,10 @@ def memoize(self, obj: Any) -> int: self.memo[obj_id] = obj return obj_id - def get_memoized_state(self, obj: Any) -> Optional[dict]: - # Used in persisting a single object that is referenced in multiple places - obj_id = id(obj) - return self.memo.get(obj_id) - - def store_state(self, obj: Any, state: dict) -> None: - obj_id = id(obj) - if obj_id not in self.memo: - self.memo[obj_id] = state - def clear_memo(self) -> None: self.memo.clear() -@dataclass(frozen=True) -class LoadState: - - """ - State required for loading objects. - This state is passed to each ``get_instance_*`` function. - This is primarily used to hold references to objects which exist in multiple places - in the state tree. - """ - - memo: dict[int, Any] = field(default_factory=dict) - - def memoize(self, obj: Any, id: int) -> None: - self.memo[id] = obj - - def get_instance(self, id: int) -> Any: - return self.memo.get(id) - - @singledispatch def _get_state(obj, save_state): # This function should never be called directly. Instead, it is used to diff --git a/skops/io/tests/test_persist.py b/skops/io/tests/test_persist.py index 47bcdb27..f6198cb1 100644 --- a/skops/io/tests/test_persist.py +++ b/skops/io/tests/test_persist.py @@ -96,16 +96,16 @@ def wrapper(obj, save_state): def debug_get_instance(func): # check consistency of argument names and input type signature = inspect.signature(func) - assert list(signature.parameters.keys()) == ["state", "src", "load_state"] + assert list(signature.parameters.keys()) == ["state", "src"] @wraps(func) - def wrapper(state, src, load_state): + def wrapper(state, src): assert "__class__" in state assert "__module__" in state assert "__loader__" in state assert isinstance(src, ZipFile) - result = func(state, src, load_state) + result = func(state, src) return result return wrapper @@ -789,7 +789,7 @@ def test_get_instance_unknown_type_error_msg(): state["__loader__"] = "this_get_instance_does_not_exist" msg = "Can't find loader this_get_instance_does_not_exist for type builtins.tuple." with pytest.raises(TypeError, match=msg): - get_instance(state, None, None) + get_instance(state, None) class _BoundMethodHolder: @@ -804,21 +804,6 @@ def __init__(self, object_state: str): def bound_method(self, x): return self.chosen_function(x) - def other_bound_method(self, x): - return self.chosen_function(x) - - -class _MultipleMethodHolder: - """Used to test the ability to serialize and deserialize""" - - def __init__(self, func1, func2): - self.func1 = func1 - self.func2 = func2 - - def bound_method(self, x): - # arbitrary function that uses both funcs - return self.func1(x) + self.func2(x) - class TestPersistingBoundMethods: @staticmethod @@ -878,25 +863,6 @@ def test_when_object_is_changed_after_init_works_as_expected(self): self.assert_transformer_persisted_correctly(loaded_transformer, transformer) self.assert_bound_method_holder_persisted_correctly(obj, loaded_obj) - def test_works_when_given_multiple_bound_methods_attached_to_single_instance(self): - initial_state = "Any arbitrary state" - original_obj = _BoundMethodHolder(object_state=initial_state) - multiple_bound_method = _MultipleMethodHolder( - original_obj.bound_method, original_obj.other_bound_method - ) - - bound_function = multiple_bound_method.bound_method - - transformer = FunctionTransformer(func=bound_function) - - loaded_transformer = loads(dumps(transformer)) - loaded_obj = loaded_transformer.func.__self__ - - self.assert_transformer_persisted_correctly(loaded_transformer, transformer) - - # check that both func1 and func2 are from the same object - assert loaded_obj.func1.__self__ == loaded_obj.func2.__self__ - class CustomEstimator(BaseEstimator): """Estimator with np array, np scalar, and sparse matrix attribute""" From eb4b87d12dc3b57d5495df6ad572cff16a1076fb Mon Sep 17 00:00:00 2001 From: = Date: Fri, 28 Oct 2022 20:16:36 +0100 Subject: [PATCH 17/18] Add xfail tests for outstanding issues --- skops/io/tests/test_persist.py | 26 ++++++++++++++++++++++++++ 1 file changed, 26 insertions(+) diff --git a/skops/io/tests/test_persist.py b/skops/io/tests/test_persist.py index f6198cb1..1ce20ea9 100644 --- a/skops/io/tests/test_persist.py +++ b/skops/io/tests/test_persist.py @@ -804,6 +804,10 @@ def __init__(self, object_state: str): def bound_method(self, x): return self.chosen_function(x) + def other_bound_method(self, x): + # arbitrary other function, used for checking single instance loaded + return self.chosen_function(x) + class TestPersistingBoundMethods: @staticmethod @@ -863,6 +867,28 @@ def test_when_object_is_changed_after_init_works_as_expected(self): self.assert_transformer_persisted_correctly(loaded_transformer, transformer) self.assert_bound_method_holder_persisted_correctly(obj, loaded_obj) + @pytest.mark.xfail(reason="Can't load one obj referenced multiple times") + def test_works_when_given_multiple_bound_methods_attached_to_single_instance(self): + obj = _BoundMethodHolder(object_state="") + + transformer = FunctionTransformer( + func=obj.bound_method, inverse_func=obj.other_bound_method + ) + + loaded_transformer = loads(dumps(transformer)) + + # check that both func and inverse_func are from the same object instance + loaded_0 = loaded_transformer.func.__self__ + loaded_1 = loaded_transformer.inverse_func.__self__ + assert loaded_0 is loaded_1 + + @pytest.mark.xfail(reason="Failing due to circular self reference") + def test_scipy_stats(self, tmp_path): + from scipy import stats + + estimator = FunctionTransformer(func=stats.zipf) + loads(dumps(estimator)) + class CustomEstimator(BaseEstimator): """Estimator with np array, np scalar, and sparse matrix attribute""" From a2056c9d83004d4f68ed95dc2dde2520c5b05ade Mon Sep 17 00:00:00 2001 From: = Date: Mon, 31 Oct 2022 10:40:16 +0000 Subject: [PATCH 18/18] Reword XFail reason --- skops/io/tests/test_persist.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/skops/io/tests/test_persist.py b/skops/io/tests/test_persist.py index 1ce20ea9..b999487b 100644 --- a/skops/io/tests/test_persist.py +++ b/skops/io/tests/test_persist.py @@ -867,7 +867,9 @@ def test_when_object_is_changed_after_init_works_as_expected(self): self.assert_transformer_persisted_correctly(loaded_transformer, transformer) self.assert_bound_method_holder_persisted_correctly(obj, loaded_obj) - @pytest.mark.xfail(reason="Can't load one obj referenced multiple times") + @pytest.mark.xfail( + reason="Can't load an object as a single instance if referenced multiple times" + ) def test_works_when_given_multiple_bound_methods_attached_to_single_instance(self): obj = _BoundMethodHolder(object_state="")