From bd7e234f4e639e6e6b46eb528e06c3643f9d9e36 Mon Sep 17 00:00:00 2001 From: Shawn Yang Date: Sat, 9 Dec 2023 00:00:08 +0800 Subject: [PATCH] [Python] Refine py register class method (#1218) * merge register class id with class_tag * fix create schema from struct with meta * lint code * fix register_class in doc --- README.md | 2 +- docs/guide/xlang_object_graph_guide.md | 6 +-- python/pyfury/_fury.py | 45 ++++++++++++---------- python/pyfury/_serialization.pyx | 40 +++++++++---------- python/pyfury/format/encoder.pxi | 2 +- python/pyfury/tests/test_cross_language.py | 10 ++--- python/pyfury/tests/test_struct.py | 4 +- 7 files changed, 57 insertions(+), 52 deletions(-) diff --git a/README.md b/README.md index be193fe022..ccdd425a4c 100644 --- a/README.md +++ b/README.md @@ -239,7 +239,7 @@ class SomeClass: f3: Dict[str, str] fury = pyfury.Fury(ref_tracking=True) -fury.register_class(SomeClass, "example.SomeClass") +fury.register_class(SomeClass, type_tag="example.SomeClass") obj = SomeClass() obj.f2 = {"k1": "v1", "k2": "v2"} obj.f1, obj.f3 = obj, obj.f2 diff --git a/docs/guide/xlang_object_graph_guide.md b/docs/guide/xlang_object_graph_guide.md index 0a0b5ef60d..2c5c61966a 100644 --- a/docs/guide/xlang_object_graph_guide.md +++ b/docs/guide/xlang_object_graph_guide.md @@ -224,8 +224,8 @@ class SomeClass2: if __name__ == "__main__": f = pyfury.Fury() - f.register_class(SomeClass1, "example.SomeClass1") - f.register_class(SomeClass2, "example.SomeClass2") + f.register_class(SomeClass1, type_tag="example.SomeClass1") + f.register_class(SomeClass2, type_tag="example.SomeClass2") obj1 = SomeClass1(f1=True, f2={-1: 2}) obj = SomeClass2( f1=obj1, @@ -443,7 +443,7 @@ class SomeClass: f3: Dict[str, str] fury = pyfury.Fury(ref_tracking=True) -fury.register_class(SomeClass, "example.SomeClass") +fury.register_class(SomeClass, type_tag="example.SomeClass") obj = SomeClass() obj.f2 = {"k1": "v1", "k2": "v2"} obj.f1, obj.f3 = obj, obj.f2 diff --git a/python/pyfury/_fury.py b/python/pyfury/_fury.py index 0c8f3288fd..b2378e7268 100644 --- a/python/pyfury/_fury.py +++ b/python/pyfury/_fury.py @@ -192,13 +192,15 @@ def __init__(self, fury): self._dynamic_written_enum_string = [] def initialize(self): - self.register_class(int, PYINT_CLASS_ID) - self.register_class(float, PYFLOAT_CLASS_ID) - self.register_class(bool, PYBOOL_CLASS_ID) - self.register_class(str, STRING_CLASS_ID) - self.register_class(_PickleStub, PICKLE_CLASS_ID) - self.register_class(PickleStrongCacheStub, PICKLE_STRONG_CACHE_CLASS_ID) - self.register_class(PickleCacheStub, PICKLE_CACHE_CLASS_ID) + self.register_class(int, class_id=PYINT_CLASS_ID) + self.register_class(float, class_id=PYFLOAT_CLASS_ID) + self.register_class(bool, class_id=PYBOOL_CLASS_ID) + self.register_class(str, class_id=STRING_CLASS_ID) + self.register_class(_PickleStub, class_id=PICKLE_CLASS_ID) + self.register_class( + PickleStrongCacheStub, class_id=PICKLE_STRONG_CACHE_CLASS_ID + ) + self.register_class(PickleCacheStub, class_id=PICKLE_CACHE_CLASS_ID) self._add_default_serializers() # `Union[type, TypeVar]` is not supported in py3.6 @@ -212,7 +214,20 @@ def register_serializer(self, cls, serializer): self._classes_info[cls].serializer = serializer # `Union[type, TypeVar]` is not supported in py3.6 - def register_class(self, cls, class_id: int = None): + def register_class(self, cls, *, class_id: int = None, type_tag: str = None): + """Register class with given type id or tag, if tag is not None, it will be used for + cross-language serialization.""" + if type_tag is not None: + assert class_id is None, ( + f"Type tag {type_tag} has been set already, " + f"set class id at the same time is not allowed." + ) + from pyfury._struct import ComplexObjectSerializer + + self.register_serializer( + cls, ComplexObjectSerializer(self.fury, cls, type_tag) + ) + return classinfo = self._classes_info.get(cls) if classinfo is None: if isinstance(cls, TypeVar): @@ -261,13 +276,6 @@ def _next_class_id(self): class_id = self._class_id_counter = self._class_id_counter + 1 return class_id - def register_class_tag(self, cls: type, type_tag: str = None): - """Register class with given type tag which will be used for cross-language - serialization.""" - from pyfury._struct import ComplexObjectSerializer - - self.register_serializer(cls, ComplexObjectSerializer(self.fury, cls, type_tag)) - def _add_serializer(self, cls: type, serializer=None, serializer_cls=None): if serializer_cls: serializer = serializer_cls(self.fury, cls) @@ -653,11 +661,8 @@ def register_serializer(self, cls: type, serializer): self.class_resolver.register_serializer(cls, serializer) # `Union[type, TypeVar]` is not supported in py3.6 - def register_class(self, cls, class_id: int = None): - self.class_resolver.register_class(cls, class_id=class_id) - - def register_class_tag(self, cls: type, type_tag: str = None): - self.class_resolver.register_class_tag(cls, type_tag) + def register_class(self, cls, *, class_id: int = None, type_tag: str = None): + self.class_resolver.register_class(cls, class_id=class_id, type_tag=type_tag) def serialize( self, diff --git a/python/pyfury/_serialization.pyx b/python/pyfury/_serialization.pyx index 832970dd46..7360fbb4b3 100644 --- a/python/pyfury/_serialization.pyx +++ b/python/pyfury/_serialization.pyx @@ -326,13 +326,13 @@ cdef class ClassResolver: self._enum_str_set = set() def initialize(self): - self.register_class(int, PYINT_CLASS_ID) - self.register_class(float, PYFLOAT_CLASS_ID) - self.register_class(bool, PYBOOL_CLASS_ID) - self.register_class(str, STRING_CLASS_ID) - self.register_class(_PickleStub, PICKLE_CLASS_ID) - self.register_class(PickleStrongCacheStub, PICKLE_STRONG_CACHE_CLASS_ID) - self.register_class(PickleCacheStub, PICKLE_CACHE_CLASS_ID) + self.register_class(int, class_id=PYINT_CLASS_ID) + self.register_class(float, class_id=PYFLOAT_CLASS_ID) + self.register_class(bool, class_id=PYBOOL_CLASS_ID) + self.register_class(str, class_id=STRING_CLASS_ID) + self.register_class(_PickleStub, class_id=PICKLE_CLASS_ID) + self.register_class(PickleStrongCacheStub, class_id=PICKLE_STRONG_CACHE_CLASS_ID) + self.register_class(PickleCacheStub, class_id=PICKLE_CACHE_CLASS_ID) self._add_default_serializers() def register_serializer(self, cls: Union[type, TypeVar], serializer): @@ -344,7 +344,16 @@ cdef class ClassResolver: self.register_class(cls) self._classes_info[cls].serializer = serializer - def register_class(self, cls: Union[type, TypeVar], class_id: int = None): + def register_class(self, cls: Union[type, TypeVar], *, class_id: int = None, type_tag: str = None): + """Register class with given type id or tag, if tag is not None, it will be used for + cross-language serialization.""" + if type_tag is not None: + assert class_id is None, (f"Type tag {type_tag} has been set already, " + f"set class id at the same time is not allowed.") + self.register_serializer( + cls, ComplexObjectSerializer(self.fury, cls, type_tag) + ) + return classinfo = self._classes_info.get(cls) if classinfo is None: if isinstance(cls, TypeVar): @@ -397,13 +406,6 @@ cdef class ClassResolver: class_id = self._class_id_counter = self._class_id_counter + 1 return class_id - def register_class_tag(self, cls: Union[type, TypeVar], type_tag: str = None): - """Register class with given type tag which will be used for cross-language - serialization.""" - self.register_serializer( - cls, ComplexObjectSerializer(self.fury, cls, type_tag) - ) - def _add_serializer( self, cls: Union[type, TypeVar], @@ -846,11 +848,9 @@ cdef class Fury: def register_serializer(self, cls: Union[type, TypeVar], Serializer serializer): self.class_resolver.register_serializer(cls, serializer) - def register_class(self, cls: Union[type, TypeVar], class_id: int = None): - self.class_resolver.register_class(cls, class_id=class_id) - - def register_class_tag(self, cls: Union[type, TypeVar], type_tag: str = None): - self.class_resolver.register_class_tag(cls, type_tag) + def register_class(self, cls: Union[type, TypeVar], *, + class_id: int = None, type_tag: str = None): + self.class_resolver.register_class(cls, class_id=class_id, type_tag=type_tag) def serialize( self, obj, diff --git a/python/pyfury/format/encoder.pxi b/python/pyfury/format/encoder.pxi index 531bc2d868..3ea019aee6 100644 --- a/python/pyfury/format/encoder.pxi +++ b/python/pyfury/format/encoder.pxi @@ -452,7 +452,7 @@ cdef create_converter(Field field, CWriter* writer): return create_atomic_encoder(StrWriter, writer) elif types.is_struct(data_type): row_encoder = RowEncoder.create(pa.schema( - data_type, metadata=field.metadata), writer) + list(data_type), metadata=field.metadata), writer) return row_encoder elif types.is_list(data_type): array_encoder = ArrayWriter.create(data_type, writer) diff --git a/python/pyfury/tests/test_cross_language.py b/python/pyfury/tests/test_cross_language.py index c5a101889c..aeecb1bac0 100644 --- a/python/pyfury/tests/test_cross_language.py +++ b/python/pyfury/tests/test_cross_language.py @@ -441,7 +441,7 @@ class ComplexObject2: def test_serialize_simple_struct_local(): fury = pyfury.Fury(language=pyfury.Language.XLANG, ref_tracking=True) - fury.register_class_tag(ComplexObject2, "test.ComplexObject2") + fury.register_class(ComplexObject2, type_tag="test.ComplexObject2") obj = ComplexObject2(f1=True, f2={-1: 2}) new_buf = fury.serialize(obj) assert fury.deserialize(new_buf) == obj @@ -450,7 +450,7 @@ def test_serialize_simple_struct_local(): @cross_language_test def test_serialize_simple_struct(data_file_path): fury = pyfury.Fury(language=pyfury.Language.XLANG, ref_tracking=True) - fury.register_class_tag(ComplexObject2, "test.ComplexObject2") + fury.register_class(ComplexObject2, type_tag="test.ComplexObject2") obj = ComplexObject2(f1=True, f2={-1: 2}) struct_round_back(data_file_path, fury, obj) @@ -458,8 +458,8 @@ def test_serialize_simple_struct(data_file_path): @cross_language_test def test_serialize_complex_struct(data_file_path): fury = pyfury.Fury(language=pyfury.Language.XLANG, ref_tracking=True) - fury.register_class_tag(ComplexObject1, "test.ComplexObject1") - fury.register_class_tag(ComplexObject2, "test.ComplexObject2") + fury.register_class(ComplexObject1, type_tag="test.ComplexObject1") + fury.register_class(ComplexObject2, type_tag="test.ComplexObject2") obj2 = ComplexObject2(f1=True, f2={-1: 2}) obj1 = ComplexObject1( @@ -501,7 +501,7 @@ def test_serialize_opaque_object(data_file_path): data_bytes = f.read() debug_print(f"len {len(data_bytes)}") fury = pyfury.Fury(language=pyfury.Language.XLANG, ref_tracking=True) - fury.register_class_tag(ComplexObject1, "test.ComplexObject1") + fury.register_class(ComplexObject1, type_tag="test.ComplexObject1") new_obj = fury.deserialize(data_bytes) debug_print(new_obj) assert new_obj.f2 == "abc" diff --git a/python/pyfury/tests/test_struct.py b/python/pyfury/tests/test_struct.py index f37ea61510..5a71dc2ab9 100644 --- a/python/pyfury/tests/test_struct.py +++ b/python/pyfury/tests/test_struct.py @@ -48,8 +48,8 @@ class ComplexObject: def test_struct(): fury = Fury(language=Language.XLANG, ref_tracking=True) - fury.register_class_tag(SimpleObject, "example.SimpleObject") - fury.register_class_tag(ComplexObject, "example.ComplexObject") + fury.register_class(SimpleObject, type_tag="example.SimpleObject") + fury.register_class(ComplexObject, type_tag="example.ComplexObject") o = SimpleObject(f1={1: 1.0 / 3}) # assert ser_de(fury, o) == o