Skip to content

Commit

Permalink
Merge branch 'streamlined-backend' of https://github.com/stefanradev9…
Browse files Browse the repository at this point in the history
…3/BayesFlow into streamlined-backend
  • Loading branch information
stefanradev93 committed Jun 6, 2024
2 parents 797e1e7 + ed07fb3 commit 4fbda46
Show file tree
Hide file tree
Showing 6 changed files with 100 additions and 66 deletions.
6 changes: 5 additions & 1 deletion bayesflow/experimental/utils/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,7 @@

from .dictutils import nested_getitem, keras_kwargs
from .finders import find_distribution, find_network, find_pooling
from .dispatch import (
find_distribution,
find_network,
find_pooling,
)
2 changes: 2 additions & 0 deletions bayesflow/experimental/utils/dispatch/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@

from .find_network import find_network
24 changes: 24 additions & 0 deletions bayesflow/experimental/utils/dispatch/find_distribution.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@

from functools import singledispatch


@singledispatch
def find_distribution(arg, **kwargs):
raise TypeError(f"Cannot infer distribution from {arg!r}.")


@find_distribution.register
def _(name: str, **kwargs):
match name.lower():
case "normal":
from bayesflow.experimental.distributions import DiagonalNormal
distribution = DiagonalNormal(**kwargs)
case other:
raise ValueError(f"Unsupported distribution name '{other}'.")

return distribution


@find_distribution.register
def _(constructor: type, **kwargs):
return constructor(**kwargs)
31 changes: 31 additions & 0 deletions bayesflow/experimental/utils/dispatch/find_network.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@

import keras

from functools import singledispatch


@singledispatch
def find_network(arg, **kwargs):
raise TypeError(f"Cannot infer network from {arg!r}.")


@find_network.register
def _(name: str, **kwargs):
match name.lower():
case "resnet":
from bayesflow.experimental.networks import ResNet
network = ResNet(**kwargs)
case other:
raise ValueError(f"Unsupported network name: '{other}'.")

return network


@find_network.register
def _(network: keras.Layer):
return network


@find_network.register
def _(constructor: type, **kwargs):
return constructor(**kwargs)
38 changes: 38 additions & 0 deletions bayesflow/experimental/utils/dispatch/find_pooling.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@

import keras

from functools import singledispatch


@singledispatch
def find_pooling(arg, **kwargs):
raise TypeError(f"Cannot infer pooling from {arg!r}.")


@find_pooling.register
def _(name: str, **kwargs):
match name.lower():
case "mean" | "avg" | "average":
pooling = keras.layers.Lambda(lambda inp: keras.ops.mean(inp, axis=-2))
case "max":
pooling = keras.layers.Lambda(lambda inp: keras.ops.max(inp, axis=-2))
case "min":
pooling = keras.layers.Lambda(lambda inp: keras.ops.min(inp, axis=-2))
case "learnable" | "pma" | "attention":
from bayesflow.experimental.networks.set_transformer.pma import PoolingByMultiheadAttention
pooling = PoolingByMultiheadAttention(**kwargs)
case other:
raise ValueError(f"Unsupported pooling name: '{other}'.")

return pooling


@find_pooling.register
def _(constructor: type, **kwargs):
return constructor(**kwargs)


@find_pooling.register
def _(pooling: keras.Layer):
return pooling

65 changes: 0 additions & 65 deletions bayesflow/experimental/utils/finders.py

This file was deleted.

0 comments on commit 4fbda46

Please sign in to comment.