Skip to content

Commit

Permalink
FOBS enhancements: 1) Added auto-registration of data class decompose…
Browse files Browse the repository at this point in the history
…rs. 2). No need to register decomposers on deserializing
  • Loading branch information
nvidianz committed May 22, 2024
1 parent 98289c4 commit 210ea08
Show file tree
Hide file tree
Showing 2 changed files with 136 additions and 63 deletions.
168 changes: 107 additions & 61 deletions nvflare/fuel/utils/fobs/fobs.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import os
from enum import Enum
from os.path import dirname, join
from types import NoneType
from typing import Any, BinaryIO, Dict, Type, TypeVar, Union

import msgpack
Expand All @@ -42,63 +43,29 @@

FOBS_TYPE = "__fobs_type__"
FOBS_DATA = "__fobs_data__"
FOBS_DECOMPOSER = "__fobs_dc__"

MAX_CONTENT_LEN = 128
MSGPACK_TYPES = (None, bool, int, float, str, bytes, bytearray, memoryview, list, dict)
MSGPACK_TYPES = (NoneType, bool, int, float, str, bytes, bytearray, memoryview, list, dict)
T = TypeVar("T")

log = logging.getLogger(__name__)
_decomposers: Dict[str, Decomposer] = {}
_decomposers_registered = False
_enum_auto_register = True


class Packer:
def __init__(self, manager: DatumManager):
self.manager = manager

def pack(self, obj: Any) -> dict:

if type(obj) in MSGPACK_TYPES:
return obj

type_name = _get_type_name(obj.__class__)
if type_name not in _decomposers:
if _enum_auto_register and isinstance(obj, Enum):
register_enum_types(type(obj))
else:
return obj

decomposed = _decomposers[type_name].decompose(obj, self.manager)
if self.manager:
decomposed = self.manager.externalize(decomposed)

return {FOBS_TYPE: type_name, FOBS_DATA: decomposed}

def unpack(self, obj: Any) -> Any:

if type(obj) is not dict or FOBS_TYPE not in obj:
return obj
# If this is enabled, FOBS will try to register generic decomposers automatically
_enum_auto_registration = True
_data_auto_registration = True

type_name = obj[FOBS_TYPE]
if type_name not in _decomposers:
error = True
if _enum_auto_register:
cls = self._load_class(type_name)
if issubclass(cls, Enum):
register_enum_types(cls)
error = False
if error:
raise TypeError(f"Unknown type {type_name}, caused by mismatching decomposers")

data = obj[FOBS_DATA]
if self.manager:
data = self.manager.internalize(data)
def _get_type_name(cls: Type) -> str:
module = cls.__module__
if module == "builtins":
return cls.__qualname__
return module + "." + cls.__qualname__

decomposer = _decomposers[type_name]
return decomposer.recompose(data, self.manager)

@staticmethod
def _load_class(type_name: str):
def _load_class(type_name: str):
try:
parts = type_name.split(".")
if len(parts) == 1:
parts = ["builtins", type_name]
Expand All @@ -108,13 +75,8 @@ def _load_class(type_name: str):
mod = getattr(mod, comp)

return mod


def _get_type_name(cls: Type) -> str:
module = cls.__module__
if module == "builtins":
return cls.__qualname__
return module + "." + cls.__qualname__
except Exception as ex:
raise TypeError(f"Can't load class {type_name}: {ex}")


def register(decomposer: Union[Decomposer, Type[Decomposer]]) -> None:
Expand Down Expand Up @@ -142,6 +104,80 @@ def register(decomposer: Union[Decomposer, Type[Decomposer]]) -> None:
_decomposers[name] = instance


class Packer:
def __init__(self, manager: DatumManager):
self.manager = manager
self.enum_decomposer_name = _get_type_name(EnumTypeDecomposer)
self.data_decomposer_name = _get_type_name(DataClassDecomposer)

def pack(self, obj: Any) -> dict:

if type(obj) in MSGPACK_TYPES:
return obj

type_name = _get_type_name(obj.__class__)
if type_name not in _decomposers:
registered = False
if isinstance(obj, Enum):
if _enum_auto_registration:
register_enum_types(type(obj))
registered = True
else:
if callable(obj) or (not hasattr(obj, "__dict__")):
raise TypeError(f"{type(obj)} can't be serialized by FOBS without a decomposer")
if _data_auto_registration:
register_data_classes(type(obj))
registered = True

if not registered:
return obj

decomposer = _decomposers[type_name]

decomposed = decomposer.decompose(obj, self.manager)
if self.manager:
decomposed = self.manager.externalize(decomposed)

