Skip to content

Commit

Permalink
Resolves gh-1512 (#1513)
Browse files Browse the repository at this point in the history
dtype is passed to np.asarray when dpt.asarray is called with a Python scalar as input

This guarantees the expected OverflowError is thrown
  • Loading branch information
ndgrigorian authored Jan 27, 2024
1 parent 1cd0b96 commit 9b0a3cd
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 5 deletions.
6 changes: 1 addition & 5 deletions dpctl/tensor/_ctors.py
Original file line number Diff line number Diff line change
Expand Up @@ -632,17 +632,13 @@ def asarray(
usm_type=usm_type,
order=order,
)

raise NotImplementedError(
"Converting Python sequences is not implemented"
)
if copy is False:
raise ValueError(
f"Converting {type(obj)} to usm_ndarray requires a copy"
)
# obj is a scalar, create 0d array
return _asarray_from_numpy_ndarray(
np.asarray(obj),
np.asarray(obj, dtype=dtype),
dtype=dtype,
usm_type=usm_type,
sycl_queue=sycl_queue,
Expand Down
15 changes: 15 additions & 0 deletions dpctl/tests/test_tensor_asarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,6 +164,21 @@ def test_asarray_input_validation():
with pytest.raises(ValueError):
# sequence is not rectangular
dpt.asarray([[1], 2])
with pytest.raises(OverflowError):
# Python int too large for type
dpt.asarray(-9223372036854775809, dtype="i4")
with pytest.raises(ValueError):
# buffer to usm_ndarray requires a copy
dpt.asarray(memoryview(np.arange(5)), copy=False)
with pytest.raises(ValueError):
# Numpy array to usm_ndarray requires a copy
dpt.asarray(np.arange(5), copy=False)
with pytest.raises(ValueError):
# Python sequence to usm_ndarray requires a copy
dpt.asarray([1, 2, 3], copy=False)
with pytest.raises(ValueError):
# Python scalar to usm_ndarray requires a copy
dpt.asarray(5, copy=False)


def test_asarray_input_validation2():
Expand Down

0 comments on commit 9b0a3cd

Please sign in to comment.