Skip to content

Commit

Permalink
Better support for subclasses and fix issues with >1 exception depth
Browse files Browse the repository at this point in the history
  • Loading branch information
romain-intel committed Mar 12, 2024
1 parent dfce3b7 commit 79bb1d4
Show file tree
Hide file tree
Showing 9 changed files with 289 additions and 20 deletions.
69 changes: 55 additions & 14 deletions metaflow/plugins/env_escape/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -357,7 +357,9 @@ def decode(self, json_obj):
# this connection will be converted to a local stub.
return self._datatransferer.load(json_obj)

def get_local_class(self, name, obj_id=None, is_returned_exception=False):
def get_local_class(
self, name, obj_id=None, is_returned_exception=False, is_parent=False
):
# Gets (and creates if needed), the class mapping to the remote
# class of name 'name'.

Expand All @@ -367,6 +369,15 @@ def get_local_class(self, name, obj_id=None, is_returned_exception=False):
# - classes that are proxied regular classes AND NOT proxied exceptions
# - clases that are NOT proxied regular classes AND are proxied exceptions
name = get_canonical_name(name, self._aliases)

def name_to_parent_name(name):
return "parent:%s" % name

if is_parent:
lookup_name = name_to_parent_name(name)
else:
lookup_name = name

if name == "function":
# Special handling of pickled functions. We create a new class that
# simply has a __call__ method that will forward things back to
Expand All @@ -378,15 +389,15 @@ def get_local_class(self, name, obj_id=None, is_returned_exception=False):
self, "__function_%s" % obj_id, {}, {}, {}, {"__call__": ""}, []
)
return self._proxied_standalone_functions[obj_id]
local_class = self._proxied_classes.get(name, None)
local_class = self._proxied_classes.get(lookup_name, None)
if local_class is not None:
return local_class

is_proxied_exception = name in self._exception_hierarchy
is_proxied_non_exception = name in self._proxied_classnames

if not is_proxied_exception and not is_proxied_non_exception:
if is_returned_exception:
if is_returned_exception or is_parent:
# In this case, it may be a local exception that we need to
# recreate
try:
Expand All @@ -405,7 +416,7 @@ def get_local_class(self, name, obj_id=None, is_returned_exception=False):
dict(getattr(local_exception, "__dict__", {})),
)
wrapped_exception.__module__ = ex_module
self._proxied_classes[name] = wrapped_exception
self._proxied_classes[lookup_name] = wrapped_exception
return wrapped_exception

raise ValueError("Class '%s' is not known" % name)
Expand All @@ -422,31 +433,61 @@ def get_local_class(self, name, obj_id=None, is_returned_exception=False):
for parent in ex_parents:
# We always consider it to be an exception so that we wrap even non
# proxied builtins exceptions
parents.append(self.get_local_class(parent, is_returned_exception=True))
parents.append(self.get_local_class(parent, is_parent=True))
# For regular classes, we get what it exposes from the server
if is_proxied_non_exception:
remote_methods = self.stub_request(None, OP_GETMETHODS, name)
else:
remote_methods = {}

if is_proxied_exception and not is_proxied_non_exception:
# This is a pure exception
parent_local_class = None
local_class = None
if is_proxied_exception:
# If we are a proxied exception AND a proxied class, we create two classes:
# actually:
# - the class itself (which is a stub)
# - the class in the capacity of a parent class (to another exception
# presumably). The reason for this is that if we have a exception/proxied
# class A and another B and B inherits from A, the MRO order would be all
# wrong since both A and B would also inherit from `Stub`. Here what we
# do is:
# - A_parent inherits from the actual parents of A (let's assume a
# builtin exception)
# - A inherits from (Stub, A_parent)
# - B_parent inherints from A_parent and the builtin Exception
# - B inherits from (Stub, B_parent)
ex_module, ex_name = name.rsplit(".", 1)
local_class = ExceptionMetaClass(ex_name, (*parents,), {})
local_class.__module__ = ex_module
else:
# This method creates either a pure stub or a stub that is also an exception
parent_local_class = ExceptionMetaClass(ex_name, (*parents,), {})
parent_local_class.__module__ = ex_module

