-
Notifications
You must be signed in to change notification settings - Fork 56
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Added summary_networks folder in experimental for future additions. Added __init__ file for exposed summary networks. Added deepset implementation to summary network class hierarchy.
- Loading branch information
1 parent
cc69a38
commit 04045b8
Showing
2 changed files
with
98 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
from .deepset import DeepSet |
97 changes: 97 additions & 0 deletions
97
bayesflow/experimental/networks/summary_networks/deepset.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,97 @@ | ||
from ..summary_network import SummaryNetwork | ||
# from ..helper_networks import InvariantModule, EquivariantModule # use this when InvariantModule is implemented | ||
from bayesflow.helper_networks import InvariantModule, EquivariantModule # remove this when above is satisfied | ||
from bayesflow import default_settings as defaults | ||
from keras.api.layers import Dense | ||
from keras import Sequential | ||
|
||
class DeepSet(SummaryNetwork): | ||
def __init__( | ||
self, | ||
summary_dim: int = 10, | ||
num_dense_s1: int = 2, | ||
num_dense_s2: int = 2, | ||
num_dense_s3: int = 3, | ||
num_equiv: int = 2, | ||
dense_s1_args=None, | ||
dense_s2_args=None, | ||
dense_s3_args=None, | ||
pooling_fun: str = "mean", | ||
**kwargs | ||
): | ||
"""Creates a stack of 'num_equiv' equivariant layers followed by a final invariant layer. | ||
Parameters | ||
---------- | ||
summary_dim : int, optional, default: 10 | ||
The number of learned summary statistics. | ||
num_dense_s1 : int, optional, default: 2 | ||
The number of dense layers in the inner function of a deep set. | ||
num_dense_s2 : int, optional, default: 2 | ||
The number of dense layers in the outer function of a deep set. | ||
num_dense_s3 : int, optional, default: 2 | ||
The number of dense layers in an equivariant layer. | ||
num_equiv : int, optional, default: 2 | ||
The number of equivariant layers in the network. | ||
dense_s1_args : dict or None, optional, default: None | ||
The arguments for the dense layers of s1 (inner, pre-pooling function). If `None`, | ||
defaults will be used (see `default_settings`). Otherwise, all arguments for a | ||
tf.keras.layers.Dense layer are supported. | ||
dense_s2_args : dict or None, optional, default: None | ||
The arguments for the dense layers of s2 (outer, post-pooling function). If `None`, | ||
defaults will be used (see `default_settings`). Otherwise, all arguments for a | ||
tf.keras.layers.Dense layer are supported. | ||
dense_s3_args : dict or None, optional, default: None | ||
The arguments for the dense layers of s3 (equivariant function). If `None`, | ||
defaults will be used (see `default_settings`). Otherwise, all arguments for a | ||
tf.keras.layers.Dense layer are supported. | ||
pooling_fun : str of callable, optional, default: 'mean' | ||
If string argument provided, should be one in ['mean', 'max']. In addition, ac actual | ||
neural network can be passed for learnable pooling. | ||
**kwargs : dict, optional, default: {} | ||
Optional keyword arguments passed to the __init__() method of tf.keras.Model. | ||
""" | ||
|
||
super().__init__(**kwargs) | ||
|
||
# Prepare settings dictionary | ||
settings = dict( | ||
num_dense_s1=num_dense_s1, | ||
num_dense_s2=num_dense_s2, | ||
num_dense_s3=num_dense_s3, | ||
dense_s1_args=defaults.DEFAULT_SETTING_DENSE_DEEP_SET if dense_s1_args is None else dense_s1_args, | ||
dense_s2_args=defaults.DEFAULT_SETTING_DENSE_DEEP_SET if dense_s2_args is None else dense_s2_args, | ||
dense_s3_args=defaults.DEFAULT_SETTING_DENSE_DEEP_SET if dense_s3_args is None else dense_s3_args, | ||
pooling_fun=pooling_fun, | ||
) | ||
|
||
# Create equivariant layers and final invariant layer | ||
self.equiv_layers = Sequential([EquivariantModule(settings) for _ in range(num_equiv)]) | ||
self.inv = InvariantModule(settings) | ||
|
||
# Output layer to output "summary_dim" learned summary statistics | ||
self.out_layer = Dense(summary_dim, activation="linear") | ||
self.summary_dim = summary_dim | ||
|
||
def call(self, x, **kwargs): | ||
"""Performs the forward pass of a learnable deep invariant transformation consisting of | ||
a sequence of equivariant transforms followed by an invariant transform. | ||
Parameters | ||
---------- | ||
x : tf.Tensor | ||
Input of shape (batch_size, n_obs, data_dim) | ||
Returns | ||
------- | ||
out : tf.Tensor | ||
Output of shape (batch_size, out_dim) | ||
""" | ||
|
||
# Pass through series of augmented equivariant transforms | ||
out_equiv = self.equiv_layers(x, **kwargs) | ||
|
||
# Pass through final invariant layer | ||
out = self.out_layer(self.inv(out_equiv, **kwargs), **kwargs) | ||
|
||
return out |