Skip to content

Commit

Permalink
[Python] Refine py register class method (#1218)
Browse files Browse the repository at this point in the history
* merge register class id with class_tag

* fix create schema from struct with meta

* lint code

* fix register_class in doc
  • Loading branch information
chaokunyang authored Dec 8, 2023
1 parent fa7c7a1 commit bd7e234
Show file tree
Hide file tree
Showing 7 changed files with 57 additions and 52 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -239,7 +239,7 @@ class SomeClass:
f3: Dict[str, str]

fury = pyfury.Fury(ref_tracking=True)
fury.register_class(SomeClass, "example.SomeClass")
fury.register_class(SomeClass, type_tag="example.SomeClass")
obj = SomeClass()
obj.f2 = {"k1": "v1", "k2": "v2"}
obj.f1, obj.f3 = obj, obj.f2
Expand Down
6 changes: 3 additions & 3 deletions docs/guide/xlang_object_graph_guide.md
Original file line number Diff line number Diff line change
Expand Up @@ -224,8 +224,8 @@ class SomeClass2:

if __name__ == "__main__":
f = pyfury.Fury()
f.register_class(SomeClass1, "example.SomeClass1")
f.register_class(SomeClass2, "example.SomeClass2")
f.register_class(SomeClass1, type_tag="example.SomeClass1")
f.register_class(SomeClass2, type_tag="example.SomeClass2")
obj1 = SomeClass1(f1=True, f2={-1: 2})
obj = SomeClass2(
f1=obj1,
Expand Down Expand Up @@ -443,7 +443,7 @@ class SomeClass:
f3: Dict[str, str]

fury = pyfury.Fury(ref_tracking=True)
fury.register_class(SomeClass, "example.SomeClass")
fury.register_class(SomeClass, type_tag="example.SomeClass")
obj = SomeClass()
obj.f2 = {"k1": "v1", "k2": "v2"}
obj.f1, obj.f3 = obj, obj.f2
Expand Down
45 changes: 25 additions & 20 deletions python/pyfury/_fury.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,13 +192,15 @@ def __init__(self, fury):
self._dynamic_written_enum_string = []

def initialize(self):
self.register_class(int, PYINT_CLASS_ID)
self.register_class(float, PYFLOAT_CLASS_ID)
self.register_class(bool, PYBOOL_CLASS_ID)
self.register_class(str, STRING_CLASS_ID)
self.register_class(_PickleStub, PICKLE_CLASS_ID)
self.register_class(PickleStrongCacheStub, PICKLE_STRONG_CACHE_CLASS_ID)
self.register_class(PickleCacheStub, PICKLE_CACHE_CLASS_ID)
self.register_class(int, class_id=PYINT_CLASS_ID)
self.register_class(float, class_id=PYFLOAT_CLASS_ID)
self.register_class(bool, class_id=PYBOOL_CLASS_ID)
self.register_class(str, class_id=STRING_CLASS_ID)
self.register_class(_PickleStub, class_id=PICKLE_CLASS_ID)
self.register_class(
PickleStrongCacheStub, class_id=PICKLE_STRONG_CACHE_CLASS_ID
)
self.register_class(PickleCacheStub, class_id=PICKLE_CACHE_CLASS_ID)
self._add_default_serializers()

# `Union[type, TypeVar]` is not supported in py3.6
Expand All @@ -212,7 +214,20 @@ def register_serializer(self, cls, serializer):
self._classes_info[cls].serializer = serializer

# `Union[type, TypeVar]` is not supported in py3.6
def register_class(self, cls, class_id: int = None):
def register_class(self, cls, *, class_id: int = None, type_tag: str = None):
"""Register class with given type id or tag, if tag is not None, it will be used for
cross-language serialization."""
if type_tag is not None:
assert class_id is None, (
f"Type tag {type_tag} has been set already, "
f"set class id at the same time is not allowed."
)
from pyfury._struct import ComplexObjectSerializer

self.register_serializer(
cls, ComplexObjectSerializer(self.fury, cls, type_tag)
)
return
classinfo = self._classes_info.get(cls)
if classinfo is None:
if isinstance(cls, TypeVar):
Expand Down Expand Up @@ -261,13 +276,6 @@ def _next_class_id(self):
class_id = self._class_id_counter = self._class_id_counter + 1
return class_id

def register_class_tag(self, cls: type, type_tag: str = None):
"""Register class with given type tag which will be used for cross-language
serialization."""
from pyfury._struct import ComplexObjectSerializer

self.register_serializer(cls, ComplexObjectSerializer(self.fury, cls, type_tag))

def _add_serializer(self, cls: type, serializer=None, serializer_cls=None):
if serializer_cls:
serializer = serializer_cls(self.fury, cls)
Expand Down Expand Up @@ -653,11 +661,8 @@ def register_serializer(self, cls: type, serializer):
self.class_resolver.register_serializer(cls, serializer)

# `Union[type, TypeVar]` is not supported in py3.6
def register_class(self, cls, class_id: int = None):
self.class_resolver.register_class(cls, class_id=class_id)

def register_class_tag(self, cls: type, type_tag: str = None):
self.class_resolver.register_class_tag(cls, type_tag)
def register_class(self, cls, *, class_id: int = None, type_tag: str = None):
self.class_resolver.register_class(cls, class_id=class_id, type_tag=type_tag)

def serialize(
self,
Expand Down
40 changes: 20 additions & 20 deletions python/pyfury/_serialization.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -326,13 +326,13 @@ cdef class ClassResolver:
self._enum_str_set = set()

def initialize(self):
self.register_class(int, PYINT_CLASS_ID)
self.register_class(float, PYFLOAT_CLASS_ID)
self.register_class(bool, PYBOOL_CLASS_ID)
self.register_class(str, STRING_CLASS_ID)
self.register_class(_PickleStub, PICKLE_CLASS_ID)
self.register_class(PickleStrongCacheStub, PICKLE_STRONG_CACHE_CLASS_ID)
self.register_class(PickleCacheStub, PICKLE_CACHE_CLASS_ID)
self.register_class(int, class_id=PYINT_CLASS_ID)
self.register_class(float, class_id=PYFLOAT_CLASS_ID)
self.register_class(bool, class_id=PYBOOL_CLASS_ID)
self.register_class(str, class_id=STRING_CLASS_ID)
self.register_class(_PickleStub, class_id=PICKLE_CLASS_ID)
self.register_class(PickleStrongCacheStub, class_id=PICKLE_STRONG_CACHE_CLASS_ID)
self.register_class(PickleCacheStub, class_id=PICKLE_CACHE_CLASS_ID)
self._add_default_serializers()

def register_serializer(self, cls: Union[type, TypeVar], serializer):
Expand All @@ -344,7 +344,16 @@ cdef class ClassResolver:
self.register_class(cls)
self._classes_info[cls].serializer = serializer

def register_class(self, cls: Union[type, TypeVar], class_id: int = None):
def register_class(self, cls: Union[type, TypeVar], *, class_id: int = None, type_tag: str = None):
"""Register class with given type id or tag, if tag is not None, it will be used for
cross-language serialization."""
if type_tag is not None:
assert class_id is None, (f"Type tag {type_tag} has been set already, "
f"set class id at the same time is not allowed.")
self.register_serializer(
cls, ComplexObjectSerializer(self.fury, cls, type_tag)
)
return
classinfo = self._classes_info.get(cls)
if classinfo is None:
if isinstance(cls, TypeVar):
Expand Down Expand Up @@ -397,13 +406,6 @@ cdef class ClassResolver:
class_id = self._class_id_counter = self._class_id_counter + 1
return class_id

def register_class_tag(self, cls: Union[type, TypeVar], type_tag: str = None):
"""Register class with given type tag which will be used for cross-language
serialization."""
self.register_serializer(
cls, ComplexObjectSerializer(self.fury, cls, type_tag)
)

def _add_serializer(
self,
cls: Union[type, TypeVar],
Expand Down Expand Up @@ -846,11 +848,9 @@ cdef class Fury:
def register_serializer(self, cls: Union[type, TypeVar], Serializer serializer):
self.class_resolver.register_serializer(cls, serializer)

def register_class(self, cls: Union[type, TypeVar], class_id: int = None):
self.class_resolver.register_class(cls, class_id=class_id)

def register_class_tag(self, cls: Union[type, TypeVar], type_tag: str = None):
self.class_resolver.register_class_tag(cls, type_tag)
def register_class(self, cls: Union[type, TypeVar], *,
class_id: int = None, type_tag: str = None):
self.class_resolver.register_class(cls, class_id=class_id, type_tag=type_tag)

def serialize(
self, obj,
Expand Down
2 changes: 1 addition & 1 deletion python/pyfury/format/encoder.pxi
Original file line number Diff line number Diff line change
Expand Up @@ -452,7 +452,7 @@ cdef create_converter(Field field, CWriter* writer):
return create_atomic_encoder(StrWriter, writer)
elif types.is_struct(data_type):
row_encoder = RowEncoder.create(pa.schema(
data_type, metadata=field.metadata), writer)
list(data_type), metadata=field.metadata), writer)
return row_encoder
elif types.is_list(data_type):
array_encoder = ArrayWriter.create(data_type, writer)
Expand Down
10 changes: 5 additions & 5 deletions python/pyfury/tests/test_cross_language.py
Original file line number Diff line number Diff line change
Expand Up @@ -441,7 +441,7 @@ class ComplexObject2:

def test_serialize_simple_struct_local():
fury = pyfury.Fury(language=pyfury.Language.XLANG, ref_tracking=True)
fury.register_class_tag(ComplexObject2, "test.ComplexObject2")
fury.register_class(ComplexObject2, type_tag="test.ComplexObject2")
obj = ComplexObject2(f1=True, f2={-1: 2})
new_buf = fury.serialize(obj)
assert fury.deserialize(new_buf) == obj
Expand All @@ -450,16 +450,16 @@ def test_serialize_simple_struct_local():
@cross_language_test
def test_serialize_simple_struct(data_file_path):
fury = pyfury.Fury(language=pyfury.Language.XLANG, ref_tracking=True)
fury.register_class_tag(ComplexObject2, "test.ComplexObject2")
fury.register_class(ComplexObject2, type_tag="test.ComplexObject2")
obj = ComplexObject2(f1=True, f2={-1: 2})
struct_round_back(data_file_path, fury, obj)


@cross_language_test
def test_serialize_complex_struct(data_file_path):
fury = pyfury.Fury(language=pyfury.Language.XLANG, ref_tracking=True)
fury.register_class_tag(ComplexObject1, "test.ComplexObject1")
fury.register_class_tag(ComplexObject2, "test.ComplexObject2")
fury.register_class(ComplexObject1, type_tag="test.ComplexObject1")
fury.register_class(ComplexObject2, type_tag="test.ComplexObject2")

obj2 = ComplexObject2(f1=True, f2={-1: 2})
obj1 = ComplexObject1(
Expand Down Expand Up @@ -501,7 +501,7 @@ def test_serialize_opaque_object(data_file_path):
data_bytes = f.read()
debug_print(f"len {len(data_bytes)}")
fury = pyfury.Fury(language=pyfury.Language.XLANG, ref_tracking=True)
fury.register_class_tag(ComplexObject1, "test.ComplexObject1")
fury.register_class(ComplexObject1, type_tag="test.ComplexObject1")
new_obj = fury.deserialize(data_bytes)
debug_print(new_obj)
assert new_obj.f2 == "abc"
Expand Down
4 changes: 2 additions & 2 deletions python/pyfury/tests/test_struct.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,8 +48,8 @@ class ComplexObject:

def test_struct():
fury = Fury(language=Language.XLANG, ref_tracking=True)
fury.register_class_tag(SimpleObject, "example.SimpleObject")
fury.register_class_tag(ComplexObject, "example.ComplexObject")
fury.register_class(SimpleObject, type_tag="example.SimpleObject")
fury.register_class(ComplexObject, type_tag="example.ComplexObject")
o = SimpleObject(f1={1: 1.0 / 3})
# assert ser_de(fury, o) == o

Expand Down

0 comments on commit bd7e234

Please sign in to comment.