Skip to content

Commit

Permalink
Working kwargs propagation
Browse files Browse the repository at this point in the history
  • Loading branch information
stefanradev93 committed Jun 4, 2024
1 parent 4f3bcb9 commit 4919a57
Show file tree
Hide file tree
Showing 6 changed files with 41 additions and 53 deletions.
20 changes: 10 additions & 10 deletions bayesflow/experimental/networks/coupling_flow/coupling_flow.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,10 @@
from typing import Tuple, Union

import keras
from keras.saving import (
register_keras_serializable,
)
from keras.saving import register_keras_serializable

from bayesflow.experimental.types import Tensor
from bayesflow.experimental.utils import keras_kwargs
from .actnorm import ActNorm
from .couplings import DualCoupling
from .permutations import OrthogonalPermutation, RandomPermutation, Swap
Expand Down Expand Up @@ -46,20 +45,21 @@ def __init__(
use_actnorm: bool = True,
**kwargs
):
# TODO - propagate optional keyword arguments to find_network and ResNet respectively
super().__init__(**kwargs)
"""TODO"""

super().__init__(**keras_kwargs(kwargs))

self._layers = []
for i in range(depth):
if use_actnorm:
self._layers.append(ActNorm(name=f"ActNorm{i}"))
self._layers.append(DualCoupling(subnet, transform, name=f"DualCoupling{i}"))
self._layers.append(ActNorm())
self._layers.append(DualCoupling(subnet, transform, **kwargs))
if permutation.lower() == "random":
self._layers.append(RandomPermutation(name=f"RandomPermutation{i}"))
self._layers.append(RandomPermutation())
elif permutation.lower() == "swap":
self._layers.append(Swap(name=f"Swap{i}"))
self._layers.append(Swap())
elif permutation.lower() == "learnable":
self._layers.append(OrthogonalPermutation(name=f"OrthogonalPermutation{i}"))
self._layers.append(OrthogonalPermutation())

# noinspection PyMethodOverriding
def build(self, xz_shape, conditions_shape=None):
Expand Down
Original file line number Diff line number Diff line change
@@ -1,9 +1,8 @@

import keras
from keras.saving import (
register_keras_serializable,
)
from keras.saving import register_keras_serializable

from bayesflow.experimental.utils import keras_kwargs
from bayesflow.experimental.types import Tensor
from .single_coupling import SingleCoupling
from ..invertible_layer import InvertibleLayer
Expand All @@ -12,9 +11,9 @@
@register_keras_serializable(package="bayesflow.networks.coupling_flow")
class DualCoupling(InvertibleLayer):
def __init__(self, subnet: str = "resnet", transform: str = "affine", **kwargs):
super().__init__(**kwargs)
self.coupling1 = SingleCoupling(subnet, transform, name=f"CouplingA")
self.coupling2 = SingleCoupling(subnet, transform, name=f"CouplingB")
super().__init__(**keras_kwargs(kwargs))
self.coupling1 = SingleCoupling(subnet, transform, **kwargs)
self.coupling2 = SingleCoupling(subnet, transform, **kwargs)
self.pivot = None

def build(self, input_shape):
Expand Down
Original file line number Diff line number Diff line change
@@ -1,11 +1,9 @@

import keras
from keras.saving import (
register_keras_serializable
)
from keras.saving import register_keras_serializable

from bayesflow.experimental.types import Tensor
from bayesflow.experimental.utils import find_network
from bayesflow.experimental.utils import find_network, keras_kwargs
from ..invertible_layer import InvertibleLayer
from ..transforms import find_transform

Expand All @@ -17,11 +15,20 @@ class SingleCoupling(InvertibleLayer):
Subnet output tensors are linearly mapped to the correct dimension.
"""
def __init__(self, network: str = "resnet", transform: str = "affine", **kwargs):
super().__init__(**kwargs)
self.output_projector = keras.layers.Dense(None, kernel_initializer="zeros", bias_initializer="zeros")
self.network = find_network(network)
self.transform = find_transform(transform)
def __init__(
self,
network: str = "resnet",
transform: str = "affine",
output_layer_kernel_init: str = "zeros",
**kwargs
):
super().__init__(**keras_kwargs(kwargs))
self.output_projector = keras.layers.Dense(
units=None,
kernel_initializer=output_layer_kernel_init,
)
self.network = find_network(network, **kwargs.get("subnet_kwargs", {}))
self.transform = find_transform(transform, **kwargs.get("transform_kwargs", {}))

# noinspection PyMethodOverriding
def build(self, x1_shape, x2_shape):
Expand Down
14 changes: 4 additions & 10 deletions bayesflow/experimental/networks/deep_set/deep_set.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@

import keras
from keras import layers, regularizers
from keras import layers
from keras.saving import (
register_keras_serializable,
serialize_keras_object
Expand All @@ -18,18 +18,16 @@ def __init__(
self,
summary_dim: int = 10,
depth: int = 2,
inner_pooling: str = "mean",
output_pooling: str = "mean",
inner_pooling: str | keras.Layer = "mean",
output_pooling: str | keras.Layer = "mean",
num_dense_equivariant: int = 2,
num_dense_invariant_inner: int = 2,
num_dense_invariant_outer: int = 2,
units_equivariant: int = 128,
units_invariant_inner: int = 128,
units_invariant_outer: int = 128,
activation: str | callable = "gelu",
kernel_regularizer: regularizers.Regularizer | None = None,
activation: str = "gelu",
kernel_initializer: str = "he_uniform",
bias_regularizer: regularizers.Regularizer | None = None,
dropout: float = 0.05,
spectral_normalization: bool = False,
**kwargs
Expand All @@ -51,9 +49,7 @@ def __init__(
units_invariant_inner=units_invariant_inner,
units_invariant_outer=units_invariant_outer,
activation=activation,
kernel_regularizer=kernel_regularizer,
kernel_initializer=kernel_initializer,
bias_regularizer=bias_regularizer,
spectral_normalization=spectral_normalization,
dropout=dropout,
pooling=inner_pooling,
Expand All @@ -68,9 +64,7 @@ def __init__(
units_inner=units_invariant_inner,
units_outer=units_invariant_outer,
activation=activation,
kernel_regularizer=kernel_regularizer,
kernel_initializer=kernel_initializer,
bias_regularizer=bias_regularizer,
dropout=dropout,
pooling=output_pooling,
spectral_normalization=spectral_normalization,
Expand Down
12 changes: 3 additions & 9 deletions bayesflow/experimental/networks/deep_set/equivariant_module.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@

import keras
from keras import ops, layers, regularizers
from keras import ops, layers
from keras.saving import register_keras_serializable

from bayesflow.experimental.types import Tensor
Expand All @@ -26,11 +26,9 @@ def __init__(
units_equivariant: int = 128,
units_invariant_inner: int = 128,
units_invariant_outer: int = 128,
pooling: str = "mean",
activation: str | callable = "gelu",
kernel_regularizer: regularizers.Regularizer | None = None,
pooling: str | keras.Layer = "mean",
activation: str = "gelu",
kernel_initializer: str = "he_uniform",
bias_regularizer: regularizers.Regularizer | None = None,
dropout: float = 0.05,
spectral_normalization: bool = False,
**kwargs
Expand All @@ -51,9 +49,7 @@ def __init__(
units_inner=units_invariant_inner,
units_outer=units_invariant_outer,
activation=activation,
kernel_regularizer=kernel_regularizer,
kernel_initializer=kernel_initializer,
bias_regularizer=bias_regularizer,
dropout=dropout,
pooling=pooling,
spectral_normalization=spectral_normalization,
Expand All @@ -65,9 +61,7 @@ def __init__(
layer = layers.Dense(
units=units_equivariant,
activation=activation,
kernel_regularizer=kernel_regularizer,
kernel_initializer=kernel_initializer,
bias_regularizer=bias_regularizer
)
if spectral_normalization:
layer = layers.SpectralNormalization(layer)
Expand Down
12 changes: 3 additions & 9 deletions bayesflow/experimental/networks/deep_set/invariant_module.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@

import keras
from keras import layers, regularizers
from keras import layers
from keras.saving import register_keras_serializable

from bayesflow.experimental.types import Tensor
Expand All @@ -24,12 +24,10 @@ def __init__(
num_dense_outer: int = 2,
units_inner: int = 128,
units_outer: int = 128,
activation: str | callable = "gelu",
kernel_regularizer: regularizers.Regularizer | None = None,
activation: str = "gelu",
kernel_initializer: str = "he_uniform",
bias_regularizer: regularizers.Regularizer | None = None,
dropout: float = 0.05,
pooling: str = "mean",
pooling: str | keras.Layer = "mean",
spectral_normalization: bool = False,
**kwargs
):
Expand All @@ -52,9 +50,7 @@ def __init__(
layer = layers.Dense(
units=units_inner,
activation=activation,
kernel_regularizer=kernel_regularizer,
kernel_initializer=kernel_initializer,
bias_regularizer=bias_regularizer
)
if spectral_normalization:
layer = layers.SpectralNormalization(layer)
Expand All @@ -69,9 +65,7 @@ def __init__(
layer = layers.Dense(
units=units_outer,
activation=activation,
kernel_regularizer=kernel_regularizer,
kernel_initializer=kernel_initializer,
bias_regularizer=bias_regularizer
)
if spectral_normalization:
layer = layers.SpectralNormalization(layer)
Expand Down

0 comments on commit 4919a57

Please sign in to comment.