Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: batched sampling and log prob methods. #1153

Merged
merged 76 commits into from
Jun 18, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
76 commits
Select commit Hold shift + click to select a range
17c5343
Base estimator class
manuelgloeckler Apr 29, 2024
705e9df
intermediate commit
michaeldeistler May 3, 2024
07b53cd
make autoreload work
michaeldeistler May 3, 2024
dd02e22
`amortized_sample` works for MCMCPosterior
michaeldeistler May 5, 2024
663185b
fixes current bug!
manuelgloeckler May 7, 2024
df8899a
Added tests
manuelgloeckler May 7, 2024
aa82aab
batched_rejection_sampling
manuelgloeckler May 7, 2024
00cdade
intermediate commit
michaeldeistler May 3, 2024
cb8e4d8
make autoreload work
michaeldeistler May 3, 2024
d64557f
`amortized_sample` works for MCMCPosterior
michaeldeistler May 5, 2024
f16622d
Merge branch 'amortizedsample' of https://github.com/sbi-dev/sbi into…
manuelgloeckler May 7, 2024
07084e2
Merge branch '990-add-sample_batched-and-log_prob_batched-to-posterio…
manuelgloeckler May 7, 2024
e54a2fb
Revert "Merge branch '990-add-sample_batched-and-log_prob_batched-to-…
manuelgloeckler May 7, 2024
52d0e7e
Merge branch '1154-density-estimator-batched-sample-mixes-up-samples-…
manuelgloeckler May 7, 2024
cd808d5
sample works, try log_prob_batched
manuelgloeckler May 7, 2024
f542224
log_prob_batched works
manuelgloeckler May 7, 2024
48a1a28
abstract method implement for other methods
manuelgloeckler May 7, 2024
5a37330
temp fix mcmcposterior
manuelgloeckler May 7, 2024
2b23e42
meh for general use i.e. in the restriction prior we have to add some…
manuelgloeckler May 7, 2024
6362051
... test class
manuelgloeckler May 7, 2024
294609d
Revert "Base estimator class"
manuelgloeckler May 8, 2024
99abbb1
removing previous change
manuelgloeckler May 8, 2024
ef9e99c
removing some artifacts
manuelgloeckler May 8, 2024
5eb1007
revert wierd change
manuelgloeckler May 8, 2024
82127ab
docs and tests
manuelgloeckler May 8, 2024
41617a8
MCMC sample_batched works but not log_prob batched
manuelgloeckler May 14, 2024
82951db
adding some docs
manuelgloeckler May 14, 2024
c5fac1d
batch_log_prob for MCMC requires at best changes for potential -> rem…
manuelgloeckler May 14, 2024
0d82422
intermediate commit
michaeldeistler May 3, 2024
57cfde3
make autoreload work
michaeldeistler May 3, 2024
de5d647
`amortized_sample` works for MCMCPosterior
michaeldeistler May 5, 2024
f8b6604
intermediate commit
michaeldeistler May 3, 2024
1dcf882
make autoreload work
michaeldeistler May 3, 2024
5a31970
`amortized_sample` works for MCMCPosterior
michaeldeistler May 5, 2024
871c4de
Base estimator class
manuelgloeckler Apr 29, 2024
f87d6b6
Revert "Merge branch '990-add-sample_batched-and-log_prob_batched-to-…
manuelgloeckler May 7, 2024
dbd0109
fixes current bug!
manuelgloeckler May 7, 2024
264b6c4
Added tests
manuelgloeckler May 7, 2024
339b57b
batched_rejection_sampling
manuelgloeckler May 7, 2024
676c271
sample works, try log_prob_batched
manuelgloeckler May 7, 2024
7a8a84d
log_prob_batched works
manuelgloeckler May 7, 2024
5daab92
abstract method implement for other methods
manuelgloeckler May 7, 2024
40897a0
temp fix mcmcposterior
manuelgloeckler May 7, 2024
a2b7e32
meh for general use i.e. in the restriction prior we have to add some…
manuelgloeckler May 7, 2024
cb4d8ae
... test class
manuelgloeckler May 7, 2024
ab9b1e1
Revert "Base estimator class"
manuelgloeckler May 8, 2024
d2b1a62
removing previous change
manuelgloeckler May 8, 2024
a0c0c97
removing some artifacts
manuelgloeckler May 8, 2024
8fc5a46
revert wierd change
manuelgloeckler May 8, 2024
18c7d36
docs and tests
manuelgloeckler May 8, 2024
6ad6cb7
MCMC sample_batched works but not log_prob batched
manuelgloeckler May 14, 2024
03c10f3
adding some docs
manuelgloeckler May 14, 2024
24c4821
batch_log_prob for MCMC requires at best changes for potential -> rem…
manuelgloeckler May 14, 2024
1769d6e
Merge branch 'amortizedsample' of https://github.com/sbi-dev/sbi into…
manuelgloeckler Jun 11, 2024
a445a6c
Fixing bug from rebase...
manuelgloeckler Jun 11, 2024
86767a1
tracking all acceptance rates
manuelgloeckler Jun 11, 2024
9502af3
Comment on NFlows
manuelgloeckler Jun 11, 2024
c80e6ff
Also testing SNRE batched sampling, Need to test ensemble implementation
manuelgloeckler Jun 11, 2024
7aac84c
fig bug
manuelgloeckler Jun 11, 2024
7d4eb55
Ensemble sample_batched is working (with tests)
manuelgloeckler Jun 11, 2024
f53e1ec
GPU compatibility
manuelgloeckler Jun 11, 2024
2dc6ebd
restriction priopr requires float as output of accept_reject
manuelgloeckler Jun 11, 2024
7dfda13
Adding a few comments
manuelgloeckler Jun 11, 2024
89b6e8f
2d sample_shape tests
manuelgloeckler Jun 11, 2024
35dcf40
Merge branch 'main' into amortizedsample
janfb Jun 13, 2024
93ca374
Apply suggestions from code review
manuelgloeckler Jun 14, 2024
86f3531
Adding comment about squeeze
manuelgloeckler Jun 14, 2024
2a5f357
Update sbi/inference/posteriors/direct_posterior.py
manuelgloeckler Jun 18, 2024
79273a2
fixing formating
manuelgloeckler Jun 18, 2024
7b23d60
reverting MCM posterior changes
manuelgloeckler Jun 18, 2024
d4f9e46
xfail mcmc tests
manuelgloeckler Jun 18, 2024
6798c97
Exclude MCMC from ensamble batched_sample test
manuelgloeckler Jun 18, 2024
b1724a5
SNPE_A Bug fix
manuelgloeckler Jun 18, 2024
a6f4845
typo fix
manuelgloeckler Jun 18, 2024
2aac705
preamtive main fix
manuelgloeckler Jun 18, 2024
26444f7
Revert "preamtive main fix"
manuelgloeckler Jun 18, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 11 additions & 0 deletions sbi/inference/posteriors/base_posterior.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,17 @@
"""See child classes for docstring."""
pass

