Skip to content

Commit

Permalink
bugfix for device handling in SNRE
Browse files Browse the repository at this point in the history
  • Loading branch information
michaeldeistler authored and janfb committed May 18, 2021
1 parent 9ea4b93 commit 1fd78b5
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 5 deletions.
2 changes: 1 addition & 1 deletion sbi/inference/posteriors/likelihood_based_posterior.py
Original file line number Diff line number Diff line change
Expand Up @@ -322,7 +322,7 @@ def _log_likelihoods_over_trials(
), "x and theta must match in batch shape."
assert (
next(net.parameters()).device == x.device and x.device == theta.device
), f"device mismatch: net, x, theta: {net.device}, {x.decive}, {theta.device}."
), f"device mismatch: net, x, theta: {next(net.parameters()).device}, {x.decive}, {theta.device}."

# Calculate likelihood in one batch.
with torch.set_grad_enabled(track_gradients):
Expand Down
11 changes: 7 additions & 4 deletions sbi/inference/posteriors/ratio_based_posterior.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,10 @@ def log_prob(

# Sum log ratios over x batch of iid trials.
log_ratio = self._log_ratios_over_trials(
x, theta, self.net, track_gradients=track_gradients
x.to(self._device),
theta.to(self._device),
self.net,
track_gradients=track_gradients,
)

return log_ratio.cpu() + self._prior.log_prob(theta)
Expand Down Expand Up @@ -352,7 +355,7 @@ def _log_ratios_over_trials(
), "x and theta must match in batch shape."
assert (
next(net.parameters()).device == x.device and x.device == theta.device
), f"device mismatch: net, x, theta: {net.device}, {x.decive}, {theta.device}."
), f"device mismatch: net, x, theta: {next(net.parameters()).device}, {x.device}, {theta.device}."

# Calculate ratios in one batch.
with torch.set_grad_enabled(track_gradients):
Expand Down Expand Up @@ -429,7 +432,7 @@ def np_potential(self, theta: np.array) -> ScalarFloat:
theta = ensure_theta_batched(theta)

log_ratio = RatioBasedPosterior._log_ratios_over_trials(
self.x, theta, self.classifier, track_gradients=False
self.x, theta.to(self.device), self.classifier, track_gradients=False
)

# Notice opposite sign to pyro potential.
Expand All @@ -455,7 +458,7 @@ def pyro_potential(self, theta: Dict[str, Tensor]) -> Tensor:
theta = ensure_theta_batched(theta)

log_ratio = RatioBasedPosterior._log_ratios_over_trials(
self.x, theta, self.classifier, track_gradients=False
self.x, theta.to(self.device), self.classifier, track_gradients=False
)

return -(log_ratio.cpu() + self.prior.log_prob(theta))

0 comments on commit 1fd78b5

Please sign in to comment.