diff --git a/devito/propagator.py b/devito/propagator.py index 3c325608cd..d7d7acdda9 100644 --- a/devito/propagator.py +++ b/devito/propagator.py @@ -6,6 +6,7 @@ import numpy as np from sympy import Indexed, IndexedBase, symbols from sympy.abc import t, x, y, z +from sympy.utilities.iterables import postorder_traversal import cgen_wrapper as cgen from codeprinter import ccode @@ -560,22 +561,22 @@ def time_substitutions(self, sympy_expr): :param sympy_expr: The Sympy expression to process :returns: The expression after the substitutions """ - if isinstance(sympy_expr, Indexed): - array_term = sympy_expr + subs_dict = {} - if not str(array_term.base.label) in self.save_vars: - raise ValueError("Invalid variable '%s' in sympy expression. Did you add it to the operator's params?" % - str(array_term.base.label)) + for arg in postorder_traversal(sympy_expr): + if isinstance(arg, Indexed): + array_term = arg - if not self.save_vars[str(array_term.base.label)]: - array_term = array_term.xreplace(self.t_replace) + if not str(array_term.base.label) in self.save_vars: + raise ValueError( + "Invalid variable '%s' in sympy expression. Did you add it to the operator's params?" + % str(array_term.base.label) + ) - return array_term - else: - for arg in sympy_expr.args: - sympy_expr = sympy_expr.subs(arg, self.time_substitutions(arg)) + if not self.save_vars[str(array_term.base.label)]: + subs_dict[arg] = array_term.xreplace(self.t_replace) - return sympy_expr + return sympy_expr.xreplace(subs_dict) def add_time_loop_stencil(self, stencil, before=False): """Add a statement either before or after the main spatial loop, but still inside the time loop.