Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use jax.Array for type checking in tests #1372

Merged
merged 1 commit into from
Dec 15, 2022
Merged

Conversation

rlouf
Copy link
Member

@rlouf rlouf commented Dec 15, 2022

JAX's 0.4.1 release introduces jax.Array which replaces DeviceArray. Tests were failing in CI because they are currently checking that the arrays returned by the compiled functions are DeviceArray.

I updated the tests and added a constraint on the JAX version in requirements.txt.

@rlouf rlouf added bug Something isn't working JAX Involves JAX transpilation labels Dec 15, 2022
@codecov
Copy link

codecov bot commented Dec 15, 2022

Codecov Report

Merging #1372 (607bfc4) into main (2434cb4) will increase coverage by 0.30%.
The diff coverage is 93.46%.

Additional details and impacted files

Impacted file tree graph

@@            Coverage Diff             @@
##             main    #1372      +/-   ##
==========================================
+ Coverage   74.35%   74.66%   +0.30%     
==========================================
  Files         177      177              
  Lines       49046    49050       +4     
  Branches    10379    10400      +21     
==========================================
+ Hits        36468    36623     +155     
+ Misses      10285    10131     -154     
- Partials     2293     2296       +3     
Impacted Files Coverage Δ
aesara/link/jax/dispatch/elemwise.py 80.59% <50.00%> (ø)
aesara/link/jax/dispatch/shape.py 94.82% <75.00%> (+6.36%) ⬆️
aesara/tensor/rewriting/jax.py 86.44% <86.44%> (ø)
aesara/link/jax/dispatch/scalar.py 96.72% <95.74%> (-0.69%) ⬇️
aesara/link/jax/dispatch/basic.py 92.59% <100.00%> (+8.72%) ⬆️
aesara/link/jax/dispatch/random.py 100.00% <100.00%> (ø)
aesara/link/jax/dispatch/subtensor.py 100.00% <100.00%> (+32.07%) ⬆️
aesara/link/jax/dispatch/tensor_basic.py 97.22% <100.00%> (+5.15%) ⬆️
aesara/tensor/basic.py 89.88% <0.00%> (-0.07%) ⬇️
tests/link/jax/test_subtensor.py
... and 3 more

@rlouf rlouf merged commit 67ea5b6 into aesara-devs:main Dec 15, 2022
@rlouf rlouf deleted the fix-jax-tests branch December 15, 2022 15:55
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working JAX Involves JAX transpilation
Projects
None yet
Development

Successfully merging this pull request may close these issues.

1 participant