Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
50 commits
Select commit Hold shift + click to select a range
df0c822
Update LLVM version
rniczh Oct 12, 2025
e7d2a5b
Update .dep-versions
rniczh Oct 16, 2025
b1082d0
Fix formatting
rniczh Oct 16, 2025
9e8e1f2
Merge branch 'main' into rniczh/update-llvm-version-20251009
rniczh Oct 16, 2025
c06245b
mock register_traceback_file_exclusion
rniczh Oct 27, 2025
28ba2da
Merge branch 'main' into rniczh/update-llvm-version-20251009
rniczh Oct 27, 2025
4024075
fix formatting
rniczh Oct 27, 2025
496fc73
fix formatting
rniczh Oct 27, 2025
3fb8611
fix pylint
rniczh Oct 27, 2025
0e1543c
fix pylint
rniczh Oct 27, 2025
0795360
merge from upstream
rniczh Oct 28, 2025
87b7de2
update commit hash
rniczh Oct 28, 2025
61cc8b4
fix coverage
rniczh Oct 28, 2025
5fc49be
revert
rniczh Oct 28, 2025
21b9097
fix coverage
rniczh Oct 28, 2025
45b3f9d
fix pylint
rniczh Oct 28, 2025
3d07aac
patch ods_cext with patcher
rniczh Oct 29, 2025
11aee12
move the patch to jax_primitives.py
rniczh Oct 29, 2025
096d7a2
fix formatting
rniczh Oct 29, 2025
f9f7763
fix formatting
rniczh Oct 29, 2025
875ef7c
Merge branch 'main' into rniczh/update-llvm-version-20251009
rniczh Oct 29, 2025
73d8713
remove redundant
rniczh Oct 29, 2025
6b518df
Update frontend/catalyst/jax_primitives.py
rniczh Oct 29, 2025
cb5df00
resolve conflict
rniczh Nov 4, 2025
0b4591a
Merge branch 'main' into rniczh/update-llvm-version-20251009
rniczh Nov 4, 2025
2ccfbae
try to fix decompose failure
rniczh Nov 4, 2025
028ad0e
fix decomp
rniczh Nov 4, 2025
8add623
fix test
rniczh Nov 4, 2025
f4615f3
fix
rniczh Nov 4, 2025
3a71e65
fix
rniczh Nov 4, 2025
0545b4b
fix
rniczh Nov 4, 2025
1a093cb
test
rniczh Nov 4, 2025
22870fa
CI test
rniczh Nov 4, 2025
5e93dbd
fix decomp
rniczh Nov 4, 2025
257a203
fix formatting
rniczh Nov 4, 2025
b3c57c6
Merge branch 'main' into rniczh/update-llvm-version-20251009
rniczh Nov 4, 2025
6dafb84
fix
rniczh Nov 4, 2025
fd46638
update
rniczh Nov 4, 2025
1b5a08d
Merge branch 'rniczh/update-llvm-version-20251009' of github.com:Penn…
rniczh Nov 4, 2025
82068da
update and trigger compiler to recompile hpp
rniczh Nov 5, 2025
9c9ba77
include decompose impl
rniczh Nov 5, 2025
0f1e68a
update
rniczh Nov 5, 2025
828b9ae
formatting
rniczh Nov 5, 2025
cd5aad3
debug CI
rniczh Nov 5, 2025
df79232
formatting
rniczh Nov 5, 2025
b37de9d
debug on CI
rniczh Nov 5, 2025
0bce06b
fix
rniczh Nov 5, 2025
f537360
fix
rniczh Nov 5, 2025
9360e8d
fix
rniczh Nov 5, 2025
232ca6c
fix
rniczh Nov 5, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions .dep-versions
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,9 @@
# To update JAX version alongside compatible dependency tags, run the following script:
# python3 .github/workflows/set_dep_versions.py {JAX_version}
jax=0.6.2
stablehlo=69d6dae46e1c7de36e6e6973654754f05353cba5
llvm=f8cb7987c64dcffb72414a40560055cb717dbf74
enzyme=v0.0.186
stablehlo=0a4440a5c8de45c4f9649bf3eb4913bf3f97da0d
llvm=113f01aa82d055410f22a9d03b3468fa68600589
enzyme=v0.0.203

# Always remove custom PL/LQ versions before release.

Expand Down
23 changes: 23 additions & 0 deletions frontend/catalyst/jax_extras/patches.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,29 @@
)


def mock_attributes(obj, attrs: dict[str, any]):
"""Mock the attribute of an object by returning a wrapper.
Args:
obj: The object to mock the attributes of.
attrs: A dictionary of attributes to mock.
Example: {"attribute_name": attribute_value}
"""

class MockAttributeWrapper:
"""Wrapper to mock the attribute of an object."""

def __init__(self, original):
self.original = original

def __getattr__(self, name):
if name in attrs:
return attrs[name]
return getattr(self.original, name)

