Skip to content

Commit c9ced7d

Browse files
Tzung-Han Juangdime10erick-xanadu
authored
Update JAX & MLIR dependency chain to v0.4.28 (#931)
**Context:** We target at 0.4.28 instead of 0.4.30 because many bufferization passes are removed after [this llvm commit](llvm/llvm-project#93535). **Description of the Change:** ***Mandatory Updates:*** * Cmake - Remove `MhloShapeOpsToStandard` [link](tensorflow/mlir-hlo@57d2124) - Add `StablehloPasses` - Add `MhloQuantToIntConversion` (This will be removed after 0.4.29) - `EnzymeStatic-18` => `EnzymeStatic-19` - `RunnerUtils.h` requires `Float16Bits.h` [link](llvm/llvm-project@7bc6c4a) * LLVM - `setDataLayout` must happen before code generation or they will use the default one. [link](https://discord.com/channels/636084430946959380/636732535434510338/1265407221324451871) * MLIR - `updateRootInPlace` => `modifyOpInPlace` - `startRootUpdate` => `startOpModification` - `finalizeRootUpdate` => `finalizeOpModification` - The order of transformed mlir expressions is different and required fine-tuning for `CHECK-DAG`s. - Using rewriter's method to release MeasurementOps (Caused by LLVM commit-`b840d2968391dd610b792a65133a1edc1bcc397c`). [link](llvm/llvm-project@b840d29) - Allow `replaceTerminatorWithUnconditionalJumpToSuccessBlock` to accept `LLVM:br`. (New LLVM/MLIR will reuse `LLVM::UnreachableOp` and use `LLVM::br` to reach it.) * Frontend - `jax.linear_util` => `jax.extend.linear_util` [link](https://jax.readthedocs.io/en/latest/changelog.html#jax-0-4-24-feb-6-2024) - `jax_ctx.module_context.replace` => `jax_ctx.replace` - `gensym(jaxprs, suffix)` => `gensym(suffix)` [link](jax-ml/jax@67df647) - Move `name_stack` out of mlir.ModuleContext (functions like `lower_jaxpr_to_fun` is taking `name_stack` now). [link](jax-ml/jax#19856). - Pass `LoweringRuleContext.ModuleContext` instead of `LoweringRuleContext` to `jaxpr_subcomp` - Patch new `_sin_lowering` and `_cos_lowering` with `_nary_lower_hlo(sine/cosine)`. - Variable names (like `%0` => `%cst`) and orders in FileCheck. ***Deprecations/Warnings:*** * MLIR - `x.cast<T>()` => `mlir::cast<T>(x)` [link](https://discourse.llvm.org/t/preferred-casting-style-going-forward/68443) - `x.dyn_cast<T>()` => `mlir::dyn_cast<T>(x)` [link](https://discourse.llvm.org/t/preferred-casting-style-going-forward/68443) - `x.isa<T>()` => `mlir::isa<T>(x)` [link](https://discourse.llvm.org/t/preferred-casting-style-going-forward/68443) * Enzyme - Add `DCMAKE_POLICY_DEFAULT_CMP0116` **Related GitHub Issues:** #863 [sc-67111] --------- Co-authored-by: David Ittah <dime10@users.noreply.github.com> Co-authored-by: erick-xanadu <110487834+erick-xanadu@users.noreply.github.com>
1 parent 5fa4b21 commit c9ced7d

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

53 files changed

+350
-267
lines changed

.dep-versions

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
# Always update the version check in catalyst.__init__ when changing the JAX version.
2-
jax=0.4.23
3-
mhlo=4611968a5f6818e6bdfb82217b9e836e0400bba9
4-
llvm=cd9a641613eddf25d4b25eaa96b2c393d401d42c
2+
jax=0.4.28
3+
mhlo=89a891c986650c33df76885f5620e0a92150d90f
4+
llvm=3a8316216807d64a586b971f51695e23883331f7
55
enzyme=v0.0.130
66

77
# Always remove custom PL/LQ versions before release.

doc/changelog.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -170,6 +170,8 @@
170170
[(#945)](https://github.com/PennyLaneAI/catalyst/pull/945)
171171
[(#962)](https://github.com/PennyLaneAI/catalyst/pull/962)
172172

173+
* Update JAX to `v0.4.28`. [(#931)](https://github.com/PennyLaneAI/catalyst/pull/931)
174+
173175
<h3>Breaking changes</h3>
174176

175177
* Return values of qjit-compiled functions that were previously `numpy.ndarray` are now of type

frontend/catalyst/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323

2424
import jaxlib as _jaxlib
2525

26-
_jaxlib_version = "0.4.23"
26+
_jaxlib_version = "0.4.28"
2727
if _jaxlib.__version__ != _jaxlib_version:
2828
import warnings
2929

frontend/catalyst/compiler.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -165,7 +165,6 @@ def run_writing_command(command: List[str], compile_options: Optional[CompileOpt
165165
"func.func(chlo-legalize-to-hlo)",
166166
"stablehlo-legalize-to-hlo",
167167
"func.func(mhlo-legalize-control-flow)",
168-
"func.func(hlo-legalize-shapeops-to-standard)",
169168
"func.func(hlo-legalize-to-linalg)",
170169
"func.func(mhlo-legalize-to-std)",
171170
"func.func(hlo-legalize-sort)",

frontend/catalyst/jax_extras/lowering.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -125,7 +125,6 @@ def custom_lower_jaxpr_to_module(
125125
backend_or_name=None,
126126
platforms=[platform],
127127
axis_context=axis_context,
128-
name_stack=name_stack,
129128
keepalives=keepalives,
130129
channel_iterator=channel_iter,
131130
host_callbacks=host_callbacks,
@@ -149,6 +148,7 @@ def custom_lower_jaxpr_to_module(
149148
replicated_args=replicated_args,
150149
arg_shardings=arg_shardings,
151150
result_shardings=result_shardings,
151+
name_stack=name_stack,
152152
)
153153

154154
for op in ctx.module.body.operations:

frontend/catalyst/jax_extras/patches.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,19 +18,23 @@
1818
from __future__ import annotations
1919

2020
import jax
21+
from jax._src.lax.lax import _nary_lower_hlo
2122
from jax._src.lax.slicing import (
2223
_gather_shape_computation,
2324
_is_sorted,
2425
_no_duplicate_dims,
2526
_rank,
2627
_sorted_dims_in_range,
2728
)
29+
from jax._src.lib.mlir.dialects import hlo
2830
from jax.core import AbstractValue, Tracer, concrete_aval
2931

3032
__all__ = (
3133
"get_aval2",
3234
"_no_clean_up_dead_vars",
3335
"_gather_shape_rule_dynamic",
36+
"_sin_lowering2",
37+
"_cos_lowering2",
3438
)
3539

3640

@@ -180,3 +184,13 @@ def _gather_shape_rule_dynamic(
180184
)
181185

182186
return _gather_shape_computation(indices, dimension_numbers, slice_sizes)
187+
188+
189+
def _sin_lowering2(ctx, x):
190+
"""Use hlo.sine lowering instead of the new sin lowering from jax 0.4.28"""
191+
return _nary_lower_hlo(hlo.sine, ctx, x)
192+
193+
194+
def _cos_lowering2(ctx, x):
195+
"""Use hlo.cosine lowering instead of the new cosine lowering from jax 0.4.28"""
196+
return _nary_lower_hlo(hlo.cosine, ctx, x)

frontend/catalyst/jax_extras/tracing.py

Lines changed: 15 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@
4444
trace_to_jaxpr_dynamic2,
4545
)
4646
from jax._src.lax.control_flow import _initial_style_jaxpr
47-
from jax._src.lax.lax import _abstractify
47+
from jax._src.lax.lax import _abstractify, cos_p, sin_p
4848
from jax._src.lax.slicing import (
4949
_argnum_weak_type,
5050
_gather_dtype_rule,
@@ -80,14 +80,14 @@
8080
new_jaxpr_eqn,
8181
thread_local_state,
8282
)
83+
from jax.extend.linear_util import transformation_with_aux, wrap_init
8384
from jax.interpreters.partial_eval import (
8485
DynamicJaxprTrace,
8586
DynamicJaxprTracer,
8687
convert_constvars_jaxpr,
8788
make_jaxpr_effects,
8889
)
8990
from jax.lax import convert_element_type
90-
from jax.linear_util import transformation_with_aux, wrap_init
9191
from jax.tree_util import (
9292
PyTreeDef,
9393
tree_flatten,
@@ -97,7 +97,12 @@
9797
)
9898
from jaxlib.xla_extension import PyTreeRegistry
9999

100-
from catalyst.jax_extras.patches import _gather_shape_rule_dynamic, get_aval2
100+
from catalyst.jax_extras.patches import (
101+
_cos_lowering2,
102+
_gather_shape_rule_dynamic,
103+
_sin_lowering2,
104+
get_aval2,
105+
)
101106
from catalyst.logging import debug_logger
102107
from catalyst.tracing.type_signatures import verify_static_argnums_type
103108
from catalyst.utils.patching import Patcher
@@ -288,7 +293,7 @@ def __init__(self, boxid: int, e: JaxprEqn):
288293
def jaxpr_pad_consts(jaxprs: List[Jaxpr]) -> List[ClosedJaxpr]:
289294
"""Align the constants of Jaxpr programs. Return the list of corresponding programs accepting
290295
the same constants."""
291-
newvar = gensym(jaxprs, suffix="_")
296+
newvar = gensym("_")
292297

293298
# List of constant variables of all jaxprs, preprended with '_'
294299
all_mangled_constvars: List[List[Var]] = []
@@ -519,6 +524,10 @@ def abstractify(args, kwargs):
519524
)
520525
register_lowering(gather2_p, _gather_lower)
521526

527+
# TBD
528+
register_lowering(sin_p, _sin_lowering2)
529+
register_lowering(cos_p, _cos_lowering2)
530+
522531
primitive_batchers2 = jax._src.interpreters.batching.primitive_batchers.copy()
523532
for primitive in jax._src.interpreters.batching.primitive_batchers.keys():
524533
if primitive.name == "gather":
@@ -532,6 +541,8 @@ def make_jaxpr_f(*args, **kwargs):
532541
(jax._src.interpreters.partial_eval, "get_aval", get_aval2),
533542
(jax._src.lax.slicing, "gather_p", gather2_p),
534543
(jax._src.interpreters.batching, "primitive_batchers", primitive_batchers2),
544+
(jax._src.lax.lax, "_sin_lowering", _sin_lowering2),
545+
(jax._src.lax.lax, "_cos_lowering", _cos_lowering2),
535546
), ExitStack():
536547
f = wrap_init(fun)
537548
if static_argnums:

frontend/catalyst/jax_primitives.py

Lines changed: 24 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -330,8 +330,8 @@ def _python_callback_lowering(
330330
fwd_jaxpr = custom_grad._fwd_jaxpr
331331
rev_jaxpr = custom_grad._bwd_jaxpr
332332
ctx = jax_ctx.module_context
333-
mlir_fwd = _func_def_lowering(ctx, call_jaxpr=fwd_jaxpr, fn=fwd)
334-
mlir_rev = _func_def_lowering(ctx, call_jaxpr=rev_jaxpr, fn=rev)
333+
mlir_fwd = _func_def_lowering(ctx, call_jaxpr=fwd_jaxpr, fn=fwd, name_stack=jax_ctx.name_stack)
334+
mlir_rev = _func_def_lowering(ctx, call_jaxpr=rev_jaxpr, fn=rev, name_stack=jax_ctx.name_stack)
335335
sym_fwd = mlir_fwd.sym_name.value + ".fwd"
336336

337337
argc = len(args)
@@ -397,11 +397,11 @@ def _func_def_impl(ctx, *args, call_jaxpr, fn, call=True): # pragma: no cover
397397
raise NotImplementedError()
398398

399399

400-
def _func_def_lowering(ctx, fn, call_jaxpr) -> str:
400+
def _func_def_lowering(ctx, fn, call_jaxpr, name_stack) -> str:
401401
"""Create a func::FuncOp from JAXPR."""
402402
if isinstance(call_jaxpr, core.Jaxpr):
403403
call_jaxpr = core.ClosedJaxpr(call_jaxpr, ())
404-
func_op = mlir.lower_jaxpr_to_fun(ctx, fn.__name__, call_jaxpr, tuple())
404+
func_op = mlir.lower_jaxpr_to_fun(ctx, fn.__name__, call_jaxpr, tuple(), name_stack=name_stack)
405405

406406
if isinstance(fn, qml.QNode):
407407
func_op.attributes["qnode"] = ir.UnitAttr.get()
@@ -443,7 +443,7 @@ def _func_lowering(ctx, *args, call_jaxpr, fn, call=True):
443443
if fn in mlir_fn_cache:
444444
func_op = mlir_fn_cache[fn]
445445
else:
446-
func_op = _func_def_lowering(ctx.module_context, fn, call_jaxpr)
446+
func_op = _func_def_lowering(ctx.module_context, fn, call_jaxpr, name_stack=ctx.name_stack)
447447
mlir_fn_cache[fn] = func_op
448448

449449
symbol_name = func_op.name.value
@@ -1585,14 +1585,13 @@ def emit_branches(preds, branch_jaxprs, ip):
15851585

15861586
# if block
15871587
source_info_util.extend_name_stack("if")
1588-
if_ctx = jax_ctx.module_context.replace(
1589-
name_stack=jax_ctx.module_context.name_stack.extend("if")
1590-
)
1588+
if_ctx = jax_ctx.replace(name_stack=jax_ctx.name_stack.extend("if"))
15911589
with ir.InsertionPoint(if_block):
15921590
# recursively generate the mlir for the if block
15931591
out = mlir.jaxpr_subcomp(
1594-
if_ctx,
1592+
if_ctx.module_context,
15951593
true_jaxpr.jaxpr,
1594+
if_ctx.name_stack,
15961595
mlir.TokenSet(),
15971596
[mlir.ir_constants(c) for c in true_jaxpr.consts],
15981597
*([a] for a in flat_args_plus_consts), # fn expects [a1], [a2], [a3] format
@@ -1603,17 +1602,16 @@ def emit_branches(preds, branch_jaxprs, ip):
16031602

16041603
# else block
16051604
source_info_util.extend_name_stack("else")
1606-
else_ctx = jax_ctx.module_context.replace(
1607-
name_stack=jax_ctx.module_context.name_stack.extend("else")
1608-
)
1605+
else_ctx = jax_ctx.replace(name_stack=jax_ctx.name_stack.extend("else"))
16091606
else_block = if_op_scf.else_block
16101607
if len(preds) == 1:
16111608
# Base case: reached the otherwise block
16121609
otherwise_jaxpr = branch_jaxprs[-1]
16131610
with ir.InsertionPoint(else_block):
16141611
out = mlir.jaxpr_subcomp(
1615-
else_ctx,
1612+
else_ctx.module_context,
16161613
otherwise_jaxpr.jaxpr,
1614+
else_ctx.name_stack,
16171615
mlir.TokenSet(),
16181616
[mlir.ir_constants(c) for c in otherwise_jaxpr.consts],
16191617
*([a] for a in flat_args_plus_consts),
@@ -1695,15 +1693,16 @@ def _while_loop_lowering(
16951693

16961694
# cond block
16971695
cond_block = while_op_scf.regions[0].blocks.append(*loop_carry_types)
1698-
name_stack = jax_ctx.module_context.name_stack.extend("while")
1699-
cond_ctx = jax_ctx.module_context.replace(name_stack=name_stack.extend("cond"))
1696+
name_stack = jax_ctx.name_stack.extend("while")
1697+
cond_ctx = jax_ctx.replace(name_stack=name_stack.extend("cond"))
17001698
with ir.InsertionPoint(cond_block):
17011699
cond_args = [cond_block.arguments[i] for i in range(len(loop_carry_types))]
17021700

17031701
# recursively generate the mlir for the while cond
17041702
((pred,),), _ = mlir.jaxpr_subcomp(
1705-
cond_ctx,
1703+
cond_ctx.module_context,
17061704
cond_jaxpr.jaxpr,
1705+
cond_ctx.name_stack,
17071706
mlir.TokenSet(),
17081707
[mlir.ir_constants(c) for c in cond_jaxpr.consts],
17091708
*([a] for a in (cond_consts + cond_args)), # fn expects [a1], [a2], [a3] format
@@ -1715,14 +1714,15 @@ def _while_loop_lowering(
17151714

17161715
# body block
17171716
body_block = while_op_scf.regions[1].blocks.append(*loop_carry_types)
1718-
body_ctx = jax_ctx.module_context.replace(name_stack=name_stack.extend("body"))
1717+
body_ctx = jax_ctx.replace(name_stack=name_stack.extend("body"))
17191718
with ir.InsertionPoint(body_block):
17201719
body_args = [body_block.arguments[i] for i in range(len(loop_carry_types))]
17211720

17221721
# recursively generate the mlir for the while body
17231722
out, _ = mlir.jaxpr_subcomp(
1724-
body_ctx,
1723+
body_ctx.module_context,
17251724
body_jaxpr.jaxpr,
1725+
body_ctx.name_stack,
17261726
mlir.TokenSet(),
17271727
[mlir.ir_constants(c) for c in cond_jaxpr.consts],
17281728
*([a] for a in (body_consts + body_args)), # fn expects [a1], [a2], [a3] format
@@ -1831,9 +1831,9 @@ def _cast_to_index(p):
18311831

18321832
for_op_scf = ForOp(lower_bound, upper_bound, step, iter_args=loop_args)
18331833

1834-
name_stack = jax_ctx.module_context.name_stack.extend("for")
1834+
name_stack = jax_ctx.name_stack.extend("for")
18351835
body_block = for_op_scf.body
1836-
body_ctx = jax_ctx.module_context.replace(name_stack=name_stack.extend("body"))
1836+
body_ctx = jax_ctx.replace(name_stack=name_stack.extend("body"))
18371837

18381838
with ir.InsertionPoint(body_block):
18391839
body_args = list(body_block.arguments)
@@ -1858,8 +1858,9 @@ def _cast_to_index(p):
18581858

18591859
# Recursively generate the mlir for the loop body
18601860
out, _ = mlir.jaxpr_subcomp(
1861-
body_ctx,
1861+
body_ctx.module_context,
18621862
body_jaxpr.jaxpr,
1863+
body_ctx.name_stack,
18631864
mlir.TokenSet(),
18641865
[mlir.ir_constants(c) for c in body_jaxpr.consts],
18651866
*body_args,
@@ -1935,10 +1936,9 @@ def _adjoint_lowering(
19351936
with ir.InsertionPoint(adjoint_block):
19361937
source_info_util.extend_name_stack("adjoint")
19371938
out, _ = mlir.jaxpr_subcomp(
1938-
jax_ctx.module_context.replace(
1939-
name_stack=jax_ctx.module_context.name_stack.extend("adjoint")
1940-
),
1939+
jax_ctx.module_context,
19411940
jaxpr.jaxpr,
1941+
jax_ctx.name_stack.extend("adjoint"),
19421942
mlir.TokenSet(),
19431943
[mlir.ir_constants(c) for c in jaxpr.consts],
19441944
*([a] for a in chain(consts, cargs, adjoint_block.arguments)), # [3]

0 commit comments

Comments
 (0)