From 210ea0818df246458aca2a852a663329069b55c4 Mon Sep 17 00:00:00 2001 From: Zhihong Zhang Date: Tue, 14 May 2024 11:34:58 -0400 Subject: [PATCH] FOBS enhancements: 1) Added auto-registration of data class decomposers. 2). No need to register decomposers on deserializing --- nvflare/fuel/utils/fobs/fobs.py | 168 ++++++++++++------- tests/unit_test/fuel/utils/fobs/fobs_test.py | 31 +++- 2 files changed, 136 insertions(+), 63 deletions(-) diff --git a/nvflare/fuel/utils/fobs/fobs.py b/nvflare/fuel/utils/fobs/fobs.py index 57526c5ecf..b0bc15187c 100644 --- a/nvflare/fuel/utils/fobs/fobs.py +++ b/nvflare/fuel/utils/fobs/fobs.py @@ -17,6 +17,7 @@ import os from enum import Enum from os.path import dirname, join +from types import NoneType from typing import Any, BinaryIO, Dict, Type, TypeVar, Union import msgpack @@ -42,63 +43,29 @@ FOBS_TYPE = "__fobs_type__" FOBS_DATA = "__fobs_data__" +FOBS_DECOMPOSER = "__fobs_dc__" + MAX_CONTENT_LEN = 128 -MSGPACK_TYPES = (None, bool, int, float, str, bytes, bytearray, memoryview, list, dict) +MSGPACK_TYPES = (NoneType, bool, int, float, str, bytes, bytearray, memoryview, list, dict) T = TypeVar("T") log = logging.getLogger(__name__) _decomposers: Dict[str, Decomposer] = {} _decomposers_registered = False -_enum_auto_register = True - - -class Packer: - def __init__(self, manager: DatumManager): - self.manager = manager - - def pack(self, obj: Any) -> dict: - - if type(obj) in MSGPACK_TYPES: - return obj - - type_name = _get_type_name(obj.__class__) - if type_name not in _decomposers: - if _enum_auto_register and isinstance(obj, Enum): - register_enum_types(type(obj)) - else: - return obj - - decomposed = _decomposers[type_name].decompose(obj, self.manager) - if self.manager: - decomposed = self.manager.externalize(decomposed) - - return {FOBS_TYPE: type_name, FOBS_DATA: decomposed} - - def unpack(self, obj: Any) -> Any: - - if type(obj) is not dict or FOBS_TYPE not in obj: - return obj +# If this is enabled, FOBS will try to register generic decomposers automatically +_enum_auto_registration = True +_data_auto_registration = True - type_name = obj[FOBS_TYPE] - if type_name not in _decomposers: - error = True - if _enum_auto_register: - cls = self._load_class(type_name) - if issubclass(cls, Enum): - register_enum_types(cls) - error = False - if error: - raise TypeError(f"Unknown type {type_name}, caused by mismatching decomposers") - data = obj[FOBS_DATA] - if self.manager: - data = self.manager.internalize(data) +def _get_type_name(cls: Type) -> str: + module = cls.__module__ + if module == "builtins": + return cls.__qualname__ + return module + "." + cls.__qualname__ - decomposer = _decomposers[type_name] - return decomposer.recompose(data, self.manager) - @staticmethod - def _load_class(type_name: str): +def _load_class(type_name: str): + try: parts = type_name.split(".") if len(parts) == 1: parts = ["builtins", type_name] @@ -108,13 +75,8 @@ def _load_class(type_name: str): mod = getattr(mod, comp) return mod - - -def _get_type_name(cls: Type) -> str: - module = cls.__module__ - if module == "builtins": - return cls.__qualname__ - return module + "." + cls.__qualname__ + except Exception as ex: + raise TypeError(f"Can't load class {type_name}: {ex}") def register(decomposer: Union[Decomposer, Type[Decomposer]]) -> None: @@ -142,6 +104,80 @@ def register(decomposer: Union[Decomposer, Type[Decomposer]]) -> None: _decomposers[name] = instance +class Packer: + def __init__(self, manager: DatumManager): + self.manager = manager + self.enum_decomposer_name = _get_type_name(EnumTypeDecomposer) + self.data_decomposer_name = _get_type_name(DataClassDecomposer) + + def pack(self, obj: Any) -> dict: + + if type(obj) in MSGPACK_TYPES: + return obj + + type_name = _get_type_name(obj.__class__) + if type_name not in _decomposers: + registered = False + if isinstance(obj, Enum): + if _enum_auto_registration: + register_enum_types(type(obj)) + registered = True + else: + if callable(obj) or (not hasattr(obj, "__dict__")): + raise TypeError(f"{type(obj)} can't be serialized by FOBS without a decomposer") + if _data_auto_registration: + register_data_classes(type(obj)) + registered = True + + if not registered: + return obj + + decomposer = _decomposers[type_name] + + decomposed = decomposer.decompose(obj, self.manager) + if self.manager: + decomposed = self.manager.externalize(decomposed) + + return {FOBS_TYPE: type_name, FOBS_DATA: decomposed, FOBS_DECOMPOSER: _get_type_name(type(decomposer))} + + def unpack(self, obj: Any) -> Any: + + if type(obj) is not dict or FOBS_TYPE not in obj: + return obj + + type_name = obj[FOBS_TYPE] + if type_name not in _decomposers: + registered = False + decomposer_name = obj.get(FOBS_DECOMPOSER) + cls = _load_class(type_name) + if not decomposer_name: + # Maintaining backward compatibility with auto enum registration + if _enum_auto_registration: + if issubclass(cls, Enum): + register_enum_types(cls) + registered = True + else: + decomposer_class = _load_class(decomposer_name) + if decomposer_name == self.enum_decomposer_name or decomposer_name == self.data_decomposer_name: + # Generic decomposer's __init__ takes the target class as argument + decomposer = decomposer_class(cls) + else: + decomposer = decomposer_class() + + register(decomposer) + registered = True + + if not registered: + raise TypeError(f"Type {type_name} has no decomposer registered") + + data = obj[FOBS_DATA] + if self.manager: + data = self.manager.internalize(data) + + decomposer = _decomposers[type_name] + return decomposer.recompose(data, self.manager) + + def register_data_classes(*data_classes: Type[T]) -> None: """Register generic decomposers for data classes @@ -169,14 +205,25 @@ def register_enum_types(*enum_types: Type[Enum]) -> None: def auto_register_enum_types(enabled=True) -> None: - """Enable or disable auto registering of enum classes + """Enable or disable auto registering of enum types Args: enabled: Auto-registering of enum classes is enabled if True """ - global _enum_auto_register + global _enum_auto_registration + + _enum_auto_registration = enabled + + +def auto_register_data_classes(enabled=True) -> None: + """Enable or disable auto registering of data classes + + Args: + enabled: Auto-registering of data classes is enabled if True + """ + global _data_auto_registration - _enum_auto_register = enabled + _enum_data_registration = enabled def register_folder(folder: str, package: str): @@ -281,7 +328,6 @@ def deserialize_stream(stream: BinaryIO, manager: DatumManager = None, **kwargs) def reset(): """Reset FOBS to initial state. Used for unit test""" - # global _decomposers, _decomposers_registered - # _decomposers.clear() - # _decomposers_registered = False - pass + global _decomposers, _decomposers_registered + _decomposers.clear() + _decomposers_registered = False diff --git a/tests/unit_test/fuel/utils/fobs/fobs_test.py b/tests/unit_test/fuel/utils/fobs/fobs_test.py index 976e0e0d11..360fe415b4 100644 --- a/tests/unit_test/fuel/utils/fobs/fobs_test.py +++ b/tests/unit_test/fuel/utils/fobs/fobs_test.py @@ -27,6 +27,7 @@ class TestFobs: NUMBER = 123456 FLOAT = 123.456 + NAME = "FOBS Test" NOW = datetime.now() test_data = { @@ -51,9 +52,13 @@ def test_aliases(self): def test_unsupported_classes(self): with pytest.raises(TypeError): - # Queue is just a random built-in class not supported by FOBS + # Queue contains collections.deque, which has no __dict__, can't be handled as Data Class unsupported_class = queue.Queue() - fobs.dumps(unsupported_class) + try: + fobs.dumps(unsupported_class) + except Exception as ex: + print(ex) + raise ex def test_decomposers(self): test_class = ExampleClass(TestFobs.NUMBER) @@ -62,6 +67,23 @@ def test_decomposers(self): new_class = fobs.loads(buf) assert new_class.number == TestFobs.NUMBER + def test_no_registration(self): + test_class = ExampleClass(TestFobs.NUMBER) + fobs.register(ExampleClassDecomposer) + buf = fobs.dumps(test_class) + # Clear registration before deserializing + fobs.reset() + new_class = fobs.loads(buf) + assert new_class.number == TestFobs.NUMBER + + def test_auto_registration(self): + fobs.reset() + test_class = ExampleDataClass(TestFobs.NAME) + # No decomposer is registered for ExampleDataClass + buf = fobs.dumps(test_class) + new_class = fobs.loads(buf) + assert new_class.name == TestFobs.NAME + def test_buffer_list(self): buf = fobs.dumps(TestFobs.test_data, buffer_list=True) data = fobs.loads(buf) @@ -73,6 +95,11 @@ def __init__(self, number): self.number = number +class ExampleDataClass: + def __init__(self, name): + self.name = name + + class ExampleClassDecomposer(Decomposer): def supported_type(self): return ExampleClass