-
Notifications
You must be signed in to change notification settings - Fork 59
Bump jax to 0.7.0 #2131
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Bump jax to 0.7.0 #2131
Conversation
…n of dependencies
….0.dev20250703+cd1b9520b
…alyst into rniczh/bump-jax-to-0.7.0
Co-authored-by: Yushao Chen (Jerry) <chenys13@outlook.com>
|
Fantastic, now even Codecov passes! @rniczh |
**Context:** Basically, for Catalyst, the new `bump-jax-patching` branch means there's no patching happenning by default anymore. In this way, we can better separate the impact of patching, instead of causing inheritance of patching from upstream to downstream. We could also better observe what patches are actually needed for other packages depending on Pennylane. **Description of the Change:** **Benefits:** **Possible Drawbacks:** **Related GitHub Issues:**
|
Now Catalyst's jax bumping branch should be independent of whether the PL Jax patches are global or local. |
JerryChen97
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hold on temporarily until the PennyLane PR is also ready-to-merge
|
@rniczh Merge? |
Context:
Description of the Change:
lower_jaxpr_to_funhas changed, replacingpublicwithmain_function. Now, custom function names (e.g.,jit_…) were not permitted, withmainbeing the default. To address this,main_functionis now set to False, whilesym_visibilityis set to True for MLIR, retaining thejit_funcfunction name. Additionally,name_stackhas been removed. Do not propagate name_stacks into lower_jaxpr_to_fun. jax-ml/jax#29783make_varandget_varhas been removed. Should set the var (val) when initializing theDynamicJaxprTracer. Avoid strong refs to tracers in DynamicJaxprTrace. jax-ml/jax#29937. There is a typo in jax Fix typo in pjit.py jax-ml/jax#32718, we will patch the installation of jax in this PR.flatten_lowering_ir_argswithflatten_ir_values[JAX] Remove jax.interpreters.mlir.flatten_lowering_ir_args. jax-ml/jax#29706TracingEqnand replaceframe.eqns(originally usingJaxprEqn) withframe.tracing_eqns. roll-forward with fixes jax-ml/jax#30135pjit_phas been removed, usejit_pinsteadbin/patch_jax_installation.py)frontend/catalyst/jax_extras/patches.py)Benefits:
Possible Drawbacks:
Related GitHub Issues: