@@ -71,9 +71,7 @@ def compare_jax_and_py(
7171
7272 if must_be_device_array :
7373 if isinstance (jax_res , list ):
74- assert all (
75- isinstance (res , jax .interpreters .xla .DeviceArray ) for res in jax_res
76- )
74+ assert all (isinstance (res , jax .Array ) for res in jax_res )
7775 else :
7876 assert isinstance (jax_res , jax .interpreters .xla .DeviceArray )
7977
@@ -146,13 +144,13 @@ def test_shared():
146144 pytensor_jax_fn = function ([], a , mode = "JAX" )
147145 jax_res = pytensor_jax_fn ()
148146
149- assert isinstance (jax_res , jax .interpreters . xla . DeviceArray )
147+ assert isinstance (jax_res , jax .Array )
150148 np .testing .assert_allclose (jax_res , a .get_value ())
151149
152150 pytensor_jax_fn = function ([], a * 2 , mode = "JAX" )
153151 jax_res = pytensor_jax_fn ()
154152
155- assert isinstance (jax_res , jax .interpreters . xla . DeviceArray )
153+ assert isinstance (jax_res , jax .Array )
156154 np .testing .assert_allclose (jax_res , a .get_value () * 2 )
157155
158156 # Changed the shared value and make sure that the JAX-compiled
@@ -161,7 +159,7 @@ def test_shared():
161159 a .set_value (new_a_value )
162160
163161 jax_res = pytensor_jax_fn ()
164- assert isinstance (jax_res , jax .interpreters . xla . DeviceArray )
162+ assert isinstance (jax_res , jax .Array )
165163 np .testing .assert_allclose (jax_res , new_a_value * 2 )
166164
167165
0 commit comments