Skip to content

Commit

Permalink
Improve repr for empty jax.Array
Browse files Browse the repository at this point in the history
  • Loading branch information
jakevdp committed Feb 5, 2024
1 parent e224c3d commit d9cbd7b
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 2 deletions.
7 changes: 5 additions & 2 deletions jax/_src/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -376,8 +376,11 @@ def __repr__(self):

if self.is_fully_addressable or self.is_fully_replicated:
line_width = np.get_printoptions()["linewidth"]
s = np.array2string(self._value, prefix=prefix, suffix=',',
separator=', ', max_line_width=line_width)
if self.size == 0:
s = f"[], shape={self.shape}"
else:
s = np.array2string(self._value, prefix=prefix, suffix=',',
separator=', ', max_line_width=line_width)
last_line_len = len(s) - s.rfind('\n') + 1
sep = ' '
if last_line_len + len(dtype_str) + 1 > line_width:
Expand Down
6 changes: 6 additions & 0 deletions tests/array_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -231,6 +231,12 @@ def test_repr(self):
input_shape, jax.sharding.NamedSharding(global_mesh, P('x', 'y')))
self.assertStartsWith(repr(arr), "Array(")

def test_empty_repr(self):
shape = (0, 5)
dtype = 'float32'
x = jnp.empty(shape, dtype)
self.assertEqual(repr(x), f"Array([], shape={shape}, dtype={dtype})")

def test_jnp_array(self):
arr = jnp.array([1, 2, 3])
self.assertIsInstance(arr, array.ArrayImpl)
Expand Down

0 comments on commit d9cbd7b

Please sign in to comment.