diff --git a/.dep-versions b/.dep-versions index 5b6bea4912..aecf39d601 100644 --- a/.dep-versions +++ b/.dep-versions @@ -1,7 +1,7 @@ # Always update the version check in catalyst.__init__ when changing the JAX version. # To update JAX version alongside compatible dependency tags, run the following script: # python3 .github/workflows/set_dep_versions.py {JAX_version} -jax=0.6.2 +jax=0.7.0 stablehlo=0a4440a5c8de45c4f9649bf3eb4913bf3f97da0d llvm=113f01aa82d055410f22a9d03b3468fa68600589 enzyme=v0.0.203 @@ -10,7 +10,9 @@ enzyme=v0.0.203 # For a custom PL version, update the package version here and at # 'doc/requirements.txt' -pennylane=0.44.0.dev31 +# TODO: uncomment and update to latest version of pennylane +# after https://github.com/PennyLaneAI/pennylane/pull/8525 is merged. +# pennylane=0.44.0.dev31 # For a custom LQ/LK version, update the package version here and at # 'doc/requirements.txt' diff --git a/Makefile b/Makefile index d0ac210b8d..bd921e823f 100644 --- a/Makefile +++ b/Makefile @@ -117,6 +117,8 @@ frontend: # versions of a package with the same version tag (e.g. 0.38-dev0). $(PYTHON) -m pip uninstall -y pennylane $(PYTHON) -m pip install -e . --extra-index-url https://test.pypi.org/simple $(PIP_VERBOSE_FLAG) + # TODO: remove after https://github.com/PennyLaneAI/pennylane/pull/8525 is merged. + $(PYTHON) -m pip install git+https://github.com/PennyLaneAI/pennylane@bump-jax-api-hashability rm -r frontend/pennylane_catalyst.egg-info .PHONY: mlir llvm stablehlo enzyme dialects runtime oqc @@ -134,7 +136,7 @@ enzyme: dialects: $(MAKE) -C mlir dialects - + .PHONY: dialect-docs dialect-docs: $(MAKE) -C mlir dialect-docs diff --git a/doc/releases/changelog-dev.md b/doc/releases/changelog-dev.md index 46e4f31da7..f653af2aa6 100644 --- a/doc/releases/changelog-dev.md +++ b/doc/releases/changelog-dev.md @@ -69,6 +69,13 @@

Improvements 🛠

+* Remove the hardcoded list of runtime operations in the frontend. + This will allow arbitrary PL gates to be represented without hyperparameters in MLIR. + For gates that do not have a QIR representation, a runtime error will be raised at execution. + Users can still decompose these gates via `qml.transforms.decompose` + when both capture and graph-decomposition are enabled. + [(#2215)](https://github.com/PennyLaneAI/catalyst/pull/2215) + * `qml.PCPhase` can be compiled and executed with capture enabled. [(#2226)](https://github.com/PennyLaneAI/catalyst/pull/2226) diff --git a/frontend/catalyst/__init__.py b/frontend/catalyst/__init__.py index eeaf6aa1be..4a50fb5859 100644 --- a/frontend/catalyst/__init__.py +++ b/frontend/catalyst/__init__.py @@ -18,12 +18,11 @@ # pylint: disable=wrong-import-position import sys -import types from os.path import dirname import jaxlib as _jaxlib -_jaxlib_version = "0.6.2" +_jaxlib_version = "0.7.0" if _jaxlib.__version__ != _jaxlib_version: import warnings diff --git a/frontend/catalyst/autograph/ag_primitives.py b/frontend/catalyst/autograph/ag_primitives.py index 61c0a06839..bed3c5cbfd 100644 --- a/frontend/catalyst/autograph/ag_primitives.py +++ b/frontend/catalyst/autograph/ag_primitives.py @@ -62,7 +62,7 @@ def get_program_length(reference_tracers): if EvaluationContext.is_tracing(): # pragma: no branch jaxpr_frame = EvaluationContext.find_jaxpr_frame(reference_tracers) - num_jaxpr_eqns = len(jaxpr_frame.eqns) + num_jaxpr_eqns = len(jaxpr_frame.tracing_eqns) if EvaluationContext.is_quantum_tracing(): quantum_queue = EvaluationContext.find_quantum_queue() @@ -79,8 +79,8 @@ def reset_program_to_length(reference_tracers, num_jaxpr_eqns, num_tape_ops): if EvaluationContext.is_tracing(): # pragma: no branch jaxpr_frame = EvaluationContext.find_jaxpr_frame(reference_tracers) - while len(jaxpr_frame.eqns) > num_jaxpr_eqns: - jaxpr_frame.eqns.pop() + while len(jaxpr_frame.tracing_eqns) > num_jaxpr_eqns: + jaxpr_frame.tracing_eqns.pop() if EvaluationContext.is_quantum_tracing(): quantum_queue = EvaluationContext.find_quantum_queue() diff --git a/frontend/catalyst/device/qjit_device.py b/frontend/catalyst/device/qjit_device.py index b4f92bf022..848a638d48 100644 --- a/frontend/catalyst/device/qjit_device.py +++ b/frontend/catalyst/device/qjit_device.py @@ -56,44 +56,6 @@ logger = logging.getLogger(__name__) logger.addHandler(logging.NullHandler()) -RUNTIME_OPERATIONS = [ - "CNOT", - "ControlledPhaseShift", - "CRot", - "CRX", - "CRY", - "CRZ", - "CSWAP", - "CY", - "CZ", - "Hadamard", - "Identity", - "IsingXX", - "IsingXY", - "IsingYY", - "IsingZZ", - "SingleExcitation", - "DoubleExcitation", - "ISWAP", - "MultiRZ", - "PauliX", - "PauliY", - "PauliZ", - "PCPhase", - "PhaseShift", - "PSWAP", - "QubitUnitary", - "Rot", - "RX", - "RY", - "RZ", - "S", - "SWAP", - "T", - "Toffoli", - "GlobalPhase", -] - RUNTIME_OBSERVABLES = [ "Identity", "PauliX", @@ -109,11 +71,9 @@ RUNTIME_MPS = ["ExpectationMP", "SampleMP", "VarianceMP", "CountsMP", "StateMP", "ProbabilityMP"] -# The runtime interface does not care about specific gate properties, so set them all to True. -RUNTIME_OPERATIONS = { - op: OperatorProperties(invertible=True, controllable=True, differentiable=True) - for op in RUNTIME_OPERATIONS -} +# A list of custom operations supported by the Catalyst compiler. +# This is useful especially for testing a device with custom operations. +CUSTOM_OPERATIONS = {} RUNTIME_OBSERVABLES = { obs: OperatorProperties(invertible=True, controllable=True, differentiable=True) @@ -199,6 +159,14 @@ def extract_backend_info(device: qml.devices.QubitDevice) -> BackendInfo: return BackendInfo(dname, device_name, device_lpath, device_kwargs) +def union_operations( + a: Dict[str, OperatorProperties], b: Dict[str, OperatorProperties] +) -> Dict[str, OperatorProperties]: + """Union of two sets of operator properties""" + return {**a, **b} + # return {k: a[k] & b[k] for k in (a.keys() & b.keys())} + + def intersect_operations( a: Dict[str, OperatorProperties], b: Dict[str, OperatorProperties] ) -> Dict[str, OperatorProperties]: @@ -223,8 +191,8 @@ def get_qjit_device_capabilities(target_capabilities: DeviceCapabilities) -> Dev qjit_capabilities = deepcopy(target_capabilities) # Intersection of gates and observables supported by the device and by Catalyst runtime. - qjit_capabilities.operations = intersect_operations( - target_capabilities.operations, RUNTIME_OPERATIONS + qjit_capabilities.operations = union_operations( + target_capabilities.operations, CUSTOM_OPERATIONS ) qjit_capabilities.observables = intersect_operations( target_capabilities.observables, RUNTIME_OBSERVABLES diff --git a/frontend/catalyst/from_plxpr/control_flow.py b/frontend/catalyst/from_plxpr/control_flow.py index 1569daf60b..1a42c3bdc9 100644 --- a/frontend/catalyst/from_plxpr/control_flow.py +++ b/frontend/catalyst/from_plxpr/control_flow.py @@ -25,7 +25,11 @@ from pennylane.capture.primitives import for_loop_prim as plxpr_for_loop_prim from pennylane.capture.primitives import while_loop_prim as plxpr_while_loop_prim -from catalyst.from_plxpr.from_plxpr import PLxPRToQuantumJaxprInterpreter, WorkflowInterpreter +from catalyst.from_plxpr.from_plxpr import ( + PLxPRToQuantumJaxprInterpreter, + WorkflowInterpreter, + _tuple_to_slice, +) from catalyst.from_plxpr.qubit_handler import ( QubitHandler, QubitIndexRecorder, @@ -101,8 +105,13 @@ def _to_bool_if_not(arg): @WorkflowInterpreter.register_primitive(plxpr_cond_prim) def workflow_cond(self, *plxpr_invals, jaxpr_branches, consts_slices, args_slice): - """Handle the conversion from plxpr to Catalyst jaxpr for the cond primitive""" - args = plxpr_invals[args_slice] + """Handle the conversion from plxpr to Catalyst jaxpr for the cond primitive + + Args: + consts_slices: List of tuples (start, stop, step) to slice consts for each branch + args_slice: Tuple (start, stop, step) to slice args from plxpr_invals + """ + args = plxpr_invals[_tuple_to_slice(args_slice)] converted_jaxpr_branches = [] all_consts = [] @@ -110,7 +119,7 @@ def workflow_cond(self, *plxpr_invals, jaxpr_branches, consts_slices, args_slice for const_slice, plxpr_branch in zip(consts_slices, jaxpr_branches): # Store all branches consts in a flat list - branch_consts = plxpr_invals[const_slice] + branch_consts = plxpr_invals[_tuple_to_slice(const_slice)] evaluator = partial(copy(self).eval, plxpr_branch, branch_consts) new_jaxpr = jax.make_jaxpr(evaluator)(*args) @@ -132,8 +141,13 @@ def workflow_cond(self, *plxpr_invals, jaxpr_branches, consts_slices, args_slice @PLxPRToQuantumJaxprInterpreter.register_primitive(plxpr_cond_prim) def handle_cond(self, *plxpr_invals, jaxpr_branches, consts_slices, args_slice): - """Handle the conversion from plxpr to Catalyst jaxpr for the cond primitive""" - args = plxpr_invals[args_slice] + """Handle the conversion from plxpr to Catalyst jaxpr for the cond primitive + + Args: + consts_slices: List of tuples (start, stop, step) to slice consts for each branch + args_slice: Tuple (start, stop, step) to slice args from plxpr_invals + """ + args = plxpr_invals[_tuple_to_slice(args_slice)] self.init_qreg.insert_all_dangling_qubits() dynalloced_qregs, dynalloced_wire_global_indices = _get_dynamically_allocated_qregs( @@ -154,7 +168,7 @@ def handle_cond(self, *plxpr_invals, jaxpr_branches, consts_slices, args_slice): for const_slice, plxpr_branch in zip(consts_slices, jaxpr_branches): # Store all branches consts in a flat list - branch_consts = plxpr_invals[const_slice] + branch_consts = plxpr_invals[_tuple_to_slice(const_slice)] converted_jaxpr_branch = None closed_jaxpr = ClosedJaxpr(plxpr_branch, branch_consts) @@ -205,11 +219,17 @@ def workflow_for_loop( args_slice, abstract_shapes_slice, ): - """Handle the conversion from plxpr to Catalyst jaxpr for the for loop primitive""" + """Handle the conversion from plxpr to Catalyst jaxpr for the for loop primitive + + Args: + consts_slice: Tuple (start, stop, step) to slice consts from plxpr_invals + args_slice: Tuple (start, stop, step) to slice args from plxpr_invals + abstract_shapes_slice: Tuple (start, stop, step) to slice abstract shapes + """ assert jaxpr_body_fn is not None - args = plxpr_invals[args_slice] + args = plxpr_invals[_tuple_to_slice(args_slice)] - consts = plxpr_invals[consts_slice] + consts = plxpr_invals[_tuple_to_slice(consts_slice)] converter = copy(self) evaluator = partial(converter.eval, jaxpr_body_fn, consts) @@ -250,9 +270,15 @@ def handle_for_loop( args_slice, abstract_shapes_slice, ): - """Handle the conversion from plxpr to Catalyst jaxpr for the for loop primitive""" + """Handle the conversion from plxpr to Catalyst jaxpr for the for loop primitive + + Args: + consts_slice: Tuple (start, stop, step) to slice consts from plxpr_invals + args_slice: Tuple (start, stop, step) to slice args from plxpr_invals + abstract_shapes_slice: Tuple (start, stop, step) to slice abstract shapes + """ assert jaxpr_body_fn is not None - args = plxpr_invals[args_slice] + args = plxpr_invals[_tuple_to_slice(args_slice)] # Add the iteration start and the qreg to the args self.init_qreg.insert_all_dangling_qubits() @@ -268,7 +294,7 @@ def handle_for_loop( self.init_qreg.get(), ] - consts = plxpr_invals[consts_slice] + consts = plxpr_invals[_tuple_to_slice(consts_slice)] jaxpr = ClosedJaxpr(jaxpr_body_fn, consts) @@ -326,10 +352,16 @@ def workflow_while_loop( cond_slice, args_slice, ): - """Handle the conversion from plxpr to Catalyst jaxpr for the while loop primitive""" - consts_body = plxpr_invals[body_slice] - consts_cond = plxpr_invals[cond_slice] - args = plxpr_invals[args_slice] + """Handle the conversion from plxpr to Catalyst jaxpr for the while loop primitive + + Args: + body_slice: Tuple (start, stop, step) to slice body consts from plxpr_invals + cond_slice: Tuple (start, stop, step) to slice cond consts from plxpr_invals + args_slice: Tuple (start, stop, step) to slice args from plxpr_invals + """ + consts_body = plxpr_invals[_tuple_to_slice(body_slice)] + consts_cond = plxpr_invals[_tuple_to_slice(cond_slice)] + args = plxpr_invals[_tuple_to_slice(args_slice)] evaluator_body = partial(copy(self).eval, jaxpr_body_fn, consts_body) new_body_jaxpr = jax.make_jaxpr(evaluator_body)(*args) @@ -367,14 +399,20 @@ def handle_while_loop( cond_slice, args_slice, ): - """Handle the conversion from plxpr to Catalyst jaxpr for the while loop primitive""" + """Handle the conversion from plxpr to Catalyst jaxpr for the while loop primitive + + Args: + body_slice: Tuple (start, stop, step) to slice body consts from plxpr_invals + cond_slice: Tuple (start, stop, step) to slice cond consts from plxpr_invals + args_slice: Tuple (start, stop, step) to slice args from plxpr_invals + """ self.init_qreg.insert_all_dangling_qubits() dynalloced_qregs, dynalloced_wire_global_indices = _get_dynamically_allocated_qregs( plxpr_invals, self.qubit_index_recorder, self.init_qreg ) - consts_body = plxpr_invals[body_slice] - consts_cond = plxpr_invals[cond_slice] - args = plxpr_invals[args_slice] + consts_body = plxpr_invals[_tuple_to_slice(body_slice)] + consts_cond = plxpr_invals[_tuple_to_slice(cond_slice)] + args = plxpr_invals[_tuple_to_slice(args_slice)] args_plus_qreg = [ *args, *[dyn_qreg.get() for dyn_qreg in dynalloced_qregs], diff --git a/frontend/catalyst/from_plxpr/from_plxpr.py b/frontend/catalyst/from_plxpr/from_plxpr.py index d819d55a5e..e46bec5349 100644 --- a/frontend/catalyst/from_plxpr/from_plxpr.py +++ b/frontend/catalyst/from_plxpr/from_plxpr.py @@ -40,6 +40,7 @@ from catalyst.device import extract_backend_info from catalyst.from_plxpr.decompose import COMPILER_OPS_FOR_DECOMPOSITION, DecompRuleInterpreter from catalyst.jax_extras import make_jaxpr2, transient_jax_config +from catalyst.jax_extras.patches import patched_make_eqn from catalyst.jax_primitives import ( device_init_p, device_release_p, @@ -48,6 +49,7 @@ quantum_kernel_p, ) from catalyst.passes.pass_api import Pass +from catalyst.utils.patching import Patcher from .qfunc_interpreter import PLxPRToQuantumJaxprInterpreter from .qubit_handler import ( @@ -56,6 +58,64 @@ ) +def _tuple_to_slice(t): + """Convert a tuple representation of a slice back to a slice object. + + JAX converts slice objects to tuples for hashability in jaxpr parameters. + This function converts them back to slice objects for use with indexing. + + Args: + t: Either a slice object (returned as-is) or a tuple (start, stop, step) + + Returns: + slice: A slice object + """ + assert ( + isinstance(t, tuple) and len(t) == 3 + ), "Please only use _tuple_to_slice on a tuple of length 3!" + return slice(*t) + + +def _is_dict_like_tuple(t): + """Checks if a tuple t is structured like a list of (key, value) pairs.""" + return isinstance(t, tuple) and all(isinstance(item, tuple) and len(item) == 2 for item in t) + + +def _tuple_to_dict(t): + """ + Recursively converts JAX-hashable tuple representations back to dicts, + and list-like tuples back to lists. + + Args: + t: The item to convert. Can be a dict, a tuple, or a scalar. + + Returns: + The converted dict, list, or the original scalar value. + """ + + if not isinstance(t, (dict, tuple, list)): + return t + + if isinstance(t, dict): # pragma: no cover + return {k: _tuple_to_dict(v) for k, v in t.items()} + + if isinstance(t, list): # pragma: no cover + return [_tuple_to_dict(item) for item in t] + + if isinstance(t, tuple): + + # A. Dict-like tuple: Convert to dict, then recurse on values + if _is_dict_like_tuple(t): + # This handles the main (key, value) pair structure + return {key: _tuple_to_dict(value) for key, value in t} + + # B. List-like tuple: Convert to list, then recurse on elements + else: + return [_tuple_to_dict(item) for item in t] + + return t # pragma: no cover + + def _get_device_kwargs(device) -> dict: """Calulcate the params for a device equation.""" info = extract_backend_info(device) @@ -132,7 +192,19 @@ def f(x): in (b,) } """ - return jax.make_jaxpr(partial(WorkflowInterpreter().eval, plxpr.jaxpr, plxpr.consts)) + + original_fn = partial(WorkflowInterpreter().eval, plxpr.jaxpr, plxpr.consts) + + # pylint: disable=import-outside-toplevel + from jax._src.interpreters.partial_eval import DynamicJaxprTrace + + def wrapped_fn(*args, **kwargs): + with Patcher( + (DynamicJaxprTrace, "make_eqn", patched_make_eqn), + ): + return jax.make_jaxpr(original_fn)(*args, **kwargs) + + return wrapped_fn class WorkflowInterpreter(PlxprInterpreter): @@ -291,9 +363,10 @@ def handle_transform( ): """Handle the conversion from plxpr to Catalyst jaxpr for a PL transform.""" - consts = args[consts_slice] - non_const_args = args[args_slice] - targs = args[targs_slice] + consts = args[_tuple_to_slice(consts_slice)] + non_const_args = args[_tuple_to_slice(args_slice)] + targs = args[_tuple_to_slice(targs_slice)] + tkwargs = _tuple_to_dict(tkwargs) # If the transform is a decomposition transform # and the graph-based decomposition is enabled @@ -302,6 +375,7 @@ def handle_transform( and transform._plxpr_transform.__name__ == "decompose_plxpr_to_plxpr" and qml.decomposition.enabled_graph() ): + # Handle the conversion from plxpr to Catalyst jaxpr for a PL transform. if not self.requires_decompose_lowering: self.requires_decompose_lowering = True else: @@ -397,7 +471,34 @@ def trace_from_pennylane( Tuple[Any]: the dynamic argument signature """ - with transient_jax_config({"jax_dynamic_shapes": True}): + # pylint: disable=import-outside-toplevel + import jax._src.interpreters.partial_eval as pe + from jax._src.interpreters.partial_eval import DynamicJaxprTrace + from jax._src.lax import lax + from jax._src.pjit import jit_p + + from catalyst.jax_extras.patches import ( + get_aval2, + patched_drop_unused_vars, + patched_dyn_shape_staging_rule, + patched_pjit_staging_rule, + ) + from catalyst.utils.patching import DictPatchWrapper + + with transient_jax_config( + {"jax_dynamic_shapes": True, "jax_use_shardy_partitioner": False} + ), Patcher( + (pe, "_drop_unused_vars", patched_drop_unused_vars), + (DynamicJaxprTrace, "make_eqn", patched_make_eqn), + (lax, "_dyn_shape_staging_rule", patched_dyn_shape_staging_rule), + ( + jax._src.pjit, # pylint: disable=protected-access + "pjit_staging_rule", + patched_pjit_staging_rule, + ), + (DictPatchWrapper(pe.custom_staging_rules, jit_p), "value", patched_pjit_staging_rule), + (pe, "get_aval", get_aval2), + ): make_jaxpr_kwargs = { "static_argnums": static_argnums, diff --git a/frontend/catalyst/jax_extras/lowering.py b/frontend/catalyst/jax_extras/lowering.py index cd9a84a121..ec7b583d52 100644 --- a/frontend/catalyst/jax_extras/lowering.py +++ b/frontend/catalyst/jax_extras/lowering.py @@ -23,8 +23,6 @@ from jax._src.effects import ordered_effects as jax_ordered_effects from jax._src.interpreters.mlir import _module_name_regex from jax._src.sharding_impls import AxisEnv, ReplicaAxisContext -from jax._src.source_info_util import new_name_stack -from jax._src.util import wrap_name from jax.extend.core import ClosedJaxpr from jax.interpreters.mlir import ( AxisContext, @@ -73,7 +71,6 @@ def jaxpr_to_mlir(jaxpr, func_name, arg_names): nrep = jaxpr_replicas(jaxpr) effects = jax_ordered_effects.filter_in(jaxpr.effects) axis_context = ReplicaAxisContext(AxisEnv(nrep, (), ())) - name_stack = new_name_stack(wrap_name("ok", "jit")) module, context = custom_lower_jaxpr_to_module( func_name="jit_" + func_name, module_name=func_name, @@ -81,7 +78,6 @@ def jaxpr_to_mlir(jaxpr, func_name, arg_names): effects=effects, platform="cpu", axis_context=axis_context, - name_stack=name_stack, arg_names=arg_names, ) @@ -97,7 +93,6 @@ def custom_lower_jaxpr_to_module( effects, platform: str, axis_context: AxisContext, - name_stack, replicated_args=None, arg_names=None, arg_shardings=None, @@ -145,27 +140,34 @@ def custom_lower_jaxpr_to_module( # XLA computation preserves the module name. module_name = _module_name_regex.sub("_", module_name) ctx.module.operation.attributes["sym_name"] = ir.StringAttr.get(module_name) + + # Use main_function=False to preserve the function name (e.g., "jit_func") + # instead of renaming it to "main" lower_jaxpr_to_fun( ctx, func_name, jaxpr, effects, - public=True, + main_function=False, replicated_args=replicated_args, arg_names=arg_names, arg_shardings=arg_shardings, result_shardings=result_shardings, - name_stack=name_stack, ) + # Set the entry point function visibility to public and other functions to internal worklist = [*ctx.module.body.operations] while worklist: op = worklist.pop() func_name = str(op.name) is_entry_point = func_name.startswith('"jit_') + if is_entry_point: + # Keep entry point functions public + op.attributes["sym_visibility"] = ir.StringAttr.get("public") continue if isinstance(op, FuncOp): + # Set non-entry functions to internal linkage op.attributes["llvm.linkage"] = ir.Attribute.parse("#llvm.linkage") if isinstance(op, ModuleOp): worklist += [*op.body.operations] diff --git a/frontend/catalyst/jax_extras/patches.py b/frontend/catalyst/jax_extras/patches.py index 82ea21984b..dee87eea55 100644 --- a/frontend/catalyst/jax_extras/patches.py +++ b/frontend/catalyst/jax_extras/patches.py @@ -20,7 +20,15 @@ from functools import partial import jax -from jax._src.core import abstractify, standard_vma_rule +import jax._src.interpreters.partial_eval as pe +from jax._src import config, core, source_info_util +from jax._src.core import JaxprEqnContext, abstractify, standard_vma_rule +from jax._src.interpreters.partial_eval import ( + DynamicJaxprTracer, + TracingEqn, + compute_on, + xla_metadata_lib, +) from jax._src.lax.slicing import ( _argnum_weak_type, _gather_dtype_rule, @@ -32,6 +40,8 @@ _sorted_dims_in_range, standard_primitive, ) +from jax._src.pjit import _out_type, _pjit_forwarding, jit_p +from jax._src.sharding_impls import UnspecifiedValue from jax.core import AbstractValue, Tracer __all__ = ( @@ -40,6 +50,10 @@ "_no_clean_up_dead_vars", "_gather_shape_rule_dynamic", "gather2_p", + "patched_drop_unused_vars", + "patched_make_eqn", + "patched_dyn_shape_staging_rule", + "patched_pjit_staging_rule", ) @@ -66,11 +80,13 @@ def __getattr__(self, name): return MockAttributeWrapper(obj) -def _drop_unused_vars2(jaxpr, constvals): +def _drop_unused_vars2( + constvars, constvals, eqns=None, outvars=None +): # pylint: disable=unused-argument """ A patch to not drop unused vars during classical tracing of control flow. """ - return jaxpr, list(constvals) + return constvars, list(constvals) def get_aval2(x): @@ -231,3 +247,135 @@ def _gather_shape_rule_dynamic( sharding_rule=_gather_sharding_rule, vma_rule=partial(standard_vma_rule, "gather"), ) + +# pylint: disable=protected-access +original_drop_unused_vars = pe._drop_unused_vars + + +# pylint: disable=too-many-function-args +def patched_drop_unused_vars(constvars, constvals, eqns=None, outvars=None): + """Patched drop_unused_vars to ensure constvals is a list.""" + constvars, constvals = original_drop_unused_vars(constvars, constvals, eqns, outvars) + return constvars, list(constvals) + + +# pylint: disable=too-many-positional-arguments +def patched_make_eqn( + self, + in_tracers, + out_avals, + primitive, + params, + effects, + source_info=None, + ctx=None, + out_tracers=None, +): + """Patched make_eqn for DynamicJaxprTrace""" + + # Helper function (replaces make_eqn_internal) + def make_eqn_internal(out_avals_list, out_tracers): + source_info_final = source_info or source_info_util.new_source_info() + ctx_final = ctx or JaxprEqnContext( + compute_on.current_compute_type(), + config.threefry_partitionable.value, + xla_metadata_lib.current_xla_metadata(), + ) + + if out_tracers is not None: + outvars = [tracer.val for tracer in out_tracers] + eqn = TracingEqn( + in_tracers, + outvars, + primitive, + params, + effects, + source_info_final, + ctx_final, + ) + return eqn, out_tracers + else: + outvars = list(map(self.frame.newvar, out_avals_list)) + eqn = TracingEqn( + in_tracers, + outvars, + primitive, + params, + effects, + source_info_final, + ctx_final, + ) + out_tracers_new = [ + DynamicJaxprTracer(self, aval, v, source_info_final, eqn) + for aval, v in zip(out_avals_list, outvars) + ] + return eqn, out_tracers_new + + # Normalize out_avals to a list if it's a single AbstractValue + if not isinstance(out_avals, (list, tuple)): + # It's a single aval, wrap it in a list + out_avals_list = [out_avals] + eqn, out_tracers_result = make_eqn_internal(out_avals_list, out_tracers) + # Return single tracer instead of list + return eqn, (out_tracers_result[0] if len(out_tracers_result) == 1 else out_tracers_result) + else: + return make_eqn_internal(out_avals, out_tracers) + + +def patched_dyn_shape_staging_rule(trace, source_info, prim, out_aval, *args, **params): + """Patched _dyn_shape_staging_rule for dynamic shape handling.""" + eqn, out_tracer = trace.make_eqn(args, out_aval, prim, params, core.no_effects, source_info) + trace.frame.add_eqn(eqn) + return out_tracer + + +def patched_pjit_staging_rule(trace, source_info, *args, **params): + """Patched pjit_staging_rule for pjit compatibility.""" + # If we're inlining, no need to compute forwarding information; the inlined + # computation will in effect forward things. + if ( + params["inline"] + and all(isinstance(i, UnspecifiedValue) for i in params["in_shardings"]) + and all(isinstance(o, UnspecifiedValue) for o in params["out_shardings"]) + and all(i is None for i in params["in_layouts"]) + and all(o is None for o in params["out_layouts"]) + ): + jaxpr = params["jaxpr"] + if config.dynamic_shapes.value: + # Inline jaxpr doesn't handle dynamic shapes when inlining. If dynamic + # shapes are enabled, use eval_jaxpr, which uses the tracing machinery, + # but redundantly performs abstract evaluation again. + with core.set_current_trace(trace): + out = core.eval_jaxpr(jaxpr.jaxpr, jaxpr.consts, *args, propagate_source_info=False) + else: + out = pe.inline_jaxpr_into_trace(trace, source_info, jaxpr.jaxpr, jaxpr.consts, *args) + return [trace.to_jaxpr_tracer(x, source_info) for x in out] + + jaxpr = params["jaxpr"] + if config.dynamic_shapes.value: + jaxpr, in_fwd, out_shardings, out_layouts = _pjit_forwarding( + jaxpr, params["out_shardings"], params["out_layouts"] + ) + params = dict(params, jaxpr=jaxpr, out_shardings=out_shardings, out_layouts=out_layouts) + outvars = list(map(trace.frame.newvar, _out_type(jaxpr))) + out_avals = [v.aval for v in outvars] + out_tracers = [ + pe.DynamicJaxprTracer(trace, aval, v, source_info) + for aval, v in zip(out_avals, outvars) + ] + eqn, out_tracers = trace.make_eqn( + args, + out_avals, + jit_p, + params, + jaxpr.effects, + source_info, + out_tracers=out_tracers, + ) + trace.frame.add_eqn(eqn) + out_tracers_ = iter(out_tracers) + out_tracers = [args[f] if isinstance(f, int) else next(out_tracers_) for f in in_fwd] + assert next(out_tracers_, None) is None + else: + out_tracers = trace.default_process_primitive(jit_p, args, params, source_info=source_info) + return out_tracers diff --git a/frontend/catalyst/jax_extras/tracing.py b/frontend/catalyst/jax_extras/tracing.py index 8798fb65d9..969b7edb4c 100644 --- a/frontend/catalyst/jax_extras/tracing.py +++ b/frontend/catalyst/jax_extras/tracing.py @@ -13,7 +13,7 @@ # limitations under the License. """Jax extras module containing functions related to the Python program tracing""" -# pylint: disable=line-too-long +# pylint: disable=line-too-long,too-many-lines from __future__ import annotations @@ -47,7 +47,6 @@ eval_jaxpr, find_top_trace, gensym, - new_jaxpr_eqn, ) from jax.extend.core import ClosedJaxpr, Jaxpr, JaxprEqn from jax.extend.core import Primitive as JaxprPrimitive @@ -197,8 +196,23 @@ def stable_toposort(end_nodes: list) -> list: return sorted_nodes -def sort_eqns(eqns: List[JaxprEqn], forced_order_primitives: Set[JaxprPrimitive]) -> List[JaxprEqn]: - """Topologically sort JAXRR equations in a unsorted list of equations, based on their +class Box: + """Wrapper for TracingEqn keeping track of its id and parents.""" + + def __init__(self, boxid: int, e: JaxprEqn): + self.id: int = boxid + self.e: JaxprEqn = e + self.parents: List["Box"] = [] # to be filled later + + def __lt__(self, other): + return self.id < other.id + + +def sort_eqns( + eqns: List[JaxprEqn | Callable[[], JaxprEqn]], + forced_order_primitives: Set[JaxprPrimitive], +) -> List[JaxprEqn | Callable[[], JaxprEqn]]: + """Topologically sort TracingEqns in a unsorted list of equations, based on their input/output variables and additional criterias.""" # The procedure goes as follows: [1] - initialize the `origin` map mapping variable identifiers @@ -206,18 +220,18 @@ def sort_eqns(eqns: List[JaxprEqn], forced_order_primitives: Set[JaxprPrimitive] # correct values, [3] - add additional equation order restrictions to boxes, [4] - call the # topological sorting. - class Box: - """Wrapper for JaxprEqn keeping track of its id and parents.""" - - def __init__(self, boxid: int, e: JaxprEqn): - self.id: int = boxid - self.e: JaxprEqn = e - self.parents: List["Box"] = [] # to be filled later + # JAX 0.7+: eqns might be lambda functions that return TracingEqn or weakrefs + # We need to preserve the mapping from actual_eqn to the original callable + actual_eqns = [] + eqn_to_callable = {} # Maps actual TracingEqn to its original callable/wrapper + for eqn_or_callable in eqns: + eqn = eqn_or_callable() if callable(eqn_or_callable) else eqn_or_callable + assert eqn is not None + actual_eqns.append(eqn) + eqn_to_callable[id(eqn)] = eqn_or_callable - def __lt__(self, other): - return self.id < other.id + boxes = [Box(i, e) for i, e in enumerate(actual_eqns)] - boxes = [Box(i, e) for i, e in enumerate(eqns)] fixedorder = [(i, b) for (i, b) in enumerate(boxes) if b.e.primitive in forced_order_primitives] origin: Dict[int, Box] = {} for b in boxes: @@ -225,15 +239,28 @@ def __lt__(self, other): for b in boxes: b.parents = [] for v in b.e.invars: - if not isinstance(v, jax._src.core.Var): + if hasattr(v, "val") and isinstance(v.val, jax._src.core.Var): + actual_var = v.val + elif isinstance(v, jax._src.core.Var): + actual_var = v + else: # constant literal invar, no need to track def use order continue - if v.count in origin: - b.parents.append(origin[v.count]) # [2] + + if actual_var.count in origin: + b.parents.append(origin[actual_var.count]) # [2] for i, q in fixedorder: for b in boxes[i + 1 :]: b.parents.append(q) # [3] - return [b.e for b in stable_toposort(boxes)] # [4] + + sorted_boxes = stable_toposort(boxes) + + # Restore the original callables/wrappers for the sorted equations + result = [] + for b in sorted_boxes: + result.append(eqn_to_callable.get(id(b.e), b.e)) + + return result # [4] def jaxpr_pad_consts(jaxprs: List[Jaxpr]) -> List[ClosedJaxpr]: @@ -428,9 +455,8 @@ def trace_to_jaxpr( def new_inner_tracer(trace: DynamicJaxprTrace, aval) -> DynamicJaxprTracer: """Create a JAX tracer tracing an abstract value ``aval`, without specifying its source primitive.""" - dt = DynamicJaxprTracer(trace, aval, current_source_info()) - trace.frame.tracers.append(dt) - trace.frame.tracer_to_var[id(dt)] = trace.frame.newvar(aval) + atom = trace.frame.newvar(aval) + dt = DynamicJaxprTracer(trace, aval, atom, line_info=current_source_info()) return dt @@ -913,13 +939,16 @@ def bind(self, *args, **params): # `abstract_eval` returned `out_type` calculated for empty constants. [], tracers, - maker=lambda a: DynamicJaxprTracer(trace, a, source_info), + maker=lambda aval: new_inner_tracer(trace, aval), ) - invars = map(trace.getvar, tracers) - outvars = map(trace.makevar, out_tracers) - - eqn = new_jaxpr_eqn(invars, outvars, self, params, [], source_info) + # invars = map(lambda t: t.val, tracers) + # outvars = map(lambda t: t.val, out_tracers) + # eqn = new_jaxpr_eqn(invars, outvars, self, params, [], source_info) + out_avals = [t.aval for t in out_tracers] + eqn, out_tracers = trace.make_eqn( + tracers, out_avals, self, params, [], source_info, out_tracers=out_tracers + ) trace.frame.add_eqn(eqn) return out_tracers if self.multiple_results else out_tracers.pop() diff --git a/frontend/catalyst/jax_primitives.py b/frontend/catalyst/jax_primitives.py index c1aa61e124..e0d86cb4d6 100644 --- a/frontend/catalyst/jax_primitives.py +++ b/frontend/catalyst/jax_primitives.py @@ -32,9 +32,8 @@ from jax._src.lax.lax import _merge_dyn_shape, _nary_lower_hlo, cos_p, sin_p from jax._src.lib.mlir import ir from jax._src.lib.mlir.dialects import hlo -from jax._src.pjit import _pjit_lowering +from jax._src.pjit import _pjit_lowering, jit_p from jax.core import AbstractValue -from jax.experimental.pjit import pjit_p from jax.extend.core import Primitive from jax.interpreters import mlir from jax.tree_util import PyTreeDef, tree_unflatten @@ -350,7 +349,7 @@ class MeasurementPlane(Enum): decomprule_p = core.Primitive("decomposition_rule") decomprule_p.multiple_results = True -quantum_subroutine_p = copy.deepcopy(pjit_p) +quantum_subroutine_p = copy.deepcopy(jit_p) quantum_subroutine_p.name = "quantum_subroutine_p" subroutine_cache: dict[callable, callable] = {} @@ -394,15 +393,15 @@ def main(): # pylint: disable-next=import-outside-toplevel from catalyst.api_extensions.callbacks import WRAPPER_ASSIGNMENTS - old_pjit = jax._src.pjit.pjit_p + old_jit_p = jax._src.pjit.jit_p @functools.wraps(func, assigned=WRAPPER_ASSIGNMENTS) def inside(*args, **kwargs): with Patcher( ( jax._src.pjit, - "pjit_p", - old_pjit, + "jit_p", + old_jit_p, ), ): return func(*args, **kwargs) @@ -417,7 +416,7 @@ def wrapper(*args, **kwargs): with Patcher( ( jax._src.pjit, - "pjit_p", + "jit_p", quantum_subroutine_p, ), ): @@ -714,7 +713,7 @@ def _grad_lowering(ctx, *args, jaxpr, fn, grad_params): flat_output_types, ir.StringAttr.get(method), symbol_ref, - mlir.flatten_lowering_ir_args(args_and_consts), + mlir.flatten_ir_values(args_and_consts), diffArgIndices=diffArgIndices, finiteDiffParam=finiteDiffParam, ).results @@ -738,7 +737,7 @@ def _capture_grad_lowering(ctx, *args, argnums, jaxpr, n_consts, method, h, fn, flat_output_types, ir.StringAttr.get(method), symbol_ref, - mlir.flatten_lowering_ir_args(args), + mlir.flatten_ir_values(args), diffArgIndices=diffArgIndices, finiteDiffParam=finiteDiffParam, ).results @@ -815,7 +814,7 @@ def _value_and_grad_lowering(ctx, *args, jaxpr, fn, grad_params): gradient_result_types, ir.StringAttr.get(method), symbol_ref, - mlir.flatten_lowering_ir_args(func_args), + mlir.flatten_ir_values(func_args), diffArgIndices=ir.DenseIntElementsAttr.get(new_argnums), finiteDiffParam=ir.FloatAttr.get(ir.F64Type.get(mlir_ctx), h) if h else None, ).results @@ -868,8 +867,8 @@ def _jvp_lowering(ctx, *args, jaxpr, fn, grad_params): flat_output_types[len(flat_output_types) // 2 :], ir.StringAttr.get(method), symbol_ref, - mlir.flatten_lowering_ir_args(func_args), - mlir.flatten_lowering_ir_args(tang_args), + mlir.flatten_ir_values(func_args), + mlir.flatten_ir_values(tang_args), diffArgIndices=ir.DenseIntElementsAttr.get(new_argnums), finiteDiffParam=ir.FloatAttr.get(ir.F64Type.get(mlir_ctx), h) if h else None, ).results @@ -918,8 +917,8 @@ def _vjp_lowering(ctx, *args, jaxpr, fn, grad_params): vjp_result_types, ir.StringAttr.get(method), symbol_ref, - mlir.flatten_lowering_ir_args(func_args), - mlir.flatten_lowering_ir_args(cotang_args), + mlir.flatten_ir_values(func_args), + mlir.flatten_ir_values(cotang_args), diffArgIndices=ir.DenseIntElementsAttr.get(new_argnums), finiteDiffParam=ir.FloatAttr.get(ir.F64Type.get(mlir_ctx), h) if h else None, ).results @@ -982,7 +981,7 @@ def _zne_lowering(ctx, *args, folding, jaxpr, fn): return ZneOp( flat_output_types, symbol_ref, - mlir.flatten_lowering_ir_args(args_and_consts), + mlir.flatten_ir_values(args_and_consts), _folding_attribute(ctx, folding), num_folds, ).results @@ -1852,21 +1851,11 @@ def custom_measurement_staging_rule( else: out_shapes = tuple(core.DShapedArray(shape, dtype) for dtype in dtypes) - invars = [jaxpr_trace.getvar(obs)] - for dyn_dim in dynamic_shape: - invars.append(jaxpr_trace.getvar(dyn_dim)) + in_tracers = [obs] + list(dynamic_shape) params = {"static_shape": static_shape} - out_tracers = tuple(pe.DynamicJaxprTracer(jaxpr_trace, out_shape) for out_shape in out_shapes) - - eqn = pe.new_jaxpr_eqn( - invars, - [jaxpr_trace.makevar(out_tracer) for out_tracer in out_tracers], - primitive, - params, - jax.core.no_effects, - ) + eqn, out_tracers = jaxpr_trace.make_eqn(in_tracers, out_shapes, primitive, params, []) jaxpr_trace.frame.add_eqn(eqn) return out_tracers if len(out_tracers) > 1 else out_tracers[0] @@ -2172,7 +2161,7 @@ def _cond_lowering( num_preds = len(branch_jaxprs) - 1 preds = preds_and_branch_args_plus_consts[:num_preds] branch_args_plus_consts = preds_and_branch_args_plus_consts[num_preds:] - flat_args_plus_consts = mlir.flatten_lowering_ir_args(branch_args_plus_consts) + flat_args_plus_consts = mlir.flatten_ir_values(branch_args_plus_consts) # recursively lower if-else chains to nested IfOps def emit_branches(preds, branch_jaxprs, ip): @@ -2263,7 +2252,7 @@ def _switch_lowering( # the last branch is default and does not have a case cases = index_and_cases_and_branch_args_plus_consts[1 : len(branch_jaxprs)] branch_args_plus_consts = index_and_cases_and_branch_args_plus_consts[len(branch_jaxprs) :] - flat_args_plus_consts = mlir.flatten_lowering_ir_args(branch_args_plus_consts) + flat_args_plus_consts = mlir.flatten_ir_values(branch_args_plus_consts) index = _cast_to_index(index) @@ -2357,7 +2346,7 @@ def _while_loop_lowering( preserve_dimensions: bool, ): loop_carry_types_plus_consts = [mlir.aval_to_ir_types(a)[0] for a in jax_ctx.avals_in] - flat_args_plus_consts = mlir.flatten_lowering_ir_args(iter_args_plus_consts) + flat_args_plus_consts = mlir.flatten_ir_values(iter_args_plus_consts) assert [val.type for val in flat_args_plus_consts] == loop_carry_types_plus_consts # split the argument list into 3 separate groups diff --git a/frontend/catalyst/jax_primitives_utils.py b/frontend/catalyst/jax_primitives_utils.py index 8d7e27acfe..99d2d181d1 100644 --- a/frontend/catalyst/jax_primitives_utils.py +++ b/frontend/catalyst/jax_primitives_utils.py @@ -176,13 +176,11 @@ def lower_callable_to_funcop(ctx, callable_, call_jaxpr, public=False): kwargs["name"] = name kwargs["jaxpr"] = call_jaxpr kwargs["effects"] = [] - kwargs["name_stack"] = ctx.name_stack - - # Make the visibility of the function public=True - # to avoid elimination by the compiler - kwargs["public"] = public + kwargs["main_function"] = False func_op = mlir.lower_jaxpr_to_fun(**kwargs) + if public: + func_op.attributes["sym_visibility"] = ir.StringAttr.get("public") if isinstance(callable_, qml.QNode): func_op.attributes["qnode"] = ir.UnitAttr.get() @@ -281,7 +279,7 @@ def create_call_op(ctx, func_op, *args): """Create a func::CallOp from JAXPR.""" output_types = list(map(mlir.aval_to_ir_types, ctx.avals_out)) flat_output_types = util.flatten(output_types) - mlir_args = mlir.flatten_lowering_ir_args(args) + mlir_args = mlir.flatten_ir_values(args) symbol_ref = get_symbolref(ctx, func_op) is_call_same_module = ctx.module_context.module.operation == func_op.parent constructor = CallOp if is_call_same_module else LaunchKernelOp diff --git a/frontend/catalyst/jax_tracer.py b/frontend/catalyst/jax_tracer.py index 4f0448ad7f..4402276b68 100644 --- a/frontend/catalyst/jax_tracer.py +++ b/frontend/catalyst/jax_tracer.py @@ -26,9 +26,12 @@ from typing import Any, Callable, Dict, List, Optional, Tuple, Union import jax +import jax._src.interpreters.partial_eval as pe import jax.numpy as jnp import pennylane as qml +from jax._src.lax import lax from jax._src.lax.lax import _extract_tracers_dyn_shape +from jax._src.pjit import jit_p from jax._src.source_info_util import current as current_source_info from jax.api_util import debug_info as jdb from jax.core import get_aval @@ -75,6 +78,12 @@ tree_unflatten, wrap_init, ) +from catalyst.jax_extras.patches import ( + patched_drop_unused_vars, + patched_dyn_shape_staging_rule, + patched_make_eqn, + patched_pjit_staging_rule, +) from catalyst.jax_extras.tracing import uses_transform from catalyst.jax_primitives import ( AbstractQreg, @@ -106,6 +115,7 @@ from catalyst.logging import debug_logger, debug_logger_init from catalyst.tracing.contexts import EvaluationContext, EvaluationMode from catalyst.utils.exceptions import CompileError +from catalyst.utils.patching import DictPatchWrapper, Patcher logger = logging.getLogger(__name__) logger.addHandler(logging.NullHandler()) @@ -136,6 +146,17 @@ def mark_gradient_tracing(method: str): TRACING_GRADIENTS.pop() +def _get_eqn_from_tracing_eqn(eqn_or_callable): + """Helper function to extract the actual equation from JAX 0.7's TracingEqn wrapper.""" + if callable(eqn_or_callable): + actual_eqn = eqn_or_callable() + if actual_eqn is None: # pragma: no cover + raise RuntimeError("TracingEqn weakref was garbage collected") + return actual_eqn + else: # pragma: no cover + return eqn_or_callable + + def _make_execution_config(qnode): """Updates the execution_config object with information about execution. This is used in preprocess to determine what decomposition and validation is needed.""" @@ -516,13 +537,14 @@ def bind_overwrite_classical_tracers( as-is. """ assert self.binder is not None, "HybridOp should set a binder" + # Here, we are binding any of the possible hybrid ops. # which includes: for_loop, while_loop, cond, measure. # This will place an equation at the very end of the current list of equations. out_quantum_tracer = self.binder(*in_expanded_tracers, **kwargs)[-1] trace = EvaluationContext.get_current_trace() - eqn = trace.frame.eqns[-1] + eqn = _get_eqn_from_tracing_eqn(trace.frame.tracing_eqns[-1]) frame = trace.frame assert len(eqn.outvars[:-1]) == len( @@ -532,22 +554,26 @@ def bind_overwrite_classical_tracers( jaxpr_variables = cached_vars.get(frame, set()) if not jaxpr_variables: # We get all variables in the current frame - outvars = itertools.chain.from_iterable([e.outvars for e in frame.eqns]) + outvars = itertools.chain.from_iterable( + [_get_eqn_from_tracing_eqn(e).outvars for e in frame.tracing_eqns] + ) jaxpr_variables = set(outvars) jaxpr_variables.update(frame.invars) jaxpr_variables.update(frame.constvar_to_val.keys()) cached_vars[frame] = jaxpr_variables - for outvar in frame.eqns[-1].outvars: + last_eqn = _get_eqn_from_tracing_eqn(frame.tracing_eqns[-1]) + for outvar in last_eqn.outvars: # With the exception of the output variables from the current equation. jaxpr_variables.discard(outvar) for i, t in enumerate(out_expanded_tracers): # We look for what were the previous output tracers. # If they haven't changed, then we leave them unchanged. - if trace.getvar(t) in jaxpr_variables: + if t.val in jaxpr_variables: continue + # For other hybrid ops, use the original logic # If the variable cannot be found in the current frame # it is because we have created it via new_inner_tracer # which uses JAX internals to create a tracer without associating @@ -587,7 +613,7 @@ def bind_overwrite_classical_tracers( # qrp2 = op.trace_quantum(ctx, device, trace, qrp, **kwargs) # # So it should be safe to cache the tracers as we are doing it. - eqn.outvars[i] = trace.getvar(t) + eqn.outvars[i] = t.val # Now, the output variables can be considered as part of the current frame. # This allows us to avoid importing all equations again next time. @@ -642,7 +668,19 @@ def trace_to_jaxpr(func, static_argnums, abstracted_axes, args, kwargs, debug_in PyTreeDef: PyTree-shape of the return values in ``PyTreeDef`` """ - with transient_jax_config({"jax_dynamic_shapes": True}): + with transient_jax_config( + {"jax_dynamic_shapes": True, "jax_use_shardy_partitioner": False} + ), Patcher( + (pe, "_drop_unused_vars", patched_drop_unused_vars), + (DynamicJaxprTrace, "make_eqn", patched_make_eqn), + (lax, "_dyn_shape_staging_rule", patched_dyn_shape_staging_rule), + ( + jax._src.pjit, # pylint: disable=protected-access + "pjit_staging_rule", + patched_pjit_staging_rule, + ), + (DictPatchWrapper(pe.custom_staging_rules, jit_p), "value", patched_pjit_staging_rule), + ): make_jaxpr_kwargs = { "static_argnums": static_argnums, "abstracted_axes": abstracted_axes, @@ -671,7 +709,19 @@ def lower_jaxpr_to_mlir(jaxpr, func_name, arg_names): MemrefCallable.clearcache() - with transient_jax_config({"jax_dynamic_shapes": True}): + # Apply JAX 0.7.0 compatibility patches during MLIR lowering. + # JAX internally calls trace_to_jaxpr_dynamic2 during lowering of nested @jit primitives + # (e.g., in jax.scipy.linalg.expm and jax.scipy.linalg.solve), which triggers two bugs: + # 1. make_eqn signature changed to include out_tracers parameter + # 2. pjit_staging_rule creates JaxprEqn instead of TracingEqn (AssertionError at partial_eval.py:1790) + with transient_jax_config( + {"jax_dynamic_shapes": True, "jax_use_shardy_partitioner": False} + ), Patcher( + # Fix make_eqn signature change (handles both old/new JAX versions) + (DynamicJaxprTrace, "make_eqn", patched_make_eqn), + # Fix pjit_staging_rule creating wrong equation type + (DictPatchWrapper(pe.custom_staging_rules, jit_p), "value", patched_pjit_staging_rule), + ): mlir_module, ctx = jaxpr_to_mlir(jaxpr, func_name, arg_names) return mlir_module, ctx @@ -883,7 +933,7 @@ def bind_native_operation(qrp, op, controlled_wires, controlled_values, adjoint= assert qrp2 is not None qrp = qrp2 trace = EvaluationContext.get_current_trace() - trace.frame.eqns = sort_eqns(trace.frame.eqns, FORCED_ORDER_PRIMITIVES) # [1] + trace.frame.tracing_eqns = sort_eqns(trace.frame.tracing_eqns, FORCED_ORDER_PRIMITIVES) # [1] return qrp diff --git a/frontend/catalyst/jit.py b/frontend/catalyst/jit.py index 81f873917c..100926b932 100644 --- a/frontend/catalyst/jit.py +++ b/frontend/catalyst/jit.py @@ -36,7 +36,6 @@ from catalyst.compiler import CompileOptions, Compiler, canonicalize, to_llvmir, to_mlir_opt from catalyst.debug.instruments import instrument from catalyst.from_plxpr import trace_from_pennylane -from catalyst.jax_extras.patches import get_aval2 from catalyst.jax_tracer import lower_jaxpr_to_mlir, trace_to_jaxpr from catalyst.logging import debug_logger, debug_logger_init from catalyst.qfunc import QFunc @@ -732,22 +731,15 @@ def capture(self, args, **kwargs): dbg = debug_info("qjit_capture", self.user_function, args, kwargs) if qml.capture.enabled(): - with Patcher( - ( - jax._src.interpreters.partial_eval, # pylint: disable=protected-access - "get_aval", - get_aval2, - ), - ): - return trace_from_pennylane( - self.user_function, - static_argnums, - dynamic_args, - abstracted_axes, - full_sig, - kwargs, - debug_info=dbg, - ) + return trace_from_pennylane( + self.user_function, + static_argnums, + dynamic_args, + abstracted_axes, + full_sig, + kwargs, + debug_info=dbg, + ) def closure(qnode, *args, **kwargs): params = {} diff --git a/frontend/catalyst/utils/patching.py b/frontend/catalyst/utils/patching.py index 1f6734f06e..b547d8fd5d 100644 --- a/frontend/catalyst/utils/patching.py +++ b/frontend/catalyst/utils/patching.py @@ -17,6 +17,37 @@ """ +class DictPatchWrapper: + """A wrapper to enable dictionary item patching using attribute-like access. + + This allows the Patcher class to patch dictionary items by wrapping the dictionary + and key into an object where the item can be accessed as an attribute. + + Args: + dictionary: The dictionary to wrap + key: The key to access in the dictionary + """ + + def __init__(self, dictionary, key): + self.dictionary = dictionary + self.key = key + + def __getattr__(self, name): + if name in ("dictionary", "key"): # pragma: no cover + return object.__getattribute__(self, name) + if name == "value": + return self.dictionary[self.key] + raise AttributeError(f"'{type(self).__name__}' object has no attribute '{name}'") + + def __setattr__(self, name, value): + if name in ("dictionary", "key"): + object.__setattr__(self, name, value) + elif name == "value": + self.dictionary[self.key] = value + else: # pragma: no cover + raise AttributeError(f"'{type(self).__name__}' object has no attribute '{name}'") + + class Patcher: """Patcher, a class to replace object attributes. diff --git a/frontend/test/lit/test_jax_dynamic_api.py b/frontend/test/lit/test_jax_dynamic_api.py index ff8b810d51..c23c77680d 100644 --- a/frontend/test/lit/test_jax_dynamic_api.py +++ b/frontend/test/lit/test_jax_dynamic_api.py @@ -41,7 +41,7 @@ def test_qnode_dynamic_arg(a): """Test passing a dynamic argument to qnode""" # CHECK: { lambda ; [[a:.]]:i64[] [[b:.]]:i64[[[a]]]. let - # CHECK: [[c:.]]:i64[[[a]]] = quantum_kernel[ + # CHECK: [[c:.]]:i64[InDBIdx(val=0)] = quantum_kernel[ # CHECK: ] [[a]] [[b]] # CHECK: in ([[c]],) } @qml.qnode(qml.device("lightning.qubit", wires=1)) @@ -75,9 +75,9 @@ def test_qnode_dynamic_result(a): """Test getting a dynamic result from qnode""" # CHECK: { lambda ; [[a:.]]:i64[]. let - # CHECK: [[b:.]]:i64[] [[c:.]]:f64[[[b]]] = quantum_kernel[ + # CHECK: {{.+}}:i64[] [[c:.]]:f64[OutDBIdx(val=0)] = quantum_kernel[ # CHECK: ] [[a]] - # CHECK: in ([[b]], [[c]]) } + # CHECK: in ([[c]],) } @qml.qnode(qml.device("lightning.qubit", wires=1)) def _circuit(a): return jnp.ones((a + 1,), dtype=float) diff --git a/frontend/test/pytest/from_plxpr/test_decompose_transform.py b/frontend/test/pytest/from_plxpr/test_decompose_transform.py index 811eaa9b1b..861bb80905 100644 --- a/frontend/test/pytest/from_plxpr/test_decompose_transform.py +++ b/frontend/test/pytest/from_plxpr/test_decompose_transform.py @@ -428,7 +428,7 @@ def test_multi_passes(self): @qml.transforms.cancel_inverses @partial( qml.transforms.decompose, - gate_set={"RZ", "RY", "CNOT", "GlobalPhase"}, + gate_set=frozenset({"RZ", "RY", "CNOT", "GlobalPhase"}), ) @qml.qnode(qml.device("lightning.qubit", wires=1)) def circuit(): diff --git a/frontend/test/pytest/from_plxpr/test_from_plxpr_qubit_handler.py b/frontend/test/pytest/from_plxpr/test_from_plxpr_qubit_handler.py index 4d57bdf15b..3ecdbd6a03 100644 --- a/frontend/test/pytest/from_plxpr/test_from_plxpr_qubit_handler.py +++ b/frontend/test/pytest/from_plxpr/test_from_plxpr_qubit_handler.py @@ -45,6 +45,7 @@ from catalyst.from_plxpr.qubit_handler import QubitHandler, QubitIndexRecorder from catalyst.jax_primitives import AbstractQbit, AbstractQreg, qalloc_p, qextract_p +from catalyst.jax_tracer import _get_eqn_from_tracing_eqn from catalyst.utils.exceptions import CompileError @@ -239,13 +240,17 @@ def test_auto_extract(self): assert qubit_handler[0] is new_qubit with take_current_trace() as trace: # Check that an extract primitive is added - assert trace.frame.eqns[-1].primitive is qextract_p + last_eqn = _get_eqn_from_tracing_eqn(trace.frame.tracing_eqns[-1]) + assert last_eqn.primitive is qextract_p # Check that the extract primitive follows the wire index in the qreg manager # __getitem__ method - extract_p_index_invar = trace.frame.eqns[-1].invars[-1] - assert isinstance(extract_p_index_invar, Literal) - assert extract_p_index_invar.val == 0 + extract_p_index_invar = last_eqn.invars[-1] + if isinstance(extract_p_index_invar, Literal): + assert extract_p_index_invar.val == 0 + else: + assert isinstance(extract_p_index_invar.val, Literal) + assert extract_p_index_invar.val.val == 0 def test_no_overwriting_extract(self): """Test that no new qubit is extracted when indexing into an existing wire""" @@ -273,9 +278,10 @@ def test_simple_gate(self): # Also check with actual jaxpr variables with take_current_trace() as trace: - gate_out_qubits = trace.frame.eqns[-1].outvars - assert trace.frame.tracer_to_var[id(qubit_handler[0])] == gate_out_qubits[0] - assert trace.frame.tracer_to_var[id(qubit_handler[1])] == gate_out_qubits[1] + last_eqn = _get_eqn_from_tracing_eqn(trace.frame.tracing_eqns[-1]) + gate_out_qubits = last_eqn.outvars + assert qubit_handler[0].val == gate_out_qubits[0] + assert qubit_handler[1].val == gate_out_qubits[1] def test_iter(self): """Test __iter__ in the qreg manager""" @@ -315,9 +321,10 @@ def test_chained_gate(self): # Also check with actual jaxpr variables with take_current_trace() as trace: - gate_out_qubits = trace.frame.eqns[-1].outvars - assert trace.frame.tracer_to_var[id(qubit_handler[0])] == gate_out_qubits[0] - assert trace.frame.tracer_to_var[id(qubit_handler[1])] == gate_out_qubits[1] + last_eqn = _get_eqn_from_tracing_eqn(trace.frame.tracing_eqns[-1]) + gate_out_qubits = last_eqn.outvars + assert qubit_handler[0].val == gate_out_qubits[0] + assert qubit_handler[1].val == gate_out_qubits[1] def test_insert_all_dangling_qubits(self): """ diff --git a/frontend/test/pytest/test_preprocess.py b/frontend/test/pytest/test_preprocess.py index a918ec09d4..916e3e7c97 100644 --- a/frontend/test/pytest/test_preprocess.py +++ b/frontend/test/pytest/test_preprocess.py @@ -118,14 +118,14 @@ def test_decompose_integration(self): @qml.qnode(dev) def circuit(theta: float): - qml.SingleExcitationPlus(theta, wires=[0, 1]) + qml.OrbitalRotation(theta, wires=[0, 1, 2, 3]) return qml.state() mlir = qjit(circuit, target="mlir").mlir + assert "SingleExcitation" in mlir assert "Hadamard" in mlir - assert "CNOT" in mlir - assert "RY" in mlir - assert "SingleExcitationPlus" not in mlir + assert "RX" in mlir + assert "OrbitalRotation" not in mlir def test_decompose_ops_to_unitary(self): """Test the decompose ops to unitary transform.""" diff --git a/frontend/test/pytest/test_verification.py b/frontend/test/pytest/test_verification.py index e50d5c10ce..b88cf26642 100644 --- a/frontend/test/pytest/test_verification.py +++ b/frontend/test/pytest/test_verification.py @@ -37,7 +37,7 @@ from catalyst.api_extensions import HybridAdjoint, HybridCtrl from catalyst.compiler import get_lib_path from catalyst.device import get_device_capabilities -from catalyst.device.qjit_device import RUNTIME_OPERATIONS, get_qjit_device_capabilities +from catalyst.device.qjit_device import CUSTOM_OPERATIONS, get_qjit_device_capabilities from catalyst.device.verification import validate_measurements # pylint: disable = unused-argument, unnecessary-lambda-assignment, unnecessary-lambda @@ -290,7 +290,7 @@ def test_non_controllable_gate_hybridctrl(self): # Note: The HybridCtrl operator is not currently supported with the QJIT device, but the # verification structure is in place, so we test the verification of its nested operators by # adding HybridCtrl to the list of native gates for the custom base device and by patching - # the list of RUNTIME_OPERATIONS for the QJIT device to include HybridCtrl for this test. + # the list of CUSTOM_OPERATIONS for the QJIT device to include HybridCtrl for this test. @qml.qnode( get_custom_device( @@ -302,12 +302,12 @@ def f(x: float): assert isinstance(op, HybridCtrl), f"op expected to be HybridCtrl but got {type(op)}" return qml.expval(qml.PauliX(0)) - runtime_ops_with_qctrl = deepcopy(RUNTIME_OPERATIONS) + runtime_ops_with_qctrl = deepcopy(CUSTOM_OPERATIONS) runtime_ops_with_qctrl["HybridCtrl"] = OperatorProperties( invertible=True, controllable=True, differentiable=True ) - with patch("catalyst.device.qjit_device.RUNTIME_OPERATIONS", runtime_ops_with_qctrl): + with patch("catalyst.device.qjit_device.CUSTOM_OPERATIONS", runtime_ops_with_qctrl): with pytest.raises(CompileError, match="PauliZ is not controllable"): qjit(f)(1.2) @@ -321,7 +321,7 @@ def test_hybridctrl_raises_error(self): """Test that a HybridCtrl operator is rejected by the verification.""" # TODO: If you are deleting this test because HybridCtrl support has been added, consider - # updating the tests that patch RUNTIME_OPERATIONS to inclue HybridCtrl accordingly + # updating the tests that patch CUSTOM_OPERATIONS to inclue HybridCtrl accordingly @qml.qnode(get_custom_device(non_controllable_gates={"PauliZ"}, wires=4)) def f(x: float): @@ -391,7 +391,7 @@ def test_hybrid_ctrl_containing_adjoint(self, adjoint_type, unsupported_gate_att # Note: The HybridCtrl operator is not currently supported with the QJIT device, but the # verification structure is in place, so we test the verification of its nested operators by # adding HybridCtrl to the list of native gates for the custom base device and by patching - # the list of RUNTIME_OPERATIONS for the QJIT device to include HybridCtrl for this test. + # the list of CUSTOM_OPERATIONS for the QJIT device to include HybridCtrl for this test. def _ops(x, wires): if adjoint_type == HybridAdjoint: @@ -410,12 +410,12 @@ def f(x: float): assert isinstance(base, adjoint_type), f"expected {adjoint_type} but got {type(op)}" return qml.expval(qml.PauliX(0)) - runtime_ops_with_qctrl = deepcopy(RUNTIME_OPERATIONS) + runtime_ops_with_qctrl = deepcopy(CUSTOM_OPERATIONS) runtime_ops_with_qctrl["HybridCtrl"] = OperatorProperties( invertible=True, controllable=True, differentiable=True ) - with patch("catalyst.device.qjit_device.RUNTIME_OPERATIONS", runtime_ops_with_qctrl): + with patch("catalyst.device.qjit_device.CUSTOM_OPERATIONS", runtime_ops_with_qctrl): with pytest.raises(CompileError, match=f"PauliZ is not {unsupported_gate_attribute}"): qjit(f)(1.2) @@ -434,7 +434,7 @@ def test_hybrid_adjoint_containing_hybrid_ctrl(self, ctrl_type, unsupported_gate # Note: The HybridCtrl operator is not currently supported with the QJIT device, but the # verification structure is in place, so we test the verification of its nested operators by # adding HybridCtrl to the list of native gates for the custom base device and by patching - # the list of RUNTIME_OPERATIONS for the QJIT device to include HybridCtrl for this test. + # the list of CUSTOM_OPERATIONS for the QJIT device to include HybridCtrl for this test. def _ops(x, wires): if ctrl_type == HybridCtrl: @@ -453,12 +453,12 @@ def f(x: float): assert isinstance(base, ctrl_type), f"expected {ctrl_type} but got {type(op)}" return qml.expval(qml.PauliX(0)) - runtime_ops_with_qctrl = deepcopy(RUNTIME_OPERATIONS) + runtime_ops_with_qctrl = deepcopy(CUSTOM_OPERATIONS) runtime_ops_with_qctrl["HybridCtrl"] = OperatorProperties( invertible=True, controllable=True, differentiable=True ) - with patch("catalyst.device.qjit_device.RUNTIME_OPERATIONS", runtime_ops_with_qctrl): + with patch("catalyst.device.qjit_device.CUSTOM_OPERATIONS", runtime_ops_with_qctrl): with pytest.raises(CompileError, match=f"PauliZ is not {unsupported_gate_attribute}"): qjit(f)(1.2)