Skip to content

Commit

Permalink
bugfix for device handling in SNRE
Browse files Browse the repository at this point in the history
  • Loading branch information
michaeldeistler committed May 18, 2021
1 parent e06fa36 commit 11a0c1f
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 6 deletions.
2 changes: 1 addition & 1 deletion sbi/inference/posteriors/likelihood_based_posterior.py
Original file line number Diff line number Diff line change
Expand Up @@ -383,7 +383,7 @@ def _log_likelihoods_over_trials(
), "x and theta must match in batch shape."
assert (
next(net.parameters()).device == x.device and x.device == theta.device
), f"device mismatch: net, x, theta: {net.device}, {x.decive}, {theta.device}."
), f"device mismatch: net, x, theta: {next(net.parameters()).device}, {x.decive}, {theta.device}."

# Calculate likelihood in one batch.
with torch.set_grad_enabled(track_gradients):
Expand Down
13 changes: 8 additions & 5 deletions sbi/inference/posteriors/ratio_based_posterior.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,10 @@ def log_prob(

# Sum log ratios over x batch of iid trials.
log_ratio = self._log_ratios_over_trials(
x, theta, self.net, track_gradients=track_gradients
x.to(self._device),
theta.to(self._device),
self.net,
track_gradients=track_gradients,
)

return log_ratio.cpu() + self._prior.log_prob(theta)
Expand Down Expand Up @@ -416,7 +419,7 @@ def _log_ratios_over_trials(
), "x and theta must match in batch shape."
assert (
next(net.parameters()).device == x.device and x.device == theta.device
), f"device mismatch: net, x, theta: {net.device}, {x.decive}, {theta.device}."
), f"device mismatch: net, x, theta: {next(net.parameters()).device}, {x.device}, {theta.device}."

# Calculate ratios in one batch.
with torch.set_grad_enabled(track_gradients):
Expand Down Expand Up @@ -497,7 +500,7 @@ def rejection_potential(self, theta: np.array) -> ScalarFloat:
theta = ensure_theta_batched(theta)

log_ratio = RatioBasedPosterior._log_ratios_over_trials(
self.x, theta, self.classifier, track_gradients=True
self.x, theta.to(self.device), self.classifier, track_gradients=True
)

# Notice opposite sign to pyro potential.
Expand All @@ -518,7 +521,7 @@ def np_potential(self, theta: np.array) -> ScalarFloat:
theta = ensure_theta_batched(theta)

log_ratio = RatioBasedPosterior._log_ratios_over_trials(
self.x, theta, self.classifier, track_gradients=False
self.x, theta.to(self.device), self.classifier, track_gradients=False
)

# Notice opposite sign to pyro potential.
Expand All @@ -544,7 +547,7 @@ def pyro_potential(self, theta: Dict[str, Tensor]) -> Tensor:
theta = ensure_theta_batched(theta)

log_ratio = RatioBasedPosterior._log_ratios_over_trials(
self.x, theta, self.classifier, track_gradients=True
self.x, theta.to(self.device), self.classifier, track_gradients=True
)

return -(log_ratio.cpu() + self.prior.log_prob(theta))

0 comments on commit 11a0c1f

Please sign in to comment.