Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow reshapes for 'F' ordered arrays #455

Merged
merged 31 commits into from
Sep 9, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
31 commits
Select commit Hold shift + click to select a range
efcae65
First shot at implementing 'F' ordered array reshapes
a-alveyblanc Sep 11, 2023
cc8f07f
Resolve merge conflicts
a-alveyblanc Sep 11, 2023
35c6d1f
Remove restriction on reshape order
a-alveyblanc Sep 11, 2023
a8ae5e2
Merge branch 'inducer:main' into implement-f-ordered-reshapes
a-alveyblanc Oct 8, 2023
1d6562f
Merge branch 'inducer:main' into implement-f-ordered-reshapes
a-alveyblanc Feb 11, 2024
614d98a
Merge branch 'inducer:main' into implement-f-ordered-reshapes
a-alveyblanc Mar 22, 2024
4cad54d
Refactor to unify paths
a-alveyblanc Mar 29, 2024
6ef8ca0
Slight adjustment
a-alveyblanc Mar 29, 2024
4268c7f
Merge branch 'implement-f-ordered-reshapes' of github.com:a-alveyblan…
a-alveyblanc Mar 29, 2024
4977c76
Update comment wording
a-alveyblanc Mar 29, 2024
48da7d0
Rough draft of old -> new axis mapping
a-alveyblanc Mar 31, 2024
4daa100
Absorb new tag propagation
a-alveyblanc Apr 22, 2024
e9a07c3
Revert changes
a-alveyblanc Apr 22, 2024
e4a3a42
Merge branch 'main' of https://github.com/inducer/pytato into impleme…
a-alveyblanc May 25, 2024
e9f6525
don't linearize everything
a-alveyblanc May 27, 2024
c1553b6
pass tests with new reshaped indices getter
a-alveyblanc May 27, 2024
db2bfff
revert change that belongs to another branch
a-alveyblanc May 27, 2024
896eac0
fix mypy issues
a-alveyblanc May 27, 2024
43c671d
fix failing arraycontext test
a-alveyblanc May 27, 2024
97e1314
fix failing flake8 test
a-alveyblanc May 27, 2024
6366f6f
fixes
a-alveyblanc Sep 3, 2024
e2ea820
resolve merge conflicts
a-alveyblanc Sep 3, 2024
2d2dd27
remove unnecessary import
a-alveyblanc Sep 3, 2024
bb794d8
ruff fixes + small change
a-alveyblanc Sep 3, 2024
e3c2173
Stopped in the middle
inducer Sep 4, 2024
969c03a
improved index expression generation
a-alveyblanc Sep 5, 2024
f6b583b
add test for F ordered reshapes
a-alveyblanc Sep 5, 2024
19e2d50
remove redundant scalar check
a-alveyblanc Sep 5, 2024
537ad3b
Improve reshape docs
inducer Sep 9, 2024
61cedbd
Goodbye unnecessary variable
inducer Sep 9, 2024
57ae8b2
Way more reshape test coverage
inducer Sep 9, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 12 additions & 6 deletions pytato/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -1618,8 +1618,6 @@ class Reshape(_SuppliedAxesAndTagsMixin, IndexRemappingBase):

if __debug__:
def __attrs_post_init__(self) -> None:
# FIXME: Get rid of this restriction
assert self.order == "C"
super().__attrs_post_init__()

@property
Expand Down Expand Up @@ -2123,8 +2121,16 @@ def reshape(array: Array, newshape: int | Sequence[int],
"""
:param array: array to be reshaped
:param newshape: shape of the resulting array
:param order: ``"C"`` or ``"F"``. Layout order of the result array. Only
``"C"`` allowed for now.
:param order: ``"C"`` or ``"F"``. For each group of
non-matching shape axes, indices in the input
*and* the output array are linearized according to this order
and 'matched up'.

Groups are found by multiplying axis lengths on the input and output side,
a matching input/output group is found once adding an input or axis to the
group makes the two products match.

The semantics are identical to :func:`numpy.reshape`.

.. note::

Expand All @@ -2144,8 +2150,8 @@ def reshape(array: Array, newshape: int | Sequence[int],
if not all(isinstance(axis_len, INT_CLASSES) for axis_len in array.shape):
raise ValueError("reshape of arrays with symbolic lengths not allowed")

if order != "C":
raise NotImplementedError("Reshapes to a 'F'-ordered arrays")
if order.upper() not in ["F", "C"]:
raise ValueError("order must be one of F or C")

newshape_explicit = []

Expand Down
164 changes: 142 additions & 22 deletions pytato/transform/lower_to_index_lambda.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
THE SOFTWARE.
"""

from dataclasses import dataclass
from typing import TYPE_CHECKING, Any, TypeVar

from immutabledict import immutabledict
Expand All @@ -48,48 +49,167 @@
NormalizedSlice,
Reshape,
Roll,
ShapeComponent,
ShapeType,
Stack,
)
from pytato.diagnostic import CannotBeLoweredToIndexLambda
from pytato.scalar_expr import INT_CLASSES, IntegralT, ScalarExpression
from pytato.scalar_expr import INT_CLASSES, ScalarExpression
from pytato.tags import AssumeNonNegative
from pytato.transform import Mapper


ToIndexLambdaT = TypeVar("ToIndexLambdaT", Array, AbstractResultWithNamedArrays)


@dataclass(frozen=True)
class _ReshapeIndexGroup:
old_ax_indices: tuple[ShapeComponent, ...]
new_ax_indices: tuple[ShapeComponent, ...]


def _generate_index_expressions(
old_shape: ShapeType,
new_shape: ShapeType,
order: str,
index_vars: list[prim.Variable]) -> tuple[ScalarExpression, ...]:

old_strides = [1]
new_strides = [1]
old_size_tills = [old_shape[-1] if order == "C" else old_shape[0]]

old_stride_axs = (old_shape[::-1][:-1] if order == "C" else
old_shape[:-1])
for old_ax in old_stride_axs:
old_strides.append(old_strides[-1]*old_ax)

new_stride_axs = (new_shape[::-1][:-1] if order == "C" else
new_shape[:-1])
for new_ax in new_stride_axs:
new_strides.append(new_strides[-1]*new_ax)

old_size_till_axs = (old_shape[:-1][::-1] if order == "C" else
old_shape[1:])
for old_ax in old_size_till_axs:
old_size_tills.append(old_size_tills[-1]*old_ax)

if order == "C":
old_strides = old_strides[::-1]
new_strides = new_strides[::-1]
old_size_tills = old_size_tills[::-1]

flattened_index_expn = sum(
index_var*new_stride
for index_var, new_stride in zip(index_vars, new_strides))

return tuple(
(flattened_index_expn % old_size_till) // old_stride
for old_size_till, old_stride in zip(old_size_tills, old_strides))


def _get_reshaped_indices(expr: Reshape) -> tuple[ScalarExpression, ...]:
if expr.array.shape == ():
# RHS must be a scalar i.e. RHS' indices are empty

if expr.order.upper() not in ["C", "F"]:
raise NotImplementedError("Order expected to be 'C' or 'F'",
" (case insensitive). Found order = ",
f"{expr.order}")

order = expr.order
old_shape = expr.array.shape
new_shape = expr.shape

# index variables need to be unique and depend on the new shape length
index_vars = [prim.Variable(f"_{i}") for i in range(len(new_shape))]

# {{{ check for scalars

if old_shape == ():
assert expr.size == 1
return ()

if expr.order != "C":
raise NotImplementedError(expr.order)
if new_shape == ():
return _generate_index_expressions(old_shape, new_shape, order,
index_vars)

if 0 in old_shape and 0 in new_shape:
return _generate_index_expressions(old_shape, new_shape, order,
index_vars)

# }}}

# {{{ generate subsets of old axes mapped to subsets of new axes

axis_mapping: list[_ReshapeIndexGroup] = []

old_index = 0
new_index = 0

while old_index < len(old_shape) and new_index < len(new_shape):
old_ax_len_product = old_shape[old_index]
new_ax_len_product = new_shape[new_index]

old_product_end = old_index + 1
new_product_end = new_index + 1

while old_ax_len_product != new_ax_len_product:
if not isinstance(old_ax_len_product, INT_CLASSES) or \
not isinstance(new_ax_len_product, INT_CLASSES):
raise TypeError("Cannot determine which axes were expanded or "
"collapsed symbolically")

if new_ax_len_product < old_ax_len_product:
new_ax_len_product *= new_shape[new_product_end]
new_product_end += 1
else:
old_ax_len_product *= old_shape[old_product_end]
old_product_end += 1

old_ax_indices = old_shape[old_index:old_product_end]
new_ax_indices = new_shape[new_index:new_product_end]

axis_mapping.append(_ReshapeIndexGroup(
old_ax_indices=old_ax_indices,
new_ax_indices=new_ax_indices))

