Skip to content

Commit

Permalink
Support mask of broadcast, consts (arange, zeros, ...), where
Browse files Browse the repository at this point in the history
  • Loading branch information
juliuskunze committed Jun 19, 2020
1 parent 3b4a123 commit 8945e77
Show file tree
Hide file tree
Showing 6 changed files with 210 additions and 49 deletions.
2 changes: 1 addition & 1 deletion jax/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -1326,7 +1326,7 @@ def shapecheck(in_shapes, out_shape, fun: Callable):
out_specs = map(masking.parse_spec, out_specs)
flat_fun, out_tree_thunk = flatten_fun_nokwargs(lu.wrap_init(fun), in_tree)
avals = map(partial(ShapedArray, dtype=onp.float32), in_shapes)
out_shapes = [o.shape for o in pe.abstract_eval_fun(flat_fun.call_wrapped, *avals)]
out_shapes = [o.shape for o in pe.abstract_eval_fun(flat_fun.call_wrapped, *avals, trace_type=masking.PolymorphicJaxprTrace)]
masking.check_shapes(map(tuple, out_specs), out_spec_tree,
map(tuple, out_shapes), out_tree_thunk())
return fun
Expand Down
27 changes: 26 additions & 1 deletion jax/interpreters/masking.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from itertools import chain, product
import operator as op
import string
from typing import Callable, Dict, Sequence, Union
from typing import Callable, Dict, Sequence, Union, Tuple

import numpy as onp

Expand All @@ -29,6 +29,7 @@
from ..util import safe_map, safe_zip, unzip2, prod, wrap_name
from ..abstract_arrays import ShapedArray
from .. import linear_util as lu
from ..interpreters.partial_eval import JaxprTrace

map = safe_map
zip = safe_zip
Expand Down Expand Up @@ -75,6 +76,9 @@ def padded_shape_as_value(shape):
assert is_tracing() or not is_polymorphic(shape)
return eval_polymorphic_shape(shape, shape_envs.padded)

def ensure_padded_shape(shape):
return padded_shape_as_value(shape) if is_tracing() else shape

def mask_fun(fun, logical_env, padded_env, in_vals, polymorphic_shapes):
env_keys, padded_env_vals = unzip2(sorted(padded_env.items()))
logical_env_vals = [logical_env[k] for k in env_keys]
Expand Down Expand Up @@ -391,6 +395,9 @@ def lift(self, val):
def sublift(self, val):
return MaskTracer(self, val.val, val.polymorphic_shape)

def new_instantiated_const(self, val):
return MaskTracer(self, val, onp.shape(val))

def process_primitive(self, primitive, tracers, params):
masking_rule = masking_rules.get(primitive)
if masking_rule is None:
Expand Down Expand Up @@ -476,3 +483,21 @@ def check_shapes(specs, spec_tree, shapes, tree, message_prefix="Output"):
specs = tree_unflatten(spec_tree, specs)
shapes = tree_unflatten(tree, shapes)
raise ShapeError(f"{message_prefix} shapes should be {specs} but are {shapes}.")

class PolymorphicJaxprTrace(JaxprTrace):
pass

polymorphic_trace_types: Tuple = (PolymorphicJaxprTrace, MaskTrace)

# TODO(mattjj): We should be able to remove this using omnistaging:
def ensure_traced(operand):
if isinstance(operand, Tracer):
return operand

def has_poly_trace(master):
return issubclass(master.trace_type, polymorphic_trace_types)

masters = reversed(core.trace_state.trace_stack.upward)
master = next(filter(has_poly_trace, masters))
trace = master.trace_type(master, core.cur_sublevel())
return trace.new_instantiated_const(operand)
4 changes: 2 additions & 2 deletions jax/interpreters/partial_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -346,10 +346,10 @@ def partial_eval_wrapper(avals: Sequence[Optional[AbstractValue]], *consts):
staged_out_calls: Set[core.Primitive] = set()


def abstract_eval_fun(fun, *avals, **params):
def abstract_eval_fun(fun, *avals, trace_type: Optional[Type[Trace]] = None, **params):
pvals_in = [PartialVal.unknown(a) for a in avals]
_, pvals_out, _ = trace_to_jaxpr(lu.wrap_init(fun, params), pvals_in,
instantiate=True, stage_out=True)
instantiate=True, stage_out=True, trace_type=trace_type)
avals_out, _ = unzip2(pvals_out)
for aval_out in avals_out:
assert isinstance(aval_out, AbstractValue) # instantiate=True
Expand Down
148 changes: 121 additions & 27 deletions jax/lax/lax.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,14 +22,15 @@

