Skip to content

Commit

Permalink
Optimize timesteppers to avoid early RHS call and extra transforms, c…
Browse files Browse the repository at this point in the history
…aught by Daniel
  • Loading branch information
kburns committed Aug 11, 2020
1 parent cc0373e commit 0ce1482
Showing 1 changed file with 35 additions and 28 deletions.
63 changes: 35 additions & 28 deletions dedalus/core/timesteppers.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,10 +105,6 @@ def step(self, solver, dt):
a, b, c = self.compute_coefficients(self.dt, self._iteration)
self._iteration += 1

# Run evaluator
state.scatter()
evaluator.evaluate_scheduled(**evaluator_kw)

# Update RHS components and LHS matrices
MX.rotate()
LX.rotate()
Expand All @@ -124,34 +120,40 @@ def step(self, solver, dt):
update_LHS = ((a0, b0) != self._LHS_params)
self._LHS_params = (a0, b0)

# Update MX0, LX0, F0
# Evaluate M.X0 and L.X0
MX0.data.fill(0)
LX0.data.fill(0)
F0.data.fill(0)
for p in pencils:
x = state.get_pencil(p)
fast_csr_matvec(p.M, x, MX0.get_pencil(p))
fast_csr_matvec(p.L, x, LX0.get_pencil(p))
if update_LHS:
# Remove old solver reference
p.LHS_solver = None

# Run evaluator and compute F0
# No need to scatter since gather occured just before step was called
evaluator.evaluate_scheduled(**evaluator_kw)
F0.data.fill(0)
for p in pencils:
fast_csr_matvec(p.pre_left, solver.F.get_pencil(p), F0.get_pencil(p))

# Build RHS
RHS.data.fill(0)
for j in range(1, len(c)):
RHS.data += c[j] * F[j-1].data
np.multiply(c[1], F0.data, out=RHS.data)
for j in range(2, len(c)):
RHS.data += c[j] * F[j-1].data # CREATES TEMPORARY
for j in range(1, len(a)):
RHS.data -= a[j] * MX[j-1].data
RHS.data -= a[j] * MX[j-1].data # CREATES TEMPORARY
for j in range(1, len(b)):
RHS.data -= b[j] * LX[j-1].data
RHS.data -= b[j] * LX[j-1].data # CREATES TEMPORARY

# Solve
state.data.fill(0)
for p in pencils:
pRHS = RHS.get_pencil(p)
if update_LHS:
np.copyto(p.LHS.data, a0*p.M_exp.data + b0*p.L_exp.data)
# Remove old solver reference before building new solver
p.LHS_solver = None
np.copyto(p.LHS.data, a0*p.M_exp.data + b0*p.L_exp.data) # CREATES TEMPORARY
p.LHS_solver = solver.matsolver(p.LHS, solver)
pRHS = RHS.get_pencil(p)
pX = p.LHS_solver.solve(pRHS)
if p.pre_right is None:
state.set_pencil(p, pX)
Expand Down Expand Up @@ -546,41 +548,46 @@ def step(self, solver, dt):
for p in pencils:
fast_csr_matvec(p.M, state.get_pencil(p), MX0.get_pencil(p))
if update_LHS:
# Remove old solver references
p.LHS_solvers = [None] * (self.stages+1)

# Compute stages
# (M + k Hii L).X(n,i) = M.X(n,0) + k Aij F(n,j) - k Hij L.X(n,j)
for i in range(1, self.stages+1):
# Compute L.X(n,i-1)
LXi = LX[i-1]
LXi.data.fill(0)
for p in pencils:
fast_csr_matvec(p.L, state.get_pencil(p), LXi.get_pencil(p))

# Compute F(n,i-1), L.X(n,i-1)
state.scatter()
# Compute F(n,i-1)
evaluator_kw['sim_time'] = solver.sim_time
if i == 1:
# No need to scatter since gather occured just before step was called
evaluator.evaluate_scheduled(**evaluator_kw)
else:
state.scatter()
evaluator.evaluate_group('F', **evaluator_kw)
LX[i-1].data.fill(0)
F[i-1].data.fill(0)
Fi = F[i-1]
Fi.data.fill(0)
for p in pencils:
fast_csr_matvec(p.L, state.get_pencil(p), LX[i-1].get_pencil(p))
fast_csr_matvec(p.pre_left, solver.F.get_pencil(p), F[i-1].get_pencil(p))
fast_csr_matvec(p.pre_left, solver.F.get_pencil(p), Fi.get_pencil(p))

# Construct RHS(n,i)
np.copyto(RHS.data, MX0.data)
for j in range(0, i):
RHS.data += (k * A[i,j]) * F[j].data
RHS.data -= (k * H[i,j]) * LX[j].data
RHS.data += (k * A[i,j]) * F[j].data # CREATES TEMPORARY
RHS.data -= (k * H[i,j]) * LX[j].data # CREATES TEMPORARY

# Solve for stage
state.data.fill(0)
for p in pencils:
pRHS = RHS.get_pencil(p)
# Construct LHS(n,i)
if update_LHS:
np.copyto(p.LHS.data, p.M_exp.data + (k*H[i,i])*p.L_exp.data)
# Remove old solver reference before building new solver
p.LHS_solvers[i] = None
np.copyto(p.LHS.data, p.M_exp.data + (k*H[i,i])*p.L_exp.data) # CREATES TEMPORARY
p.LHS_solvers[i] = solver.matsolver(p.LHS, solver)
pX = p.LHS_solvers[i].solve(pRHS)
pRHS = RHS.get_pencil(p)
pX = p.LHS_solvers[i].solve(pRHS) # CREATES TEMPORARY
if p.pre_right is None:
state.set_pencil(p, pX)
else:
Expand Down

0 comments on commit 0ce1482

Please sign in to comment.