Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor ops to allow non-funsor parameters #491

Merged
merged 23 commits into from
Mar 18, 2021
Merged
Show file tree
Hide file tree
Changes from 10 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 3 additions & 11 deletions funsor/adjoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from collections import defaultdict
from collections.abc import Hashable

from funsor.cnf import Contraction, nullop
from funsor.cnf import Contraction, null
from funsor.interpretations import Interpretation, reflect
from funsor.interpreter import stack_reinterpret
from funsor.ops import AssociativeOp
Expand Down Expand Up @@ -49,15 +49,7 @@ def interpret(self, cls, *args):
self.tape.append((result, cls, args))
else:
result = self._old_interpretation.interpret(cls, *args)
lazy_args = [
self._eager_to_lazy.get(
id(arg)
if ops.is_numeric_array(arg) or not isinstance(arg, Hashable)
else arg,
arg,
)
for arg in args
]
lazy_args = [self._eager_to_lazy.get(arg, arg) for arg in args]
fritzo marked this conversation as resolved.
Show resolved Hide resolved
self._eager_to_lazy[result] = reflect.interpret(cls, *lazy_args)
return result

Expand Down Expand Up @@ -233,7 +225,7 @@ def adjoint_contract_generic(
def adjoint_contract(
adj_sum_op, adj_prod_op, out_adj, sum_op, prod_op, reduced_vars, lhs, rhs
):
if prod_op is adj_prod_op and sum_op in (nullop, adj_sum_op):
if prod_op is adj_prod_op and sum_op in (null, adj_sum_op):

# the only change is here:
out_adj = Approximate(
Expand Down
4 changes: 2 additions & 2 deletions funsor/affine.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

from funsor.domains import Bint
from funsor.interpreter import gensym
from funsor.tensor import EinsumOp, Tensor, get_default_prototype
from funsor.tensor import Tensor, get_default_prototype
from funsor.terms import Binary, Finitary, Funsor, Lambda, Reduce, Unary, Variable

from . import ops
Expand Down Expand Up @@ -92,7 +92,7 @@ def _(fn):
return affine_inputs(fn.arg) - fn.reduced_vars


@affine_inputs.register(Finitary[EinsumOp, tuple])
@affine_inputs.register(Finitary[ops.EinsumOp, tuple])
def _(fn):
# This is simply a multiary version of the above Binary(ops.mul, ...) case.
results = []
Expand Down
36 changes: 18 additions & 18 deletions funsor/cnf.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from funsor.gaussian import Gaussian
from funsor.interpretations import eager, normalize, reflect
from funsor.interpreter import recursion_reinterpret
from funsor.ops import DISTRIBUTIVE_OPS, AssociativeOp, NullOp, nullop
from funsor.ops import DISTRIBUTIVE_OPS, AssociativeOp, NullOp, null
from funsor.tensor import Tensor
from funsor.terms import (
_INFIX,
Expand Down Expand Up @@ -69,7 +69,7 @@ def __init__(self, red_op, bin_op, reduced_vars, terms):
for v in terms:
inputs.update((k, d) for k, d in v.inputs.items() if k not in bound)

if bin_op is nullop:
if bin_op is null:
output = terms[0].output
else:
output = reduce(
Expand Down Expand Up @@ -107,8 +107,8 @@ def unscaled_sample(self, sampled_vars, sample_inputs, rng_key=None):
if not sampled_vars:
return self

if self.red_op in (ops.logaddexp, nullop):
if self.bin_op in (ops.nullop, ops.logaddexp):
if self.red_op in (ops.logaddexp, null):
if self.bin_op in (ops.null, ops.logaddexp):
if rng_key is not None and get_backend() == "jax":
import jax

Expand Down Expand Up @@ -272,7 +272,7 @@ def eager_contraction_generic_recursive(red_op, bin_op, reduced_vars, terms):
if unique_vars:
result = term.reduce(red_op, unique_vars)
if result is not normalize.interpret(
Contraction, red_op, nullop, unique_vars, (term,)
Contraction, red_op, null, unique_vars, (term,)
):
terms[i] = result
reduced_vars -= unique_vars
Expand Down Expand Up @@ -427,7 +427,7 @@ def normalize_contraction_commutative_canonical_order(
)
def normalize_contraction_commute_joint(red_op, bin_op, reduced_vars, mixture, other):
return Contraction(
mixture.red_op if red_op is nullop else red_op,
mixture.red_op if red_op is null else red_op,
bin_op,
reduced_vars | mixture.reduced_vars,
*(mixture.terms + (other,))
Expand All @@ -439,7 +439,7 @@ def normalize_contraction_commute_joint(red_op, bin_op, reduced_vars, mixture, o
)
def normalize_contraction_commute_joint(red_op, bin_op, reduced_vars, other, mixture):
return Contraction(
mixture.red_op if red_op is nullop else red_op,
mixture.red_op if red_op is null else red_op,
bin_op,
reduced_vars | mixture.reduced_vars,
*(mixture.terms + (other,))
Expand All @@ -462,13 +462,13 @@ def normalize_trivial(red_op, bin_op, reduced_vars, term):
@normalize.register(Contraction, AssociativeOp, AssociativeOp, frozenset, tuple)
def normalize_contraction_generic_tuple(red_op, bin_op, reduced_vars, terms):

if not reduced_vars and red_op is not nullop:
return Contraction(nullop, bin_op, reduced_vars, *terms)
if not reduced_vars and red_op is not null:
return Contraction(null, bin_op, reduced_vars, *terms)

if len(terms) == 1 and bin_op is not nullop:
return Contraction(red_op, nullop, reduced_vars, *terms)
if len(terms) == 1 and bin_op is not null:
return Contraction(red_op, null, reduced_vars, *terms)

if red_op is nullop and bin_op is nullop:
if red_op is null and bin_op is null:
return terms[0]

if red_op is bin_op:
Expand All @@ -493,11 +493,11 @@ def normalize_contraction_generic_tuple(red_op, bin_op, reduced_vars, terms):
continue

# fuse operations without distributing
if (v.red_op is nullop and bin_op is v.bin_op) or (
bin_op is nullop and v.red_op in (red_op, nullop)
if (v.red_op is null and bin_op is v.bin_op) or (
bin_op is null and v.red_op in (red_op, null)
):
red_op = v.red_op if red_op is nullop else red_op
bin_op = v.bin_op if bin_op is nullop else bin_op
red_op = v.red_op if red_op is null else red_op
bin_op = v.bin_op if bin_op is null else bin_op
new_terms = terms[:i] + v.terms + terms[i + 1 :]
return Contraction(
red_op, bin_op, reduced_vars | v.reduced_vars, *new_terms
Expand All @@ -514,12 +514,12 @@ def normalize_contraction_generic_tuple(red_op, bin_op, reduced_vars, terms):

@normalize.register(Binary, AssociativeOp, Funsor, Funsor)
def binary_to_contract(op, lhs, rhs):
return Contraction(nullop, op, frozenset(), lhs, rhs)
return Contraction(null, op, frozenset(), lhs, rhs)


@normalize.register(Reduce, AssociativeOp, Funsor, frozenset)
def reduce_funsor(op, arg, reduced_vars):
return Contraction(op, nullop, reduced_vars, arg)
return Contraction(op, null, reduced_vars, arg)


@normalize.register(
Expand Down
9 changes: 4 additions & 5 deletions funsor/distribution.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,6 @@
get_default_prototype,
ignore_jit_warnings,
numeric_array,
stack,
)
from funsor.terms import (
Funsor,
Expand Down Expand Up @@ -768,15 +767,15 @@ def LogNormal(loc, scale, value="value"):


def eager_beta(concentration1, concentration0, value):
concentration = stack((concentration0, concentration1))
value = stack((1 - value, value))
concentration = ops.stack((concentration0, concentration1))
value = ops.stack((1 - value, value))
backend_dist = import_module(BACKEND_TO_DISTRIBUTIONS_BACKEND[get_backend()])
return backend_dist.Dirichlet(concentration, value=value) # noqa: F821


def eager_binomial(total_count, probs, value):
probs = stack((1 - probs, probs))
value = stack((total_count - value, value))
probs = ops.stack((1 - probs, probs))
value = ops.stack((total_count - value, value))
backend_dist = import_module(BACKEND_TO_DISTRIBUTIONS_BACKEND[get_backend()])
return backend_dist.Multinomial(total_count, probs, value=value) # noqa: F821

Expand Down
38 changes: 35 additions & 3 deletions funsor/domains.py
Original file line number Diff line number Diff line change
Expand Up @@ -250,14 +250,15 @@ def _find_domain_log_exp(op, domain):

@find_domain.register(ops.ReshapeOp)
def _find_domain_reshape(op, domain):
return Array[domain.dtype, op.shape]
return Array[domain.dtype, op.defaults["shape"]]


@find_domain.register(ops.GetitemOp)
def _find_domain_getitem(op, lhs_domain, rhs_domain):
if isinstance(lhs_domain, ArrayType):
offset = op.defaults["offset"]
dtype = lhs_domain.dtype
shape = lhs_domain.shape[: op.offset] + lhs_domain.shape[1 + op.offset :]
shape = lhs_domain.shape[:offset] + lhs_domain.shape[1 + offset :]
return Array[dtype, shape]
elif isinstance(lhs_domain, ProductDomain):
# XXX should this return a Union?
Expand Down Expand Up @@ -342,7 +343,7 @@ def _find_domain_associative_generic(op, *domains):

@find_domain.register(ops.WrappedTransformOp)
def _transform_find_domain(op, domain):
fn = op.dispatch(object)
fn = op.defaults["fn"]
shape = fn.forward_shape(domain.shape)
return Array[domain.dtype, shape]

Expand All @@ -353,6 +354,37 @@ def _transform_log_abs_det_jacobian(op, domain, codomain):
return Real


@find_domain.register(ops.StackOp)
def _find_domain_stack(op, parts):
shape = broadcast_shape(*(x.shape for x in parts))
dim = op.defaults["dim"]
if dim >= 0:
dim = dim - len(shape) - 1
assert dim < 0
split = dim + len(shape) + 1
shape = shape[:split] + (len(parts),) + shape[split:]
output = Array[parts[0].dtype, shape]
return output


@find_domain.register(ops.EinsumOp)
def _find_domain_einsum(op, operands):
equation = op.defaults["equation"]
ein_inputs, ein_output = equation.split("->")
ein_inputs = ein_inputs.split(",")
size_dict = {}
for ein_input, x in zip(ein_inputs, operands):
assert x.dtype == "real"
assert len(ein_input) == len(x.shape)
for name, size in zip(ein_input, x.shape):
other_size = size_dict.setdefault(name, size)
if other_size != size:
raise ValueError(
"Size mismatch at {}: {} vs {}".format(name, size, other_size)
)
return Reals[tuple(size_dict[d] for d in ein_output)]


__all__ = [
"Bint",
"BintType",
Expand Down
2 changes: 1 addition & 1 deletion funsor/einsum/numpy_log.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ def einsum(equation, *operands):
shift = ops.permute(shift, [dims.index(dim) for dim in output])
shifts.append(shift)

result = ops.log(ops.einsum(equation, *exp_operands))
result = ops.log(ops.einsum(exp_operands, equation))
return sum(shifts + [result])


Expand Down
40 changes: 20 additions & 20 deletions funsor/gaussian.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,7 +137,7 @@ def as_tensor(self):

# Concatenate parts.
parts = [v for k, v in sorted(self.parts.items())]
result = ops.cat(-1, *parts)
result = ops.cat(parts, -1)
if not get_tracing_state():
assert result.shape == self.shape
return result
Expand Down Expand Up @@ -182,10 +182,10 @@ def as_tensor(self):
# TODO This could be optimized into a single .reshape().cat().reshape() if
# all inputs are contiguous, thereby saving a memcopy.
columns = {
i: ops.cat(-1, *[v for j, v in sorted(part.items())])
i: ops.cat([v for j, v in sorted(part.items())], -1)
for i, part in self.parts.items()
}
result = ops.cat(-2, *[v for i, v in sorted(columns.items())])
result = ops.cat([v for i, v in sorted(columns.items())], -2)
if not get_tracing_state():
assert result.shape == self.shape
return result
Expand Down Expand Up @@ -468,32 +468,32 @@ def _eager_subs_real(self, subs, remaining_subs):
k for k, d in self.inputs.items() if d.dtype == "real" and k not in b
)
prec_aa = ops.cat(
-2,
*[
ops.cat(-1, *[precision[..., i1, i2] for k2, i2 in slices if k2 in a])
[
ops.cat([precision[..., i1, i2] for k2, i2 in slices if k2 in a], -1)
for k1, i1 in slices
if k1 in a
]
],
-2,
)
prec_ab = ops.cat(
-2,
*[
ops.cat(-1, *[precision[..., i1, i2] for k2, i2 in slices if k2 in b])
[
ops.cat([precision[..., i1, i2] for k2, i2 in slices if k2 in b], -1)
for k1, i1 in slices
if k1 in a
]
],
-2,
)
prec_bb = ops.cat(
-2,
*[
ops.cat(-1, *[precision[..., i1, i2] for k2, i2 in slices if k2 in b])
[
ops.cat([precision[..., i1, i2] for k2, i2 in slices if k2 in b], -1)
for k1, i1 in slices
if k1 in b
]
],
-2,
)
info_a = ops.cat(-1, *[info_vec[..., i] for k, i in slices if k in a])
info_b = ops.cat(-1, *[info_vec[..., i] for k, i in slices if k in b])
value_b = ops.cat(-1, *[values[k] for k, i in slices if k in b])
info_a = ops.cat([info_vec[..., i] for k, i in slices if k in a], -1)
info_b = ops.cat([info_vec[..., i] for k, i in slices if k in b], -1)
value_b = ops.cat([values[k] for k, i in slices if k in b], -1)
info_vec = info_a - _mv(prec_ab, value_b)
log_scale = _vv(value_b, info_b - 0.5 * _mv(prec_bb, value_b))
precision = ops.expand(prec_aa, info_vec.shape + info_vec.shape[-1:])
Expand Down Expand Up @@ -637,8 +637,8 @@ def eager_reduce(self, op, reduced_vars):
1,
)
(b if key in reduced_vars else a).append(block)
a = ops.cat(-1, *a)
b = ops.cat(-1, *b)
a = ops.cat(a, -1)
b = ops.cat(b, -1)
prec_aa = self.precision[..., a[..., None], a]
prec_ba = self.precision[..., b[..., None], a]
prec_bb = self.precision[..., b[..., None], b]
Expand Down
2 changes: 1 addition & 1 deletion funsor/jax/distributions.py
Original file line number Diff line number Diff line change
Expand Up @@ -243,7 +243,7 @@ def deltadist_to_data(funsor_dist, name_to_dim=None):

@to_funsor.register(dist.transforms.Transform)
def transform_to_funsor(tfm, output=None, dim_to_name=None, real_inputs=None):
op = ops.WrappedTransformOp(tfm)
op = ops.WrappedTransformOp(fn=tfm)
name = next(real_inputs.keys()) if real_inputs else "value"
return op(Variable(name, output))

Expand Down
Loading