return MockAttributeWrapper(obj)


def _drop_unused_vars2(jaxpr, constvals):
"""
A patch to not drop unused vars during classical tracing of control flow.
Expand Down
138 changes: 80 additions & 58 deletions frontend/catalyst/jax_primitives.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
from jax.interpreters import mlir
from jax.tree_util import PyTreeDef, tree_unflatten
from jaxlib.hlo_helpers import shape_dtype_to_ir_type
from jaxlib.mlir._mlir_libs import _mlir as _ods_cext
from jaxlib.mlir.dialects.arith import (
AddIOp,
CeilDivSIOp,
Expand All @@ -52,54 +53,85 @@
from jaxlib.mlir.dialects.scf import ConditionOp, ForOp, IfOp, WhileOp, YieldOp
from jaxlib.mlir.dialects.stablehlo import ConstantOp as StableHLOConstantOp
from jaxlib.mlir.dialects.stablehlo import ConvertOp as StableHLOConvertOp
from mlir_quantum.dialects.catalyst import (
AssertionOp,
CallbackCallOp,
CallbackOp,
PrintOp,
)
from mlir_quantum.dialects.gradient import (
CustomGradOp,
ForwardOp,
GradOp,
JVPOp,
ReverseOp,
ValueAndGradOp,
VJPOp,
)
from mlir_quantum.dialects.mbqc import MeasureInBasisOp
from mlir_quantum.dialects.mitigation import ZneOp
from mlir_quantum.dialects.quantum import (
AdjointOp,
AllocOp,
ComputationalBasisOp,
CountsOp,
CustomOp,
DeallocOp,
DeallocQubitOp,
DeviceInitOp,
DeviceReleaseOp,
ExpvalOp,
ExtractOp,
GlobalPhaseOp,
HamiltonianOp,
HermitianOp,
InsertOp,
MeasureOp,
MultiRZOp,
NamedObsOp,
NumQubitsOp,
PCPhaseOp,
ProbsOp,
QubitUnitaryOp,
SampleOp,
SetBasisStateOp,
SetStateOp,
StateOp,
TensorOp,
VarianceOp,
)
from mlir_quantum.dialects.quantum import YieldOp as QYieldOp

# TODO: remove after jax v0.7.2 upgrade
# Mock _ods_cext.globals.register_traceback_file_exclusion due to API conflicts between
# Catalyst's MLIR version and the MLIR version used by JAX. The current JAX version has not
# yet updated to the latest MLIR, causing compatibility issues. This workaround will be removed
# once JAX updates to a compatible MLIR version
# pylint: disable=ungrouped-imports
from catalyst.jax_extras.patches import mock_attributes
from catalyst.utils.patching import Patcher

with Patcher(
(
_ods_cext,
"globals",
mock_attributes(
# pylint: disable=c-extension-no-member
_ods_cext.globals,
{"register_traceback_file_exclusion": lambda x: None},
),
),
):
from mlir_quantum.dialects.catalyst import (
AssertionOp,
CallbackCallOp,
CallbackOp,
PrintOp,
)
from mlir_quantum.dialects.gradient import (
CustomGradOp,
ForwardOp,
GradOp,
JVPOp,
ReverseOp,
ValueAndGradOp,
VJPOp,
)
from mlir_quantum.dialects.mbqc import MeasureInBasisOp
from mlir_quantum.dialects.mitigation import ZneOp
from mlir_quantum.dialects.quantum import (
AdjointOp,
AllocOp,
ComputationalBasisOp,
CountsOp,
CustomOp,
DeallocOp,
DeallocQubitOp,
DeviceInitOp,
DeviceReleaseOp,
ExpvalOp,
ExtractOp,
GlobalPhaseOp,
HamiltonianOp,
HermitianOp,
InsertOp,
MeasureOp,
MultiRZOp,
NamedObsOp,
NumQubitsOp,
PCPhaseOp,
ProbsOp,
QubitUnitaryOp,
SampleOp,
SetBasisStateOp,
SetStateOp,
StateOp,
TensorOp,
VarianceOp,
)
from mlir_quantum.dialects.quantum import YieldOp as QYieldOp
from catalyst.jax_primitives_utils import (
cache,
create_call_op,
get_cached,
get_call_jaxpr,
get_symbolref,
lower_callable,
lower_jaxpr,
)

from pennylane.capture.primitives import jacobian_prim as pl_jac_prim

from catalyst.compiler import get_lib_path
Expand All @@ -111,19 +143,9 @@
infer_output_type_jaxpr,
while_loop_expansion_strategy,
)
from catalyst.jax_primitives_utils import (
cache,
create_call_op,
get_cached,
get_call_jaxpr,
get_symbolref,
lower_callable,
lower_jaxpr,
)
from catalyst.utils.calculate_grad_shape import Signature, calculate_grad_shape
from catalyst.utils.exceptions import CompileError
from catalyst.utils.extra_bindings import FromElementsOp, TensorExtractOp
from catalyst.utils.patching import Patcher
from catalyst.utils.types import convert_shaped_arrays_to_tensors

# pylint: disable=unused-argument,too-many-lines,too-many-statements,protected-access
Expand Down
114 changes: 114 additions & 0 deletions frontend/test/pytest/test_jax_extras_patches.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,114 @@
# Copyright 2025 Xanadu Quantum Technologies Inc.

# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at

# http://www.apache.org/licenses/LICENSE-2.0

# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for the jax_extras.patches module"""

