-
Notifications
You must be signed in to change notification settings - Fork 56
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge branch 'streamlined-backend' of https://github.com/stefanradev9…
…3/BayesFlow into streamlined-backend
- Loading branch information
Showing
6 changed files
with
100 additions
and
66 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,3 +1,7 @@ | ||
|
||
from .dictutils import nested_getitem, keras_kwargs | ||
from .finders import find_distribution, find_network, find_pooling | ||
from .dispatch import ( | ||
find_distribution, | ||
find_network, | ||
find_pooling, | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,2 @@ | ||
|
||
from .find_network import find_network |
24 changes: 24 additions & 0 deletions
24
bayesflow/experimental/utils/dispatch/find_distribution.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,24 @@ | ||
|
||
from functools import singledispatch | ||
|
||
|
||
@singledispatch | ||
def find_distribution(arg, **kwargs): | ||
raise TypeError(f"Cannot infer distribution from {arg!r}.") | ||
|
||
|
||
@find_distribution.register | ||
def _(name: str, **kwargs): | ||
match name.lower(): | ||
case "normal": | ||
from bayesflow.experimental.distributions import DiagonalNormal | ||
distribution = DiagonalNormal(**kwargs) | ||
case other: | ||
raise ValueError(f"Unsupported distribution name '{other}'.") | ||
|
||
return distribution | ||
|
||
|
||
@find_distribution.register | ||
def _(constructor: type, **kwargs): | ||
return constructor(**kwargs) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,31 @@ | ||
|
||
import keras | ||
|
||
from functools import singledispatch | ||
|
||
|
||
@singledispatch | ||
def find_network(arg, **kwargs): | ||
raise TypeError(f"Cannot infer network from {arg!r}.") | ||
|
||
|
||
@find_network.register | ||
def _(name: str, **kwargs): | ||
match name.lower(): | ||
case "resnet": | ||
from bayesflow.experimental.networks import ResNet | ||
network = ResNet(**kwargs) | ||
case other: | ||
raise ValueError(f"Unsupported network name: '{other}'.") | ||
|
||
return network | ||
|
||
|
||
@find_network.register | ||
def _(network: keras.Layer): | ||
return network | ||
|
||
|
||
@find_network.register | ||
def _(constructor: type, **kwargs): | ||
return constructor(**kwargs) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,38 @@ | ||
|
||
import keras | ||
|
||
from functools import singledispatch | ||
|
||
|
||
@singledispatch | ||
def find_pooling(arg, **kwargs): | ||
raise TypeError(f"Cannot infer pooling from {arg!r}.") | ||
|
||
|
||
@find_pooling.register | ||
def _(name: str, **kwargs): | ||
match name.lower(): | ||
case "mean" | "avg" | "average": | ||
pooling = keras.layers.Lambda(lambda inp: keras.ops.mean(inp, axis=-2)) | ||
case "max": | ||
pooling = keras.layers.Lambda(lambda inp: keras.ops.max(inp, axis=-2)) | ||
case "min": | ||
pooling = keras.layers.Lambda(lambda inp: keras.ops.min(inp, axis=-2)) | ||
case "learnable" | "pma" | "attention": | ||
from bayesflow.experimental.networks.set_transformer.pma import PoolingByMultiheadAttention | ||
pooling = PoolingByMultiheadAttention(**kwargs) | ||
case other: | ||
raise ValueError(f"Unsupported pooling name: '{other}'.") | ||
|
||
return pooling | ||
|
||
|
||
@find_pooling.register | ||
def _(constructor: type, **kwargs): | ||
return constructor(**kwargs) | ||
|
||
|
||
@find_pooling.register | ||
def _(pooling: keras.Layer): | ||
return pooling | ||
|
This file was deleted.
Oops, something went wrong.