Skip to content

Commit

Permalink
ENH: replace task based lss moment iteration with Iterator
Browse files Browse the repository at this point in the history
  • Loading branch information
sglyon committed Feb 7, 2017
1 parent e61ecc0 commit fd303c4
Showing 1 changed file with 36 additions and 27 deletions.
63 changes: 36 additions & 27 deletions src/lss.jl
Original file line number Diff line number Diff line change
Expand Up @@ -153,8 +153,27 @@ end
replicate(lss::LSS; t::Integer=10, num_reps::Integer=100) =
replicate(lss, t, num_reps)

immutable LSSMoments
lss::LSS
end

Base.start(L::LSSMoments) = (copy(L.lss.mu_0), copy(L.lss.Sigma_0))
Base.done(L::LSSMoments, _) = false
function Base.next(L::LSSMoments, moms)
A, C, G = L.lss.A, L.lss.C, L.lss.G
mu_x, Sigma_x = moms

mu_y, Sigma_y = G * mu_x, G * Sigma_x * G'

# Update moments of x
mu_x2 = A * mu_x
Sigma_x2 = A * Sigma_x * A' + C * C'

(mu_x, mu_y, Sigma_x, Sigma_y), (mu_x2, Sigma_x2)
end

"""
Create a generator to calculate the population mean and
Create an iterator to calculate the population mean and
variance-convariance matrix for both x_t and y_t, starting at
the initial condition (self.mu_0, self.Sigma_0). Each iteration
produces a 4-tuple of items (mu_x, mu_y, Sigma_x, Sigma_y) for
Expand All @@ -165,18 +184,8 @@ the next period.
- `lss::LSS` An instance of the Gaussian linear state space model
"""
function moment_sequence(lss::LSS)
A, C, G = lss.A, lss.C, lss.G
mu_x, Sigma_x = copy(lss.mu_0), copy(lss.Sigma_0)
while true
mu_y, Sigma_y = G * mu_x, G * Sigma_x * G'
produce((mu_x, mu_y, Sigma_x, Sigma_y))

# Update moments of x
mu_x = A * mu_x
Sigma_x = A * Sigma_x * A' + C * C'
end
end
moment_sequence(lss::LSS) = LSSMoments(lss)


"""
Compute the moments of the stationary distributions of x_t and
Expand All @@ -199,22 +208,22 @@ initial conditions lss.mu_0 and lss.Sigma_0
"""
function stationary_distributions(lss::LSS; max_iter=200, tol=1e-5)
# Initialize iteration
m = @task moment_sequence(lss)
mu_x, mu_y, Sigma_x, Sigma_y = consume(m)
m = moment_sequence(lss)
mu_x, mu_y, Sigma_x, Sigma_y = first(m)

i = 0
err = tol + 1.

while err > tol
if i > max_iter
error("Convergence failed after $i iterations")
else
i += 1
mu_x1, mu_y, Sigma_x1, Sigma_y = consume(m)
err_mu = Base.maxabs(mu_x1 - mu_x)
err_Sigma = Base.maxabs(Sigma_x1 - Sigma_x)
err = max(err_Sigma, err_mu)
mu_x, Sigma_x = mu_x1, Sigma_x1
err = tol + 1.0

for (mu_x1, mu_y, Sigma_x1, Sigma_y) in m
i > max_iter && error("Convergence failed after $i iterations")
i += 1
err_mu = maximum(abs, mu_x1 - mu_x)
err_Sigma = maximum(abs, Sigma_x1 - Sigma_x)
err = max(err_Sigma, err_mu)
mu_x, Sigma_x = mu_x1, Sigma_x1

if err < tol && i > 1
break
end
end

Expand Down

0 comments on commit fd303c4

Please sign in to comment.