from catalyst.jax_extras.patches import mock_attributes


# pylint: disable=missing-class-docstring,missing-function-docstring
class TestMockAttributes:
"""Test the mock_attributes function and MockAttributeWrapper class."""

def test_mock_attributes_returns_mocked_value(self):
"""Test that accessing a mocked attribute returns the mocked value."""

class DummyClass:
def __init__(self):
self.original_attr = "original"

obj = DummyClass()
mocked = mock_attributes(obj, {"mocked_attr": "mocked_value"})

# Access the mocked attribute - this should come from the attrs dict
assert mocked.mocked_attr == "mocked_value"

def test_mock_attributes_returns_original_value(self):
"""Test that accessing an unmocked attribute returns the original value."""

class DummyClass:
def __init__(self):
self.original_attr = "original"

obj = DummyClass()
mocked = mock_attributes(obj, {"mocked_attr": "mocked_value"})

# Access the original attribute - this should come from the original object
# This tests the else branch in __getattr__
assert mocked.original_attr == "original"

def test_mock_attributes_with_methods(self):
"""Test that calling original methods works through the wrapper."""

class DummyClass:
def __init__(self):
self.value = 42

def get_value(self):
return self.value

obj = DummyClass()

def mocked_method():
return "mocked"

mocked = mock_attributes(obj, {"mocked_method": mocked_method})

# Access the mocked method
assert mocked.mocked_method() == "mocked"

# Access the original method - tests the getattr fallback
assert mocked.get_value() == 42

def test_mock_attributes_with_callable(self):
"""Test mocking with callable attributes like lambda functions."""

class DummyClass:
def __init__(self):
self.original_func = lambda x: x * 2

obj = DummyClass()
mocked = mock_attributes(obj, {"new_func": lambda x: x * 3})

# Access the mocked callable
assert mocked.new_func(5) == 15

# Access the original callable - tests the getattr fallback
assert mocked.original_func(5) == 10

def test_mock_attributes_override_existing(self):
"""Test that mocking can override existing attributes."""

class DummyClass:
def __init__(self):
self.attr = "original"

obj = DummyClass()
mocked = mock_attributes(obj, {"attr": "overridden"})

# The mocked value should take precedence
assert mocked.attr == "overridden"

def test_mock_attributes_stores_original(self):
"""Test that the original object is accessible through the wrapper."""

class DummyClass:
def __init__(self):
self.value = 100

obj = DummyClass()
mocked = mock_attributes(obj, {})

# The wrapper should store the original object
assert mocked.original is obj
assert mocked.original.value == 100
2 changes: 1 addition & 1 deletion mlir/Enzyme
Submodule Enzyme updated 153 files
2 changes: 1 addition & 1 deletion mlir/Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,7 @@ enzyme:
-DCMAKE_CXX_VISIBILITY_PRESET=$(SYMBOL_VISIBILITY) \
-DCMAKE_POLICY_DEFAULT_CMP0116=NEW

cmake --build $(ENZYME_BUILD_DIR) --target EnzymeStatic-21
cmake --build $(ENZYME_BUILD_DIR) --target EnzymeStatic-22

.PHONY: plugin
plugin:
Expand Down
13 changes: 11 additions & 2 deletions mlir/include/Catalyst/IR/CatalystOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -138,8 +138,17 @@ def CallbackOp : Catalyst_Op<"callback",

let builders = [OpBuilder<(ins
"mlir::StringRef":$name, "mlir::FunctionType":$type,
CArg<"mlir::ArrayRef<mlir::NamedAttribute>", "{}">:$attrs)
>];
CArg<"mlir::ArrayRef<mlir::NamedAttribute>", "{}">:$attrs), [{
$_state.addAttribute(mlir::SymbolTable::getSymbolAttrName(),
$_builder.getStringAttr(name));
$_state.addAttribute("function_type", mlir::TypeAttr::get(type));
$_state.addAttribute("id", $_builder.getI64IntegerAttr(0));
$_state.addAttribute("argc", $_builder.getI64IntegerAttr(type.getNumInputs()));
$_state.addAttribute("resc", $_builder.getI64IntegerAttr(type.getNumResults()));
$_state.attributes.append(attrs.begin(), attrs.end());
$_state.addRegion();
}]>
];

let extraClassDeclaration = [{
//===------------------------------------------------------------------===//
Expand Down
Loading