if is_proxied_non_exception:
local_class = create_class(
self,
name,
self._overrides.get(name, {}),
self._getattr_overrides.get(name, {}),
self._setattr_overrides.get(name, {}),
remote_methods,
parents,
(parent_local_class,) if parent_local_class else None,
)
self._proxied_classes[name] = local_class
return local_class
if parent_local_class:
self._proxied_classes[name_to_parent_name(name)] = parent_local_class
if local_class:
self._proxied_classes[name] = local_class
else:
# This is for the case of pure proxied exceptions -- we want the lookup of
# foo.MyException to be the same class as looking of foo.MyException as a parent
# of another exception so `isinstance` works properly
self._proxied_classes[name] = parent_local_class

if is_parent:
# This should never happen but making sure
if not parent_local_class:
raise RuntimeError(
"Exception parent class %s is not a proxied exception" % name
)
return parent_local_class
return self._proxied_classes[name]

def can_pickle(self, obj):
return getattr(obj, "___connection___", None) == self
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -52,10 +52,30 @@ def unsupported_method(stub, func, *args, **kwargs):


@local_exception_deserialize("test_lib.SomeException")
def deserialize_user(ex, json_obj):
def some_exception_deserialize(ex, json_obj):
ex.user_value = json_obj


@remote_exception_serialize("test_lib.SomeException")
def some_exception_serialize(ex):
return 42


@local_exception_deserialize("test_lib.ExceptionAndClass")
def exception_and_class_deserialize(ex, json_obj):
ex.user_value = json_obj


@remote_exception_serialize("test_lib.ExceptionAndClass")
def exception_and_class_serialize(ex):
return 43


@local_exception_deserialize("test_lib.ExceptionAndClassChild")
def exception_and_class_child_deserialize(ex, json_obj):
ex.user_value = json_obj


@remote_exception_serialize("test_lib.ExceptionAndClassChild")
def exception_and_class_child_serialize(ex):
return 44
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,10 @@
("test_lib", "test_lib.alias"): {
"TestClass1": lib.TestClass1,
"TestClass2": lib.TestClass2,
"BaseClass": lib.BaseClass,
"ChildClass": lib.ChildClass,
"ExceptionAndClass": lib.ExceptionAndClass,
"ExceptionAndClassChild": lib.ExceptionAndClassChild,
}
}

Expand All @@ -22,6 +25,7 @@
"SomeException": lib.SomeException,
"MyBaseException": lib.MyBaseException,
"ExceptionAndClass": lib.ExceptionAndClass,
"ExceptionAndClassChild": lib.ExceptionAndClassChild,
}
}

Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import functools
from html.parser import HTMLParser


class MyBaseException(Exception):
Expand All @@ -20,6 +21,36 @@ def __str__(self):
return "ExceptionAndClass Str: %s" % super().__str__()


class ExceptionAndClassChild(ExceptionAndClass):
def __init__(self, *args):
super().__init__(*args)

def method_on_child_exception(self):
return "method_on_child_exception"

def __str__(self):
return "ExceptionAndClassChild Str: %s" % super().__str__()


class BaseClass(HTMLParser):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self._output = []

def handle_starttag(self, tag, attrs):
self._output.append(tag)
return super().handle_starttag(tag, attrs)

def get_output(self):
return self._output


class ChildClass(BaseClass):
def handle_endtag(self, tag):
self._output.append(tag)
return super().handle_endtag(tag)


class TestClass1(object):
cls_object = 25

Expand Down Expand Up @@ -72,6 +103,9 @@ def __hidden(self, name, value):
def weird_indirection(self, name):
return functools.partial(self.__hidden, name)

def returnChild(self):
return ChildClass()

def raiseOrReturnValueError(self, doRaise=False):
if doRaise:
raise ValueError("I raised")
Expand All @@ -87,6 +121,11 @@ def raiseOrReturnExceptionAndClass(self, doRaise=False):
raise ExceptionAndClass("I raised")
return ExceptionAndClass("I returned")

def raiseOrReturnExceptionAndClassChild(self, doRaise=False):
if doRaise:
raise ExceptionAndClassChild("I raised")
return ExceptionAndClassChild("I returned")


class TestClass2(object):
def __init__(self, value, stride, count):
Expand Down
1 change: 1 addition & 0 deletions metaflow/plugins/env_escape/consts.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
OP_SETVAL = 16
OP_INIT = 17
OP_CALLONCLASS = 18
OP_SUBCLASSCHECK = 19

# Control messages
CONTROL_SHUTDOWN = 1
Expand Down
9 changes: 8 additions & 1 deletion metaflow/plugins/env_escape/exception_transferer.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,8 @@ def dump_exception(data_transferer, exception_type, exception_val, tb, user_data


def load_exception(client, json_obj):
from .stub import Stub

json_obj = client.decode(json_obj)

if json_obj.get(FIELD_EXC_SI) is not None:
Expand All @@ -93,11 +95,16 @@ def load_exception(client, json_obj):
exception_module = json_obj.get(FIELD_EXC_MODULE)
exception_name = json_obj.get(FIELD_EXC_NAME)
exception_class = None
# This name is already cannonical since we cannonicalize it on the server side
full_name = "%s.%s" % (exception_module, exception_name)

exception_class = client.get_local_class(full_name, is_returned_exception=True)

raised_exception = exception_class(*json_obj.get(FIELD_EXC_ARGS))
if issubclass(exception_class, Stub):
raised_exception = exception_class(_is_returned_exception=True)
raised_exception.args = tuple(json_obj.get(FIELD_EXC_ARGS))
else:
raised_exception = exception_class(*json_obj.get(FIELD_EXC_ARGS))
raised_exception._exception_str = json_obj.get(FIELD_EXC_STR, None)
raised_exception._exception_repr = json_obj.get(FIELD_EXC_REPR, None)
raised_exception._exception_tb = json_obj.get(FIELD_EXC_TB, None)
Expand Down
18 changes: 18 additions & 0 deletions metaflow/plugins/env_escape/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
OP_GETVAL,
OP_SETVAL,
OP_INIT,
OP_SUBCLASSCHECK,
VALUE_LOCAL,
VALUE_REMOTE,
CONTROL_GETEXPORTS,
Expand Down Expand Up @@ -255,6 +256,7 @@ def __init__(self, config_dir, max_pickle_version):
OP_GETVAL: self._handle_getval,
OP_SETVAL: self._handle_setval,
OP_INIT: self._handle_init,
OP_SUBCLASSCHECK: self._handle_subclasscheck,
}

self._local_objects = {}
Expand Down Expand Up @@ -292,6 +294,7 @@ def encode(self, obj):
def encode_exception(self, ex_type, ex, trace_back):
try:
full_name = "%s.%s" % (ex_type.__module__, ex_type.__name__)
get_canonical_name(full_name, self._aliases)
serializer = self._exception_serializers.get(full_name)
except AttributeError:
# Ignore if no __module__ for example -- definitely not something we built
Expand Down Expand Up @@ -502,6 +505,21 @@ def _handle_init(self, target, class_name, *args, **kwargs):
raise ValueError("Unknown class %s" % class_name)
return class_type(*args, **kwargs)

def _handle_subclasscheck(self, target, class_name, otherclass_name, reverse=False):
class_type = self._known_classes.get(class_name)
if class_type is None:
raise ValueError("Unknown class %s" % class_name)
try:
sub_module, sub_name = otherclass_name.rsplit(".", 1)
__import__(sub_module, None, None, "*")
except Exception:
sub_module = None
if sub_module is None:
return False
if reverse:
return issubclass(class_type, getattr(sys.modules[sub_module], sub_name))
return issubclass(getattr(sys.modules[sub_module], sub_name), class_type)


if __name__ == "__main__":
max_pickle_version = int(sys.argv[1])
Expand Down
Loading

0 comments on commit 79bb1d4

Please sign in to comment.