diff --git a/ivy/functional/frontends/jax/devicearray.py b/ivy/functional/frontends/jax/devicearray.py index b04cecb8a1095..dabd2464a1f7a 100644 --- a/ivy/functional/frontends/jax/devicearray.py +++ b/ivy/functional/frontends/jax/devicearray.py @@ -41,6 +41,10 @@ def shape(self): def at(self): return jax_frontend._src.numpy.lax_numpy._IndexUpdateHelper(self.ivy_array) + @property + def T(self): + return self.ivy_array.T + # Instance Methods # # ---------------- # diff --git a/ivy_tests/test_ivy/test_frontends/test_jax/test_devicearray.py b/ivy_tests/test_ivy/test_frontends/test_jax/test_devicearray.py index a2401a4c69a30..9089fbb60a565 100644 --- a/ivy_tests/test_ivy/test_frontends/test_jax/test_devicearray.py +++ b/ivy_tests/test_ivy/test_frontends/test_jax/test_devicearray.py @@ -59,6 +59,30 @@ def test_jax_devicearray_property_shape( assert x.shape == shape +@st.composite +def _transpose_helper(draw): + dtype_x = draw( + helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("valid", prune_function=False), + min_num_dims=2, + max_num_dims=2, + min_dim_size=2, + ) + ) + + _, data = dtype_x + x = data[0] + xT = np.transpose(x) + return x, xT + + +@given(x_transpose=_transpose_helper()) +def test_jax_devicearray_property_T(x_transpose): + x, xT = x_transpose + x = DeviceArray(x) + assert np.array_equal(x.T, xT) + + @st.composite def _at_helper(draw): _, data, shape = draw(