old_index = old_product_end
new_index = new_product_end

# handle trailing 1s
final_reshaped_indices = axis_mapping.pop(-1)
old_ax_indices = final_reshaped_indices.old_ax_indices
new_ax_indices = final_reshaped_indices.new_ax_indices

while old_index < len(old_shape):
old_ax_indices += tuple([old_shape[old_index]]) # noqa: C409
old_index += 1

while new_index < len(new_shape):
new_ax_indices += tuple([new_shape[new_index]]) # noqa: C409
new_index += 1

axis_mapping.append(_ReshapeIndexGroup(old_ax_indices=old_ax_indices,
new_ax_indices=new_ax_indices))

# }}}

# {{{ compute index expressions for sub shapes

newstrides: list[IntegralT] = [1] # reshaped array strides
for new_axis_len in reversed(expr.shape[1:]):
assert isinstance(new_axis_len, INT_CLASSES)
newstrides.insert(0, newstrides[0]*new_axis_len)
index_vars_begin = 0
index_expressions = []
for reshaped_indices in axis_mapping:
sub_old_shape = reshaped_indices.old_ax_indices
sub_new_shape = reshaped_indices.new_ax_indices

flattened_idx = sum(prim.Variable(f"_{i}")*stride
for i, stride in enumerate(newstrides))
index_vars_end = index_vars_begin + len(sub_new_shape)
sub_index_vars = index_vars[index_vars_begin:index_vars_end]
index_vars_begin = index_vars_end

oldstrides: list[IntegralT] = [1] # input array strides
for axis_len in reversed(expr.array.shape[1:]):
assert isinstance(axis_len, INT_CLASSES)
oldstrides.insert(0, oldstrides[0]*axis_len)
index_expressions.append(_generate_index_expressions(
sub_old_shape, sub_new_shape, order, sub_index_vars))

assert isinstance(expr.array.shape[-1], INT_CLASSES)
oldsizetills = [expr.array.shape[-1]] # input array size till for axes idx
for old_axis_len in reversed(expr.array.shape[:-1]):
assert isinstance(old_axis_len, INT_CLASSES)
oldsizetills.insert(0, oldsizetills[0]*old_axis_len)
# }}}

return tuple(((flattened_idx % sizetill) // stride)
for stride, sizetill in zip(oldstrides, oldsizetills))
return sum(index_expressions, ())


class ToIndexLambdaMixin:
Expand Down
26 changes: 12 additions & 14 deletions test/test_codegen.py
Original file line number Diff line number Diff line change
Expand Up @@ -523,17 +523,14 @@ def test_concatenate(ctx_factory):
assert_allclose_to_numpy(pt.concatenate((x0, x1, x2), axis=1), queue)


@pytest.mark.parametrize("oldshape", [(36,),
(3, 3, 4),
(12, 3),
(2, 2, 3, 3, 1)])
@pytest.mark.parametrize("newshape", [(-1,),
(-1, 6),
(4, 9),
(9, -1),
(36, -1),
36])
def test_reshape(ctx_factory, oldshape, newshape):
_SHAPES = [(36,), (3, 3, 4), (12, 3), (2, 2, 3, 3, 1), (4, 9), (9, 4)]


@pytest.mark.parametrize("oldshape", _SHAPES)
@pytest.mark.parametrize("newshape", [
*_SHAPES, (-1,), (-1, 6), (4, 9), (9, -1), (36, -1), 36])
@pytest.mark.parametrize("order", ["C", "F"])
def test_reshape(ctx_factory, oldshape, newshape, order):
cl_ctx = ctx_factory()
queue = cl.CommandQueue(cl_ctx)

Expand All @@ -543,10 +540,11 @@ def test_reshape(ctx_factory, oldshape, newshape):

x = pt.make_data_wrapper(x_in)

assert_allclose_to_numpy(pt.reshape(x, newshape=newshape), queue)
assert_allclose_to_numpy(x.reshape(newshape), queue)
assert_allclose_to_numpy(pt.reshape(x, newshape=newshape, order=order),
queue)
assert_allclose_to_numpy(x.reshape(newshape, order=order), queue)
if isinstance(newshape, tuple):
assert_allclose_to_numpy(x.reshape(*newshape), queue)
assert_allclose_to_numpy(x.reshape(*newshape, order=order), queue)


def test_dict_of_named_array_codegen_avoids_recomputation():
Expand Down
Loading