diff --git a/.dep-versions b/.dep-versions
index 5b6bea4912..aecf39d601 100644
--- a/.dep-versions
+++ b/.dep-versions
@@ -1,7 +1,7 @@
# Always update the version check in catalyst.__init__ when changing the JAX version.
# To update JAX version alongside compatible dependency tags, run the following script:
# python3 .github/workflows/set_dep_versions.py {JAX_version}
-jax=0.6.2
+jax=0.7.0
stablehlo=0a4440a5c8de45c4f9649bf3eb4913bf3f97da0d
llvm=113f01aa82d055410f22a9d03b3468fa68600589
enzyme=v0.0.203
@@ -10,7 +10,9 @@ enzyme=v0.0.203
# For a custom PL version, update the package version here and at
# 'doc/requirements.txt'
-pennylane=0.44.0.dev31
+# TODO: uncomment and update to latest version of pennylane
+# after https://github.com/PennyLaneAI/pennylane/pull/8525 is merged.
+# pennylane=0.44.0.dev31
# For a custom LQ/LK version, update the package version here and at
# 'doc/requirements.txt'
diff --git a/Makefile b/Makefile
index d0ac210b8d..bd921e823f 100644
--- a/Makefile
+++ b/Makefile
@@ -117,6 +117,8 @@ frontend:
# versions of a package with the same version tag (e.g. 0.38-dev0).
$(PYTHON) -m pip uninstall -y pennylane
$(PYTHON) -m pip install -e . --extra-index-url https://test.pypi.org/simple $(PIP_VERBOSE_FLAG)
+ # TODO: remove after https://github.com/PennyLaneAI/pennylane/pull/8525 is merged.
+ $(PYTHON) -m pip install git+https://github.com/PennyLaneAI/pennylane@bump-jax-api-hashability
rm -r frontend/pennylane_catalyst.egg-info
.PHONY: mlir llvm stablehlo enzyme dialects runtime oqc
@@ -134,7 +136,7 @@ enzyme:
dialects:
$(MAKE) -C mlir dialects
-
+
.PHONY: dialect-docs
dialect-docs:
$(MAKE) -C mlir dialect-docs
diff --git a/doc/releases/changelog-dev.md b/doc/releases/changelog-dev.md
index 46e4f31da7..f653af2aa6 100644
--- a/doc/releases/changelog-dev.md
+++ b/doc/releases/changelog-dev.md
@@ -69,6 +69,13 @@
Improvements ðŸ›
+* Remove the hardcoded list of runtime operations in the frontend.
+ This will allow arbitrary PL gates to be represented without hyperparameters in MLIR.
+ For gates that do not have a QIR representation, a runtime error will be raised at execution.
+ Users can still decompose these gates via `qml.transforms.decompose`
+ when both capture and graph-decomposition are enabled.
+ [(#2215)](https://github.com/PennyLaneAI/catalyst/pull/2215)
+
* `qml.PCPhase` can be compiled and executed with capture enabled.
[(#2226)](https://github.com/PennyLaneAI/catalyst/pull/2226)
diff --git a/frontend/catalyst/__init__.py b/frontend/catalyst/__init__.py
index eeaf6aa1be..4a50fb5859 100644
--- a/frontend/catalyst/__init__.py
+++ b/frontend/catalyst/__init__.py
@@ -18,12 +18,11 @@
# pylint: disable=wrong-import-position
import sys
-import types
from os.path import dirname
import jaxlib as _jaxlib
-_jaxlib_version = "0.6.2"
+_jaxlib_version = "0.7.0"
if _jaxlib.__version__ != _jaxlib_version:
import warnings
diff --git a/frontend/catalyst/autograph/ag_primitives.py b/frontend/catalyst/autograph/ag_primitives.py
index 61c0a06839..bed3c5cbfd 100644
--- a/frontend/catalyst/autograph/ag_primitives.py
+++ b/frontend/catalyst/autograph/ag_primitives.py
@@ -62,7 +62,7 @@ def get_program_length(reference_tracers):
if EvaluationContext.is_tracing(): # pragma: no branch
jaxpr_frame = EvaluationContext.find_jaxpr_frame(reference_tracers)
- num_jaxpr_eqns = len(jaxpr_frame.eqns)
+ num_jaxpr_eqns = len(jaxpr_frame.tracing_eqns)
if EvaluationContext.is_quantum_tracing():
quantum_queue = EvaluationContext.find_quantum_queue()
@@ -79,8 +79,8 @@ def reset_program_to_length(reference_tracers, num_jaxpr_eqns, num_tape_ops):
if EvaluationContext.is_tracing(): # pragma: no branch
jaxpr_frame = EvaluationContext.find_jaxpr_frame(reference_tracers)
- while len(jaxpr_frame.eqns) > num_jaxpr_eqns:
- jaxpr_frame.eqns.pop()
+ while len(jaxpr_frame.tracing_eqns) > num_jaxpr_eqns:
+ jaxpr_frame.tracing_eqns.pop()
if EvaluationContext.is_quantum_tracing():
quantum_queue = EvaluationContext.find_quantum_queue()
diff --git a/frontend/catalyst/device/qjit_device.py b/frontend/catalyst/device/qjit_device.py
index b4f92bf022..848a638d48 100644
--- a/frontend/catalyst/device/qjit_device.py
+++ b/frontend/catalyst/device/qjit_device.py
@@ -56,44 +56,6 @@
logger = logging.getLogger(__name__)
logger.addHandler(logging.NullHandler())
-RUNTIME_OPERATIONS = [
- "CNOT",
- "ControlledPhaseShift",
- "CRot",
- "CRX",
- "CRY",
- "CRZ",
- "CSWAP",
- "CY",
- "CZ",
- "Hadamard",
- "Identity",
- "IsingXX",
- "IsingXY",
- "IsingYY",
- "IsingZZ",
- "SingleExcitation",
- "DoubleExcitation",
- "ISWAP",
- "MultiRZ",
- "PauliX",
- "PauliY",
- "PauliZ",
- "PCPhase",
- "PhaseShift",
- "PSWAP",
- "QubitUnitary",
- "Rot",
- "RX",
- "RY",
- "RZ",
- "S",
- "SWAP",
- "T",
- "Toffoli",
- "GlobalPhase",
-]
-
RUNTIME_OBSERVABLES = [
"Identity",
"PauliX",
@@ -109,11 +71,9 @@
RUNTIME_MPS = ["ExpectationMP", "SampleMP", "VarianceMP", "CountsMP", "StateMP", "ProbabilityMP"]
-# The runtime interface does not care about specific gate properties, so set them all to True.
-RUNTIME_OPERATIONS = {
- op: OperatorProperties(invertible=True, controllable=True, differentiable=True)
- for op in RUNTIME_OPERATIONS
-}
+# A list of custom operations supported by the Catalyst compiler.
+# This is useful especially for testing a device with custom operations.
+CUSTOM_OPERATIONS = {}
RUNTIME_OBSERVABLES = {
obs: OperatorProperties(invertible=True, controllable=True, differentiable=True)
@@ -199,6 +159,14 @@ def extract_backend_info(device: qml.devices.QubitDevice) -> BackendInfo:
return BackendInfo(dname, device_name, device_lpath, device_kwargs)
+def union_operations(
+ a: Dict[str, OperatorProperties], b: Dict[str, OperatorProperties]
+) -> Dict[str, OperatorProperties]:
+ """Union of two sets of operator properties"""
+ return {**a, **b}
+ # return {k: a[k] & b[k] for k in (a.keys() & b.keys())}
+
+
def intersect_operations(
a: Dict[str, OperatorProperties], b: Dict[str, OperatorProperties]
) -> Dict[str, OperatorProperties]:
@@ -223,8 +191,8 @@ def get_qjit_device_capabilities(target_capabilities: DeviceCapabilities) -> Dev
qjit_capabilities = deepcopy(target_capabilities)
# Intersection of gates and observables supported by the device and by Catalyst runtime.
- qjit_capabilities.operations = intersect_operations(
- target_capabilities.operations, RUNTIME_OPERATIONS
+ qjit_capabilities.operations = union_operations(
+ target_capabilities.operations, CUSTOM_OPERATIONS
)
qjit_capabilities.observables = intersect_operations(
target_capabilities.observables, RUNTIME_OBSERVABLES
diff --git a/frontend/catalyst/from_plxpr/control_flow.py b/frontend/catalyst/from_plxpr/control_flow.py
index 1569daf60b..1a42c3bdc9 100644
--- a/frontend/catalyst/from_plxpr/control_flow.py
+++ b/frontend/catalyst/from_plxpr/control_flow.py
@@ -25,7 +25,11 @@
from pennylane.capture.primitives import for_loop_prim as plxpr_for_loop_prim
from pennylane.capture.primitives import while_loop_prim as plxpr_while_loop_prim
-from catalyst.from_plxpr.from_plxpr import PLxPRToQuantumJaxprInterpreter, WorkflowInterpreter
+from catalyst.from_plxpr.from_plxpr import (
+ PLxPRToQuantumJaxprInterpreter,
+ WorkflowInterpreter,
+ _tuple_to_slice,
+)
from catalyst.from_plxpr.qubit_handler import (
QubitHandler,
QubitIndexRecorder,
@@ -101,8 +105,13 @@ def _to_bool_if_not(arg):
@WorkflowInterpreter.register_primitive(plxpr_cond_prim)
def workflow_cond(self, *plxpr_invals, jaxpr_branches, consts_slices, args_slice):
- """Handle the conversion from plxpr to Catalyst jaxpr for the cond primitive"""
- args = plxpr_invals[args_slice]
+ """Handle the conversion from plxpr to Catalyst jaxpr for the cond primitive
+
+ Args:
+ consts_slices: List of tuples (start, stop, step) to slice consts for each branch
+ args_slice: Tuple (start, stop, step) to slice args from plxpr_invals
+ """
+ args = plxpr_invals[_tuple_to_slice(args_slice)]
converted_jaxpr_branches = []
all_consts = []
@@ -110,7 +119,7 @@ def workflow_cond(self, *plxpr_invals, jaxpr_branches, consts_slices, args_slice
for const_slice, plxpr_branch in zip(consts_slices, jaxpr_branches):
# Store all branches consts in a flat list
- branch_consts = plxpr_invals[const_slice]
+ branch_consts = plxpr_invals[_tuple_to_slice(const_slice)]
evaluator = partial(copy(self).eval, plxpr_branch, branch_consts)
new_jaxpr = jax.make_jaxpr(evaluator)(*args)
@@ -132,8 +141,13 @@ def workflow_cond(self, *plxpr_invals, jaxpr_branches, consts_slices, args_slice
@PLxPRToQuantumJaxprInterpreter.register_primitive(plxpr_cond_prim)
def handle_cond(self, *plxpr_invals, jaxpr_branches, consts_slices, args_slice):
- """Handle the conversion from plxpr to Catalyst jaxpr for the cond primitive"""
- args = plxpr_invals[args_slice]
+ """Handle the conversion from plxpr to Catalyst jaxpr for the cond primitive
+
+ Args:
+ consts_slices: List of tuples (start, stop, step) to slice consts for each branch
+ args_slice: Tuple (start, stop, step) to slice args from plxpr_invals
+ """
+ args = plxpr_invals[_tuple_to_slice(args_slice)]
self.init_qreg.insert_all_dangling_qubits()
dynalloced_qregs, dynalloced_wire_global_indices = _get_dynamically_allocated_qregs(
@@ -154,7 +168,7 @@ def handle_cond(self, *plxpr_invals, jaxpr_branches, consts_slices, args_slice):
for const_slice, plxpr_branch in zip(consts_slices, jaxpr_branches):
# Store all branches consts in a flat list
- branch_consts = plxpr_invals[const_slice]
+ branch_consts = plxpr_invals[_tuple_to_slice(const_slice)]
converted_jaxpr_branch = None
closed_jaxpr = ClosedJaxpr(plxpr_branch, branch_consts)
@@ -205,11 +219,17 @@ def workflow_for_loop(
args_slice,
abstract_shapes_slice,
):
- """Handle the conversion from plxpr to Catalyst jaxpr for the for loop primitive"""
+ """Handle the conversion from plxpr to Catalyst jaxpr for the for loop primitive
+
+ Args:
+ consts_slice: Tuple (start, stop, step) to slice consts from plxpr_invals
+ args_slice: Tuple (start, stop, step) to slice args from plxpr_invals
+ abstract_shapes_slice: Tuple (start, stop, step) to slice abstract shapes
+ """
assert jaxpr_body_fn is not None
- args = plxpr_invals[args_slice]
+ args = plxpr_invals[_tuple_to_slice(args_slice)]
- consts = plxpr_invals[consts_slice]
+ consts = plxpr_invals[_tuple_to_slice(consts_slice)]
converter = copy(self)
evaluator = partial(converter.eval, jaxpr_body_fn, consts)
@@ -250,9 +270,15 @@ def handle_for_loop(
args_slice,
abstract_shapes_slice,
):
- """Handle the conversion from plxpr to Catalyst jaxpr for the for loop primitive"""
+ """Handle the conversion from plxpr to Catalyst jaxpr for the for loop primitive
+
+ Args:
+ consts_slice: Tuple (start, stop, step) to slice consts from plxpr_invals
+ args_slice: Tuple (start, stop, step) to slice args from plxpr_invals
+ abstract_shapes_slice: Tuple (start, stop, step) to slice abstract shapes
+ """
assert jaxpr_body_fn is not None
- args = plxpr_invals[args_slice]
+ args = plxpr_invals[_tuple_to_slice(args_slice)]
# Add the iteration start and the qreg to the args
self.init_qreg.insert_all_dangling_qubits()
@@ -268,7 +294,7 @@ def handle_for_loop(
self.init_qreg.get(),
]
- consts = plxpr_invals[consts_slice]
+ consts = plxpr_invals[_tuple_to_slice(consts_slice)]
jaxpr = ClosedJaxpr(jaxpr_body_fn, consts)
@@ -326,10 +352,16 @@ def workflow_while_loop(
cond_slice,
args_slice,
):
- """Handle the conversion from plxpr to Catalyst jaxpr for the while loop primitive"""
- consts_body = plxpr_invals[body_slice]
- consts_cond = plxpr_invals[cond_slice]
- args = plxpr_invals[args_slice]
+ """Handle the conversion from plxpr to Catalyst jaxpr for the while loop primitive
+
+ Args:
+ body_slice: Tuple (start, stop, step) to slice body consts from plxpr_invals
+ cond_slice: Tuple (start, stop, step) to slice cond consts from plxpr_invals
+ args_slice: Tuple (start, stop, step) to slice args from plxpr_invals
+ """
+ consts_body = plxpr_invals[_tuple_to_slice(body_slice)]
+ consts_cond = plxpr_invals[_tuple_to_slice(cond_slice)]
+ args = plxpr_invals[_tuple_to_slice(args_slice)]
evaluator_body = partial(copy(self).eval, jaxpr_body_fn, consts_body)
new_body_jaxpr = jax.make_jaxpr(evaluator_body)(*args)
@@ -367,14 +399,20 @@ def handle_while_loop(
cond_slice,
args_slice,
):
- """Handle the conversion from plxpr to Catalyst jaxpr for the while loop primitive"""
+ """Handle the conversion from plxpr to Catalyst jaxpr for the while loop primitive
+
+ Args:
+ body_slice: Tuple (start, stop, step) to slice body consts from plxpr_invals
+ cond_slice: Tuple (start, stop, step) to slice cond consts from plxpr_invals
+ args_slice: Tuple (start, stop, step) to slice args from plxpr_invals
+ """
self.init_qreg.insert_all_dangling_qubits()
dynalloced_qregs, dynalloced_wire_global_indices = _get_dynamically_allocated_qregs(
plxpr_invals, self.qubit_index_recorder, self.init_qreg
)
- consts_body = plxpr_invals[body_slice]
- consts_cond = plxpr_invals[cond_slice]
- args = plxpr_invals[args_slice]
+ consts_body = plxpr_invals[_tuple_to_slice(body_slice)]
+ consts_cond = plxpr_invals[_tuple_to_slice(cond_slice)]
+ args = plxpr_invals[_tuple_to_slice(args_slice)]
args_plus_qreg = [
*args,
*[dyn_qreg.get() for dyn_qreg in dynalloced_qregs],
diff --git a/frontend/catalyst/from_plxpr/from_plxpr.py b/frontend/catalyst/from_plxpr/from_plxpr.py
index d819d55a5e..e46bec5349 100644
--- a/frontend/catalyst/from_plxpr/from_plxpr.py
+++ b/frontend/catalyst/from_plxpr/from_plxpr.py
@@ -40,6 +40,7 @@
from catalyst.device import extract_backend_info
from catalyst.from_plxpr.decompose import COMPILER_OPS_FOR_DECOMPOSITION, DecompRuleInterpreter
from catalyst.jax_extras import make_jaxpr2, transient_jax_config
+from catalyst.jax_extras.patches import patched_make_eqn
from catalyst.jax_primitives import (
device_init_p,
device_release_p,
@@ -48,6 +49,7 @@
quantum_kernel_p,
)
from catalyst.passes.pass_api import Pass
+from catalyst.utils.patching import Patcher
from .qfunc_interpreter import PLxPRToQuantumJaxprInterpreter
from .qubit_handler import (
@@ -56,6 +58,64 @@
)
+def _tuple_to_slice(t):
+ """Convert a tuple representation of a slice back to a slice object.
+
+ JAX converts slice objects to tuples for hashability in jaxpr parameters.
+ This function converts them back to slice objects for use with indexing.
+
+ Args:
+ t: Either a slice object (returned as-is) or a tuple (start, stop, step)
+
+ Returns:
+ slice: A slice object
+ """
+ assert (
+ isinstance(t, tuple) and len(t) == 3
+ ), "Please only use _tuple_to_slice on a tuple of length 3!"
+ return slice(*t)
+
+
+def _is_dict_like_tuple(t):
+ """Checks if a tuple t is structured like a list of (key, value) pairs."""
+ return isinstance(t, tuple) and all(isinstance(item, tuple) and len(item) == 2 for item in t)
+
+
+def _tuple_to_dict(t):
+ """
+ Recursively converts JAX-hashable tuple representations back to dicts,
+ and list-like tuples back to lists.
+
+ Args:
+ t: The item to convert. Can be a dict, a tuple, or a scalar.
+
+ Returns:
+ The converted dict, list, or the original scalar value.
+ """
+
+ if not isinstance(t, (dict, tuple, list)):
+ return t
+
+ if isinstance(t, dict): # pragma: no cover
+ return {k: _tuple_to_dict(v) for k, v in t.items()}
+
+ if isinstance(t, list): # pragma: no cover
+ return [_tuple_to_dict(item) for item in t]
+
+ if isinstance(t, tuple):
+
+ # A. Dict-like tuple: Convert to dict, then recurse on values
+ if _is_dict_like_tuple(t):
+ # This handles the main (key, value) pair structure
+ return {key: _tuple_to_dict(value) for key, value in t}
+
+ # B. List-like tuple: Convert to list, then recurse on elements
+ else:
+ return [_tuple_to_dict(item) for item in t]
+
+ return t # pragma: no cover
+
+
def _get_device_kwargs(device) -> dict:
"""Calulcate the params for a device equation."""
info = extract_backend_info(device)
@@ -132,7 +192,19 @@ def f(x):
in (b,) }
"""
- return jax.make_jaxpr(partial(WorkflowInterpreter().eval, plxpr.jaxpr, plxpr.consts))
+
+ original_fn = partial(WorkflowInterpreter().eval, plxpr.jaxpr, plxpr.consts)
+
+ # pylint: disable=import-outside-toplevel
+ from jax._src.interpreters.partial_eval import DynamicJaxprTrace
+
+ def wrapped_fn(*args, **kwargs):
+ with Patcher(
+ (DynamicJaxprTrace, "make_eqn", patched_make_eqn),
+ ):
+ return jax.make_jaxpr(original_fn)(*args, **kwargs)
+
+ return wrapped_fn
class WorkflowInterpreter(PlxprInterpreter):
@@ -291,9 +363,10 @@ def handle_transform(
):
"""Handle the conversion from plxpr to Catalyst jaxpr for a
PL transform."""
- consts = args[consts_slice]
- non_const_args = args[args_slice]
- targs = args[targs_slice]
+ consts = args[_tuple_to_slice(consts_slice)]
+ non_const_args = args[_tuple_to_slice(args_slice)]
+ targs = args[_tuple_to_slice(targs_slice)]
+ tkwargs = _tuple_to_dict(tkwargs)
# If the transform is a decomposition transform
# and the graph-based decomposition is enabled
@@ -302,6 +375,7 @@ def handle_transform(
and transform._plxpr_transform.__name__ == "decompose_plxpr_to_plxpr"
and qml.decomposition.enabled_graph()
):
+ # Handle the conversion from plxpr to Catalyst jaxpr for a PL transform.
if not self.requires_decompose_lowering:
self.requires_decompose_lowering = True
else:
@@ -397,7 +471,34 @@ def trace_from_pennylane(
Tuple[Any]: the dynamic argument signature
"""
- with transient_jax_config({"jax_dynamic_shapes": True}):
+ # pylint: disable=import-outside-toplevel
+ import jax._src.interpreters.partial_eval as pe
+ from jax._src.interpreters.partial_eval import DynamicJaxprTrace
+ from jax._src.lax import lax
+ from jax._src.pjit import jit_p
+
+ from catalyst.jax_extras.patches import (
+ get_aval2,
+ patched_drop_unused_vars,
+ patched_dyn_shape_staging_rule,
+ patched_pjit_staging_rule,
+ )
+ from catalyst.utils.patching import DictPatchWrapper
+
+ with transient_jax_config(
+ {"jax_dynamic_shapes": True, "jax_use_shardy_partitioner": False}
+ ), Patcher(
+ (pe, "_drop_unused_vars", patched_drop_unused_vars),
+ (DynamicJaxprTrace, "make_eqn", patched_make_eqn),
+ (lax, "_dyn_shape_staging_rule", patched_dyn_shape_staging_rule),
+ (
+ jax._src.pjit, # pylint: disable=protected-access
+ "pjit_staging_rule",
+ patched_pjit_staging_rule,
+ ),
+ (DictPatchWrapper(pe.custom_staging_rules, jit_p), "value", patched_pjit_staging_rule),
+ (pe, "get_aval", get_aval2),
+ ):
make_jaxpr_kwargs = {
"static_argnums": static_argnums,
diff --git a/frontend/catalyst/jax_extras/lowering.py b/frontend/catalyst/jax_extras/lowering.py
index cd9a84a121..ec7b583d52 100644
--- a/frontend/catalyst/jax_extras/lowering.py
+++ b/frontend/catalyst/jax_extras/lowering.py
@@ -23,8 +23,6 @@
from jax._src.effects import ordered_effects as jax_ordered_effects
from jax._src.interpreters.mlir import _module_name_regex
from jax._src.sharding_impls import AxisEnv, ReplicaAxisContext
-from jax._src.source_info_util import new_name_stack
-from jax._src.util import wrap_name
from jax.extend.core import ClosedJaxpr
from jax.interpreters.mlir import (
AxisContext,
@@ -73,7 +71,6 @@ def jaxpr_to_mlir(jaxpr, func_name, arg_names):
nrep = jaxpr_replicas(jaxpr)
effects = jax_ordered_effects.filter_in(jaxpr.effects)
axis_context = ReplicaAxisContext(AxisEnv(nrep, (), ()))
- name_stack = new_name_stack(wrap_name("ok", "jit"))
module, context = custom_lower_jaxpr_to_module(
func_name="jit_" + func_name,
module_name=func_name,
@@ -81,7 +78,6 @@ def jaxpr_to_mlir(jaxpr, func_name, arg_names):
effects=effects,
platform="cpu",
axis_context=axis_context,
- name_stack=name_stack,
arg_names=arg_names,
)
@@ -97,7 +93,6 @@ def custom_lower_jaxpr_to_module(
effects,
platform: str,
axis_context: AxisContext,
- name_stack,
replicated_args=None,
arg_names=None,
arg_shardings=None,
@@ -145,27 +140,34 @@ def custom_lower_jaxpr_to_module(
# XLA computation preserves the module name.
module_name = _module_name_regex.sub("_", module_name)
ctx.module.operation.attributes["sym_name"] = ir.StringAttr.get(module_name)
+
+ # Use main_function=False to preserve the function name (e.g., "jit_func")
+ # instead of renaming it to "main"
lower_jaxpr_to_fun(
ctx,
func_name,
jaxpr,
effects,
- public=True,
+ main_function=False,
replicated_args=replicated_args,
arg_names=arg_names,
arg_shardings=arg_shardings,
result_shardings=result_shardings,
- name_stack=name_stack,
)
+ # Set the entry point function visibility to public and other functions to internal
worklist = [*ctx.module.body.operations]
while worklist:
op = worklist.pop()
func_name = str(op.name)
is_entry_point = func_name.startswith('"jit_')
+
if is_entry_point:
+ # Keep entry point functions public
+ op.attributes["sym_visibility"] = ir.StringAttr.get("public")
continue
if isinstance(op, FuncOp):
+ # Set non-entry functions to internal linkage
op.attributes["llvm.linkage"] = ir.Attribute.parse("#llvm.linkage")
if isinstance(op, ModuleOp):
worklist += [*op.body.operations]
diff --git a/frontend/catalyst/jax_extras/patches.py b/frontend/catalyst/jax_extras/patches.py
index 82ea21984b..dee87eea55 100644
--- a/frontend/catalyst/jax_extras/patches.py
+++ b/frontend/catalyst/jax_extras/patches.py
@@ -20,7 +20,15 @@
from functools import partial
import jax
-from jax._src.core import abstractify, standard_vma_rule
+import jax._src.interpreters.partial_eval as pe
+from jax._src import config, core, source_info_util
+from jax._src.core import JaxprEqnContext, abstractify, standard_vma_rule
+from jax._src.interpreters.partial_eval import (
+ DynamicJaxprTracer,
+ TracingEqn,
+ compute_on,
+ xla_metadata_lib,
+)
from jax._src.lax.slicing import (
_argnum_weak_type,
_gather_dtype_rule,
@@ -32,6 +40,8 @@
_sorted_dims_in_range,
standard_primitive,
)
+from jax._src.pjit import _out_type, _pjit_forwarding, jit_p
+from jax._src.sharding_impls import UnspecifiedValue
from jax.core import AbstractValue, Tracer
__all__ = (
@@ -40,6 +50,10 @@
"_no_clean_up_dead_vars",
"_gather_shape_rule_dynamic",
"gather2_p",
+ "patched_drop_unused_vars",
+ "patched_make_eqn",
+ "patched_dyn_shape_staging_rule",
+ "patched_pjit_staging_rule",
)
@@ -66,11 +80,13 @@ def __getattr__(self, name):
return MockAttributeWrapper(obj)
-def _drop_unused_vars2(jaxpr, constvals):
+def _drop_unused_vars2(
+ constvars, constvals, eqns=None, outvars=None
+): # pylint: disable=unused-argument
"""
A patch to not drop unused vars during classical tracing of control flow.
"""
- return jaxpr, list(constvals)
+ return constvars, list(constvals)
def get_aval2(x):
@@ -231,3 +247,135 @@ def _gather_shape_rule_dynamic(
sharding_rule=_gather_sharding_rule,
vma_rule=partial(standard_vma_rule, "gather"),
)
+
+# pylint: disable=protected-access
+original_drop_unused_vars = pe._drop_unused_vars
+
+
+# pylint: disable=too-many-function-args
+def patched_drop_unused_vars(constvars, constvals, eqns=None, outvars=None):
+ """Patched drop_unused_vars to ensure constvals is a list."""
+ constvars, constvals = original_drop_unused_vars(constvars, constvals, eqns, outvars)
+ return constvars, list(constvals)
+
+
+# pylint: disable=too-many-positional-arguments
+def patched_make_eqn(
+ self,
+ in_tracers,
+ out_avals,
+ primitive,
+ params,
+ effects,
+ source_info=None,
+ ctx=None,
+ out_tracers=None,
+):
+ """Patched make_eqn for DynamicJaxprTrace"""
+
+ # Helper function (replaces make_eqn_internal)
+ def make_eqn_internal(out_avals_list, out_tracers):
+ source_info_final = source_info or source_info_util.new_source_info()
+ ctx_final = ctx or JaxprEqnContext(
+ compute_on.current_compute_type(),
+ config.threefry_partitionable.value,
+ xla_metadata_lib.current_xla_metadata(),
+ )
+
+ if out_tracers is not None:
+ outvars = [tracer.val for tracer in out_tracers]
+ eqn = TracingEqn(
+ in_tracers,
+ outvars,
+ primitive,
+ params,
+ effects,
+ source_info_final,
+ ctx_final,
+ )
+ return eqn, out_tracers
+ else:
+ outvars = list(map(self.frame.newvar, out_avals_list))
+ eqn = TracingEqn(
+ in_tracers,
+ outvars,
+ primitive,
+ params,
+ effects,
+ source_info_final,
+ ctx_final,
+ )
+ out_tracers_new = [
+ DynamicJaxprTracer(self, aval, v, source_info_final, eqn)
+ for aval, v in zip(out_avals_list, outvars)
+ ]
+ return eqn, out_tracers_new
+
+ # Normalize out_avals to a list if it's a single AbstractValue
+ if not isinstance(out_avals, (list, tuple)):
+ # It's a single aval, wrap it in a list
+ out_avals_list = [out_avals]
+ eqn, out_tracers_result = make_eqn_internal(out_avals_list, out_tracers)
+ # Return single tracer instead of list
+ return eqn, (out_tracers_result[0] if len(out_tracers_result) == 1 else out_tracers_result)
+ else:
+ return make_eqn_internal(out_avals, out_tracers)
+
+
+def patched_dyn_shape_staging_rule(trace, source_info, prim, out_aval, *args, **params):
+ """Patched _dyn_shape_staging_rule for dynamic shape handling."""
+ eqn, out_tracer = trace.make_eqn(args, out_aval, prim, params, core.no_effects, source_info)
+ trace.frame.add_eqn(eqn)
+ return out_tracer
+
+
+def patched_pjit_staging_rule(trace, source_info, *args, **params):
+ """Patched pjit_staging_rule for pjit compatibility."""
+ # If we're inlining, no need to compute forwarding information; the inlined
+ # computation will in effect forward things.
+ if (
+ params["inline"]
+ and all(isinstance(i, UnspecifiedValue) for i in params["in_shardings"])
+ and all(isinstance(o, UnspecifiedValue) for o in params["out_shardings"])
+ and all(i is None for i in params["in_layouts"])
+ and all(o is None for o in params["out_layouts"])
+ ):
+ jaxpr = params["jaxpr"]
+ if config.dynamic_shapes.value:
+ # Inline jaxpr doesn't handle dynamic shapes when inlining. If dynamic
+ # shapes are enabled, use eval_jaxpr, which uses the tracing machinery,
+ # but redundantly performs abstract evaluation again.
+ with core.set_current_trace(trace):
+ out = core.eval_jaxpr(jaxpr.jaxpr, jaxpr.consts, *args, propagate_source_info=False)
+ else:
+ out = pe.inline_jaxpr_into_trace(trace, source_info, jaxpr.jaxpr, jaxpr.consts, *args)
+ return [trace.to_jaxpr_tracer(x, source_info) for x in out]
+
+ jaxpr = params["jaxpr"]
+ if config.dynamic_shapes.value:
+ jaxpr, in_fwd, out_shardings, out_layouts = _pjit_forwarding(
+ jaxpr, params["out_shardings"], params["out_layouts"]
+ )
+ params = dict(params, jaxpr=jaxpr, out_shardings=out_shardings, out_layouts=out_layouts)
+ outvars = list(map(trace.frame.newvar, _out_type(jaxpr)))
+ out_avals = [v.aval for v in outvars]
+ out_tracers = [
+ pe.DynamicJaxprTracer(trace, aval, v, source_info)
+ for aval, v in zip(out_avals, outvars)
+ ]
+ eqn, out_tracers = trace.make_eqn(
+ args,
+ out_avals,
+ jit_p,
+ params,
+ jaxpr.effects,
+ source_info,
+ out_tracers=out_tracers,
+ )
+ trace.frame.add_eqn(eqn)
+ out_tracers_ = iter(out_tracers)
+ out_tracers = [args[f] if isinstance(f, int) else next(out_tracers_) for f in in_fwd]
+ assert next(out_tracers_, None) is None
+ else:
+ out_tracers = trace.default_process_primitive(jit_p, args, params, source_info=source_info)
+ return out_tracers
diff --git a/frontend/catalyst/jax_extras/tracing.py b/frontend/catalyst/jax_extras/tracing.py
index 8798fb65d9..969b7edb4c 100644
--- a/frontend/catalyst/jax_extras/tracing.py
+++ b/frontend/catalyst/jax_extras/tracing.py
@@ -13,7 +13,7 @@
# limitations under the License.
"""Jax extras module containing functions related to the Python program tracing"""
-# pylint: disable=line-too-long
+# pylint: disable=line-too-long,too-many-lines
from __future__ import annotations
@@ -47,7 +47,6 @@
eval_jaxpr,
find_top_trace,
gensym,
- new_jaxpr_eqn,
)
from jax.extend.core import ClosedJaxpr, Jaxpr, JaxprEqn
from jax.extend.core import Primitive as JaxprPrimitive
@@ -197,8 +196,23 @@ def stable_toposort(end_nodes: list) -> list:
return sorted_nodes
-def sort_eqns(eqns: List[JaxprEqn], forced_order_primitives: Set[JaxprPrimitive]) -> List[JaxprEqn]:
- """Topologically sort JAXRR equations in a unsorted list of equations, based on their
+class Box:
+ """Wrapper for TracingEqn keeping track of its id and parents."""
+
+ def __init__(self, boxid: int, e: JaxprEqn):
+ self.id: int = boxid
+ self.e: JaxprEqn = e
+ self.parents: List["Box"] = [] # to be filled later
+
+ def __lt__(self, other):
+ return self.id < other.id
+
+
+def sort_eqns(
+ eqns: List[JaxprEqn | Callable[[], JaxprEqn]],
+ forced_order_primitives: Set[JaxprPrimitive],
+) -> List[JaxprEqn | Callable[[], JaxprEqn]]:
+ """Topologically sort TracingEqns in a unsorted list of equations, based on their
input/output variables and additional criterias."""
# The procedure goes as follows: [1] - initialize the `origin` map mapping variable identifiers
@@ -206,18 +220,18 @@ def sort_eqns(eqns: List[JaxprEqn], forced_order_primitives: Set[JaxprPrimitive]
# correct values, [3] - add additional equation order restrictions to boxes, [4] - call the
# topological sorting.
- class Box:
- """Wrapper for JaxprEqn keeping track of its id and parents."""
-
- def __init__(self, boxid: int, e: JaxprEqn):
- self.id: int = boxid
- self.e: JaxprEqn = e
- self.parents: List["Box"] = [] # to be filled later
+ # JAX 0.7+: eqns might be lambda functions that return TracingEqn or weakrefs
+ # We need to preserve the mapping from actual_eqn to the original callable
+ actual_eqns = []
+ eqn_to_callable = {} # Maps actual TracingEqn to its original callable/wrapper
+ for eqn_or_callable in eqns:
+ eqn = eqn_or_callable() if callable(eqn_or_callable) else eqn_or_callable
+ assert eqn is not None
+ actual_eqns.append(eqn)
+ eqn_to_callable[id(eqn)] = eqn_or_callable
- def __lt__(self, other):
- return self.id < other.id
+ boxes = [Box(i, e) for i, e in enumerate(actual_eqns)]
- boxes = [Box(i, e) for i, e in enumerate(eqns)]
fixedorder = [(i, b) for (i, b) in enumerate(boxes) if b.e.primitive in forced_order_primitives]
origin: Dict[int, Box] = {}
for b in boxes:
@@ -225,15 +239,28 @@ def __lt__(self, other):
for b in boxes:
b.parents = []
for v in b.e.invars:
- if not isinstance(v, jax._src.core.Var):
+ if hasattr(v, "val") and isinstance(v.val, jax._src.core.Var):
+ actual_var = v.val
+ elif isinstance(v, jax._src.core.Var):
+ actual_var = v
+ else:
# constant literal invar, no need to track def use order
continue
- if v.count in origin:
- b.parents.append(origin[v.count]) # [2]
+
+ if actual_var.count in origin:
+ b.parents.append(origin[actual_var.count]) # [2]
for i, q in fixedorder:
for b in boxes[i + 1 :]:
b.parents.append(q) # [3]
- return [b.e for b in stable_toposort(boxes)] # [4]
+
+ sorted_boxes = stable_toposort(boxes)
+
+ # Restore the original callables/wrappers for the sorted equations
+ result = []
+ for b in sorted_boxes:
+ result.append(eqn_to_callable.get(id(b.e), b.e))
+
+ return result # [4]
def jaxpr_pad_consts(jaxprs: List[Jaxpr]) -> List[ClosedJaxpr]:
@@ -428,9 +455,8 @@ def trace_to_jaxpr(
def new_inner_tracer(trace: DynamicJaxprTrace, aval) -> DynamicJaxprTracer:
"""Create a JAX tracer tracing an abstract value ``aval`, without specifying its source
primitive."""
- dt = DynamicJaxprTracer(trace, aval, current_source_info())
- trace.frame.tracers.append(dt)
- trace.frame.tracer_to_var[id(dt)] = trace.frame.newvar(aval)
+ atom = trace.frame.newvar(aval)
+ dt = DynamicJaxprTracer(trace, aval, atom, line_info=current_source_info())
return dt
@@ -913,13 +939,16 @@ def bind(self, *args, **params):
# `abstract_eval` returned `out_type` calculated for empty constants.
[],
tracers,
- maker=lambda a: DynamicJaxprTracer(trace, a, source_info),
+ maker=lambda aval: new_inner_tracer(trace, aval),
)
- invars = map(trace.getvar, tracers)
- outvars = map(trace.makevar, out_tracers)
-
- eqn = new_jaxpr_eqn(invars, outvars, self, params, [], source_info)
+ # invars = map(lambda t: t.val, tracers)
+ # outvars = map(lambda t: t.val, out_tracers)
+ # eqn = new_jaxpr_eqn(invars, outvars, self, params, [], source_info)
+ out_avals = [t.aval for t in out_tracers]
+ eqn, out_tracers = trace.make_eqn(
+ tracers, out_avals, self, params, [], source_info, out_tracers=out_tracers
+ )
trace.frame.add_eqn(eqn)
return out_tracers if self.multiple_results else out_tracers.pop()
diff --git a/frontend/catalyst/jax_primitives.py b/frontend/catalyst/jax_primitives.py
index c1aa61e124..e0d86cb4d6 100644
--- a/frontend/catalyst/jax_primitives.py
+++ b/frontend/catalyst/jax_primitives.py
@@ -32,9 +32,8 @@
from jax._src.lax.lax import _merge_dyn_shape, _nary_lower_hlo, cos_p, sin_p
from jax._src.lib.mlir import ir
from jax._src.lib.mlir.dialects import hlo
-from jax._src.pjit import _pjit_lowering
+from jax._src.pjit import _pjit_lowering, jit_p
from jax.core import AbstractValue
-from jax.experimental.pjit import pjit_p
from jax.extend.core import Primitive
from jax.interpreters import mlir
from jax.tree_util import PyTreeDef, tree_unflatten
@@ -350,7 +349,7 @@ class MeasurementPlane(Enum):
decomprule_p = core.Primitive("decomposition_rule")
decomprule_p.multiple_results = True
-quantum_subroutine_p = copy.deepcopy(pjit_p)
+quantum_subroutine_p = copy.deepcopy(jit_p)
quantum_subroutine_p.name = "quantum_subroutine_p"
subroutine_cache: dict[callable, callable] = {}
@@ -394,15 +393,15 @@ def main():
# pylint: disable-next=import-outside-toplevel
from catalyst.api_extensions.callbacks import WRAPPER_ASSIGNMENTS
- old_pjit = jax._src.pjit.pjit_p
+ old_jit_p = jax._src.pjit.jit_p
@functools.wraps(func, assigned=WRAPPER_ASSIGNMENTS)
def inside(*args, **kwargs):
with Patcher(
(
jax._src.pjit,
- "pjit_p",
- old_pjit,
+ "jit_p",
+ old_jit_p,
),
):
return func(*args, **kwargs)
@@ -417,7 +416,7 @@ def wrapper(*args, **kwargs):
with Patcher(
(
jax._src.pjit,
- "pjit_p",
+ "jit_p",
quantum_subroutine_p,
),
):
@@ -714,7 +713,7 @@ def _grad_lowering(ctx, *args, jaxpr, fn, grad_params):
flat_output_types,
ir.StringAttr.get(method),
symbol_ref,
- mlir.flatten_lowering_ir_args(args_and_consts),
+ mlir.flatten_ir_values(args_and_consts),
diffArgIndices=diffArgIndices,
finiteDiffParam=finiteDiffParam,
).results
@@ -738,7 +737,7 @@ def _capture_grad_lowering(ctx, *args, argnums, jaxpr, n_consts, method, h, fn,
flat_output_types,
ir.StringAttr.get(method),
symbol_ref,
- mlir.flatten_lowering_ir_args(args),
+ mlir.flatten_ir_values(args),
diffArgIndices=diffArgIndices,
finiteDiffParam=finiteDiffParam,
).results
@@ -815,7 +814,7 @@ def _value_and_grad_lowering(ctx, *args, jaxpr, fn, grad_params):
gradient_result_types,
ir.StringAttr.get(method),
symbol_ref,
- mlir.flatten_lowering_ir_args(func_args),
+ mlir.flatten_ir_values(func_args),
diffArgIndices=ir.DenseIntElementsAttr.get(new_argnums),
finiteDiffParam=ir.FloatAttr.get(ir.F64Type.get(mlir_ctx), h) if h else None,
).results
@@ -868,8 +867,8 @@ def _jvp_lowering(ctx, *args, jaxpr, fn, grad_params):
flat_output_types[len(flat_output_types) // 2 :],
ir.StringAttr.get(method),
symbol_ref,
- mlir.flatten_lowering_ir_args(func_args),
- mlir.flatten_lowering_ir_args(tang_args),
+ mlir.flatten_ir_values(func_args),
+ mlir.flatten_ir_values(tang_args),
diffArgIndices=ir.DenseIntElementsAttr.get(new_argnums),
finiteDiffParam=ir.FloatAttr.get(ir.F64Type.get(mlir_ctx), h) if h else None,
).results
@@ -918,8 +917,8 @@ def _vjp_lowering(ctx, *args, jaxpr, fn, grad_params):
vjp_result_types,
ir.StringAttr.get(method),
symbol_ref,
- mlir.flatten_lowering_ir_args(func_args),
- mlir.flatten_lowering_ir_args(cotang_args),
+ mlir.flatten_ir_values(func_args),
+ mlir.flatten_ir_values(cotang_args),
diffArgIndices=ir.DenseIntElementsAttr.get(new_argnums),
finiteDiffParam=ir.FloatAttr.get(ir.F64Type.get(mlir_ctx), h) if h else None,
).results
@@ -982,7 +981,7 @@ def _zne_lowering(ctx, *args, folding, jaxpr, fn):
return ZneOp(
flat_output_types,
symbol_ref,
- mlir.flatten_lowering_ir_args(args_and_consts),
+ mlir.flatten_ir_values(args_and_consts),
_folding_attribute(ctx, folding),
num_folds,
).results
@@ -1852,21 +1851,11 @@ def custom_measurement_staging_rule(
else:
out_shapes = tuple(core.DShapedArray(shape, dtype) for dtype in dtypes)
- invars = [jaxpr_trace.getvar(obs)]
- for dyn_dim in dynamic_shape:
- invars.append(jaxpr_trace.getvar(dyn_dim))
+ in_tracers = [obs] + list(dynamic_shape)
params = {"static_shape": static_shape}
- out_tracers = tuple(pe.DynamicJaxprTracer(jaxpr_trace, out_shape) for out_shape in out_shapes)
-
- eqn = pe.new_jaxpr_eqn(
- invars,
- [jaxpr_trace.makevar(out_tracer) for out_tracer in out_tracers],
- primitive,
- params,
- jax.core.no_effects,
- )
+ eqn, out_tracers = jaxpr_trace.make_eqn(in_tracers, out_shapes, primitive, params, [])
jaxpr_trace.frame.add_eqn(eqn)
return out_tracers if len(out_tracers) > 1 else out_tracers[0]
@@ -2172,7 +2161,7 @@ def _cond_lowering(
num_preds = len(branch_jaxprs) - 1
preds = preds_and_branch_args_plus_consts[:num_preds]
branch_args_plus_consts = preds_and_branch_args_plus_consts[num_preds:]
- flat_args_plus_consts = mlir.flatten_lowering_ir_args(branch_args_plus_consts)
+ flat_args_plus_consts = mlir.flatten_ir_values(branch_args_plus_consts)
# recursively lower if-else chains to nested IfOps
def emit_branches(preds, branch_jaxprs, ip):
@@ -2263,7 +2252,7 @@ def _switch_lowering(
# the last branch is default and does not have a case
cases = index_and_cases_and_branch_args_plus_consts[1 : len(branch_jaxprs)]
branch_args_plus_consts = index_and_cases_and_branch_args_plus_consts[len(branch_jaxprs) :]
- flat_args_plus_consts = mlir.flatten_lowering_ir_args(branch_args_plus_consts)
+ flat_args_plus_consts = mlir.flatten_ir_values(branch_args_plus_consts)
index = _cast_to_index(index)
@@ -2357,7 +2346,7 @@ def _while_loop_lowering(
preserve_dimensions: bool,
):
loop_carry_types_plus_consts = [mlir.aval_to_ir_types(a)[0] for a in jax_ctx.avals_in]
- flat_args_plus_consts = mlir.flatten_lowering_ir_args(iter_args_plus_consts)
+ flat_args_plus_consts = mlir.flatten_ir_values(iter_args_plus_consts)
assert [val.type for val in flat_args_plus_consts] == loop_carry_types_plus_consts
# split the argument list into 3 separate groups
diff --git a/frontend/catalyst/jax_primitives_utils.py b/frontend/catalyst/jax_primitives_utils.py
index 8d7e27acfe..99d2d181d1 100644
--- a/frontend/catalyst/jax_primitives_utils.py
+++ b/frontend/catalyst/jax_primitives_utils.py
@@ -176,13 +176,11 @@ def lower_callable_to_funcop(ctx, callable_, call_jaxpr, public=False):
kwargs["name"] = name
kwargs["jaxpr"] = call_jaxpr
kwargs["effects"] = []
- kwargs["name_stack"] = ctx.name_stack
-
- # Make the visibility of the function public=True
- # to avoid elimination by the compiler
- kwargs["public"] = public
+ kwargs["main_function"] = False
func_op = mlir.lower_jaxpr_to_fun(**kwargs)
+ if public:
+ func_op.attributes["sym_visibility"] = ir.StringAttr.get("public")
if isinstance(callable_, qml.QNode):
func_op.attributes["qnode"] = ir.UnitAttr.get()
@@ -281,7 +279,7 @@ def create_call_op(ctx, func_op, *args):
"""Create a func::CallOp from JAXPR."""
output_types = list(map(mlir.aval_to_ir_types, ctx.avals_out))
flat_output_types = util.flatten(output_types)
- mlir_args = mlir.flatten_lowering_ir_args(args)
+ mlir_args = mlir.flatten_ir_values(args)
symbol_ref = get_symbolref(ctx, func_op)
is_call_same_module = ctx.module_context.module.operation == func_op.parent
constructor = CallOp if is_call_same_module else LaunchKernelOp
diff --git a/frontend/catalyst/jax_tracer.py b/frontend/catalyst/jax_tracer.py
index 4f0448ad7f..4402276b68 100644
--- a/frontend/catalyst/jax_tracer.py
+++ b/frontend/catalyst/jax_tracer.py
@@ -26,9 +26,12 @@
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
import jax
+import jax._src.interpreters.partial_eval as pe
import jax.numpy as jnp
import pennylane as qml
+from jax._src.lax import lax
from jax._src.lax.lax import _extract_tracers_dyn_shape
+from jax._src.pjit import jit_p
from jax._src.source_info_util import current as current_source_info
from jax.api_util import debug_info as jdb
from jax.core import get_aval
@@ -75,6 +78,12 @@
tree_unflatten,
wrap_init,
)
+from catalyst.jax_extras.patches import (
+ patched_drop_unused_vars,
+ patched_dyn_shape_staging_rule,
+ patched_make_eqn,
+ patched_pjit_staging_rule,
+)
from catalyst.jax_extras.tracing import uses_transform
from catalyst.jax_primitives import (
AbstractQreg,
@@ -106,6 +115,7 @@
from catalyst.logging import debug_logger, debug_logger_init
from catalyst.tracing.contexts import EvaluationContext, EvaluationMode
from catalyst.utils.exceptions import CompileError
+from catalyst.utils.patching import DictPatchWrapper, Patcher
logger = logging.getLogger(__name__)
logger.addHandler(logging.NullHandler())
@@ -136,6 +146,17 @@ def mark_gradient_tracing(method: str):
TRACING_GRADIENTS.pop()
+def _get_eqn_from_tracing_eqn(eqn_or_callable):
+ """Helper function to extract the actual equation from JAX 0.7's TracingEqn wrapper."""
+ if callable(eqn_or_callable):
+ actual_eqn = eqn_or_callable()
+ if actual_eqn is None: # pragma: no cover
+ raise RuntimeError("TracingEqn weakref was garbage collected")
+ return actual_eqn
+ else: # pragma: no cover
+ return eqn_or_callable
+
+
def _make_execution_config(qnode):
"""Updates the execution_config object with information about execution. This is
used in preprocess to determine what decomposition and validation is needed."""
@@ -516,13 +537,14 @@ def bind_overwrite_classical_tracers(
as-is.
"""
assert self.binder is not None, "HybridOp should set a binder"
+
# Here, we are binding any of the possible hybrid ops.
# which includes: for_loop, while_loop, cond, measure.
# This will place an equation at the very end of the current list of equations.
out_quantum_tracer = self.binder(*in_expanded_tracers, **kwargs)[-1]
trace = EvaluationContext.get_current_trace()
- eqn = trace.frame.eqns[-1]
+ eqn = _get_eqn_from_tracing_eqn(trace.frame.tracing_eqns[-1])
frame = trace.frame
assert len(eqn.outvars[:-1]) == len(
@@ -532,22 +554,26 @@ def bind_overwrite_classical_tracers(
jaxpr_variables = cached_vars.get(frame, set())
if not jaxpr_variables:
# We get all variables in the current frame
- outvars = itertools.chain.from_iterable([e.outvars for e in frame.eqns])
+ outvars = itertools.chain.from_iterable(
+ [_get_eqn_from_tracing_eqn(e).outvars for e in frame.tracing_eqns]
+ )
jaxpr_variables = set(outvars)
jaxpr_variables.update(frame.invars)
jaxpr_variables.update(frame.constvar_to_val.keys())
cached_vars[frame] = jaxpr_variables
- for outvar in frame.eqns[-1].outvars:
+ last_eqn = _get_eqn_from_tracing_eqn(frame.tracing_eqns[-1])
+ for outvar in last_eqn.outvars:
# With the exception of the output variables from the current equation.
jaxpr_variables.discard(outvar)
for i, t in enumerate(out_expanded_tracers):
# We look for what were the previous output tracers.
# If they haven't changed, then we leave them unchanged.
- if trace.getvar(t) in jaxpr_variables:
+ if t.val in jaxpr_variables:
continue
+ # For other hybrid ops, use the original logic
# If the variable cannot be found in the current frame
# it is because we have created it via new_inner_tracer
# which uses JAX internals to create a tracer without associating
@@ -587,7 +613,7 @@ def bind_overwrite_classical_tracers(
# qrp2 = op.trace_quantum(ctx, device, trace, qrp, **kwargs)
#
# So it should be safe to cache the tracers as we are doing it.
- eqn.outvars[i] = trace.getvar(t)
+ eqn.outvars[i] = t.val
# Now, the output variables can be considered as part of the current frame.
# This allows us to avoid importing all equations again next time.
@@ -642,7 +668,19 @@ def trace_to_jaxpr(func, static_argnums, abstracted_axes, args, kwargs, debug_in
PyTreeDef: PyTree-shape of the return values in ``PyTreeDef``
"""
- with transient_jax_config({"jax_dynamic_shapes": True}):
+ with transient_jax_config(
+ {"jax_dynamic_shapes": True, "jax_use_shardy_partitioner": False}
+ ), Patcher(
+ (pe, "_drop_unused_vars", patched_drop_unused_vars),
+ (DynamicJaxprTrace, "make_eqn", patched_make_eqn),
+ (lax, "_dyn_shape_staging_rule", patched_dyn_shape_staging_rule),
+ (
+ jax._src.pjit, # pylint: disable=protected-access
+ "pjit_staging_rule",
+ patched_pjit_staging_rule,
+ ),
+ (DictPatchWrapper(pe.custom_staging_rules, jit_p), "value", patched_pjit_staging_rule),
+ ):
make_jaxpr_kwargs = {
"static_argnums": static_argnums,
"abstracted_axes": abstracted_axes,
@@ -671,7 +709,19 @@ def lower_jaxpr_to_mlir(jaxpr, func_name, arg_names):
MemrefCallable.clearcache()
- with transient_jax_config({"jax_dynamic_shapes": True}):
+ # Apply JAX 0.7.0 compatibility patches during MLIR lowering.
+ # JAX internally calls trace_to_jaxpr_dynamic2 during lowering of nested @jit primitives
+ # (e.g., in jax.scipy.linalg.expm and jax.scipy.linalg.solve), which triggers two bugs:
+ # 1. make_eqn signature changed to include out_tracers parameter
+ # 2. pjit_staging_rule creates JaxprEqn instead of TracingEqn (AssertionError at partial_eval.py:1790)
+ with transient_jax_config(
+ {"jax_dynamic_shapes": True, "jax_use_shardy_partitioner": False}
+ ), Patcher(
+ # Fix make_eqn signature change (handles both old/new JAX versions)
+ (DynamicJaxprTrace, "make_eqn", patched_make_eqn),
+ # Fix pjit_staging_rule creating wrong equation type
+ (DictPatchWrapper(pe.custom_staging_rules, jit_p), "value", patched_pjit_staging_rule),
+ ):
mlir_module, ctx = jaxpr_to_mlir(jaxpr, func_name, arg_names)
return mlir_module, ctx
@@ -883,7 +933,7 @@ def bind_native_operation(qrp, op, controlled_wires, controlled_values, adjoint=
assert qrp2 is not None
qrp = qrp2
trace = EvaluationContext.get_current_trace()
- trace.frame.eqns = sort_eqns(trace.frame.eqns, FORCED_ORDER_PRIMITIVES) # [1]
+ trace.frame.tracing_eqns = sort_eqns(trace.frame.tracing_eqns, FORCED_ORDER_PRIMITIVES) # [1]
return qrp
diff --git a/frontend/catalyst/jit.py b/frontend/catalyst/jit.py
index 81f873917c..100926b932 100644
--- a/frontend/catalyst/jit.py
+++ b/frontend/catalyst/jit.py
@@ -36,7 +36,6 @@
from catalyst.compiler import CompileOptions, Compiler, canonicalize, to_llvmir, to_mlir_opt
from catalyst.debug.instruments import instrument
from catalyst.from_plxpr import trace_from_pennylane
-from catalyst.jax_extras.patches import get_aval2
from catalyst.jax_tracer import lower_jaxpr_to_mlir, trace_to_jaxpr
from catalyst.logging import debug_logger, debug_logger_init
from catalyst.qfunc import QFunc
@@ -732,22 +731,15 @@ def capture(self, args, **kwargs):
dbg = debug_info("qjit_capture", self.user_function, args, kwargs)
if qml.capture.enabled():
- with Patcher(
- (
- jax._src.interpreters.partial_eval, # pylint: disable=protected-access
- "get_aval",
- get_aval2,
- ),
- ):
- return trace_from_pennylane(
- self.user_function,
- static_argnums,
- dynamic_args,
- abstracted_axes,
- full_sig,
- kwargs,
- debug_info=dbg,
- )
+ return trace_from_pennylane(
+ self.user_function,
+ static_argnums,
+ dynamic_args,
+ abstracted_axes,
+ full_sig,
+ kwargs,
+ debug_info=dbg,
+ )
def closure(qnode, *args, **kwargs):
params = {}
diff --git a/frontend/catalyst/utils/patching.py b/frontend/catalyst/utils/patching.py
index 1f6734f06e..b547d8fd5d 100644
--- a/frontend/catalyst/utils/patching.py
+++ b/frontend/catalyst/utils/patching.py
@@ -17,6 +17,37 @@
"""
+class DictPatchWrapper:
+ """A wrapper to enable dictionary item patching using attribute-like access.
+
+ This allows the Patcher class to patch dictionary items by wrapping the dictionary
+ and key into an object where the item can be accessed as an attribute.
+
+ Args:
+ dictionary: The dictionary to wrap
+ key: The key to access in the dictionary
+ """
+
+ def __init__(self, dictionary, key):
+ self.dictionary = dictionary
+ self.key = key
+
+ def __getattr__(self, name):
+ if name in ("dictionary", "key"): # pragma: no cover
+ return object.__getattribute__(self, name)
+ if name == "value":
+ return self.dictionary[self.key]
+ raise AttributeError(f"'{type(self).__name__}' object has no attribute '{name}'")
+
+ def __setattr__(self, name, value):
+ if name in ("dictionary", "key"):
+ object.__setattr__(self, name, value)
+ elif name == "value":
+ self.dictionary[self.key] = value
+ else: # pragma: no cover
+ raise AttributeError(f"'{type(self).__name__}' object has no attribute '{name}'")
+
+
class Patcher:
"""Patcher, a class to replace object attributes.
diff --git a/frontend/test/lit/test_jax_dynamic_api.py b/frontend/test/lit/test_jax_dynamic_api.py
index ff8b810d51..c23c77680d 100644
--- a/frontend/test/lit/test_jax_dynamic_api.py
+++ b/frontend/test/lit/test_jax_dynamic_api.py
@@ -41,7 +41,7 @@ def test_qnode_dynamic_arg(a):
"""Test passing a dynamic argument to qnode"""
# CHECK: { lambda ; [[a:.]]:i64[] [[b:.]]:i64[[[a]]]. let
- # CHECK: [[c:.]]:i64[[[a]]] = quantum_kernel[
+ # CHECK: [[c:.]]:i64[InDBIdx(val=0)] = quantum_kernel[
# CHECK: ] [[a]] [[b]]
# CHECK: in ([[c]],) }
@qml.qnode(qml.device("lightning.qubit", wires=1))
@@ -75,9 +75,9 @@ def test_qnode_dynamic_result(a):
"""Test getting a dynamic result from qnode"""
# CHECK: { lambda ; [[a:.]]:i64[]. let
- # CHECK: [[b:.]]:i64[] [[c:.]]:f64[[[b]]] = quantum_kernel[
+ # CHECK: {{.+}}:i64[] [[c:.]]:f64[OutDBIdx(val=0)] = quantum_kernel[
# CHECK: ] [[a]]
- # CHECK: in ([[b]], [[c]]) }
+ # CHECK: in ([[c]],) }
@qml.qnode(qml.device("lightning.qubit", wires=1))
def _circuit(a):
return jnp.ones((a + 1,), dtype=float)
diff --git a/frontend/test/pytest/from_plxpr/test_decompose_transform.py b/frontend/test/pytest/from_plxpr/test_decompose_transform.py
index 811eaa9b1b..861bb80905 100644
--- a/frontend/test/pytest/from_plxpr/test_decompose_transform.py
+++ b/frontend/test/pytest/from_plxpr/test_decompose_transform.py
@@ -428,7 +428,7 @@ def test_multi_passes(self):
@qml.transforms.cancel_inverses
@partial(
qml.transforms.decompose,
- gate_set={"RZ", "RY", "CNOT", "GlobalPhase"},
+ gate_set=frozenset({"RZ", "RY", "CNOT", "GlobalPhase"}),
)
@qml.qnode(qml.device("lightning.qubit", wires=1))
def circuit():
diff --git a/frontend/test/pytest/from_plxpr/test_from_plxpr_qubit_handler.py b/frontend/test/pytest/from_plxpr/test_from_plxpr_qubit_handler.py
index 4d57bdf15b..3ecdbd6a03 100644
--- a/frontend/test/pytest/from_plxpr/test_from_plxpr_qubit_handler.py
+++ b/frontend/test/pytest/from_plxpr/test_from_plxpr_qubit_handler.py
@@ -45,6 +45,7 @@
from catalyst.from_plxpr.qubit_handler import QubitHandler, QubitIndexRecorder
from catalyst.jax_primitives import AbstractQbit, AbstractQreg, qalloc_p, qextract_p
+from catalyst.jax_tracer import _get_eqn_from_tracing_eqn
from catalyst.utils.exceptions import CompileError
@@ -239,13 +240,17 @@ def test_auto_extract(self):
assert qubit_handler[0] is new_qubit
with take_current_trace() as trace:
# Check that an extract primitive is added
- assert trace.frame.eqns[-1].primitive is qextract_p
+ last_eqn = _get_eqn_from_tracing_eqn(trace.frame.tracing_eqns[-1])
+ assert last_eqn.primitive is qextract_p
# Check that the extract primitive follows the wire index in the qreg manager
# __getitem__ method
- extract_p_index_invar = trace.frame.eqns[-1].invars[-1]
- assert isinstance(extract_p_index_invar, Literal)
- assert extract_p_index_invar.val == 0
+ extract_p_index_invar = last_eqn.invars[-1]
+ if isinstance(extract_p_index_invar, Literal):
+ assert extract_p_index_invar.val == 0
+ else:
+ assert isinstance(extract_p_index_invar.val, Literal)
+ assert extract_p_index_invar.val.val == 0
def test_no_overwriting_extract(self):
"""Test that no new qubit is extracted when indexing into an existing wire"""
@@ -273,9 +278,10 @@ def test_simple_gate(self):
# Also check with actual jaxpr variables
with take_current_trace() as trace:
- gate_out_qubits = trace.frame.eqns[-1].outvars
- assert trace.frame.tracer_to_var[id(qubit_handler[0])] == gate_out_qubits[0]
- assert trace.frame.tracer_to_var[id(qubit_handler[1])] == gate_out_qubits[1]
+ last_eqn = _get_eqn_from_tracing_eqn(trace.frame.tracing_eqns[-1])
+ gate_out_qubits = last_eqn.outvars
+ assert qubit_handler[0].val == gate_out_qubits[0]
+ assert qubit_handler[1].val == gate_out_qubits[1]
def test_iter(self):
"""Test __iter__ in the qreg manager"""
@@ -315,9 +321,10 @@ def test_chained_gate(self):
# Also check with actual jaxpr variables
with take_current_trace() as trace:
- gate_out_qubits = trace.frame.eqns[-1].outvars
- assert trace.frame.tracer_to_var[id(qubit_handler[0])] == gate_out_qubits[0]
- assert trace.frame.tracer_to_var[id(qubit_handler[1])] == gate_out_qubits[1]
+ last_eqn = _get_eqn_from_tracing_eqn(trace.frame.tracing_eqns[-1])
+ gate_out_qubits = last_eqn.outvars
+ assert qubit_handler[0].val == gate_out_qubits[0]
+ assert qubit_handler[1].val == gate_out_qubits[1]
def test_insert_all_dangling_qubits(self):
"""
diff --git a/frontend/test/pytest/test_preprocess.py b/frontend/test/pytest/test_preprocess.py
index a918ec09d4..916e3e7c97 100644
--- a/frontend/test/pytest/test_preprocess.py
+++ b/frontend/test/pytest/test_preprocess.py
@@ -118,14 +118,14 @@ def test_decompose_integration(self):
@qml.qnode(dev)
def circuit(theta: float):
- qml.SingleExcitationPlus(theta, wires=[0, 1])
+ qml.OrbitalRotation(theta, wires=[0, 1, 2, 3])
return qml.state()
mlir = qjit(circuit, target="mlir").mlir
+ assert "SingleExcitation" in mlir
assert "Hadamard" in mlir
- assert "CNOT" in mlir
- assert "RY" in mlir
- assert "SingleExcitationPlus" not in mlir
+ assert "RX" in mlir
+ assert "OrbitalRotation" not in mlir
def test_decompose_ops_to_unitary(self):
"""Test the decompose ops to unitary transform."""
diff --git a/frontend/test/pytest/test_verification.py b/frontend/test/pytest/test_verification.py
index e50d5c10ce..b88cf26642 100644
--- a/frontend/test/pytest/test_verification.py
+++ b/frontend/test/pytest/test_verification.py
@@ -37,7 +37,7 @@
from catalyst.api_extensions import HybridAdjoint, HybridCtrl
from catalyst.compiler import get_lib_path
from catalyst.device import get_device_capabilities
-from catalyst.device.qjit_device import RUNTIME_OPERATIONS, get_qjit_device_capabilities
+from catalyst.device.qjit_device import CUSTOM_OPERATIONS, get_qjit_device_capabilities
from catalyst.device.verification import validate_measurements
# pylint: disable = unused-argument, unnecessary-lambda-assignment, unnecessary-lambda
@@ -290,7 +290,7 @@ def test_non_controllable_gate_hybridctrl(self):
# Note: The HybridCtrl operator is not currently supported with the QJIT device, but the
# verification structure is in place, so we test the verification of its nested operators by
# adding HybridCtrl to the list of native gates for the custom base device and by patching
- # the list of RUNTIME_OPERATIONS for the QJIT device to include HybridCtrl for this test.
+ # the list of CUSTOM_OPERATIONS for the QJIT device to include HybridCtrl for this test.
@qml.qnode(
get_custom_device(
@@ -302,12 +302,12 @@ def f(x: float):
assert isinstance(op, HybridCtrl), f"op expected to be HybridCtrl but got {type(op)}"
return qml.expval(qml.PauliX(0))
- runtime_ops_with_qctrl = deepcopy(RUNTIME_OPERATIONS)
+ runtime_ops_with_qctrl = deepcopy(CUSTOM_OPERATIONS)
runtime_ops_with_qctrl["HybridCtrl"] = OperatorProperties(
invertible=True, controllable=True, differentiable=True
)
- with patch("catalyst.device.qjit_device.RUNTIME_OPERATIONS", runtime_ops_with_qctrl):
+ with patch("catalyst.device.qjit_device.CUSTOM_OPERATIONS", runtime_ops_with_qctrl):
with pytest.raises(CompileError, match="PauliZ is not controllable"):
qjit(f)(1.2)
@@ -321,7 +321,7 @@ def test_hybridctrl_raises_error(self):
"""Test that a HybridCtrl operator is rejected by the verification."""
# TODO: If you are deleting this test because HybridCtrl support has been added, consider
- # updating the tests that patch RUNTIME_OPERATIONS to inclue HybridCtrl accordingly
+ # updating the tests that patch CUSTOM_OPERATIONS to inclue HybridCtrl accordingly
@qml.qnode(get_custom_device(non_controllable_gates={"PauliZ"}, wires=4))
def f(x: float):
@@ -391,7 +391,7 @@ def test_hybrid_ctrl_containing_adjoint(self, adjoint_type, unsupported_gate_att
# Note: The HybridCtrl operator is not currently supported with the QJIT device, but the
# verification structure is in place, so we test the verification of its nested operators by
# adding HybridCtrl to the list of native gates for the custom base device and by patching
- # the list of RUNTIME_OPERATIONS for the QJIT device to include HybridCtrl for this test.
+ # the list of CUSTOM_OPERATIONS for the QJIT device to include HybridCtrl for this test.
def _ops(x, wires):
if adjoint_type == HybridAdjoint:
@@ -410,12 +410,12 @@ def f(x: float):
assert isinstance(base, adjoint_type), f"expected {adjoint_type} but got {type(op)}"
return qml.expval(qml.PauliX(0))
- runtime_ops_with_qctrl = deepcopy(RUNTIME_OPERATIONS)
+ runtime_ops_with_qctrl = deepcopy(CUSTOM_OPERATIONS)
runtime_ops_with_qctrl["HybridCtrl"] = OperatorProperties(
invertible=True, controllable=True, differentiable=True
)
- with patch("catalyst.device.qjit_device.RUNTIME_OPERATIONS", runtime_ops_with_qctrl):
+ with patch("catalyst.device.qjit_device.CUSTOM_OPERATIONS", runtime_ops_with_qctrl):
with pytest.raises(CompileError, match=f"PauliZ is not {unsupported_gate_attribute}"):
qjit(f)(1.2)
@@ -434,7 +434,7 @@ def test_hybrid_adjoint_containing_hybrid_ctrl(self, ctrl_type, unsupported_gate
# Note: The HybridCtrl operator is not currently supported with the QJIT device, but the
# verification structure is in place, so we test the verification of its nested operators by
# adding HybridCtrl to the list of native gates for the custom base device and by patching
- # the list of RUNTIME_OPERATIONS for the QJIT device to include HybridCtrl for this test.
+ # the list of CUSTOM_OPERATIONS for the QJIT device to include HybridCtrl for this test.
def _ops(x, wires):
if ctrl_type == HybridCtrl:
@@ -453,12 +453,12 @@ def f(x: float):
assert isinstance(base, ctrl_type), f"expected {ctrl_type} but got {type(op)}"
return qml.expval(qml.PauliX(0))
- runtime_ops_with_qctrl = deepcopy(RUNTIME_OPERATIONS)
+ runtime_ops_with_qctrl = deepcopy(CUSTOM_OPERATIONS)
runtime_ops_with_qctrl["HybridCtrl"] = OperatorProperties(
invertible=True, controllable=True, differentiable=True
)
- with patch("catalyst.device.qjit_device.RUNTIME_OPERATIONS", runtime_ops_with_qctrl):
+ with patch("catalyst.device.qjit_device.CUSTOM_OPERATIONS", runtime_ops_with_qctrl):
with pytest.raises(CompileError, match=f"PauliZ is not {unsupported_gate_attribute}"):
qjit(f)(1.2)