Skip to content

Specialize away Split #1334

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 3 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 1 addition & 4 deletions pytensor/link/numba/dispatch/tensor_basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,10 +136,7 @@ def join(axis, *tensors):
def numba_funcify_Split(op, **kwargs):
@numba_basic.numba_njit
def split(tensor, axis, indices):
# Work around for https://github.com/numba/numba/issues/8257
axis = axis % tensor.ndim
axis = numba_basic.to_scalar(axis)
return np.split(tensor, np.cumsum(indices)[:-1], axis=axis)
return np.split(tensor, np.cumsum(indices)[:-1], axis=axis.item())

return split

Expand Down
22 changes: 20 additions & 2 deletions pytensor/tensor/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -2201,8 +2201,26 @@ def make_node(self, x, axis, splits):
raise TypeError("`axis` parameter must be an integer scalar")

inputs = [x, axis, splits]
out_type = TensorType(dtype=x.dtype, shape=(None,) * x.type.ndim)
outputs = [out_type() for i in range(self.len_splits)]

x_dtype = x.type.dtype
if isinstance(axis, Constant):
# In this case we can preserve more static shape info
static_axis = axis.data.item()
outputs = []
x_static_shape = list(x.type.shape)
for i in range(self.len_splits):
try:
static_split_size = int(get_scalar_constant_value(splits[i]))
except NotScalarConstantError:
static_split_size = None
static_out_shape = x_static_shape.copy()
static_out_shape[static_axis] = static_split_size
outputs.append(tensor(shape=tuple(static_out_shape), dtype=x_dtype))
else:
outputs = [
tensor(shape=(None,) * x.type.ndim, dtype=x_dtype)
for i in range(self.len_splits)
]

return Apply(self, inputs, outputs)

Expand Down
34 changes: 34 additions & 0 deletions pytensor/tensor/rewriting/subtensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
Join,
MakeVector,
ScalarFromTensor,
Split,
TensorFromScalar,
alloc,
as_tensor,
Expand Down Expand Up @@ -616,6 +617,39 @@ def local_subtensor_remove_broadcastable_index(fgraph, node):
return [node.inputs[0].dimshuffle(tuple(remain_dim))]


@register_specialize("shape_unsafe")
@node_rewriter(tracks=[Split])
def split_to_subtensor(fgraph, node):
"""Rewrite split(2)(x, 0) -> (x[:split_sizes[0]], x[split_sizes[0]:).

This allows lifting the underlying split close to the inputs, and increases fusion opportunities.
It automatically handles unused split outputs as well

It only works for constant axis
"""
x, axis, split_sizes = node.inputs

n_splits = node.op.len_splits
if n_splits <= 1:
return [x]

if not isinstance(axis, Constant):
return None

empty_slices = (slice(None),) * int(axis.data)
ys = []

end = split_sizes[0]
ys.append(x[(*empty_slices, slice(None, end))])
prev_start = end
for i in range(1, n_splits - 1):
end = prev_start + split_sizes[i]
ys.append(x[(*empty_slices, slice(prev_start, end))])
prev_start = end
ys.append(x[(*empty_slices, slice(prev_start, None))])
return ys


@register_infer_shape
@register_useless
@register_canonicalize
Expand Down
Loading