Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 13 additions & 12 deletions devito/propagator.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import numpy as np
from sympy import Indexed, IndexedBase, symbols
from sympy.abc import t, x, y, z
from sympy.utilities.iterables import postorder_traversal

import cgen_wrapper as cgen
from codeprinter import ccode
Expand Down Expand Up @@ -560,22 +561,22 @@ def time_substitutions(self, sympy_expr):
:param sympy_expr: The Sympy expression to process
:returns: The expression after the substitutions
"""
if isinstance(sympy_expr, Indexed):
array_term = sympy_expr
subs_dict = {}

if not str(array_term.base.label) in self.save_vars:
raise ValueError("Invalid variable '%s' in sympy expression. Did you add it to the operator's params?" %
str(array_term.base.label))
for arg in postorder_traversal(sympy_expr):
if isinstance(arg, Indexed):
array_term = arg

if not self.save_vars[str(array_term.base.label)]:
array_term = array_term.xreplace(self.t_replace)
if not str(array_term.base.label) in self.save_vars:
raise ValueError(
"Invalid variable '%s' in sympy expression. Did you add it to the operator's params?"
% str(array_term.base.label)
)

return array_term
else:
for arg in sympy_expr.args:
sympy_expr = sympy_expr.subs(arg, self.time_substitutions(arg))
if not self.save_vars[str(array_term.base.label)]:
subs_dict[arg] = array_term.xreplace(self.t_replace)

return sympy_expr
return sympy_expr.xreplace(subs_dict)

def add_time_loop_stencil(self, stencil, before=False):
"""Add a statement either before or after the main spatial loop, but still inside the time loop.
Expand Down