-
Notifications
You must be signed in to change notification settings - Fork 53
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
FEAT Audit before loading a skops file #204
Conversation
@skops-dev/maintainers working on some tests and doing a self review, which will add a bunch to the PR. But how much should we include in this PR? There is a ton of work which can be done later. This PR as is allows the tests to pass as long as the user trusts the input. Otherwise there are issues which I'm fixing, and there is also work to actually fill a bunch of security holes. I'm not sure how much we want to go on, on the same PR. |
I don't think we need to be too strict for the initial release, as long as it's clear to the user what they can and what they can't expect. Regarding this, could you please summarize what you want to have in this PR? E.g. the
but at least in the tests, I don't see the list of string option being used. Does it work? How would I specify the types? And if |
@BenjaminBossan I completely agree with all you said, I was gonna add those to this PR anyway. So I think we're on the same page. I'll get this PR to a reviewable state and ping back. |
@BenjaminBossan API question. Which one would you prefer? load("file.skops", trusted=["mymodule.myclass"])
# or
load("file.skops", trusted=["mymodule.myclass", "builtins.int", "builtins.dict", ...]) As in, should we always trust the types we usually trust, or should we require user to even pass those if they want to customize trusted types? I personally lean towards the first option. |
I think this type of question pops up quite often in different disguises, like each time you have some allow list and disallow list. In this case, the second option would be far too much work and would be error prone, so option 1 looks better. I wonder, however, if we can make it so that if needed, a user could still choose not to trust all the defaults. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
- Some lines appear not to be covered, even though it looks like they should be. We should probably understand why that is before merging.
I've left comments on those lines now.
- I wonder if we can somehow ensure that during a node's
__init__
, no object is constructed or module imported. Yes, of course we could just look hard at the code and make sure that nothing ever slips through, but I would like some automatic checking. I have no concrete idea honestly, but maybe we can somehow use theLoadContext
and have separate modes on it, audit mode and construct mode, and during construction, we check that construct mode is set or else fail. Then, we only set construct mode temporarily inside ofconstruct
, so if some code accidentally constructs an object during the auditing part, it would fail.
We certainly should work on that, but I've been thinking of a series of following "hardenning" PRs which work on improving the security of this implementation. We could think of a few ways, for instance, patching import_module
to prevent imports during __init__
maybe.
- Is there something we can do to ensure that the children of a node are set correctly? It seems to me that if we do a mistake there, it would leave a backdoor open. IIUC, we only have 3 possibilities for child nodes, a list, a dict or something else. Can we maybe infer the right way to check just from looking at
state
? I think that would be a huge simplification. Then we don't need custom__init__
in the node subclasses, we just storeself.state = state
inNode
, thus we don't need to think about what all the children are, which can be error prone. For this to work, we might have to ensure that ourstate
s always have the same structure, and during loading verify that this structure is found, but it would be a worthwhile trade-off IMO (if it works).
Another thing we can do is to pass allowed_types
to get_tree
to limit the type of node which can be created. I don't think we can/should leave state
as is, we should parse the data during __init__
and load those values; e.g. we read the numpy data in memory from the zipfile in __init__
, and construct the array later.
- API-wise, as a user, should I be able to say
load(..., trusted=["sklearn.*"])
or something like that? I.e. have a way to blanket trust a module as a whole. Maybe as a future addition?
I don't think users should ever do trusted=["sklearn.*]
because one can always find things like os
through that path. But API wise, adding it later would be easy.
- I think the tests should be extended a bit to cover a handful of realistic use cases we could encounter in the wild. As an example, I would like to see a
Pipeline
containing aFeatureUnion
, consisting of a couple of estimators in total. The test should check that if I only trust a subset of those estimators, the whole pipeline is not trusted. Or if I have aFunctionTransformer
with a custom function or numpy function, it is not trusted unless that function is allowed.
I agree, but I feel like doing that while working on adding sklearn
safe types would be easier. Trying to limit this PR as much as we can. If you think we should add a specific test, we can do though. Also, I feel a bit exhausted working on this PR, if you add those tests here, I definitely wouldn't complain :P
skops/io/_audit.py
Outdated
type_name : str | ||
The class name of the type. | ||
|
||
trusted : list of str |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm happier w/o the types :D
|
||
def get_unsafe_set(self): | ||
if self.is_safe: | ||
return set() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
same here, once we pass trusted
down or when we start trusting a few functions from numpy/scipy, this would be covered.
] | ||
self.children = {"shape": Node, "content": list} | ||
else: | ||
raise ValueError(f"Unknown type {self.type}.") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this only should happen if the given file is malformed, which we're not testing for now.
self.type = state["type"] | ||
self.trusted = self._get_trusted(trusted, ["scipy.sparse.spmatrix"]) | ||
if self.type != "scipy": | ||
raise TypeError( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this should also only happen when input is malformed.
return instance | ||
|
||
if isinstance(args, tuple) and not hasattr(instance, "__setstate__"): | ||
raise UnsupportedTypeException( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this would be the case on some odd types which we're not testing.
if hasattr(instance, "__setstate__"): | ||
instance.__setstate__(attrs) | ||
else: | ||
instance.__dict__.update(attrs) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
same here about odd types.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We certainly should work on that, but I've been thinking of a series of following "hardenning" PRs which work on improving the security of this implementation.
Okay, we can do that as a follow up.
I don't think we can/should leave state as is, we should parse the data during
__init__
and load those values; e.g. we read the numpy data in memory from the zipfile in__init__
, and construct the array later.
Could you please elaborate on that, I don't understand it. Probably it also relates to this comment elsewhere:
we need to do the traverse in
__init__
, and this traversal is node specific. We almost always callget_tree
on children in__init__
.
I checked each self.children = ...
and none of them are calling get_tree
.
I tried to modify the code as following to avoid having a bunch of different attributes for every node type and it worked (only tried for DictNode
):
class DictNode(Node):
def __init__(self, state, load_context: LoadContext, trusted=False):
super().__init__(state, load_context, trusted)
self.trusted = self._get_trusted(trusted, ["builtins.dict"])
# ideally, we could automatically infer the children from the state...
self.children = {
"key_types": get_tree(state["key_types"], load_context),
"content": {
key: get_tree(value, load_context)
for key, value in state["content"].items()
},
}
# no other custom attributes stored
def _construct(self):
key_types = self.children["key_types"]
content_ = self.children["content"]
content = gettype(self.module_name, self.class_name)()
key_types = key_types.construct()
for k_type, (key, val) in zip(key_types, content_.items()):
if key == "categories_":
pass
content[k_type(key)] = val.construct()
return content
# inside of get_unsafe_set, the loop is replaced with:
def get_unsafe_set(self):
...
for child in self.children.values():
if child is None:
continue
# Get the safety set based on the type of the child. In most cases
# other than ListNode and DictNode, children are all of type Node.
if isinstance(child, list):
for value in child:
res.update(value.get_unsafe_set())
elif isinstance(child, dict):
for value in child.values():
res.update(value.get_unsafe_set())
elif isinstance(child, Node):
res.update(child.get_unsafe_set())
else:
raise ValueError(f"Unknown type {type(child)}.")
...
It's not exactly my initial proposal with saving state
but still addresses my concern. WDYT?
I don't think users should ever do trusted=["sklearn.*] because one can always find things like os through that path.
Even if a module in sklearn imports os
, the module name would still not be sklearn
, so there is no match, right?
I agree, but I feel like doing that while working on adding sklearn safe types would be easier.
About that, what's the plan? Enumerating each and every sklearn class and function that's safe? And keeping it up-to-date with each release, across multiple versions? Sounds infeasible.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Had a few more thoughts :)
# TODO: This should help with fixing recursive references. | ||
# if id(value) in save_context.memo: | ||
# return { | ||
# "__module__": None, | ||
# "__class__": None, | ||
# "__id__": id(value), | ||
# "__loader__": "CachedNode", | ||
# } |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think this should work ok, but we would need to restructure how things get memoized.
Right now, things only get memoized once they've been fully constructued, so if an object has a recursive attribute, it never actually gets to be in the memo.
I thought about this originally, and started exploring if there's a nice way to hold a reference to an object that isn't initialized yet, so we could do something like:
_obj = Object() # pseudocode, not the right way to do this
save_context.memoize(_obj, id)
res = _get_state(value, save_context)
_obj.assign(res)
return _obj
Still haven't fully thought this through, but I think we might need to do something like this to allow circular self-references to work the right way.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yeah I've also thought exactly along those lines, I still don't know how exactly to do it though. Worst case scenario we end up doing the DAG work lol. But even that shouldn't require much refactoring.
The reason I commented out this code, is that small integers have the same id
in python (and I don't know what else has that), and that somehow when saving, we don't save in the right order and the cached object gets loaded before the actual object does.
def gettype(module_name, cls_or_func): | ||
if module_name and cls_or_func: | ||
return _import_obj(module_name, cls_or_func) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I wanted to check something here:
As it stands, is there any way loading a .skops file could redefine something in the global namespace?
My understanding (correct me if I'm wrong), is that this could only happen if:
- Code is defined that does that before calling
loads
- Something does this in a user's imported module
In either of those cases, this isn't a vulnerability with skops itself, so I think it's ok, but I wanted to make sure I've not missed somewhere that global namespaces could be changed during load
, as that could lead to a vulnerability.
In other words, there's not currently a way someone could structure a .skops
file that redefines a type we deem "trusted", like np.random.Generator
, right?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, that's exactly the idea, unless an import
statement would do that, in which case the user is already compromised anyway.
I don't see a way with the current format for anybody to be able to modify globals the way one can do with pickle.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Even if a module in sklearn imports os, the module name would still not be sklearn, so there is no match, right?
when we persist yes, but somebody could curate a .skops
file where they do sklearn.datasets._base
as module and os
as class name and we'd import it.
About that, what's the plan? Enumerating each and every sklearn class and function that's safe? And keeping it up-to-date with each release, across multiple versions? Sounds infeasible.
It sounds feasible to me, but would require some automation to streamline it. By enumerating, I would think of getting all estimators through sklearn's API rather than adding manually though.
It's not exactly my initial proposal with saving state but still addresses my concern. WDYT?
This doesn't look neat to me, but while doing it I found one bug, so it's a good pattern I'd say. Also, had to modify the get_unsafe_set
quit a bit as a result.
def gettype(module_name, cls_or_func): | ||
if module_name and cls_or_func: | ||
return _import_obj(module_name, cls_or_func) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, that's exactly the idea, unless an import
statement would do that, in which case the user is already compromised anyway.
I don't see a way with the current format for anybody to be able to modify globals the way one can do with pickle.
# TODO: This should help with fixing recursive references. | ||
# if id(value) in save_context.memo: | ||
# return { | ||
# "__module__": None, | ||
# "__class__": None, | ||
# "__id__": id(value), | ||
# "__loader__": "CachedNode", | ||
# } |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yeah I've also thought exactly along those lines, I still don't know how exactly to do it though. Worst case scenario we end up doing the DAG work lol. But even that shouldn't require much refactoring.
The reason I commented out this code, is that small integers have the same id
in python (and I don't know what else has that), and that somehow when saving, we don't save in the right order and the cached object gets loaded before the actual object does.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think this is ready for another review @BenjaminBossan
# this means we're already computing this node's unsafe set, so we | ||
# return an empty set and let the computation of the parent node | ||
# continue. This is to avoid infinite recursion. | ||
return set() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this is isn't tested since we haven't figured the whole recursive pointers thing out.
# conditions about BytesIO, etc should be ignored. | ||
if not check_type(get_module(child), child.__name__, self.trusted): | ||
# if the child is a type, we check its safety | ||
res.add(get_module(child) + "." + child.__name__) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
for all cases where the child is a type, we have it as trusted, This is only used in reduce.
super().__init__(state, load_context, trusted, memoize=False) | ||
self.trusted = True | ||
self.cached = load_context.get_object(state.get("__id__")) | ||
self.children = {} # type: ignore |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
cache node is used for recursive pointers as well.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pretty good work, I'm happy with the result. There are some small things left to do, please check my comments, but overall exceptional work.
when we persist yes, but somebody could curate a .skops file where they do sklearn.datasets._base as module and os as class name and we'd import it.
Yes, good point. But I wonder how much of a problem that is. If sklearn imports foo
somewhere, we cannot guarantee that foo
will not be imported at some point, even if users pass fine grained types to allow. Therefore, we already need to assume that importing foo
is safe.
Now, using foo
could still be dangerous. However, when we create the instances, we could check if the path of foo
corresponds to sklearn.*
and if not, raise an error. That way, we could still prevent its usage, if I'm not missing something.
It sounds feasible to me, but would require some automation to streamline it. By enumerating, I would think of getting all estimators through sklearn's API rather than adding manually though.
Okay, let's see how it'll work out in practice. Do you plan to include that before next release?
This doesn't look neat to me, but while doing it I found one bug, so it's a good pattern I'd say. Also, had to modify the get_unsafe_set quit a bit as a result.
Not sure what part didn't look neat, but the way you refactored corresponds to my intent and is cleaner IMO, so I'm happy with the outcome.
return state | ||
|
||
|
||
def tree_get_instance(state, load_context): | ||
return reduce_get_instance(state, load_context, constructor=Tree) | ||
class TreeNode(ReduceNode): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
A bit of an unfortunate name now that "Tree" can have another meaning in skops.io.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
we could call it SklearnTreeTreeNode
😁
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, good point. But I wonder how much of a problem that is. If sklearn imports foo somewhere, we cannot guarantee that foo will not be imported at some point, even if users pass fine grained types to allow. Therefore, we already need to assume that importing foo is safe.
Now, using foo could still be dangerous. However, when we create the instances, we could check if the path of foo corresponds to sklearn.* and if not, raise an error. That way, we could still prevent its usage, if I'm not missing something.
Yes, that'd be interesting, but we should do the check before creating the instance. We can do that in a followup PR, and add extra checks for it.
It sounds feasible to me, but would require some automation to streamline it. By enumerating, I would think of getting all estimators through sklearn's API rather than adding manually though.
Okay, let's see how it'll work out in practice. Do you plan to include that before next release?
No, I think we can release a first version w/o trusting much from sklearn.
return content_type([item.construct() for item in self.children["content"]]) | ||
|
||
|
||
def set_get_state(obj: Any, save_context: SaveContext) -> dict[str, Any]: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
good question, but I saw it being an issue during my tests.
return state | ||
|
||
|
||
def tree_get_instance(state, load_context): | ||
return reduce_get_instance(state, load_context, constructor=Tree) | ||
class TreeNode(ReduceNode): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
we could call it SklearnTreeTreeNode
😁
I think I'm happy with this now. It could be merged and we could release early and get feedback, and in the meantime work on the lot left to improve. cc @skops-dev/maintainers |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Great work, I think this is now ready to be merged. Everything that's still open, we can work on later, e.g. I'd like to add some light typing (-:
@E-Aho Do you want to give this another pass too?
Sure! I can give it a final look tonight and hopefully we can merge this in! |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I caught a tiny typo in a docstring, but overall LGTM!
Feel free to hit the "Squash and merge" button then @E-Aho :) (we almost always squash and merge) |
We usually try to craft a nice commit message, not just using the GH suggestion. |
All set :) Side note, is there a list anywhere of the commit prefixes? [FEAT, FIX, DOC, etc] |
I don't have a list, but I use:
|
I was also searching for such a list. How about adding it to the contribution guide? |
We should add it to the maintainers guide instead of contributing guide (we don't have the separation now). Maintainers can/should fix commit messages/titles before merging, and I don't think we should burden first time contributors with such details. |
Created #217 |
This adds auditing before load for the
skops
file format.It creates a tree of nodes by traversing the
state
json stored in the.skops
file, and loads the information in memory w/o loading any modules or constructing any instances.Then we can check this tree for existing types/functions and report things which are not trusted. The user then has to pass this list to a
load
/loads
function to successfully load the.skops
file: