From 21c65e0846867efd1515536b839edadf2c700ec9 Mon Sep 17 00:00:00 2001 From: Du Phan Date: Fri, 2 Apr 2021 18:18:32 -0500 Subject: [PATCH 1/3] add failing test for scatter --- test/test_tensor.py | 31 +++++++++++++++++++++++++++++++ 1 file changed, 31 insertions(+) diff --git a/test/test_tensor.py b/test/test_tensor.py index 6b4e5323..724edc72 100644 --- a/test/test_tensor.py +++ b/test/test_tensor.py @@ -14,6 +14,8 @@ import funsor import funsor.ops as ops +from funsor.cnf import Contraction +from funsor.delta import Delta from funsor.domains import Array, Bint, Product, Real, Reals, find_domain from funsor.interpretations import eager, lazy from funsor.tensor import ( @@ -1468,3 +1470,32 @@ def raw_reduction(x, dim=None, keepdims=False, batch_ndims=len(batch_shape)): actual = op(Tensor(data, inputs), dim, keepdims=keepdims) expected = Tensor(raw_reduction(data, dim, keepdims), inputs, dtype) assert_close(actual, expected, rtol=rtol) + + +def test_scatter_substitute(): + expr = Scatter( + ops.logaddexp, + (('_time_states_38', + Number(0, 1),),), + Contraction(ops.null, ops.add, + frozenset(), + (Delta( + (('states', + (Tensor( + np.array(5, dtype=np.int32), + (), + 10), + Number(0.0),),), + ('_PREV_states', + (Tensor( + np.array(4, dtype=np.int32), + (), + 10), + Number(0.0),),),)), + Tensor( + np.array(0.3386716842651367, dtype=np.float32), + (), + 'real'),)), + frozenset()) + + expr(_time_states_38="_time_states") From 16fb14931107583e721f666785f2862369fedc92 Mon Sep 17 00:00:00 2001 From: Du Phan Date: Fri, 2 Apr 2021 20:59:33 -0500 Subject: [PATCH 2/3] fix scatter subs test --- funsor/terms.py | 11 +++++++++ test/test_tensor.py | 54 +++++++++++++++++++++++++++------------------ 2 files changed, 44 insertions(+), 21 deletions(-) diff --git a/funsor/terms.py b/funsor/terms.py index d5d795ff..916756da 100644 --- a/funsor/terms.py +++ b/funsor/terms.py @@ -1218,6 +1218,17 @@ def _alpha_convert(self, alpha_subs): reduced_vars = frozenset(alpha_subs.get(var.name, var) for var in reduced_vars) return op, subs, source, reduced_vars + def eager_subs(self, subs): + subs = OrderedDict(subs) + new_subs = [] + for name, sub in self.subs: + if name in subs: + assert isinstance(subs[name], Variable) + new_subs.append((subs[name].name, sub)) + else: + new_subs.append((name, sub)) + return Scatter(self.op, tuple(new_subs), self.source, self.reduced_vars) + class Approximate(Funsor): """ diff --git a/test/test_tensor.py b/test/test_tensor.py index 724edc72..a2a733f6 100644 --- a/test/test_tensor.py +++ b/test/test_tensor.py @@ -1475,27 +1475,39 @@ def raw_reduction(x, dim=None, keepdims=False, batch_ndims=len(batch_shape)): def test_scatter_substitute(): expr = Scatter( ops.logaddexp, - (('_time_states_38', - Number(0, 1),),), - Contraction(ops.null, ops.add, + ( + ( + "_time_states_38", + Number(0, 1), + ), + ), + Contraction( + ops.null, + ops.add, + frozenset(), + ( + Delta( + ( + ( + "states", + ( + Tensor(np.array(5, dtype=np.int32), (), 10), + Number(0.0), + ), + ), + ( + "_PREV_states", + ( + Tensor(np.array(4, dtype=np.int32), (), 10), + Number(0.0), + ), + ), + ) + ), + Tensor(np.array(0.3386716842651367, dtype=np.float32), (), "real"), + ), + ), frozenset(), - (Delta( - (('states', - (Tensor( - np.array(5, dtype=np.int32), - (), - 10), - Number(0.0),),), - ('_PREV_states', - (Tensor( - np.array(4, dtype=np.int32), - (), - 10), - Number(0.0),),),)), - Tensor( - np.array(0.3386716842651367, dtype=np.float32), - (), - 'real'),)), - frozenset()) + ) expr(_time_states_38="_time_states") From baec1a18ba536f1c1a8724f1f36f163b0726c188 Mon Sep 17 00:00:00 2001 From: Du Phan Date: Sun, 4 Apr 2021 23:48:08 -0500 Subject: [PATCH 3/3] by pass Slice issue --- funsor/terms.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/funsor/terms.py b/funsor/terms.py index 916756da..49611b2c 100644 --- a/funsor/terms.py +++ b/funsor/terms.py @@ -1222,8 +1222,7 @@ def eager_subs(self, subs): subs = OrderedDict(subs) new_subs = [] for name, sub in self.subs: - if name in subs: - assert isinstance(subs[name], Variable) + if name in subs and isinstance(subs[name], Variable): new_subs.append((subs[name].name, sub)) else: new_subs.append((name, sub))