import numpy as onp

from .. import source_info_util
from .. import core
from .. import ad_util
from .. import api
from .. import linear_util as lu
from .. import dtypes
from .. import lazy
from ..config import flags
from ..core import Primitive, _canonicalize_dimension
from ..core import Primitive, unit, _canonicalize_dimension
from ..abstract_arrays import (UnshapedArray, ShapedArray, ConcreteArray, array_types,
raise_to_shaped, abstract_token, canonicalize_shape)
from ..interpreters import partial_eval as pe
Expand All @@ -39,7 +40,7 @@
from ..interpreters import invertible_ad as iad
from ..interpreters import batching
from ..interpreters import masking
from ..util import cache, safe_zip, partial, prod, safe_map
from ..util import cache, safe_zip, partial, prod, safe_map, curry
from ..tree_util import tree_map
from ..lib import pytree
from ..lib import xla_bridge
Expand Down Expand Up @@ -671,6 +672,8 @@ def broadcast_in_dim(operand: Array, shape: Shape,
operand, shape=shape, broadcast_dimensions=broadcast_dimensions)
if onp.ndim(operand) == len(shape) and not len(broadcast_dimensions):
return operand
if masking.is_polymorphic(shape):
operand = masking.ensure_traced(operand)
return broadcast_in_dim_p.bind(
operand, shape=tuple(shape),
broadcast_dimensions=tuple(broadcast_dimensions))
Expand Down Expand Up @@ -1258,17 +1261,65 @@ def full(shape: Shape, fill_value: Array, dtype: Optional[DType] = None) -> Arra
fill_value = xla.device_put_p.bind(convert_element_type(fill_value, dtype))
return broadcast(fill_value, shape)

@curry
def _jaxpr_process_primitive_without_lowering(prim, trace, *tracers, **params):
# Like JaxprTrace.process_primitive but without lowering out of the trace.
avals = [t.aval for t in tracers]
out_aval = prim.abstract_eval(*avals, **params)
out_tracer = pe.JaxprTracer(trace, pe.PartialVal((out_aval, unit)), None)
out_tracer.recipe = pe.new_eqn_recipe(tracers, [out_tracer], prim, params,
source_info_util.current())
return out_tracer

def lazy_primitive_bind(name: str, fun: Callable[..., lazy.LazyExpr]) -> Callable:
def shape(kwargs):
return kwargs['shape'] if 'shape' in kwargs else (kwargs['size'],)

def impl(**kwargs):
aval = ShapedArray(shape(kwargs), kwargs['dtype'])
lazy_expr = fun(**kwargs)
return xla.DeviceArray(aval, None, lazy_expr, xla.DeviceConstant())

def abstract_eval(**kwargs):
if masking.is_polymorphic(shape(kwargs)):
return ShapedArray(shape(kwargs), kwargs['dtype'])
return ConcreteArray(impl(**kwargs))

def masking_rule(padded_vals, logical_shapes, **kwargs):
padded_shape = masking.padded_shape_as_value(shape(kwargs))
if 'shape' in kwargs:
kwargs['shape'] = padded_shape
else:
padded_size, = padded_shape
kwargs['size'] = padded_size
return impl(**kwargs)

p = Primitive(name)
# For constant primitives of polymorphic shape, bind cannot fall back to impl.
# Instead, we need to trigger abstract_eval, returning an abstract value.
# This is achieved by providing a dummy arg wrapped in a masking trace and
# not lowering out of the trace:
def bind(**kwargs):
dummy = onp.array(0)
if masking.is_polymorphic(shape(kwargs)):
dummy = masking.ensure_traced(dummy)
return p.bind(dummy, **kwargs)
p.def_impl(lambda dummy, **kwargs: impl(**kwargs))
p.def_abstract_eval(lambda dummy, **kwargs: abstract_eval(**kwargs))
pe.custom_partial_eval_rules[p] = _jaxpr_process_primitive_without_lowering(p)
masking.masking_rules[p] = masking_rule
return bind

iota_bind = lazy_primitive_bind('iota', lazy.iota)

def iota(dtype: DType, size: int) -> Array:
"""Wraps XLA's `Iota
<https://www.tensorflow.org/xla/operation_semantics#iota>`_
operator.
"""
size = size if type(size) is masking.Poly else int(size)
shape = canonicalize_shape((size,))
dtype = dtypes.canonicalize_dtype(dtype)
lazy_expr = lazy.iota(dtype, shape[0])
aval = ShapedArray(shape, dtype)
return xla.DeviceArray(aval, None, lazy_expr, xla.DeviceConstant())
return iota_bind(dtype=dtypes.canonicalize_dtype(dtype),
size=_canonicalize_dimension(size))

def broadcasted_iota(dtype: DType, shape: Shape, dimension: int) -> Array:
"""Convenience wrapper around ``iota``."""
Expand All @@ -1277,41 +1328,39 @@ def broadcasted_iota(dtype: DType, shape: Shape, dimension: int) -> Array:
dimension = int(dimension)
return broadcast_in_dim(iota(dtype, shape[dimension]), shape, [dimension])

eye_bind = lazy_primitive_bind('eye', lazy.eye)

def _eye(dtype: DType, shape: Shape, offset: int) -> Array:
"""Like numpy.eye, create a 2D array with ones on a diagonal.
This function exists for creating lazy identity matrices; that is,
materialization of the array is delayed and it may be fused into consumers to
avoid materialization at all."""
N, M = tuple(map(int, shape))
offset = int(offset)
N, M = canonicalize_shape(shape)
dtype = dtypes.canonicalize_dtype(dtype)
lazy_expr = lazy.eye(dtype, (N, M), offset)
aval = ShapedArray((N, M), dtype)
return xla.DeviceArray(aval, None, lazy_expr, xla.DeviceConstant())
return eye_bind(dtype=dtype, shape=(N, M), offset=int(offset))

delta_bind = lazy_primitive_bind('delta', lambda dtype, shape, axes:
lazy.broadcast(lazy.delta(dtype, tuple(onp.take(shape, axes))), shape, axes))

def _delta(dtype: DType, shape: Shape, axes: Sequence[int]) -> Array:
"""This function exists for creating lazy Kronecker delta arrays, particularly
for use in jax.numpy.einsum to express traces. It differs from ``eye`` in that
it can create arrays of any rank, but doesn't allow offsets."""
shape = tuple(map(int, shape))
shape = canonicalize_shape(shape)
axes = tuple(map(int, axes))
dtype = dtypes.canonicalize_dtype(dtype)
base_shape = tuple(onp.take(shape, axes))
lazy_expr = lazy.broadcast(lazy.delta(dtype, base_shape), shape, axes)
aval = ShapedArray(shape, dtype)
return xla.DeviceArray(aval, None, lazy_expr, xla.DeviceConstant())
return delta_bind(dtype=dtype, shape=shape, axes=axes)

tri_bind = lazy_primitive_bind('tri', lazy.tri)

def _tri(dtype: DType, shape: Shape, offset: int) -> Array:
"""Like numpy.tri, create a 2D array with ones below a diagonal.
This function exists for creating lazy triangular matrices, particularly for
use in jax.numpy.tri."""
N, M = tuple(map(int, shape))
offset = int(offset)
N, M = canonicalize_shape(shape)
dtype = dtypes.canonicalize_dtype(dtype)
lazy_expr = lazy.tri(dtype, (N, M), offset)
aval = ShapedArray((N, M), dtype)
return xla.DeviceArray(aval, None, lazy_expr, xla.DeviceConstant())
return tri_bind(dtype=dtype, shape=(N, M), offset=int(offset))

