Skip to content

Commit

Permalink
Add Finitary funsor for representing lazy op application (#423)
Browse files Browse the repository at this point in the history
* Add Finitary funsor for lazy op application

* format

* variadic tuple

* fix tests

* fix mvn_affine test

* update Einsum interface
  • Loading branch information
eb8680 authored Feb 23, 2021
1 parent 16c713e commit 0344765
Show file tree
Hide file tree
Showing 6 changed files with 91 additions and 76 deletions.
13 changes: 7 additions & 6 deletions funsor/affine.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,10 @@

import opt_einsum

from funsor.domains import Bint
from funsor.interpreter import gensym
from funsor.tensor import Einsum, Tensor, get_default_prototype
from funsor.terms import Binary, Bint, Funsor, Lambda, Reduce, Unary, Variable
from funsor.tensor import EinsumOp, Tensor, get_default_prototype
from funsor.terms import Binary, Finitary, Funsor, Lambda, Reduce, Unary, Variable

from . import ops

Expand Down Expand Up @@ -91,12 +92,12 @@ def _(fn):
return affine_inputs(fn.arg) - fn.reduced_vars


@affine_inputs.register(Einsum)
@affine_inputs.register(Finitary[EinsumOp, tuple])
def _(fn):
# This is simply a multiary version of the above Binary(ops.mul, ...) case.
results = []
for i, x in enumerate(fn.operands):
others = fn.operands[:i] + fn.operands[i + 1 :]
for i, x in enumerate(fn.args):
others = fn.args[:i] + fn.args[i + 1 :]
other_inputs = reduce(ops.or_, map(_real_inputs, others), frozenset())
results.append(affine_inputs(x) - other_inputs)
# This multilinear case introduces incompleteness, since some vars
Expand All @@ -114,7 +115,7 @@ def extract_affine(fn):
x = ...
const, coeffs = extract_affine(x)
y = sum(Einsum(eqn, (coeff, Variable(var, coeff.output)))
y = sum(Einsum(eqn, coeff, Variable(var, coeff.output))
for var, (coeff, eqn) in coeffs.items())
assert_close(y, x)
assert frozenset(coeffs) == affine_inputs(x)
Expand Down
122 changes: 61 additions & 61 deletions funsor/tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from .ops import GetitemOp, MatmulOp, Op, ReshapeOp
from .terms import (
Binary,
Finitary,
Funsor,
FunsorMeta,
Lambda,
Expand Down Expand Up @@ -770,6 +771,13 @@ def eager_getitem_tensor_tensor(op, lhs, rhs):
return Tensor(data, inputs, lhs.dtype)


@eager.register(Finitary, Op, typing.Tuple[Tensor, ...])
def eager_finitary_generic_tensors(op, args):
inputs, raw_args = align_tensors(*args)
raw_result = op(*raw_args)
return Tensor(raw_result, inputs, args[0].dtype)


@eager.register(Lambda, Variable, Tensor)
def eager_lambda(var, expr):
inputs = expr.inputs.copy()
Expand Down Expand Up @@ -1019,7 +1027,30 @@ def max_and_argmax(x: Reals[8]) -> Tuple[Real, Bint[8]]:
return functools.partial(_function, inputs, output)


class Einsum(Funsor):
class EinsumOp(ops.Op, metaclass=ops.CachedOpMeta):
def __init__(self, equation):
self.equation = equation


@find_domain.register(EinsumOp)
def _find_domain_einsum(op, *operands):
equation = op.equation
ein_inputs, ein_output = equation.split("->")
ein_inputs = ein_inputs.split(",")
size_dict = {}
for ein_input, x in zip(ein_inputs, operands):
assert x.dtype == "real"
assert len(ein_input) == len(x.shape)
for name, size in zip(ein_input, x.shape):
other_size = size_dict.setdefault(name, size)
if other_size != size:
raise ValueError(
"Size mismatch at {}: {} vs {}".format(name, size, other_size)
)
return Reals[tuple(size_dict[d] for d in ein_output)]


def Einsum(equation, *operands):
"""
Wrapper around :func:`torch.einsum` or :func:`np.einsum` to operate on real-valued Funsors.
Expand All @@ -1030,70 +1061,39 @@ class Einsum(Funsor):
:param str equation: An :func:`torch.einsum` or :func:`np.einsum` equation.
:param tuple operands: A tuple of input funsors.
"""

def __init__(self, equation, operands):
assert isinstance(equation, str)
assert isinstance(operands, tuple)
assert all(isinstance(x, Funsor) for x in operands)
ein_inputs, ein_output = equation.split("->")
ein_inputs = ein_inputs.split(",")
size_dict = {}
inputs = OrderedDict()
assert len(ein_inputs) == len(operands)
for ein_input, x in zip(ein_inputs, operands):
assert x.dtype == "real"
inputs.update(x.inputs)
assert len(ein_input) == len(x.output.shape)
for name, size in zip(ein_input, x.output.shape):
other_size = size_dict.setdefault(name, size)
if other_size != size:
raise ValueError(
"Size mismatch at {}: {} vs {}".format(name, size, other_size)
)
output = Reals[tuple(size_dict[d] for d in ein_output)]
super(Einsum, self).__init__(inputs, output)
self.equation = equation
self.operands = operands

def __repr__(self):
return "Einsum({}, {})".format(repr(self.equation), repr(self.operands))

def __str__(self):
return "Einsum({}, {})".format(repr(self.equation), str(self.operands))
return Finitary(EinsumOp(equation), tuple(operands))


@eager.register(Einsum, str, tuple)
def eager_einsum(equation, operands):
if all(isinstance(x, Tensor) for x in operands):
# Make new symbols for inputs of operands.
inputs = OrderedDict()
for x in operands:
inputs.update(x.inputs)
symbols = set(equation)
get_symbol = iter(map(opt_einsum.get_symbol, itertools.count()))
new_symbols = {}
for k in inputs:
@eager.register(Finitary, EinsumOp, typing.Tuple[Tensor, ...])
def eager_einsum(op, operands):
# Make new symbols for inputs of operands.
equation = op.equation
inputs = OrderedDict()
for x in operands:
inputs.update(x.inputs)
symbols = set(equation)
get_symbol = iter(map(opt_einsum.get_symbol, itertools.count()))
new_symbols = {}
for k in inputs:
symbol = next(get_symbol)
while symbol in symbols:
symbol = next(get_symbol)
while symbol in symbols:
symbol = next(get_symbol)
symbols.add(symbol)
new_symbols[k] = symbol

# Manually broadcast using einsum symbols.
assert "." not in equation
ins, out = equation.split("->")
ins = ins.split(",")
ins = [
"".join(new_symbols[k] for k in x.inputs) + x_out
for x, x_out in zip(operands, ins)
]
out = "".join(new_symbols[k] for k in inputs) + out
equation = ",".join(ins) + "->" + out
symbols.add(symbol)
new_symbols[k] = symbol

data = ops.einsum(equation, *[x.data for x in operands])
return Tensor(data, inputs)
# Manually broadcast using einsum symbols.
assert "." not in equation
ins, out = equation.split("->")
ins = ins.split(",")
ins = [
"".join(new_symbols[k] for k in x.inputs) + x_out
for x, x_out in zip(operands, ins)
]
out = "".join(new_symbols[k] for k in inputs) + out
equation = ",".join(ins) + "->" + out

return None # defer to default implementation
data = ops.einsum(equation, *[x.data for x in operands])
return Tensor(data, inputs)


def tensordot(x, y, dims):
Expand Down Expand Up @@ -1129,7 +1129,7 @@ def tensordot(x, y, dims):
symbols[y_start:y_end],
symbols[x_start:y_start] + symbols[x_end:y_end],
)
return Einsum(equation, (x, y))
return Einsum(equation, x, y)


def stack(parts, dim=0):
Expand Down
14 changes: 14 additions & 0 deletions funsor/terms.py
Original file line number Diff line number Diff line change
Expand Up @@ -1289,6 +1289,20 @@ def eager_binary_align_align(op, lhs, rhs):
return Binary(op, lhs.arg, rhs.arg)


class Finitary(Funsor):
def __init__(self, op, args):
assert isinstance(op, ops.Op)
assert isinstance(args, tuple)
assert all(isinstance(v, Funsor) for v in args)
inputs = OrderedDict()
for arg in args:
inputs.update(arg.inputs)
output = find_domain(op, *(arg.output for arg in args))
super().__init__(inputs, output)
self.op = op
self.args = args


class Stack(Funsor):
"""
Stack of funsors along a new input dimension.
Expand Down
12 changes: 6 additions & 6 deletions test/test_affine.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from funsor.cnf import Contraction
from funsor.domains import Bint, Real, Reals # noqa: F401
from funsor.tensor import Einsum, Tensor
from funsor.terms import Number, Unary, Variable
from funsor.terms import Finitary, Number, Unary, Variable
from funsor.testing import (
assert_close,
check_funsor,
Expand Down Expand Up @@ -107,17 +107,17 @@ def test_affine_subs(expr, expected_type, expected_inputs):
"Variable('x', Reals[2]) * randn(2) + ones(2)",
"Variable('x', Reals[2]) + Tensor(randn(3, 2), OrderedDict(i=Bint[3]))",
"Einsum('abcd,ac->bd',"
" (Tensor(randn(2, 3, 4, 5)), Variable('x', Reals[2, 4])))",
" Tensor(randn(2, 3, 4, 5)), Variable('x', Reals[2, 4]))",
"Tensor(randn(3, 5)) + Einsum('abcd,ac->bd',"
" (Tensor(randn(2, 3, 4, 5)), Variable('x', Reals[2, 4])))",
" Tensor(randn(2, 3, 4, 5)), Variable('x', Reals[2, 4]))",
"Variable('x', Reals[2, 8])[0] + randn(8)",
"Variable('x', Reals[2, 8])[Variable('i', Bint[2])] / 4 - 3.5",
],
)
def test_extract_affine(expr):
x = eval(expr)
assert is_affine(x)
assert isinstance(x, (Unary, Contraction, Einsum))
assert isinstance(x, (Unary, Contraction, Finitary))
real_inputs = OrderedDict((k, d) for k, d in x.inputs.items() if d.dtype == "real")

const, coeffs = extract_affine(x)
Expand All @@ -134,7 +134,7 @@ def test_extract_affine(expr):
assert isinstance(expected, Tensor)

actual = const + sum(
Einsum(eqn, (coeff, subs[k])) for k, (coeff, eqn) in coeffs.items()
Einsum(eqn, coeff, subs[k]) for k, (coeff, eqn) in coeffs.items()
)
assert isinstance(actual, Tensor)
assert_close(actual, expected)
Expand All @@ -157,7 +157,7 @@ def test_extract_affine(expr):
"Variable('x', Reals[2,3]) @ Variable('y', Reals[3,4])",
"random_gaussian(OrderedDict(x=Real))",
"Einsum('abcd,ac->bd',"
" (Variable('y', Reals[2, 3, 4, 5]), Variable('x', Reals[2, 4])))",
" Variable('y', Reals[2, 3, 4, 5]), Variable('x', Reals[2, 4]))",
],
)
def test_not_is_affine(expr):
Expand Down
2 changes: 1 addition & 1 deletion test/test_gaussian.py
Original file line number Diff line number Diff line change
Expand Up @@ -368,7 +368,7 @@ def test_eager_subs_variable():
(
(
"y",
'Einsum("abc,bc->a", (Tensor(randn((4, 3, 5))), Variable("v", Reals[3, 5])))',
'Einsum("abc,bc->a", Tensor(randn((4, 3, 5))), Variable("v", Reals[3, 5]))',
),
),
],
Expand Down
4 changes: 2 additions & 2 deletions test/test_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -950,7 +950,7 @@ def test_einsum(equation):
tensors = [randn(tuple(sizes[d] for d in dims)) for dims in inputs]
funsors = [Tensor(x) for x in tensors]
expected = Tensor(ops.einsum(equation, *tensors))
actual = Einsum(equation, tuple(funsors))
actual = Einsum(equation, *funsors)
assert_close(actual, expected, atol=1e-5, rtol=None)


Expand All @@ -968,7 +968,7 @@ def test_batched_einsum(equation, batch1, batch2):
random_tensor(batch, Reals[tuple(sizes[d] for d in dims)])
for batch, dims in zip([batch1, batch2], inputs)
]
actual = Einsum(equation, tuple(funsors))
actual = Einsum(equation, *funsors)

_equation = ",".join("..." + i for i in inputs) + "->..." + output
inputs, tensors = align_tensors(*funsors)
Expand Down

0 comments on commit 0344765

Please sign in to comment.