Skip to content

Commit

Permalink
Reverts 2075b09
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 698152759
  • Loading branch information
hawkinsp authored and Google-ML-Automation committed Nov 19, 2024
1 parent 42fbd30 commit 525b646
Show file tree
Hide file tree
Showing 12 changed files with 30 additions and 197 deletions.
3 changes: 0 additions & 3 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -59,9 +59,6 @@ When releasing, please add the new-release-boilerplate to docs/pallas/CHANGELOG.
* {func}`jax.lax.linalg.eig` and the related `jax.numpy` functions
({func}`jax.numpy.linalg.eig` and {func}`jax.numpy.linalg.eigvals`) are now
supported on GPU. See {jax-issue}`#24663` for more details.
* Added {func}`jax.lax.split`. This is a primitive version of
{func}`jax.numpy.split`, added because it yields a more compact
transpose in automatic differentiation.

* Bug fixes
* Fixed a bug where the GPU implementations of LU and QR decomposition would
Expand Down
1 change: 0 additions & 1 deletion docs/jax.lax.rst
Original file line number Diff line number Diff line change
Expand Up @@ -154,7 +154,6 @@ Operators
slice_in_dim
sort
sort_key_val
split
sqrt
square
squeeze
Expand Down
96 changes: 12 additions & 84 deletions jax/_src/lax/lax.py
Original file line number Diff line number Diff line change
Expand Up @@ -654,26 +654,6 @@ def concatenate(operands: Array | Sequence[ArrayLike], dimension: int) -> Array:
return concatenate_p.bind(*operands, dimension=dimension)


def split(operand: ArrayLike, sizes: Sequence[int],
axis: int = 0) -> Sequence[Array]:
"""Splits an array along ``axis``.
Args:
operand: an array to split
sizes: the sizes of the split arrays. The sum of the sizes must be equal
to the size of the ``axis`` dimension of ``operand``.
axis: the axis along which to split the array.
Returns:
A sequence of ``len(sizes)`` arrays. If ``sizes`` is
``[s1, s2, ...]``, this function returns chunks of sizes ``s1``, ``s2``,
taken along ``axis``.
"""
operand = asarray(operand)
return split_p.bind(operand, sizes=tuple(sizes),
axis=canonicalize_axis(axis, operand.ndim))


_precision_strings: dict[Any, Precision] = {}

class Precision(enum.Enum):
Expand Down Expand Up @@ -4393,8 +4373,18 @@ def _concatenate_transpose_rule(t, *operands, dimension):
return [ad_util.Zero(o.aval) if ad.is_undefined_primal(o) else None
for o in operands]
else:
return split(t, tuple(shape[dimension] for shape in operand_shapes),
axis=dimension)
limit_points = np.cumsum(
[shape[dimension] for shape in operand_shapes]).tolist()
starts = np.zeros((len(operands), t.ndim), dtype=int).tolist()
limits = np.tile(t.shape, (len(operands), 1)).tolist()

for i, s in enumerate(starts[1:]):
s[dimension] = limit_points[:-1][i]
for i, l in enumerate(limits):
l[dimension] = limit_points[i]

return [slicing.slice(t, start, limit) if ad.is_undefined_primal(o)
else None for o, start, limit in zip(operands, starts, limits)]