return {FOBS_TYPE: type_name, FOBS_DATA: decomposed, FOBS_DECOMPOSER: _get_type_name(type(decomposer))}

def unpack(self, obj: Any) -> Any:

if type(obj) is not dict or FOBS_TYPE not in obj:
return obj

type_name = obj[FOBS_TYPE]
if type_name not in _decomposers:
registered = False
decomposer_name = obj.get(FOBS_DECOMPOSER)
cls = _load_class(type_name)
if not decomposer_name:
# Maintaining backward compatibility with auto enum registration
if _enum_auto_registration:
if issubclass(cls, Enum):
register_enum_types(cls)
registered = True
else:
decomposer_class = _load_class(decomposer_name)
if decomposer_name == self.enum_decomposer_name or decomposer_name == self.data_decomposer_name:
# Generic decomposer's __init__ takes the target class as argument
decomposer = decomposer_class(cls)
else:
decomposer = decomposer_class()

register(decomposer)
registered = True

if not registered:
raise TypeError(f"Type {type_name} has no decomposer registered")

data = obj[FOBS_DATA]
if self.manager:
data = self.manager.internalize(data)

decomposer = _decomposers[type_name]
return decomposer.recompose(data, self.manager)


def register_data_classes(*data_classes: Type[T]) -> None:
"""Register generic decomposers for data classes
Expand Down Expand Up @@ -169,14 +205,25 @@ def register_enum_types(*enum_types: Type[Enum]) -> None:


def auto_register_enum_types(enabled=True) -> None:
"""Enable or disable auto registering of enum classes
"""Enable or disable auto registering of enum types
Args:
enabled: Auto-registering of enum classes is enabled if True
"""
global _enum_auto_register
global _enum_auto_registration

_enum_auto_registration = enabled


def auto_register_data_classes(enabled=True) -> None:
"""Enable or disable auto registering of data classes
Args:
enabled: Auto-registering of data classes is enabled if True
"""
global _data_auto_registration

_enum_auto_register = enabled
_enum_data_registration = enabled


def register_folder(folder: str, package: str):
Expand Down Expand Up @@ -281,7 +328,6 @@ def deserialize_stream(stream: BinaryIO, manager: DatumManager = None, **kwargs)

def reset():
"""Reset FOBS to initial state. Used for unit test"""
# global _decomposers, _decomposers_registered
# _decomposers.clear()
# _decomposers_registered = False
pass
global _decomposers, _decomposers_registered
_decomposers.clear()
_decomposers_registered = False
31 changes: 29 additions & 2 deletions tests/unit_test/fuel/utils/fobs/fobs_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ class TestFobs:

NUMBER = 123456
FLOAT = 123.456
NAME = "FOBS Test"
NOW = datetime.now()

test_data = {
Expand All @@ -51,9 +52,13 @@ def test_aliases(self):

def test_unsupported_classes(self):
with pytest.raises(TypeError):
# Queue is just a random built-in class not supported by FOBS
# Queue contains collections.deque, which has no __dict__, can't be handled as Data Class
unsupported_class = queue.Queue()
fobs.dumps(unsupported_class)
try:
fobs.dumps(unsupported_class)
except Exception as ex:
print(ex)
raise ex

def test_decomposers(self):
test_class = ExampleClass(TestFobs.NUMBER)
Expand All @@ -62,6 +67,23 @@ def test_decomposers(self):
new_class = fobs.loads(buf)
assert new_class.number == TestFobs.NUMBER

def test_no_registration(self):
test_class = ExampleClass(TestFobs.NUMBER)
fobs.register(ExampleClassDecomposer)
buf = fobs.dumps(test_class)
# Clear registration before deserializing
fobs.reset()
new_class = fobs.loads(buf)
assert new_class.number == TestFobs.NUMBER

def test_auto_registration(self):
fobs.reset()
test_class = ExampleDataClass(TestFobs.NAME)
# No decomposer is registered for ExampleDataClass
buf = fobs.dumps(test_class)
new_class = fobs.loads(buf)
assert new_class.name == TestFobs.NAME

def test_buffer_list(self):
buf = fobs.dumps(TestFobs.test_data, buffer_list=True)
data = fobs.loads(buf)
Expand All @@ -73,6 +95,11 @@ def __init__(self, number):
self.number = number


class ExampleDataClass:
def __init__(self, name):
self.name = name


class ExampleClassDecomposer(Decomposer):
def supported_type(self):
return ExampleClass
Expand Down

0 comments on commit 210ea08

Please sign in to comment.