Skip to content

Commit

Permalink
catch NaN standardizing transforms, add err msg.
Browse files Browse the repository at this point in the history
  • Loading branch information
janfb committed Mar 1, 2023
1 parent 06f864b commit fbbaf8f
Showing 1 changed file with 12 additions and 0 deletions.
12 changes: 12 additions & 0 deletions sbi/utils/sbiutils.py
Original file line number Diff line number Diff line change
Expand Up @@ -224,6 +224,12 @@ def standardizing_net(
not the case, please be sure to use a larger batch."""
)

assert not (
torch.isnan(t_mean).any() or torch.isnan(torch.tensor(t_std)).any(),
), """Training data mean and std for standardizing net must not contain NaNs.
In case you are encoding missing trials with NaNs, consider setting
z_score_x = 'none' to disable z-scoring."""

return Standardize(t_mean, t_std)


Expand Down Expand Up @@ -255,6 +261,12 @@ def handle_invalid_x(
else:
is_valid_x = ones(batch_size, dtype=torch.bool)

assert (
is_valid_x.sum() > 0
), """No valid data entries left after excluding NaNs and Infs. In case you are
encoding missing trials with NaNs consider setting exclude_invalid_x=False and
z_score_x = 'none' to disable z-scoring."""

return is_valid_x, num_nans, num_infs


Expand Down

0 comments on commit fbbaf8f

Please sign in to comment.