Skip to content

Commit

Permalink
Merge pull request #1869 from IntelPython/bugfix/gh-1857-roll-with-la…
Browse files Browse the repository at this point in the history
…rge-shift

Roll must reduce shift steps by size along axis
  • Loading branch information
oleksandr-pavlyk authored Oct 21, 2024
2 parents 85e4121 + bd0c9b2 commit 9eb8f03
Show file tree
Hide file tree
Showing 3 changed files with 42 additions and 14 deletions.
30 changes: 17 additions & 13 deletions dpctl/tensor/_manipulation_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -311,7 +311,7 @@ def flip(X, /, *, axis=None):
return X[indexer]


def roll(X, /, shift, *, axis=None):
def roll(x, /, shift, *, axis=None):
"""
roll(x, shift, axis)
Expand Down Expand Up @@ -343,41 +343,45 @@ def roll(X, /, shift, *, axis=None):
`device` attributes as `x` and whose elements are shifted relative
to `x`.
"""
if not isinstance(X, dpt.usm_ndarray):
raise TypeError(f"Expected usm_ndarray type, got {type(X)}.")
exec_q = X.sycl_queue
if not isinstance(x, dpt.usm_ndarray):
raise TypeError(f"Expected usm_ndarray type, got {type(x)}.")
exec_q = x.sycl_queue
_manager = dputils.SequentialOrderManager[exec_q]
if axis is None:
shift = operator.index(shift)
dep_evs = _manager.submitted_events
res = dpt.empty(
X.shape, dtype=X.dtype, usm_type=X.usm_type, sycl_queue=exec_q
x.shape, dtype=x.dtype, usm_type=x.usm_type, sycl_queue=exec_q
)
sz = operator.index(x.size)
shift = (shift % sz) if sz > 0 else 0
dep_evs = _manager.submitted_events
hev, roll_ev = ti._copy_usm_ndarray_for_roll_1d(
src=X,
src=x,
dst=res,
shift=shift,
sycl_queue=exec_q,
depends=dep_evs,
)
_manager.add_event_pair(hev, roll_ev)
return res
axis = normalize_axis_tuple(axis, X.ndim, allow_duplicate=True)
axis = normalize_axis_tuple(axis, x.ndim, allow_duplicate=True)
broadcasted = np.broadcast(shift, axis)
if broadcasted.ndim > 1:
raise ValueError("'shift' and 'axis' should be scalars or 1D sequences")
shifts = [
0,
] * X.ndim
] * x.ndim
shape = x.shape
for sh, ax in broadcasted:
shifts[ax] += sh

n_i = operator.index(shape[ax])
shifted = shifts[ax] + operator.index(sh)
shifts[ax] = (shifted % n_i) if n_i > 0 else 0
res = dpt.empty(
X.shape, dtype=X.dtype, usm_type=X.usm_type, sycl_queue=exec_q
x.shape, dtype=x.dtype, usm_type=x.usm_type, sycl_queue=exec_q
)
dep_evs = _manager.submitted_events
ht_e, roll_ev = ti._copy_usm_ndarray_for_roll_nd(
src=X, dst=res, shifts=shifts, sycl_queue=exec_q, depends=dep_evs
src=x, dst=res, shifts=shifts, sycl_queue=exec_q, depends=dep_evs
)
_manager.add_event_pair(ht_e, roll_ev)
return res
Expand Down
2 changes: 1 addition & 1 deletion dpctl/tensor/libtensor/source/copy_for_roll.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -326,7 +326,7 @@ copy_usm_ndarray_for_roll_nd(const dpctl::tensor::usm_ndarray &src,
// normalize shift parameter to be 0 <= offset < dim
py::ssize_t dim = src_shape_ptr[i];
size_t offset =
(shifts[i] > 0) ? (shifts[i] % dim) : dim + (shifts[i] % dim);
(shifts[i] >= 0) ? (shifts[i] % dim) : dim + (shifts[i] % dim);

normalized_shifts.push_back(offset);
}
Expand Down
24 changes: 24 additions & 0 deletions dpctl/tests/test_usm_ndarray_manipulation.py
Original file line number Diff line number Diff line change
Expand Up @@ -657,6 +657,30 @@ def test_roll_2d(data):
assert_array_equal(Ynp, dpt.asnumpy(Y))


def test_roll_out_bounds_shifts():
"See gh-1857"
get_queue_or_skip()

x = dpt.arange(4)
y = dpt.roll(x, np.uint64(2**63 + 2))
expected = dpt.roll(x, 2)
assert dpt.all(y == expected)

x_empty = x[1:1]
y = dpt.roll(x_empty, 11)
assert y.size == 0

x_2d = dpt.reshape(x, (2, 2))
y = dpt.roll(x_2d, np.uint64(2**63 + 1), axis=1)
expected = dpt.roll(x_2d, 1, axis=1)
assert dpt.all(y == expected)

x_2d_empty = x_2d[:, 1:1]
y = dpt.roll(x_2d_empty, 3, axis=1)
expected = dpt.empty_like(x_2d_empty)
assert dpt.all(y == expected)


def test_roll_validation():
get_queue_or_skip()

Expand Down

0 comments on commit 9eb8f03

Please sign in to comment.