Skip to content

Commit

Permalink
Add options to docstring
Browse files Browse the repository at this point in the history
  • Loading branch information
manuelgloeckler authored and michaeldeistler committed Aug 27, 2024
1 parent 84bbf85 commit cb6adff
Show file tree
Hide file tree
Showing 4 changed files with 3 additions and 5 deletions.
3 changes: 2 additions & 1 deletion sbi/inference/npse/npse.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,8 @@ def __init__(
Args:
prior: Prior distribution.
score_estimator: Neural network architecture for the score estimator. Can be
a string (e.g. 'mlp') or a callable that returns a neural network.
a string (e.g. 'mlp' or 'ada_mlp') or a callable that returns a neural
network.
sde_type: Type of SDE to use. Must be one of ['vp', 've', 'subvp'].
device: Device to run the training on.
logging_level: Logging level for the training. Can be an integer or a
Expand Down
3 changes: 1 addition & 2 deletions sbi/inference/potentials/score_based_potential.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,7 +162,7 @@ def gradient(
raise NotImplementedError(
"Score accumulation for IID data is not yet implemented."
)

return score

def get_continuous_normalizing_flow(
Expand Down Expand Up @@ -229,4 +229,3 @@ def f(t, x):
exact=exact,
)
return transform

1 change: 0 additions & 1 deletion sbi/samplers/score/predictors.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,4 +120,3 @@ def predict(self, theta: Tensor, t1: Tensor, t0: Tensor):
f_backward = f - (1 + self.eta**2) / 2 * g**2 * score
g_backward = self.eta * g
return theta - f_backward * dt + g_backward * torch.randn_like(theta) * dt_sqrt

1 change: 0 additions & 1 deletion tests/score_samplers_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,6 @@ def _build_gaussian_score_estimator(
# Note the precondition predicts a correct Gaussian score by default if the neural
# net predicts 0!
class DummyNet(torch.nn.Module):

def __init__(self):
super().__init__()
self.dummy_param_for_device_detection = torch.nn.Linear(1, 1)
Expand Down

0 comments on commit cb6adff

Please sign in to comment.