Skip to content

Commit

Permalink
Support mask of broadcast, consts (arange, zeros, ...), squeeze
Browse files Browse the repository at this point in the history
  • Loading branch information
juliuskunze committed Sep 17, 2020
1 parent b81c246 commit 90753d7
Show file tree
Hide file tree
Showing 4 changed files with 226 additions and 50 deletions.
27 changes: 21 additions & 6 deletions jax/interpreters/masking.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,14 +18,14 @@
from itertools import chain, product
import operator as op
import string
from typing import Callable, Dict, Sequence, Union
from typing import Callable, Dict, Sequence, Union, Generator

import numpy as np

from .. import abstract_arrays
from .. import core, dtypes
from ..tree_util import tree_unflatten
from ..core import Trace, Tracer
from ..core import Trace, Tracer, MainTrace, thread_local_state
from ..util import safe_map, safe_zip, unzip2, prod, wrap_name
from ..abstract_arrays import ShapedArray
from .. import linear_util as lu
Expand Down Expand Up @@ -75,12 +75,15 @@ def padded_shape_as_value(shape):
assert is_tracing() or not is_polymorphic(shape)
return eval_poly_shape(shape, shape_envs.padded)

def ensure_padded_shape(shape):
return padded_shape_as_value(shape) if is_tracing() else shape

def mask_fun(fun, logical_env, padded_env, in_vals, polymorphic_shapes):
env_keys, padded_env_vals = unzip2(sorted(padded_env.items()))
logical_env_vals = [logical_env[k] for k in env_keys]
# Make padded_env hashable
padded_env = (env_keys, padded_env_vals)
with core.new_main(MaskTrace) as main:
with core.new_main(MaskTrace, dynamic=True) as main:
fun, out_shapes = mask_subtrace(fun, main, polymorphic_shapes, padded_env)
out_vals = fun.call_wrapped(*(logical_env_vals + in_vals))
del main
Expand Down Expand Up @@ -384,6 +387,15 @@ def full_lower(self):
else:
return self

@contextmanager
def _suspend_masking() -> Generator[MainTrace, None, None]:
stack = thread_local_state.trace_state.trace_stack
main = stack.stack[0]
prev_dynamic, stack.dynamic = stack.dynamic, main
try:
yield main
finally:
stack.dynamic = prev_dynamic

class MaskTrace(Trace):
def pure(self, val):
Expand All @@ -405,7 +417,8 @@ def process_primitive(self, primitive, tracers, params):
logical_shapes = map(shape_as_value, polymorphic_shapes)
# TODO(mattjj): generalize mask rule signature
if primitive.name == 'reshape': params['polymorphic_shapes'] = polymorphic_shapes
out = masking_rule(vals, logical_shapes, **params)
with _suspend_masking():
out = masking_rule(vals, logical_shapes, **params)
if primitive.multiple_results:
return map(partial(MaskTracer, self), out, (o.shape for o in out_aval))
else:
Expand All @@ -416,7 +429,8 @@ def process_call(self, call_primitive, f, tracers, params):
params = dict(params, name=wrap_name(params.get('name', f.__name__), 'mask'))
vals, shapes = unzip2((t.val, t.polymorphic_shape) for t in tracers)
if not any(is_polymorphic(s) for s in shapes):
return call_primitive.bind(f, *vals, **params)
with _suspend_masking():
return call_primitive.bind(f, *vals, **params)
else:
logical_env, padded_env = shape_envs
env_keys, padded_env_vals = unzip2(sorted(padded_env.items()))
Expand All @@ -427,7 +441,8 @@ def process_call(self, call_primitive, f, tracers, params):
if 'donated_invars' in params:
params = dict(params, donated_invars=((False,) * len(logical_env_vals) +
params['donated_invars']))
vals_out = call_primitive.bind(f, *(logical_env_vals + vals), **params)
with _suspend_masking():
vals_out = call_primitive.bind(f, *(logical_env_vals + vals), **params)
return [MaskTracer(self, v, s) for v, s in zip(vals_out, shapes_out())]

def post_process_call(self, call_primitive, out_tracers, params):
Expand Down
155 changes: 130 additions & 25 deletions jax/lax/lax.py
Original file line number Diff line number Diff line change
Expand Up @@ -1412,11 +1412,8 @@ def iota(dtype: DType, size: int) -> Array:
operator.
"""
size = size if type(size) is masking.Poly else int(size)
shape = canonicalize_shape((size,))
dtype = dtypes.canonicalize_dtype(dtype)
lazy_expr = lazy.iota(dtype, shape[0])
aval = ShapedArray(shape, dtype)
return xla.DeviceArray(aval, None, lazy_expr, xla.DeviceConstant())
return iota_p.bind(dtype=dtypes.canonicalize_dtype(dtype),
size=_canonicalize_dimension(size))

def broadcasted_iota(dtype: DType, shape: Shape, dimension: int) -> Array:
"""Convenience wrapper around ``iota``."""
Expand All @@ -1431,35 +1428,26 @@ def _eye(dtype: DType, shape: Shape, offset: int) -> Array:
This function exists for creating lazy identity matrices; that is,
materialization of the array is delayed and it may be fused into consumers to
avoid materialization at all."""
N, M = tuple(map(int, shape))
offset = int(offset)
N, M = canonicalize_shape(shape)
dtype = dtypes.canonicalize_dtype(dtype)
lazy_expr = lazy.eye(dtype, (N, M), offset)
aval = ShapedArray((N, M), dtype)
return xla.DeviceArray(aval, None, lazy_expr, xla.DeviceConstant())
return eye_p.bind(dtype=dtype, shape=(N, M), offset=int(offset))

def _delta(dtype: DType, shape: Shape, axes: Sequence[int]) -> Array:
"""This function exists for creating lazy Kronecker delta arrays, particularly
for use in jax.numpy.einsum to express traces. It differs from ``eye`` in that
it can create arrays of any rank, but doesn't allow offsets."""
shape = tuple(map(int, shape))
shape = canonicalize_shape(shape)
axes = tuple(map(int, axes))
dtype = dtypes.canonicalize_dtype(dtype)
base_shape = tuple(np.take(shape, axes))
lazy_expr = lazy.broadcast(lazy.delta(dtype, base_shape), shape, axes)
aval = ShapedArray(shape, dtype)
return xla.DeviceArray(aval, None, lazy_expr, xla.DeviceConstant())
return delta_p.bind(dtype=dtype, shape=shape, axes=axes)

def _tri(dtype: DType, shape: Shape, offset: int) -> Array:
"""Like numpy.tri, create a 2D array with ones below a diagonal.
This function exists for creating lazy triangular matrices, particularly for
use in jax.numpy.tri."""
N, M = tuple(map(int, shape))
offset = int(offset)
N, M = canonicalize_shape(shape)
dtype = dtypes.canonicalize_dtype(dtype)
lazy_expr = lazy.tri(dtype, (N, M), offset)
aval = ShapedArray((N, M), dtype)
return xla.DeviceArray(aval, None, lazy_expr, xla.DeviceConstant())
return tri_p.bind(dtype=dtype, shape=(N, M), offset=int(offset))

