From 30a57228c72e803eae084de24600d165eca4341d Mon Sep 17 00:00:00 2001 From: Fritz Obermeyer Date: Mon, 15 Mar 2021 16:02:08 -0400 Subject: [PATCH 01/22] Refactor ops to have known arity --- funsor/gaussian.py | 32 +++--- funsor/jax/ops.py | 8 +- funsor/joint.py | 6 +- funsor/ops/array.py | 164 ++++++++++++++++------------- funsor/ops/builtin.py | 145 +++++++++++-------------- funsor/ops/op.py | 240 ++++++++++++++++++++++++++++-------------- funsor/tensor.py | 2 +- funsor/terms.py | 28 +++-- funsor/torch/ops.py | 8 +- test/test_tensor.py | 2 +- 10 files changed, 357 insertions(+), 278 deletions(-) diff --git a/funsor/gaussian.py b/funsor/gaussian.py index caa278c9c..35d8f329b 100644 --- a/funsor/gaussian.py +++ b/funsor/gaussian.py @@ -137,7 +137,7 @@ def as_tensor(self): # Concatenate parts. parts = [v for k, v in sorted(self.parts.items())] - result = ops.cat(-1, *parts) + result = ops.cat(parts, -1) if not get_tracing_state(): assert result.shape == self.shape return result @@ -182,10 +182,10 @@ def as_tensor(self): # TODO This could be optimized into a single .reshape().cat().reshape() if # all inputs are contiguous, thereby saving a memcopy. columns = { - i: ops.cat(-1, *[v for j, v in sorted(part.items())]) + i: ops.cat([v for j, v in sorted(part.items())], -1) for i, part in self.parts.items() } - result = ops.cat(-2, *[v for i, v in sorted(columns.items())]) + result = ops.cat([v for i, v in sorted(columns.items())], -2) if not get_tracing_state(): assert result.shape == self.shape return result @@ -468,32 +468,32 @@ def _eager_subs_real(self, subs, remaining_subs): k for k, d in self.inputs.items() if d.dtype == "real" and k not in b ) prec_aa = ops.cat( - -2, - *[ + [ ops.cat(-1, *[precision[..., i1, i2] for k2, i2 in slices if k2 in a]) for k1, i1 in slices if k1 in a - ] + ], + -2, ) prec_ab = ops.cat( - -2, *[ ops.cat(-1, *[precision[..., i1, i2] for k2, i2 in slices if k2 in b]) for k1, i1 in slices if k1 in a - ] + ], + -2 ) prec_bb = ops.cat( - -2, *[ - ops.cat(-1, *[precision[..., i1, i2] for k2, i2 in slices if k2 in b]) + ops.cat([precision[..., i1, i2] for k2, i2 in slices if k2 in b], -1) for k1, i1 in slices if k1 in b - ] + ], + -2 ) - info_a = ops.cat(-1, *[info_vec[..., i] for k, i in slices if k in a]) - info_b = ops.cat(-1, *[info_vec[..., i] for k, i in slices if k in b]) - value_b = ops.cat(-1, *[values[k] for k, i in slices if k in b]) + info_a = ops.cat([info_vec[..., i] for k, i in slices if k in a], -1) + info_b = ops.cat([info_vec[..., i] for k, i in slices if k in b], -1) + value_b = ops.cat([values[k] for k, i in slices if k in b], -1) info_vec = info_a - _mv(prec_ab, value_b) log_scale = _vv(value_b, info_b - 0.5 * _mv(prec_bb, value_b)) precision = ops.expand(prec_aa, info_vec.shape + info_vec.shape[-1:]) @@ -637,8 +637,8 @@ def eager_reduce(self, op, reduced_vars): 1, ) (b if key in reduced_vars else a).append(block) - a = ops.cat(-1, *a) - b = ops.cat(-1, *b) + a = ops.cat(a, -1) + b = ops.cat(b, -1) prec_aa = self.precision[..., a[..., None], a] prec_ba = self.precision[..., b[..., None], a] prec_bb = self.precision[..., b[..., None], b] diff --git a/funsor/jax/ops.py b/funsor/jax/ops.py index 420feb8ce..6c38531be 100644 --- a/funsor/jax/ops.py +++ b/funsor/jax/ops.py @@ -2,6 +2,7 @@ # SPDX-License-Identifier: Apache-2.0 import numbers +import typing import jax.numpy as np import numpy as onp @@ -67,11 +68,8 @@ def _astype(x, dtype): return x.astype(np.result_type(dtype)) -@ops.cat.register(int, [array]) -def _cat(dim, *x): - if len(x) == 1: - return x[0] - return np.concatenate(x, axis=dim) +ops.cat.register(typing.Tuple[array, ...])(np.concatenate) +ops.cat.register(typing.List[array, ...])(np.concatenate) @ops.cholesky.register(array) diff --git a/funsor/joint.py b/funsor/joint.py index ecda93a9b..3f7a78b44 100644 --- a/funsor/joint.py +++ b/funsor/joint.py @@ -60,8 +60,8 @@ def eager_cat_homogeneous(name, part_name, *parts): del int_inputs[part_name] dim = 0 - info_vec = ops.cat(dim, *info_vecs) - precision = ops.cat(dim, *precisions) + info_vec = ops.cat(info_vecs, dim) + precision = ops.cat(precisions, dim) inputs[name] = Bint[info_vec.shape[dim]] int_inputs[name] = inputs[name] result = Gaussian(info_vec, precision, inputs) @@ -69,7 +69,7 @@ def eager_cat_homogeneous(name, part_name, *parts): for i, d in enumerate(discretes): if d is None: discretes[i] = ops.new_zeros(info_vecs[i], info_vecs[i].shape[:-1]) - discrete = ops.cat(dim, *discretes) + discrete = ops.cat(discretes, dim) result = result + Tensor(discrete, int_inputs) return result diff --git a/funsor/ops/array.py b/funsor/ops/array.py index d31f95668..6c4c0fb37 100644 --- a/funsor/ops/array.py +++ b/funsor/ops/array.py @@ -3,6 +3,7 @@ import math import numbers +import typing import numpy as np @@ -21,7 +22,16 @@ sqrt, tanh, ) -from .op import DISTRIBUTIVE_OPS, UNITS, CachedOpMeta, Op, declare_op_types, make_op +from .op import ( + DISTRIBUTIVE_OPS, + UNITS, + BinaryOp, + CachedOpMeta, + Op, + TernaryOp, + UnaryOp, + declare_op_types, +) _builtin_all = all _builtin_any = any @@ -29,21 +39,14 @@ # This is used only for pattern matching. array = (np.ndarray, np.generic) -all = make_op(np.all) -amax = make_op(np.amax) -amin = make_op(np.amin) -any = make_op(np.any) -astype = make_op("astype") -cat = make_op("cat") -clamp = make_op("clamp") -diagonal = make_op("diagonal") -einsum = make_op("einsum") -full_like = make_op(np.full_like) -isnan = make_op(np.isnan) -prod = make_op(np.prod) -stack = make_op("stack") -sum = make_op(np.sum) -transpose = make_op("transpose") +all = UnaryOp.make(np.all) +amax = UnaryOp.make(np.amax) +amin = UnaryOp.make(np.amin) +any = UnaryOp.make(np.any) +full_like = UnaryOp.make(np.full_like) +isnan = UnaryOp.make(np.isnan) +prod = UnaryOp.make(np.prod) +sum = UnaryOp.make(np.sum) sqrt.register(array)(np.sqrt) exp.register(array)(np.exp) @@ -60,31 +63,28 @@ def _log(x): return np.log(x) -def _logaddexp(x, y): +@AssociativeOp.make +def logaddexp(x, y): shift = max(detach(x), detach(y)) return log(exp(x - shift) + exp(y - shift)) + shift -logaddexp = make_op(_logaddexp, AssociativeOp, name="logaddexp") -sample = make_op(_logaddexp, type(logaddexp), name="sample") +sample = logaddexp.make(logaddexp.default, name="sample") class ReshapeMeta(CachedOpMeta): - def __call__(cls, shape): - shape = tuple(shape) # necessary to convert torch.Size to tuple - return super().__call__(shape) + def _hash_args_kwargs(cls, shape): + return tuple(shape) # necessary to convert torch.Size to tuple -class ReshapeOp(Op, metaclass=ReshapeMeta): - def __init__(self, shape): - self.shape = shape - super().__init__(self._default) +@UnaryOp.make(metaclass=ReshapeMeta) +def reshape(x, shape): + return x.reshape(shape) - def __reduce__(self): - return ReshapeOp, (self.shape,) - def _default(self, x): - return x.reshape(self.shape) +@UnaryOp.make +def astype(x, dtype): + raise NotImplementedError @astype.register(array, str) @@ -92,20 +92,24 @@ def _astype(x, dtype): return x.astype(dtype) -@cat.register(int, [array]) -def _cat(dim, *x): - return np.concatenate(x, axis=dim) +@UnaryOp.make +def cat(parts, axis=0): + raise NotImplementedError + + +cat.register(typing.Tuple[array, ...])(np.concatenate) +cat.register(typing.List[array, ...])(np.concatenate) + +@UnaryOp.make +def clamp(x, min=None, max=None): + return min(max(x, min), max) -@clamp.register(array, numbers.Number, numbers.Number) -@clamp.register(array, numbers.Number, type(None)) -@clamp.register(array, type(None), numbers.Number) -@clamp.register(array, type(None), type(None)) -def _clamp(x, min, max): - return np.clip(x, a_min=min, a_max=max) +clamp.register(array)(np.clip) -@Op + +@UnaryOp.make def cholesky(x): """ Like :func:`numpy.linalg.cholesky` but uses sqrt for scalar matrices. @@ -115,7 +119,7 @@ def cholesky(x): return np.linalg.cholesky(x) -@Op +@UnaryOp.make def cholesky_inverse(x): """ Like :func:`torch.cholesky_inverse` but supports batching and gradients. @@ -123,29 +127,40 @@ def cholesky_inverse(x): return cholesky_solve(new_eye(x, x.shape[:-1]), x) -@Op +@BinaryOp.make def cholesky_solve(x, y): y_inv = np.linalg.inv(y) A = np.swapaxes(y_inv, -2, -1) @ y_inv return A @ x -@Op +@UnaryOp.make def detach(x): return x +@UnaryOp.make +def diagonal(x, dim1, dim2): + raise NotImplementedError + + @diagonal.register(array, int, int) def _diagonal(x, dim1, dim2): return np.diagonal(x, axis1=dim1, axis2=dim2) -@einsum.register(str, [array]) -def _einsum(x, *operand): - return np.einsum(x, *operand) +@UnaryOp.make +def einsum(operands, equation): + raise NotImplementedError + +@einsum.register(typing.Tuple[array, ...]) +@einsum.register(typing.List[array, ...]) +def _einsum(operands, equation): + return np.einsum(equation, *operands) -@Op + +@UnaryOp.make def expand(x, shape): prepend_dim = len(shape) - np.ndim(x) assert prepend_dim >= 0 @@ -155,12 +170,12 @@ def expand(x, shape): return np.broadcast_to(x, shape) -@Op +@UnaryOp.make def finfo(x): return np.finfo(x.dtype) -@Op +@UnaryOp.make def is_numeric_array(x): return True if isinstance(x, array) else False @@ -184,7 +199,7 @@ def _safe_logaddexp_tensor_number(x, y): return _safe_logaddexp_number_tensor(y, x) -@Op +@UnaryOp.make def logsumexp(x, dim): amax = np.amax(x, axis=dim, keepdims=True) # treat the case x = -inf @@ -192,9 +207,8 @@ def logsumexp(x, dim): return log(np.sum(np.exp(x - amax), axis=dim)) + amax.squeeze(axis=dim) -@max.register(array, array) -def _max(x, y): - return np.maximum(x, y) +max.register(array, array)(np.maximum) +min.register(array, array)(np.minimum) @max.register((int, float), array) @@ -207,11 +221,6 @@ def _max(x, y): return np.clip(x, a_min=y, a_max=None) -@min.register(array, array) -def _min(x, y): - return np.minimum(x, y) - - @min.register((int, float), array) def _min(x, y): return np.clip(y, a_min=None, a_max=x) @@ -222,7 +231,7 @@ def _min(x, y): return np.clip(x, a_min=None, a_max=y) -@Op +@UnaryOp.make def argmax(x, dim): raise NotImplementedError @@ -232,7 +241,7 @@ def _argmax(x, dim): return np.argmax(x, dim) -@Op +@UnaryOp.make def new_arange(x, stop): return np.arange(stop) @@ -242,23 +251,23 @@ def _new_arange(x, start, stop, step): return np.arange(start, stop, step) -@Op +@UnaryOp.make def new_zeros(x, shape): return np.zeros(shape, dtype=x.dtype) -@Op +@UnaryOp.make def new_full(x, shape, value): return np.full(shape, value, dtype=x.dtype) -@Op +@UnaryOp.make def new_eye(x, shape): n = shape[-1] return np.broadcast_to(np.eye(n), shape + (n,)) -@Op +@UnaryOp.make def permute(x, dims): return np.transpose(x, axes=dims) @@ -289,7 +298,7 @@ def _safesub(x, y): return x + np.clip(-y, a_min=None, a_max=finfo.max) -@Op +@TernaryOp.make def scatter(destin, indices, source): raise NotImplementedError @@ -301,7 +310,7 @@ def _scatter(destin, indices, source): return result -@Op +@TernaryOp.make def scatter_add(destin, indices, source): raise NotImplementedError @@ -313,24 +322,31 @@ def _scatter_add(destin, indices, source): return result -@stack.register(int, [array]) -def _stack(dim, *x): - return np.stack(x, axis=dim) +@UnaryOp.make +def stack(parts, axis=0): + raise NotImplementedError + + +stack.register(typing.Tuple[array, ...])(np.stack) +stack.register(typing.List[array, ...])(np.stack) + + +@UnaryOp.make +def transpose(array, axis1, axis2): + raise NotImplementedError -@transpose.register(array, int, int) -def _transpose(x, dim1, dim2): - return np.swapaxes(x, dim1, dim2) +transpose.register(array, int, int)(np.swapaxes) -@Op +@BinaryOp.make def triangular_solve(x, y, upper=False, transpose=False): if transpose: y = np.swapaxes(y, -2, -1) return np.linalg.inv(y) @ x -@Op +@UnaryOp.make def unsqueeze(x, dim): return np.expand_dims(x, axis=dim) diff --git a/funsor/ops/builtin.py b/funsor/ops/builtin.py index d7e7fd81d..a9290b916 100644 --- a/funsor/ops/builtin.py +++ b/funsor/ops/builtin.py @@ -17,7 +17,6 @@ TransformOp, UnaryOp, declare_op_types, - make_op, ) _builtin_abs = abs @@ -25,116 +24,94 @@ _builtin_sum = sum -def sigmoid(x): - return 1 / (1 + exp(-x)) +# FIXME Most code assumes this is an AssociativeCommutativeOp. +class AssociativeOp(BinaryOp): + pass -def softplus(x): - return log(1.0 + exp(x)) +@AssociativeOp.make +def nullop(x, y): + """Placeholder associative op that unifies with any other op""" + raise ValueError("should never actually evaluate this!") +@BinaryOp.make(metaclass=CachedOpMeta) +def getitem(lhs, rhs, offset=0): + if offset == 0: + return lhs[rhs] + return lhs[(slice(None),) * offset + (rhs,)] + + +abs = UnaryOp.make(_builtin_abs) +eq = BinaryOp.make(operator.eq) +ge = BinaryOp.make(operator.ge) +gt = BinaryOp.make(operator.gt) +invert = UnaryOp.make(operator.invert) +le = BinaryOp.make(operator.le) +lt = BinaryOp.make(operator.lt) +ne = BinaryOp.make(operator.ne) +pos = UnaryOp.make(operator.pos) +neg = UnaryOp.make(operator.neg) +pow = BinaryOp.make(operator.pow) +sub = BinaryOp.make(operator.sub) +truediv = BinaryOp.make(operator.truediv) +floordiv = BinaryOp.make(operator.floordiv) +add = AssociativeOp.make(operator.add) +and_ = AssociativeOp.make(operator.and_) +mul = AssociativeOp.make(operator.mul) +matmul = BinaryOp.make(operator.matmul) +mod = BinaryOp.make(operator.mod) +lshift = BinaryOp.make(operator.lshift) +rshift = BinaryOp.make(operator.rshift) +or_ = AssociativeOp.make(operator.or_) +xor = AssociativeOp.make(operator.xor) +max = AssociativeOp.make(max) +min = AssociativeOp.make(min) + +lgamma = UnaryOp.make(math.lgamma) +log1p = UnaryOp.make(math.log1p) +sqrt = UnaryOp.make(math.sqrt) + + +@UnaryOp.make def reciprocal(x): if isinstance(x, Number): return 1.0 / x raise ValueError("No reciprocal for type {}".format(type(x))) -# FIXME Most code assumes this is an AssociativeCommutativeOp. -class AssociativeOp(Op): - pass - +@UnaryOp.make +def softplus(x): + return log(1.0 + exp(x)) -class NullOp(AssociativeOp): - """Placeholder associative op that unifies with any other op""" - pass +@TransformOp.make +def log(x): + return math.log(x) if x > 0 else -math.inf -@NullOp -def nullop(x, y): - raise ValueError("should never actually evaluate this!") +exp = TransformOp.make(math.exp) +tanh = TransformOp.make(math.tanh) +atanh = TransformOp.make(math.atanh) -class GetitemOp(Op, metaclass=CachedOpMeta): - """ - Op encoding an index into one dimension, e.g. ``x[:,:,y]`` for offset of 2. - """ - - def __init__(self, offset): - assert isinstance(offset, int) - assert offset >= 0 - self.offset = offset - self._prefix = (slice(None),) * offset - super(GetitemOp, self).__init__(self._default) - self.__name__ = "GetitemOp({})".format(offset) - - def __reduce__(self): - return GetitemOp, (self.offset,) - - def _default(self, x, y): - return x[self._prefix + (y,)] if self.offset else x[y] - - -getitem = GetitemOp(0) -abs = make_op(_builtin_abs, UnaryOp) -eq = make_op(operator.eq, BinaryOp) -ge = make_op(operator.ge, BinaryOp) -gt = make_op(operator.gt, BinaryOp) -invert = make_op(operator.invert, UnaryOp) -le = make_op(operator.le, BinaryOp) -lt = make_op(operator.lt, BinaryOp) -ne = make_op(operator.ne, BinaryOp) -pos = make_op(operator.pos, UnaryOp) -neg = make_op(operator.neg, UnaryOp) -pow = make_op(operator.pow, BinaryOp) -sub = make_op(operator.sub, BinaryOp) -truediv = make_op(operator.truediv, BinaryOp) -floordiv = make_op(operator.floordiv, BinaryOp) -add = make_op(operator.add, AssociativeOp) -and_ = make_op(operator.and_, AssociativeOp) -mul = make_op(operator.mul, AssociativeOp) -matmul = make_op(operator.matmul, BinaryOp) -mod = make_op(operator.mod, BinaryOp) -lshift = make_op(operator.lshift, BinaryOp) -rshift = make_op(operator.rshift, BinaryOp) -or_ = make_op(operator.or_, AssociativeOp) -xor = make_op(operator.xor, AssociativeOp) -max = make_op(max, AssociativeOp) -min = make_op(min, AssociativeOp) - -lgamma = make_op(math.lgamma, UnaryOp) -log1p = make_op(math.log1p, UnaryOp) -sqrt = make_op(math.sqrt, UnaryOp) - -reciprocal = make_op(reciprocal, UnaryOp) -softplus = make_op(softplus, UnaryOp) - -exp = make_op(math.exp, TransformOp) -log = make_op( - lambda x: math.log(x) if x > 0 else -math.inf, parent=TransformOp, name="log" -) -tanh = make_op(math.tanh, TransformOp) -atanh = make_op(math.atanh, TransformOp) -sigmoid = make_op(sigmoid, TransformOp) +@TransformOp.make +def sigmoid(x): + return 1 / (1 + exp(-x)) -@make_op(parent=type(sub)) +@sub.make def safesub(x, y): if isinstance(y, Number): return sub(x, y) -@make_op(parent=type(truediv)) +@truediv.make def safediv(x, y): if isinstance(y, Number): return operator.truediv(x, y) -@add.register(object) -def _unary_add(x): - return x.sum() - - @exp.set_log_abs_det_jacobian def log_abs_det_jacobian(x, y): return add(x) @@ -198,8 +175,6 @@ def sigmoid_log_abs_det_jacobian(x, y): __all__ = [ "AssociativeOp", - "GetitemOp", - "NullOp", "abs", "add", "and_", diff --git a/funsor/ops/op.py b/funsor/ops/op.py index 12d6f5769..ade11721d 100644 --- a/funsor/ops/op.py +++ b/funsor/ops/op.py @@ -8,89 +8,111 @@ from multipledispatch import Dispatcher +def apply(function, args, kwargs={}): + return function(*args, **kwargs) + + class WeakPartial: - """ - Like ``functools.partial(fn, arg)`` but weakly referencing ``arg``. - """ + # Like ``functools.partial(fn, arg)`` but weakly referencing ``arg``. def __init__(self, fn, arg): self.fn = fn self.weak_arg = weakref.ref(arg) functools.update_wrapper(self, fn) - def __call__(self, *args): + def __call__(self, *args, **kwargs): arg = self.weak_arg() - return self.fn(arg, *args) + return self.fn(arg, *args, **kwargs) -class CachedOpMeta(type): +class OpMeta(type): """ - Metaclass for caching op instance construction. - Caching strategy is to key on ``*args`` and retain values forever. + Metaclass for :class:`Op` classes. """ def __init__(cls, *args, **kwargs): super().__init__(*args, **kwargs) - cls._instance_cache = {} + cls._subclass_registry = [] - def __call__(cls, *args, **kwargs): - try: - return cls._instance_cache[args] - except KeyError: - instance = super(CachedOpMeta, cls).__call__(*args, **kwargs) - cls._instance_cache[args] = instance - return instance + # Register all existing patterns. + for supercls in reversed(inspect.getmro(cls)): + for pattern, fn in getattr(supercls, "_subclass_registry", ()): + cls.dispatcher.add(pattern, WeakPartial(fn, cls)) + @property + def register(cls): + return cls.dispatcher.register -class WrappedOpMeta(type): + +class CachedOpMeta(OpMeta): """ - Metaclass for ops that wrap temporary backend ops. - Caching strategy is to key on ``id(backend_op)`` and forget values asap. + Metaclass for caching op instance construction. + + Caching strategy is to key on ``args[arity:],kwargs`` and retain values + forever. This requires all non-funsor args to be hashable. """ def __init__(cls, *args, **kwargs): super().__init__(*args, **kwargs) - cls._instance_cache = weakref.WeakValueDictionary() + cls._instance_cache = {} - def __call__(cls, fn): - if inspect.ismethod(fn): - key = id(fn.__self__), fn.__func__ # e.g. t.log_abs_det_jacobian - else: - key = id(fn) # e.g. t.inv + def __call__(cls, *args, **kwargs): + bound = cls.signature.bind(*args, **kwargs) + bound.apply_defaults() + args = bound.args + fn = cls.dispatcher.dispatch(*args[: cls.arity]) + return fn(*args, **bound.kwargs) + + def bind_partial(cls, *args, **kwargs): + """ + Finds or constructs an instance ``op`` such that:: + + op(*args[:op.arity]) == cls()(*args, **kwargs) + + where ``cls()`` is the default op. + """ + bound = cls.signature.bind(*args, **kwargs) + args = bound.args + assert len(args) >= cls.arity, "missing required args" + args = args[cls.arity :] + kwargs = bound.kwargs + key = cls._hash_args_kwargs(args, tuple(kwargs.items())) try: return cls._instance_cache[key] except KeyError: - op = super().__call__(fn) - op.fn = fn # Ensures the key id(fn) is not reused. + op = cls(*args, **kwargs) cls._instance_cache[key] = op return op + @staticmethod + def _hash_args_kwargs(args, kwargs): + return args, tuple(kwargs.items()) -class Op(Dispatcher): - _all_instances = weakref.WeakSet() - def __init_subclass__(cls, **kwargs): - super().__init_subclass__(**kwargs) - cls._subclass_registry = [] +class Op(metaclass=OpMeta): + r""" + Abstract base class for all mathematical operations on ground terms. - def __init__(self, fn, *, name=None): - if isinstance(fn, str): - fn, name = None, fn - if name is None: - name = fn.__name__ - super(Op, self).__init__(name) - if fn is not None: - # register as default operation - for nargs in (1, 2, 3): - default_signature = (object,) * nargs - self.add(default_signature, fn) + Ops take ``arity``-many leftmost positional args that may be funsors, + followed by additional non-funsor args and kwargs. The additional args and + kwargs must have default values. - # Register all existing patterns. - for supercls in reversed(inspect.getmro(type(self))): - for pattern, fn in getattr(supercls, "_subclass_registry", ()): - self.add(pattern, WeakPartial(fn, self)) - # Save self for registering future patterns. - Op._all_instances.add(self) + :cvar int arity: The number of funsor arguments this op takes. Must be + defined by subclasses. + :param \*args: + :param \*\*kwargs: All extra arguments to this op, excluding the arguments + up to ``.arity``, + """ + + arity = NotImplemented # abstract + + def __init__(self, *args, **kwargs): + super().__init__() + cls = type(self) + args = (None,) * cls.arity + args + bound = cls.signature.bind(*args, **kwargs) + bound.apply_defaults() + self.defaults = tuple(bound.arguments.items())[cls.arity :] def __copy__(self): return self @@ -99,7 +121,8 @@ def __deepcopy__(self, memo): return self def __reduce__(self): - return self.__name__ + args = self.bound.args[self.arity :] + return apply, (type(self), args, self.bound.kwargs) def __repr__(self): return "ops." + self.__name__ @@ -107,41 +130,69 @@ def __repr__(self): def __str__(self): return self.__name__ + def __call__(self, *args, **kwargs): + # Normalize args, kwargs. + cls = type(self) + bound = cls.signature.bind(*args, **kwargs) + for key, value in self.defaults: + bound.arguments.setdefault(key, value) + args = bound.args + assert len(args) >= cls.arity + kwargs = bound.kwargs + + # Dispatch. + fn = cls.dispatcher.dispatch(*args[: cls.arity]) + return fn(*args, **kwargs) + @classmethod def subclass_register(cls, *pattern): def decorator(fn): - # Register with all existing instances. - for op in Op._all_instances: - if isinstance(op, cls): - op.add(pattern, WeakPartial(fn, op)) - # Ensure registration with all future instances. + # Register with all existing sublasses. + for subcls in [cls] + cls.__subclasses__(): + cls.dispatcher.add(pattern, WeakPartial(fn, subcls)) + # Ensure registration with all future subclasses. cls._subclass_registry.append((pattern, fn)) return fn return decorator + @classmethod + def make(cls, fn=None, *, name=None, metaclass=OpMeta, module_name="funsor.ops"): + """ + Factory to create a new :class:`Op` subclass together with a new + instance of that class. + """ + if not isinstance(cls.arity, int): + raise TypeError( + f"Can't instantiate abstract class {cls.__name__} with abstract arity" + ) -def make_op(fn=None, parent=None, *, name=None, module_name="funsor.ops"): - """ - Factory to create a new :class:`Op` subclass and a new instance of that class. - """ - # Support use as decorator. - if fn is None: - return lambda fn: make_op(fn, parent, name=name, module_name=module_name) - - if parent is None: - parent = Op - assert issubclass(parent, Op) - - if name is None: - name = fn if isinstance(fn, str) else fn.__name__ - assert isinstance(name, str) + # Support use as decorator. + if fn is None: + return lambda fn: cls.make(fn, name=name, module_name=module_name) + assert callable(fn) - classname = name.capitalize().rstrip("_") + "Op" # e.g. add -> AddOp - cls = type(classname, (parent,), {}) - cls.__module__ = module_name - op = cls(fn, name=name) - return op + if name is None: + name = fn.__name__ + assert isinstance(name, str) + + assert issubclass(metaclass, OpMeta) + classname = name.capitalize().rstrip("_") + "Op" # e.g. add -> AddOp + signature = inspect.Signature.from_callable(fn) + dispatcher = Dispatcher(name) + op_class = metaclass( + classname, + (cls,), + { + "default": fn, + "dispatcher": dispatcher, + "signature": signature, + }, + ) + op_class.__module__ = module_name + dispatcher.add((object,) * cls.arity, fn) + op = op_class() + return op def declare_op_types(locals_, all_, name_): @@ -161,12 +212,20 @@ def declare_op_types(locals_, all_, name_): all_.sort() +class NullaryOp(Op): + arity = 0 + + class UnaryOp(Op): - pass + arity = 1 class BinaryOp(Op): - pass + arity = 2 + + +class TernaryOp(Op): + arity = 3 class TransformOp(UnaryOp): @@ -197,6 +256,30 @@ def log_abs_det_jacobian(x, y): raise NotImplementedError +class WrappedOpMeta(type): + """ + Metaclass for ops that wrap temporary backend ops. + Caching strategy is to key on ``id(backend_op)`` and forget values asap. + """ + + def __init__(cls, *args, **kwargs): + super().__init__(*args, **kwargs) + cls._instance_cache = weakref.WeakValueDictionary() + + def __call__(cls, fn): + if inspect.ismethod(fn): + key = id(fn.__self__), fn.__func__ # e.g. t.log_abs_det_jacobian + else: + key = id(fn) # e.g. t.inv + try: + return cls._instance_cache[key] + except KeyError: + op = super().__call__(fn) + op.fn = fn # Ensures the key id(fn) is not reused. + cls._instance_cache[key] = op + return op + + class WrappedTransformOp(TransformOp, metaclass=WrappedOpMeta): """ Wrapper for a backend ``Transform`` object that provides ``.inv`` and @@ -265,13 +348,14 @@ class LogAbsDetJacobianOp(BinaryOp, metaclass=WrappedOpMeta): "CachedOpMeta", "DISTRIBUTIVE_OPS", "LogAbsDetJacobianOp", + "NullaryOp", "Op", "SAFE_BINARY_INVERSES", + "TernaryOp", "TransformOp", "UNARY_INVERSES", "UNITS", "UnaryOp", "WrappedTransformOp", "declare_op_types", - "make_op", ] diff --git a/funsor/tensor.py b/funsor/tensor.py index 63f5ffb1e..6d994472a 100644 --- a/funsor/tensor.py +++ b/funsor/tensor.py @@ -891,7 +891,7 @@ def eager_cat_homogeneous(name, part_name, *parts): del inputs[part_name] dim = 0 - tensor = ops.cat(dim, *tensors) + tensor = ops.cat(tensors, dim) inputs = OrderedDict([(name, Bint[tensor.shape[dim]])] + list(inputs.items())) return Tensor(tensor, inputs, dtype=output.dtype) diff --git a/funsor/terms.py b/funsor/terms.py index 9423e0266..557d8602b 100644 --- a/funsor/terms.py +++ b/funsor/terms.py @@ -1916,22 +1916,30 @@ def quote_inplace_first_arg_on_first_line(arg, indent, out): out[-1] = i, line + ")" -ops.UnaryOp.subclass_register(Funsor)(Unary) -ops.BinaryOp.subclass_register(Funsor, Funsor)(Binary) -ops.AssociativeOp.subclass_register(Funsor, Funsor)(Binary) -ops.AssociativeOp.subclass_register(Funsor)(Unary) # Reductions. +@ops.UnaryOp.subclass_register(Funsor) +def unary_funsor(cls, arg, *args, **kwargs): + op = cls.bind_partial(arg, *args, **kwargs) + return Unary(op, arg) + + +@ops.BinaryOp.subclass_register(Funsor, Funsor) +def binary_funsor_funsor(cls, lhs, rhs, *args, **kwargs): + op = cls.bind_partial(lhs, rhs, *args, **kwargs) + return Binary(op, lhs, rhs) @ops.BinaryOp.subclass_register(object, Funsor) -@ops.AssociativeOp.subclass_register(object, Funsor) -def binary_object_funsor(op, x, y): - return Binary(op, to_funsor(x), y) +def binary_object_funsor(cls, lhs, rhs, *args, **kwargs): + lhs = to_funsor(lhs) + op = cls.bind_partial(lhs, rhs, *args, **kwargs) + return Binary(op, lhs, rhs) @ops.BinaryOp.subclass_register(Funsor, object) -@ops.AssociativeOp.subclass_register(Funsor, object) -def binary_funsor_object(op, x, y): - return Binary(op, x, to_funsor(y)) +def binary_funsor_object(cls, lhs, rhs, *args, **kwargs): + rhs = to_funsor(rhs) + op = cls.bind_partial(lhs, rhs, *args, **kwargs) + return Binary(op, lhs, rhs) __all__ = [ diff --git a/funsor/torch/ops.py b/funsor/torch/ops.py index f1a2fd11e..2684015cf 100644 --- a/funsor/torch/ops.py +++ b/funsor/torch/ops.py @@ -2,6 +2,7 @@ # SPDX-License-Identifier: Apache-2.0 import numbers +import typing import torch @@ -58,11 +59,8 @@ def _astype(x, dtype): return x.type(getattr(torch, dtype)) -@ops.cat.register(int, [torch.Tensor]) -def _cat(dim, *x): - if len(x) == 1: - return x[0] - return torch.cat(x, dim=dim) +ops.cat.register(typing.Tuple[torch.Tensor, ...])(torch.cat) +ops.cat.register(typing.List[torch.Tensor, ...])(torch.cat) @ops.cholesky.register(torch.Tensor) diff --git a/test/test_tensor.py b/test/test_tensor.py index 8be2d769c..50b8fa410 100644 --- a/test/test_tensor.py +++ b/test/test_tensor.py @@ -1087,7 +1087,7 @@ def test_scatter_number(op): proto = source.data.reshape((-1,))[:1].reshape(()) zero = ops.full_like(ops.expand(proto, (5, 2)), ops.UNITS[op]) - expected_data = ops.cat(1, source.data.reshape((5, 1)), zero) + expected_data = ops.cat([source.data.reshape((5, 1)), zero], 1) expected = Tensor(expected_data, OrderedDict(k=Bint[5], i=Bint[3])) assert_close(actual, expected) From 4d33515e5ed17c8eec27afaf64839c70c5b87462 Mon Sep 17 00:00:00 2001 From: Fritz Obermeyer Date: Mon, 15 Mar 2021 16:12:07 -0400 Subject: [PATCH 02/22] Add docs; merge CachedOpMeta into OpMeta --- funsor/ops/array.py | 4 ++-- funsor/ops/builtin.py | 3 +-- funsor/ops/op.py | 36 +++++++++++++++++++----------------- 3 files changed, 22 insertions(+), 21 deletions(-) diff --git a/funsor/ops/array.py b/funsor/ops/array.py index 6c4c0fb37..400d26746 100644 --- a/funsor/ops/array.py +++ b/funsor/ops/array.py @@ -26,8 +26,8 @@ DISTRIBUTIVE_OPS, UNITS, BinaryOp, - CachedOpMeta, Op, + OpMeta, TernaryOp, UnaryOp, declare_op_types, @@ -72,7 +72,7 @@ def logaddexp(x, y): sample = logaddexp.make(logaddexp.default, name="sample") -class ReshapeMeta(CachedOpMeta): +class ReshapeMeta(OpMeta): def _hash_args_kwargs(cls, shape): return tuple(shape) # necessary to convert torch.Size to tuple diff --git a/funsor/ops/builtin.py b/funsor/ops/builtin.py index a9290b916..856cf8fd5 100644 --- a/funsor/ops/builtin.py +++ b/funsor/ops/builtin.py @@ -12,7 +12,6 @@ UNARY_INVERSES, UNITS, BinaryOp, - CachedOpMeta, Op, TransformOp, UnaryOp, @@ -35,7 +34,7 @@ def nullop(x, y): raise ValueError("should never actually evaluate this!") -@BinaryOp.make(metaclass=CachedOpMeta) +@BinaryOp.make def getitem(lhs, rhs, offset=0): if offset == 0: return lhs[rhs] diff --git a/funsor/ops/op.py b/funsor/ops/op.py index ade11721d..cd7a6f948 100644 --- a/funsor/ops/op.py +++ b/funsor/ops/op.py @@ -28,10 +28,16 @@ def __call__(self, *args, **kwargs): class OpMeta(type): """ Metaclass for :class:`Op` classes. + + This weakly caches op instances. Caching strategy is to key on + ``args[arity:],kwargs`` and to weakly retain values. Caching requires all + non-funsor args to be hashable; for non-hashable args, implement a derived + metaclass with custom :meth:`hash_args_kwargs` method. """ def __init__(cls, *args, **kwargs): super().__init__(*args, **kwargs) + cls._instance_cache = weakref.WeakValueDictionary() cls._subclass_registry = [] # Register all existing patterns. @@ -43,19 +49,6 @@ def __init__(cls, *args, **kwargs): def register(cls): return cls.dispatcher.register - -class CachedOpMeta(OpMeta): - """ - Metaclass for caching op instance construction. - - Caching strategy is to key on ``args[arity:],kwargs`` and retain values - forever. This requires all non-funsor args to be hashable. - """ - - def __init__(cls, *args, **kwargs): - super().__init__(*args, **kwargs) - cls._instance_cache = {} - def __call__(cls, *args, **kwargs): bound = cls.signature.bind(*args, **kwargs) bound.apply_defaults() @@ -67,7 +60,7 @@ def bind_partial(cls, *args, **kwargs): """ Finds or constructs an instance ``op`` such that:: - op(*args[:op.arity]) == cls()(*args, **kwargs) + op(*args[:cls.arity]) == cls()(*args, **kwargs) where ``cls()`` is the default op. """ @@ -76,7 +69,7 @@ def bind_partial(cls, *args, **kwargs): assert len(args) >= cls.arity, "missing required args" args = args[cls.arity :] kwargs = bound.kwargs - key = cls._hash_args_kwargs(args, tuple(kwargs.items())) + key = cls.hash_args_kwargs(args, tuple(kwargs.items())) try: return cls._instance_cache[key] except KeyError: @@ -85,7 +78,7 @@ def bind_partial(cls, *args, **kwargs): return op @staticmethod - def _hash_args_kwargs(args, kwargs): + def hash_args_kwargs(args, kwargs): return args, tuple(kwargs.items()) @@ -97,6 +90,16 @@ class Op(metaclass=OpMeta): followed by additional non-funsor args and kwargs. The additional args and kwargs must have default values. + When wrapping new backend ops, keep in mind these restrictions: + + - Create new ops only by decoraing a default implementation with + ``@UnaryOp.make``, ``@BinaryOp.make``, etc. + - Register backend-specific implementations via ``@my_op.register(type1)``, + ``@my_op.register(type1, type2)`` etc for arity 1, 2, etc. Patterns may + include only the first ``arity``-many types. + - Only the first ``arity``-many arguments may be funsors. Remaining args + and kwargs must all be ground Python data. + :cvar int arity: The number of funsor arguments this op takes. Must be defined by subclasses. :param \*args: @@ -345,7 +348,6 @@ class LogAbsDetJacobianOp(BinaryOp, metaclass=WrappedOpMeta): __all__ = [ "BINARY_INVERSES", "BinaryOp", - "CachedOpMeta", "DISTRIBUTIVE_OPS", "LogAbsDetJacobianOp", "NullaryOp", From 54baf12257c38298c3925bc4ada14626c659eefb Mon Sep 17 00:00:00 2001 From: Fritz Obermeyer Date: Mon, 15 Mar 2021 16:35:05 -0400 Subject: [PATCH 03/22] Fix some ops --- funsor/jax/ops.py | 2 +- funsor/ops/array.py | 25 +++++++++++++++++-------- funsor/ops/builtin.py | 15 +++++++++++++-- funsor/ops/op.py | 36 +++++++++++------------------------- funsor/terms.py | 8 ++++---- 5 files changed, 46 insertions(+), 40 deletions(-) diff --git a/funsor/jax/ops.py b/funsor/jax/ops.py index 6c38531be..bc49d25d6 100644 --- a/funsor/jax/ops.py +++ b/funsor/jax/ops.py @@ -26,7 +26,7 @@ ops.clamp.register(array, type(None), numbers.Number)(np.clip) ops.clamp.register(array, type(None), type(None))(np.clip) ops.exp.register(array)(np.exp) -ops.full_like.register(array, numbers.Number)(np.full_like) +ops.new_full.register(array)(np.full_like) ops.log1p.register(array)(np.log1p) ops.max.register(array)(np.maximum) ops.min.register(array)(np.minimum) diff --git a/funsor/ops/array.py b/funsor/ops/array.py index 400d26746..3c2c28394 100644 --- a/funsor/ops/array.py +++ b/funsor/ops/array.py @@ -43,7 +43,6 @@ amax = UnaryOp.make(np.amax) amin = UnaryOp.make(np.amin) any = UnaryOp.make(np.any) -full_like = UnaryOp.make(np.full_like) isnan = UnaryOp.make(np.isnan) prod = UnaryOp.make(np.prod) sum = UnaryOp.make(np.sum) @@ -55,6 +54,11 @@ atanh.register(array)(np.arctanh) +@UnaryOp.make +def full_like(prototype, shape=(), fill_value=0): + return np.full_like(prototype, shape, fill_value) + + @log.register(array) def _log(x): if x.dtype == "bool": @@ -73,8 +77,13 @@ def logaddexp(x, y): class ReshapeMeta(OpMeta): - def _hash_args_kwargs(cls, shape): - return tuple(shape) # necessary to convert torch.Size to tuple + def hash_args_kwargs(cls, args, kwargs): + assert not kwargs + if args: + (shape,) = args + shape = tuple(shape) # necessary to convert torch.Size to tuple + args = (shape,) + return super().hash_args_kwargs(args, kwargs) @UnaryOp.make(metaclass=ReshapeMeta) @@ -236,7 +245,7 @@ def argmax(x, dim): raise NotImplementedError -@argmax.register(array, int) +@argmax.register(array) def _argmax(x, dim): return np.argmax(x, dim) @@ -246,7 +255,7 @@ def new_arange(x, stop): return np.arange(stop) -@new_arange.register(array, int, int, int) +@new_arange.register(array) def _new_arange(x, start, stop, step): return np.arange(start, stop, step) @@ -257,18 +266,18 @@ def new_zeros(x, shape): @UnaryOp.make -def new_full(x, shape, value): +def new_full(x, shape=(), value=math.nan): return np.full(shape, value, dtype=x.dtype) @UnaryOp.make -def new_eye(x, shape): +def new_eye(x, shape=()): n = shape[-1] return np.broadcast_to(np.eye(n), shape + (n,)) @UnaryOp.make -def permute(x, dims): +def permute(x, dims=()): return np.transpose(x, axes=dims) diff --git a/funsor/ops/builtin.py b/funsor/ops/builtin.py index 856cf8fd5..1ba80f939 100644 --- a/funsor/ops/builtin.py +++ b/funsor/ops/builtin.py @@ -19,6 +19,8 @@ ) _builtin_abs = abs +_builtin_max = max +_builtin_min = min _builtin_pow = pow _builtin_sum = sum @@ -64,8 +66,17 @@ def getitem(lhs, rhs, offset=0): rshift = BinaryOp.make(operator.rshift) or_ = AssociativeOp.make(operator.or_) xor = AssociativeOp.make(operator.xor) -max = AssociativeOp.make(max) -min = AssociativeOp.make(min) + + +@AssociativeOp.make +def max(lhs, rhs): + return _builtin_max(lhs, rhs) + + +@AssociativeOp.make +def min(lhs, rhs): + return _builtin_min(lhs, rhs) + lgamma = UnaryOp.make(math.lgamma) log1p = UnaryOp.make(math.log1p) diff --git a/funsor/ops/op.py b/funsor/ops/op.py index cd7a6f948..bc6260637 100644 --- a/funsor/ops/op.py +++ b/funsor/ops/op.py @@ -50,32 +50,16 @@ def register(cls): return cls.dispatcher.register def __call__(cls, *args, **kwargs): + args = (None,) * cls.arity + args bound = cls.signature.bind(*args, **kwargs) bound.apply_defaults() - args = bound.args - fn = cls.dispatcher.dispatch(*args[: cls.arity]) - return fn(*args, **bound.kwargs) - - def bind_partial(cls, *args, **kwargs): - """ - Finds or constructs an instance ``op`` such that:: - - op(*args[:cls.arity]) == cls()(*args, **kwargs) - - where ``cls()`` is the default op. - """ - bound = cls.signature.bind(*args, **kwargs) - args = bound.args - assert len(args) >= cls.arity, "missing required args" - args = args[cls.arity :] + args = bound.args[cls.arity :] kwargs = bound.kwargs - key = cls.hash_args_kwargs(args, tuple(kwargs.items())) - try: - return cls._instance_cache[key] - except KeyError: - op = cls(*args, **kwargs) - cls._instance_cache[key] = op - return op + key = cls.hash_args_kwargs(args, kwargs) + op = cls._instance_cache.get(key, None) + if op is None: + op = cls._instance_cache[key] = super().__call__(*args, **kwargs) + return op @staticmethod def hash_args_kwargs(args, kwargs): @@ -90,7 +74,8 @@ class Op(metaclass=OpMeta): followed by additional non-funsor args and kwargs. The additional args and kwargs must have default values. - When wrapping new backend ops, keep in mind these restrictions: + When wrapping new backend ops, keep in mind these restrictions, which may + require you to wrap backend functions before making them into ops: - Create new ops only by decoraing a default implementation with ``@UnaryOp.make``, ``@BinaryOp.make``, etc. @@ -99,6 +84,7 @@ class Op(metaclass=OpMeta): include only the first ``arity``-many types. - Only the first ``arity``-many arguments may be funsors. Remaining args and kwargs must all be ground Python data. + - All remaining non-funsor args and kwargs must define default values. :cvar int arity: The number of funsor arguments this op takes. Must be defined by subclasses. @@ -259,7 +245,7 @@ def log_abs_det_jacobian(x, y): raise NotImplementedError -class WrappedOpMeta(type): +class WrappedOpMeta(OpMeta): """ Metaclass for ops that wrap temporary backend ops. Caching strategy is to key on ``id(backend_op)`` and forget values asap. diff --git a/funsor/terms.py b/funsor/terms.py index 557d8602b..b913fc470 100644 --- a/funsor/terms.py +++ b/funsor/terms.py @@ -1918,27 +1918,27 @@ def quote_inplace_first_arg_on_first_line(arg, indent, out): @ops.UnaryOp.subclass_register(Funsor) def unary_funsor(cls, arg, *args, **kwargs): - op = cls.bind_partial(arg, *args, **kwargs) + op = cls(*args, **kwargs) return Unary(op, arg) @ops.BinaryOp.subclass_register(Funsor, Funsor) def binary_funsor_funsor(cls, lhs, rhs, *args, **kwargs): - op = cls.bind_partial(lhs, rhs, *args, **kwargs) + op = cls(*args, **kwargs) return Binary(op, lhs, rhs) @ops.BinaryOp.subclass_register(object, Funsor) def binary_object_funsor(cls, lhs, rhs, *args, **kwargs): + op = cls(*args, **kwargs) lhs = to_funsor(lhs) - op = cls.bind_partial(lhs, rhs, *args, **kwargs) return Binary(op, lhs, rhs) @ops.BinaryOp.subclass_register(Funsor, object) def binary_funsor_object(cls, lhs, rhs, *args, **kwargs): + op = cls(*args, **kwargs) rhs = to_funsor(rhs) - op = cls.bind_partial(lhs, rhs, *args, **kwargs) return Binary(op, lhs, rhs) From 6a4ae0e91b42b11fa16be106a224a56a800e00cf Mon Sep 17 00:00:00 2001 From: Fritz Obermeyer Date: Mon, 15 Mar 2021 16:53:02 -0400 Subject: [PATCH 04/22] Add Ternary -> Finitary pattern --- funsor/ops/array.py | 2 +- funsor/terms.py | 13 +++++++++++++ 2 files changed, 14 insertions(+), 1 deletion(-) diff --git a/funsor/ops/array.py b/funsor/ops/array.py index 3c2c28394..f93c7ea11 100644 --- a/funsor/ops/array.py +++ b/funsor/ops/array.py @@ -153,7 +153,7 @@ def diagonal(x, dim1, dim2): raise NotImplementedError -@diagonal.register(array, int, int) +@diagonal.register(array) def _diagonal(x, dim1, dim2): return np.diagonal(x, axis1=dim1, axis2=dim2) diff --git a/funsor/terms.py b/funsor/terms.py index b913fc470..d079fa806 100644 --- a/funsor/terms.py +++ b/funsor/terms.py @@ -1942,6 +1942,19 @@ def binary_funsor_object(cls, lhs, rhs, *args, **kwargs): return Binary(op, lhs, rhs) +@ops.TernaryOp.subclass_register(Funsor, Funsor, Funsor) +@ops.TernaryOp.subclass_register(Funsor, Funsor, object) +@ops.TernaryOp.subclass_register(Funsor, object, object) +@ops.TernaryOp.subclass_register(object, Funsor, object) +@ops.TernaryOp.subclass_register(object, object, Funsor) +def binary_funsor_object(cls, x, y, z, *args, **kwargs): + op = cls(*args, **kwargs) + x = to_funsor(x) + y = to_funsor(y) + z = to_funsor(z) + return Finitary(op, (x, y, z)) + + __all__ = [ "Approximate", "Binary", From 0c14b6236597989a0f4018259dd98b5bae07848d Mon Sep 17 00:00:00 2001 From: Fritz Obermeyer Date: Mon, 15 Mar 2021 21:37:06 -0400 Subject: [PATCH 05/22] Fix some dispatch logic; fix patterns --- funsor/adjoint.py | 4 +- funsor/affine.py | 4 +- funsor/cnf.py | 36 +++++++------- funsor/domains.py | 25 ++++++++-- funsor/jax/ops.py | 75 +++++++++++++---------------- funsor/ops/array.py | 40 ++++++++------- funsor/ops/builtin.py | 4 +- funsor/ops/op.py | 66 ++++++++++++++++++------- funsor/optimizer.py | 16 +++--- funsor/registry.py | 4 +- funsor/tensor.py | 49 ++++++------------- funsor/terms.py | 14 ++++-- funsor/testing.py | 6 ++- funsor/torch/ops.py | 64 +++++++++++------------- test/examples/test_bart.py | 6 +-- test/examples/test_sensor_fusion.py | 2 +- 16 files changed, 225 insertions(+), 190 deletions(-) diff --git a/funsor/adjoint.py b/funsor/adjoint.py index b80cb2534..1df57a406 100644 --- a/funsor/adjoint.py +++ b/funsor/adjoint.py @@ -4,7 +4,7 @@ from collections import defaultdict from collections.abc import Hashable -from funsor.cnf import Contraction, nullop +from funsor.cnf import Contraction, null from funsor.interpretations import Interpretation, reflect from funsor.interpreter import stack_reinterpret from funsor.ops import AssociativeOp @@ -233,7 +233,7 @@ def adjoint_contract_generic( def adjoint_contract( adj_sum_op, adj_prod_op, out_adj, sum_op, prod_op, reduced_vars, lhs, rhs ): - if prod_op is adj_prod_op and sum_op in (nullop, adj_sum_op): + if prod_op is adj_prod_op and sum_op in (null, adj_sum_op): # the only change is here: out_adj = Approximate( diff --git a/funsor/affine.py b/funsor/affine.py index cf8f695ab..5c6b31535 100644 --- a/funsor/affine.py +++ b/funsor/affine.py @@ -8,7 +8,7 @@ from funsor.domains import Bint from funsor.interpreter import gensym -from funsor.tensor import EinsumOp, Tensor, get_default_prototype +from funsor.tensor import Tensor, get_default_prototype from funsor.terms import Binary, Finitary, Funsor, Lambda, Reduce, Unary, Variable from . import ops @@ -92,7 +92,7 @@ def _(fn): return affine_inputs(fn.arg) - fn.reduced_vars -@affine_inputs.register(Finitary[EinsumOp, tuple]) +@affine_inputs.register(Finitary[ops.EinsumOp, tuple]) def _(fn): # This is simply a multiary version of the above Binary(ops.mul, ...) case. results = [] diff --git a/funsor/cnf.py b/funsor/cnf.py index d6a11d524..b42963871 100644 --- a/funsor/cnf.py +++ b/funsor/cnf.py @@ -17,7 +17,7 @@ from funsor.gaussian import Gaussian from funsor.interpretations import eager, normalize, reflect from funsor.interpreter import recursion_reinterpret -from funsor.ops import DISTRIBUTIVE_OPS, AssociativeOp, NullOp, nullop +from funsor.ops import DISTRIBUTIVE_OPS, AssociativeOp, NullOp, null from funsor.tensor import Tensor from funsor.terms import ( _INFIX, @@ -69,7 +69,7 @@ def __init__(self, red_op, bin_op, reduced_vars, terms): for v in terms: inputs.update((k, d) for k, d in v.inputs.items() if k not in bound) - if bin_op is nullop: + if bin_op is null: output = terms[0].output else: output = reduce( @@ -107,8 +107,8 @@ def unscaled_sample(self, sampled_vars, sample_inputs, rng_key=None): if not sampled_vars: return self - if self.red_op in (ops.logaddexp, nullop): - if self.bin_op in (ops.nullop, ops.logaddexp): + if self.red_op in (ops.logaddexp, null): + if self.bin_op in (ops.null, ops.logaddexp): if rng_key is not None and get_backend() == "jax": import jax @@ -272,7 +272,7 @@ def eager_contraction_generic_recursive(red_op, bin_op, reduced_vars, terms): if unique_vars: result = term.reduce(red_op, unique_vars) if result is not normalize.interpret( - Contraction, red_op, nullop, unique_vars, (term,) + Contraction, red_op, null, unique_vars, (term,) ): terms[i] = result reduced_vars -= unique_vars @@ -427,7 +427,7 @@ def normalize_contraction_commutative_canonical_order( ) def normalize_contraction_commute_joint(red_op, bin_op, reduced_vars, mixture, other): return Contraction( - mixture.red_op if red_op is nullop else red_op, + mixture.red_op if red_op is null else red_op, bin_op, reduced_vars | mixture.reduced_vars, *(mixture.terms + (other,)) @@ -439,7 +439,7 @@ def normalize_contraction_commute_joint(red_op, bin_op, reduced_vars, mixture, o ) def normalize_contraction_commute_joint(red_op, bin_op, reduced_vars, other, mixture): return Contraction( - mixture.red_op if red_op is nullop else red_op, + mixture.red_op if red_op is null else red_op, bin_op, reduced_vars | mixture.reduced_vars, *(mixture.terms + (other,)) @@ -462,13 +462,13 @@ def normalize_trivial(red_op, bin_op, reduced_vars, term): @normalize.register(Contraction, AssociativeOp, AssociativeOp, frozenset, tuple) def normalize_contraction_generic_tuple(red_op, bin_op, reduced_vars, terms): - if not reduced_vars and red_op is not nullop: - return Contraction(nullop, bin_op, reduced_vars, *terms) + if not reduced_vars and red_op is not null: + return Contraction(null, bin_op, reduced_vars, *terms) - if len(terms) == 1 and bin_op is not nullop: - return Contraction(red_op, nullop, reduced_vars, *terms) + if len(terms) == 1 and bin_op is not null: + return Contraction(red_op, null, reduced_vars, *terms) - if red_op is nullop and bin_op is nullop: + if red_op is null and bin_op is null: return terms[0] if red_op is bin_op: @@ -493,11 +493,11 @@ def normalize_contraction_generic_tuple(red_op, bin_op, reduced_vars, terms): continue # fuse operations without distributing - if (v.red_op is nullop and bin_op is v.bin_op) or ( - bin_op is nullop and v.red_op in (red_op, nullop) + if (v.red_op is null and bin_op is v.bin_op) or ( + bin_op is null and v.red_op in (red_op, null) ): - red_op = v.red_op if red_op is nullop else red_op - bin_op = v.bin_op if bin_op is nullop else bin_op + red_op = v.red_op if red_op is null else red_op + bin_op = v.bin_op if bin_op is null else bin_op new_terms = terms[:i] + v.terms + terms[i + 1 :] return Contraction( red_op, bin_op, reduced_vars | v.reduced_vars, *new_terms @@ -514,12 +514,12 @@ def normalize_contraction_generic_tuple(red_op, bin_op, reduced_vars, terms): @normalize.register(Binary, AssociativeOp, Funsor, Funsor) def binary_to_contract(op, lhs, rhs): - return Contraction(nullop, op, frozenset(), lhs, rhs) + return Contraction(null, op, frozenset(), lhs, rhs) @normalize.register(Reduce, AssociativeOp, Funsor, frozenset) def reduce_funsor(op, arg, reduced_vars): - return Contraction(op, nullop, reduced_vars, arg) + return Contraction(op, null, reduced_vars, arg) @normalize.register( diff --git a/funsor/domains.py b/funsor/domains.py index 617e965c5..8008759c3 100644 --- a/funsor/domains.py +++ b/funsor/domains.py @@ -250,14 +250,15 @@ def _find_domain_log_exp(op, domain): @find_domain.register(ops.ReshapeOp) def _find_domain_reshape(op, domain): - return Array[domain.dtype, op.shape] + return Array[domain.dtype, op.defaults["shape"]] @find_domain.register(ops.GetitemOp) def _find_domain_getitem(op, lhs_domain, rhs_domain): if isinstance(lhs_domain, ArrayType): + offset = op.defaults["offset"] dtype = lhs_domain.dtype - shape = lhs_domain.shape[: op.offset] + lhs_domain.shape[1 + op.offset :] + shape = lhs_domain.shape[:offset] + lhs_domain.shape[1 + offset :] return Array[dtype, shape] elif isinstance(lhs_domain, ProductDomain): # XXX should this return a Union? @@ -342,7 +343,7 @@ def _find_domain_associative_generic(op, *domains): @find_domain.register(ops.WrappedTransformOp) def _transform_find_domain(op, domain): - fn = op.dispatch(object) + fn = op.default shape = fn.forward_shape(domain.shape) return Array[domain.dtype, shape] @@ -353,6 +354,24 @@ def _transform_log_abs_det_jacobian(op, domain, codomain): return Real +@find_domain.register(ops.EinsumOp) +def _find_domain_einsum(op, operands): + equation = op.defaults["equation"] + ein_inputs, ein_output = equation.split("->") + ein_inputs = ein_inputs.split(",") + size_dict = {} + for ein_input, x in zip(ein_inputs, operands): + assert x.dtype == "real" + assert len(ein_input) == len(x.shape) + for name, size in zip(ein_input, x.shape): + other_size = size_dict.setdefault(name, size) + if other_size != size: + raise ValueError( + "Size mismatch at {}: {} vs {}".format(name, size, other_size) + ) + return Reals[tuple(size_dict[d] for d in ein_output)] + + __all__ = [ "Bint", "BintType", diff --git a/funsor/jax/ops.py b/funsor/jax/ops.py index bc49d25d6..013aa32bb 100644 --- a/funsor/jax/ops.py +++ b/funsor/jax/ops.py @@ -21,55 +21,52 @@ array = (onp.generic, onp.ndarray, DeviceArray, Tracer) ops.atanh.register(array)(np.arctanh) -ops.clamp.register(array, numbers.Number, numbers.Number)(np.clip) -ops.clamp.register(array, numbers.Number, type(None))(np.clip) -ops.clamp.register(array, type(None), numbers.Number)(np.clip) -ops.clamp.register(array, type(None), type(None))(np.clip) +ops.clamp.register(array)(np.clip) ops.exp.register(array)(np.exp) ops.new_full.register(array)(np.full_like) ops.log1p.register(array)(np.log1p) -ops.max.register(array)(np.maximum) -ops.min.register(array)(np.minimum) -ops.permute.register(array, (tuple, list))(np.transpose) +ops.max.register(array, array)(np.maximum) +ops.min.register(array, array)(np.minimum) +ops.permute.register(array)(np.transpose) ops.sigmoid.register(array)(expit) ops.sqrt.register(array)(np.sqrt) ops.tanh.register(array)(np.tanh) -ops.transpose.register(array, int, int)(np.swapaxes) -ops.unsqueeze.register(array, int)(np.expand_dims) +ops.transpose.register(array)(np.swapaxes) +ops.unsqueeze.register(array)(np.expand_dims) -@ops.all.register(array, (int, type(None))) +@ops.all.register(array) def _all(x, dim): return np.all(x, axis=dim) -@ops.amax.register(array, (int, type(None))) +@ops.amax.register(array) def _amax(x, dim, keepdims=False): return np.amax(x, axis=dim, keepdims=keepdims) -@ops.amin.register(array, (int, type(None))) +@ops.amin.register(array) def _amin(x, dim, keepdims=False): return np.amin(x, axis=dim, keepdims=keepdims) -@ops.argmax.register(array, int) +@ops.argmax.register(array) def _argmax(x, dim): return np.argmax(x, dim) -@ops.any.register(array, (int, type(None))) +@ops.any.register(array) def _any(x, dim): return np.any(x, axis=dim) -@ops.astype.register(array, str) +@ops.astype.register(array) def _astype(x, dtype): return x.astype(np.result_type(dtype)) -ops.cat.register(typing.Tuple[array, ...])(np.concatenate) -ops.cat.register(typing.List[array, ...])(np.concatenate) +ops.cat.register(typing.Tuple[typing.Union[array], ...])(np.concatenate) +ops.cat.register(typing.List[typing.Union[array]])(np.concatenate) @ops.cholesky.register(array) @@ -100,17 +97,18 @@ def _detach(x): return lax.stop_gradient(x) -@ops.diagonal.register(array, int, int) +@ops.diagonal.register(array) def _diagonal(x, dim1, dim2): return np.diagonal(x, axis1=dim1, axis2=dim2) -@ops.einsum.register(str, [array]) -def _einsum(equation, *operands): +@ops.einsum.register(typing.Tuple[typing.Union[array], ...]) +@ops.einsum.register(typing.List[typing.Union[array]]) +def _einsum(operands, equation): return np.einsum(equation, *operands) -@ops.expand.register(array, tuple) +@ops.expand.register(array) def _expand(x, shape): prepend_dim = len(shape) - np.ndim(x) assert prepend_dim >= 0 @@ -164,14 +162,13 @@ def _safe_logaddexp_tensor_number(x, y): return _safe_logaddexp_number_tensor(y, x) -@ops.logsumexp.register(array, (int, type(None))) +@ops.logsumexp.register(array) def _logsumexp(x, dim): return logsumexp(x, axis=dim) -@ops.max.register(array, array) -def _max(x, y): - return np.maximum(x, y) +ops.max.register(array, array)(np.maximum) +ops.min.register(array, array)(np.minimum) @ops.max.register((int, float), array) @@ -184,11 +181,6 @@ def _max(x, y): return np.clip(x, a_min=y, a_max=None) -@ops.min.register(array, array) -def _min(x, y): - return np.minimum(x, y) - - # TODO: replace (int, float) by numbers.Number @ops.min.register((int, float), array) def _min(x, y): @@ -200,28 +192,27 @@ def _min(x, y): return np.clip(x, a_min=None, a_max=y) -@ops.new_arange.register(array, int, int, int) +@ops.new_arange.register(array) def _new_arange(x, start, stop, step): - return np.arange(start, stop, step) - - -@ops.new_arange.register(array, int) -def _new_arange(x, stop): - return np.arange(stop) + if step is not None: + return np.arange(start, stop, step) + if stop is not None: + return np.arange(start, stop) + return np.arange(start) -@ops.new_eye.register(array, tuple) +@ops.new_eye.register(array) def _new_eye(x, shape): n = shape[-1] return np.broadcast_to(np.eye(n), shape + (n,)) -@ops.new_zeros.register(array, tuple) +@ops.new_zeros.register(array) def _new_zeros(x, shape): return onp.zeros(shape, dtype=x.dtype) -@ops.prod.register(array, (int, type(None))) +@ops.prod.register(array) def _prod(x, dim): return np.prod(x, axis=dim) @@ -257,12 +248,12 @@ def _scatter(dest, indices, src): return index_update(dest, indices, src) -@ops.stack.register(int, [array + (int, float)]) +@ops.stack.register(typing.Tuple[typing.Union[array + (int, float)], ...]) def _stack(dim, *x): return np.stack(x, axis=dim) -@ops.sum.register(array, (int, type(None))) +@ops.sum.register(array) def _sum(x, dim): return np.sum(x, axis=dim) diff --git a/funsor/ops/array.py b/funsor/ops/array.py index f93c7ea11..04542ebe2 100644 --- a/funsor/ops/array.py +++ b/funsor/ops/array.py @@ -38,12 +38,12 @@ # This is used only for pattern matching. array = (np.ndarray, np.generic) +arraylist = sum([(typing.Tuple[t, ...], typing.List[t]) for t in array], ()) all = UnaryOp.make(np.all) amax = UnaryOp.make(np.amax) amin = UnaryOp.make(np.amin) any = UnaryOp.make(np.any) -isnan = UnaryOp.make(np.isnan) prod = UnaryOp.make(np.prod) sum = UnaryOp.make(np.sum) @@ -55,7 +55,12 @@ @UnaryOp.make -def full_like(prototype, shape=(), fill_value=0): +def isnan(x): + return np.isnan(x) + + +@UnaryOp.make +def full_like(prototype, shape, fill_value): return np.full_like(prototype, shape, fill_value) @@ -96,18 +101,17 @@ def astype(x, dtype): raise NotImplementedError -@astype.register(array, str) +@astype.register(array) def _astype(x, dtype): return x.astype(dtype) @UnaryOp.make -def cat(parts, axis=0): +def cat(parts, axis): raise NotImplementedError -cat.register(typing.Tuple[array, ...])(np.concatenate) -cat.register(typing.List[array, ...])(np.concatenate) +cat.register(arraylist)(np.concatenate) @UnaryOp.make @@ -163,8 +167,7 @@ def einsum(operands, equation): raise NotImplementedError -@einsum.register(typing.Tuple[array, ...]) -@einsum.register(typing.List[array, ...]) +@einsum.register(arraylist) def _einsum(operands, equation): return np.einsum(equation, *operands) @@ -251,13 +254,17 @@ def _argmax(x, dim): @UnaryOp.make -def new_arange(x, stop): - return np.arange(stop) +def new_arange(x, start=None, stop=None, step=None): + raise NotImplementedError @new_arange.register(array) def _new_arange(x, start, stop, step): - return np.arange(start, stop, step) + if step is not None: + return np.arange(start, stop, step) + if stop is not None: + return np.arange(start, stop) + return np.arange(start) @UnaryOp.make @@ -266,18 +273,18 @@ def new_zeros(x, shape): @UnaryOp.make -def new_full(x, shape=(), value=math.nan): +def new_full(x, shape, value): return np.full(shape, value, dtype=x.dtype) @UnaryOp.make -def new_eye(x, shape=()): +def new_eye(x, shape): n = shape[-1] return np.broadcast_to(np.eye(n), shape + (n,)) @UnaryOp.make -def permute(x, dims=()): +def permute(x, dims): return np.transpose(x, axes=dims) @@ -336,8 +343,7 @@ def stack(parts, axis=0): raise NotImplementedError -stack.register(typing.Tuple[array, ...])(np.stack) -stack.register(typing.List[array, ...])(np.stack) +stack.register(arraylist)(np.stack) @UnaryOp.make @@ -345,7 +351,7 @@ def transpose(array, axis1, axis2): raise NotImplementedError -transpose.register(array, int, int)(np.swapaxes) +transpose.register(array)(np.swapaxes) @BinaryOp.make diff --git a/funsor/ops/builtin.py b/funsor/ops/builtin.py index 1ba80f939..3dff980de 100644 --- a/funsor/ops/builtin.py +++ b/funsor/ops/builtin.py @@ -31,7 +31,7 @@ class AssociativeOp(BinaryOp): @AssociativeOp.make -def nullop(x, y): +def null(x, y): """Placeholder associative op that unifies with any other op""" raise ValueError("should never actually evaluate this!") @@ -209,7 +209,7 @@ def sigmoid_log_abs_det_jacobian(x, y): "mul", "ne", "neg", - "nullop", + "null", "or_", "pos", "pow", diff --git a/funsor/ops/op.py b/funsor/ops/op.py index bc6260637..4574864f8 100644 --- a/funsor/ops/op.py +++ b/funsor/ops/op.py @@ -5,13 +5,23 @@ import inspect import weakref -from multipledispatch import Dispatcher +from funsor.registry import PartialDispatcher def apply(function, args, kwargs={}): return function(*args, **kwargs) +def _iter_subclasses(cls): + yield cls + for subcls in cls.__subclasses__(): + yield from _iter_subclasses(subcls) + + +def _snake_to_camel(name): + return "".join(part.capitalize() for part in name.split("_") if part) + + class WeakPartial: # Like ``functools.partial(fn, arg)`` but weakly referencing ``arg``. @@ -51,7 +61,7 @@ def register(cls): def __call__(cls, *args, **kwargs): args = (None,) * cls.arity + args - bound = cls.signature.bind(*args, **kwargs) + bound = cls.signature.bind_partial(*args, **kwargs) bound.apply_defaults() args = bound.args[cls.arity :] kwargs = bound.kwargs @@ -84,7 +94,6 @@ class Op(metaclass=OpMeta): include only the first ``arity``-many types. - Only the first ``arity``-many arguments may be funsors. Remaining args and kwargs must all be ground Python data. - - All remaining non-funsor args and kwargs must define default values. :cvar int arity: The number of funsor arguments this op takes. Must be defined by subclasses. @@ -99,9 +108,15 @@ def __init__(self, *args, **kwargs): super().__init__() cls = type(self) args = (None,) * cls.arity + args - bound = cls.signature.bind(*args, **kwargs) + bound = cls.signature.bind_partial(*args, **kwargs) bound.apply_defaults() - self.defaults = tuple(bound.arguments.items())[cls.arity :] + self.defaults = bound.arguments + for key in list(self.defaults)[: cls.arity]: + del self.defaults[key] + + @property + def __name__(self): + return self.name def __copy__(self): return self @@ -110,8 +125,7 @@ def __deepcopy__(self, memo): return self def __reduce__(self): - args = self.bound.args[self.arity :] - return apply, (type(self), args, self.bound.kwargs) + return apply, (type(self), (), self.defaults) def __repr__(self): return "ops." + self.__name__ @@ -123,22 +137,32 @@ def __call__(self, *args, **kwargs): # Normalize args, kwargs. cls = type(self) bound = cls.signature.bind(*args, **kwargs) - for key, value in self.defaults: + for key, value in self.defaults.items(): bound.arguments.setdefault(key, value) args = bound.args assert len(args) >= cls.arity kwargs = bound.kwargs # Dispatch. - fn = cls.dispatcher.dispatch(*args[: cls.arity]) + fn = cls.dispatcher.partial_call(*args[: cls.arity]) return fn(*args, **kwargs) + def register(self, *pattern): + if len(pattern) != self.arity: + raise ValueError( + f"Invalid pattern for {self}, " + f"expected {self.arity} types but got {len(pattern)}." + ) + return type(self).dispatcher.register(*pattern) + @classmethod def subclass_register(cls, *pattern): def decorator(fn): # Register with all existing sublasses. - for subcls in [cls] + cls.__subclasses__(): - cls.dispatcher.add(pattern, WeakPartial(fn, subcls)) + for subcls in _iter_subclasses(cls): + dispatcher = getattr(subcls, "dispatcher", None) + if dispatcher is not None: + dispatcher.add(pattern, WeakPartial(fn, subcls)) # Ensure registration with all future subclasses. cls._subclass_registry.append((pattern, fn)) return fn @@ -149,7 +173,11 @@ def decorator(fn): def make(cls, fn=None, *, name=None, metaclass=OpMeta, module_name="funsor.ops"): """ Factory to create a new :class:`Op` subclass together with a new - instance of that class. + default instance of that class. + + :param callable fn: A function whose signature can be inspected. + :returns: The new default instance. + :rtype: Op """ if not isinstance(cls.arity, int): raise TypeError( @@ -166,20 +194,19 @@ def make(cls, fn=None, *, name=None, metaclass=OpMeta, module_name="funsor.ops") assert isinstance(name, str) assert issubclass(metaclass, OpMeta) - classname = name.capitalize().rstrip("_") + "Op" # e.g. add -> AddOp + classname = _snake_to_camel(name) + "Op" # e.g. scatter_add -> ScatterAddOp signature = inspect.Signature.from_callable(fn) - dispatcher = Dispatcher(name) op_class = metaclass( classname, (cls,), { - "default": fn, - "dispatcher": dispatcher, + "name": name, "signature": signature, + "default": staticmethod(fn), + "dispatcher": PartialDispatcher(fn, name), }, ) op_class.__module__ = module_name - dispatcher.add((object,) * cls.arity, fn) op = op_class() return op @@ -217,6 +244,10 @@ class TernaryOp(Op): arity = 3 +class FinitaryOp(Op): + arity = 1 # encoded as a tuple + + class TransformOp(UnaryOp): def set_inv(self, fn): """ @@ -335,6 +366,7 @@ class LogAbsDetJacobianOp(BinaryOp, metaclass=WrappedOpMeta): "BINARY_INVERSES", "BinaryOp", "DISTRIBUTIVE_OPS", + "FinitaryOp", "LogAbsDetJacobianOp", "NullaryOp", "Op", diff --git a/funsor/optimizer.py b/funsor/optimizer.py index 91cc06763..b401a1cdd 100644 --- a/funsor/optimizer.py +++ b/funsor/optimizer.py @@ -6,7 +6,7 @@ from opt_einsum.paths import greedy import funsor.interpreter as interpreter -from funsor.cnf import Contraction, nullop +from funsor.cnf import Contraction, null from funsor.interpretations import ( DispatchedInterpretation, PrioritizedInterpretation, @@ -31,7 +31,7 @@ def unfold_contraction_generic_tuple(red_op, bin_op, reduced_vars, terms): if not isinstance(v, Contraction): continue - if v.red_op is nullop and (v.bin_op, bin_op) in DISTRIBUTIVE_OPS: + if v.red_op is null and (v.bin_op, bin_op) in DISTRIBUTIVE_OPS: # a * e * (b + c + d) -> (a * e * b) + (a * e * c) + (a * e * d) new_terms = tuple( Contraction( @@ -44,7 +44,7 @@ def unfold_contraction_generic_tuple(red_op, bin_op, reduced_vars, terms): ) return Contraction(red_op, v.bin_op, reduced_vars, *new_terms) - if red_op in (v.red_op, nullop) and (v.red_op, bin_op) in DISTRIBUTIVE_OPS: + if red_op in (v.red_op, null) and (v.red_op, bin_op) in DISTRIBUTIVE_OPS: new_terms = ( terms[:i] + (Contraction(v.red_op, v.bin_op, frozenset(), *v.terms),) @@ -54,9 +54,9 @@ def unfold_contraction_generic_tuple(red_op, bin_op, reduced_vars, terms): red_op, reduced_vars ) - if v.red_op in (red_op, nullop) and bin_op in (v.bin_op, nullop): - red_op = v.red_op if red_op is nullop else red_op - bin_op = v.bin_op if bin_op is nullop else bin_op + if v.red_op in (red_op, null) and bin_op in (v.bin_op, null): + red_op = v.red_op if red_op is null else red_op + bin_op = v.bin_op if bin_op is null else bin_op new_terms = terms[:i] + v.terms + terms[i + 1 :] return Contraction( red_op, bin_op, reduced_vars | v.reduced_vars, *new_terms @@ -94,7 +94,7 @@ def eager_contract_base(red_op, bin_op, reduced_vars, *terms): @optimize.register(Contraction, AssociativeOp, AssociativeOp, frozenset, tuple) def optimize_contract_finitary_funsor(red_op, bin_op, reduced_vars, terms): - if red_op is nullop or bin_op is nullop or not (red_op, bin_op) in DISTRIBUTIVE_OPS: + if red_op is null or bin_op is null or not (red_op, bin_op) in DISTRIBUTIVE_OPS: return None # build opt_einsum optimizer IR @@ -140,7 +140,7 @@ def optimize_contract_finitary_funsor(red_op, bin_op, reduced_vars, terms): ) path_end = Contraction( - red_op if path_end_reduced_vars else nullop, + red_op if path_end_reduced_vars else null, bin_op, path_end_reduced_vars, ta, diff --git a/funsor/registry.py b/funsor/registry.py index 389c5dadd..78d747bdc 100644 --- a/funsor/registry.py +++ b/funsor/registry.py @@ -13,9 +13,9 @@ class PartialDispatcher(Dispatcher): Wrapper to avoid appearance in stack traces. """ - def __init__(self, default=None): + def __init__(self, default=None, name="PartialDispatcher"): self.default = default if default is None else PartialDefault(default) - super().__init__("PartialDispatcher") + super().__init__(name) if default is not None: self.add(([object],), self.default) diff --git a/funsor/tensor.py b/funsor/tensor.py index 6d994472a..92f4e6dff 100644 --- a/funsor/tensor.py +++ b/funsor/tensor.py @@ -764,16 +764,18 @@ def eager_binary_tensor_tensor(op, lhs, rhs): @eager.register(Unary, ReshapeOp, Tensor) def eager_reshape_tensor(op, arg): - if arg.shape == op.shape: + shape = op.defaults["shape"] + if arg.shape == shape: return arg batch_shape = arg.data.shape[: len(arg.data.shape) - len(arg.shape)] - data = arg.data.reshape(batch_shape + op.shape) + data = arg.data.reshape(batch_shape + shape) return Tensor(data, arg.inputs, arg.dtype) @eager.register(Binary, GetitemOp, Tensor, Number) def eager_getitem_tensor_number(op, lhs, rhs): - index = [slice(None)] * (len(lhs.inputs) + op.offset) + offset = op.defaults["offset"] + index = [slice(None)] * (len(lhs.inputs) + offset) index.append(rhs.data) index = tuple(index) data = lhs.data[index] @@ -782,8 +784,9 @@ def eager_getitem_tensor_number(op, lhs, rhs): @eager.register(Binary, GetitemOp, Tensor, Variable) def eager_getitem_tensor_variable(op, lhs, rhs): - assert op.offset < len(lhs.output.shape) - assert rhs.output == Bint[lhs.output.shape[op.offset]] + offset = op.defaults["offset"] + assert offset < len(lhs.output.shape) + assert rhs.output == Bint[lhs.output.shape[offset]] assert rhs.name not in lhs.inputs # Convert a positional event dimension to a named batch dimension. @@ -791,7 +794,7 @@ def eager_getitem_tensor_variable(op, lhs, rhs): inputs[rhs.name] = rhs.output data = lhs.data target_dim = len(lhs.inputs) - source_dim = target_dim + op.offset + source_dim = target_dim + offset if target_dim != source_dim: perm = list(range(len(data.shape))) del perm[source_dim] @@ -802,8 +805,9 @@ def eager_getitem_tensor_variable(op, lhs, rhs): @eager.register(Binary, GetitemOp, Tensor, Tensor) def eager_getitem_tensor_tensor(op, lhs, rhs): - assert op.offset < len(lhs.output.shape) - assert rhs.output == Bint[lhs.output.shape[op.offset]] + offset = op.defaults["offset"] + assert offset < len(lhs.output.shape) + assert rhs.output == Bint[lhs.output.shape[offset]] # Compute inputs and outputs. if lhs.inputs == rhs.inputs: @@ -815,7 +819,7 @@ def eager_getitem_tensor_tensor(op, lhs, rhs): # Perform advanced indexing. lhs_data_dim = len(lhs_data.shape) - target_dim = lhs_data_dim - len(lhs.output.shape) + op.offset + target_dim = lhs_data_dim - len(lhs.output.shape) + offset index = [None] * lhs_data_dim for i in range(target_dim): index[i] = ops.new_arange(lhs_data, lhs_data.shape[i]).reshape( @@ -1086,29 +1090,6 @@ def max_and_argmax(x: Reals[8]) -> Tuple[Real, Bint[8]]: return functools.partial(_function, inputs, output) -class EinsumOp(ops.Op, metaclass=ops.CachedOpMeta): - def __init__(self, equation): - self.equation = equation - - -@find_domain.register(EinsumOp) -def _find_domain_einsum(op, *operands): - equation = op.equation - ein_inputs, ein_output = equation.split("->") - ein_inputs = ein_inputs.split(",") - size_dict = {} - for ein_input, x in zip(ein_inputs, operands): - assert x.dtype == "real" - assert len(ein_input) == len(x.shape) - for name, size in zip(ein_input, x.shape): - other_size = size_dict.setdefault(name, size) - if other_size != size: - raise ValueError( - "Size mismatch at {}: {} vs {}".format(name, size, other_size) - ) - return Reals[tuple(size_dict[d] for d in ein_output)] - - def Einsum(equation, *operands): """ Wrapper around :func:`torch.einsum` or :func:`np.einsum` to operate on real-valued Funsors. @@ -1120,10 +1101,10 @@ def Einsum(equation, *operands): :param str equation: An :func:`torch.einsum` or :func:`np.einsum` equation. :param tuple operands: A tuple of input funsors. """ - return Finitary(EinsumOp(equation), tuple(operands)) + return ops.einsum(operands, equation) -@eager.register(Finitary, EinsumOp, typing.Tuple[Tensor, ...]) +@eager.register(Finitary, ops.EinsumOp, typing.Tuple[Tensor, ...]) def eager_einsum(op, operands): # Make new symbols for inputs of operands. equation = op.equation diff --git a/funsor/terms.py b/funsor/terms.py index d079fa806..d23890304 100644 --- a/funsor/terms.py +++ b/funsor/terms.py @@ -1670,9 +1670,10 @@ def _alpha_convert(self, alpha_subs): @eager.register(Binary, GetitemOp, Lambda, (Funsor, Align)) def eager_getitem_lambda(op, lhs, rhs): - if op.offset == 0: + offset = op.defaults["offset"] + if offset == 0: return Subs(lhs.expr, ((lhs.var.name, rhs),)) - expr = GetitemOp(op.offset - 1)(lhs.expr, rhs) + expr = GetitemOp(offset - 1)(lhs.expr, rhs) return Lambda(lhs.var, expr) @@ -1947,7 +1948,7 @@ def binary_funsor_object(cls, lhs, rhs, *args, **kwargs): @ops.TernaryOp.subclass_register(Funsor, object, object) @ops.TernaryOp.subclass_register(object, Funsor, object) @ops.TernaryOp.subclass_register(object, object, Funsor) -def binary_funsor_object(cls, x, y, z, *args, **kwargs): +def ternary_funsor_object(cls, x, y, z, *args, **kwargs): op = cls(*args, **kwargs) x = to_funsor(x) y = to_funsor(y) @@ -1955,6 +1956,13 @@ def binary_funsor_object(cls, x, y, z, *args, **kwargs): return Finitary(op, (x, y, z)) +# FIXME allow some non-funsors +@ops.FinitaryOp.subclass_register(typing.Tuple[Funsor, ...]) +def finitary_funsor(cls, arg, *args, **kwargs): + op = cls(*args, **kwargs) + return Finitary(op, arg) + + __all__ = [ "Approximate", "Binary", diff --git a/funsor/testing.py b/funsor/testing.py index 2eb2cf7db..109744c10 100644 --- a/funsor/testing.py +++ b/funsor/testing.py @@ -81,7 +81,11 @@ def allclose(a, b, rtol=1e-05, atol=1e-08): def is_array(x): - return get_backend() != "torch" and ops.is_numeric_array(x) + if isinstance(x, Funsor): + return False + if get_backend() == "torch": + return False + return ops.is_numeric_array(x) def assert_close(actual, expected, atol=1e-6, rtol=1e-6): diff --git a/funsor/torch/ops.py b/funsor/torch/ops.py index 2684015cf..4ff3e7389 100644 --- a/funsor/torch/ops.py +++ b/funsor/torch/ops.py @@ -15,52 +15,48 @@ ops.abs.register(torch.Tensor)(torch.abs) ops.atanh.register(torch.Tensor)(torch.atanh) ops.cholesky_solve.register(torch.Tensor, torch.Tensor)(torch.cholesky_solve) -ops.clamp.register(torch.Tensor, numbers.Number, numbers.Number)(torch.clamp) -ops.clamp.register(torch.Tensor, numbers.Number, type(None))(torch.clamp) -ops.clamp.register(torch.Tensor, type(None), numbers.Number)(torch.clamp) -ops.clamp.register(torch.Tensor, type(None), type(None))(torch.clamp) +ops.clamp.register(torch.Tensor)(torch.clamp) ops.exp.register(torch.Tensor)(torch.exp) -ops.full_like.register(torch.Tensor, numbers.Number)(torch.full_like) +ops.full_like.register(torch.Tensor)(torch.full_like) ops.log1p.register(torch.Tensor)(torch.log1p) ops.sigmoid.register(torch.Tensor)(torch.sigmoid) ops.sqrt.register(torch.Tensor)(torch.sqrt) ops.tanh.register(torch.Tensor)(torch.tanh) -ops.transpose.register(torch.Tensor, int, int)(torch.transpose) -ops.unsqueeze.register(torch.Tensor, int)(torch.unsqueeze) +ops.transpose.register(torch.Tensor)(torch.transpose) -@ops.all.register(torch.Tensor, (int, type(None))) +@ops.all.register(torch.Tensor) def _all(x, dim): return x.all() if dim is None else x.all(dim=dim) -@ops.amax.register(torch.Tensor, (int, type(None))) +@ops.amax.register(torch.Tensor) def _amax(x, dim, keepdims=False): return x.max() if dim is None else x.max(dim, keepdims)[0] -@ops.amin.register(torch.Tensor, (int, type(None))) +@ops.amin.register(torch.Tensor) def _amin(x, dim, keepdims=False): return x.min() if dim is None else x.min(dim, keepdims)[0] -@ops.argmax.register(torch.Tensor, int) +@ops.argmax.register(torch.Tensor) def _argmax(x, dim): return x.max(dim).indices -@ops.any.register(torch.Tensor, (int, type(None))) +@ops.any.register(torch.Tensor) def _any(x, dim): return x.any() if dim is None else x.any(dim=dim) -@ops.astype.register(torch.Tensor, str) +@ops.astype.register(torch.Tensor) def _astype(x, dtype): return x.type(getattr(torch, dtype)) ops.cat.register(typing.Tuple[torch.Tensor, ...])(torch.cat) -ops.cat.register(typing.List[torch.Tensor, ...])(torch.cat) +ops.cat.register(typing.List[torch.Tensor])(torch.cat) @ops.cholesky.register(torch.Tensor) @@ -89,17 +85,17 @@ def _detach(x): return x.detach() -@ops.diagonal.register(torch.Tensor, int, int) +@ops.diagonal.register(torch.Tensor) def _diagonal(x, dim1, dim2): return x.diagonal(dim1=dim1, dim2=dim2) -@ops.einsum.register(str, [torch.Tensor]) -def _einsum(equation, *operands): +@ops.einsum.register(typing.Tuple[torch.Tensor, ...]) +def _einsum(operands, equation): return torch.einsum(equation, *operands) -@ops.expand.register(torch.Tensor, tuple) +@ops.expand.register(torch.Tensor) def _expand(x, shape): return x.expand(shape) @@ -150,7 +146,7 @@ def _safe_logaddexp_tensor_number(x, y): return _safe_logaddexp_number_tensor(y, x) -@ops.logsumexp.register(torch.Tensor, (int, type(None))) +@ops.logsumexp.register(torch.Tensor) def _logsumexp(x, dim): return x.reshape(-1).logsumexp(0) if dim is None else x.logsumexp(dim) @@ -185,32 +181,31 @@ def _min(x, y): return x.clamp(max=y) -@ops.new_arange.register(torch.Tensor, int, int, int) +@ops.new_arange.register(torch.Tensor) def _new_arange(x, start, stop, step): - return torch.arange(start, stop, step) + if step is not None: + return torch.arange(start, stop, step) + if stop is not None: + return torch.arange(start, stop) + return torch.arange(start) -@ops.new_arange.register(torch.Tensor, (int, torch.Tensor)) -def _new_arange(x, stop): - return torch.arange(stop) - - -@ops.new_eye.register(torch.Tensor, tuple) +@ops.new_eye.register(torch.Tensor) def _new_eye(x, shape): return torch.eye(shape[-1]).expand(shape + (-1,)) -@ops.new_zeros.register(torch.Tensor, tuple) +@ops.new_zeros.register(torch.Tensor) def _new_zeros(x, shape): return x.new_zeros(shape) -@ops.new_full.register(torch.Tensor, tuple, numbers.Number) +@ops.new_full.register(torch.Tensor) def _new_full(x, shape, value): return x.new_full(shape, value) -@ops.permute.register(torch.Tensor, (tuple, list)) +@ops.permute.register(torch.Tensor) def _permute(x, dims): return x.permute(dims) @@ -228,7 +223,7 @@ def _pow(x, y): return x ** y -@ops.prod.register(torch.Tensor, (int, type(None))) +@ops.prod.register(torch.Tensor) def _prod(x, dim): return x.prod() if dim is None else x.prod(dim=dim) @@ -272,12 +267,11 @@ def _scatter_add(destin, indices, source): return result.index_put(indices, source, accumulate=True) -@ops.stack.register(int, [torch.Tensor]) -def _stack(dim, *x): - return torch.stack(x, dim=dim) +ops.stack.register(typing.Tuple[torch.Tensor, ...])(torch.stack) +ops.stack.register(typing.List[torch.Tensor])(torch.stack) -@ops.sum.register(torch.Tensor, (int, type(None))) +@ops.sum.register(torch.Tensor) def _sum(x, dim): return x.sum() if dim is None else x.sum(dim) diff --git a/test/examples/test_bart.py b/test/examples/test_bart.py index 7f1940f10..c5ffbc45b 100644 --- a/test/examples/test_bart.py +++ b/test/examples/test_bart.py @@ -63,7 +63,7 @@ def test_bart(analytic_kl): q = Independent( Independent( Contraction( - ops.nullop, + ops.null, ops.add, frozenset(), ( @@ -177,7 +177,7 @@ def test_bart(analytic_kl): ops.logaddexp, ops.add, Contraction( - ops.nullop, + ops.null, ops.add, frozenset(), ( @@ -358,7 +358,7 @@ def test_bart(analytic_kl): ) p_likelihood = Contraction( ops.add, - ops.nullop, + ops.null, frozenset( { Variable("time_b17", Bint[2]), diff --git a/test/examples/test_sensor_fusion.py b/test/examples/test_sensor_fusion.py index f8fc8b77f..84c8d8d30 100644 --- a/test/examples/test_sensor_fusion.py +++ b/test/examples/test_sensor_fusion.py @@ -157,7 +157,7 @@ def test_affine_subs(): ( "obs_b2", Contraction( - ops.nullop, + ops.null, ops.add, frozenset(), ( From 43bae4a48c40a1a27751079deada555cb97e8a60 Mon Sep 17 00:00:00 2001 From: Fritz Obermeyer Date: Mon, 15 Mar 2021 21:52:03 -0400 Subject: [PATCH 06/22] Fix ops.einsum usage --- funsor/einsum/numpy_log.py | 2 +- funsor/tensor.py | 2 +- funsor/terms.py | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/funsor/einsum/numpy_log.py b/funsor/einsum/numpy_log.py index af3a78f2a..33d9b8b11 100644 --- a/funsor/einsum/numpy_log.py +++ b/funsor/einsum/numpy_log.py @@ -43,7 +43,7 @@ def einsum(equation, *operands): shift = ops.permute(shift, [dims.index(dim) for dim in output]) shifts.append(shift) - result = ops.log(ops.einsum(equation, *exp_operands)) + result = ops.log(ops.einsum(exp_operands, equation)) return sum(shifts + [result]) diff --git a/funsor/tensor.py b/funsor/tensor.py index 92f4e6dff..f19b719b6 100644 --- a/funsor/tensor.py +++ b/funsor/tensor.py @@ -1132,7 +1132,7 @@ def eager_einsum(op, operands): out = "".join(new_symbols[k] for k in inputs) + out equation = ",".join(ins) + "->" + out - data = ops.einsum(equation, *[x.data for x in operands]) + data = ops.einsum([x.data for x in operands], equation) return Tensor(data, inputs) diff --git a/funsor/terms.py b/funsor/terms.py index d23890304..0bf122e85 100644 --- a/funsor/terms.py +++ b/funsor/terms.py @@ -1460,7 +1460,7 @@ def __init__(self, op, args): inputs = OrderedDict() for arg in args: inputs.update(arg.inputs) - output = find_domain(op, *(arg.output for arg in args)) + output = find_domain(op, tuple(arg.output for arg in args)) super().__init__(inputs, output) self.op = op self.args = args From e66f441df0cbe0251463cdc54a24e5f588dcf7df Mon Sep 17 00:00:00 2001 From: Fritz Obermeyer Date: Mon, 15 Mar 2021 22:58:57 -0400 Subject: [PATCH 07/22] Fix test_tensors.py --- funsor/distribution.py | 9 +++---- funsor/jax/ops.py | 12 +++++---- funsor/ops/array.py | 52 ++++++++++++++++++++++++++++----------- funsor/ops/op.py | 10 +++++++- funsor/tensor.py | 42 ++++++------------------------- funsor/torch/ops.py | 3 +-- test/test_distribution.py | 4 +-- test/test_tensor.py | 9 +++---- 8 files changed, 73 insertions(+), 68 deletions(-) diff --git a/funsor/distribution.py b/funsor/distribution.py index ec05e6d2b..b5e221276 100644 --- a/funsor/distribution.py +++ b/funsor/distribution.py @@ -27,7 +27,6 @@ get_default_prototype, ignore_jit_warnings, numeric_array, - stack, ) from funsor.terms import ( Funsor, @@ -768,15 +767,15 @@ def LogNormal(loc, scale, value="value"): def eager_beta(concentration1, concentration0, value): - concentration = stack((concentration0, concentration1)) - value = stack((1 - value, value)) + concentration = ops.stack((concentration0, concentration1)) + value = ops.stack((1 - value, value)) backend_dist = import_module(BACKEND_TO_DISTRIBUTIONS_BACKEND[get_backend()]) return backend_dist.Dirichlet(concentration, value=value) # noqa: F821 def eager_binomial(total_count, probs, value): - probs = stack((1 - probs, probs)) - value = stack((total_count - value, value)) + probs = ops.stack((1 - probs, probs)) + value = ops.stack((total_count - value, value)) backend_dist = import_module(BACKEND_TO_DISTRIBUTIONS_BACKEND[get_backend()]) return backend_dist.Multinomial(total_count, probs, value=value) # noqa: F821 diff --git a/funsor/jax/ops.py b/funsor/jax/ops.py index 013aa32bb..c6d16ee13 100644 --- a/funsor/jax/ops.py +++ b/funsor/jax/ops.py @@ -23,7 +23,6 @@ ops.atanh.register(array)(np.arctanh) ops.clamp.register(array)(np.clip) ops.exp.register(array)(np.exp) -ops.new_full.register(array)(np.full_like) ops.log1p.register(array)(np.log1p) ops.max.register(array, array)(np.maximum) ops.min.register(array, array)(np.minimum) @@ -66,7 +65,6 @@ def _astype(x, dtype): ops.cat.register(typing.Tuple[typing.Union[array], ...])(np.concatenate) -ops.cat.register(typing.List[typing.Union[array]])(np.concatenate) @ops.cholesky.register(array) @@ -103,7 +101,6 @@ def _diagonal(x, dim1, dim2): @ops.einsum.register(typing.Tuple[typing.Union[array], ...]) -@ops.einsum.register(typing.List[typing.Union[array]]) def _einsum(operands, equation): return np.einsum(equation, *operands) @@ -192,6 +189,11 @@ def _min(x, y): return np.clip(x, a_min=None, a_max=y) +@ops.new_full.register(array) +def _new_full(x, shape, value): + return np.full(shape, value, dtype=x.dtype) + + @ops.new_arange.register(array) def _new_arange(x, start, stop, step): if step is not None: @@ -249,8 +251,8 @@ def _scatter(dest, indices, src): @ops.stack.register(typing.Tuple[typing.Union[array + (int, float)], ...]) -def _stack(dim, *x): - return np.stack(x, axis=dim) +def _stack(parts, dim=0): + return np.stack(parts, axis=dim) @ops.sum.register(array) diff --git a/funsor/ops/array.py b/funsor/ops/array.py index 04542ebe2..2286f4bdc 100644 --- a/funsor/ops/array.py +++ b/funsor/ops/array.py @@ -26,6 +26,7 @@ DISTRIBUTIVE_OPS, UNITS, BinaryOp, + FinitaryOp, Op, OpMeta, TernaryOp, @@ -38,14 +39,7 @@ # This is used only for pattern matching. array = (np.ndarray, np.generic) -arraylist = sum([(typing.Tuple[t, ...], typing.List[t]) for t in array], ()) - -all = UnaryOp.make(np.all) -amax = UnaryOp.make(np.amax) -amin = UnaryOp.make(np.amin) -any = UnaryOp.make(np.any) -prod = UnaryOp.make(np.prod) -sum = UnaryOp.make(np.sum) +arraylist = typing.Tuple[typing.Union[array], ...] sqrt.register(array)(np.sqrt) exp.register(array)(np.exp) @@ -54,14 +48,44 @@ atanh.register(array)(np.arctanh) +@UnaryOp.make +def all(x, dim=None): + return np.all(x, dim) + + +@UnaryOp.make +def any(x, dim=None): + return np.any(x, dim) + + +@UnaryOp.make +def amax(x, dim=None): + return np.amax(x, dim) + + +@UnaryOp.make +def amin(x, dim=None): + return np.amax(x, dim) + + +@UnaryOp.make +def sum(x, dim=None): + return np.sum(x, dim) + + +@UnaryOp.make +def prod(x, dim=None): + return np.prod(x, dim) + + @UnaryOp.make def isnan(x): return np.isnan(x) @UnaryOp.make -def full_like(prototype, shape, fill_value): - return np.full_like(prototype, shape, fill_value) +def full_like(prototype, fill_value): + return np.full_like(prototype, fill_value) @log.register(array) @@ -106,7 +130,7 @@ def _astype(x, dtype): return x.astype(dtype) -@UnaryOp.make +@FinitaryOp.make def cat(parts, axis): raise NotImplementedError @@ -162,7 +186,7 @@ def _diagonal(x, dim1, dim2): return np.diagonal(x, axis1=dim1, axis2=dim2) -@UnaryOp.make +@FinitaryOp.make def einsum(operands, equation): raise NotImplementedError @@ -338,8 +362,8 @@ def _scatter_add(destin, indices, source): return result -@UnaryOp.make -def stack(parts, axis=0): +@FinitaryOp.make +def stack(parts, dim=0): raise NotImplementedError diff --git a/funsor/ops/op.py b/funsor/ops/op.py index 4574864f8..abd4a9901 100644 --- a/funsor/ops/op.py +++ b/funsor/ops/op.py @@ -136,7 +136,7 @@ def __str__(self): def __call__(self, *args, **kwargs): # Normalize args, kwargs. cls = type(self) - bound = cls.signature.bind(*args, **kwargs) + bound = cls.signature.bind_partial(*args, **kwargs) for key, value in self.defaults.items(): bound.arguments.setdefault(key, value) args = bound.args @@ -248,6 +248,14 @@ class FinitaryOp(Op): arity = 1 # encoded as a tuple +# Convert list to tuple for easier typing. +@FinitaryOp.subclass_register(list) +def _list_to_tuple(cls, arg, *args, **kwargs): + arg = tuple(arg) + op = cls(*args, **kwargs) + return op(arg) + + class TransformOp(UnaryOp): def set_inv(self, fn): """ diff --git a/funsor/tensor.py b/funsor/tensor.py index f19b719b6..0805049d0 100644 --- a/funsor/tensor.py +++ b/funsor/tensor.py @@ -18,7 +18,7 @@ from . import ops from .delta import Delta from .domains import Array, ArrayType, Bint, Product, Real, Reals, find_domain -from .ops import GetitemOp, MatmulOp, Op, ReshapeOp +from .ops import BinaryOp, FinitaryOp, GetitemOp, MatmulOp, Op, ReshapeOp from .terms import ( Binary, Finitary, @@ -682,21 +682,21 @@ def eager_scatter_tensor(op, subs, source, reduced_vars): return Tensor(data, destin_inputs, output.dtype) -@eager.register(Binary, Op, Tensor, Number) +@eager.register(Binary, BinaryOp, Tensor, Number) def eager_binary_tensor_number(op, lhs, rhs): dtype = find_domain(op, lhs.output, rhs.output).dtype data = op(lhs.data, rhs.data) return Tensor(data, lhs.inputs, dtype) -@eager.register(Binary, Op, Number, Tensor) +@eager.register(Binary, BinaryOp, Number, Tensor) def eager_binary_number_tensor(op, lhs, rhs): dtype = find_domain(op, lhs.output, rhs.output).dtype data = op(lhs.data, rhs.data) return Tensor(data, rhs.inputs, dtype) -@eager.register(Binary, Op, Tensor, Tensor) +@eager.register(Binary, BinaryOp, Tensor, Tensor) def eager_binary_tensor_tensor(op, lhs, rhs): # Compute inputs and outputs. dtype = find_domain(op, lhs.output, rhs.output).dtype @@ -834,10 +834,10 @@ def eager_getitem_tensor_tensor(op, lhs, rhs): return Tensor(data, inputs, lhs.dtype) -@eager.register(Finitary, Op, typing.Tuple[Tensor, ...]) +@eager.register(Finitary, FinitaryOp, typing.Tuple[Tensor, ...]) def eager_finitary_generic_tensors(op, args): inputs, raw_args = align_tensors(*args) - raw_result = op(*raw_args) + raw_result = op(raw_args) return Tensor(raw_result, inputs, args[0].dtype) @@ -870,7 +870,7 @@ def eager_stack_homogeneous(name, *parts): shape = tuple(d.size for d in part_inputs.values()) + output.shape data = ops.stack( - 0, *[ops.expand(align_tensor(part_inputs, part), shape) for part in parts] + [ops.expand(align_tensor(part_inputs, part), shape) for part in parts] ) inputs = OrderedDict([(name, Bint[len(parts)])]) inputs.update(part_inputs) @@ -1107,7 +1107,7 @@ def Einsum(equation, *operands): @eager.register(Finitary, ops.EinsumOp, typing.Tuple[Tensor, ...]) def eager_einsum(op, operands): # Make new symbols for inputs of operands. - equation = op.equation + equation = op.defaults["equation"] inputs = OrderedDict() for x in operands: inputs.update(x.inputs) @@ -1172,31 +1172,6 @@ def tensordot(x, y, dims): return Einsum(equation, x, y) -def stack(parts, dim=0): - """ - Wrapper around :func:`torch.stack` or :func:`np.stack` to operate on real-valued Funsors. - - Note this operates only on the ``output`` tensor. To stack funsors in a - new named dim, instead use :class:`~funsor.terms.Stack`. - - :param tuple parts: A tuple of funsors. - :param int dim: A torch dim along which to stack. - :rtype: Funsor - """ - assert isinstance(dim, int) - assert isinstance(parts, tuple) - assert len(set(x.output for x in parts)) == 1 - shape = parts[0].output.shape - if dim >= 0: - dim = dim - len(shape) - 1 - assert dim < 0 - split = dim + len(shape) + 1 - shape = shape[:split] + (len(parts),) + shape[split:] - output = Array[parts[0].dtype, shape] - fn = functools.partial(ops.stack, dim) - return Function(fn, output, parts) - - REDUCE_OP_TO_NUMERIC = { ops.add: ops.sum, ops.mul: ops.prod, @@ -1218,6 +1193,5 @@ def stack(parts, dim=0): "align_tensors", "function", "ignore_jit_warnings", - "stack", "tensordot", ] diff --git a/funsor/torch/ops.py b/funsor/torch/ops.py index 4ff3e7389..f99b9bd25 100644 --- a/funsor/torch/ops.py +++ b/funsor/torch/ops.py @@ -23,6 +23,7 @@ ops.sqrt.register(torch.Tensor)(torch.sqrt) ops.tanh.register(torch.Tensor)(torch.tanh) ops.transpose.register(torch.Tensor)(torch.transpose) +ops.unsqueeze.register(torch.Tensor)(torch.unsqueeze) @ops.all.register(torch.Tensor) @@ -56,7 +57,6 @@ def _astype(x, dtype): ops.cat.register(typing.Tuple[torch.Tensor, ...])(torch.cat) -ops.cat.register(typing.List[torch.Tensor])(torch.cat) @ops.cholesky.register(torch.Tensor) @@ -268,7 +268,6 @@ def _scatter_add(destin, indices, source): ops.stack.register(typing.Tuple[torch.Tensor, ...])(torch.stack) -ops.stack.register(typing.List[torch.Tensor])(torch.stack) @ops.sum.register(torch.Tensor) diff --git a/test/test_distribution.py b/test/test_distribution.py index f64a1d7eb..0e036e2eb 100644 --- a/test/test_distribution.py +++ b/test/test_distribution.py @@ -18,7 +18,7 @@ from funsor.integrate import Integrate from funsor.interpretations import eager, lazy from funsor.interpreter import reinterpret -from funsor.tensor import Einsum, Tensor, numeric_array, stack +from funsor.tensor import Einsum, Tensor, numeric_array from funsor.terms import Independent, Variable, to_funsor from funsor.testing import ( assert_close, @@ -1205,7 +1205,7 @@ def test_beta_bernoulli_conjugate(batch_shape): conditional = dist.Bernoulli(probs=prior) reduced = (latent + conditional).reduce(ops.logaddexp, set(["prior"])) assert isinstance(reduced, dist.DirichletMultinomial) - concentration = stack((concentration0, concentration1), dim=-1) + concentration = ops.stack((concentration0, concentration1), dim=-1) assert_close(reduced.concentration, concentration) assert_close(reduced.total_count, Tensor(numeric_array(1.0))) diff --git a/test/test_tensor.py b/test/test_tensor.py index 50b8fa410..bacc5cdc0 100644 --- a/test/test_tensor.py +++ b/test/test_tensor.py @@ -22,7 +22,6 @@ Tensor, align_tensors, numeric_array, - stack, tensordot, ) from funsor.terms import Cat, Lambda, Number, Scatter, Slice, Stack, Variable @@ -870,7 +869,7 @@ def test_einsum(equation): inputs = inputs.split(",") tensors = [randn(tuple(sizes[d] for d in dims)) for dims in inputs] funsors = [Tensor(x) for x in tensors] - expected = Tensor(ops.einsum(equation, *tensors)) + expected = Tensor(ops.einsum(tensors, equation)) actual = Einsum(equation, *funsors) assert_close(actual, expected, atol=1e-5, rtol=None) @@ -895,7 +894,7 @@ def test_batched_einsum(equation, batch1, batch2): inputs, tensors = align_tensors(*funsors) batch = tuple(v.size for v in inputs.values()) tensors = [ops.expand(x, batch + f.shape) for (x, f) in zip(tensors, funsors)] - expected = Tensor(ops.einsum(_equation, *tensors), inputs) + expected = Tensor(ops.einsum(tensors, _equation), inputs) assert_close(actual, expected, atol=1e-5, rtol=None) @@ -941,8 +940,8 @@ def test_tensor_tensordot(x_shape, xy_shape, y_shape): ) def test_tensor_stack(n, shape, dim): tensors = [randn(shape) for _ in range(n)] - actual = stack(tuple(Tensor(t) for t in tensors), dim=dim) - expected = Tensor(ops.stack(dim, *tensors)) + actual = ops.stack(tuple(Tensor(t) for t in tensors), dim=dim) + expected = Tensor(ops.stack(tensors, dim)) assert_close(actual, expected) From 244a315f244618c00e92fab5ab7bbe2153d1f0b5 Mon Sep 17 00:00:00 2001 From: Fritz Obermeyer Date: Mon, 15 Mar 2021 23:16:32 -0400 Subject: [PATCH 08/22] Fix more ops --- funsor/domains.py | 13 +++++++++++++ funsor/ops/builtin.py | 4 ++-- funsor/tensor.py | 12 ++++++++++++ 3 files changed, 27 insertions(+), 2 deletions(-) diff --git a/funsor/domains.py b/funsor/domains.py index 8008759c3..4ccd3f6b0 100644 --- a/funsor/domains.py +++ b/funsor/domains.py @@ -354,6 +354,19 @@ def _transform_log_abs_det_jacobian(op, domain, codomain): return Real +@find_domain.register(ops.StackOp) +def _find_domain_stack(op, parts): + shape = broadcast_shape(*(x.shape for x in parts)) + dim = op.defaults["dim"] + if dim >= 0: + dim = dim - len(shape) - 1 + assert dim < 0 + split = dim + len(shape) + 1 + shape = shape[:split] + (len(parts),) + shape[split:] + output = Array[parts[0].dtype, shape] + return output + + @find_domain.register(ops.EinsumOp) def _find_domain_einsum(op, operands): equation = op.defaults["equation"] diff --git a/funsor/ops/builtin.py b/funsor/ops/builtin.py index 3dff980de..103cf3e1f 100644 --- a/funsor/ops/builtin.py +++ b/funsor/ops/builtin.py @@ -124,12 +124,12 @@ def safediv(x, y): @exp.set_log_abs_det_jacobian def log_abs_det_jacobian(x, y): - return add(x) + return x.sum() @log.set_log_abs_det_jacobian def log_abs_det_jacobian(x, y): - return -add(y) + return -y.sum() exp.set_inv(log) diff --git a/funsor/tensor.py b/funsor/tensor.py index 0805049d0..c34bb4ba9 100644 --- a/funsor/tensor.py +++ b/funsor/tensor.py @@ -834,6 +834,18 @@ def eager_getitem_tensor_tensor(op, lhs, rhs): return Tensor(data, inputs, lhs.dtype) +@eager.register(Finitary, ops.StackOp, typing.Tuple[Tensor, ...]) +def eager_finitary_stack(op, parts): + dim = op.defaults["dim"] + if dim >= 0: + event_dim = max(len(part.output.shape) for part in parts) + dim = dim - event_dim - 1 + assert dim < 0 + inputs, raw_parts = align_tensors(*parts) + raw_result = ops.stack(raw_parts, dim) + return Tensor(raw_result, inputs, parts[0].dtype) + + @eager.register(Finitary, FinitaryOp, typing.Tuple[Tensor, ...]) def eager_finitary_generic_tensors(op, args): inputs, raw_args = align_tensors(*args) From b6fc5c83230a134ca65e6124eeb63874194ff552 Mon Sep 17 00:00:00 2001 From: Fritz Obermeyer Date: Tue, 16 Mar 2021 16:02:20 -0400 Subject: [PATCH 09/22] Refactor TransformOp and WrappedTransformOp --- funsor/domains.py | 2 +- funsor/jax/distributions.py | 2 +- funsor/ops/op.py | 130 ++++++++++++++++++---------------- funsor/torch/distributions.py | 4 +- funsor/util.py | 5 +- test/test_ops.py | 14 ++-- 6 files changed, 84 insertions(+), 73 deletions(-) diff --git a/funsor/domains.py b/funsor/domains.py index 4ccd3f6b0..bfc358564 100644 --- a/funsor/domains.py +++ b/funsor/domains.py @@ -343,7 +343,7 @@ def _find_domain_associative_generic(op, *domains): @find_domain.register(ops.WrappedTransformOp) def _transform_find_domain(op, domain): - fn = op.default + fn = op.defaults["fn"] shape = fn.forward_shape(domain.shape) return Array[domain.dtype, shape] diff --git a/funsor/jax/distributions.py b/funsor/jax/distributions.py index 765d7a3af..3e4fe3990 100644 --- a/funsor/jax/distributions.py +++ b/funsor/jax/distributions.py @@ -243,7 +243,7 @@ def deltadist_to_data(funsor_dist, name_to_dim=None): @to_funsor.register(dist.transforms.Transform) def transform_to_funsor(tfm, output=None, dim_to_name=None, real_inputs=None): - op = ops.WrappedTransformOp(tfm) + op = ops.WrappedTransformOp(fn=tfm) name = next(real_inputs.keys()) if real_inputs else "value" return op(Variable(name, output)) diff --git a/funsor/ops/op.py b/funsor/ops/op.py index abd4a9901..df8b7be3b 100644 --- a/funsor/ops/op.py +++ b/funsor/ops/op.py @@ -6,6 +6,7 @@ import weakref from funsor.registry import PartialDispatcher +from funsor.util import methodof def apply(function, args, kwargs={}): @@ -170,7 +171,7 @@ def decorator(fn): return decorator @classmethod - def make(cls, fn=None, *, name=None, metaclass=OpMeta, module_name="funsor.ops"): + def make(cls, fn=None, *, name=None, metaclass=None, module_name="funsor.ops"): """ Factory to create a new :class:`Op` subclass together with a new default instance of that class. @@ -186,20 +187,26 @@ def make(cls, fn=None, *, name=None, metaclass=OpMeta, module_name="funsor.ops") # Support use as decorator. if fn is None: - return lambda fn: cls.make(fn, name=name, module_name=module_name) + return lambda fn: cls.make( + fn, name=name, metaclass=metaclass, module_name=module_name + ) assert callable(fn) if name is None: name = fn.__name__ assert isinstance(name, str) + if metaclass is None: + metaclass = type(cls) assert issubclass(metaclass, OpMeta) + classname = _snake_to_camel(name) + "Op" # e.g. scatter_add -> ScatterAddOp signature = inspect.Signature.from_callable(fn) op_class = metaclass( classname, (cls,), { + "__doc__": fn.__doc__, "name": name, "signature": signature, "default": staticmethod(fn), @@ -290,77 +297,76 @@ class WrappedOpMeta(OpMeta): Caching strategy is to key on ``id(backend_op)`` and forget values asap. """ - def __init__(cls, *args, **kwargs): - super().__init__(*args, **kwargs) - cls._instance_cache = weakref.WeakValueDictionary() + def hash_args_kwargs(self, args, kwargs): + if args: + (fn,) = args + if inspect.ismethod(fn): + args = id(fn.__self__), fn.__func__ # e.g. t.log_abs_det_jacobian + else: + args = (id(fn),) # e.g. t.inv + return super().hash_args_kwargs(args, kwargs) - def __call__(cls, fn): - if inspect.ismethod(fn): - key = id(fn.__self__), fn.__func__ # e.g. t.log_abs_det_jacobian - else: - key = id(fn) # e.g. t.inv - try: - return cls._instance_cache[key] - except KeyError: - op = super().__call__(fn) - op.fn = fn # Ensures the key id(fn) is not reused. - cls._instance_cache[key] = op - return op - - -class WrappedTransformOp(TransformOp, metaclass=WrappedOpMeta): + +@TransformOp.make(metaclass=WrappedOpMeta) +def wrapped_transform(x, fn, *, validate_args=True): """ Wrapper for a backend ``Transform`` object that provides ``.inv`` and ``.log_abs_det_jacobian``. This additionally validates shapes on the first :meth:`__call__`. """ + if not validate_args: + return fn(x) + + try: + # Check for shape metadata available after + # https://github.com/pytorch/pytorch/pull/50547 + # https://github.com/pytorch/pytorch/pull/50581 + # https://github.com/pyro-ppl/pyro/pull/2739 + # https://github.com/pyro-ppl/numpyro/pull/876 + fn.domain.event_dim + fn.codomain.event_dim + fn.forward_shape + except AttributeError: + backend = fn.__module__.split(".")[0] + raise NotImplementedError( + f"{fn} is missing shape metadata; try upgrading backend {backend}" + ) - def __init__(self, fn): - super().__init__(fn, name=type(fn).__name__) - self._is_validated = False - - def __call__(self, x): - if self._is_validated: - return super().__call__(x) - - try: - # Check for shape metadata available after - # https://github.com/pytorch/pytorch/pull/50547 - # https://github.com/pytorch/pytorch/pull/50581 - # https://github.com/pyro-ppl/pyro/pull/2739 - # https://github.com/pyro-ppl/numpyro/pull/876 - self.fn.domain.event_dim - self.fn.codomain.event_dim - self.fn.forward_shape - except AttributeError: - backend = self.fn.__module__.split(".")[0] - raise NotImplementedError( - f"{self.fn} is missing shape metadata; " - f"try upgrading backend {backend}" - ) + if len(x.shape) < fn.domain.event_dim: + raise ValueError(f"Too few dimensions for input, in {fn.__name_}") + event_shape = x.shape[len(x.shape) - fn.domain.event_dim :] + shape = fn.forward_shape(event_shape) + if len(shape) > fn.codomain.event_dim: + raise ValueError( + f"Cannot treat transform {fn.__name__} as an Op because it is batched" + ) - if len(x.shape) < self.fn.domain.event_dim: - raise ValueError(f"Too few dimensions for input, in {self.name}") - event_shape = x.shape[len(x.shape) - self.fn.domain.event_dim :] - shape = self.fn.forward_shape(event_shape) - if len(shape) > self.fn.codomain.event_dim: - raise ValueError( - f"Cannot treat transform {self.name} as an Op " "because it is batched" - ) - self._is_validated = True - return super().__call__(x) + return fn(x) - @property - def inv(self): - return WrappedTransformOp(self.fn.inv) - @property - def log_abs_det_jacobian(self): - return LogAbsDetJacobianOp(self.fn.log_abs_det_jacobian) +WrappedTransformOp = type(wrapped_transform) + + +@methodof(WrappedTransformOp) +@property +def inv(self): + fn = self.defaults["fn"] + return WrappedTransformOp(fn=fn.inv) + + +@methodof(WrappedTransformOp) +@property +def log_abs_det_jacobian(self): + fn = self.defaults["fn"] + return LogAbsDetJacobianOp(fn=fn.log_abs_det_jacobian) + + +@BinaryOp.make(metaclass=WrappedOpMeta) +def log_abs_det_jacobian(x, y, fn): + return fn(x, y) -class LogAbsDetJacobianOp(BinaryOp, metaclass=WrappedOpMeta): - pass +LogAbsDetJacobianOp = type(log_abs_det_jacobian) # Op registration tables. @@ -386,4 +392,6 @@ class LogAbsDetJacobianOp(BinaryOp, metaclass=WrappedOpMeta): "UnaryOp", "WrappedTransformOp", "declare_op_types", + "log_abs_det_jacobian", + "wrapped_transform", ] diff --git a/funsor/torch/distributions.py b/funsor/torch/distributions.py index 063317e09..0b9e202a2 100644 --- a/funsor/torch/distributions.py +++ b/funsor/torch/distributions.py @@ -234,7 +234,7 @@ def transform_to_torch_transform(op, name_to_dim=None): @op_to_torch_transform.register(ops.WrappedTransformOp) def transform_to_torch_transform(op, name_to_dim=None): - return op.fn + return op.defaults["fn"] @op_to_torch_transform.register(ops.ExpOp) @@ -281,7 +281,7 @@ def transform_to_data(expr, name_to_dim=None): @to_funsor.register(torch.distributions.Transform) def transform_to_funsor(tfm, output=None, dim_to_name=None, real_inputs=None): - op = ops.WrappedTransformOp(tfm) + op = ops.WrappedTransformOp(fn=tfm) name = next(real_inputs.keys()) if real_inputs else "value" return op(Variable(name, output)) diff --git a/funsor/util.py b/funsor/util.py index 738992a73..3de9c0e4f 100644 --- a/funsor/util.py +++ b/funsor/util.py @@ -237,7 +237,10 @@ def decorator(fn): if name_ is None: fn_ = fn while not hasattr(fn_, "__name__"): - fn_ = fn_.__func__ + if isinstance(fn_, property): + fn_ = fn_.fget + else: + fn_ = fn_.__func__ name_ = fn_.__name__ setattr(cls, name_, fn) return fn diff --git a/test/test_ops.py b/test/test_ops.py index 6a26a1621..f70b8d804 100644 --- a/test/test_ops.py +++ b/test/test_ops.py @@ -5,8 +5,8 @@ import pytest +from funsor import ops from funsor.distribution import BACKEND_TO_DISTRIBUTIONS_BACKEND -from funsor.ops import WrappedTransformOp from funsor.util import get_backend @@ -21,16 +21,16 @@ def dist(): def test_transform_op_cache(dist): t = dist.transforms.PowerTransform(0.5) - W = WrappedTransformOp - assert W(t) is W(t) - assert W(t).inv is W(t).inv - assert W(t.inv) is W(t).inv - assert W(t).log_abs_det_jacobian is W(t).log_abs_det_jacobian + W = ops.WrappedTransformOp + assert W(fn=t) is W(fn=t) + assert W(fn=t).inv is W(fn=t).inv + assert W(fn=t.inv) is W(fn=t).inv + assert W(fn=t).log_abs_det_jacobian is W(fn=t).log_abs_det_jacobian def test_transform_op_gc(dist): t = dist.transforms.PowerTransform(0.5) - op = WrappedTransformOp(t) + op = ops.WrappedTransformOp(fn=t) op_set = weakref.WeakSet() op_set.add(op) assert len(op_set) == 1 From a97231eb40298bea555d3535f5a3319b0e68e6b3 Mon Sep 17 00:00:00 2001 From: Fritz Obermeyer Date: Tue, 16 Mar 2021 16:49:54 -0400 Subject: [PATCH 10/22] Fix misc ops --- funsor/adjoint.py | 10 +--------- funsor/gaussian.py | 12 ++++++------ funsor/jax/ops.py | 8 +++++--- funsor/ops/array.py | 13 +++++++++++-- funsor/tensor.py | 5 +++++ funsor/terms.py | 5 +++++ 6 files changed, 33 insertions(+), 20 deletions(-) diff --git a/funsor/adjoint.py b/funsor/adjoint.py index 1df57a406..83152f97b 100644 --- a/funsor/adjoint.py +++ b/funsor/adjoint.py @@ -49,15 +49,7 @@ def interpret(self, cls, *args): self.tape.append((result, cls, args)) else: result = self._old_interpretation.interpret(cls, *args) - lazy_args = [ - self._eager_to_lazy.get( - id(arg) - if ops.is_numeric_array(arg) or not isinstance(arg, Hashable) - else arg, - arg, - ) - for arg in args - ] + lazy_args = [self._eager_to_lazy.get(arg, arg) for arg in args] self._eager_to_lazy[result] = reflect.interpret(cls, *lazy_args) return result diff --git a/funsor/gaussian.py b/funsor/gaussian.py index 35d8f329b..4cfb7b258 100644 --- a/funsor/gaussian.py +++ b/funsor/gaussian.py @@ -469,27 +469,27 @@ def _eager_subs_real(self, subs, remaining_subs): ) prec_aa = ops.cat( [ - ops.cat(-1, *[precision[..., i1, i2] for k2, i2 in slices if k2 in a]) + ops.cat([precision[..., i1, i2] for k2, i2 in slices if k2 in a], -1) for k1, i1 in slices if k1 in a ], -2, ) prec_ab = ops.cat( - *[ - ops.cat(-1, *[precision[..., i1, i2] for k2, i2 in slices if k2 in b]) + [ + ops.cat([precision[..., i1, i2] for k2, i2 in slices if k2 in b], -1) for k1, i1 in slices if k1 in a ], - -2 + -2, ) prec_bb = ops.cat( - *[ + [ ops.cat([precision[..., i1, i2] for k2, i2 in slices if k2 in b], -1) for k1, i1 in slices if k1 in b ], - -2 + -2, ) info_a = ops.cat([info_vec[..., i] for k, i in slices if k in a], -1) info_b = ops.cat([info_vec[..., i] for k, i in slices if k in b], -1) diff --git a/funsor/jax/ops.py b/funsor/jax/ops.py index c6d16ee13..be784071d 100644 --- a/funsor/jax/ops.py +++ b/funsor/jax/ops.py @@ -120,9 +120,11 @@ def _finfo(x): return np.finfo(x.dtype) -@ops.is_numeric_array.register(array) -def _is_numeric_array(x): - return True +for typ in array: + + @ops.is_numeric_array.register(typ) + def _is_numeric_array(x): + return True @ops.isnan.register(array) diff --git a/funsor/ops/array.py b/funsor/ops/array.py index 2286f4bdc..78b44ed50 100644 --- a/funsor/ops/array.py +++ b/funsor/ops/array.py @@ -4,6 +4,7 @@ import math import numbers import typing +from functools import singledispatch import numpy as np @@ -211,9 +212,17 @@ def finfo(x): return np.finfo(x.dtype) -@UnaryOp.make +# this isn't really a mathematical op +@singledispatch def is_numeric_array(x): - return True if isinstance(x, array) else False + return False + + +for typ in array: + + @is_numeric_array.register(typ) + def _is_numeric_array(x): + return True @logaddexp.register(array, array) diff --git a/funsor/tensor.py b/funsor/tensor.py index c34bb4ba9..7ed020bdd 100644 --- a/funsor/tensor.py +++ b/funsor/tensor.py @@ -1184,6 +1184,11 @@ def tensordot(x, y, dims): return Einsum(equation, x, y) +@ops.is_numeric_array.register(Tensor) +def _is_numeric_array(x): + return True + + REDUCE_OP_TO_NUMERIC = { ops.add: ops.sum, ops.mul: ops.prod, diff --git a/funsor/terms.py b/funsor/terms.py index 0bf122e85..7699ff385 100644 --- a/funsor/terms.py +++ b/funsor/terms.py @@ -1963,6 +1963,11 @@ def finitary_funsor(cls, arg, *args, **kwargs): return Finitary(op, arg) +@ops.is_numeric_array.register(Funsor) +def _is_numeric_array(x): + return False + + __all__ = [ "Approximate", "Binary", From 0ba0c31385defeceb77e6cdd8d1b4e0cbb903bb9 Mon Sep 17 00:00:00 2001 From: Fritz Obermeyer Date: Tue, 16 Mar 2021 17:03:33 -0400 Subject: [PATCH 11/22] lint --- funsor/adjoint.py | 1 - 1 file changed, 1 deletion(-) diff --git a/funsor/adjoint.py b/funsor/adjoint.py index 83152f97b..bb003ec86 100644 --- a/funsor/adjoint.py +++ b/funsor/adjoint.py @@ -2,7 +2,6 @@ # SPDX-License-Identifier: Apache-2.0 from collections import defaultdict -from collections.abc import Hashable from funsor.cnf import Contraction, null from funsor.interpretations import Interpretation, reflect From 744634a6e2e126b8efa4a787bbad52c8408f85d1 Mon Sep 17 00:00:00 2001 From: Fritz Obermeyer Date: Tue, 16 Mar 2021 17:31:04 -0400 Subject: [PATCH 12/22] Work around signature parsing in Python 3.6 --- funsor/ops/op.py | 21 ++++++++++++++++++++- 1 file changed, 20 insertions(+), 1 deletion(-) diff --git a/funsor/ops/op.py b/funsor/ops/op.py index df8b7be3b..34bfe20b7 100644 --- a/funsor/ops/op.py +++ b/funsor/ops/op.py @@ -3,6 +3,8 @@ import functools import inspect +import math +import operator import weakref from funsor.registry import PartialDispatcher @@ -13,6 +15,23 @@ def apply(function, args, kwargs={}): return function(*args, **kwargs) +def _get_signature(fn): + try: + return inspect.Signature.from_callable(fn) + except ValueError as e: + # In Python <=3.6, attempt to parse docstring of builtins. + if any(fn is getattr(lib, fn.__name__, None) for lib in (math, operator)): + if fn.__doc__.startswith(f"{fn.__name__}(x)"): + return inspect.Signature.from_callable(lambda x: None) + if fn.__doc__.startswith(f"{fn.__name__}(a)"): + return inspect.Signature.from_callable(lambda a: None) + if fn.__doc__.startswith(f"{fn.__name__}(obj)"): + return inspect.Signature.from_callable(lambda obj: None) + if fn.__doc__.startswith(f"{fn.__name__}(a, b)"): + return inspect.Signature.from_callable(lambda a, b: None) + raise e from None + + def _iter_subclasses(cls): yield cls for subcls in cls.__subclasses__(): @@ -201,7 +220,7 @@ def make(cls, fn=None, *, name=None, metaclass=None, module_name="funsor.ops"): assert issubclass(metaclass, OpMeta) classname = _snake_to_camel(name) + "Op" # e.g. scatter_add -> ScatterAddOp - signature = inspect.Signature.from_callable(fn) + signature = _get_signature(fn) op_class = metaclass( classname, (cls,), From b30ac7f324e7cd247dca3c6029c586fad5454ad1 Mon Sep 17 00:00:00 2001 From: Fritz Obermeyer Date: Tue, 16 Mar 2021 17:51:55 -0400 Subject: [PATCH 13/22] Fix test_cnf.py --- funsor/ops/array.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/funsor/ops/array.py b/funsor/ops/array.py index 78b44ed50..58551d205 100644 --- a/funsor/ops/array.py +++ b/funsor/ops/array.py @@ -60,13 +60,13 @@ def any(x, dim=None): @UnaryOp.make -def amax(x, dim=None): - return np.amax(x, dim) +def amax(x, dim=None, keepdims=False): + return np.amax(x, dim, keepdims=keepdims) @UnaryOp.make -def amin(x, dim=None): - return np.amax(x, dim) +def amin(x, dim=None, keepdims=False): + return np.amax(x, dim, keepdims=keepdims) @UnaryOp.make From 1ce8e89bab30e4c46eb808f44bffb8fb6504c54a Mon Sep 17 00:00:00 2001 From: Fritz Obermeyer Date: Wed, 17 Mar 2021 11:12:40 -0400 Subject: [PATCH 14/22] Disable obsolete/questionable test --- test/test_distribution_generic.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/test/test_distribution_generic.py b/test/test_distribution_generic.py index 5ffa99bf6..ebef850e8 100644 --- a/test/test_distribution_generic.py +++ b/test/test_distribution_generic.py @@ -657,13 +657,12 @@ def test_generic_distribution_to_funsor(case): @pytest.mark.parametrize("case", TEST_CASES, ids=str) -@pytest.mark.parametrize("use_lazy", [True, False]) -def test_generic_log_prob(case, use_lazy): +def test_generic_log_prob(case): raw_dist = case.get_dist() expected_value_domain = case.expected_value_domain dim_to_name, name_to_dim = _default_dim_to_name(raw_dist.batch_shape) - with (eager_no_dists if use_lazy else eager): + with eager_no_dists: with xfail_if_not_implemented(match="try upgrading backend"): # some distributions have nontrivial eager patterns funsor_dist = to_funsor( From 920f345bd62e123f3549c5b04a4323c46b1afc16 Mon Sep 17 00:00:00 2001 From: Fritz Obermeyer Date: Wed, 17 Mar 2021 12:54:29 -0400 Subject: [PATCH 15/22] Add info to assertions failing only on ci --- funsor/terms.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/funsor/terms.py b/funsor/terms.py index 7699ff385..a0fc1ad0d 100644 --- a/funsor/terms.py +++ b/funsor/terms.py @@ -1555,7 +1555,7 @@ def __init__(self, name, parts, part_name=None): assert isinstance(parts, tuple) assert isinstance(part_name, str) assert parts - assert all(part_name in x.inputs for x in parts) + assert all(part_name in x.inputs for x in parts), (part_name, x.inputs) if part_name != name: assert not any(name in x.inputs for x in parts) assert len(set(x.output for x in parts)) == 1 @@ -1705,7 +1705,7 @@ def __init__(self, fn, reals_var, bint_var, diag_var): assert isinstance(fn, Funsor) assert isinstance(reals_var, str) assert isinstance(bint_var, str) - assert bint_var in fn.inputs + assert bint_var in fn.inputs, (bint_var, fn.inputs) assert isinstance(fn.inputs[bint_var].dtype, int) assert isinstance(diag_var, str) assert diag_var in fn.inputs From cc750a4af3f56552abccb881006f64f9e41da89b Mon Sep 17 00:00:00 2001 From: Fritz Obermeyer Date: Wed, 17 Mar 2021 13:02:38 -0400 Subject: [PATCH 16/22] Clean up error printing --- funsor/terms.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/funsor/terms.py b/funsor/terms.py index a0fc1ad0d..216702045 100644 --- a/funsor/terms.py +++ b/funsor/terms.py @@ -269,9 +269,12 @@ def __annotations__(self): return type_hints def __repr__(self): - return "{}({})".format( - type(self).__name__, ", ".join(map(repr, self._ast_values)) - ) + try: + ast_values = self._ast_values + except AttributeError: + # E.g. when printing errors during __init__, before ._ast_va. + return f"{type(self).__name__}(...)" + return "{}({})".format(type(self).__name__, ", ".join(map(repr, ast_values))) def __str__(self): return "{}({})".format( From d7d728178b0e9812a93efb8a67f45e429f9e7108 Mon Sep 17 00:00:00 2001 From: Fritz Obermeyer Date: Wed, 17 Mar 2021 13:36:12 -0400 Subject: [PATCH 17/22] Fix typo --- funsor/terms.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/funsor/terms.py b/funsor/terms.py index 216702045..5ea2a6dd6 100644 --- a/funsor/terms.py +++ b/funsor/terms.py @@ -272,7 +272,7 @@ def __repr__(self): try: ast_values = self._ast_values except AttributeError: - # E.g. when printing errors during __init__, before ._ast_va. + # E.g. when printing errors during __init__, before ._ast_values is set. return f"{type(self).__name__}(...)" return "{}({})".format(type(self).__name__, ", ".join(map(repr, ast_values))) @@ -1558,7 +1558,8 @@ def __init__(self, name, parts, part_name=None): assert isinstance(parts, tuple) assert isinstance(part_name, str) assert parts - assert all(part_name in x.inputs for x in parts), (part_name, x.inputs) + for part in parts: + assert part_name in part.inputs, (part_name, part.inputs) if part_name != name: assert not any(name in x.inputs for x in parts) assert len(set(x.output for x in parts)) == 1 From 7002fd1b1ec1af1540f9ba66253d2218e6e83401 Mon Sep 17 00:00:00 2001 From: Fritz Obermeyer Date: Wed, 17 Mar 2021 15:47:16 -0400 Subject: [PATCH 18/22] Fix is_numeric_array(funsor.Tensor) --- funsor/tensor.py | 5 ----- 1 file changed, 5 deletions(-) diff --git a/funsor/tensor.py b/funsor/tensor.py index 7ed020bdd..c34bb4ba9 100644 --- a/funsor/tensor.py +++ b/funsor/tensor.py @@ -1184,11 +1184,6 @@ def tensordot(x, y, dims): return Einsum(equation, x, y) -@ops.is_numeric_array.register(Tensor) -def _is_numeric_array(x): - return True - - REDUCE_OP_TO_NUMERIC = { ops.add: ops.sum, ops.mul: ops.prod, From 11aa7d7968acc9feaa6e2a69eaac6f2ad9878f01 Mon Sep 17 00:00:00 2001 From: Fritz Obermeyer Date: Wed, 17 Mar 2021 18:01:00 -0400 Subject: [PATCH 19/22] Revert some changes to is_numeric_array and funsor.adjoint --- funsor/adjoint.py | 11 ++++++++++- funsor/ops/array.py | 3 +++ funsor/terms.py | 5 ----- 3 files changed, 13 insertions(+), 6 deletions(-) diff --git a/funsor/adjoint.py b/funsor/adjoint.py index bb003ec86..1df57a406 100644 --- a/funsor/adjoint.py +++ b/funsor/adjoint.py @@ -2,6 +2,7 @@ # SPDX-License-Identifier: Apache-2.0 from collections import defaultdict +from collections.abc import Hashable from funsor.cnf import Contraction, null from funsor.interpretations import Interpretation, reflect @@ -48,7 +49,15 @@ def interpret(self, cls, *args): self.tape.append((result, cls, args)) else: result = self._old_interpretation.interpret(cls, *args) - lazy_args = [self._eager_to_lazy.get(arg, arg) for arg in args] + lazy_args = [ + self._eager_to_lazy.get( + id(arg) + if ops.is_numeric_array(arg) or not isinstance(arg, Hashable) + else arg, + arg, + ) + for arg in args + ] self._eager_to_lazy[result] = reflect.interpret(cls, *lazy_args) return result diff --git a/funsor/ops/array.py b/funsor/ops/array.py index 58551d205..3e6bc9289 100644 --- a/funsor/ops/array.py +++ b/funsor/ops/array.py @@ -215,6 +215,9 @@ def finfo(x): # this isn't really a mathematical op @singledispatch def is_numeric_array(x): + """ + Returns whether an object is a ground numeric array. + """ return False diff --git a/funsor/terms.py b/funsor/terms.py index 5ea2a6dd6..d778f0d33 100644 --- a/funsor/terms.py +++ b/funsor/terms.py @@ -1967,11 +1967,6 @@ def finitary_funsor(cls, arg, *args, **kwargs): return Finitary(op, arg) -@ops.is_numeric_array.register(Funsor) -def _is_numeric_array(x): - return False - - __all__ = [ "Approximate", "Binary", From 8233bc501677e9455ad21eb69858152428f6ce6e Mon Sep 17 00:00:00 2001 From: Fritz Obermeyer Date: Wed, 17 Mar 2021 19:09:00 -0400 Subject: [PATCH 20/22] Set JAX_ENABLE_X64=1 on ci --- .github/workflows/ci.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 2b5b0e12b..c22fc20fd 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -86,4 +86,4 @@ jobs: pip freeze - name: Run test run: | - CI=1 FUNSOR_BACKEND=jax make test + CI=1 JAX_ENABLE_X64=1 FUNSOR_BACKEND=jax make test From 0b0e274dfb570a26f119fa7693298d36bf85f75e Mon Sep 17 00:00:00 2001 From: Fritz Obermeyer Date: Wed, 17 Mar 2021 19:25:05 -0400 Subject: [PATCH 21/22] Increase number of samples in test_gaussian_mixture_distribution --- test/test_samplers.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/test_samplers.py b/test/test_samplers.py index 46d9c6f41..3f96cc99a 100644 --- a/test/test_samplers.py +++ b/test/test_samplers.py @@ -339,7 +339,7 @@ def test_gaussian_distribution(event_inputs, batch_inputs): ids=id_from_inputs, ) def test_gaussian_mixture_distribution(batch_inputs, event_inputs): - num_samples = 100000 + num_samples = 200000 sample_inputs = OrderedDict(particle=Bint[num_samples]) be_inputs = OrderedDict(batch_inputs + event_inputs) int_inputs = OrderedDict((k, d) for k, d in be_inputs.items() if d.dtype != "real") From ff6fa187e675cfa2d24ad8b33efdb875997e05d6 Mon Sep 17 00:00:00 2001 From: Fritz Obermeyer Date: Wed, 17 Mar 2021 20:15:14 -0400 Subject: [PATCH 22/22] Increase number of samples in test_dirichlet_sample --- test/test_distribution.py | 1 + 1 file changed, 1 insertion(+) diff --git a/test/test_distribution.py b/test/test_distribution.py index 0e036e2eb..5fa75b463 100644 --- a/test/test_distribution.py +++ b/test/test_distribution.py @@ -972,6 +972,7 @@ def test_dirichlet_sample(batch_shape, sample_inputs, event_shape, reparametrize params, sample_inputs, inputs, + num_samples=200000, atol=1e-2 if reparametrized else 1e-1, )