Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
129 commits
Select commit Hold shift + click to select a range
df0c822
Update LLVM version
rniczh Oct 12, 2025
e7d2a5b
Update .dep-versions
rniczh Oct 16, 2025
b1082d0
Fix formatting
rniczh Oct 16, 2025
2502cba
Add commit hash support to setup.py for better controlling the versio…
rniczh Oct 16, 2025
f61b55d
Merge branch 'main' into rniczh/update-jax-to-0.7.2
rniczh Oct 16, 2025
9e8e1f2
Merge branch 'main' into rniczh/update-llvm-version-20251009
rniczh Oct 16, 2025
652d3d1
Merge branch 'rniczh/update-llvm-version-20251009' into rniczh/update…
rniczh Oct 16, 2025
bedfa21
Bump jax to 0.7.0.dev20250703+cd1b9520b
rniczh Oct 17, 2025
5c49b4d
formatting
rniczh Oct 17, 2025
1f7ebce
update to jax 0.7.0
rniczh Oct 20, 2025
34dbdba
change ods file patch method
rniczh Oct 20, 2025
a13f970
remove
rniczh Oct 21, 2025
c06245b
mock register_traceback_file_exclusion
rniczh Oct 27, 2025
28ba2da
Merge branch 'main' into rniczh/update-llvm-version-20251009
rniczh Oct 27, 2025
4024075
fix formatting
rniczh Oct 27, 2025
496fc73
fix formatting
rniczh Oct 27, 2025
3fb8611
fix pylint
rniczh Oct 27, 2025
0e1543c
fix pylint
rniczh Oct 27, 2025
408d0c0
Merge branch 'rniczh/update-jax-to-0.7.2' into rniczh/bump-jax-to-0.7…
rniczh Oct 27, 2025
02598d2
merge upstream
rniczh Oct 27, 2025
ac90188
merge from upstream
rniczh Oct 27, 2025
1d53800
udpate
rniczh Oct 27, 2025
a0d8c1c
link to PL 0.7.0 branch
rniczh Oct 27, 2025
8b75df2
fix formatting
rniczh Oct 27, 2025
8bf8808
update
rniczh Oct 27, 2025
fff771a
fix formatting
rniczh Oct 27, 2025
c1eee61
fix formatting
rniczh Oct 27, 2025
0795360
merge from upstream
rniczh Oct 28, 2025
e22029d
Patch: tuple2slice and tuple2dict
JerryChen97 Oct 28, 2025
851197a
fix formatting
rniczh Oct 28, 2025
2995927
Patch: tuple2slice and tuple2dict
JerryChen97 Oct 28, 2025
d8a5dc9
Merge remote-tracking branch 'refs/remotes/origin/rniczh/bump-jax-to-…
JerryChen97 Oct 28, 2025
3baa2d8
Merge remote-tracking branch 'refs/remotes/origin/rniczh/bump-jax-to-…
JerryChen97 Oct 28, 2025
0c94942
Update submodules: Enzyme v0.0.203, LLVM 113f01aa, StableHLO 0a4440a5
JerryChen97 Oct 28, 2025
75cb75c
Merge remote-tracking branch 'refs/remotes/origin/rniczh/bump-jax-to-…
JerryChen97 Oct 28, 2025
c8c5629
Merge branch 'rniczh/update-llvm-version-20251009' into rniczh/bump-j…
JerryChen97 Oct 28, 2025
5d1e65d
remove source code patch
rniczh Oct 28, 2025
345911f
unwrap callable tracing with helper function
rniczh Oct 28, 2025
55ac161
jax version using 0.7.0
rniczh Oct 28, 2025
dc65497
Fix formatting
rniczh Oct 28, 2025
87b7de2
update commit hash
rniczh Oct 28, 2025
7d11d2a
Merge branch 'rniczh/update-llvm-version-20251009' of github.com:Penn…
rniczh Oct 28, 2025
61cc8b4
fix coverage
rniczh Oct 28, 2025
5acc74e
revert git reference thing
rniczh Oct 28, 2025
d871d37
revert git reference thing
rniczh Oct 28, 2025
17ff2cf
update pl version in doc
rniczh Oct 28, 2025
586fbcf
fix
rniczh Oct 28, 2025
509e8b7
revert
rniczh Oct 28, 2025
5fc49be
revert
rniczh Oct 28, 2025
f9d1531
Merge branch 'rniczh/update-llvm-version-20251009' of github.com:Penn…
rniczh Oct 28, 2025
8a333ae
Fix test_adjoint.py: Use is_verified_hermitian instead of deprecated …
JerryChen97 Oct 28, 2025
70a8114
fx quantum control
JerryChen97 Oct 28, 2025
a4a63a3
fix
rniczh Oct 28, 2025
bfa7593
Merge branch 'rniczh/bump-jax-to-0.7.0' of github.com:PennyLaneAI/cat…
rniczh Oct 28, 2025
21b9097
fix coverage
rniczh Oct 28, 2025
45b3f9d
fix pylint
rniczh Oct 28, 2025
be1cc61
Merge branch 'rniczh/update-llvm-version-20251009' of github.com:Penn…
rniczh Oct 28, 2025
656353e
fix formatting
rniczh Oct 28, 2025
c730778
make format
JerryChen97 Oct 28, 2025
3e9bf78
remove jax patch
rniczh Oct 28, 2025
413e1bd
Merge branch 'rniczh/bump-jax-to-0.7.0' of github.com:PennyLaneAI/cat…
rniczh Oct 28, 2025
bff1f7f
comment pl version temporarily
rniczh Oct 28, 2025
7bcc8e8
remove git referencec code
rniczh Oct 28, 2025
c8ec2ac
Update Makefile
rniczh Oct 29, 2025
a483897
add comment
rniczh Oct 29, 2025
8b232fa
Update frontend/catalyst/from_plxpr/from_plxpr.py
rniczh Oct 29, 2025
297c7d6
apply suggestion https://github.com/PennyLaneAI/catalyst/pull/2131#di…
JerryChen97 Oct 29, 2025
d6d1c6a
Fix patch
rniczh Oct 29, 2025
1aa2e0a
Merge branch 'rniczh/bump-jax-to-0.7.0' of github.com:PennyLaneAI/cat…
rniczh Oct 29, 2025
709cb87
Add dict patcher
rniczh Oct 29, 2025
aca66f0
refine
rniczh Oct 29, 2025
672f230
remove unreachable code
rniczh Oct 29, 2025
3d07aac
patch ods_cext with patcher
rniczh Oct 29, 2025
ebff8a4
update patch
rniczh Oct 29, 2025
11aee12
move the patch to jax_primitives.py
rniczh Oct 29, 2025
096d7a2
fix formatting
rniczh Oct 29, 2025
f9f7763
fix formatting
rniczh Oct 29, 2025
d9b48fa
merge from upstream
rniczh Oct 29, 2025
020537f
rename
rniczh Oct 29, 2025
875ef7c
Merge branch 'main' into rniczh/update-llvm-version-20251009
rniczh Oct 29, 2025
296d2c0
merge from upstream
rniczh Oct 29, 2025
73d8713
remove redundant
rniczh Oct 29, 2025
6b518df
Update frontend/catalyst/jax_primitives.py
rniczh Oct 29, 2025
f6db500
fix rename
rniczh Oct 29, 2025
939ec48
move patch
rniczh Oct 30, 2025
6ba990c
merge from upstream
rniczh Oct 30, 2025
69856d7
update
rniczh Oct 30, 2025
e373a66
merge from upstream
rniczh Nov 5, 2025
fb60aa4
Merge branch 'main' into rniczh/bump-jax-to-0.7.0
rniczh Nov 6, 2025
fdc3bcb
CI
rniczh Nov 6, 2025
8fdad5f
fix
rniczh Nov 6, 2025
3b9cbbf
merge
rniczh Nov 12, 2025
b157d34
merge
rniczh Nov 13, 2025
55df257
Merge branch 'main' into rniczh/bump-jax-to-0.7.0
JerryChen97 Nov 14, 2025
9310f0a
Merge branch 'main' into rniczh/bump-jax-to-0.7.0
rniczh Nov 17, 2025
68a98f8
trigger
JerryChen97 Nov 17, 2025
db402f0
update
rniczh Nov 18, 2025
2f018d5
Merge branch 'main' into rniczh/bump-jax-to-0.7.0
rniczh Nov 18, 2025
9c23ff0
refactor for codefactor
rniczh Nov 20, 2025
e764b99
Remove RUNTIME_OPERATIONS from the frontend
maliasadi Nov 20, 2025
df72764
Merge branch 'main' into rniczh/bump-jax-to-0.7.0
rniczh Nov 20, 2025
5033566
formatting
rniczh Nov 20, 2025
fc4726b
Merge branch 'rniczh/bump-jax-to-0.7.0' of github.com:PennyLaneAI/cat…
rniczh Nov 20, 2025
8252422
formatting
rniczh Nov 20, 2025
fd1561b
reformatting
rniczh Nov 20, 2025
6c665a2
codecov coverage
rniczh Nov 20, 2025
654b2bf
Update frontend/catalyst/jax_tracer.py
rniczh Nov 20, 2025
68b85e1
suggesstion from JerryChen
rniczh Nov 20, 2025
0caedf6
Merge branch 'rniczh/bump-jax-to-0.7.0' of github.com:PennyLaneAI/cat…
rniczh Nov 20, 2025
65f6361
formatting
rniczh Nov 20, 2025
d3db19a
Update frontend/catalyst/utils/patching.py
rniczh Nov 20, 2025
d0a2f0a
codecov
rniczh Nov 20, 2025
b660c73
coverage
rniczh Nov 20, 2025
ce7b84e
coverage
rniczh Nov 20, 2025
0065b8e
formatting
rniczh Nov 20, 2025
722c2db
test
rniczh Nov 20, 2025
d50c7f8
github-status
rniczh Nov 20, 2025
af010cc
Merge branch 'main' into rniczh/bump-jax-to-0.7.0
JerryChen97 Nov 21, 2025
41f08c2
Merge branch 'main' into rniczh/bump-jax-to-0.7.0
JerryChen97 Nov 21, 2025
fdd0d37
trigger to see what if no sort in dict hash
JerryChen97 Nov 21, 2025
ec6ce48
Merge branch 'main' into rniczh/bump-jax-to-0.7.0
JerryChen97 Nov 21, 2025
8c6cc95
frozenset?
JerryChen97 Nov 21, 2025
b25d701
Merge branch 'main' into rniczh/bump-jax-to-0.7.0
JerryChen97 Nov 21, 2025
7793fcd
Pennylane Jax patcher test: bump jax to 0.7.0 (#2205)
JerryChen97 Nov 21, 2025
74c3a64
Merge branch 'main' into rniczh/bump-jax-to-0.7.0
JerryChen97 Nov 25, 2025
325e376
TEMP try: use the no patch PL branch
JerryChen97 Nov 25, 2025
94b769a
Merge branch 'main' into rniczh/bump-jax-to-0.7.0
rniczh Nov 26, 2025
e0e7659
Merge branch 'main' into rniczh/bump-jax-to-0.7.0
JerryChen97 Nov 27, 2025
7f535da
Merge branch 'main' into rniczh/bump-jax-to-0.7.0
rniczh Nov 27, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions .dep-versions
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# Always update the version check in catalyst.__init__ when changing the JAX version.
# To update JAX version alongside compatible dependency tags, run the following script:
# python3 .github/workflows/set_dep_versions.py {JAX_version}
jax=0.6.2
jax=0.7.0
stablehlo=0a4440a5c8de45c4f9649bf3eb4913bf3f97da0d
llvm=113f01aa82d055410f22a9d03b3468fa68600589
enzyme=v0.0.203
Expand All @@ -10,7 +10,9 @@ enzyme=v0.0.203

# For a custom PL version, update the package version here and at
# 'doc/requirements.txt'
pennylane=0.44.0.dev31
# TODO: uncomment and update to latest version of pennylane
# after https://github.com/PennyLaneAI/pennylane/pull/8525 is merged.
# pennylane=0.44.0.dev31

# For a custom LQ/LK version, update the package version here and at
# 'doc/requirements.txt'
Expand Down
4 changes: 3 additions & 1 deletion Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,8 @@ frontend:
# versions of a package with the same version tag (e.g. 0.38-dev0).
$(PYTHON) -m pip uninstall -y pennylane
$(PYTHON) -m pip install -e . --extra-index-url https://test.pypi.org/simple $(PIP_VERBOSE_FLAG)
# TODO: remove after https://github.com/PennyLaneAI/pennylane/pull/8525 is merged.
$(PYTHON) -m pip install git+https://github.com/PennyLaneAI/pennylane@bump-jax-api-hashability
rm -r frontend/pennylane_catalyst.egg-info

.PHONY: mlir llvm stablehlo enzyme dialects runtime oqc
Expand All @@ -134,7 +136,7 @@ enzyme:

dialects:
$(MAKE) -C mlir dialects

.PHONY: dialect-docs
dialect-docs:
$(MAKE) -C mlir dialect-docs
Expand Down
7 changes: 7 additions & 0 deletions doc/releases/changelog-dev.md
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,13 @@

<h3>Improvements 🛠</h3>

* Remove the hardcoded list of runtime operations in the frontend.
This will allow arbitrary PL gates to be represented without hyperparameters in MLIR.
For gates that do not have a QIR representation, a runtime error will be raised at execution.
Users can still decompose these gates via `qml.transforms.decompose`
when both capture and graph-decomposition are enabled.
[(#2215)](https://github.com/PennyLaneAI/catalyst/pull/2215)

* `qml.PCPhase` can be compiled and executed with capture enabled.
[(#2226)](https://github.com/PennyLaneAI/catalyst/pull/2226)

Expand Down
3 changes: 1 addition & 2 deletions frontend/catalyst/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,11 @@
# pylint: disable=wrong-import-position

import sys
import types
from os.path import dirname

import jaxlib as _jaxlib

_jaxlib_version = "0.6.2"
_jaxlib_version = "0.7.0"
if _jaxlib.__version__ != _jaxlib_version:
import warnings

Expand Down
6 changes: 3 additions & 3 deletions frontend/catalyst/autograph/ag_primitives.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ def get_program_length(reference_tracers):

if EvaluationContext.is_tracing(): # pragma: no branch
jaxpr_frame = EvaluationContext.find_jaxpr_frame(reference_tracers)
num_jaxpr_eqns = len(jaxpr_frame.eqns)
num_jaxpr_eqns = len(jaxpr_frame.tracing_eqns)

if EvaluationContext.is_quantum_tracing():
quantum_queue = EvaluationContext.find_quantum_queue()
Expand All @@ -79,8 +79,8 @@ def reset_program_to_length(reference_tracers, num_jaxpr_eqns, num_tape_ops):

if EvaluationContext.is_tracing(): # pragma: no branch
jaxpr_frame = EvaluationContext.find_jaxpr_frame(reference_tracers)
while len(jaxpr_frame.eqns) > num_jaxpr_eqns:
jaxpr_frame.eqns.pop()
while len(jaxpr_frame.tracing_eqns) > num_jaxpr_eqns:
jaxpr_frame.tracing_eqns.pop()

if EvaluationContext.is_quantum_tracing():
quantum_queue = EvaluationContext.find_quantum_queue()
Expand Down
58 changes: 13 additions & 45 deletions frontend/catalyst/device/qjit_device.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,44 +56,6 @@
logger = logging.getLogger(__name__)
logger.addHandler(logging.NullHandler())

RUNTIME_OPERATIONS = [
"CNOT",
"ControlledPhaseShift",
"CRot",
"CRX",
"CRY",
"CRZ",
"CSWAP",
"CY",
"CZ",
"Hadamard",
"Identity",
"IsingXX",
"IsingXY",
"IsingYY",
"IsingZZ",
"SingleExcitation",
"DoubleExcitation",
"ISWAP",
"MultiRZ",
"PauliX",
"PauliY",
"PauliZ",
"PCPhase",
"PhaseShift",
"PSWAP",
"QubitUnitary",
"Rot",
"RX",
"RY",
"RZ",
"S",
"SWAP",
"T",
"Toffoli",
"GlobalPhase",
]

RUNTIME_OBSERVABLES = [
"Identity",
"PauliX",
Expand All @@ -109,11 +71,9 @@

RUNTIME_MPS = ["ExpectationMP", "SampleMP", "VarianceMP", "CountsMP", "StateMP", "ProbabilityMP"]

# The runtime interface does not care about specific gate properties, so set them all to True.
RUNTIME_OPERATIONS = {
op: OperatorProperties(invertible=True, controllable=True, differentiable=True)
for op in RUNTIME_OPERATIONS
}
# A list of custom operations supported by the Catalyst compiler.
# This is useful especially for testing a device with custom operations.
CUSTOM_OPERATIONS = {}

RUNTIME_OBSERVABLES = {
obs: OperatorProperties(invertible=True, controllable=True, differentiable=True)
Expand Down Expand Up @@ -199,6 +159,14 @@ def extract_backend_info(device: qml.devices.QubitDevice) -> BackendInfo:
return BackendInfo(dname, device_name, device_lpath, device_kwargs)


def union_operations(
a: Dict[str, OperatorProperties], b: Dict[str, OperatorProperties]
) -> Dict[str, OperatorProperties]:
"""Union of two sets of operator properties"""
return {**a, **b}
# return {k: a[k] & b[k] for k in (a.keys() & b.keys())}


def intersect_operations(
a: Dict[str, OperatorProperties], b: Dict[str, OperatorProperties]
) -> Dict[str, OperatorProperties]:
Expand All @@ -223,8 +191,8 @@ def get_qjit_device_capabilities(target_capabilities: DeviceCapabilities) -> Dev
qjit_capabilities = deepcopy(target_capabilities)

# Intersection of gates and observables supported by the device and by Catalyst runtime.
qjit_capabilities.operations = intersect_operations(
target_capabilities.operations, RUNTIME_OPERATIONS
qjit_capabilities.operations = union_operations(
target_capabilities.operations, CUSTOM_OPERATIONS
)
qjit_capabilities.observables = intersect_operations(
target_capabilities.observables, RUNTIME_OBSERVABLES
Expand Down
80 changes: 59 additions & 21 deletions frontend/catalyst/from_plxpr/control_flow.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,11 @@
from pennylane.capture.primitives import for_loop_prim as plxpr_for_loop_prim
from pennylane.capture.primitives import while_loop_prim as plxpr_while_loop_prim

from catalyst.from_plxpr.from_plxpr import PLxPRToQuantumJaxprInterpreter, WorkflowInterpreter
from catalyst.from_plxpr.from_plxpr import (
PLxPRToQuantumJaxprInterpreter,
WorkflowInterpreter,
_tuple_to_slice,
)
from catalyst.from_plxpr.qubit_handler import (
QubitHandler,
QubitIndexRecorder,
Expand Down Expand Up @@ -101,16 +105,21 @@ def _to_bool_if_not(arg):

@WorkflowInterpreter.register_primitive(plxpr_cond_prim)
def workflow_cond(self, *plxpr_invals, jaxpr_branches, consts_slices, args_slice):
"""Handle the conversion from plxpr to Catalyst jaxpr for the cond primitive"""
args = plxpr_invals[args_slice]
"""Handle the conversion from plxpr to Catalyst jaxpr for the cond primitive

Args:
consts_slices: List of tuples (start, stop, step) to slice consts for each branch
args_slice: Tuple (start, stop, step) to slice args from plxpr_invals
"""
args = plxpr_invals[_tuple_to_slice(args_slice)]
converted_jaxpr_branches = []
all_consts = []

# Convert each branch from plxpr to jaxpr
for const_slice, plxpr_branch in zip(consts_slices, jaxpr_branches):

# Store all branches consts in a flat list
branch_consts = plxpr_invals[const_slice]
branch_consts = plxpr_invals[_tuple_to_slice(const_slice)]

evaluator = partial(copy(self).eval, plxpr_branch, branch_consts)
new_jaxpr = jax.make_jaxpr(evaluator)(*args)
Expand All @@ -132,8 +141,13 @@ def workflow_cond(self, *plxpr_invals, jaxpr_branches, consts_slices, args_slice

@PLxPRToQuantumJaxprInterpreter.register_primitive(plxpr_cond_prim)
def handle_cond(self, *plxpr_invals, jaxpr_branches, consts_slices, args_slice):
"""Handle the conversion from plxpr to Catalyst jaxpr for the cond primitive"""
args = plxpr_invals[args_slice]
"""Handle the conversion from plxpr to Catalyst jaxpr for the cond primitive

Args:
consts_slices: List of tuples (start, stop, step) to slice consts for each branch
args_slice: Tuple (start, stop, step) to slice args from plxpr_invals
"""
args = plxpr_invals[_tuple_to_slice(args_slice)]
self.init_qreg.insert_all_dangling_qubits()

dynalloced_qregs, dynalloced_wire_global_indices = _get_dynamically_allocated_qregs(
Expand All @@ -154,7 +168,7 @@ def handle_cond(self, *plxpr_invals, jaxpr_branches, consts_slices, args_slice):
for const_slice, plxpr_branch in zip(consts_slices, jaxpr_branches):

# Store all branches consts in a flat list
branch_consts = plxpr_invals[const_slice]
branch_consts = plxpr_invals[_tuple_to_slice(const_slice)]

converted_jaxpr_branch = None
closed_jaxpr = ClosedJaxpr(plxpr_branch, branch_consts)
Expand Down Expand Up @@ -205,11 +219,17 @@ def workflow_for_loop(
args_slice,
abstract_shapes_slice,
):
"""Handle the conversion from plxpr to Catalyst jaxpr for the for loop primitive"""
"""Handle the conversion from plxpr to Catalyst jaxpr for the for loop primitive

Args:
consts_slice: Tuple (start, stop, step) to slice consts from plxpr_invals
args_slice: Tuple (start, stop, step) to slice args from plxpr_invals
abstract_shapes_slice: Tuple (start, stop, step) to slice abstract shapes
"""
assert jaxpr_body_fn is not None
args = plxpr_invals[args_slice]
args = plxpr_invals[_tuple_to_slice(args_slice)]

consts = plxpr_invals[consts_slice]
consts = plxpr_invals[_tuple_to_slice(consts_slice)]

converter = copy(self)
evaluator = partial(converter.eval, jaxpr_body_fn, consts)
Expand Down Expand Up @@ -250,9 +270,15 @@ def handle_for_loop(
args_slice,
abstract_shapes_slice,
):
"""Handle the conversion from plxpr to Catalyst jaxpr for the for loop primitive"""
"""Handle the conversion from plxpr to Catalyst jaxpr for the for loop primitive

Args:
consts_slice: Tuple (start, stop, step) to slice consts from plxpr_invals
args_slice: Tuple (start, stop, step) to slice args from plxpr_invals
abstract_shapes_slice: Tuple (start, stop, step) to slice abstract shapes
"""
assert jaxpr_body_fn is not None
args = plxpr_invals[args_slice]
args = plxpr_invals[_tuple_to_slice(args_slice)]

# Add the iteration start and the qreg to the args
self.init_qreg.insert_all_dangling_qubits()
Expand All @@ -268,7 +294,7 @@ def handle_for_loop(
self.init_qreg.get(),
]

consts = plxpr_invals[consts_slice]
consts = plxpr_invals[_tuple_to_slice(consts_slice)]

jaxpr = ClosedJaxpr(jaxpr_body_fn, consts)

Expand Down Expand Up @@ -326,10 +352,16 @@ def workflow_while_loop(
cond_slice,
args_slice,
):
"""Handle the conversion from plxpr to Catalyst jaxpr for the while loop primitive"""
consts_body = plxpr_invals[body_slice]
consts_cond = plxpr_invals[cond_slice]
args = plxpr_invals[args_slice]
"""Handle the conversion from plxpr to Catalyst jaxpr for the while loop primitive

Args:
body_slice: Tuple (start, stop, step) to slice body consts from plxpr_invals
cond_slice: Tuple (start, stop, step) to slice cond consts from plxpr_invals
args_slice: Tuple (start, stop, step) to slice args from plxpr_invals
"""
consts_body = plxpr_invals[_tuple_to_slice(body_slice)]
consts_cond = plxpr_invals[_tuple_to_slice(cond_slice)]
args = plxpr_invals[_tuple_to_slice(args_slice)]

evaluator_body = partial(copy(self).eval, jaxpr_body_fn, consts_body)
new_body_jaxpr = jax.make_jaxpr(evaluator_body)(*args)
Expand Down Expand Up @@ -367,14 +399,20 @@ def handle_while_loop(
cond_slice,
args_slice,
):
"""Handle the conversion from plxpr to Catalyst jaxpr for the while loop primitive"""
"""Handle the conversion from plxpr to Catalyst jaxpr for the while loop primitive

Args:
body_slice: Tuple (start, stop, step) to slice body consts from plxpr_invals
cond_slice: Tuple (start, stop, step) to slice cond consts from plxpr_invals
args_slice: Tuple (start, stop, step) to slice args from plxpr_invals
"""
self.init_qreg.insert_all_dangling_qubits()
dynalloced_qregs, dynalloced_wire_global_indices = _get_dynamically_allocated_qregs(
plxpr_invals, self.qubit_index_recorder, self.init_qreg
)
consts_body = plxpr_invals[body_slice]
consts_cond = plxpr_invals[cond_slice]
args = plxpr_invals[args_slice]
consts_body = plxpr_invals[_tuple_to_slice(body_slice)]
consts_cond = plxpr_invals[_tuple_to_slice(cond_slice)]
args = plxpr_invals[_tuple_to_slice(args_slice)]
args_plus_qreg = [
*args,
*[dyn_qreg.get() for dyn_qreg in dynalloced_qregs],
Expand Down
Loading