Skip to content

Commit

Permalink
Allow reshapes for 'F' ordered arrays (#455)
Browse files Browse the repository at this point in the history
* First shot at implementing 'F' ordered array reshapes

* Remove restriction on reshape order

* Refactor to unify paths

* Slight adjustment

* Update comment wording

* Rough draft of old -> new axis mapping

* Absorb new tag propagation

* Revert changes

* don't linearize everything

* pass tests with new reshaped indices getter

* revert change that belongs to another branch

* fix mypy issues

* fix failing arraycontext test

* fix failing flake8 test

* fixes

* remove unnecessary import

* ruff fixes + small change

* Stopped in the middle

* improved index expression generation

* add test for F ordered reshapes

* remove redundant scalar check

* Improve reshape docs

* Goodbye unnecessary variable

* Way more reshape test coverage

---------

Co-authored-by: Andreas Kloeckner <inform@tiker.net>
  • Loading branch information
a-alveyblanc and inducer authored Sep 9, 2024
1 parent a92a0d1 commit 2de3bc9
Show file tree
Hide file tree
Showing 3 changed files with 166 additions and 42 deletions.
18 changes: 12 additions & 6 deletions pytato/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -1618,8 +1618,6 @@ class Reshape(_SuppliedAxesAndTagsMixin, IndexRemappingBase):

if __debug__:
def __attrs_post_init__(self) -> None:
# FIXME: Get rid of this restriction
assert self.order == "C"
super().__attrs_post_init__()

@property
Expand Down Expand Up @@ -2123,8 +2121,16 @@ def reshape(array: Array, newshape: int | Sequence[int],
"""
:param array: array to be reshaped
:param newshape: shape of the resulting array
:param order: ``"C"`` or ``"F"``. Layout order of the result array. Only
``"C"`` allowed for now.
:param order: ``"C"`` or ``"F"``. For each group of
non-matching shape axes, indices in the input
*and* the output array are linearized according to this order
and 'matched up'.
Groups are found by multiplying axis lengths on the input and output side,
a matching input/output group is found once adding an input or axis to the
group makes the two products match.
The semantics are identical to :func:`numpy.reshape`.
.. note::
Expand All @@ -2144,8 +2150,8 @@ def reshape(array: Array, newshape: int | Sequence[int],
if not all(isinstance(axis_len, INT_CLASSES) for axis_len in array.shape):
raise ValueError("reshape of arrays with symbolic lengths not allowed")

if order != "C":
raise NotImplementedError("Reshapes to a 'F'-ordered arrays")
if order.upper() not in ["F", "C"]:
raise ValueError("order must be one of F or C")

newshape_explicit = []

Expand Down
164 changes: 142 additions & 22 deletions pytato/transform/lower_to_index_lambda.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
THE SOFTWARE.
"""

from dataclasses import dataclass
from typing import TYPE_CHECKING, Any, TypeVar

from immutabledict import immutabledict
Expand All @@ -48,48 +49,167 @@
NormalizedSlice,
Reshape,
Roll,
ShapeComponent,
ShapeType,
Stack,
)
from pytato.diagnostic import CannotBeLoweredToIndexLambda
from pytato.scalar_expr import INT_CLASSES, IntegralT, ScalarExpression
from pytato.scalar_expr import INT_CLASSES, ScalarExpression
from pytato.tags import AssumeNonNegative
from pytato.transform import Mapper


ToIndexLambdaT = TypeVar("ToIndexLambdaT", Array, AbstractResultWithNamedArrays)


@dataclass(frozen=True)
class _ReshapeIndexGroup:
old_ax_indices: tuple[ShapeComponent, ...]
new_ax_indices: tuple[ShapeComponent, ...]


def _generate_index_expressions(
old_shape: ShapeType,
new_shape: ShapeType,
order: str,
index_vars: list[prim.Variable]) -> tuple[ScalarExpression, ...]:

old_strides = [1]
new_strides = [1]
old_size_tills = [old_shape[-1] if order == "C" else old_shape[0]]

old_stride_axs = (old_shape[::-1][:-1] if order == "C" else
old_shape[:-1])
for old_ax in old_stride_axs:
old_strides.append(old_strides[-1]*old_ax)

new_stride_axs = (new_shape[::-1][:-1] if order == "C" else
new_shape[:-1])
for new_ax in new_stride_axs:
new_strides.append(new_strides[-1]*new_ax)

old_size_till_axs = (old_shape[:-1][::-1] if order == "C" else
old_shape[1:])
for old_ax in old_size_till_axs:
old_size_tills.append(old_size_tills[-1]*old_ax)

if order == "C":
old_strides = old_strides[::-1]
new_strides = new_strides[::-1]
old_size_tills = old_size_tills[::-1]

flattened_index_expn = sum(
index_var*new_stride
for index_var, new_stride in zip(index_vars, new_strides))

return tuple(
(flattened_index_expn % old_size_till) // old_stride
for old_size_till, old_stride in zip(old_size_tills, old_strides))


def _get_reshaped_indices(expr: Reshape) -> tuple[ScalarExpression, ...]:
if expr.array.shape == ():
# RHS must be a scalar i.e. RHS' indices are empty

if expr.order.upper() not in ["C", "F"]:
raise NotImplementedError("Order expected to be 'C' or 'F'",
" (case insensitive). Found order = ",
f"{expr.order}")

order = expr.order
old_shape = expr.array.shape
new_shape = expr.shape

# index variables need to be unique and depend on the new shape length
index_vars = [prim.Variable(f"_{i}") for i in range(len(new_shape))]

# {{{ check for scalars

if old_shape == ():
assert expr.size == 1
return ()

if expr.order != "C":
raise NotImplementedError(expr.order)
if new_shape == ():
return _generate_index_expressions(old_shape, new_shape, order,
index_vars)

if 0 in old_shape and 0 in new_shape:
return _generate_index_expressions(old_shape, new_shape, order,
index_vars)

# }}}

# {{{ generate subsets of old axes mapped to subsets of new axes

axis_mapping: list[_ReshapeIndexGroup] = []

old_index = 0
new_index = 0

while old_index < len(old_shape) and new_index < len(new_shape):
old_ax_len_product = old_shape[old_index]
new_ax_len_product = new_shape[new_index]

old_product_end = old_index + 1
new_product_end = new_index + 1

while old_ax_len_product != new_ax_len_product:
if not isinstance(old_ax_len_product, INT_CLASSES) or \
not isinstance(new_ax_len_product, INT_CLASSES):
raise TypeError("Cannot determine which axes were expanded or "
"collapsed symbolically")

if new_ax_len_product < old_ax_len_product:
new_ax_len_product *= new_shape[new_product_end]
new_product_end += 1
else:
old_ax_len_product *= old_shape[old_product_end]
old_product_end += 1

old_ax_indices = old_shape[old_index:old_product_end]
new_ax_indices = new_shape[new_index:new_product_end]

axis_mapping.append(_ReshapeIndexGroup(
old_ax_indices=old_ax_indices,
new_ax_indices=new_ax_indices))

old_index = old_product_end
new_index = new_product_end

# handle trailing 1s
final_reshaped_indices = axis_mapping.pop(-1)
old_ax_indices = final_reshaped_indices.old_ax_indices
new_ax_indices = final_reshaped_indices.new_ax_indices

while old_index < len(old_shape):
old_ax_indices += tuple([old_shape[old_index]]) # noqa: C409
old_index += 1

while new_index < len(new_shape):
new_ax_indices += tuple([new_shape[new_index]]) # noqa: C409
new_index += 1

axis_mapping.append(_ReshapeIndexGroup(old_ax_indices=old_ax_indices,
new_ax_indices=new_ax_indices))

# }}}

# {{{ compute index expressions for sub shapes

newstrides: list[IntegralT] = [1] # reshaped array strides
for new_axis_len in reversed(expr.shape[1:]):
assert isinstance(new_axis_len, INT_CLASSES)
newstrides.insert(0, newstrides[0]*new_axis_len)
index_vars_begin = 0
index_expressions = []
for reshaped_indices in axis_mapping:
sub_old_shape = reshaped_indices.old_ax_indices
sub_new_shape = reshaped_indices.new_ax_indices

flattened_idx = sum(prim.Variable(f"_{i}")*stride
for i, stride in enumerate(newstrides))
index_vars_end = index_vars_begin + len(sub_new_shape)
sub_index_vars = index_vars[index_vars_begin:index_vars_end]
index_vars_begin = index_vars_end

oldstrides: list[IntegralT] = [1] # input array strides
for axis_len in reversed(expr.array.shape[1:]):
assert isinstance(axis_len, INT_CLASSES)
oldstrides.insert(0, oldstrides[0]*axis_len)
index_expressions.append(_generate_index_expressions(
sub_old_shape, sub_new_shape, order, sub_index_vars))

assert isinstance(expr.array.shape[-1], INT_CLASSES)
oldsizetills = [expr.array.shape[-1]] # input array size till for axes idx
for old_axis_len in reversed(expr.array.shape[:-1]):
assert isinstance(old_axis_len, INT_CLASSES)
oldsizetills.insert(0, oldsizetills[0]*old_axis_len)
# }}}

return tuple(((flattened_idx % sizetill) // stride)
for stride, sizetill in zip(oldstrides, oldsizetills))
return sum(index_expressions, ())


class ToIndexLambdaMixin:
Expand Down
26 changes: 12 additions & 14 deletions test/test_codegen.py
Original file line number Diff line number Diff line change
Expand Up @@ -523,17 +523,14 @@ def test_concatenate(ctx_factory):
assert_allclose_to_numpy(pt.concatenate((x0, x1, x2), axis=1), queue)


@pytest.mark.parametrize("oldshape", [(36,),
(3, 3, 4),
(12, 3),
(2, 2, 3, 3, 1)])
@pytest.mark.parametrize("newshape", [(-1,),
(-1, 6),
(4, 9),
(9, -1),
(36, -1),
36])
def test_reshape(ctx_factory, oldshape, newshape):
_SHAPES = [(36,), (3, 3, 4), (12, 3), (2, 2, 3, 3, 1), (4, 9), (9, 4)]


@pytest.mark.parametrize("oldshape", _SHAPES)
@pytest.mark.parametrize("newshape", [
*_SHAPES, (-1,), (-1, 6), (4, 9), (9, -1), (36, -1), 36])
@pytest.mark.parametrize("order", ["C", "F"])
def test_reshape(ctx_factory, oldshape, newshape, order):
cl_ctx = ctx_factory()
queue = cl.CommandQueue(cl_ctx)

Expand All @@ -543,10 +540,11 @@ def test_reshape(ctx_factory, oldshape, newshape):

x = pt.make_data_wrapper(x_in)

assert_allclose_to_numpy(pt.reshape(x, newshape=newshape), queue)
assert_allclose_to_numpy(x.reshape(newshape), queue)
assert_allclose_to_numpy(pt.reshape(x, newshape=newshape, order=order),
queue)
assert_allclose_to_numpy(x.reshape(newshape, order=order), queue)
if isinstance(newshape, tuple):
assert_allclose_to_numpy(x.reshape(*newshape), queue)
assert_allclose_to_numpy(x.reshape(*newshape, order=order), queue)


def test_dict_of_named_array_codegen_avoids_recomputation():
Expand Down

0 comments on commit 2de3bc9

Please sign in to comment.