Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

scipy: unpin jax #5

Closed
wants to merge 1 commit into from
Closed

Conversation

lucascolley
Copy link
Contributor

Allegedly we can revert this pin now!

I am GPU-less at the moment but can test later if needed.

@rgommers
Copy link
Owner

rgommers commented Sep 2, 2024

I am GPU-less at the moment but can test later if needed.

GPU-less too, but I'm fairly sure that this didn't work and things in conda-forge are still broken:

% pixi ls -e array-api-cuda --platform linux-64 | rg "cudnn|jax"
Environment: array-api-cuda
cudnn                                 8.9.7.29     h092f7fd_3                 446.6 MiB  conda  cudnn-8.9.7.29-h092f7fd_3.conda
jax                                   0.4.31       pyhd8ed1ab_0               1.3 MiB    conda  jax-0.4.31-pyhd8ed1ab_0.conda
jaxlib                                0.4.31       cuda120py312h4008524_200   89.2 MiB   conda  jaxlib-0.4.31-cuda120py312h4008524_200.conda

So let's not unpin unless we can confirm on a machine with CUDA that the JAX and PyTorch tests are actually passing.

@ngam
Copy link

ngam commented Sep 2, 2024

@rgommers, there's likely a bug in the solver you're using (since it is not supposed to pick up broken packages, and it should just fail), and it is being introduced by the fact that pytorch is required for the array-api-cuda env (which hasn't been rebuilt for cudnn yet, conda-forge/pytorch-cpu-feedstock#259)

@lucascolley
Copy link
Contributor Author

@wolfv any idea why we're picking up a broken package rather than failing here?

@traversaro
Copy link

Did you remove the pixi.lock and re-created from scratch or just update it without deleting it? If I recall correctly from some discussion with @baszalmstra were he was mentioning that the solvers tries as much as possible not to modify packages that are already in pixi.lock, perhaps there is some kind of bug that does not avoid to use a broken package if it is already in the pixi.lock file.

@@ -152,7 +152,7 @@ test-cupy = { cmd = "python dev.py test -b cupy", cwd = "scipy", env = { SCIPY_A
platforms = ["linux-64", "osx-arm64"]

[feature.jax.dependencies]
jax = "=0.4.28"
jax = "*"

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am afraid that this may install a cpu-only jax if there is a conflicting dependency (like a new jax version available, but not available for the cuda+cudnn version that pytorch requires, as it happens now were the latest jax requires cudnn==9 and there is no pytorch build with cuda compatible with it). I guess you can add jaxlib = {version = "*", build = "*cuda*"} to make sure that cuda-enabled jax gets installed, see ami-iit/jaxsim#221 (comment) for a related discussion.

Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That by itself can't be the needed change (if a change is needed), because we also have environments where we want cpu-only jax. There are cpu and cuda features that follow the example in https://pixi.sh/latest/features/multi_environment/#real-world-example-use-cases. The environment definition then combines the features like:

env-name = ["cuda", "jax", "pytorch"]

If CUDA 12 is present in the environment, the pytorch/jax CUDA versions should be chosen. If that doesn't happen then there's indeed a problem, but I haven't seen that yet.

Do you happen to know if jax/jaxlib in conda-forge uses the same mutex as pytorch for selection the CPU or CUDA build flavor?

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That by itself can't be the needed change (if a change is needed), because we also have environments where we want cpu-only jax.

My bad, I did not checked the full pixi.toml. In that case the constraint needs to be added in an appropriate cuda-related feature I guess.

Do you happen to know if jax/jaxlib in conda-forge uses the same mutex as pytorch for selection the CPU or CUDA build flavor?

I am afraid there is no mutex package for pytorch (in the sense of pytorch-mutex) on conda-forge. I guess for both jaxlib and pytorch the cuda build strings starts with cuda.

Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah:( IIRC there was a mutex years ago, but conda-forge/pytorch-cpu-feedstock#247 discusses the current state and it's quite messy.

Adding the build string really doesn't compose well. I'll look into this more. What I've seen so far is that it works as advertised; probably pixi is relying on system-requirements = { cuda = "12" } to gate selecting cuda build variants on. I haven't seen it pick a cpu build yet with that present, however it's not clear to me whether cuda/cpu is retained when there's a version conflict. Will do some more in-depth testing.

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Adding the build string really doesn't compose well.

Yes, I am not a big fan of that either, but unfortunately as far as I know there are not a lot of alternatives. Leaving aside pytorch, most packages do not provide any kind of mutex or similar to select the cuda versions, and even if the cuda versions has higher priority, you can still end with the non-cuda version if the latest versions of cuda-enabled builds are not compatible (such as in this case pytorch and jax due to the cudnn version).

I'll look into this more. What I've seen so far is that it works as advertised; probably pixi is relying on system-requirements = { cuda = "12" } to gate selecting cuda build variants on. I haven't seen it pick a cpu build yet with that present, however it's not clear to me whether cuda/cpu is retained when there's a version conflict.

I hope pixi does something supersmart here, but I am afraid in the last months we were just lucky that both the latest versions of jaxlib and pytorch used cuda 12 and cudnn 8 until a few weeks ago.

@lucascolley
Copy link
Contributor Author

Did you remove the pixi.lock and re-created from scratch or just update it without deleting it?

I updated it without deleting it.

@wolfv
Copy link

wolfv commented Sep 2, 2024

@lucascolley a broken package in the sense that this package was moved to broken already? I don't think that we would ever be able to resolve one of those ... since they are not part of the repodata anymore.

@lucascolley
Copy link
Contributor Author

@lucascolley a broken package in the sense that this package was moved to broken already? I don't think that we would ever be able to resolve one of those ... since they are not part of the repodata anymore.

Yeah. I'll try deleting the lockfile and regenerating it, rather than just pixi update, as per the suggestion above.

@lucascolley
Copy link
Contributor Author

After deleting and regenerating the lockfile, a CUDA version of jaxlib 0.4.31 is still picked up:

jax                                   0.4.31       pyhd8ed1ab_0               1.3 MiB    conda  jax-0.4.31-pyhd8ed1ab_0.conda
jaxlib                                0.4.31       cuda120py312h4008524_200   89.2 MiB   conda  jaxlib-0.4.31-cuda120py312h4008524_200.conda

Is that supposed to be marked as broken and impossible to pick up @traversaro ?

@ngam
Copy link

ngam commented Sep 2, 2024

Okay. I think the "mark as broken" business didn't actually work :(

@ngam
Copy link

ngam commented Sep 2, 2024

This PR corrects the labeling fiasco: conda-forge/admin-requests#1065

@lucascolley lucascolley marked this pull request as draft September 2, 2024 17:48
@ngam
Copy link

ngam commented Sep 3, 2024

Can you try again?

@lucascolley
Copy link
Contributor Author

lucascolley commented Sep 3, 2024

pixi update now installs jaxlib 0.4.31 CPU, so the broken labelling is fixed 🙏 . As @rgommers said above, it looks like we'll need some additional way to force CUDA when version conflicts occur.

@lucascolley
Copy link
Contributor Author

@traversaro are you thinking of something like fb8577b to force cuda?

@traversaro
Copy link

@traversaro are you thinking of something like fb8577b to force cuda?

Yes, but I guess you need something similar for each package you want to ensure uses the cuda variant (i.e. pytorch).

@lucascolley lucascolley marked this pull request as ready for review September 10, 2024 15:30
@lucascolley
Copy link
Contributor Author

I've removed the torch changes from this PR - it now just unpins JAX and splits the jax feature into jax-cpu and jax-cuda, each of which depends on jaxlib with the respective build string.

@traversaro
Copy link

Just FYI, unfortunately the latest jaxlib on conda-forge is also not working on cuda, see conda-forge/jaxlib-feedstock#285 and conda-forge/jax-feedstock#162 . There is some work in progress in fixing this in conda-forge/jaxlib-feedstock#288, but it has still problems.

@rgommers
Copy link
Owner

The jax-cpu/jax-cuda split seems like a good idea. I'll try to test this locally, if it's all happy I'll finally get this in.

I'll resolve the lock file conflict locally as well. My experience so far is that such conflicts are hard to avoid; easier to handle sometimes if the manifest and lock files are updated in separate commits, so you can drop one and regenerate.

@lucascolley
Copy link
Contributor Author

lucascolley commented Nov 28, 2024

easier to handle sometimes if the manifest and lock files are updated in separate commits, so you can drop one and regenerate.

ah, good point! certainly easier than git reset @~1 --soft

@rgommers
Copy link
Owner

Grumble .... need that GPU CI, testing manually that all envs work is getting a bit cumbersome.

Looks like it's all happy though - merged now. Thanks Lucas & everyone else who weighed in! Merged now, closing.

@rgommers rgommers closed this Nov 28, 2024
@lucascolley
Copy link
Contributor Author

I'm back with a GPU and can confirm that removing this pin demonstrates the issues Silvio mentioned:

jaxlib                       0.4.34       cuda120py312h9b6c45b_200
...
In [2]: x = jnp.arange(2, device='cuda')
An NVIDIA GPU may be present on this machine, but a CUDA-enabled jaxlib is not installed. Falling back to cpu.

We could reinstate the pin, but it sounds like the fix might be ready!

@rgommers
Copy link
Owner

rgommers commented Dec 9, 2024

Interesting. From the linked discussion it looks like this doesn't happen if a system-level CUDA is installed - that's the case for me, hence I didn't catch this in local testing.

It looks like we should wait until the fix is merged and then bump to the fixed version.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants