From 674f9862b051dae1ac893b2dd0ca3eca50cc6435 Mon Sep 17 00:00:00 2001 From: Romain Cledat Date: Sat, 17 Feb 2024 00:51:18 -0800 Subject: [PATCH] Add support for __class__ in escape hatch --- metaflow/plugins/env_escape/stub.py | 22 +++++++++++++--------- test/env_escape/example.py | 6 ++++++ 2 files changed, 19 insertions(+), 9 deletions(-) diff --git a/metaflow/plugins/env_escape/stub.py b/metaflow/plugins/env_escape/stub.py index 12131694eed..bf101633f22 100644 --- a/metaflow/plugins/env_escape/stub.py +++ b/metaflow/plugins/env_escape/stub.py @@ -47,6 +47,7 @@ "__init__", "__metaclass__", "__module__", + "__name__", "__new__", "__reduce__", "__reduce_ex__", @@ -78,11 +79,11 @@ def fwd_request(stub, request_type, *args, **kwargs): class StubMetaClass(type): - def __repr__(self): - if self.__module__: - return "" % (self.__module__, self.__name__) + def __repr__(cls): + if cls.__module__: + return "" % (cls.__module__, cls.__name__) else: - return "" % (self.__name__,) + return "" % (cls.__name__,) def with_metaclass(meta, *bases): @@ -131,9 +132,7 @@ def __del__(self): def __getattribute__(self, name): if name in LOCAL_ATTRS: - if name == "__class__": - return None - elif name == "__doc__": + if name == "__doc__": return self.__getattr__("__doc__") elif name in DELETED_ATTRS: raise AttributeError() @@ -455,9 +454,14 @@ def _do_str(self): # it but not the case if we are not. class_dict["__slots__"].append("__weakref__") + class_module, class_name_only = class_name.rsplit(".", 1) class_dict["___local_overrides___"] = overriden_attrs + class_dict["__module__"] = class_module if parents: - return MetaExceptionWithConnection( + to_return = MetaExceptionWithConnection( class_name, (Stub, *parents), class_dict, connection ) - return MetaWithConnection(class_name, (Stub,), class_dict, connection) + else: + to_return = MetaWithConnection(class_name, (Stub,), class_dict, connection) + to_return.__name__ = class_name_only + return to_return diff --git a/test/env_escape/example.py b/test/env_escape/example.py index 50d6c75bad3..128556de4b8 100644 --- a/test/env_escape/example.py +++ b/test/env_escape/example.py @@ -83,6 +83,8 @@ def run_test(through_escape=False): print("-- Test chaining of exported classes --") o2 = o1.to_class2(5) assert o2.something("foo") == "Test2:Something:foo" + assert o2.__class__.__name__ == "TestClass2" + assert o2.__class__.__module__ == "test_lib" print("-- Test Iterating --") for idx, i in enumerate(o2): @@ -108,6 +110,8 @@ def run_test(through_escape=False): assert isinstance(ex_child, test.ExceptionAndClass) assert isinstance(ex_child, Exception) assert isinstance(ex_child, object) + assert ex_child.__class__.__name__ == "ExceptionAndClassChild" + assert ex_child.__class__.__module__ == "test_lib" assert issubclass(type(ex_child), test.ExceptionAndClass) assert issubclass(test.ExceptionAndClassChild, test.ExceptionAndClass) @@ -149,6 +153,8 @@ def run_test(through_escape=False): excclass = o1.raiseOrReturnSomeException() assert not through_escape, "Should have raised through escape" assert isinstance(excclass, test.SomeException) + assert excclass.__class__.__name__ == "SomeException" + assert excclass.__class__.__module__ == "test_lib" except RuntimeError as e: assert ( through_escape