diff --git a/funsor/terms.py b/funsor/terms.py index d5d795ff..49611b2c 100644 --- a/funsor/terms.py +++ b/funsor/terms.py @@ -1218,6 +1218,16 @@ def _alpha_convert(self, alpha_subs): reduced_vars = frozenset(alpha_subs.get(var.name, var) for var in reduced_vars) return op, subs, source, reduced_vars + def eager_subs(self, subs): + subs = OrderedDict(subs) + new_subs = [] + for name, sub in self.subs: + if name in subs and isinstance(subs[name], Variable): + new_subs.append((subs[name].name, sub)) + else: + new_subs.append((name, sub)) + return Scatter(self.op, tuple(new_subs), self.source, self.reduced_vars) + class Approximate(Funsor): """ diff --git a/test/test_tensor.py b/test/test_tensor.py index 6b4e5323..a2a733f6 100644 --- a/test/test_tensor.py +++ b/test/test_tensor.py @@ -14,6 +14,8 @@ import funsor import funsor.ops as ops +from funsor.cnf import Contraction +from funsor.delta import Delta from funsor.domains import Array, Bint, Product, Real, Reals, find_domain from funsor.interpretations import eager, lazy from funsor.tensor import ( @@ -1468,3 +1470,44 @@ def raw_reduction(x, dim=None, keepdims=False, batch_ndims=len(batch_shape)): actual = op(Tensor(data, inputs), dim, keepdims=keepdims) expected = Tensor(raw_reduction(data, dim, keepdims), inputs, dtype) assert_close(actual, expected, rtol=rtol) + + +def test_scatter_substitute(): + expr = Scatter( + ops.logaddexp, + ( + ( + "_time_states_38", + Number(0, 1), + ), + ), + Contraction( + ops.null, + ops.add, + frozenset(), + ( + Delta( + ( + ( + "states", + ( + Tensor(np.array(5, dtype=np.int32), (), 10), + Number(0.0), + ), + ), + ( + "_PREV_states", + ( + Tensor(np.array(4, dtype=np.int32), (), 10), + Number(0.0), + ), + ), + ) + ), + Tensor(np.array(0.3386716842651367, dtype=np.float32), (), "real"), + ), + ), + frozenset(), + ) + + expr(_time_states_38="_time_states")