-
Notifications
You must be signed in to change notification settings - Fork 1
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
scipy: unpin jax #5
Conversation
GPU-less too, but I'm fairly sure that this didn't work and things in conda-forge are still broken:
So let's not unpin unless we can confirm on a machine with CUDA that the JAX and PyTorch tests are actually passing. |
@rgommers, there's likely a bug in the solver you're using (since it is not supposed to pick up broken packages, and it should just fail), and it is being introduced by the fact that pytorch is required for the array-api-cuda env (which hasn't been rebuilt for cudnn yet, conda-forge/pytorch-cpu-feedstock#259) |
@wolfv any idea why we're picking up a broken package rather than failing here? |
Did you remove the |
@@ -152,7 +152,7 @@ test-cupy = { cmd = "python dev.py test -b cupy", cwd = "scipy", env = { SCIPY_A | |||
platforms = ["linux-64", "osx-arm64"] | |||
|
|||
[feature.jax.dependencies] | |||
jax = "=0.4.28" | |||
jax = "*" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I am afraid that this may install a cpu-only jax if there is a conflicting dependency (like a new jax version available, but not available for the cuda+cudnn version that pytorch requires, as it happens now were the latest jax requires cudnn==9 and there is no pytorch build with cuda compatible with it). I guess you can add jaxlib = {version = "*", build = "*cuda*"}
to make sure that cuda-enabled jax gets installed, see ami-iit/jaxsim#221 (comment) for a related discussion.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That by itself can't be the needed change (if a change is needed), because we also have environments where we want cpu-only jax
. There are cpu
and cuda
features that follow the example in https://pixi.sh/latest/features/multi_environment/#real-world-example-use-cases. The environment definition then combines the features like:
env-name = ["cuda", "jax", "pytorch"]
If CUDA 12 is present in the environment, the pytorch/jax CUDA versions should be chosen. If that doesn't happen then there's indeed a problem, but I haven't seen that yet.
Do you happen to know if jax
/jaxlib
in conda-forge uses the same mutex as pytorch
for selection the CPU or CUDA build flavor?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That by itself can't be the needed change (if a change is needed), because we also have environments where we want cpu-only jax.
My bad, I did not checked the full pixi.toml
. In that case the constraint needs to be added in an appropriate cuda-related feature I guess.
Do you happen to know if jax/jaxlib in conda-forge uses the same mutex as pytorch for selection the CPU or CUDA build flavor?
I am afraid there is no mutex package for pytorch (in the sense of pytorch-mutex
) on conda-forge. I guess for both jaxlib and pytorch the cuda build strings starts with cuda
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ah:( IIRC there was a mutex years ago, but conda-forge/pytorch-cpu-feedstock#247 discusses the current state and it's quite messy.
Adding the build string really doesn't compose well. I'll look into this more. What I've seen so far is that it works as advertised; probably pixi
is relying on system-requirements = { cuda = "12" }
to gate selecting cuda
build variants on. I haven't seen it pick a cpu
build yet with that present, however it's not clear to me whether cuda/cpu is retained when there's a version conflict. Will do some more in-depth testing.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Adding the build string really doesn't compose well.
Yes, I am not a big fan of that either, but unfortunately as far as I know there are not a lot of alternatives. Leaving aside pytorch, most packages do not provide any kind of mutex or similar to select the cuda versions, and even if the cuda versions has higher priority, you can still end with the non-cuda version if the latest versions of cuda-enabled builds are not compatible (such as in this case pytorch and jax due to the cudnn version).
I'll look into this more. What I've seen so far is that it works as advertised; probably pixi is relying on system-requirements = { cuda = "12" } to gate selecting cuda build variants on. I haven't seen it pick a cpu build yet with that present, however it's not clear to me whether cuda/cpu is retained when there's a version conflict.
I hope pixi does something supersmart here, but I am afraid in the last months we were just lucky that both the latest versions of jaxlib and pytorch used cuda 12 and cudnn 8 until a few weeks ago.
I updated it without deleting it. |
@lucascolley a broken package in the sense that this package was moved to broken already? I don't think that we would ever be able to resolve one of those ... since they are not part of the repodata anymore. |
Yeah. I'll try deleting the lockfile and regenerating it, rather than just |
After deleting and regenerating the lockfile, a CUDA version of jaxlib 0.4.31 is still picked up:
Is that supposed to be marked as broken and impossible to pick up @traversaro ? |
Okay. I think the "mark as broken" business didn't actually work :(
|
This PR corrects the labeling fiasco: conda-forge/admin-requests#1065 |
Can you try again? |
|
fa97c9b
to
3396315
Compare
@traversaro are you thinking of something like fb8577b to force cuda? |
Yes, but I guess you need something similar for each package you want to ensure uses the cuda variant (i.e. pytorch). |
110417b
to
9b9b591
Compare
9b9b591
to
e19db74
Compare
I've removed the |
Just FYI, unfortunately the latest jaxlib on conda-forge is also not working on cuda, see conda-forge/jaxlib-feedstock#285 and conda-forge/jax-feedstock#162 . There is some work in progress in fixing this in conda-forge/jaxlib-feedstock#288, but it has still problems. |
The I'll resolve the lock file conflict locally as well. My experience so far is that such conflicts are hard to avoid; easier to handle sometimes if the manifest and lock files are updated in separate commits, so you can drop one and regenerate. |
ah, good point! certainly easier than |
Grumble .... need that GPU CI, testing manually that all envs work is getting a bit cumbersome. Looks like it's all happy though - merged now. Thanks Lucas & everyone else who weighed in! Merged now, closing. |
I'm back with a GPU and can confirm that removing this pin demonstrates the issues Silvio mentioned:
We could reinstate the pin, but it sounds like the fix might be ready! |
Interesting. From the linked discussion it looks like this doesn't happen if a system-level CUDA is installed - that's the case for me, hence I didn't catch this in local testing. It looks like we should wait until the fix is merged and then bump to the fixed version. |
Allegedly we can revert this pin now!
I am GPU-less at the moment but can test later if needed.