Skip to content

Commit

Permalink
Use unit normal as default init_dist in GaussianRandomWalk and AR
Browse files Browse the repository at this point in the history
  • Loading branch information
ricardoV94 committed May 18, 2022
1 parent 862bd05 commit b678bc4
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 6 deletions.
9 changes: 4 additions & 5 deletions pymc/distributions/timeseries.py
Original file line number Diff line number Diff line change
Expand Up @@ -225,7 +225,7 @@ class GaussianRandomWalk(distribution.Continuous):
sigma > 0, innovation standard deviation, defaults to 1.0
init : unnamed distribution
Univariate distribution of the initial value, created with the `.dist()` API.
Defaults to Normal with same `mu` and `sigma` as the GaussianRandomWalk
Defaults to a unit Normal.
.. warning:: init will be cloned, rendering them independent of the ones passed as input.
Expand Down Expand Up @@ -265,7 +265,7 @@ def dist(

# If no scalar distribution is passed then initialize with a Normal of same mu and sigma
if init is None:
init = Normal.dist(mu, sigma)
init = Normal.dist(0, 1)
else:
if not (
isinstance(init, at.TensorVariable)
Expand Down Expand Up @@ -361,7 +361,7 @@ class AR(SymbolicDistribution):
Whether the first element of rho should be used as a constant term in the AR
process. Defaults to False
init_dist: unnamed distribution, optional
Scalar or vector distribution for initial values. Defaults to Normal(0, sigma).
Scalar or vector distribution for initial values. Defaults to a unit Normal.
Distribution should be created via the `.dist()` API, and have dimension
(*size, ar_order). If not, it will be automatically resized.
Expand Down Expand Up @@ -452,8 +452,7 @@ def dist(
f"got ndim_supp={init_dist.owner.op.ndim_supp}.",
)
else:
# Sigma must broadcast with ar_order
init_dist = Normal.dist(sigma=at.shape_padright(sigma), size=(*sigma.shape, ar_order))
init_dist = Normal.dist(0, 1, size=(*sigma.shape, ar_order))

# Tell Aeppl to ignore init_dist, as it will be accounted for in the logp term
init_dist = ignore_logprob(init_dist)
Expand Down
4 changes: 3 additions & 1 deletion pymc/tests/test_distributions_timeseries.py
Original file line number Diff line number Diff line change
Expand Up @@ -333,6 +333,7 @@ def test_batched_sigma(self):
"y",
beta_tp,
sigma=sigma,
init_dist=Normal.dist(0, sigma[..., None]),
size=batch_size,
steps=steps,
initval=y_tp,
Expand All @@ -346,6 +347,7 @@ def test_batched_sigma(self):
f"y_{i}{j}",
beta_tp,
sigma=sigma[i][j],
init_dist=Normal.dist(0, sigma[i][j]),
shape=steps,
initval=y_tp[i, j],
ar_order=ar_order,
Expand All @@ -371,7 +373,7 @@ def test_batched_init_dist(self):
beta_tp = aesara.shared(np.random.randn(ar_order), shape=(3,))
y_tp = np.random.randn(batch_size, steps)
with Model() as t0:
init_dist = Normal.dist(0.0, 0.01, size=(batch_size, ar_order))
init_dist = Normal.dist(0.0, 1.0, size=(batch_size, ar_order))
AR("y", beta_tp, sigma=0.01, init_dist=init_dist, steps=steps, initval=y_tp)
with Model() as t1:
for i in range(batch_size):
Expand Down

0 comments on commit b678bc4

Please sign in to comment.