forked from skops-dev/skops
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Refactor: get_instance method saved in state
Resolves skops-dev#197 Description Currently, during the dispatch of the get_instance functions, the class stored in the state is being loaded to determine which function to dispatch to. This is bad because module loading can be dangerous. We will add auditing but it is intended to be on the level of get_instance itself, not for the dispatch mechanism. In this PR, the state returned by get_state functions is augmented with the name of the get_instance method required to load the object. This way, we can look up the correct method based on the state and don't need to use the modified singledispatch mechanism, thus avoiding loading modules during dispatching. Implementation Whereas for get_state, we still rely in singledispatch, for get_instance we now use a simple dictionary that looks up the function based on its name (which in turn is stored in the state). The dictionary, going by the name of GET_INSTANCE_MAPPING, is populated similarly to how the get_instance functions were registered previously with singledispatch. There was an issue with circular imports (e.g. get_instance > GET_INSTANCE_MAPPING > ndarray_get_instance > get_instance), hence the get_instance function was moved to its own module, _dispatch.py (other name suggestions welcome). For some types, we now need extra get_state functions because they have specific get_instance methods. So e.g. sgd_loss_get_state just wraps reduce_get_state and adds sgd_loss_get_instance as its loader. Coincidental changes Since we no longer have to inspect the contents of state to determine the function to dispatch to for get_instance, we can fall back to the Python implementation of singledispatch instead of rolling our own. This side effect is a big win. The function Tree_get_instance was renamed to tree_get_instance for consistency. In the debug_dispatch_functions, there was some code from a time when the state was allowed not to be a dict (json-serializable objects). Now we always have a dict, so this dead code was removed. Also, this fixture was elevated to module-level scope, since it only needs to run once.
- Loading branch information
1 parent
bf8c2c1
commit abe490d
Showing
9 changed files
with
113 additions
and
212 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,21 @@ | ||
#!/usr/bin/env python3 | ||
|
||
import json | ||
from typing import Any, Callable | ||
from zipfile import ZipFile | ||
|
||
GET_INSTANCE_MAPPING: dict[str, Callable[[dict[str, Any], ZipFile], Any]] = {} | ||
|
||
|
||
def get_instance(state, src): | ||
"""Create instance based on the state, using json if possible""" | ||
if state.get("is_json"): | ||
return json.loads(state["content"]) | ||
|
||
try: | ||
get_instance_func = GET_INSTANCE_MAPPING[state["__loader__"]] | ||
except KeyError: | ||
raise TypeError( | ||
f"Creating an instance of type {type(state)} is not supported yet" | ||
) | ||
return get_instance_func(state, src) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.