def stop_gradient(x):
"""Stops gradient computation.
Expand Down Expand Up @@ -3055,7 +3043,9 @@ def _broadcast_in_dim_shape_rule(operand, *, shape, broadcast_dimensions):
msg = ('broadcast_in_dim broadcast_dimensions must be a subset of output '
'dimensions, got {} for operand ndim {} and shape {}.')
raise TypeError(msg.format(broadcast_dimensions, operand_ndim, shape))
if any(operand.shape[i] != 1 and operand.shape[i] != shape[broadcast_dimensions[i]]
op_shape = masking.ensure_padded_shape(np.shape(operand))
shape_ = masking.ensure_padded_shape(shape)
if any(op_shape[i] != 1 and op_shape[i] != shape_[broadcast_dimensions[i]]
for i in range(operand_ndim)):
msg = ('broadcast_in_dim operand dimension sizes must either be 1, or be '
'equal to their corresponding dimensions in the target broadcast shape; '
Expand Down Expand Up @@ -3083,12 +3073,43 @@ def _broadcast_in_dim_batch_rule(batched_args, batch_dims, *, shape,
new_broadcast_dimensions = (0,) + tuple(np.add(1, broadcast_dimensions))
return broadcast_in_dim(new_operand, new_shape, new_broadcast_dimensions), 0

def _broadcast_abstract_eval(operand, shape, broadcast_dimensions):
if masking.is_polymorphic(shape) and type(operand) is ConcreteArray:
operand = ShapedArray(operand.shape, operand.dtype)

return standard_abstract_eval(
broadcast_in_dim_p, _broadcast_in_dim_shape_rule, _input_dtype,
operand, shape=shape, broadcast_dimensions=broadcast_dimensions)

def _broadcast_in_dim_masking_rule(padded_vals, logical_shapes,
shape, broadcast_dimensions):
padded_val, = padded_vals
return broadcast_in_dim(padded_val,
shape=masking.padded_shape_as_value(shape),
broadcast_dimensions=broadcast_dimensions)

def _broadcast_in_dim_translation_rule(c, operand, shape, broadcast_dimensions):
if masking.is_polymorphic(shape):
# TODO Unlike many other primitives, broadcast_in_dim bakes the shape
# parameter into its XlaOp, causing failure on polymorphic shapes.
# Fix requires a mechanism to translate polymorphic jaxprs, for now fail:
raise NotImplementedError(
"mask(jit(broadcast_in_dim))) is not supported yet. "
"Consider using jit(mask(broadcast_in_dim)) instead."
"If you are using np.where, consider disabling jit on jax.lax._where or "
"manually broadcasting arguments to the same shape.")
return standard_translate(
'broadcast_in_dim', c, operand, shape=shape,
broadcast_dimensions=broadcast_dimensions)

broadcast_in_dim_p = standard_primitive(
_broadcast_in_dim_shape_rule, _input_dtype, 'broadcast_in_dim')
broadcast_in_dim_p.def_impl(_broadcast_in_dim_impl)
ad.deflinear(broadcast_in_dim_p, _broadcast_in_dim_transpose_rule)
batching.primitive_batchers[broadcast_in_dim_p] = _broadcast_in_dim_batch_rule
broadcast_in_dim_p.def_abstract_eval(_broadcast_abstract_eval)
masking.masking_rules[broadcast_in_dim_p] = _broadcast_in_dim_masking_rule
xla.translations[broadcast_in_dim_p] = _broadcast_in_dim_translation_rule


def _clamp_shape_rule(min, operand, max):
Expand Down Expand Up @@ -3305,10 +3326,15 @@ def _squeeze_batch_rule(batched_args, batch_dims, *, dimensions):
dimensions = tuple(np.add(1, dimensions))
return squeeze(operand, dimensions=dimensions), 0

def _squeeze_masking_rule(padded_vals, logical_shapes, *, dimensions):
padded_val, = padded_vals
return squeeze(padded_val, dimensions=dimensions)

squeeze_p = standard_primitive(_squeeze_shape_rule, _squeeze_dtype_rule,
'squeeze', _squeeze_translation_rule)
ad.deflinear2(squeeze_p, _squeeze_transpose_rule)
batching.primitive_batchers[squeeze_p] = _squeeze_batch_rule
masking.masking_rules[squeeze_p] = _squeeze_masking_rule


def expand_dims(array: Array, dimensions: Tuple[int, ...]) -> Array:
Expand Down Expand Up @@ -3477,13 +3503,16 @@ def _transpose_masking_rule(padded_vals, logical_shapes, permutation):


def _select_shape_rule(pred, on_true, on_false):
if on_true.shape != on_false.shape:
pred_shape = masking.ensure_padded_shape(pred.shape)
t_shape = masking.ensure_padded_shape(on_true.shape)
f_shape = masking.ensure_padded_shape(on_false.shape)
if t_shape != f_shape:
msg = "select on_true and on_false must have the same shape, got {} and {}."
raise TypeError(msg.format(on_true.shape, on_false.shape))
if pred.shape and pred.shape != on_true.shape:
raise TypeError(msg.format(t_shape, f_shape))
if pred_shape and pred_shape != t_shape:
msg = ("select pred must be scalar or have the same shape as on_true and "
"on_false, got pred shape {} for on_true and on_false of shape {}.")
raise TypeError(msg.format(pred.shape, on_true.shape))
raise TypeError(msg.format(pred_shape, t_shape))
return on_true.shape

def _select_dtype_rule(pred, on_true, on_false):
Expand Down Expand Up @@ -5817,6 +5846,79 @@ def _rng_uniform_translation_rule(c, a, b, *, shape):
rng_uniform_p.def_abstract_eval(_rng_uniform_abstract_eval)
xla.translations[rng_uniform_p] = _rng_uniform_translation_rule

def _iota_impl(dtype: DType, size: int) -> Array:
aval = ShapedArray((size,), dtype)
lazy_expr = lazy.iota(dtype, size)
return xla.DeviceArray(aval, None, lazy_expr, xla.DeviceConstant())

def _iota_abstract_eval(dtype: DType, size: int) -> ShapedArray:
if masking.is_polymorphic((size,)): return ShapedArray((size,), dtype)
return ConcreteArray(_iota_impl(dtype, size))

def _iota_masking_rule(_, __, dtype: DType, size: int) -> Array:
return _iota_impl(dtype, masking.padded_shape_as_value((size,))[0])

iota_p = Primitive('iota')
iota_p.def_impl(_iota_impl)
iota_p.def_abstract_eval(_iota_abstract_eval)
xla.translations[iota_p] = xla.lower_fun(_iota_impl, multiple_results=False)
masking.masking_rules[iota_p] = _iota_masking_rule

def _eye_impl(dtype: DType, shape: Shape, offset: int) -> Array:
aval = ShapedArray(shape, dtype)
lazy_expr = lazy.eye(dtype, shape, offset)
return xla.DeviceArray(aval, None, lazy_expr, xla.DeviceConstant())

def _eye_abstract_eval(dtype: DType, shape: Shape, offset: int) -> ShapedArray:
if masking.is_polymorphic(shape): return ShapedArray(shape, dtype)
return ConcreteArray(_eye_impl(dtype, shape, offset))

def _eye_masking_rule(_, __, dtype: DType, shape: Shape, offset: int) -> Array:
return _eye_impl(dtype, masking.padded_shape_as_value(shape), offset)

eye_p = Primitive('eye')
eye_p.def_impl(_eye_impl)
eye_p.def_abstract_eval(_eye_abstract_eval)
xla.translations[eye_p] = xla.lower_fun(_eye_impl, multiple_results=False)
masking.masking_rules[eye_p] = _eye_masking_rule

def _delta_impl(dtype: DType, shape: Shape, axes: Sequence[int]) -> Array:
aval = ShapedArray(shape, dtype)
lazy_delta = lazy.delta(dtype, tuple(np.take(shape, axes)))
lazy_expr = lazy.broadcast(lazy_delta, shape, axes)
return xla.DeviceArray(aval, None, lazy_expr, xla.DeviceConstant())

def _delta_abstract_eval(dtype: DType, shape: Shape, axes: Sequence[int]) -> ShapedArray:
if masking.is_polymorphic(shape): return ShapedArray(shape, dtype)
return ConcreteArray(_delta_impl(dtype, shape, axes))

def _delta_masking_rule(_, __, dtype: DType, shape: Shape, axes: Sequence[int]) -> Array:
return _delta_impl(dtype, masking.padded_shape_as_value(shape), axes)

delta_p = Primitive('delta')
delta_p.def_impl(_delta_impl)
delta_p.def_abstract_eval(_delta_abstract_eval)
xla.translations[delta_p] = xla.lower_fun(_delta_impl, multiple_results=False)
masking.masking_rules[delta_p] = _delta_masking_rule

def _tri_impl(dtype: DType, shape: Shape, offset: int) -> Array:
aval = ShapedArray(shape, dtype)
lazy_expr = lazy.tri(dtype, shape, offset)
return xla.DeviceArray(aval, None, lazy_expr, xla.DeviceConstant())

def _tri_abstract_eval(dtype: DType, shape: Shape, offset: int) -> ShapedArray:
if masking.is_polymorphic(shape): return ShapedArray(shape, dtype)
return ConcreteArray(_tri_impl(dtype, shape, offset))

def _tri_masking_rule(_, __, dtype: DType, shape: Shape, offset: int) -> Array:
return _tri_impl(dtype, masking.padded_shape_as_value(shape), offset)

tri_p = Primitive('tri')
tri_p.def_impl(_tri_impl)
tri_p.def_abstract_eval(_tri_abstract_eval)
xla.translations[tri_p] = xla.lower_fun(_tri_impl, multiple_results=False)
masking.masking_rules[tri_p] = _tri_masking_rule

### util

_ndim = np.ndim
Expand All @@ -5834,6 +5936,9 @@ def _dilate_shape(shape, dilation):
def _ceil_divide(x1, x2):
return -np.floor_divide(np.negative(x1), x2)

def _ceil(x):
return _ceil_divide(x, 1)

def padtype_to_pads(in_shape, window_shape, window_strides, padding):
"""Convert padding string to list of pairs of pad values."""
PaddingType = xla_client.PaddingType
Expand Down
6 changes: 3 additions & 3 deletions jax/numpy/lax_numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -2417,9 +2417,9 @@ def arange(start, stop=None, step=None, dtype=None):
require = partial(core.concrete_or_error, _np_asarray)
msg = "It arose in jax.numpy.arange argument `{}`.".format
if stop is None and step is None:
start = require(start, msg("stop"))
dtype = dtype or _dtype(start)
return lax.iota(dtype, np.ceil(start)) # avoids materializing
start_ = require(start, msg("stop"))
dtype = dtype or (int64 if type(start) is Poly else _dtype(start_))
return lax.iota(dtype, lax.lax._ceil(start_)) # avoids materializing
else:
start = require(start, msg("start"))
stop = None if stop is None else require(stop, msg("stop"))
Expand Down
Loading

0 comments on commit 90753d7

Please sign in to comment.