Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add a Delta -> Scatter pattern #531

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions funsor/jax/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -286,6 +286,9 @@ def _safesub(x, y):

@ops.scatter.register(array, tuple, array)
def _scatter(dest, indices, src):
missing = len(indices) - len(dest.shape)
if missing > 0:
dest = dest[(None,) * missing]
return index_update(dest, indices, src)


Expand Down
22 changes: 21 additions & 1 deletion funsor/joint.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
from funsor.interpretations import eager, moment_matching, normalize
from funsor.ops import AssociativeOp
from funsor.tensor import Tensor, align_tensor
from funsor.terms import Funsor, Independent, Number, Reduce, Unary
from funsor.terms import Funsor, Independent, Number, Reduce, Scatter, Unary
from funsor.typing import Variadic


Expand Down Expand Up @@ -74,6 +74,26 @@ def eager_cat_homogeneous(name, part_name, *parts):
return result


# FIXME this is too aggressive, but does fix some numpyro tests in
# motivated by https://github.com/pyro-ppl/numpyro/pull/991
@eager.register(
Contraction,
AssociativeOp,
AssociativeOp,
frozenset,
Delta[Tuple[Tuple[str, Tuple[Tensor, Number]], ...]],
Tensor,
)
def eager_delta_to_scatter(red_op, bin_op, reduced_vars, delta, tensor):
source = tensor
subs = {}
for name, (point, log_density) in delta.terms:
subs[name] = point
source = bin_op(source, log_density)
subs = tuple(subs.items())
return Scatter(red_op, subs, source, reduced_vars)


#################################
# patterns for moment-matching
#################################
Expand Down
4 changes: 3 additions & 1 deletion funsor/tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

import functools
import itertools
import math
import typing
import warnings
from collections import Counter, OrderedDict
Expand Down Expand Up @@ -636,6 +637,7 @@ def eager_scatter_number(op, subs, source, reduced_vars):
return eager_scatter_tensor(op, subs, source, reduced_vars)


# FIXME Does this blow out one-hot tensors, using unnecessarily much memory?
@eager.register(Scatter, Op, tuple, Tensor, frozenset)
def eager_scatter_tensor(op, subs, source, reduced_vars):
if not all(isinstance(v, (Variable, Number, Slice, Tensor)) for k, v in subs):
Expand Down Expand Up @@ -676,7 +678,7 @@ def eager_scatter_tensor(op, subs, source, reduced_vars):
# Construct a destination backend tensor.
output = source.output
shape = tuple(d.size for d in destin_inputs.values()) + output.shape
destin = ops.new_full(source.data, shape, ops.UNITS[op])
destin = ops.new_full(source.data, shape, ops.UNITS.get(op, math.nan))