@abstractmethod
def sample_batched(
self,
sample_shape: Shape,
x: Tensor,
max_sampling_batch_size: int = 10_000,
show_progress_bars: bool = True,
) -> Tensor:
"""See child classes for docstring."""
pass

Check warning on line 133 in sbi/inference/posteriors/base_posterior.py

View check run for this annotation

Codecov / codecov/patch

sbi/inference/posteriors/base_posterior.py#L133

Added line #L133 was not covered by tests

@property
def default_x(self) -> Optional[Tensor]:
"""Return default x used by `.sample(), .log_prob` as conditioning context."""
Expand Down
125 changes: 122 additions & 3 deletions sbi/inference/posteriors/direct_posterior.py
manuelgloeckler marked this conversation as resolved.
Show resolved Hide resolved
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
reshape_to_batch_event,
reshape_to_sample_batch_event,
)
from sbi.samplers.rejection.rejection import accept_reject_sample
from sbi.samplers.rejection import rejection
from sbi.sbi_types import Shape
from sbi.utils import check_prior, within_support
from sbi.utils.torchutils import ensure_theta_batched
Expand Down Expand Up @@ -123,7 +123,51 @@ def sample(
f"`.build_posterior(sample_with={sample_with}).`"
)

samples = accept_reject_sample(
samples = rejection.accept_reject_sample(
proposal=self.posterior_estimator,
accept_reject_fn=lambda theta: within_support(self.prior, theta),
num_samples=num_samples,
show_progress_bars=show_progress_bars,
max_sampling_batch_size=max_sampling_batch_size,
proposal_sampling_kwargs={"condition": x},
alternative_method="build_posterior(..., sample_with='mcmc')",
)[0]

return samples[:, 0] # Remove batch dimension.

def sample_batched(
self,
sample_shape: Shape,
x: Tensor,
max_sampling_batch_size: int = 10_000,
show_progress_bars: bool = True,
) -> Tensor:
r"""Given a batch of observations [x_1, ..., x_B] this function samples from
posteriors $p(\theta|x_1)$, ... ,$p(\theta|x_B)$, in a batched (i.e. vectorized)
manner.

Args:
sample_shape: Desired shape of samples that are drawn from the posterior
given every observation.
x: A batch of observations, of shape `(batch_dim, event_shape_x)`.
`batch_dim` corresponds to the number of observations to be drawn.
max_sampling_batch_size: Maximum batch size for rejection sampling.
show_progress_bars: Whether to show sampling progress monitor.

Returns:
Samples from the posteriors of shape (*sample_shape, B, *input_shape)
"""
num_samples = torch.Size(sample_shape).numel()
condition_shape = self.posterior_estimator.condition_shape
x = reshape_to_batch_event(x, event_shape=condition_shape)

max_sampling_batch_size = (
self.max_sampling_batch_size
if max_sampling_batch_size is None
else max_sampling_batch_size
)

samples = rejection.accept_reject_sample(
proposal=self.posterior_estimator,
accept_reject_fn=lambda theta: within_support(self.prior, theta),
num_samples=num_samples,
Expand Down Expand Up @@ -210,6 +254,81 @@ def log_prob(

return masked_log_prob - log_factor

def log_prob_batched(
self,
theta: Tensor,
x: Tensor,
norm_posterior: bool = True,
track_gradients: bool = False,
leakage_correction_params: Optional[dict] = None,
) -> Tensor:
"""Given a batch of observations [x_1, ..., x_B] and a batch of parameters \
[$\theta_1$,..., $\theta_B$] this function evalautes the log-probabilities \
of the posteriors $p(\theta_1|x_1)$, ..., $p(\theta_B|x_B)$ in a batched \
(i.e. vectorized) manner.

Args:
theta: Batch of parameters $\theta$ of shape \
`(*sample_shape, batch_dim, *theta_shape)`.
x: Batch of observations $x$ of shape \
`(batch_dim, *condition_shape)`.
norm_posterior: Whether to enforce a normalized posterior density.
Renormalization of the posterior is useful when some
probability falls out or leaks out of the prescribed prior support.
The normalizing factor is calculated via rejection sampling, so if you
need speedier but unnormalized log posterior estimates set here
`norm_posterior=False`. The returned log posterior is set to
-∞ outside of the prior support regardless of this setting.
track_gradients: Whether the returned tensor supports tracking gradients.
This can be helpful for e.g. sensitivity analysis, but increases memory
consumption.
leakage_correction_params: A `dict` of keyword arguments to override the
default values of `leakage_correction()`. Possible options are:
`num_rejection_samples`, `force_update`, `show_progress_bars`, and
`rejection_sampling_batch_size`.
These parameters only have an effect if `norm_posterior=True`.

Returns:
`(len(θ), B)`-shaped log posterior probability $\\log p(\theta|x)$\\ for θ \
in the support of the prior, -∞ (corresponding to 0 probability) outside.
"""

theta = ensure_theta_batched(torch.as_tensor(theta))
event_shape = self.posterior_estimator.input_shape
theta_density_estimator = reshape_to_sample_batch_event(
theta, event_shape, leading_is_sample=True
)
x_density_estimator = reshape_to_batch_event(
x, event_shape=self.posterior_estimator.condition_shape
)

self.posterior_estimator.eval()

with torch.set_grad_enabled(track_gradients):
# Evaluate on device, move back to cpu for comparison with prior.
unnorm_log_prob = self.posterior_estimator.log_prob(
theta_density_estimator, condition=x_density_estimator
)

# Force probability to be zero outside prior support.
in_prior_support = within_support(self.prior, theta)

masked_log_prob = torch.where(
in_prior_support,
unnorm_log_prob,
torch.tensor(float("-inf"), dtype=torch.float32, device=self._device),
)

if leakage_correction_params is None:
leakage_correction_params = dict() # use defaults
log_factor = (
log(self.leakage_correction(x=x, **leakage_correction_params))
if norm_posterior
else 0
)

return masked_log_prob - log_factor

@torch.no_grad()
def leakage_correction(
self,
Expand Down Expand Up @@ -240,7 +359,7 @@ def leakage_correction(

def acceptance_at(x: Tensor) -> Tensor:
# [1:] to remove batch-dimension for `reshape_to_batch_event`.
return accept_reject_sample(
return rejection.accept_reject_sample(
proposal=self.posterior_estimator,
accept_reject_fn=lambda theta: within_support(self.prior, theta),
num_samples=num_rejection_samples,
Expand Down
23 changes: 23 additions & 0 deletions sbi/inference/posteriors/ensemble_posterior.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,6 +179,29 @@
)
return torch.vstack(samples).reshape(*sample_shape, -1)

def sample_batched(
self,
sample_shape: Shape,
x: Tensor,
**kwargs,
) -> Tensor:
num_samples = torch.Size(sample_shape).numel()
posterior_indices = torch.multinomial(

Check warning on line 189 in sbi/inference/posteriors/ensemble_posterior.py

View check run for this annotation

Codecov / codecov/patch

sbi/inference/posteriors/ensemble_posterior.py#L188-L189

Added lines #L188 - L189 were not covered by tests
self._weights, num_samples, replacement=True
)
samples = []
for posterior_index, sample_size in torch.vstack(

Check warning on line 193 in sbi/inference/posteriors/ensemble_posterior.py

View check run for this annotation

Codecov / codecov/patch

sbi/inference/posteriors/ensemble_posterior.py#L192-L193

Added lines #L192 - L193 were not covered by tests
posterior_indices.unique(return_counts=True)
).T:
sample_shape_c = torch.Size((int(sample_size),))
samples.append(

Check warning on line 197 in sbi/inference/posteriors/ensemble_posterior.py

View check run for this annotation

Codecov / codecov/patch

sbi/inference/posteriors/ensemble_posterior.py#L196-L197

Added lines #L196 - L197 were not covered by tests
self.posteriors[posterior_index].sample_batched(
sample_shape_c, x=x, **kwargs
)
)
samples = torch.vstack(samples)
return samples.reshape(sample_shape + samples.shape[1:])

Check warning on line 204 in sbi/inference/posteriors/ensemble_posterior.py

View check run for this annotation

Codecov / codecov/patch

sbi/inference/posteriors/ensemble_posterior.py#L202-L204

Added lines #L202 - L204 were not covered by tests
def log_prob(
self,
theta: Tensor,
Expand Down
13 changes: 13 additions & 0 deletions sbi/inference/posteriors/importance_posterior.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,6 +194,19 @@ def sample(
else:
raise NameError

def sample_batched(
self,
sample_shape: Shape,
x: Tensor,
max_sampling_batch_size: int = 10000,
show_progress_bars: bool = True,
) -> Tensor:
raise NotImplementedError(
"Batched sampling is not implemented for ImportanceSamplingPosterior. \
Alternatively you can use `sample` in a loop \
[posterior.sample(theta, x_o) for x_o in x]."
)
manuelgloeckler marked this conversation as resolved.
Show resolved Hide resolved

def _importance_sample(
self,
sample_shape: Shape = torch.Size(),
Expand Down
45 changes: 44 additions & 1 deletion sbi/inference/posteriors/mcmc_posterior.py
Original file line number Diff line number Diff line change
Expand Up @@ -352,8 +352,51 @@ def sample(
raise NameError(f"The sampling method {method} is not implemented!")

samples = self.theta_transform.inv(transformed_samples)
# NOTE: Currently MCMCPosteriors will require a single dimension for the
# parameter dimension. With recent ConditionalDensity(Ratio) estimators, we
# can have multiple dimensions for the parameter dimension.
samples = samples.reshape((*sample_shape, -1)) # type: ignore
gmoss13 marked this conversation as resolved.
Show resolved Hide resolved

return samples.reshape((*sample_shape, -1)) # type: ignore
return samples

def sample_batched(
self,
sample_shape: Shape,
x: Tensor,
method: Optional[str] = None,
thin: Optional[int] = None,
warmup_steps: Optional[int] = None,
num_chains: Optional[int] = None,
init_strategy: Optional[str] = None,
init_strategy_parameters: Optional[Dict[str, Any]] = None,
num_workers: Optional[int] = None,
mp_context: Optional[str] = None,
show_progress_bars: bool = True,
) -> Tensor:
r"""Given a batch of observations [x_1, ..., x_B] this function samples from
posteriors $p(\theta|x_1)$, ... ,$p(\theta|x_B)$, in a batched (i.e. vectorized)
manner.

Check the `__init__()` method for a description of all arguments as well as
their default values.

Args:
sample_shape: Desired shape of samples that are drawn from the posterior
given every observation.
x: A batch of observations, of shape `(batch_dim, event_shape_x)`.
`batch_dim` corresponds to the number of observations to be drawn.
show_progress_bars: Whether to show sampling progress monitor.

Returns:
Samples from the posteriors of shape (*sample_shape, B, *input_shape)
"""

# See #1176 for a discussion on the implementation of batched sampling.
raise NotImplementedError(
"Batched sampling is not implemented for MCMC posterior. \
Alternatively you can use `sample` in a loop \
[posterior.sample(theta, x_o) for x_o in x]."
)

def _build_mcmc_init_fn(
self,
Expand Down
13 changes: 13 additions & 0 deletions sbi/inference/posteriors/rejection_posterior.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,6 +167,19 @@

return samples.reshape((*sample_shape, -1))

def sample_batched(
self,
sample_shape: Shape,
x: Tensor,
max_sampling_batch_size: int = 10000,
show_progress_bars: bool = True,
) -> Tensor:
raise NotImplementedError(

Check warning on line 177 in sbi/inference/posteriors/rejection_posterior.py

View check run for this annotation

Codecov / codecov/patch

sbi/inference/posteriors/rejection_posterior.py#L177

Added line #L177 was not covered by tests
"Batched sampling is not implemented for RejectionPosterior. \
Alternatively you can use `sample` in a loop \
[posterior.sample(theta, x_o) for x_o in x]."
)

def map(
self,
x: Optional[Tensor] = None,
Expand Down
13 changes: 13 additions & 0 deletions sbi/inference/posteriors/vi_posterior.py
Original file line number Diff line number Diff line change
Expand Up @@ -296,6 +296,19 @@
samples = self.q.sample(torch.Size(sample_shape))
return samples.reshape((*sample_shape, samples.shape[-1]))

def sample_batched(
self,
sample_shape: Shape,
x: Tensor,
max_sampling_batch_size: int = 10000,
show_progress_bars: bool = True,
) -> Tensor:
raise NotImplementedError(

Check warning on line 306 in sbi/inference/posteriors/vi_posterior.py

View check run for this annotation

Codecov / codecov/patch

sbi/inference/posteriors/vi_posterior.py#L306

Added line #L306 was not covered by tests
"Batched sampling is not implemented for VIPosterior. \
Alternatively you can use `sample` in a loop \
[posterior.sample(theta, x_o) for x_o in x]."
)

def log_prob(
self,
theta: Tensor,
Expand Down
10 changes: 8 additions & 2 deletions sbi/inference/snpe/snpe_a.py
Original file line number Diff line number Diff line change
Expand Up @@ -474,7 +474,8 @@
condition = condition.to(self._device)

if not self._apply_correction:
return self._neural_net.sample(sample_shape, condition=condition)
samples = self._neural_net.sample(sample_shape, condition=condition)

Check warning on line 477 in sbi/inference/snpe/snpe_a.py

View check run for this annotation

Codecov / codecov/patch

sbi/inference/snpe/snpe_a.py#L477

Added line #L477 was not covered by tests
return samples
else:
# When we want to sample from the approx. posterior, a proposal prior
# \tilde{p} has already been observed. To analytically calculate the
Expand All @@ -483,7 +484,12 @@
condition_ndim = len(self.condition_shape)
batch_size = condition.shape[:-condition_ndim]
batch_size = torch.Size(batch_size).numel()
return self._sample_approx_posterior_mog(num_samples, condition, batch_size)
samples = self._sample_approx_posterior_mog(
num_samples, condition, batch_size
)
# NOTE: New batching convention: (batch_dim, sample_dim, *event_shape)

Check warning on line 490 in sbi/inference/snpe/snpe_a.py

View check run for this annotation

Codecov / codecov/patch

sbi/inference/snpe/snpe_a.py#L487-L490

Added lines #L487 - L490 were not covered by tests
samples = samples.transpose(0, 1)
return samples

def _sample_approx_posterior_mog(
self, num_samples, x: Tensor, batch_size: int
Expand Down
2 changes: 2 additions & 0 deletions sbi/neural_nets/density_estimators/nflows_flow.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,6 +135,8 @@ def sample(self, sample_shape: Shape, condition: Tensor) -> Tensor:
num_samples = torch.Size(sample_shape).numel()

samples = self.net.sample(num_samples, context=condition)
# Change from Nflows' convention of (batch_dim, sample_dim, *event_shape) to
# (sample_dim, batch_dim, *event_shape) (PyTorch + SBI).
samples = samples.transpose(0, 1)
manuelgloeckler marked this conversation as resolved.
Show resolved Hide resolved
return samples.reshape((*sample_shape, condition_batch_dim, *self.input_shape))

Expand Down
Loading