Skip to content

Commit

Permalink
Initial changes
Browse files Browse the repository at this point in the history
Added summary_networks folder in experimental for future additions.
Added __init__ file for exposed summary networks.
Added deepset implementation to summary network class hierarchy.
  • Loading branch information
Chase-Grajeda committed Jun 3, 2024
1 parent cc69a38 commit 04045b8
Show file tree
Hide file tree
Showing 2 changed files with 98 additions and 0 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from .deepset import DeepSet
97 changes: 97 additions & 0 deletions bayesflow/experimental/networks/summary_networks/deepset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
from ..summary_network import SummaryNetwork
# from ..helper_networks import InvariantModule, EquivariantModule # use this when InvariantModule is implemented
from bayesflow.helper_networks import InvariantModule, EquivariantModule # remove this when above is satisfied
from bayesflow import default_settings as defaults
from keras.api.layers import Dense
from keras import Sequential

class DeepSet(SummaryNetwork):
def __init__(
self,
summary_dim: int = 10,
num_dense_s1: int = 2,
num_dense_s2: int = 2,
num_dense_s3: int = 3,
num_equiv: int = 2,
dense_s1_args=None,
dense_s2_args=None,
dense_s3_args=None,
pooling_fun: str = "mean",
**kwargs
):
"""Creates a stack of 'num_equiv' equivariant layers followed by a final invariant layer.
Parameters
----------
summary_dim : int, optional, default: 10
The number of learned summary statistics.
num_dense_s1 : int, optional, default: 2
The number of dense layers in the inner function of a deep set.
num_dense_s2 : int, optional, default: 2
The number of dense layers in the outer function of a deep set.
num_dense_s3 : int, optional, default: 2
The number of dense layers in an equivariant layer.
num_equiv : int, optional, default: 2
The number of equivariant layers in the network.
dense_s1_args : dict or None, optional, default: None
The arguments for the dense layers of s1 (inner, pre-pooling function). If `None`,
defaults will be used (see `default_settings`). Otherwise, all arguments for a
tf.keras.layers.Dense layer are supported.
dense_s2_args : dict or None, optional, default: None
The arguments for the dense layers of s2 (outer, post-pooling function). If `None`,
defaults will be used (see `default_settings`). Otherwise, all arguments for a
tf.keras.layers.Dense layer are supported.
dense_s3_args : dict or None, optional, default: None
The arguments for the dense layers of s3 (equivariant function). If `None`,
defaults will be used (see `default_settings`). Otherwise, all arguments for a
tf.keras.layers.Dense layer are supported.
pooling_fun : str of callable, optional, default: 'mean'
If string argument provided, should be one in ['mean', 'max']. In addition, ac actual
neural network can be passed for learnable pooling.
**kwargs : dict, optional, default: {}
Optional keyword arguments passed to the __init__() method of tf.keras.Model.
"""

super().__init__(**kwargs)

# Prepare settings dictionary
settings = dict(
num_dense_s1=num_dense_s1,
num_dense_s2=num_dense_s2,
num_dense_s3=num_dense_s3,
dense_s1_args=defaults.DEFAULT_SETTING_DENSE_DEEP_SET if dense_s1_args is None else dense_s1_args,
dense_s2_args=defaults.DEFAULT_SETTING_DENSE_DEEP_SET if dense_s2_args is None else dense_s2_args,
dense_s3_args=defaults.DEFAULT_SETTING_DENSE_DEEP_SET if dense_s3_args is None else dense_s3_args,
pooling_fun=pooling_fun,
)

# Create equivariant layers and final invariant layer
self.equiv_layers = Sequential([EquivariantModule(settings) for _ in range(num_equiv)])
self.inv = InvariantModule(settings)

# Output layer to output "summary_dim" learned summary statistics
self.out_layer = Dense(summary_dim, activation="linear")
self.summary_dim = summary_dim

def call(self, x, **kwargs):
"""Performs the forward pass of a learnable deep invariant transformation consisting of
a sequence of equivariant transforms followed by an invariant transform.
Parameters
----------
x : tf.Tensor
Input of shape (batch_size, n_obs, data_dim)
Returns
-------
out : tf.Tensor
Output of shape (batch_size, out_dim)
"""

# Pass through series of augmented equivariant transforms
out_equiv = self.equiv_layers(x, **kwargs)

# Pass through final invariant layer
out = self.out_layer(self.inv(out_equiv, **kwargs), **kwargs)

return out

0 comments on commit 04045b8

Please sign in to comment.