Skip to content

Commit

Permalink
No default validation of prior args (#31)
Browse files Browse the repository at this point in the history
  • Loading branch information
jan-matthis authored Nov 12, 2021
1 parent b9781c6 commit e8f9d2b
Show file tree
Hide file tree
Showing 8 changed files with 8 additions and 0 deletions.
1 change: 1 addition & 0 deletions sbibm/tasks/bernoulli_glm/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ def __init__(self, summary="sufficient"):

self.prior_params = {"loc": torch.zeros((M + 1,)), "precision_matrix": Binv}
self.prior_dist = pdist.MultivariateNormal(**self.prior_params)
self.prior_dist.set_default_validate_args(False)

def get_prior(self) -> Callable:
def prior(num_samples=1):
Expand Down
1 change: 1 addition & 0 deletions sbibm/tasks/gaussian_linear/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ def __init__(
}

self.prior_dist = pdist.MultivariateNormal(**self.prior_params)
self.prior_dist.set_default_validate_args(False)

self.simulator_params = {
"precision_matrix": torch.inverse(
Expand Down
1 change: 1 addition & 0 deletions sbibm/tasks/gaussian_linear_uniform/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ def __init__(
}

self.prior_dist = pdist.Uniform(**self.prior_params).to_event(1)
self.prior_dist.set_default_validate_args(False)

self.simulator_params = {
"precision_matrix": torch.inverse(
Expand Down
1 change: 1 addition & 0 deletions sbibm/tasks/gaussian_mixture/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ def __init__(
}

self.prior_dist = pdist.Uniform(**self.prior_params).to_event(1)
self.prior_dist.set_default_validate_args(False)

self.simulator_params = {
"mixture_locs_factor": torch.tensor([1.0, 1.0]),
Expand Down
1 change: 1 addition & 0 deletions sbibm/tasks/lotka_volterra/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,7 @@ def __init__(
"scale": torch.tensor([sigma_p, sigma_p, sigma_p, sigma_p]),
}
self.prior_dist = pdist.LogNormal(**self.prior_params).to_event(1)
self.prior_dist.set_default_validate_args(False)

self.u0 = torch.tensor([30.0, 1.0])
self.tspan = torch.tensor([0.0, days])
Expand Down
1 change: 1 addition & 0 deletions sbibm/tasks/sir/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,7 @@ def __init__(
"scale": torch.tensor([0.5, 0.2]),
}
self.prior_dist = pdist.LogNormal(**self.prior_params).to_event(1)
self.prior_dist.set_default_validate_args(False)

self.u0 = torch.tensor([N - I0 - R0, I0, R0])
self.tspan = torch.tensor([0.0, days])
Expand Down
1 change: 1 addition & 0 deletions sbibm/tasks/slcp/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@ def __init__(self, distractors: bool = False):
"high": torch.tensor([+3.0 for _ in range(self.dim_parameters)]),
}
self.prior_dist = pdist.Uniform(**self.prior_params).to_event(1)
self.prior_dist.set_default_validate_args(False)

def get_prior(self) -> Callable:
def prior(num_samples=1):
Expand Down
1 change: 1 addition & 0 deletions sbibm/tasks/two_moons/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ def __init__(self):
"high": +prior_bound * torch.ones((self.dim_parameters,)),
}
self.prior_dist = pdist.Uniform(**self.prior_params).to_event(1)
self.prior_dist.set_default_validate_args(False)

self.simulator_params = {
"a_low": -math.pi / 2.0,
Expand Down

0 comments on commit e8f9d2b

Please sign in to comment.