Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ENH Implement LoadContext to handle multiple instances #209

Merged
merged 18 commits into from
Nov 14, 2022
Merged
Show file tree
Hide file tree
Changes from 16 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 26 additions & 11 deletions skops/io/_dispatch.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,19 +2,34 @@

import json

from skops.io._utils import LoadContext

GET_INSTANCE_MAPPING = {} # type: ignore


def get_instance(state, src):
def get_instance(state, load_context: LoadContext):
"""Create instance based on the state, using json if possible"""

saved_id = state.get("__id__")
if saved_id in load_context.memo:
# an instance has already been loaded, just return the loaded instance
return load_context.get_instance(saved_id)

if state.get("is_json"):
return json.loads(state["content"])

try:
get_instance_func = GET_INSTANCE_MAPPING[state["__loader__"]]
except KeyError:
type_name = f"{state['__module__']}.{state['__class__']}"
raise TypeError(
f" Can't find loader {state['__loader__']} for type {type_name}."
)
return get_instance_func(state, src)
loaded_obj = json.loads(state["content"])
else:
try:
get_instance_func = GET_INSTANCE_MAPPING[state["__loader__"]]
except KeyError:
type_name = f"{state['__module__']}.{state['__class__']}"
raise TypeError(
f" Can't find loader {state['__loader__']} for type {type_name}."
)

loaded_obj = get_instance_func(state, load_context)

# hold reference to obj in case same instance encountered again in save state
if saved_id:
load_context.memoize(loaded_obj, saved_id)

return loaded_obj
90 changes: 48 additions & 42 deletions skops/io/_general.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,71 +8,78 @@
import numpy as np

from ._dispatch import get_instance
from ._utils import SaveState, _import_obj, get_module, get_state, gettype
from ._utils import (
LoadContext,
SaveContext,
_import_obj,
get_module,
get_state,
gettype,
)
from .exceptions import UnsupportedTypeException


def dict_get_state(obj: Any, save_state: SaveState) -> dict[str, Any]:
def dict_get_state(obj: Any, save_context: SaveContext) -> dict[str, Any]:
res = {
"__class__": obj.__class__.__name__,
"__module__": get_module(type(obj)),
"__loader__": "dict_get_instance",
}

key_types = get_state([type(key) for key in obj.keys()], save_state)
key_types = get_state([type(key) for key in obj.keys()], save_context)
content = {}
for key, value in obj.items():
if isinstance(value, property):
continue
if np.isscalar(key) and hasattr(key, "item"):
# convert numpy value to python object
key = key.item() # type: ignore
content[key] = get_state(value, save_state)
content[key] = get_state(value, save_context)
res["content"] = content
res["key_types"] = key_types
return res


def dict_get_instance(state, src):
def dict_get_instance(state, load_context: LoadContext):
content = gettype(state)()
key_types = get_instance(state["key_types"], src)
key_types = get_instance(state["key_types"], load_context)
for k_type, item in zip(key_types, state["content"].items()):
content[k_type(item[0])] = get_instance(item[1], src)
content[k_type(item[0])] = get_instance(item[1], load_context)
return content


def list_get_state(obj: Any, save_state: SaveState) -> dict[str, Any]:
def list_get_state(obj: Any, save_context: SaveContext) -> dict[str, Any]:
res = {
"__class__": obj.__class__.__name__,
"__module__": get_module(type(obj)),
"__loader__": "list_get_instance",
}
content = []
for value in obj:
content.append(get_state(value, save_state))
content.append(get_state(value, save_context))
res["content"] = content
return res


def list_get_instance(state, src):
def list_get_instance(state, load_context: LoadContext):
content = gettype(state)()
for value in state["content"]:
content.append(get_instance(value, src))
content.append(get_instance(value, load_context))
return content


def tuple_get_state(obj: Any, save_state: SaveState) -> dict[str, Any]:
def tuple_get_state(obj: Any, save_context: SaveContext) -> dict[str, Any]:
res = {
"__class__": obj.__class__.__name__,
"__module__": get_module(type(obj)),
"__loader__": "tuple_get_instance",
}
content = tuple(get_state(value, save_state) for value in obj)
content = tuple(get_state(value, save_context) for value in obj)
res["content"] = content
return res


def tuple_get_instance(state, src):
def tuple_get_instance(state, load_context: LoadContext):
# Returns a tuple or a namedtuple instance.
def isnamedtuple(t):
# This is needed since namedtuples need to have the args when
Expand All @@ -86,14 +93,14 @@ def isnamedtuple(t):
return all(type(n) == str for n in f)

cls = gettype(state)
content = tuple(get_instance(value, src) for value in state["content"])
content = tuple(get_instance(value, load_context) for value in state["content"])

if isnamedtuple(cls):
return cls(*content)
return content