def _concatenate_batch_rule(batched_args, batch_dims, *, dimension):
size = next(op.shape[bdim] for op, bdim in zip(batched_args, batch_dims)
Expand Down Expand Up @@ -4423,68 +4413,6 @@ def _concatenate_lower(ctx, *xs, dimension):
mlir.register_lowering(concatenate_p, _concatenate_lower)


def _split_shape_rule(operand, *, sizes, axis):
offset = 0
shapes = []
shape = list(operand.shape)
if any(s < 0 for s in sizes):
raise ValueError(
f"Sizes passed to split must be nonnegative, got {list(sizes)}")
if operand.shape[axis] != np.sum(sizes):
raise ValueError(
f"Sum of sizes {np.sum(sizes)} must be equal to dimension {axis} of the "
f"operand shape {list(operand.shape)}")
for size in sizes:
shape[axis] = size
shapes.append(tuple(shape))
return shapes

def _split_dtype_rule(operand, *, sizes, axis):
return (operand.dtype,) * len(sizes)

def _split_weak_type_rule(operand, *, sizes, axis):
return (operand.weak_type,) * len(sizes)

def _split_transpose_rule(cotangents, operand, *, sizes, axis):
assert ad.is_undefined_primal(operand)
if all(type(t) is ad_util.Zero for t in cotangents):
return ad_util.Zero(operand.aval),
cotangents = [
_zeros(t.aval) if type(t) is ad_util.Zero else t
for t in cotangents
]
return concatenate(cotangents, dimension=axis),

def _split_batch_rule(batched_args, batch_dims, *, sizes, axis):
operand, = batched_args
bdim, = batch_dims
new_bdims = (bdim,) * len(sizes)
out = split(operand, sizes=sizes, axis=axis + 1 if axis >= bdim else axis)
return out, new_bdims

def _split_lower(ctx, x, *, sizes, axis):
x_aval, = ctx.avals_in
start_indices = [0] * x_aval.ndim
limit_indices = list(x_aval.shape)
strides = (1,) * x_aval.ndim
outs = []
for aval_out in ctx.avals_out:
limit_indices[axis] = start_indices[axis] + aval_out.shape[axis]
outs.append(mlir.slice_op(ctx, x, aval_out, start_indices=start_indices,
limit_indices=limit_indices, strides=strides))
start_indices[axis] = limit_indices[axis]
return outs

split_p = core.Primitive('split')
split_p.multiple_results = True
split_p.def_abstract_eval(
partial(standard_multi_result_abstract_eval, split_p, _split_shape_rule,
_split_dtype_rule, _split_weak_type_rule))
split_p.def_impl(partial(dispatch.apply_primitive, split_p))
ad.deflinear2(split_p, _split_transpose_rule)
batching.primitive_batchers[split_p] = _split_batch_rule
mlir.register_lowering(split_p, _split_lower)

def _pad_dtype_rule(operand, padding_value, *, padding_config):
if operand.dtype != padding_value.dtype:
msg = "pad operand and padding_value must be same dtype: got {} and {}."
Expand Down
3 changes: 1 addition & 2 deletions jax/_src/numpy/array_methods.py
Original file line number Diff line number Diff line change
Expand Up @@ -629,8 +629,7 @@ def _multi_slice(self: Array,
# avoid circular imports.
@jax.jit
def _unstack(x: Array) -> list[Array]:
dims = (0,)
return [lax.squeeze(t, dims) for t in lax.split(x, (1,) * x.shape[0])]
return [lax.index_in_dim(x, i, keepdims=False) for i in range(x.shape[0])]

def _chunk_iter(x, size):
if size > x.shape[0]:
Expand Down
31 changes: 17 additions & 14 deletions jax/_src/numpy/lax_numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@
)
from jax._src.util import (
NumpyComplexWarning, canonicalize_axis as _canonicalize_axis,
ceil_of_ratio, partition_list, safe_zip, set_module, unzip2,
ceil_of_ratio, partition_list, safe_zip, set_module, subvals,unzip2,
tuple_replace)
from jax.sharding import (Sharding, SingleDeviceSharding, NamedSharding,
PartitionSpec as P)
Expand Down Expand Up @@ -3280,10 +3280,10 @@ def _split(op: str, ary: ArrayLike,
if (isinstance(indices_or_sections, (tuple, list)) or
isinstance(indices_or_sections, (np.ndarray, Array)) and
indices_or_sections.ndim > 0):
split_indices = np.asarray([0] + [
indices_or_sections = [
core.concrete_dim_or_error(i_s, f"in jax.numpy.{op} argument 1")
for i_s in indices_or_sections] + [size])
sizes = list(np.diff(split_indices))
for i_s in indices_or_sections]
split_indices = [0] + list(indices_or_sections) + [size]
else:
if core.is_symbolic_dim(indices_or_sections):
raise ValueError(f"jax.numpy.{op} with a symbolic number of sections is "
Expand All @@ -3292,14 +3292,21 @@ def _split(op: str, ary: ArrayLike,
f"in jax.numpy.{op} argument 1")
part_size, r = divmod(size, num_sections)
if r == 0:
sizes = [part_size] * num_sections
split_indices = [i * part_size
for i in range(num_sections + 1)]
elif op == "array_split":
sizes = [(part_size + 1)] * r + [part_size] * (num_sections - r)
split_indices = (
[i * (part_size + 1) for i in range(r + 1)] +
[i * part_size + ((r + 1) * (part_size + 1) - 1)
for i in range(num_sections - r)])
else:
raise ValueError(f"array split does not result in an equal division: rest is {r}")
sizes = [i if core.is_symbolic_dim(i) else np.int64(i) # type: ignore[misc]
for i in sizes]
return list(lax.split(ary, sizes, axis=axis))
split_indices = [i if core.is_symbolic_dim(i) else np.int64(i) # type: ignore[misc]
for i in split_indices]
starts, ends = [0] * ndim(ary), shape(ary)
_subval = lambda x, i, v: subvals(x, [(i, v)])
return [lax.slice(ary, _subval(starts, axis, start), _subval(ends, axis, end))
for start, end in zip(split_indices[:-1], split_indices[1:])]


@export
Expand Down Expand Up @@ -4662,11 +4669,7 @@ def unstack(x: ArrayLike, /, *, axis: int = 0) -> tuple[Array, ...]:
"Unstack requires arrays with rank > 0, however a scalar array was "
"passed."
)
dimensions = (axis,)
return tuple(
lax.squeeze(t, dimensions)
for t in lax.split(x, (1,) * x.shape[axis], axis=axis)
)
return tuple(moveaxis(x, axis, 0))


@export
Expand Down
21 changes: 0 additions & 21 deletions jax/_src/pallas/mosaic/lowering.py
Original file line number Diff line number Diff line change
Expand Up @@ -1871,27 +1871,6 @@ def _concatenate_lowering_rule(ctx: LoweringRuleContext, *xs, dimension):
lowering_rules[lax.concatenate_p] = _concatenate_lowering_rule


def _split_lowering_rule(
ctx: LoweringRuleContext, x, *, sizes, axis
):
(x_aval,) = ctx.avals_in
slice_size = np.array(x_aval.shape, dtype=np.int64)
starts = np.zeros_like(slice_size)
strides = np.ones_like(slice_size)
outs = []
for size, aval_out in zip(sizes, ctx.avals_out):
slice_size[axis] = size
outs.append(
vector.extract_strided_slice(
aval_to_ir_type(aval_out), x, starts, slice_size, strides
)
)
starts[axis] += size
return outs

lowering_rules[lax.split_p] = _split_lowering_rule


def _iota_lowering_rule(ctx: LoweringRuleContext, dtype, shape, dimension,
sharding):
out_type = aval_to_ir_type(ctx.avals_out[0])
Expand Down
6 changes: 0 additions & 6 deletions jax/experimental/jax2tf/jax2tf.py
Original file line number Diff line number Diff line change
Expand Up @@ -2087,12 +2087,6 @@ def _concatenate(*operands, dimension):
tf_impl[lax.concatenate_p] = _concatenate


def _split(operand, *, sizes, axis):
return tf.split(operand, sizes, axis=axis)

tf_impl[lax.split_p] = _split


def _conv_general_dimension_numbers_proto(dimension_numbers):
"""Converts a ConvDimensionNumbers to an XLA ConvolutionDimensionNumbers."""
assert isinstance(dimension_numbers, lax.ConvDimensionNumbers)
Expand Down
1 change: 0 additions & 1 deletion jax/experimental/jet.py
Original file line number Diff line number Diff line change
Expand Up @@ -323,7 +323,6 @@ def linear_prop(prim, primals_in, series_in, **params):
deflinear(lax.convert_element_type_p)
deflinear(lax.broadcast_in_dim_p)
deflinear(lax.concatenate_p)
deflinear(lax.split_p)
deflinear(lax.pad_p)
deflinear(lax.reshape_p)
deflinear(lax.squeeze_p)
Expand Down
2 changes: 0 additions & 2 deletions jax/lax/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -203,8 +203,6 @@
sort as sort,
sort_key_val as sort_key_val,
sort_p as sort_p,
split as split,
split_p as split_p,
sqrt as sqrt,
sqrt_p as sqrt_p,
square as square,
Expand Down
18 changes: 0 additions & 18 deletions tests/lax_autodiff_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -273,24 +273,6 @@ def testConcatenateGrad(self, dim, base_shape, dtype, num_arrs):
concatenate = lambda *args: lax.concatenate(args, dim)
check_grads(concatenate, operands, 2, ["fwd", "rev"], eps=1.)

@jtu.sample_product(
[dict(base_shape=base_shape, axis=axis)
for base_shape in [(4,), (3, 4), (2, 3, 4)]
for axis in range(len(base_shape))
],
num_pieces=range(3),
dtype=float_dtypes,
)
def testSplitGrad(self, axis, base_shape, dtype, num_pieces):
sizes = jtu.rand_int(self.rng(), 5)((num_pieces + 1,), np.int64)
shape = list(base_shape)
shape[axis] = np.sum(sizes)
rng = jtu.rand_default(self.rng())
operands = (rng(shape, dtype),)
split = lambda x: lax.split(x, sizes, axis)
check_grads(split, operands, 2, ["fwd", "rev"], eps=1.)


@jtu.sample_product(
[dict(lhs_shape=lhs_shape, rhs_shape=rhs_shape, strides=strides)
for lhs_shape, rhs_shape, all_strides in itertools.chain(
Expand Down
27 changes: 0 additions & 27 deletions tests/lax_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -283,33 +283,6 @@ def testConcatenateAgainstNumpy(self, dim, base_shape, dtype, num_arrs):
numpy_op = lambda *args: lax_reference.concatenate(args, dim)
self._CheckAgainstNumpy(numpy_op, op, args_maker)

@jtu.sample_product(
[dict(base_shape=shape, axis=axis) for shape in [(4,), (3, 4), (2, 3, 4)]
for axis in range(len(shape))],
num_pieces=range(3),
dtype=lax_test_util.default_dtypes,
)
def testSplit(self, axis, base_shape, dtype, num_pieces):
sizes = jtu.rand_int(self.rng(), 5)((num_pieces + 1,), np.int64)
shape = list(base_shape)
shape[axis] = np.sum(sizes)
rng = jtu.rand_default(self.rng())
args_maker = lambda: [rng(shape, dtype)]
op = lambda x: lax.split(x, sizes, axis=axis)
def numpy_op(x):
return np.split(x, np.cumsum(sizes[:-1]), axis=axis)
self._CompileAndCheck(op, args_maker)
self._CheckAgainstNumpy(numpy_op, op, args_maker)

def testSplitErrors(self):
with self.assertRaisesRegex(ValueError,
"Sizes passed to split must be nonnegative"):
lax.split(np.arange(5), [-1])
with self.assertRaisesRegex(ValueError, "Sum of sizes 6 must be equal"):
lax.split(np.arange(5), [6])
with self.assertRaisesRegex(ValueError, "axis 1 is out of bounds"):
lax.split(np.arange(5), sizes=(), axis=1)

@jtu.sample_product(
[
dict(lhs_shape=(b, i, 9, 10), rhs_shape=(j, i, 4, 5))
Expand Down
18 changes: 0 additions & 18 deletions tests/lax_vmap_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -344,24 +344,6 @@ def testSlice(self, shape, dtype, starts, limits, strides, bdims):
op = lambda x: lax.slice(x, starts, limits, strides)
self._CheckBatching(op, 5, bdims, (shape,), (dtype,), rng)

@jtu.sample_product(
[dict(base_shape=base_shape, axis=axis, bdims=bdims)
for base_shape in [(4,), (3, 4), (2, 3, 4)]
for axis in range(len(base_shape))
for bdims in lax_test_util.all_bdims(base_shape)
],
num_pieces=range(3),
dtype=lax_test_util.default_dtypes,
)
def testSplit(self, base_shape, dtype, num_pieces, axis, bdims):
sizes = jtu.rand_int(self.rng(), 5)((num_pieces + 1,), np.int64)
shape = list(base_shape)
shape[axis] = np.sum(sizes)
rng = jtu.rand_default(self.rng())
op = lambda x: lax.split(x, sizes, axis)
self._CheckBatching(op, 5, bdims, (shape,), (dtype,), rng,
multiple_results=True)

@jtu.sample_product(
[dict(shape=shape, perm=perm, bdims=bdims)
for shape, perm in [
Expand Down

0 comments on commit 525b646

Please sign in to comment.