Skip to content

Commit

Permalink
fix #1260: include points in plotting limits
Browse files Browse the repository at this point in the history
  • Loading branch information
janfb committed Sep 6, 2024
1 parent becc93c commit 15bf836
Show file tree
Hide file tree
Showing 2 changed files with 49 additions and 25 deletions.
68 changes: 46 additions & 22 deletions sbi/analysis/plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -554,42 +554,66 @@ def handle_nan_infs(samples: List[np.ndarray]) -> List[np.ndarray]:
return samples


def convert_to_list_of_numpy(
arr: Union[List[np.ndarray], List[torch.Tensor], np.ndarray, torch.Tensor],
) -> List[np.ndarray]:
"""Converts a list of torch.Tensor to a list of np.ndarray."""
if not isinstance(arr, list):
arr = ensure_numpy(arr)
return [arr]
return [ensure_numpy(a) for a in arr]


def infer_limits(
samples: List[np.ndarray],
dim: int,
points: Optional[List[np.ndarray]] = None,
eps: float = 0.1,
) -> List:
"""Infer limits for the plot.
Args:
samples: List of samples.
dim: Dimension of the samples.
points: List of points.
eps: Relative margin for the limits.
"""
limits = []
for d in range(dim):
min_val = min(np.min(sample[:, d]) for sample in samples)
max_val = max(np.max(sample[:, d]) for sample in samples)
if points is not None:
min_val = min(min_val, min(np.min(point[:, d]) for point in points))
max_val = max(max_val, max(np.max(point[:, d]) for point in points))
limits.append([min_val * (1 + eps), max_val * (1 + eps)])
return limits


def prepare_for_plot(
samples: Union[List[np.ndarray], List[torch.Tensor], np.ndarray, torch.Tensor],
limits: Optional[Union[List, torch.Tensor, np.ndarray]],
limits: Optional[Union[List, torch.Tensor, np.ndarray]] = None,
points: Optional[
Union[List[np.ndarray], List[torch.Tensor], np.ndarray, torch.Tensor]
] = None,
) -> Tuple[List[np.ndarray], int, torch.Tensor]:
"""
Ensures correct formatting for samples and limits, and returns dimension
of the samples.
"""

# Prepare samples
if not isinstance(samples, list):
samples = ensure_numpy(samples)
samples = [samples]
else:
samples = [ensure_numpy(sample) for sample in samples]
samples = convert_to_list_of_numpy(samples)
if points is not None:
points = convert_to_list_of_numpy(points)

# check if nans and infs
samples = handle_nan_infs(samples)

# Dimensionality of the problem.
dim = samples[0].shape[1]

# Prepare limits. Infer them from samples if they had not been passed.
if limits == [] or limits is None:
limits = []
for d in range(dim):
min = +np.inf
max = -np.inf
for sample in samples:
min_ = np.min(sample[:, d])
min = min_ if min_ < min else min
max_ = np.max(sample[:, d])
max = max_ if max_ > max else max
limits.append([min, max])
if limits is None or limits == []:
limits = infer_limits(samples, dim, points)
else:
limits = [limits[0] for _ in range(dim)] if len(limits) == 1 else limits

limits = torch.as_tensor(limits)
return samples, dim, limits

Expand Down Expand Up @@ -737,7 +761,7 @@ def pairplot(
)
return fig, axes

samples, dim, limits = prepare_for_plot(samples, limits)
samples, dim, limits = prepare_for_plot(samples, limits, points)

# prepate figure kwargs
fig_kwargs_filled = _get_default_fig_kwargs()
Expand Down
6 changes: 3 additions & 3 deletions tests/plot_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@


@pytest.mark.parametrize("samples", (torch.randn(100, 1),))
@pytest.mark.parametrize("limits", ([(-1, 1)],))
@pytest.mark.parametrize("limits", ([(-1, 1)], None))
def test_pairplot1D(samples, limits):
fig, axs = pairplot(**{k: v for k, v in locals().items() if v is not None})
assert isinstance(fig, Figure)
Expand All @@ -24,7 +24,7 @@ def test_pairplot1D(samples, limits):


@pytest.mark.parametrize("samples", (torch.randn(100, 2),))
@pytest.mark.parametrize("limits", ([(-1, 1)],))
@pytest.mark.parametrize("limits", ([(-1, 1)], None))
def test_nan_inf(samples, limits):
samples[0, 0] = np.nan
samples[5, 1] = np.inf
Expand All @@ -37,7 +37,7 @@ def test_nan_inf(samples, limits):

@pytest.mark.parametrize("samples", (torch.randn(100, 2), [torch.randn(100, 3)] * 2))
@pytest.mark.parametrize("points", (torch.ones(1, 3),))
@pytest.mark.parametrize("limits", ([(-3, 3)],))
@pytest.mark.parametrize("limits", ([(-3, 3)], None))
@pytest.mark.parametrize("subset", (None, [0, 1]))
@pytest.mark.parametrize("upper", ("scatter",))
@pytest.mark.parametrize(
Expand Down

0 comments on commit 15bf836

Please sign in to comment.