diff --git a/funsor/affine.py b/funsor/affine.py index 630502ec9..86e2fabf2 100644 --- a/funsor/affine.py +++ b/funsor/affine.py @@ -7,8 +7,8 @@ import opt_einsum from funsor.interpreter import gensym -from funsor.tensor import Einsum, Tensor, get_default_prototype -from funsor.terms import Binary, Funsor, Lambda, Reduce, Unary, Variable, Bint +from funsor.tensor import EinsumOp, Tensor, get_default_prototype +from funsor.terms import Binary, Finitary, Funsor, Lambda, Reduce, Unary, Variable, Bint from . import ops @@ -91,12 +91,12 @@ def _(fn): return affine_inputs(fn.arg) - fn.reduced_vars -@affine_inputs.register(Einsum) +@affine_inputs.register(Finitary[EinsumOp, tuple]) def _(fn): # This is simply a multiary version of the above Binary(ops.mul, ...) case. results = [] - for i, x in enumerate(fn.operands): - others = fn.operands[:i] + fn.operands[i+1:] + for i, x in enumerate(fn.args): + others = fn.args[:i] + fn.args[i+1:] other_inputs = reduce(ops.or_, map(_real_inputs, others), frozenset()) results.append(affine_inputs(x) - other_inputs) # This multilinear case introduces incompleteness, since some vars diff --git a/funsor/tensor.py b/funsor/tensor.py index 0dc1a9542..5f420f4f0 100644 --- a/funsor/tensor.py +++ b/funsor/tensor.py @@ -21,6 +21,7 @@ from funsor.ops import GetitemOp, MatmulOp, Op, ReshapeOp from funsor.terms import ( Binary, + Finitary, Funsor, FunsorMeta, Lambda, @@ -949,7 +950,29 @@ def max_and_argmax(x: Reals[8]) -> Tuple[Real, Bint[8]]: return functools.partial(_function, inputs, output) -class Einsum(Funsor): +class EinsumOp(ops.Op, metaclass=ops.OpCacheMeta): + def __init__(self, equation): + self.equation = equation + + +@find_domain.register(EinsumOp) +def _find_domain_einsum(op, *operands): + equation = op.equation + ein_inputs, ein_output = equation.split('->') + ein_inputs = ein_inputs.split(',') + size_dict = {} + for ein_input, x in zip(ein_inputs, operands): + assert x.dtype == 'real' + assert len(ein_input) == len(x.output.shape) + for name, size in zip(ein_input, x.output.shape): + other_size = size_dict.setdefault(name, size) + if other_size != size: + raise ValueError("Size mismatch at {}: {} vs {}" + .format(name, size, other_size)) + return Reals[tuple(size_dict[d] for d in ein_output)] + + +def Einsum(equation, operands): """ Wrapper around :func:`torch.einsum` or :func:`np.einsum` to operate on real-valued Funsors. @@ -960,40 +983,14 @@ class Einsum(Funsor): :param str equation: An :func:`torch.einsum` or :func:`np.einsum` equation. :param tuple operands: A tuple of input funsors. """ - def __init__(self, equation, operands): - assert isinstance(equation, str) - assert isinstance(operands, tuple) - assert all(isinstance(x, Funsor) for x in operands) - ein_inputs, ein_output = equation.split('->') - ein_inputs = ein_inputs.split(',') - size_dict = {} - inputs = OrderedDict() - assert len(ein_inputs) == len(operands) - for ein_input, x in zip(ein_inputs, operands): - assert x.dtype == 'real' - inputs.update(x.inputs) - assert len(ein_input) == len(x.output.shape) - for name, size in zip(ein_input, x.output.shape): - other_size = size_dict.setdefault(name, size) - if other_size != size: - raise ValueError("Size mismatch at {}: {} vs {}" - .format(name, size, other_size)) - output = Reals[tuple(size_dict[d] for d in ein_output)] - super(Einsum, self).__init__(inputs, output) - self.equation = equation - self.operands = operands - - def __repr__(self): - return 'Einsum({}, {})'.format(repr(self.equation), repr(self.operands)) - - def __str__(self): - return 'Einsum({}, {})'.format(repr(self.equation), str(self.operands)) + return Finitary(EinsumOp(equation), tuple(operands)) -@eager.register(Einsum, str, tuple) -def eager_einsum(equation, operands): +@eager.register(Finitary, EinsumOp, tuple) +def eager_einsum(op, operands): if all(isinstance(x, Tensor) for x in operands): # Make new symbols for inputs of operands. + equation = op.equation inputs = OrderedDict() for x in operands: inputs.update(x.inputs)