Skip to content

Commit

Permalink
Feature: Allow serialization of custom networks (#284)
Browse files Browse the repository at this point in the history
* feat: allow serialization of custom networks

This commit adds utility functions and extends existing networks to
enable serialization of complete networks when custom network types are
passed as arguments (e.g., for sub-networks in coupling flows).

The main complications were:

* Objects of type `type` (uninstantiated classes) cannot be serialized
  using `keras.saving.serialize_keras_object`, as the have no
  `get_config` function.

* We want to support both strings and types as parameters, leading to
  the need to distinguish those during manual
  serialization/deserialization.

* Auto-discovery of __init__ parameters is only active when `get_config`
  is not overridden, necessitating to manually store the configuration
  for serialization.

For storing the types, we use `keras.saving.get_registered_name`,
which can be reconstructed at deserialization using
`keras.saving.get_registered_object`.

Handling the different cases is moved the utility functions
`(de)serialize_val_or_type`, which uses a naming scheme to determine
which deserialization method to use.

The same setup can be extended to other custom types, e.g.
distributions.

* rename (de)serialize_val_or_type to (de)serialize_value_or_type
  • Loading branch information
vpratz authored Dec 20, 2024
1 parent 8a870c0 commit 2068b5e
Show file tree
Hide file tree
Showing 11 changed files with 290 additions and 17 deletions.
23 changes: 22 additions & 1 deletion bayesflow/networks/consistency_models/consistency_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import numpy as np

from bayesflow.types import Tensor
from bayesflow.utils import find_network, keras_kwargs
from bayesflow.utils import find_network, keras_kwargs, serialize_value_or_type, deserialize_value_or_type


from ..inference_network import InferenceNetwork
Expand Down Expand Up @@ -88,6 +88,27 @@ def __init__(

self.seed_generator = keras.random.SeedGenerator()

# serialization: store all parameters necessary to call __init__
self.config = {
"total_steps": total_steps,
"max_time": max_time,
"sigma2": sigma2,
"eps": eps,
"s0": s0,
"s1": s1,
**kwargs,
}
self.config = serialize_value_or_type(self.config, "subnet", subnet)

def get_config(self):
base_config = super().get_config()
return base_config | self.config

@classmethod
def from_config(cls, config):
config = deserialize_value_or_type(config, "subnet")
return cls(**config)

def _schedule_discretization(self, step) -> float:
"""Schedule function for adjusting the discretization level `N` during
the course of training.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,16 @@
import numpy as np

from bayesflow.types import Tensor
from bayesflow.utils import jvp, concatenate, find_network, keras_kwargs, expand_right_as, expand_right_to
from bayesflow.utils import (
jvp,
concatenate,
find_network,
keras_kwargs,
expand_right_as,
expand_right_to,
serialize_value_or_type,
deserialize_value_or_type,
)


from ..inference_network import InferenceNetwork
Expand Down Expand Up @@ -62,6 +71,22 @@ def __init__(

self.seed_generator = keras.random.SeedGenerator()

# serialization: store all parameters necessary to call __init__
self.config = {
"sigma_data": sigma_data,
**kwargs,
}
self.config = serialize_value_or_type(self.config, "subnet", subnet)

def get_config(self):
base_config = super().get_config()
return base_config | self.config

@classmethod
def from_config(cls, config):
config = deserialize_value_or_type(config, "subnet")
return cls(**config)

def _discretize_time(self, num_steps: int, rho: float = 3.5, **kwargs):
t = np.linspace(0.0, np.pi / 2, num_steps)
times = np.exp((t - np.pi / 2) * rho) * np.pi / 2
Expand Down
22 changes: 21 additions & 1 deletion bayesflow/networks/coupling_flow/coupling_flow.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from keras.saving import register_keras_serializable as serializable

from bayesflow.types import Tensor
from bayesflow.utils import find_permutation, keras_kwargs
from bayesflow.utils import find_permutation, keras_kwargs, serialize_value_or_type, deserialize_value_or_type

from .actnorm import ActNorm
from .couplings import DualCoupling
Expand Down Expand Up @@ -58,13 +58,33 @@ def __init__(

self.invertible_layers.append(DualCoupling(subnet, transform, **kwargs.get("coupling_kwargs", {})))

# serialization: store all parameters necessary to call __init__
self.config = {
"depth": depth,
"transform": transform,
"permutation": permutation,
"use_actnorm": use_actnorm,
"base_distribution": base_distribution,
**kwargs,
}
self.config = serialize_value_or_type(self.config, "subnet", subnet)

# noinspection PyMethodOverriding
def build(self, xz_shape, conditions_shape=None):
super().build(xz_shape)

for layer in self.invertible_layers:
layer.build(xz_shape=xz_shape, conditions_shape=conditions_shape)

def get_config(self):
base_config = super().get_config()
return base_config | self.config

@classmethod
def from_config(cls, config):
config = deserialize_value_or_type(config, "subnet")
return cls(**config)

def _forward(
self, x: Tensor, conditions: Tensor = None, density: bool = False, training: bool = False, **kwargs
) -> Tensor | tuple[Tensor, Tensor]:
Expand Down
18 changes: 17 additions & 1 deletion bayesflow/networks/coupling_flow/couplings/dual_coupling.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import keras
from keras.saving import register_keras_serializable as serializable

from bayesflow.utils import keras_kwargs
from bayesflow.utils import keras_kwargs, serialize_value_or_type, deserialize_value_or_type
from bayesflow.types import Tensor
from .single_coupling import SingleCoupling
from ..invertible_layer import InvertibleLayer
Expand All @@ -15,6 +15,22 @@ def __init__(self, subnet: str | type = "mlp", transform: str = "affine", **kwar
self.coupling2 = SingleCoupling(subnet, transform, **kwargs)
self.pivot = None

# serialization: store all parameters necessary to call __init__
self.config = {
"transform": transform,
**kwargs,
}
self.config = serialize_value_or_type(self.config, "subnet", subnet)

def get_config(self):
base_config = super().get_config()
return base_config | self.config

@classmethod
def from_config(cls, config):
config = deserialize_value_or_type(config, "subnet")
return cls(**config)

# noinspection PyMethodOverriding
def build(self, xz_shape, conditions_shape=None):
self.pivot = xz_shape[-1] // 2
Expand Down
18 changes: 17 additions & 1 deletion bayesflow/networks/coupling_flow/couplings/single_coupling.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from keras.saving import register_keras_serializable as serializable

from bayesflow.types import Tensor
from bayesflow.utils import find_network, keras_kwargs
from bayesflow.utils import find_network, keras_kwargs, serialize_value_or_type, deserialize_value_or_type
from ..invertible_layer import InvertibleLayer
from ..transforms import find_transform

Expand All @@ -26,6 +26,22 @@ def __init__(self, subnet: str | type = "mlp", transform: str = "affine", **kwar
output_projector_kwargs.setdefault("kernel_initializer", "zeros")
self.output_projector = keras.layers.Dense(units=None, **output_projector_kwargs)

# serialization: store all parameters necessary to call __init__
self.config = {
"transform": transform,
**kwargs,
}
self.config = serialize_value_or_type(self.config, "subnet", subnet)

def get_config(self):
base_config = super().get_config()
return base_config | self.config

@classmethod
def from_config(cls, config):
config = deserialize_value_or_type(config, "subnet")
return cls(**config)

# noinspection PyMethodOverriding
def build(self, x1_shape, x2_shape, conditions_shape=None):
self.output_projector.units = self.transform.params_per_dim * x2_shape[-1]
Expand Down
27 changes: 26 additions & 1 deletion bayesflow/networks/flow_matching/flow_matching.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,13 @@
from keras.saving import register_keras_serializable as serializable

from bayesflow.types import Shape, Tensor
from bayesflow.utils import expand_right_as, keras_kwargs, optimal_transport
from bayesflow.utils import (
expand_right_as,
keras_kwargs,
optimal_transport,
serialize_value_or_type,
deserialize_value_or_type,
)
from ..inference_network import InferenceNetwork
from .integrators import EulerIntegrator
from .integrators import RK2Integrator
Expand Down Expand Up @@ -52,10 +58,29 @@ def __init__(
case _:
raise NotImplementedError(f"No support for {integrator} integration")

# serialization: store all parameters necessary to call __init__
self.config = {
"base_distribution": base_distribution,
"integrator": integrator,
"use_optimal_transport": use_optimal_transport,
"optimal_transport_kwargs": optimal_transport_kwargs,
**kwargs,
}
self.config = serialize_value_or_type(self.config, "subnet", subnet)

def build(self, xz_shape: Shape, conditions_shape: Shape = None) -> None:
super().build(xz_shape)
self.integrator.build(xz_shape, conditions_shape)

def get_config(self):
base_config = super().get_config()
return base_config | self.config

@classmethod
def from_config(cls, config):
config = deserialize_value_or_type(config, "subnet")
return cls(**config)

def _forward(
self, x: Tensor, conditions: Tensor = None, density: bool = False, training: bool = False, **kwargs
) -> Tensor | tuple[Tensor, Tensor]:
Expand Down
31 changes: 30 additions & 1 deletion bayesflow/networks/free_form_flow/free_form_flow.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,16 @@
from keras.saving import register_keras_serializable as serializable

from bayesflow.types import Tensor
from bayesflow.utils import find_network, keras_kwargs, concatenate, log_jacobian_determinant, jvp, vjp
from bayesflow.utils import (
find_network,
keras_kwargs,
concatenate,
log_jacobian_determinant,
jvp,
vjp,
serialize_value_or_type,
deserialize_value_or_type,
)

from ..inference_network import InferenceNetwork

Expand Down Expand Up @@ -63,6 +72,26 @@ def __init__(

self.seed_generator = keras.random.SeedGenerator()

# serialization: store all parameters necessary to call __init__
self.config = {
"beta": beta,
"base_distribution": base_distribution,
"hutchinson_sampling": hutchinson_sampling,
**kwargs,
}
self.config = serialize_value_or_type(self.config, "encoder_subnet", encoder_subnet)
self.config = serialize_value_or_type(self.config, "decoder_subnet", decoder_subnet)

def get_config(self):
base_config = super().get_config()
return base_config | self.config

@classmethod
def from_config(cls, config):
config = deserialize_value_or_type(config, "encoder_subnet")
config = deserialize_value_or_type(config, "decoder_subnet")
return cls(**config)

# noinspection PyMethodOverriding
def build(self, xz_shape, conditions_shape=None):
super().build(xz_shape)
Expand Down
1 change: 1 addition & 0 deletions bayesflow/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
format_bytes,
parse_bytes,
)
from .serialization import serialize_value_or_type, deserialize_value_or_type
from .jacobian_trace import jacobian_trace
from .jacobian import compute_jacobian, log_jacobian_determinant
from .jvp import jvp
Expand Down
78 changes: 78 additions & 0 deletions bayesflow/utils/serialization.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
import keras


PREFIX = "_bayesflow_"


def serialize_value_or_type(config, name, obj):
"""Serialize an object that can be either a value or a type
and add it to a copy of the supplied dictionary.
Parameters
----------
config : dict
Dictionary to add the serialized object to. This function does not
modify the dictionary in place, but returns a modified copy.
name : str
Name of the obj that should be stored. Required for later deserialization.
obj : object or type
The object to serialize. If `obj` is of type `type`, we use
`keras.saving.get_registered_name` to obtain the registered type name.
If it is not a type, we try to serialize it as a Keras object.
Returns
-------
updated_config : dict
Updated dictionary with a new key `"_bayesflow_<name>_type"` or
`"_bayesflow_<name>_val"`. The prefix is used to avoid name collisions,
the suffix indicates how the stored value has to be deserialized.
Notes
-----
We allow strings or `type` parameters at several places to instantiate objects
of a given type (e.g., `subnet` in `CouplingFlow`). As `type` objects cannot
be serialized, we have to distinguish the two cases for serialization and
deserialization. This function is a helper function to standardize and
simplify this.
"""
updated_config = config.copy()
if isinstance(obj, type):
updated_config[f"{PREFIX}{name}_type"] = keras.saving.get_registered_name(obj)
else:
updated_config[f"{PREFIX}{name}_val"] = keras.saving.serialize_keras_object(obj)
return updated_config


def deserialize_value_or_type(config, name):
"""Deserialize an object that can be either a value or a type and add
it to the supplied dictionary.
Parameters
----------
config : dict
Dictionary containing the object to deserialize. If a type was
serialized, it should contain the key `"_bayesflow_<name>_type"`.
If an object was serialized, it should contain the key
`"_bayesflow_<name>_val"`. In a copy of this dictionary,
the item will be replaced with the key `name`.
name : str
Name of the object to deserialize.
Returns
-------
updated_config : dict
Updated dictionary with a new key `name`, with a value that is either
a type or an object.
See Also
--------
`serialize_value_or_type`
"""
updated_config = config.copy()
if f"{PREFIX}{name}_type" in config:
updated_config[name] = keras.saving.get_registered_object(config[f"{PREFIX}{name}_type"])
del updated_config[f"{PREFIX}{name}_type"]
elif f"{PREFIX}{name}_val" in config:
updated_config[name] = keras.saving.deserialize_keras_object(config[f"{PREFIX}{name}_val"])
del updated_config[f"{PREFIX}{name}_val"]
return updated_config
Loading

0 comments on commit 2068b5e

Please sign in to comment.