Skip to content

Commit

Permalink
Make Einsum an Op
Browse files Browse the repository at this point in the history
  • Loading branch information
eb8680 committed Jan 14, 2021
1 parent 69d6e14 commit f23d7f8
Show file tree
Hide file tree
Showing 2 changed files with 33 additions and 36 deletions.
10 changes: 5 additions & 5 deletions funsor/affine.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,8 @@
import opt_einsum

from funsor.interpreter import gensym
from funsor.tensor import Einsum, Tensor, get_default_prototype
from funsor.terms import Binary, Funsor, Lambda, Reduce, Unary, Variable, Bint
from funsor.tensor import EinsumOp, Tensor, get_default_prototype
from funsor.terms import Binary, Finitary, Funsor, Lambda, Reduce, Unary, Variable, Bint

from . import ops

Expand Down Expand Up @@ -91,12 +91,12 @@ def _(fn):
return affine_inputs(fn.arg) - fn.reduced_vars


@affine_inputs.register(Einsum)
@affine_inputs.register(Finitary[EinsumOp, tuple])
def _(fn):
# This is simply a multiary version of the above Binary(ops.mul, ...) case.
results = []
for i, x in enumerate(fn.operands):
others = fn.operands[:i] + fn.operands[i+1:]
for i, x in enumerate(fn.args):
others = fn.args[:i] + fn.args[i+1:]
other_inputs = reduce(ops.or_, map(_real_inputs, others), frozenset())
results.append(affine_inputs(x) - other_inputs)
# This multilinear case introduces incompleteness, since some vars
Expand Down
59 changes: 28 additions & 31 deletions funsor/tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from funsor.ops import GetitemOp, MatmulOp, Op, ReshapeOp
from funsor.terms import (
Binary,
Finitary,
Funsor,
FunsorMeta,
Lambda,
Expand Down Expand Up @@ -949,7 +950,29 @@ def max_and_argmax(x: Reals[8]) -> Tuple[Real, Bint[8]]:
return functools.partial(_function, inputs, output)


class Einsum(Funsor):
class EinsumOp(ops.Op, metaclass=ops.OpCacheMeta):
def __init__(self, equation):
self.equation = equation


@find_domain.register(EinsumOp)
def _find_domain_einsum(op, *operands):
equation = op.equation
ein_inputs, ein_output = equation.split('->')
ein_inputs = ein_inputs.split(',')
size_dict = {}
for ein_input, x in zip(ein_inputs, operands):
assert x.dtype == 'real'
assert len(ein_input) == len(x.output.shape)
for name, size in zip(ein_input, x.output.shape):
other_size = size_dict.setdefault(name, size)
if other_size != size:
raise ValueError("Size mismatch at {}: {} vs {}"
.format(name, size, other_size))
return Reals[tuple(size_dict[d] for d in ein_output)]


def Einsum(equation, operands):
"""
Wrapper around :func:`torch.einsum` or :func:`np.einsum` to operate on real-valued Funsors.
Expand All @@ -960,40 +983,14 @@ class Einsum(Funsor):
:param str equation: An :func:`torch.einsum` or :func:`np.einsum` equation.
:param tuple operands: A tuple of input funsors.
"""
def __init__(self, equation, operands):
assert isinstance(equation, str)
assert isinstance(operands, tuple)
assert all(isinstance(x, Funsor) for x in operands)
ein_inputs, ein_output = equation.split('->')
ein_inputs = ein_inputs.split(',')
size_dict = {}
inputs = OrderedDict()
assert len(ein_inputs) == len(operands)
for ein_input, x in zip(ein_inputs, operands):
assert x.dtype == 'real'
inputs.update(x.inputs)
assert len(ein_input) == len(x.output.shape)
for name, size in zip(ein_input, x.output.shape):
other_size = size_dict.setdefault(name, size)
if other_size != size:
raise ValueError("Size mismatch at {}: {} vs {}"
.format(name, size, other_size))
output = Reals[tuple(size_dict[d] for d in ein_output)]
super(Einsum, self).__init__(inputs, output)
self.equation = equation
self.operands = operands

def __repr__(self):
return 'Einsum({}, {})'.format(repr(self.equation), repr(self.operands))

def __str__(self):
return 'Einsum({}, {})'.format(repr(self.equation), str(self.operands))
return Finitary(EinsumOp(equation), tuple(operands))


@eager.register(Einsum, str, tuple)
def eager_einsum(equation, operands):
@eager.register(Finitary, EinsumOp, tuple)
def eager_einsum(op, operands):
if all(isinstance(x, Tensor) for x in operands):
# Make new symbols for inputs of operands.
equation = op.equation
inputs = OrderedDict()
for x in operands:
inputs.update(x.inputs)
Expand Down

0 comments on commit f23d7f8

Please sign in to comment.