# TODO Add a check for injectivity and dispatch to scatter_add etc.
data = ops.scatter(destin, indices, source_data)
Expand Down
292 changes: 286 additions & 6 deletions test/test_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -1475,12 +1475,7 @@ def raw_reduction(x, dim=None, keepdims=False, batch_ndims=len(batch_shape)):
def test_scatter_substitute():
expr = Scatter(
ops.logaddexp,
(
(
"_time_states_38",
Number(0, 1),
),
),
(("_time_states_38", Number(0, 1))),
Contraction(
ops.null,
ops.add,
Expand Down Expand Up @@ -1511,3 +1506,288 @@ def test_scatter_substitute():
)

expr(_time_states_38="_time_states")


# motivated by https://github.com/pyro-ppl/numpyro/pull/991
def test_scatter_dims_error():
op = ops.null
subs = (('_drop_0',
Tensor(
np.array([[[0, 0], [0, 0]]], dtype=np.int32),
(('_time_states', Bint[1]),
('_PREV_states', Bint[2]),
('states', Bint[2])),
2),),)
# This source is invalid because the input _drop_0 should have been
# substituted away by the above subs.
source = Tensor(
np.array([[[-1.7461166381835938, -3.480717658996582],
[-5.133678436279297, -3.3990774154663086]]], dtype=np.float32),
(('_time_states', Bint[1]),
('states', Bint[2]),
('_drop_0', Bint[2])),
'real')
reduced_vars = frozenset()

with pytest.raises(Exception):
Scatter(op, subs, source, reduced_vars)


# motivated by https://github.com/pyro-ppl/numpyro/pull/991
def test_infer_discrete_hmm_scan_1():
from math import inf
actual = Scatter(
ops.max,
(('_time_states', Slice('_time_states__BOUND_21', 1, 2, 2, 2)),
('_PREV_states', Variable('_drop_0__BOUND_20', Bint[2]))),
Contraction(ops.max, ops.add,
frozenset({Variable('_PREV_states__BOUND_14', Bint[2])}),
(Delta(
(('_drop_0__BOUND_20',
(Tensor(
np.array([[[0, 0], [0, 0]]], dtype=np.int32),
(('_time_states__BOUND_21', Bint[1]),
('_PREV_states__BOUND_14', Bint[2]),
('states', Bint[2])),
2),
Number(0.0),),),)),
Tensor(
np.array([[[-inf, -inf], [-inf, -inf]]], dtype=np.float32),
(('_time_states__BOUND_21', Bint[1]),
('_PREV_states__BOUND_14', Bint[2]),
('states', Bint[2])),
'real'),)),
frozenset({Variable('_drop_0__BOUND_20', Bint[2]), Variable('_time_states__BOUND_21', Bint[1])}))
assert isinstance(actual, Tensor), actual.pretty()


def test_infer_discrete_hmm_scan_2():
from math import inf, nan
actual = Contraction(ops.max, ops.add,
frozenset({Variable('states__BOUND_27', Bint[2])}),
(Delta(
(('_drop_0__BOUND_46',
(Tensor(
np.array([[[0, 1], [1, 1]]], dtype=np.int32),
(('_time_states__BOUND_45',
Bint[1],),
('_PREV_states',
Bint[2],),
('states__BOUND_27',
Bint[2],),),
2),
Number(0.0),),),)),
Tensor(
np.array([[[-2.827766537666321], [nan]], [[nan], [nan]]], dtype=np.float64), # noqa
(('_PREV_states',
Bint[2],),
('states__BOUND_27',
Bint[2],),
('_time_states__BOUND_45',
Bint[1],),),
'real'),))
assert isinstance(actual, Tensor), actual.pretty()


def test_infer_discrete_hmm_scan_3():
from math import inf
actual = Contraction(ops.null, ops.max,
frozenset(),
(Contraction(ops.null, ops.add,
frozenset(),
(Tensor(
np.array([[-2.110203742980957, -1.0666589736938477], [-1.5835977792739868, -3.236558437347412]], dtype=np.float32), # noqa
(('_time_states',
Bint[2],),
('states',
Bint[2],),),
'real'),
Scatter(
ops.max,
(('_time_states',
Slice('_time_states__BOUND_21', 1, 2, 2, 2),),
('_PREV_states',
Variable('_drop_0__BOUND_20', Bint[2]),),),
Contraction(ops.max, ops.add,
frozenset({Variable('_PREV_states__BOUND_14', Bint[2])}),
(Delta(
(('_drop_0__BOUND_20',
(Tensor(
np.array([[[0, 0], [0, 0]]], dtype=np.int32),
(('_time_states__BOUND_21',
Bint[1],),
('_PREV_states__BOUND_14',
Bint[2],),
('states',
Bint[2],),),
2),
Number(0.0),),),)),
Tensor(
np.array([[[-inf, -inf], [-inf, -inf]]], dtype=np.float32),
(('_time_states__BOUND_21',
Bint[1],),
('_PREV_states__BOUND_14',
Bint[2],),
('states',
Bint[2],),),
'real'),)),
frozenset({Variable('_drop_0__BOUND_20', Bint[2]), Variable('_time_states__BOUND_21', Bint[1])})),)), # noqa
Contraction(ops.null, ops.add,
frozenset(),
(Tensor(
np.array([[-2.110203742980957, -1.0666589736938477], [-1.5835977792739868, -3.236558437347412]], dtype=np.float32), # noqa
(('_time_states',
Bint[2],),
('states',
Bint[2],),),
'real'),
Scatter(
ops.max,
(('_time_states',
Slice('_time_states__BOUND_25', 0, 2, 2, 2),),
('states',
Variable('_drop_0__BOUND_24', Bint[2]),),),
Contraction(ops.max, ops.add,
frozenset({Variable('states__BOUND_13', Bint[2])}),
(Delta(
(('_drop_0__BOUND_24',
(Tensor(
np.array([[[0, 0], [0, 0]]], dtype=np.int32),
(('_time_states__BOUND_25',
Bint[1],),
('_PREV_states',
Bint[2],),
('states__BOUND_13',
Bint[2],),),
2),
Number(0.0),),),)),
Tensor(
np.array([[[-inf, -inf], [-inf, -inf]]], dtype=np.float32),
(('_time_states__BOUND_25',
Bint[1],),
('states__BOUND_13',
Bint[2],),
('_PREV_states',
Bint[2],),),
'real'),)),
frozenset({Variable('_drop_0__BOUND_24', Bint[2]), Variable('_time_states__BOUND_25', Bint[1])})),)), # noqa
Contraction(ops.null, ops.add,
frozenset(),
(Tensor(
np.array([[-2.110203742980957, -1.0666589736938477], [-1.5835977792739868, -3.236558437347412]], dtype=np.float32), # noqa
(('_time_states',
Bint[2],),
('states',
Bint[2],),),
'real'),
Contraction(ops.null, ops.max,
frozenset(),
(Scatter(
ops.max,
(('_time_states',
Slice('_time_states__BOUND_37', 1, 2, 2, 2),),
('_PREV_states',
Variable('_drop_0__BOUND_36', Bint[2]),),),
Contraction(ops.max, ops.add,
frozenset({Variable('_PREV_states__BOUND_28', Bint[2])}),
(Delta(
(('_drop_0__BOUND_36',
(Tensor(
np.array([[[0, 1], [1, 1]]], dtype=np.int32),
(('_time_states__BOUND_37',
Bint[1],),
('_PREV_states__BOUND_28',
Bint[2],),
('states',
Bint[2],),),
2),
Number(0.0),),),)),
Tensor(
np.array([[[-2.2727227210998535, -2.9637789726257324], [-1.2291778326034546, -1.2291778326034546]]], dtype=np.float32), # noqa
(('_time_states__BOUND_37',
Bint[1],),
('_PREV_states__BOUND_28',
Bint[2],),
('states',
Bint[2],),),
'real'),
Scatter(
ops.max,
(('_time_states__BOUND_37',
Number(0, 1),),),
Contraction(ops.null, ops.add,
frozenset(),
(Delta(
(('states',
(Tensor(
np.array(0, dtype=np.int32),
(),
2),
Number(0.0),),),
('_PREV_states__BOUND_28',
(Tensor(
np.array(0, dtype=np.int32),
(),
2),
Number(0.0),),),)),
Tensor(
np.array(-1.081649899482727, dtype=np.float32),
(),
'real'),)),
frozenset()),)),
frozenset({Variable('_drop_0__BOUND_36', Bint[2]), Variable('_time_states__BOUND_37', Bint[1])})), # noqa
Scatter(
ops.max,
(('_time_states',
Slice('_time_states__BOUND_45', 0, 2, 2, 2),),
('states',
Variable('_drop_0__BOUND_46', Bint[2]),),),
Contraction(ops.max, ops.add,
frozenset({Variable('states__BOUND_27', Bint[2])}),
(Scatter(
ops.max,
(('_time_states__BOUND_45',
Number(0, 1),),),
Contraction(ops.null, ops.add,
frozenset(),
(Delta(
(('_PREV_states',
(Tensor(
np.array(0, dtype=np.int32),
(),
2),
Number(0.0),),),
('states__BOUND_27',
(Tensor(
np.array(0, dtype=np.int32),
(),
2),
Number(0.0),),),)),
Tensor(
np.array(-1.081649899482727, dtype=np.float32),
(),
'real'),)),
frozenset()),
Delta(
(('_drop_0__BOUND_46',
(Tensor(
np.array([[[0, 1], [1, 1]]], dtype=np.int32),
(('_time_states__BOUND_45',
Bint[1],),
('_PREV_states',
Bint[2],),
('states__BOUND_27',
Bint[2],),),
2),
Number(0.0),),),)),
Tensor(
np.array([[[-1.7461166381835938, -3.480717658996582], [-3.3990774154663086, -3.3990774154663086]]], dtype=np.float32), # noqa
(('_time_states__BOUND_45',
Bint[1],),
('states__BOUND_27',
Bint[2],),
('_PREV_states',
Bint[2],),),
'real'),)),
frozenset({Variable('_time_states__BOUND_45', Bint[1]), Variable('_drop_0__BOUND_46', Bint[2])})),)),)),)) # noqa
assert isinstance(actual, Tensor), actual.quote()