Skip to content

Commit

Permalink
Zuko density estimators (#1088)
Browse files Browse the repository at this point in the history
* update zuko to 1.1.0

* test zuko_gmm commit

* build_zuko_nsf added

* add build_zuko_naf, update test

* add license change to pr template.

* CLN pyproject.toml (#1009)

* CLN pyproject.toml

* CLN optional deps comment

* CLN alphabetical order

* fix x_o and broken link tutorial 7 (#1003)

* fix x_o and broken link tutorial 7

* typo in title

* suppress plotting output

---------

Co-authored-by: Matthijs <matthijs@example.com>

* replace prepare_for_sbi in tutorials (#1013)

* add zuko density estimators

* not working gmm

* update tests for PR

* update PR for pyright

* resolve pyright

* add reportArgumentType

* resolve pyright issue

* resolve all issues pyright

* resolve pyright

* add typing and docstring

* add functions from factory to test

* remove comment mdn file

* add docstrings flow file

* add docstring in density_estimator_test.py

* Update sbi/neural_nets/flow.py

Co-authored-by: Sebastian Bischoff <sebastian@salzreute.de>

* Update sbi/neural_nets/flow.py

Co-authored-by: Sebastian Bischoff <sebastian@salzreute.de>

* Update sbi/neural_nets/flow.py

Co-authored-by: Sebastian Bischoff <sebastian@salzreute.de>

* removed pyright

---------

Co-authored-by: bkmi <12955549+bkmi@users.noreply.github.com>
Co-authored-by: Nastya Krouglova <nastyakrouglova@Nastyas-MacBook-Pro.local>
Co-authored-by: Jan Boelts <jan.boelts@mailbox.org>
Co-authored-by: Thomas Moreau <thomas.moreau.2010@gmail.com>
Co-authored-by: Matthijs Pals <34062419+Matthijspals@users.noreply.github.com>
Co-authored-by: Matthijs <matthijs@example.com>
Co-authored-by: zinaStef <49067201+zinaStef@users.noreply.github.com>
Co-authored-by: Sebastian Bischoff <sebastian@salzreute.de>
  • Loading branch information
9 people committed Apr 5, 2024
1 parent b2d7d21 commit dea309c
Show file tree
Hide file tree
Showing 8 changed files with 879 additions and 130 deletions.
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -42,8 +42,8 @@ dependencies = [
"tensorboard",
"torch>=1.8.0",
"tqdm",
"zuko>=1.0.0",
"pymc>=5.0.0",
"zuko>=1.1.0",
]

[project.optional-dependencies]
Expand Down
8 changes: 4 additions & 4 deletions sbi/neural_nets/density_estimators/zuko_flow.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import torch
from torch import Tensor, nn
from zuko.flows import Flow
from zuko.flows.core import Flow

from sbi.neural_nets.density_estimators.base import DensityEstimator
from sbi.sbi_types import Shape
Expand Down Expand Up @@ -125,6 +125,7 @@ def log_prob(self, input: Tensor, condition: Tensor) -> Tensor:
emb_cond = emb_cond.expand(batch_shape + (emb_cond.shape[-1],))

dists = self.net(emb_cond)

log_probs = dists.log_prob(input)

return log_probs
Expand Down Expand Up @@ -166,7 +167,7 @@ def sample(self, sample_shape: Shape, condition: Tensor) -> Tensor:

emb_cond = self._embedding_net(condition)
dists = self.net(emb_cond)
# zuko.sample() returns (*sample_shape, *batch_shape, input_size).

samples = dists.sample(sample_shape).reshape(*batch_shape, *sample_shape, -1)

return samples
Expand All @@ -190,9 +191,8 @@ def sample_and_log_prob(

emb_cond = self._embedding_net(condition)
dists = self.net(emb_cond)
samples, log_probs = dists.rsample_and_log_prob(sample_shape)
# zuko.sample_and_log_prob() returns (*sample_shape, *batch_shape, ...).

samples, log_probs = dists.rsample_and_log_prob(sample_shape)
samples = samples.reshape(*batch_shape, *sample_shape, -1)
log_probs = log_probs.reshape(*batch_shape, *sample_shape)

Expand Down
50 changes: 50 additions & 0 deletions sbi/neural_nets/factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,16 @@
build_maf,
build_maf_rqs,
build_nsf,
build_zuko_bpf,
build_zuko_cnf,
build_zuko_gf,
build_zuko_maf,
build_zuko_naf,
build_zuko_ncsf,
build_zuko_nice,
build_zuko_nsf,
build_zuko_sospf,
build_zuko_unaf,
)
from sbi.neural_nets.mdn import build_mdn
from sbi.neural_nets.mnle import build_mnle
Expand Down Expand Up @@ -174,8 +183,26 @@ def build_fn(batch_theta, batch_x):
return build_nsf(batch_x=batch_x, batch_y=batch_theta, **kwargs)
elif model == "mnle":
return build_mnle(batch_x=batch_x, batch_y=batch_theta, **kwargs)
elif model == "zuko_nice":
return build_zuko_nice(batch_x=batch_x, batch_y=batch_theta, **kwargs)
elif model == "zuko_maf":
return build_zuko_maf(batch_x=batch_x, batch_y=batch_theta, **kwargs)
elif model == "zuko_nsf":
return build_zuko_nsf(batch_x=batch_x, batch_y=batch_theta, **kwargs)
elif model == "zuko_ncsf":
return build_zuko_ncsf(batch_x=batch_x, batch_y=batch_theta, **kwargs)
elif model == "zuko_sospf":
return build_zuko_sospf(batch_x=batch_x, batch_y=batch_theta, **kwargs)
elif model == "zuko_naf":
return build_zuko_naf(batch_x=batch_x, batch_y=batch_theta, **kwargs)
elif model == "zuko_unaf":
return build_zuko_unaf(batch_x=batch_x, batch_y=batch_theta, **kwargs)
elif model == "zuko_cnf":
return build_zuko_cnf(batch_x=batch_x, batch_y=batch_theta, **kwargs)
elif model == "zuko_gf":
return build_zuko_gf(batch_x=batch_x, batch_y=batch_theta, **kwargs)
elif model == "zuko_bpf":
return build_zuko_bpf(batch_x=batch_x, batch_y=batch_theta, **kwargs)
else:
raise NotImplementedError

Expand Down Expand Up @@ -266,6 +293,9 @@ def build_fn_snpe_a(batch_theta, batch_x, num_components):

def build_fn(batch_theta, batch_x):
if model == "mdn":
# The naming might be a bit confusing.
# batch_x are the latent variables, batch_y the conditioned variables.
# batch_theta are the parameters and batch_x the observable variables.
return build_mdn(batch_x=batch_theta, batch_y=batch_x, **kwargs)
elif model == "made":
return build_made(batch_x=batch_theta, batch_y=batch_x, **kwargs)
Expand All @@ -275,8 +305,28 @@ def build_fn(batch_theta, batch_x):
return build_maf_rqs(batch_x=batch_theta, batch_y=batch_x, **kwargs)
elif model == "nsf":
return build_nsf(batch_x=batch_theta, batch_y=batch_x, **kwargs)
elif model == "mnle":
return build_mnle(batch_x=batch_theta, batch_y=batch_x, **kwargs)
elif model == "zuko_nice":
return build_zuko_nice(batch_x=batch_theta, batch_y=batch_x, **kwargs)
elif model == "zuko_maf":
return build_zuko_maf(batch_x=batch_theta, batch_y=batch_x, **kwargs)
elif model == "zuko_nsf":
return build_zuko_nsf(batch_x=batch_theta, batch_y=batch_x, **kwargs)
elif model == "zuko_ncsf":
return build_zuko_ncsf(batch_x=batch_theta, batch_y=batch_x, **kwargs)
elif model == "zuko_sospf":
return build_zuko_sospf(batch_x=batch_theta, batch_y=batch_x, **kwargs)
elif model == "zuko_naf":
return build_zuko_naf(batch_x=batch_theta, batch_y=batch_x, **kwargs)
elif model == "zuko_unaf":
return build_zuko_unaf(batch_x=batch_theta, batch_y=batch_x, **kwargs)
elif model == "zuko_cnf":
return build_zuko_cnf(batch_x=batch_theta, batch_y=batch_x, **kwargs)
elif model == "zuko_gf":
return build_zuko_gf(batch_x=batch_theta, batch_y=batch_x, **kwargs)
elif model == "zuko_bpf":
return build_zuko_bpf(batch_x=batch_theta, batch_y=batch_x, **kwargs)
else:
raise NotImplementedError

Expand Down
Loading

0 comments on commit dea309c

Please sign in to comment.