def function_get_state(obj: Any, save_state: SaveState) -> dict[str, Any]:
def function_get_state(obj: Any, save_context: SaveContext) -> dict[str, Any]:
res = {
"__class__": obj.__class__.__name__,
"__module__": get_module(obj),
Expand All @@ -106,39 +113,39 @@ def function_get_state(obj: Any, save_state: SaveState) -> dict[str, Any]:
return res


def function_get_instance(state, src):
def function_get_instance(state, load_context: LoadContext):
loaded = _import_obj(state["content"]["module_path"], state["content"]["function"])
return loaded


def partial_get_state(obj: Any, save_state: SaveState) -> dict[str, Any]:
def partial_get_state(obj: Any, save_context: SaveContext) -> dict[str, Any]:
_, _, (func, args, kwds, namespace) = obj.__reduce__()
res = {
"__class__": "partial", # don't allow any subclass
"__module__": get_module(type(obj)),
"__loader__": "partial_get_instance",
"content": {
"func": get_state(func, save_state),
"args": get_state(args, save_state),
"kwds": get_state(kwds, save_state),
"namespace": get_state(namespace, save_state),
"func": get_state(func, save_context),
"args": get_state(args, save_context),
"kwds": get_state(kwds, save_context),
"namespace": get_state(namespace, save_context),
},
}
return res


def partial_get_instance(state, src):
def partial_get_instance(state, load_context: LoadContext):
content = state["content"]
func = get_instance(content["func"], src)
args = get_instance(content["args"], src)
kwds = get_instance(content["kwds"], src)
namespace = get_instance(content["namespace"], src)
func = get_instance(content["func"], load_context)
args = get_instance(content["args"], load_context)
kwds = get_instance(content["kwds"], load_context)
namespace = get_instance(content["namespace"], load_context)
instance = partial(func, *args, **kwds) # always use partial, not a subclass
instance.__setstate__((func, args, kwds, namespace))
instance.__setstate__((func, args, kwds, namespace)) # type: ignore
return instance


def type_get_state(obj: Any, save_state: SaveState) -> dict[str, Any]:
def type_get_state(obj: Any, save_context: SaveContext) -> dict[str, Any]:
# To serialize a type, we first need to set the metadata to tell that it's
# a type, then store the type's info itself in the content field.
res = {
Expand All @@ -153,12 +160,12 @@ def type_get_state(obj: Any, save_state: SaveState) -> dict[str, Any]:
return res


def type_get_instance(state, src):
def type_get_instance(state, load_context: LoadContext):
loaded = _import_obj(state["content"]["__module__"], state["content"]["__class__"])
return loaded


def slice_get_state(obj: Any, save_state: SaveState) -> dict[str, Any]:
def slice_get_state(obj: Any, save_context: SaveContext) -> dict[str, Any]:
res = {
"__class__": obj.__class__.__name__,
"__module__": get_module(type(obj)),
Expand All @@ -172,14 +179,14 @@ def slice_get_state(obj: Any, save_state: SaveState) -> dict[str, Any]:
return res


def slice_get_instance(state, src):
def slice_get_instance(state, load_context: LoadContext):
start = state["content"]["start"]
stop = state["content"]["stop"]
step = state["content"]["step"]
return slice(start, stop, step)


def object_get_state(obj: Any, save_state: SaveState) -> dict[str, Any]:
def object_get_state(obj: Any, save_context: SaveContext) -> dict[str, Any]:
# This method is for objects which can either be persisted with json, or
# the ones for which we can get/set attributes through
# __getstate__/__setstate__ or reading/writing to __dict__.
Expand Down Expand Up @@ -211,14 +218,14 @@ def object_get_state(obj: Any, save_state: SaveState) -> dict[str, Any]:
else:
return res

content = get_state(attrs, save_state)
content = get_state(attrs, save_context)
# it's sufficient to store the "content" because we know that this dict can
# only have str type keys
res["content"] = content
return res


def object_get_instance(state, src):
def object_get_instance(state, load_context: LoadContext):
if state.get("is_json", False):
return json.loads(state["content"])

Expand All @@ -233,7 +240,7 @@ def object_get_instance(state, src):
if not content: # nothing more to do
return instance

attrs = get_instance(content, src)
attrs = get_instance(content, load_context)
if hasattr(instance, "__setstate__"):
instance.__setstate__(attrs)
else:
Expand All @@ -242,7 +249,7 @@ def object_get_instance(state, src):
return instance


def method_get_state(obj: Any, save_state: SaveState):
def method_get_state(obj: Any, save_context: SaveContext):
# This method is used to persist bound methods, which are
# dependent on a specific instance of an object.
# It stores the state of the object the method is bound to,
Expand All @@ -253,20 +260,19 @@ def method_get_state(obj: Any, save_state: SaveState):
"__loader__": "method_get_instance",
"content": {
"func": obj.__func__.__name__,
"obj": get_state(obj.__self__, save_state),
"obj": get_state(obj.__self__, save_context),
},
}

return res


def method_get_instance(state, src):
loaded_obj = object_get_instance(state["content"]["obj"], src)
def method_get_instance(state, load_context: LoadContext):
loaded_obj = get_instance(state["content"]["obj"], load_context)
method = getattr(loaded_obj, state["content"]["func"])
return method


def unsupported_get_state(obj: Any, save_state: SaveState) -> dict[str, Any]:
def unsupported_get_state(obj: Any, save_context: SaveContext) -> dict[str, Any]:
raise UnsupportedTypeException(obj)


Expand Down
Loading