Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Zuko density estimators (#1088) #1116

Merged
merged 16 commits into from
Apr 5, 2024
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -42,8 +42,8 @@ dependencies = [
"tensorboard",
"torch>=1.8.0",
"tqdm",
"zuko>=1.0.0",
"pymc>=5.0.0",
"zuko>=1.1.0",
]

[project.optional-dependencies]
Expand Down
8 changes: 4 additions & 4 deletions sbi/neural_nets/density_estimators/zuko_flow.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import torch
from torch import Tensor, nn
from zuko.flows import Flow
from zuko.flows.core import Flow

from sbi.neural_nets.density_estimators.base import DensityEstimator
from sbi.sbi_types import Shape
Expand Down Expand Up @@ -125,6 +125,7 @@ def log_prob(self, input: Tensor, condition: Tensor) -> Tensor:
emb_cond = emb_cond.expand(batch_shape + (emb_cond.shape[-1],))

dists = self.net(emb_cond)

log_probs = dists.log_prob(input)

return log_probs
Expand Down Expand Up @@ -166,7 +167,7 @@ def sample(self, sample_shape: Shape, condition: Tensor) -> Tensor:

emb_cond = self._embedding_net(condition)
dists = self.net(emb_cond)
# zuko.sample() returns (*sample_shape, *batch_shape, input_size).

samples = dists.sample(sample_shape).reshape(*batch_shape, *sample_shape, -1)

return samples
Expand All @@ -190,9 +191,8 @@ def sample_and_log_prob(

emb_cond = self._embedding_net(condition)
dists = self.net(emb_cond)
samples, log_probs = dists.rsample_and_log_prob(sample_shape)
# zuko.sample_and_log_prob() returns (*sample_shape, *batch_shape, ...).

samples, log_probs = dists.rsample_and_log_prob(sample_shape)
samples = samples.reshape(*batch_shape, *sample_shape, -1)
log_probs = log_probs.reshape(*batch_shape, *sample_shape)

Expand Down
67 changes: 37 additions & 30 deletions sbi/neural_nets/factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,37 @@
build_maf,
build_maf_rqs,
build_nsf,
build_zuko_bpf,
build_zuko_gf,
build_zuko_maf,
build_zuko_naf,
build_zuko_ncsf,
build_zuko_nice,
build_zuko_nsf,
build_zuko_sospf,
build_zuko_unaf,
)
from sbi.neural_nets.mdn import build_mdn
from sbi.neural_nets.mnle import build_mnle

model_builders = {
"mdn": build_mdn,
"made": build_made,
"maf": build_maf,
"maf_rqs": build_maf_rqs,
"nsf": build_nsf,
"mnle": build_mnle,
"zuko_nice": build_zuko_nice,
"zuko_maf": build_zuko_maf,
"zuko_nsf": build_zuko_nsf,
"zuko_ncsf": build_zuko_ncsf,
"zuko_sospf": build_zuko_sospf,
"zuko_naf": build_zuko_naf,
"zuko_unaf": build_zuko_unaf,
"zuko_gf": build_zuko_gf,
"zuko_bpf": build_zuko_bpf,
}


def classifier_nn(
model: str,
Expand Down Expand Up @@ -162,22 +188,10 @@
)

def build_fn(batch_theta, batch_x):
if model == "mdn":
return build_mdn(batch_x=batch_x, batch_y=batch_theta, **kwargs)
elif model == "made":
return build_made(batch_x=batch_x, batch_y=batch_theta, **kwargs)
elif model == "maf":
return build_maf(batch_x=batch_x, batch_y=batch_theta, **kwargs)
elif model == "maf_rqs":
return build_maf_rqs(batch_x=batch_x, batch_y=batch_theta, **kwargs)
elif model == "nsf":
return build_nsf(batch_x=batch_x, batch_y=batch_theta, **kwargs)
elif model == "mnle":
return build_mnle(batch_x=batch_x, batch_y=batch_theta, **kwargs)
elif model == "zuko_maf":
return build_zuko_maf(batch_x=batch_x, batch_y=batch_theta, **kwargs)
else:
raise NotImplementedError
if model not in model_builders:
raise NotImplementedError(f"Model {model} in not implemented")

Check warning on line 192 in sbi/neural_nets/factory.py

View check run for this annotation

Codecov / codecov/patch

sbi/neural_nets/factory.py#L192

Added line #L192 was not covered by tests

return model_builders[model](batch_x=batch_x, batch_y=batch_theta, **kwargs)

return build_fn

Expand Down Expand Up @@ -265,20 +279,13 @@
)

def build_fn(batch_theta, batch_x):
if model == "mdn":
return build_mdn(batch_x=batch_theta, batch_y=batch_x, **kwargs)
elif model == "made":
return build_made(batch_x=batch_theta, batch_y=batch_x, **kwargs)
elif model == "maf":
return build_maf(batch_x=batch_theta, batch_y=batch_x, **kwargs)
elif model == "maf_rqs":
return build_maf_rqs(batch_x=batch_theta, batch_y=batch_x, **kwargs)
elif model == "nsf":
return build_nsf(batch_x=batch_theta, batch_y=batch_x, **kwargs)
elif model == "zuko_maf":
return build_zuko_maf(batch_x=batch_theta, batch_y=batch_x, **kwargs)
else:
raise NotImplementedError
if model not in model_builders:
anastasiakrouglova marked this conversation as resolved.
Show resolved Hide resolved
raise NotImplementedError(f"Model {model} in not implemented")

Check warning on line 283 in sbi/neural_nets/factory.py

View check run for this annotation

Codecov / codecov/patch

sbi/neural_nets/factory.py#L283

Added line #L283 was not covered by tests

# The naming might be a bit confusing.
# batch_x are the latent variables, batch_y the conditioned variables.
# batch_theta are the parameters and batch_x the observable variables.
return model_builders[model](batch_x=batch_theta, batch_y=batch_x, **kwargs)

if model == "mdn_snpe_a":
if num_components != 10:
Expand Down
Loading