Skip to content

Commit

Permalink
Merge pull request #1043 from IntelPython/print-dtype-bug-fix
Browse files Browse the repository at this point in the history
Fixed spacing of dtype string in array printing
  • Loading branch information
ndgrigorian authored Jan 23, 2023
2 parents 1916370 + da9b756 commit 3c512ca
Show file tree
Hide file tree
Showing 2 changed files with 28 additions and 1 deletion.
6 changes: 5 additions & 1 deletion dpctl/tensor/_print.py
Original file line number Diff line number Diff line change
Expand Up @@ -316,7 +316,11 @@ def _usm_ndarray_repr(x, line_width=None, precision=None, suppress=None):
dtype_str = "dtype={}".format(x.dtype.name)
bottom_len = len(s) - (s.rfind("\n") + 1)
next_line = bottom_len + len(dtype_str) + 1 > line_width
dtype_str = ",\n" + dtype_str if next_line else ", " + dtype_str
dtype_str = (
",\n" + " " * len(prefix) + dtype_str
if next_line
else ", " + dtype_str
)
else:
dtype_str = ""

Expand Down
23 changes: 23 additions & 0 deletions dpctl/tests/test_usm_ndarray_print.py
Original file line number Diff line number Diff line change
Expand Up @@ -211,6 +211,16 @@ def test_print_repr(self):
x = dpt.arange(4, dtype="i4", sycl_queue=q)
assert repr(x) == "usm_ndarray([0, 1, 2, 3], dtype=int32)"

dpt.set_print_options(linewidth=1)
np.testing.assert_equal(
repr(x),
"usm_ndarray([0,"
"\n 1,"
"\n 2,"
"\n 3],"
"\n dtype=int32)",
)

def test_print_repr_abbreviated(self):
q = get_queue_or_skip()

Expand All @@ -237,6 +247,19 @@ def test_print_repr_abbreviated(self):
"\n [6, ..., 8]], dtype=int32)",
)

dpt.set_print_options(linewidth=1)
np.testing.assert_equal(
repr(y),
"usm_ndarray([[0,"
"\n ...,"
"\n 2],"
"\n ...,"
"\n [6,"
"\n ...,"
"\n 8]],"
"\n dtype=int32)",
)

@pytest.mark.parametrize(
"dtype",
[
Expand Down

0 comments on commit 3c512ca

Please sign in to comment.