def stop_gradient(x):
"""Stops gradient computation.
Expand Down Expand Up @@ -2795,7 +2844,9 @@ def _broadcast_in_dim_shape_rule(operand, *, shape, broadcast_dimensions):
msg = ('broadcast_in_dim broadcast_dimensions must be a subset of output '
'dimensions, got {} for operand ndim {} and shape {}.')
raise TypeError(msg.format(broadcast_dimensions, operand_ndim, shape))
if any(operand.shape[i] != 1 and operand.shape[i] != shape[broadcast_dimensions[i]]
op_shape = masking.ensure_padded_shape(onp.shape(operand))
shape_ = masking.ensure_padded_shape(shape)
if any(op_shape[i] != 1 and op_shape[i] != shape_[broadcast_dimensions[i]]
for i in range(operand_ndim)):
msg = ('broadcast_in_dim operand dimension sizes must either be 1, or be '
'equal to their corresponding dimensions in the target broadcast shape; '
Expand Down Expand Up @@ -2824,11 +2875,48 @@ def _broadcast_in_dim_batch_rule(batched_args, batch_dims, *, shape,
return broadcast_in_dim(new_operand, new_shape, new_broadcast_dimensions), 0


def _broadcast_abstract_eval(operand, shape, broadcast_dimensions):
if masking.is_polymorphic(shape) and type(operand) is ConcreteArray:
operand = ShapedArray(operand.shape, operand.dtype)

return standard_abstract_eval(
broadcast_in_dim_p, _broadcast_in_dim_shape_rule, _input_dtype,
operand, shape=shape, broadcast_dimensions=broadcast_dimensions)

def _broadcast_in_dim_masking_rule(padded_vals, logical_shapes, shape, broadcast_dimensions):
operand, = padded_vals
return broadcast_in_dim(operand,
shape=masking.padded_shape_as_value(shape),
broadcast_dimensions=broadcast_dimensions)

def _broadcast_in_dim_translation_rule(c, operand, shape, broadcast_dimensions):
if masking.is_polymorphic(shape):
# TODO Unlike many other primitives, broadcast_in_dim bakes the shape
# parameter into its XlaOp, causing failure on polymorphic shapes.
# Fix requires a mechanism to translate polymorphic jaxprs, for now fail:
raise NotImplementedError(
"mask(jit(broadcast_in_dim))) is not supported yet. "
"Consider using jit(mask(broadcast_in_dim)) instead."
"If you are using np.where, consider disabling jit on jax.lax._where or "
"manually broadcasting arguments to the same shape.")
return standard_translate(
'broadcast_in_dim', c, operand, shape=shape,
broadcast_dimensions=broadcast_dimensions)

broadcast_in_dim_p = standard_primitive(
_broadcast_in_dim_shape_rule, _input_dtype, 'broadcast_in_dim')
broadcast_in_dim_p.def_impl(_broadcast_in_dim_impl)
ad.deflinear(broadcast_in_dim_p, _broadcast_in_dim_transpose_rule)
batching.primitive_batchers[broadcast_in_dim_p] = _broadcast_in_dim_batch_rule
broadcast_in_dim_p.def_abstract_eval(_broadcast_abstract_eval)
masking.masking_rules[broadcast_in_dim_p] = _broadcast_in_dim_masking_rule
xla.translations[broadcast_in_dim_p] = _broadcast_in_dim_translation_rule

def _broadcast_in_dim_process_primitive(trace, *tracers, **params):
if masking.is_polymorphic(params['shape']):
return _jaxpr_process_primitive_without_lowering(broadcast_in_dim_p)(trace, *tracers, **params)
return trace.default_process_primitive(broadcast_in_dim_p, tracers, params)
pe.custom_partial_eval_rules[broadcast_in_dim_p] = _broadcast_in_dim_process_primitive


def _clamp_shape_rule(min, operand, max):
Expand Down Expand Up @@ -3252,13 +3340,16 @@ def _transpose_masking_rule(padded_vals, logical_shapes, permutation):


def _select_shape_rule(pred, on_true, on_false):
if on_true.shape != on_false.shape:
pred_shape = masking.ensure_padded_shape(pred.shape)
t_shape = masking.ensure_padded_shape(on_true.shape)
f_shape = masking.ensure_padded_shape(on_false.shape)
if t_shape != f_shape:
msg = "select on_true and on_false must have the same shape, got {} and {}."
raise TypeError(msg.format(on_true.shape, on_false.shape))
if pred.shape and pred.shape != on_true.shape:
raise TypeError(msg.format(t_shape, f_shape))
if pred_shape and pred_shape != t_shape:
msg = ("select pred must be scalar or have the same shape as on_true and "
"on_false, got pred shape {} for on_true and on_false of shape {}.")
raise TypeError(msg.format(pred.shape, on_true.shape))
raise TypeError(msg.format(pred_shape, t_shape))
return on_true.shape

def _select_dtype_rule(pred, on_true, on_false):
Expand Down Expand Up @@ -5238,6 +5329,9 @@ def _dilate_shape(shape, dilation):
def _ceil_divide(x1, x2):
return -onp.floor_divide(onp.negative(x1), x2)

def _ceil(x):
return _ceil_divide(x, 1)

def padtype_to_pads(in_shape, window_shape, window_strides, padding):
"""Convert padding string to list of pairs of pad values."""
PaddingType = xla_client.PaddingType
Expand Down
4 changes: 2 additions & 2 deletions jax/numpy/lax_numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -2252,8 +2252,8 @@ def identity(n, dtype=None):
def arange(start, stop=None, step=None, dtype=None):
lax._check_user_dtype_supported(dtype, "arange")
if stop is None and step is None:
dtype = dtype or _dtype(start)
return lax.iota(dtype, np.ceil(start)) # avoids materializing
dtype = dtype or (int64 if type(start) is Poly else _dtype(start))
return lax.iota(dtype, lax.lax._ceil(start)) # avoids materializing
else:
return array(np.arange(start, stop=stop, step=step, dtype=dtype))

Expand Down
Loading

0 comments on commit 8945e77

Please sign in to comment.