diff --git a/jax/api.py b/jax/api.py index 485c381deade..7ef5dff07ca1 100644 --- a/jax/api.py +++ b/jax/api.py @@ -27,8 +27,8 @@ import collections import functools import itertools as it +import operator as op import os -import string import threading from warnings import warn @@ -51,7 +51,6 @@ from .lib.xla_bridge import (device_count, local_device_count, devices, local_devices, host_id, host_ids, host_count) from .abstract_arrays import ConcreteArray, ShapedArray, raise_to_shaped -from .interpreters.masking import eval_polymorphic_shape, Poly, Mon from .interpreters import partial_eval as pe from .interpreters import xla from .interpreters import pxla @@ -59,6 +58,7 @@ from .interpreters import batching from .interpreters import parallel from .interpreters import masking +from .interpreters.masking import shapecheck, ensure_poly from .config import flags, config, bool_env map = safe_map @@ -1053,23 +1053,24 @@ def wrapped_fun(args, logical_env): out_shapes = map(masking.finalize_spec, out_specs, map(onp.shape, outs)) if not out_shapes == list(out_shapes_): raise masking.ShapeError - if not all(onp.shape(out) == eval_polymorphic_shape(shape, padded_env) - for out, shape in zip(outs, out_shapes)): + if not all(onp.shape(out) == masking.eval_shape_expr(padded_env, expr) + for out, expr in zip(outs, out_shapes)): raise masking.ShapeError return tree_unflatten(out_tree(), outs) return wrapped_fun def _remap_ids(names, shape_spec): - return masking.ShapeSpec(Poly({Mon({names[id] : deg for id, deg in mon.items()}) + ShapeSpec, Poly, Mon = masking.ShapeSpec, masking.Poly, masking.Mon + mdim = masking.monomorphic_dim + return ShapeSpec(Poly({Mon({names[id] : deg for id, deg in mon.items()}) : coeff for mon, coeff in poly.items()}) - if poly is not masking._monomorphic_dim else - masking._monomorphic_dim for poly in shape_spec) + if poly is not mdim else mdim for poly in shape_spec) def _bind_shapes(shape_exprs, shapes): env = {} for shape_expr, shape in zip(shape_exprs, shapes): for poly, d in zip(shape_expr, shape): - if type(poly) is not Poly or poly.is_constant: + if ensure_poly(poly).is_constant: continue else: (binder,), = poly # TODO generalize to handle striding @@ -1084,13 +1085,16 @@ def shapecheck(in_shapes, out_shape, fun): out_shapes, out_tree = tree_flatten(out_shape) out_shapes = map(masking.parse_spec, out_shapes) flat_fun, out_tree_ = flatten_fun_nokwargs(lu.wrap_init(fun), in_tree) - avals = map(partial(ShapedArray, dtype=onp.float32), in_shapes) - out_shapes_ = [o.shape for o in pe.abstract_eval_fun(flat_fun.call_wrapped, *avals)] + out_shapes_ = masking.shapecheck(flat_fun, in_shapes) if out_tree != out_tree_(): raise TypeError("pytree mismatch") - if not all(map(masking._shape_spec_consistent, out_shapes, out_shapes_)): + if not all(map(_shape_spec_consistent, out_shapes, out_shapes_)): raise masking.ShapeError return fun +def _shape_spec_consistent(spec, expr): + return all(a == b for a, b in zip(spec, expr) if a is not masking.monomorphic_dim) + + def jvp(fun, primals, tangents): """Computes a (forward-mode) Jacobian-vector product of `fun`. diff --git a/jax/interpreters/masking.py b/jax/interpreters/masking.py index 323b7403711f..c1435608ee56 100644 --- a/jax/interpreters/masking.py +++ b/jax/interpreters/masking.py @@ -12,10 +12,12 @@ # See the License for the specific language governing permissions and # limitations under the License. + from contextlib import contextmanager -from collections import Counter, namedtuple -from functools import partial -from itertools import chain, product +from collections import defaultdict, Counter, namedtuple +import functools +from functools import partial, wraps +import itertools as it import operator as op import string @@ -24,30 +26,19 @@ from .. import abstract_arrays from .. import core from ..core import Trace, Tracer -from ..util import safe_map, safe_zip, unzip2, prod +from ..util import unzip2, safe_map, safe_zip, curry from ..abstract_arrays import ShapedArray from .. import linear_util as lu map = safe_map zip = safe_zip -shape_parameterized_primitive_rules = {} -masking_rules = {} +def prod(xs): + xs = list(xs) + return functools.reduce(op.mul, xs) if xs else 1 -def defvectorized(prim): - masking_rules[prim] = partial(vectorized_masking_rule, prim) -def defnaryop(prim): - masking_rules[prim] = partial(naryop_masking_rule, prim) - -def vectorized_masking_rule(prim, padded_vals, logical_shapes, **params): - del logical_shapes # Unused. - padded_val, = padded_vals - return prim.bind(padded_val, **params) - -def naryop_masking_rule(prim, padded_vals, logical_shapes): - del logical_shapes # Unused. - return prim.bind(*padded_vals) +### main transformation functions ShapeEnvs = namedtuple("ShapeEnvs", ["logical", "padded"]) shape_envs = ShapeEnvs({}, {}) # TODO(mattjj): make this a stack for efficiency @@ -55,17 +46,29 @@ def naryop_masking_rule(prim, padded_vals, logical_shapes): @contextmanager def extend_shape_envs(logical_env, padded_env): global shape_envs - new_logical = dict(chain(shape_envs.logical.items(), logical_env.items())) - new_padded = dict(chain(shape_envs.padded.items(), padded_env.items())) + new_logical = dict(it.chain(shape_envs.logical.items(), logical_env.items())) + new_padded = dict(it.chain(shape_envs.padded.items(), padded_env.items())) shape_envs, prev = ShapeEnvs(new_logical, new_padded), shape_envs yield shape_envs = prev -def shape_as_value(shape): - return eval_polymorphic_shape(shape, shape_envs.logical) +def is_polymorphic(shape): + return any(map(lambda d: isinstance(d, Poly), shape)) + +def shape_as_value(expr): + if type(expr) is tuple and is_polymorphic(expr): + return tuple(eval_dim_expr(shape_envs.logical, d) if type(d) is Poly else d + for d in expr) + else: + return expr + +def padded_shape_as_value(expr): + if type(expr) is tuple and is_polymorphic(expr): + return tuple(eval_dim_expr(shape_envs.padded, d) if type(d) is Poly else d + for d in expr) + else: + return expr -def padded_shape_as_value(shape): - return eval_polymorphic_shape(shape, shape_envs.padded) def mask_fun(fun, logical_env, padded_env, in_vals, shape_exprs): with core.new_master(MaskTrace) as master: @@ -85,39 +88,29 @@ def mask_subtrace(master, in_vals, shape_exprs): out_vals, out_shapes = unzip2((t.val, t.shape_expr) for t in out_tracers) yield out_vals, out_shapes -def to_index(x): - """Like operator.index, but allowing polymorphic dimensions. - Not implemented as `Poly.__index__`, since operator.index only allows ints.""" - return x if type(x) is Poly else op.index(x) - -def eval_polymorphic_shape(shape, values_dict): - return tuple(dim.evaluate(values_dict) if type(dim) is Poly else dim - for dim in shape) - -def _ensure_poly(p): - if type(p) is Poly: +def ensure_poly(p): + if isinstance(p, Poly): return p - return Poly({Mon(): p}) + return constant_poly(int(p)) -class Poly(dict): +class Poly(Counter): """Polynomial with integer coefficients, usable as element in a polymorphic shape. type Poly = Map Mon Int -- monomials to coeffs type Mon = Map Str Int """ - def __init__(self, coeffs): # Makes sure Polynomials are always in canonical form to simplify operators: - coeffs = {mon: op.index(coeff) for mon, coeff in coeffs.items() if coeff != 0} + coeffs = {mon: coeff for mon, coeff in coeffs.items() if coeff != 0} coeffs = {Mon(): 0} if len(coeffs) == 0 else coeffs super().__init__(coeffs) def __add__(self, other): coeffs = self.copy() - for mon, coeff in _ensure_poly(other).items(): + for mon, coeff in ensure_poly(other).items(): coeffs[mon] = coeffs.get(mon, 0) + coeff return Poly(coeffs) @@ -131,9 +124,10 @@ def __neg__(self): def __mul__(self, other): coeffs = dict() for (mon1, coeff1), (mon2, coeff2) \ - in product(self.items(), _ensure_poly(other).items()): - mon = mon1 * mon2 - coeffs[mon] = coeffs.get(mon, 0) + coeff1 * coeff2 + in it.product(self.items(), ensure_poly(other).items()): + mon = Mon(mon1 + mon2) # add monomials' id degrees + coeff = coeff1 * coeff2 # multiply integer coeffs + coeffs[mon] = coeffs.get(mon, 0) + coeff # accumulate coeffs return Poly(coeffs) @@ -156,7 +150,9 @@ def __mod__(self, divisor): def __divmod__(self, divisor): if self.is_constant: - return divmod(int(self), divisor) + q, r = divmod(int(self), divisor) + + return constant_poly(q), r def divided(count): q, r = divmod(count, divisor) @@ -167,29 +163,29 @@ def divided(count): return Poly( {k: coeff // divisor if k.degree == 0 else divided(coeff) - for k, coeff in self.items()}), self.get(Mon(), 0) % divisor + for k, coeff in self.items()}), self[Mon()] % divisor def __hash__(self): - return hash(tuple(sorted(self.items()))) + return hash(super()) def __eq__(self, other): - return dict.__eq__(self, _ensure_poly(other)) + return super().__eq__(ensure_poly(other)) def __ne__(self, other): return not self == other def __ge__(self, other): - other = _ensure_poly(other) + other = ensure_poly(other) if other.is_constant and self.is_constant: return int(self) >= int(other) if other.is_constant and int(other) <= 1: - # Assume polynomials > 0, allowing to use shape rules of binops, conv: - return True + # Assume polynomials > 0, allowing to use shape rules of binops, conv: + return True if self.is_constant and int(self) <= 0: - return False # See above. + return False # See above. if self == other: return True @@ -198,27 +194,22 @@ def __ge__(self, other): .format(self, other)) def __le__(self, other): - return _ensure_poly(other) >= self + return ensure_poly(other) >= self def __lt__(self, other): return not (self >= other) def __gt__(self, other): - return not (_ensure_poly(other) >= self) + return not (ensure_poly(other) >= self) def __str__(self): - return ' + '.join('{} {}'.format(v, k) - if (v != 1 or k.degree == 0) else str(k) + return ' + '.join('{} {}'.format(v, k) if (v != 1 or k.degree == 0) else str(k) for k, v in sorted(self.items())).strip() def __int__(self): assert self.is_constant - return op.index(next(iter(self.values()))) - - def evaluate(self, values_dict): - return sum(coeff * prod([values_dict[id] ** deg for id, deg in mon.items()]) - for mon, coeff in self.items()) + return int(next(iter(self.values()))) @property def is_constant(self): @@ -227,7 +218,7 @@ def is_constant(self): abstract_arrays._DIMENSION_TYPES.add(Poly) -class Mon(dict): # type Mon = Map Id Int -- ids to degrees +class Mon(Counter): # type Mon = Map Id Int -- ids to degrees def __hash__(self): return hash(tuple(self.items())) @@ -241,13 +232,34 @@ def __lt__(self, other): other_key = other.degree, tuple(sorted(other)) return self_key < other_key - def __mul__(self, other): - return Mon(Counter(self) + Counter(other)) - @property def degree(self): return sum(self.values()) +def eval_shape_expr(env, expr): + return tuple(eval_dim_expr(env, poly) for poly in expr) + +def eval_dim_expr(env, poly): + terms = [mul(coeff, prod([pow(env[id], deg) for id, deg in mon.items()])) + for mon, coeff in poly.items()] + return sum(terms) if len(terms) > 1 else terms[0] + +def pow(x, deg): + try: + deg = int(deg) + except: + return x ** deg + else: + return 1 if deg == 0 else x if deg == 1 else x ** deg + +def mul(coeff, mon): + try: + coeff = int(coeff) + except: + return coeff * mon + else: + return 0 if coeff == 0 else mon if coeff == 1 else coeff * mon + class ShapeError(Exception): pass class ShapeSyntaxError(Exception): pass @@ -267,7 +279,7 @@ class ShapeSyntaxError(Exception): pass # dims ::= dim ',' dims | '' # dim ::= str | int | dim '*' dim | dim '+' dim | '_' # -# ShapeSpecs can have some monomorphic dims inside them, +# ShapeSpecs encode ShapeExprs but can have some monomorphic dims inside them, # which must be replaced with concrete shapes when known. class ShapeSpec(tuple): @@ -275,7 +287,7 @@ def __str__(self): return 'ShapeSpec({})'.format(', '.join(map(str, self))) def finalize_spec(spec, shape): - return tuple(_parse_lit(d) if e is _monomorphic_dim else e + return tuple(parse_lit(d) if e is monomorphic_dim else e for e, d in zip(spec, shape)) def parse_spec(spec=''): @@ -284,33 +296,35 @@ def parse_spec(spec=''): if spec[0] == '(': if spec[-1] != ')': raise ShapeSyntaxError(spec) spec = spec[1:-1] - dims = map(_parse_dim, spec.replace(' ', '').strip(',').split(',')) + dims = map(parse_dim, spec.replace(' ', '').strip(',').split(',')) return ShapeSpec(dims) -def _parse_dim(spec): +def parse_dim(spec): if '+' in spec: - return onp.sum(map(_parse_dim, spec.split('+'))) + terms = map(parse_dim, spec.split('+')) + return functools.reduce(op.add, terms) elif '*' in spec: - return prod(map(_parse_dim, spec.split('*'))) + terms = map(parse_dim, spec.split('*')) + return functools.reduce(op.mul, terms) elif spec.isdigit() or spec.startswith('-') and spec[1:].isdigit(): - return _parse_lit(spec) - elif spec in _identifiers: - return _parse_id(spec) + return parse_lit(spec) + elif spec in identifiers: + return parse_id(spec) elif spec == '_': - return _monomorphic_dim + return monomorphic_dim else: raise ShapeSyntaxError(spec) +digits = frozenset(string.digits) +identifiers = frozenset(string.ascii_lowercase) -_identifiers = frozenset(string.ascii_lowercase) - -def _parse_id(name): return Poly({Mon({name: 1}): 1}) - -def _parse_lit(val_str): return Poly({Mon(): int(val_str)}) +def parse_id(name): return Poly({Mon({name: 1}): 1}) +def parse_lit(val_str): return constant_poly(int(val_str)) +def constant_poly(val): return Poly({Mon(): val}) class MonomorphicDim(object): def __str__(self): return '_' +monomorphic_dim = MonomorphicDim() -_monomorphic_dim = MonomorphicDim() # Two convenient ways to provide shape annotations: # 1. '(m, n)' @@ -318,13 +332,14 @@ def __str__(self): return '_' class S_(object): def __getitem__(self, idx): - return parse_spec(('(' + ','.join(map(str, idx)) + ')') - if type(idx) is tuple else str(idx)) - + if type(idx) is tuple: + return parse_spec('(' + ','.join(map(str, idx)) + ')') + else: + return parse_spec(str(idx)) s_ = S_() -def _shape_spec_consistent(spec, expr): - return all(a == b for a, b in zip(spec, expr) if a is not _monomorphic_dim) + +### automasking tracer machinery class MaskTracer(Tracer): __slots__ = ["val", "shape_expr"] @@ -339,7 +354,7 @@ def aval(self): return ShapedArray(self.shape_expr, self.val.dtype) def is_pure(self): - return all(type(poly) is not Poly or poly.is_constant for poly in self.shape_expr) + return all(ensure_poly(poly).is_constant for poly in self.shape_expr) def full_lower(self): if self.is_pure(): @@ -347,7 +362,6 @@ def full_lower(self): else: return self - class MaskTrace(Trace): def pure(self, val): return MaskTracer(self, val, onp.shape(val)) @@ -364,10 +378,8 @@ def process_primitive(self, primitive, tracers, params): rule = shape_parameterized_primitive_rules[primitive] out, out_shape = rule(shape_envs, vals, shape_exprs, **params) else: - avals = [t.aval for t in tracers] - out = primitive.abstract_eval(*avals, **params) - out_shape = [o.shape for o in out] if primitive.multiple_results else out.shape - logical_shapes = map(partial(eval_polymorphic_shape, values_dict=shape_envs.logical), shape_exprs) + out_shape = shape_rules[primitive](*(t.aval for t in tracers), **params) + logical_shapes = map(partial(eval_shape_expr, shape_envs.logical), shape_exprs) out = masking_rules[primitive](vals, logical_shapes, **params) if not primitive.multiple_results: return MaskTracer(self, out, out_shape) @@ -409,7 +421,6 @@ def naryop_shape_rule(shape_exprs): raise ShapeError - ### definition-time (import-time) shape checker tracer machinery def shapecheck(fun, in_shapes): diff --git a/jax/interpreters/xla.py b/jax/interpreters/xla.py index 6fecfa834c18..e95e97195af5 100644 --- a/jax/interpreters/xla.py +++ b/jax/interpreters/xla.py @@ -965,6 +965,7 @@ def _device_put_impl(x, device=None): device_put_p.def_impl(_device_put_impl) pe.custom_partial_eval_rules[device_put_p] = lambda trace, x, **params: x ad.deflinear(device_put_p, lambda cotangent, **kwargs: [cotangent]) +masking.shape_rules[device_put_p] = lambda x, **_: x.shape masking.defvectorized(device_put_p) diff --git a/jax/lax/lax.py b/jax/lax/lax.py index 7f31d3fe41b5..89bb5c4485a5 100644 --- a/jax/lax/lax.py +++ b/jax/lax/lax.py @@ -38,7 +38,6 @@ from ..abstract_arrays import (UnshapedArray, ShapedArray, ConcreteArray, AbstractToken, array_types, make_shaped_array, raise_to_shaped, abstract_token, canonicalize_shape) -from ..interpreters.masking import to_index from ..interpreters import partial_eval as pe from ..interpreters import xla from ..interpreters import pxla @@ -1065,7 +1064,7 @@ def iota(dtype, size): `_ operator. """ - size = to_index(size) + size = int(size) dtype = dtypes.canonicalize_dtype(dtype) lazy_expr = lazy.iota(dtype, size) aval = ShapedArray((size,), dtype) @@ -1510,6 +1509,7 @@ def standard_primitive(shape_rule, dtype_rule, name, translation_rule=None): prim.def_impl(partial(xla.apply_primitive, prim)) prim.def_abstract_eval(partial(standard_abstract_eval, prim, shape_rule, dtype_rule)) xla.translations[prim] = translation_rule or partial(standard_translate, name) + masking.shape_rules[prim] = shape_rule return prim @@ -4139,6 +4139,7 @@ def _tie_in_batch_rule(batched_args, batch_dims): xla.translations[tie_in_p] = lambda c, x, y: y ad.deflinear(tie_in_p, _tie_in_transpose_rule) batching.primitive_batchers[tie_in_p] = _tie_in_batch_rule +masking.shape_rules[tie_in_p] = lambda x, y: y.shape masking.masking_rules[tie_in_p] = lambda vals, logical_shapes: vals[1] @@ -4412,6 +4413,8 @@ def conv_transpose_shape_tuple(lhs_shape, rhs_shape, window_strides, padding, def _check_shapelike(fun_name, arg_name, obj): """Check that `obj` is a shape-like value (e.g. tuple of nonnegative ints).""" + if (type(obj) is tuple and masking.is_polymorphic(obj)): + return obj if not isinstance(obj, (tuple, list, onp.ndarray)): msg = "{} {} must be of type tuple/list/ndarray, got {}." raise TypeError(msg.format(fun_name, arg_name, type(obj))) @@ -4422,9 +4425,7 @@ def _check_shapelike(fun_name, arg_name, obj): if obj_arr.ndim != 1: msg = "{} {} must be rank 1, got {}." raise TypeError(msg.format(obj_arr.ndim)) - try: - canonicalize_shape(obj_arr) - except TypeError: + if not dtypes.issubdtype(obj_arr.dtype, onp.integer): msg = "{} {} must have every element be an integer type, got {}." raise TypeError(msg.format(fun_name, arg_name, tuple(map(type, obj)))) if not (obj_arr >= 0).all(): diff --git a/jax/lax/lax_control_flow.py b/jax/lax/lax_control_flow.py index 11298ba9170d..bed09a4129ea 100644 --- a/jax/lax/lax_control_flow.py +++ b/jax/lax/lax_control_flow.py @@ -1120,7 +1120,7 @@ def _scan_masking_rule(shape_envs, padded_vals, shape_exprs, forward, length, jaxpr, num_consts, num_carry, linear): out_shape = _scan_shape_rule(shape_exprs, forward, length, jaxpr, num_consts, num_carry, linear) - dynamic_length = length.evaluate(shape_envs.logical) + dynamic_length = masking.eval_dim_expr(shape_envs.logical, length) masked_jaxpr = _masked_scan_jaxpr(jaxpr, num_consts, num_carry) consts, init, xs = split_list(padded_vals, [num_consts, num_carry]) max_length, = {x.shape[0] for x in xs} diff --git a/jax/numpy/lax_numpy.py b/jax/numpy/lax_numpy.py index f7ca137c8f7b..7f05772cf250 100644 --- a/jax/numpy/lax_numpy.py +++ b/jax/numpy/lax_numpy.py @@ -44,7 +44,6 @@ from ..abstract_arrays import UnshapedArray, ShapedArray, ConcreteArray from ..config import flags from ..interpreters.xla import DeviceArray -from ..interpreters.masking import Poly, to_index from .. import lax from ..util import partial, get_module_functions, unzip2, prod as _prod, subvals from ..lib import pytree @@ -1058,7 +1057,7 @@ def broadcast_arrays(*args): def broadcast_to(arr, shape): """Like Numpy's broadcast_to but doesn't necessarily return views.""" arr = arr if isinstance(arr, ndarray) else array(arr) - shape = tuple(map(to_index, shape)) # check that shape is concrete + shape = tuple(map(int, shape)) # check that shape is concrete arr_shape = _shape(arr) if arr_shape == shape: return arr @@ -1861,8 +1860,7 @@ def arange(start, stop=None, step=None, dtype=None): lax._check_user_dtype_supported(dtype, "arange") if stop is None and step is None: dtype = dtype or _dtype(start) - size = start if type(start) is Poly else lax.convert_element_type(start, onp.uint64) - return lax.iota(dtype, size) # avoids materializing + return lax.iota(dtype, start) # avoids materializing else: return array(onp.arange(start, stop=stop, step=step, dtype=dtype)) @@ -2644,9 +2642,6 @@ def take(a, indices, axis=None, out=None, mode=None): def _normalize_index(index, axis_size): """Normalizes an index value in the range [-N, N) to the range [0, N).""" - if type(axis_size) is Poly: - return index + axis_size if index < 0 else index - return lax.select( lax.lt(index, _constant_like(index, 0)), lax.add(index, _constant_like(index, axis_size)), @@ -2860,7 +2855,7 @@ def _index_to_gather(x_shape, idx): collapsed_slice_dims = [] start_index_map = [] - index_dtype = int64 if any([type(dim) is Poly or dim >= (1 << 31) for dim in x_shape]) else int32 + index_dtype = int64 if max(x_shape) >= (1 << 31) else int32 gather_indices = onp.zeros((0,), dtype=index_dtype) # use onp to save a compilation # We perform three transformations to y before the scatter op, in order: @@ -2916,8 +2911,7 @@ def _index_to_gather(x_shape, idx): if (isinstance(abstract_i, ConcreteArray) or isinstance(abstract_i, ShapedArray)) and _int(abstract_i): i = _normalize_index(i, x_shape[x_axis]) - # dummy index if is polynomial, doesn't matter for shape inference: - i = 0 if type(i) is Poly else lax.convert_element_type(i, index_dtype) + i = lax.convert_element_type(i, index_dtype) i = broadcast_to(i, tuple(gather_indices.shape[:-1]) + (1,)) gather_indices = concatenate((gather_indices, i), -1) collapsed_slice_dims.append(x_axis) @@ -2939,7 +2933,7 @@ def _index_to_gather(x_shape, idx): x_axis += 1 # Handle slice index (only static, otherwise an error is raised) elif isinstance(i, slice): - if not _all(elt is None or type(elt) is Poly or type(core.get_aval(elt)) is ConcreteArray + if not _all(elt is None or type(core.get_aval(elt)) is ConcreteArray for elt in (i.start, i.stop, i.step)): msg = ("Array slice indices must have static start/stop/step to be used " "with Numpy indexing syntax. Try lax.dynamic_slice/" @@ -3089,20 +3083,12 @@ def _canonicalize_tuple_index(arr_ndim, idx): idx = tuple(idx) + colons return idx -def _slice_indices(idx, size): - # like idx.indices(size), but allows for polymorphic slice and size - assert isinstance(idx, slice) - - step = 1 if idx.step is None else idx.step - start = (size - 1 if step < 0 else 0) if idx.start is None else idx.start + (size if idx.start < 0 else 0) - stop = (-1 if step < 0 else size) if idx.stop is None else idx.stop + (size if idx.stop < 0 else 0) - return start, stop, step def _static_idx(idx, size): """Helper function to compute the static slice start/limit/stride values.""" - start, stop, step = _slice_indices(idx, size) - if (type(start) is not Poly and type(stop) is not Poly and - ((step < 0 and stop >= start) or (step > 0 and start >= stop))): + assert isinstance(idx, slice) + start, stop, step = idx.indices(size) + if (step < 0 and stop >= start) or (step > 0 and start >= stop): return 0, 0, 1, False # sliced to size zero if step > 0: diff --git a/jax/random.py b/jax/random.py index 360aeaf45c28..2f64a088b38a 100644 --- a/jax/random.py +++ b/jax/random.py @@ -43,7 +43,6 @@ from jax.interpreters import partial_eval as pe from jax.interpreters import xla from jax.util import prod -from jax.interpreters.masking import to_index def PRNGKey(seed): @@ -277,7 +276,7 @@ def _random_bits(key, bit_width, shape): # TODO(mattjj): just split the key here raise TypeError("requesting more random bits than a single call provides.") - counts = lax.tie_in(key, lax.iota(onp.uint32, max_count.astype(onp.uint32))) + counts = lax.tie_in(key, lax.iota(onp.uint32, max_count)) bits = threefry_2x32(key, counts) if bit_width == 64: bits = [lax.convert_element_type(x, onp.uint64) for x in np.split(bits, 2)] @@ -290,7 +289,7 @@ def _random_bits(key, bit_width, shape): def _check_shape(name, shape, *param_shapes): try: - shape = tuple(map(to_index, shape)) + shape = tuple(map(int, shape)) except TypeError: msg = "{} requires a concrete tuple of integers as shape argument, got {}." raise ValueError(msg.format(name, shape)) diff --git a/tests/masking_test.py b/tests/masking_test.py index 8d30d58b823c..6ff0aec54e7d 100644 --- a/tests/masking_test.py +++ b/tests/masking_test.py @@ -17,24 +17,24 @@ from unittest import SkipTest import numpy as onp -from absl.testing import absltest, parameterized -from jax.interpreters.masking import shape_as_value, ShapeError, \ - parse_spec, Poly, Mon -from jax import numpy as np, test_util as jtu, mask, vmap, jit, grad, lax, \ - shapecheck, api -from jax.config import config -from jax.scipy.special import expit +from absl.testing import absltest +from absl.testing import parameterized + +from jax import test_util as jtu, core as jc, api +from jax.interpreters.masking import ShapeError, shape_as_value, parse_spec, \ + constant_poly, Mon, Poly, parse_id +from jax import mask, vmap, jit, grad, shapecheck +from jax import lax +import jax.numpy as np +from jax.config import config config.parse_flags_with_absl() -# These are 'manual' tests for masking. The more exhaustive, +# These are 'manual' tests for masking and shape checking. The more exhaustive, # more systematic tests should live in lax_test.py. -def constant_poly(c): - return Poly({Mon(): c}) - -class ShapesTest(jtu.JaxTestCase): +class MaskingTest(jtu.JaxTestCase): @parameterized.parameters([ ['(m, n)', 'ShapeSpec(m, n)'], @@ -53,10 +53,10 @@ class ShapesTest(jtu.JaxTestCase): ['', 'ShapeSpec()'], ['_', 'ShapeSpec(_)'], ]) - def test_parse_spec(self, spec, ans): + def test_shape_parsing(self, spec, ans): self.assertEqual(str(parse_spec(spec)), ans) - def test_Poly_equal(self): + def test_poly_equal(self): assert constant_poly(3) == 3 assert onp.array(3, onp.int64) == constant_poly(3) assert onp.array(3, onp.int64)[()] == constant_poly(3) @@ -70,11 +70,7 @@ def test_Poly_equal(self): assert Poly({Mon(): 3, Mon({'n': 1}): 4}) != Poly({Mon(): 3, Mon({'n': 2}): 4}) assert Poly({Mon(): 3, Mon({'m': 1}): 4}) != Poly({Mon(): 3, Mon({'n': 1}): 4}) - def test_Poly_hash(self): - assert not len(set(hash(Poly({Mon(): i})) for i in range(10))) == 1 - assert hash(Poly({Mon(): 3, Mon({'n': 1}): 4})) == hash(Poly({Mon({'n': 1}): 4, Mon(): 3})) - - def test_Poly_compare(self): + def test_poly_compare(self): poly = Poly({Mon(): 3, Mon({'n': 1}): 4}) # Assume poly > 0 to make various shape rules work with polymorphic shapes: assert poly >= 0 @@ -88,39 +84,39 @@ def test_Poly_compare(self): self.assertRaisesRegex(ValueError, "", lambda: poly >= 2) self.assertRaisesRegex(ValueError, "", lambda: poly > 1) - def test_Poly_divmod(self): - n = Poly({Mon({'n': 1}): 1}) + def test_poly_divmod(self): + n = parse_id('n') assert (n, 1) == divmod(2*n+1, 2) assert (2*n, 0) == divmod(10*n, 5) assert (2*n+4, 3) == divmod(10*n+23, 5) - def test_add_broadcast(self): - @shapecheck(['(m, n)', 'n'], '(m, n)') - @shapecheck(['n', ''], 'n') - def add(a, b): - return a + b + def test_shapecheck_add_broadcast(self): + @shapecheck(['(m, n)', 'n'], '(m, n)') + @shapecheck(['n', ''], 'n') + def add(a, b): + return a + b - def test_sum(self): + def test_shapecheck_sum(self): @shapecheck(['(m, n)'], '') def sum(x): return np.sum(x) - def test_prod(self): + def test_shapecheck_prod(self): @shapecheck(['(m, n)'], '') def prod(x): return np.prod(x) - def test_max(self): + def test_shapecheck_max(self): @shapecheck(['(m, n)'], '') def prod(x): return np.max(x) - def test_min(self): + def test_shapecheck_min(self): @shapecheck(['(m, n)'], '') def prod(x): return np.min(x) - def test_dot(self): + def test_shapecheck_dot(self): @shapecheck(['(m, n)', 'n'], 'm') def matvec(A, b): return np.dot(A, b) @@ -131,12 +127,12 @@ def matvec(A, b): return lax.dot_general(A, b, [((0,), (0,)), ((), ())]) self.assertRaisesRegex(TypeError, "", thunk) - def test_flatten(self): + def test_shapecheck_flatten(self): @shapecheck(['(m, n)'], 'm * n') def flatten(x): return lax.reshape(x, (x.shape[0] * x.shape[1],)) - def test_concatenate(self): + def test_shapecheck_concatenate(self): @shapecheck(['m', 'n', 'm'], '3*m + n') def cat(x, y, z): return lax.concatenate([x, y, x, z], 0) @@ -147,37 +143,30 @@ def cat(x, y, z): return lax.concatenate([x, y, x], 0) self.assertRaisesRegex(ShapeError, "", thunk) - def test_device_put(self): + def test_shapecheck_device_put(self): @shapecheck(['n'], 'n') def d_put(x): return api.device_put(x) - def test_broadcast_in_dim(self): + def test_shapecheck_broadcast_in_dim(self): x = np.zeros((7, 1)) lax.broadcast_in_dim(x, shape=(3, x.shape[0], 4), broadcast_dimensions=(1, 2)) @shapecheck(['(n, 1)'], '(3, n, 4)') def broadcast_in_dim(x): return lax.broadcast_in_dim(x, shape=(3, x.shape[0], 4), broadcast_dimensions=(1, 2)) - def test_jit(self): + def test_shapecheck_jit(self): @shapecheck(['n'], '2*n') @jit def concat(x): return lax.concatenate([x, x], 0) - # TODO: - # @shapecheck(['n'], 'n') - # @jit - # @grad - # def sum_square(x): - # return np.sum(x ** 2) - - def test_pad(self): + def test_shapecheck_pad(self): @shapecheck(['n'], '2*n+1') def p(x): - return lax.pad(x, np.array(0., x.dtype), [(1, 1, 1)]) + return lax.pad(x, 0, [(1, 1, 1)]) - def test_numpy_pad(self): + def test_shapecheck_numpy_pad(self): @shapecheck(['n'], 'n+1') def p(x): return np.pad(x, (0, 1)) @@ -202,8 +191,8 @@ def p(x): if (lhs_dilation is None or not isinstance(padding, str)) and # only test strides with same padding: (strides[0] == 1 or padding == 'SAME'))) - def test_conv(self, strides, padding, lhs_dilation, - dimension_numbers, lhs_perm, rhs_perm, out_perm): + def test_shapecheck_conv(self, strides, padding, lhs_dilation, + dimension_numbers, lhs_perm, rhs_perm, out_perm): valid = padding == 'VALID' is_strided = strides[0] != 1 lhs_shape = '({}, {}, {}, {})'.format(*onp.take(['n', 'i', '2*h' if is_strided else 'h', 'w'], lhs_perm)) @@ -218,22 +207,8 @@ def conv(lhs, rhs): lhs, rhs, strides, padding, lhs_dilation=lhs_dilation, dimension_numbers=dimension_numbers) - def test_indexing(self): - @shapecheck(['n'], '') - def first(x): - return x[0] - - @shapecheck(['n'], '') - def last(x): - return x[-1] - - @shapecheck(['(n,m,a)'], 'n,m') - @vmap - @shapecheck(['(n,a)'], 'n') - def last_column(x): - return x[..., -1] - - def test_slicing(self): + # TODO: + def DISABLED_shapecheck_slicing(self): @shapecheck(['n'], 'n+-1') def slice(x): return x[1:] @@ -242,34 +217,16 @@ def slice(x): def slice(x): return x[:-1] - def test_iota(self): - @shapecheck(['n'], 'n') - def range_like(x): - return lax.iota(np.int32, x.shape[0]) - - def test_arange(self): - @shapecheck(['n'], 'n') - def arange_like(x): - return np.arange(x.shape[0], dtype=np.int32) - - def test_expit(self): - @shapecheck(['n'], 'n') - def expit_(x): - return expit(x) - - def test_reshape(self): - @shapecheck(['n, a, b'], 'n, a*b') - def flatten(x): - return np.reshape(x, (x.shape[0], x.shape[1] * x.shape[2])) + def test_shapecheck_unsupported_op(self): + p = jc.Primitive('unsupported_op') + p.def_impl(lambda x: x) - def test_ravel(self): - a = np.array(1) - - @shapecheck(['n'], '') - def thunk(n): - return -(a + n.ravel()[0] * 0) + def thunk(): + @shapecheck(['n'], 'n') + def identity(x): + return p.bind(x) -class MaskingTest(jtu.JaxTestCase): + self.assertRaisesRegex(NotImplementedError, "Shape rule for unsupported_op not implemented yet.", thunk) def test_sum(self): @partial(mask, in_shapes=['n'], out_shape='')