Skip to content

Commit

Permalink
ensure within_support returns correct batch shape (#1038)
Browse files Browse the repository at this point in the history
  • Loading branch information
gmoss13 authored Mar 19, 2024
1 parent b3dbc94 commit 1d25046
Showing 1 changed file with 8 additions and 0 deletions.
8 changes: 8 additions & 0 deletions sbi/utils/sbiutils.py
Original file line number Diff line number Diff line change
Expand Up @@ -569,8 +569,16 @@ def within_support(distribution: Any, samples: Tensor) -> Tensor:
Tensor of bools indicating whether each sample was within the support.
"""
# Try to check using the support property, use log prob method otherwise.
# Before torch v1.7.0, `support.check()` returned bools for every element.
# From v1.8.0 on, it directly considers all dimensions of a sample. E.g.,
# for a single sample in 3D, v1.7.0 would return [[True, True, True]] and
# v1.8.0 would return [True]. However, Pyro distributions would still
# return [[True, True, True]]. This is relevant for `ImproperEmpirical`
# distributions in SBI, which are used in SNPE.
try:
sample_check = distribution.support.check(samples)
if sample_check.shape == samples.shape:
sample_check = torch.all(sample_check, dim=-1)
return sample_check

# Falling back to log prob method of either the NeuralPosterior's net, or of a
Expand Down

0 comments on commit 1d25046

Please sign in to comment.