Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow prior on gpu #519

Merged
merged 5 commits into from
Jul 8, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 7 additions & 2 deletions docs/docs/faq/question_04.md
Original file line number Diff line number Diff line change
@@ -1,18 +1,23 @@

# Can I use the GPU for training the density estimator?

TLDR; Yes, by passing `device="cuda"`. But no speed-ups for default density estimators.
TLDR; Yes, by passing `device="cuda"` and by passing a prior that lives on the device
name your passed. But no speed-ups for default density estimators.

Yes. When creating the inference object in the flexible interface, you can pass the
`device` as an argument, e.g.,

```python
inference = SNPE(simulator, prior, device="cuda", density_estimator="maf")
inference = SNPE(prior, device="cuda", density_estimator="maf")
```

The device is set to `"cpu"` by default, and it can be set to anything, as long as it
maps to an existing PyTorch CUDA device. `sbi` will take care of copying the `net` and
the training data to and from the `device`.
Note that the prior must be on the training device already, e.g., when passing `device="cuda:0"`,
make sure to pass a prior object that was created on that device, e.g.,
`prior = torch.distributions.MultivariateNormal(loc=torch.zeros(2, device="cuda:0"),
covariance_matrix=torch.eye(2, device="cuda:0"))`.

## Performance

Expand Down
4 changes: 2 additions & 2 deletions sbi/analysis/conditional_density.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,8 +46,8 @@ def eval_conditional_density(
eps_margins2: We will evaluate the posterior along `dim2` at
`limits[0]+eps_margins` until `limits[1]-eps_margins`. This avoids
evaluations potentially exactly at the prior bounds.
return_raw_log_prob: If `True`, return the log-probability evaluated on the·
grid. If `False`, return the probability, scaled down by the maximum value·
return_raw_log_prob: If `True`, return the log-probability evaluated on the
grid. If `False`, return the probability, scaled down by the maximum value
on the grid for numerical stability (i.e. exp(log_prob - max_log_prob)).

Returns: Conditional probabilities. If `dim1 == dim2`, this will have shape
Expand Down
25 changes: 5 additions & 20 deletions sbi/inference/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,7 @@ def __init__(
0.14.0 is more mature, we will remove this argument.
"""

self._device = process_device(device)
self._device = process_device(device, prior=prior)

if unused_args:
warn(
Expand Down Expand Up @@ -381,9 +381,7 @@ def _default_summary_writer(self) -> SummaryWriter:

method = self.__class__.__name__
logdir = Path(
get_log_root(),
method,
datetime.now().isoformat().replace(":", "_"),
get_log_root(), method, datetime.now().isoformat().replace(":", "_")
)
return SummaryWriter(logdir)

Expand Down Expand Up @@ -437,11 +435,7 @@ def _report_convergence_at_end(
)

def _summarize(
self,
round_: int,
x_o: Union[Tensor, None],
theta_bank: Tensor,
x_bank: Tensor,
self, round_: int, x_o: Union[Tensor, None], theta_bank: Tensor, x_bank: Tensor
) -> None:
"""Update the summary_writer with statistics for a given round.

Expand All @@ -456,12 +450,7 @@ def _summarize(
# Median |x - x0| for most recent round.
if x_o is not None:
median_observation_distance = torch.median(
torch.sqrt(
torch.sum(
(x_bank - x_o.reshape(1, -1)) ** 2,
dim=-1,
)
)
torch.sqrt(torch.sum((x_bank - x_o.reshape(1, -1)) ** 2, dim=-1))
)
self._summary["median_observation_distances"].append(
median_observation_distance.item()
Expand Down Expand Up @@ -558,11 +547,7 @@ def simulate_for_sbi(
theta = proposal.sample((num_simulations,))

x = simulate_in_batches(
simulator,
theta,
simulation_batch_size,
num_workers,
show_progress_bar,
simulator, theta, simulation_batch_size, num_workers, show_progress_bar
)

return theta, x
Expand Down
15 changes: 8 additions & 7 deletions sbi/inference/posteriors/base_posterior.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,7 +149,9 @@ def set_default_x(self, x: Tensor) -> "NeuralPosterior":
Returns:
`NeuralPosterior` that will use a default `x` when not explicitly passed.
"""
self._x = process_x(x, self._x_shape, allow_iid_x=self._allow_iid_x)
self._x = process_x(x, self._x_shape, allow_iid_x=self._allow_iid_x).to(
self._device
)
self._num_iid_trials = self._x.shape[0]

return self
Expand Down Expand Up @@ -358,12 +360,10 @@ def _prepare_theta_and_x_for_log_prob_(
self._ensure_single_x(x)
self._ensure_x_consistent_with_default_x(x)

return theta, x
return theta.to(self._device), x.to(self._device)

def _prepare_for_sample(
self,
x: Tensor,
sample_shape: Optional[Tensor],
self, x: Tensor, sample_shape: Optional[Tensor]
) -> Tuple[Tensor, int]:
r"""
Return checked, reshaped, potentially default values for `x` and `sample_shape`.
Expand Down Expand Up @@ -835,7 +835,8 @@ def map(
def potential_fn(theta):
return self.log_prob(theta, x=x, track_gradients=True, **log_prob_kwargs)

interruption_note = "The last estimate of the MAP can be accessed via the `posterior.map_` attribute."
interruption_note = """The last estimate of the MAP can be accessed via the
`posterior.map_` attribute."""

self.map_, _ = optimize_potential_fn(
potential_fn=potential_fn,
Expand Down Expand Up @@ -1174,7 +1175,7 @@ def np_potential(self, theta: np.ndarray) -> ScalarFloat:
theta_condition = deepcopy(self.condition)
theta_condition[:, self.dims_to_sample] = theta

return self.potential_fn_provider.np_potential(
return self.potential_fn_provider.posterior_potential(
utils.tensor2numpy(theta_condition)
)

Expand Down
79 changes: 29 additions & 50 deletions sbi/inference/posteriors/direct_posterior.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,10 +82,7 @@ def __init__(
device: Training device, e.g., cpu or cuda:0
"""

kwargs = del_entries(
locals(),
entries=("self", "__class__"),
)
kwargs = del_entries(locals(), entries=("self", "__class__"))
super().__init__(**kwargs)

self._purpose = (
Expand Down Expand Up @@ -183,17 +180,15 @@ def log_prob(
with torch.set_grad_enabled(track_gradients):

# Evaluate on device, move back to cpu for comparison with prior.
unnorm_log_prob = self.net.log_prob(
theta_repeated.to(self._device), x_repeated.to(self._device)
).cpu()
unnorm_log_prob = self.net.log_prob(theta_repeated, x_repeated)

# Force probability to be zero outside prior support.
in_prior_support = within_support(self._prior, theta)

masked_log_prob = torch.where(
in_prior_support,
unnorm_log_prob,
torch.tensor(float("-inf"), dtype=torch.float32),
torch.tensor(float("-inf"), dtype=torch.float32, device=self._device),
)

if leakage_correction_params is None:
Expand Down Expand Up @@ -556,11 +551,7 @@ class PotentialFunctionProvider:
"""

def __call__(
self,
prior,
posterior_nn: nn.Module,
x: Tensor,
method: str,
self, prior, posterior_nn: nn.Module, x: Tensor, method: str
) -> Callable:
"""Return potential function.

Expand All @@ -583,59 +574,47 @@ def __call__(
NotImplementedError

def posterior_potential(
self, theta: np.ndarray, track_gradients: bool = False
) -> ScalarFloat:
r"""Return posterior theta log prob. $p(\theta|x)$, $-\infty$ if outside prior."

Args:
theta: Parameters $\theta$, batch dimension 1.
self, theta: Union[Tensor, np.array], track_gradients: bool = False
) -> Tensor:
"Return posterior theta log prob. $p(\theta|x)$, $-\infty$ if outside prior."

Returns:
Posterior log probability $\log(p(\theta|x))$.
"""
theta = torch.as_tensor(theta, dtype=torch.float32)
theta = ensure_theta_batched(theta)
num_batch = theta.shape[0]
# Device is the same for net and prior.
theta = ensure_theta_batched(torch.as_tensor(theta, dtype=torch.float32)).to(
self.device
)

# Repeat x over batch dim to match theta batch, accounting for multi-D x.
x_repeated = self.x.repeat(num_batch, *(1 for _ in range(self.x.ndim - 1)))
theta_repeated, x_repeated = DirectPosterior._match_theta_and_x_batch_shapes(
theta, self.x
)

with torch.set_grad_enabled(track_gradients):
target_log_prob = self.posterior_nn.log_prob(
inputs=theta.to(self.device),
context=x_repeated,
)

# Evaluate on device, move back to cpu for comparison with prior.
posterior_log_prob = self.posterior_nn.log_prob(theta_repeated, x_repeated)

# Force probability to be zero outside prior support.
in_prior_support = within_support(self.prior, theta)
target_log_prob[~in_prior_support] = -float("Inf")

return target_log_prob
posterior_log_prob = torch.where(
in_prior_support,
posterior_log_prob,
torch.tensor(float("-inf"), dtype=torch.float32, device=self.device),
)

return posterior_log_prob

def pyro_potential(
self, theta: Dict[str, Tensor], track_gradients: bool = False
) -> Tensor:
r"""Return posterior log prob. of theta $p(\theta|x)$, -inf where outside prior.
r"""Return posterior theta log prob. $p(\theta|x)$, $-\infty$ if outside prior."

Args:
theta: Parameters $\theta$ (from pyro sampler).

Returns:
Posterior log probability $p(\theta|x)$, masked outside of prior.
Negative posterior log probability $p(\theta|x)$, masked outside of prior.
"""

theta = next(iter(theta.values()))

with torch.set_grad_enabled(track_gradients):
# Notice opposite sign to `posterior_potential`.
# Move theta to device for evaluation.
log_prob_posterior = -self.posterior_nn.log_prob(
inputs=theta.to(self.device),
context=self.x,
).cpu()

in_prior_support = within_support(self.prior, theta)

return torch.where(
in_prior_support,
log_prob_posterior,
float("-inf") * torch.ones_like(log_prob_posterior),
)
return -self.posterior_potential(theta, track_gradients=track_gradients)
51 changes: 15 additions & 36 deletions sbi/inference/posteriors/likelihood_based_posterior.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,7 @@ def log_prob(
)

# Move to cpu for comparison with prior.
return log_likelihood_trial_sum.cpu() + self._prior.log_prob(theta)
return log_likelihood_trial_sum + self._prior.log_prob(theta)

def sample(
self,
Expand Down Expand Up @@ -355,10 +355,7 @@ def map(

@staticmethod
def _log_likelihoods_over_trials(
x: Tensor,
theta: Tensor,
net: nn.Module,
track_gradients: bool = False,
x: Tensor, theta: Tensor, net: nn.Module, track_gradients: bool = False
) -> Tensor:
r"""Return log likelihoods summed over iid trials of `x`.

Expand Down Expand Up @@ -423,11 +420,7 @@ class PotentialFunctionProvider:
"""

def __call__(
self,
prior,
likelihood_nn: nn.Module,
x: Tensor,
method: str,
self, prior, likelihood_nn: nn.Module, x: Tensor, method: str
) -> Callable:
r"""Return potential function for posterior $p(\theta|x)$.

Expand Down Expand Up @@ -459,35 +452,23 @@ def __call__(
else:
NotImplementedError

def log_likelihood(self, theta: Tensor, track_gradients: bool = False) -> Tensor:
def posterior_potential(
self, theta: Union[Tensor, np.array], track_gradients: bool = False
) -> Tensor:
"""Return log likelihood of fixed data given a batch of parameters."""

# Device is the same for net and prior.
theta = ensure_theta_batched(torch.as_tensor(theta, dtype=torch.float32)).to(
self.device
)

log_likelihoods = LikelihoodBasedPosterior._log_likelihoods_over_trials(
x=self.x,
theta=ensure_theta_batched(theta).to(self.device),
theta=theta,
net=self.likelihood_nn,
track_gradients=track_gradients,
)

return log_likelihoods

def posterior_potential(
self, theta: np.array, track_gradients: bool = False
) -> ScalarFloat:
r"""Return posterior log prob. of theta $p(\theta|x)$"

Args:
theta: Parameters $\theta$, batch dimension 1.

Returns:
Posterior log probability of the theta, $-\infty$ if impossible under prior.
"""
theta = torch.as_tensor(theta, dtype=torch.float32)

# Notice opposite sign to pyro potential.
return self.log_likelihood(
theta, track_gradients=track_gradients
).cpu() + self.prior.log_prob(theta)
return log_likelihoods + self.prior.log_prob(theta)

def pyro_potential(
self, theta: Dict[str, Tensor], track_gradients: bool = False
Expand All @@ -505,7 +486,5 @@ def pyro_potential(

theta = next(iter(theta.values()))

return -(
self.log_likelihood(theta, track_gradients=track_gradients).cpu()
+ self.prior.log_prob(theta)
)
# Note the minus to match the pyro potential function requirements.
return -self.posterior_potential(theta, track_gradients=track_gradients)
Loading