From df0c822a8c4b4290b5936d323e1c2808ca7bb366 Mon Sep 17 00:00:00 2001 From: Hong-Sheng Zheng Date: Sun, 12 Oct 2025 13:58:50 -0400 Subject: [PATCH 01/42] Update LLVM version --- mlir/Makefile | 2 +- mlir/include/Catalyst/IR/CatalystOps.td | 13 +++++-- mlir/include/Gradient/IR/GradientOps.td | 35 ++++++++++++++----- mlir/lib/Driver/CMakeLists.txt | 3 ++ mlir/lib/Driver/Pipelines.cpp | 19 ++++++++-- .../BufferizableOpInterfaceImpl.cpp | 14 ++++---- .../Transforms/GradMethods/JVPVJPPatterns.cpp | 22 ++++-------- mlir/tools/quantum-lsp-server/CMakeLists.txt | 1 + mlir/tools/quantum-opt/CMakeLists.txt | 2 ++ 9 files changed, 75 insertions(+), 36 deletions(-) diff --git a/mlir/Makefile b/mlir/Makefile index 8fc76e11e7..4628b99bd7 100644 --- a/mlir/Makefile +++ b/mlir/Makefile @@ -138,7 +138,7 @@ enzyme: -DCMAKE_CXX_VISIBILITY_PRESET=$(SYMBOL_VISIBILITY) \ -DCMAKE_POLICY_DEFAULT_CMP0116=NEW - cmake --build $(ENZYME_BUILD_DIR) --target EnzymeStatic-21 + cmake --build $(ENZYME_BUILD_DIR) --target EnzymeStatic-22 .PHONY: plugin plugin: diff --git a/mlir/include/Catalyst/IR/CatalystOps.td b/mlir/include/Catalyst/IR/CatalystOps.td index 12daf4f6e9..c3c60fc840 100644 --- a/mlir/include/Catalyst/IR/CatalystOps.td +++ b/mlir/include/Catalyst/IR/CatalystOps.td @@ -138,8 +138,17 @@ def CallbackOp : Catalyst_Op<"callback", let builders = [OpBuilder<(ins "mlir::StringRef":$name, "mlir::FunctionType":$type, - CArg<"mlir::ArrayRef", "{}">:$attrs) - >]; + CArg<"mlir::ArrayRef", "{}">:$attrs), [{ + $_state.addAttribute(mlir::SymbolTable::getSymbolAttrName(), + $_builder.getStringAttr(name)); + $_state.addAttribute("function_type", mlir::TypeAttr::get(type)); + $_state.addAttribute("id", $_builder.getI64IntegerAttr(0)); + $_state.addAttribute("argc", $_builder.getI64IntegerAttr(type.getNumInputs())); + $_state.addAttribute("resc", $_builder.getI64IntegerAttr(type.getNumResults())); + $_state.attributes.append(attrs.begin(), attrs.end()); + $_state.addRegion(); + }]> + ]; let extraClassDeclaration = [{ //===------------------------------------------------------------------===// diff --git a/mlir/include/Gradient/IR/GradientOps.td b/mlir/include/Gradient/IR/GradientOps.td index fb81419b99..75905049aa 100644 --- a/mlir/include/Gradient/IR/GradientOps.td +++ b/mlir/include/Gradient/IR/GradientOps.td @@ -28,7 +28,7 @@ include "Gradient/IR/GradientInterfaces.td" def GradOp : Gradient_Op<"grad", [ DeclareOpInterfaceMethods, - DeclareOpInterfaceMethods, + DeclareOpInterfaceMethods, GradientOpInterface ]> { let summary = "Compute the gradient of a function."; @@ -287,7 +287,7 @@ def ForwardOp : Gradient_Op<"forward", Then: followed by the original return type, if any. - + since there is none, then: %returnTy = { %tape } @@ -302,7 +302,7 @@ def ForwardOp : Gradient_Op<"forward", One thing that was found experimentally and through tests in Enzyme is that the tape can also be a pointer. We use this in the case when there is no tape to return. Instead of returning an empty struct, we return a null pointer that is just never dereferenced. - + }]; let arguments = (ins @@ -320,8 +320,18 @@ def ForwardOp : Gradient_Op<"forward", let builders = [OpBuilder<(ins "mlir::StringRef":$name, "mlir::FunctionType":$type, - CArg<"mlir::ArrayRef", "{}">:$attrs) - >]; + CArg<"mlir::ArrayRef", "{}">:$attrs), [{ + $_state.addAttribute(mlir::SymbolTable::getSymbolAttrName(), + $_builder.getStringAttr(name)); + $_state.addAttribute("function_type", mlir::TypeAttr::get(type)); + $_state.addAttribute("implementation", mlir::FlatSymbolRefAttr::get($_builder.getStringAttr(""))); + $_state.addAttribute("argc", $_builder.getI64IntegerAttr(0)); + $_state.addAttribute("resc", $_builder.getI64IntegerAttr(0)); + $_state.addAttribute("tape", $_builder.getI64IntegerAttr(0)); + $_state.attributes.append(attrs.begin(), attrs.end()); + $_state.addRegion(); + }]> + ]; let extraClassDeclaration = [{ //===------------------------------------------------------------------===// @@ -358,7 +368,6 @@ def ReverseOp : Gradient_Op<"reverse", %returnTy = { %tape } - }]; let arguments = (ins @@ -376,8 +385,18 @@ def ReverseOp : Gradient_Op<"reverse", let builders = [OpBuilder<(ins "mlir::StringRef":$name, "mlir::FunctionType":$type, - CArg<"mlir::ArrayRef", "{}">:$attrs) - >]; + CArg<"mlir::ArrayRef", "{}">:$attrs), [{ + $_state.addAttribute(mlir::SymbolTable::getSymbolAttrName(), + $_builder.getStringAttr(name)); + $_state.addAttribute("function_type", mlir::TypeAttr::get(type)); + $_state.addAttribute("implementation", mlir::FlatSymbolRefAttr::get($_builder.getStringAttr(""))); + $_state.addAttribute("argc", $_builder.getI64IntegerAttr(0)); + $_state.addAttribute("resc", $_builder.getI64IntegerAttr(0)); + $_state.addAttribute("tape", $_builder.getI64IntegerAttr(0)); + $_state.attributes.append(attrs.begin(), attrs.end()); + $_state.addRegion(); + }]> + ]; let extraClassDeclaration = [{ //===------------------------------------------------------------------===// diff --git a/mlir/lib/Driver/CMakeLists.txt b/mlir/lib/Driver/CMakeLists.txt index c3a82be9ec..5ec6857426 100644 --- a/mlir/lib/Driver/CMakeLists.txt +++ b/mlir/lib/Driver/CMakeLists.txt @@ -27,6 +27,9 @@ set(LIBS ${translation_libs} ExternalStablehloLib MLIROptLib + MLIRRegisterAllDialects + MLIRRegisterAllPasses + MLIRRegisterAllExtensions MLIRCatalyst catalyst-transforms MLIRQuantum diff --git a/mlir/lib/Driver/Pipelines.cpp b/mlir/lib/Driver/Pipelines.cpp index ccf4e4ab4f..564b2c2f07 100644 --- a/mlir/lib/Driver/Pipelines.cpp +++ b/mlir/lib/Driver/Pipelines.cpp @@ -14,9 +14,22 @@ #include -#include "mlir/Dialect/Bufferization/IR/Bufferization.h" -#include "mlir/InitAllDialects.h" -#include "mlir/InitAllPasses.h" +#include "mlir/Conversion/AffineToStandard/AffineToStandard.h" +#include "mlir/Conversion/ArithToLLVM/ArithToLLVM.h" +#include "mlir/Conversion/BufferizationToMemRef/BufferizationToMemRef.h" +#include "mlir/Conversion/ComplexToLLVM/ComplexToLLVM.h" +#include "mlir/Conversion/ComplexToStandard/ComplexToStandard.h" +#include "mlir/Conversion/IndexToLLVM/IndexToLLVM.h" +#include "mlir/Conversion/MathToLLVM/MathToLLVM.h" +#include "mlir/Conversion/MathToLibm/MathToLibm.h" +#include "mlir/Conversion/MemRefToLLVM/MemRefToLLVM.h" +#include "mlir/Conversion/ReconcileUnrealizedCasts/ReconcileUnrealizedCasts.h" +#include "mlir/Conversion/SCFToControlFlow/SCFToControlFlow.h" +#include "mlir/Conversion/TensorToLinalg/TensorToLinalgPass.h" +#include "mlir/Dialect/Arith/Transforms/Passes.h" +#include "mlir/Dialect/Bufferization/Transforms/Passes.h" +#include "mlir/Dialect/Linalg/Passes.h" +#include "mlir/Dialect/MemRef/Transforms/Passes.h" #include "mlir/Pass/PassManager.h" #include "mlir/Transforms/Passes.h" #include "stablehlo/conversions/linalg/transforms/Passes.h" diff --git a/mlir/lib/Gradient/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Gradient/Transforms/BufferizableOpInterfaceImpl.cpp index 54cbc85aac..5d232a4ec2 100644 --- a/mlir/lib/Gradient/Transforms/BufferizableOpInterfaceImpl.cpp +++ b/mlir/lib/Gradient/Transforms/BufferizableOpInterfaceImpl.cpp @@ -124,7 +124,7 @@ void TensorType2MemrefType(const TypeRange &inTypes, SmallVector &converte } } -static BaseMemRefType +static bufferization::BufferLikeType getBufferizedFunctionArgType(FunctionOpInterface funcOp, int64_t index, const bufferization::BufferizationOptions &options) { @@ -134,7 +134,7 @@ getBufferizedFunctionArgType(FunctionOpInterface funcOp, int64_t index, BaseMemRefType memrefType = options.functionArgTypeConverterFn( tensorType, *options.defaultMemorySpaceFn(tensorType), nullptr, options); - return memrefType; + return cast(memrefType); } static ReturnOp getAssumedUniqueReturnOp(FunctionOpInterface funcOp) @@ -402,7 +402,7 @@ struct ForwardOpInterface return {}; } - FailureOr getBufferType(Operation *op, Value value, + FailureOr getBufferType(Operation *op, Value value, const bufferization::BufferizationOptions &options, const bufferization::BufferizationState &state, SmallVector &invocationStack) const @@ -526,10 +526,10 @@ struct ReverseOpInterface return {}; } - FailureOr getBufferType(Operation *op, Value value, - const bufferization::BufferizationOptions &options, - const bufferization::BufferizationState &state, - SmallVector &invocationStack) const + FailureOr + getBufferType(Operation *op, Value value, const bufferization::BufferizationOptions &options, + const bufferization::BufferizationState &state, + SmallVector &invocationStack) const { // See comment on the getBufferType() method on forward op. auto reverseOp = cast(op); diff --git a/mlir/lib/Gradient/Transforms/GradMethods/JVPVJPPatterns.cpp b/mlir/lib/Gradient/Transforms/GradMethods/JVPVJPPatterns.cpp index 672819d3e7..ad0e26cdf0 100644 --- a/mlir/lib/Gradient/Transforms/GradMethods/JVPVJPPatterns.cpp +++ b/mlir/lib/Gradient/Transforms/GradMethods/JVPVJPPatterns.cpp @@ -23,6 +23,7 @@ #include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Dialect/Linalg/IR/Linalg.h" +#include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/IR/SymbolTable.h" #include "Gradient/Utils/EinsumLinalgGeneric.h" @@ -60,8 +61,6 @@ template std::vector _tovec(const T &x) LogicalResult JVPLoweringPattern::matchAndRewrite(JVPOp op, PatternRewriter &rewriter) const { - MLIRContext *ctx = getContext(); - Location loc = op.getLoc(); auto func_diff_operand_indices = computeDiffArgIndices(op.getDiffArgIndices()); @@ -159,12 +158,9 @@ LogicalResult JVPLoweringPattern::matchAndRewrite(JVPOp op, PatternRewriter &rew } else { assert(acc.value().getType() == res.getType()); - - auto add_op = rewriter.create( - loc, res.getType(), ValueRange({acc.value(), res}), acc.value(), - linalg::BinaryFnAttr::get(ctx, linalg::BinaryFn::add), - linalg::TypeFnAttr::get(ctx, linalg::TypeFn::cast_signed)); - acc = add_op.getResultTensors()[0]; + auto addOp = rewriter.create( + loc, res.getType(), ValueRange{acc.value(), res}, ValueRange{acc.value()}); + acc = addOp.getResultTensors()[0]; } } assert(acc.has_value()); @@ -181,8 +177,6 @@ LogicalResult JVPLoweringPattern::matchAndRewrite(JVPOp op, PatternRewriter &rew LogicalResult VJPLoweringPattern::matchAndRewrite(VJPOp op, PatternRewriter &rewriter) const { - MLIRContext *ctx = getContext(); - Location loc = op.getLoc(); auto func_diff_operand_indices = computeDiffArgIndices(op.getDiffArgIndices()); @@ -278,11 +272,9 @@ LogicalResult VJPLoweringPattern::matchAndRewrite(VJPOp op, PatternRewriter &rew else { assert(acc.value().getType() == res.getType()); - auto add_op = rewriter.create( - loc, res.getType(), ValueRange({acc.value(), res}), acc.value(), - linalg::BinaryFnAttr::get(ctx, linalg::BinaryFn::add), - linalg::TypeFnAttr::get(ctx, linalg::TypeFn::cast_signed)); - acc = add_op.getResultTensors()[0]; + auto addOp = rewriter.create( + loc, res.getType(), ValueRange{acc.value(), res}, ValueRange{acc.value()}); + acc = addOp.getResultTensors()[0]; } } assert(acc.has_value()); diff --git a/mlir/tools/quantum-lsp-server/CMakeLists.txt b/mlir/tools/quantum-lsp-server/CMakeLists.txt index f4a7c2e727..507480ef00 100644 --- a/mlir/tools/quantum-lsp-server/CMakeLists.txt +++ b/mlir/tools/quantum-lsp-server/CMakeLists.txt @@ -5,6 +5,7 @@ set(LIBS ${conversion_libs} ExternalStablehloLib MLIRLspServerLib + MLIRRegisterAllDialects MLIRCatalyst MLIRQuantum MLIRQEC diff --git a/mlir/tools/quantum-opt/CMakeLists.txt b/mlir/tools/quantum-opt/CMakeLists.txt index 10c6ed5a0f..617398b03c 100644 --- a/mlir/tools/quantum-opt/CMakeLists.txt +++ b/mlir/tools/quantum-opt/CMakeLists.txt @@ -7,6 +7,8 @@ set(LIBS ${extension_libs} ExternalStablehloLib MLIROptLib + MLIRRegisterAllDialects + MLIRRegisterAllPasses MLIRCatalyst catalyst-transforms catalyst-stablehlo-transforms From e7d2a5bdb7b56ebe817ba04f299e048e26b34f27 Mon Sep 17 00:00:00 2001 From: Hong-Sheng Zheng Date: Thu, 16 Oct 2025 13:53:00 -0400 Subject: [PATCH 02/42] Update .dep-versions --- .dep-versions | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/.dep-versions b/.dep-versions index df5a76e781..33496f09da 100644 --- a/.dep-versions +++ b/.dep-versions @@ -2,9 +2,9 @@ # To update JAX version alongside compatible dependency tags, run the following script: # python3 .github/workflows/set_dep_versions.py {JAX_version} jax=0.6.2 -stablehlo=69d6dae46e1c7de36e6e6973654754f05353cba5 -llvm=f8cb7987c64dcffb72414a40560055cb717dbf74 -enzyme=v0.0.186 +stablehlo=0a4440a5c8de45c4f9649bf3eb4913bf3f97da0d +llvm=113f01aa82d055410f22a9d03b3468fa68600589 +enzyme=v0.0.203 # Always remove custom PL/LQ versions before release. From b1082d0a0ffa38624aedff39297c9fcf9cd80058 Mon Sep 17 00:00:00 2001 From: Hong-Sheng Zheng Date: Thu, 16 Oct 2025 13:54:57 -0400 Subject: [PATCH 03/42] Fix formatting --- .../Gradient/Transforms/BufferizableOpInterfaceImpl.cpp | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/mlir/lib/Gradient/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Gradient/Transforms/BufferizableOpInterfaceImpl.cpp index 5d232a4ec2..4397aafab2 100644 --- a/mlir/lib/Gradient/Transforms/BufferizableOpInterfaceImpl.cpp +++ b/mlir/lib/Gradient/Transforms/BufferizableOpInterfaceImpl.cpp @@ -402,10 +402,10 @@ struct ForwardOpInterface return {}; } - FailureOr getBufferType(Operation *op, Value value, - const bufferization::BufferizationOptions &options, - const bufferization::BufferizationState &state, - SmallVector &invocationStack) const + FailureOr + getBufferType(Operation *op, Value value, const bufferization::BufferizationOptions &options, + const bufferization::BufferizationState &state, + SmallVector &invocationStack) const { // The getBufferType() method is called on either BlockArguments or OpResults. // https://github.com/llvm/llvm-project/blob/main/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.td#L506 From c06245b417c42ae30fd2899fb726103fb8a2d8f9 Mon Sep 17 00:00:00 2001 From: Hong-Sheng Zheng Date: Sun, 26 Oct 2025 22:27:34 -0400 Subject: [PATCH 04/42] mock register_traceback_file_exclusion --- frontend/catalyst/__init__.py | 11 +++++++++++ frontend/catalyst/jax_extras/patches.py | 23 +++++++++++++++++++++++ 2 files changed, 34 insertions(+) diff --git a/frontend/catalyst/__init__.py b/frontend/catalyst/__init__.py index eeaf6aa1be..362d14b70e 100644 --- a/frontend/catalyst/__init__.py +++ b/frontend/catalyst/__init__.py @@ -68,6 +68,17 @@ sys.modules["mlir_quantum.ir"] = __import__("jaxlib.mlir.ir").mlir.ir sys.modules["mlir_quantum._mlir_libs"] = __import__("jaxlib.mlir._mlir_libs").mlir._mlir_libs +# Mock _ods_cext.globals.register_traceback_file_exclusion due to API conflicts between +# Catalyst's MLIR version and the MLIR version used by JAX. The current JAX version has not +# yet updated to the latest MLIR, causing compatibility issues. This workaround will be removed +# once JAX updates to a compatible MLIR version. +from catalyst.jax_extras.patches import mock_attributes +from jaxlib.mlir._mlir_libs import _mlir as _ods_cext + +_ods_cext.globals = mock_attributes( + _ods_cext.globals, {"register_traceback_file_exclusion": lambda x: None} +) + from catalyst import debug, logging, passes from catalyst.api_extensions import * from catalyst.api_extensions import __all__ as _api_extension_list diff --git a/frontend/catalyst/jax_extras/patches.py b/frontend/catalyst/jax_extras/patches.py index ba32f498b9..82ea21984b 100644 --- a/frontend/catalyst/jax_extras/patches.py +++ b/frontend/catalyst/jax_extras/patches.py @@ -43,6 +43,29 @@ ) +def mock_attributes(obj, attrs: dict[str, any]): + """Mock the attribute of an object by returning a wrapper. + + Args: + obj: The object to mock the attributes of. + attrs: A dictionary of attributes to mock. + Example: {"attribute_name": attribute_value} + """ + + class MockAttributeWrapper: + """Wrapper to mock the attribute of an object.""" + + def __init__(self, original): + self.original = original + + def __getattr__(self, name): + if name in attrs: + return attrs[name] + return getattr(self.original, name) + + return MockAttributeWrapper(obj) + + def _drop_unused_vars2(jaxpr, constvals): """ A patch to not drop unused vars during classical tracing of control flow. From 40240754bb0c4d5c4bca3c516cdc2adc038e6af9 Mon Sep 17 00:00:00 2001 From: Hong-Sheng Zheng Date: Sun, 26 Oct 2025 22:39:17 -0400 Subject: [PATCH 05/42] fix formatting --- frontend/catalyst/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/frontend/catalyst/__init__.py b/frontend/catalyst/__init__.py index 362d14b70e..391768cdf4 100644 --- a/frontend/catalyst/__init__.py +++ b/frontend/catalyst/__init__.py @@ -36,6 +36,7 @@ from catalyst._configuration import INSTALLED from catalyst._version import __version__ +from catalyst.jax_extras.patches import mock_attributes try: if INSTALLED: @@ -72,7 +73,6 @@ # Catalyst's MLIR version and the MLIR version used by JAX. The current JAX version has not # yet updated to the latest MLIR, causing compatibility issues. This workaround will be removed # once JAX updates to a compatible MLIR version. -from catalyst.jax_extras.patches import mock_attributes from jaxlib.mlir._mlir_libs import _mlir as _ods_cext _ods_cext.globals = mock_attributes( From 496fc73bcef68e401258aa7f2a0f105cb699e136 Mon Sep 17 00:00:00 2001 From: Hong-Sheng Zheng Date: Sun, 26 Oct 2025 22:41:09 -0400 Subject: [PATCH 06/42] fix formatting --- frontend/catalyst/__init__.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/frontend/catalyst/__init__.py b/frontend/catalyst/__init__.py index 391768cdf4..0e46b9a3e0 100644 --- a/frontend/catalyst/__init__.py +++ b/frontend/catalyst/__init__.py @@ -37,6 +37,7 @@ from catalyst._configuration import INSTALLED from catalyst._version import __version__ from catalyst.jax_extras.patches import mock_attributes +from jaxlib.mlir._mlir_libs import _mlir as _ods_cext try: if INSTALLED: @@ -73,8 +74,6 @@ # Catalyst's MLIR version and the MLIR version used by JAX. The current JAX version has not # yet updated to the latest MLIR, causing compatibility issues. This workaround will be removed # once JAX updates to a compatible MLIR version. -from jaxlib.mlir._mlir_libs import _mlir as _ods_cext - _ods_cext.globals = mock_attributes( _ods_cext.globals, {"register_traceback_file_exclusion": lambda x: None} ) From 3fb861132d005b16c83490df865357febf769742 Mon Sep 17 00:00:00 2001 From: Hong-Sheng Zheng Date: Sun, 26 Oct 2025 22:43:04 -0400 Subject: [PATCH 07/42] fix pylint --- frontend/catalyst/__init__.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/frontend/catalyst/__init__.py b/frontend/catalyst/__init__.py index 0e46b9a3e0..69adc5d313 100644 --- a/frontend/catalyst/__init__.py +++ b/frontend/catalyst/__init__.py @@ -37,7 +37,6 @@ from catalyst._configuration import INSTALLED from catalyst._version import __version__ from catalyst.jax_extras.patches import mock_attributes -from jaxlib.mlir._mlir_libs import _mlir as _ods_cext try: if INSTALLED: @@ -73,7 +72,8 @@ # Mock _ods_cext.globals.register_traceback_file_exclusion due to API conflicts between # Catalyst's MLIR version and the MLIR version used by JAX. The current JAX version has not # yet updated to the latest MLIR, causing compatibility issues. This workaround will be removed -# once JAX updates to a compatible MLIR version. +# once JAX updates to a compatible MLIR version +from jaxlib.mlir._mlir_libs import _mlir as _ods_cext # pylint: disable=ungrouped-imports _ods_cext.globals = mock_attributes( _ods_cext.globals, {"register_traceback_file_exclusion": lambda x: None} ) From 0e1543c08debfff63c87a43820d825abd9283ccb Mon Sep 17 00:00:00 2001 From: Hong-Sheng Zheng Date: Sun, 26 Oct 2025 22:46:23 -0400 Subject: [PATCH 08/42] fix pylint --- frontend/catalyst/__init__.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/frontend/catalyst/__init__.py b/frontend/catalyst/__init__.py index 69adc5d313..16eea00647 100644 --- a/frontend/catalyst/__init__.py +++ b/frontend/catalyst/__init__.py @@ -73,11 +73,13 @@ # Catalyst's MLIR version and the MLIR version used by JAX. The current JAX version has not # yet updated to the latest MLIR, causing compatibility issues. This workaround will be removed # once JAX updates to a compatible MLIR version -from jaxlib.mlir._mlir_libs import _mlir as _ods_cext # pylint: disable=ungrouped-imports +from jaxlib.mlir._mlir_libs import _mlir as _ods_cext + _ods_cext.globals = mock_attributes( _ods_cext.globals, {"register_traceback_file_exclusion": lambda x: None} ) +# pylint: disable=ungrouped-imports from catalyst import debug, logging, passes from catalyst.api_extensions import * from catalyst.api_extensions import __all__ as _api_extension_list From 87b7de2b77bbcca7917f18a844b21e5449227b36 Mon Sep 17 00:00:00 2001 From: Hong-Sheng Zheng Date: Tue, 28 Oct 2025 14:06:14 -0400 Subject: [PATCH 09/42] update commit hash --- mlir/Enzyme | 2 +- mlir/llvm-project | 2 +- mlir/stablehlo | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/mlir/Enzyme b/mlir/Enzyme index 8c1a596158..476c8e3193 160000 --- a/mlir/Enzyme +++ b/mlir/Enzyme @@ -1 +1 @@ -Subproject commit 8c1a596158f6194f10e8ffd56a1660a61c54337e +Subproject commit 476c8e3193a8577ba24ff845ae2294109225f83a diff --git a/mlir/llvm-project b/mlir/llvm-project index f8cb7987c6..113f01aa82 160000 --- a/mlir/llvm-project +++ b/mlir/llvm-project @@ -1 +1 @@ -Subproject commit f8cb7987c64dcffb72414a40560055cb717dbf74 +Subproject commit 113f01aa82d055410f22a9d03b3468fa68600589 diff --git a/mlir/stablehlo b/mlir/stablehlo index 69d6dae46e..0a4440a5c8 160000 --- a/mlir/stablehlo +++ b/mlir/stablehlo @@ -1 +1 @@ -Subproject commit 69d6dae46e1c7de36e6e6973654754f05353cba5 +Subproject commit 0a4440a5c8de45c4f9649bf3eb4913bf3f97da0d From 61cc8b4e8eebcd2125c06678577196e16161367d Mon Sep 17 00:00:00 2001 From: Hong-Sheng Zheng Date: Tue, 28 Oct 2025 14:40:44 -0400 Subject: [PATCH 10/42] fix coverage --- frontend/catalyst/jax_extras/patches.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/frontend/catalyst/jax_extras/patches.py b/frontend/catalyst/jax_extras/patches.py index 82ea21984b..e89c45c88e 100644 --- a/frontend/catalyst/jax_extras/patches.py +++ b/frontend/catalyst/jax_extras/patches.py @@ -59,9 +59,7 @@ def __init__(self, original): self.original = original def __getattr__(self, name): - if name in attrs: - return attrs[name] - return getattr(self.original, name) + return attrs.get(name, getattr(self.original, name)) return MockAttributeWrapper(obj) From 5fc49be575ce379e9b58e54a7e367a49c931a64a Mon Sep 17 00:00:00 2001 From: Hong-Sheng Zheng Date: Tue, 28 Oct 2025 15:18:44 -0400 Subject: [PATCH 11/42] revert --- frontend/catalyst/jax_extras/patches.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/frontend/catalyst/jax_extras/patches.py b/frontend/catalyst/jax_extras/patches.py index e89c45c88e..82ea21984b 100644 --- a/frontend/catalyst/jax_extras/patches.py +++ b/frontend/catalyst/jax_extras/patches.py @@ -59,7 +59,9 @@ def __init__(self, original): self.original = original def __getattr__(self, name): - return attrs.get(name, getattr(self.original, name)) + if name in attrs: + return attrs[name] + return getattr(self.original, name) return MockAttributeWrapper(obj) From 21b90972d88066ca25b3f5584c606de002a4dfde Mon Sep 17 00:00:00 2001 From: Hong-Sheng Zheng Date: Tue, 28 Oct 2025 16:50:19 -0400 Subject: [PATCH 12/42] fix coverage --- .../test/pytest/test_jax_extras_patches.py | 113 ++++++++++++++++++ 1 file changed, 113 insertions(+) create mode 100644 frontend/test/pytest/test_jax_extras_patches.py diff --git a/frontend/test/pytest/test_jax_extras_patches.py b/frontend/test/pytest/test_jax_extras_patches.py new file mode 100644 index 0000000000..0694693144 --- /dev/null +++ b/frontend/test/pytest/test_jax_extras_patches.py @@ -0,0 +1,113 @@ +# Copyright 2025 Xanadu Quantum Technologies Inc. + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Tests for the jax_extras.patches module""" + +from catalyst.jax_extras.patches import mock_attributes + + +class TestMockAttributes: + """Test the mock_attributes function and MockAttributeWrapper class.""" + + def test_mock_attributes_returns_mocked_value(self): + """Test that accessing a mocked attribute returns the mocked value.""" + + class DummyClass: + def __init__(self): + self.original_attr = "original" + + obj = DummyClass() + mocked = mock_attributes(obj, {"mocked_attr": "mocked_value"}) + + # Access the mocked attribute - this should come from the attrs dict + assert mocked.mocked_attr == "mocked_value" + + def test_mock_attributes_returns_original_value(self): + """Test that accessing an unmocked attribute returns the original value.""" + + class DummyClass: + def __init__(self): + self.original_attr = "original" + + obj = DummyClass() + mocked = mock_attributes(obj, {"mocked_attr": "mocked_value"}) + + # Access the original attribute - this should come from the original object + # This tests the else branch in __getattr__ + assert mocked.original_attr == "original" + + def test_mock_attributes_with_methods(self): + """Test that calling original methods works through the wrapper.""" + + class DummyClass: + def __init__(self): + self.value = 42 + + def get_value(self): + return self.value + + obj = DummyClass() + + def mocked_method(): + return "mocked" + + mocked = mock_attributes(obj, {"mocked_method": mocked_method}) + + # Access the mocked method + assert mocked.mocked_method() == "mocked" + + # Access the original method - tests the getattr fallback + assert mocked.get_value() == 42 + + def test_mock_attributes_with_callable(self): + """Test mocking with callable attributes like lambda functions.""" + + class DummyClass: + def __init__(self): + self.original_func = lambda x: x * 2 + + obj = DummyClass() + mocked = mock_attributes(obj, {"new_func": lambda x: x * 3}) + + # Access the mocked callable + assert mocked.new_func(5) == 15 + + # Access the original callable - tests the getattr fallback + assert mocked.original_func(5) == 10 + + def test_mock_attributes_override_existing(self): + """Test that mocking can override existing attributes.""" + + class DummyClass: + def __init__(self): + self.attr = "original" + + obj = DummyClass() + mocked = mock_attributes(obj, {"attr": "overridden"}) + + # The mocked value should take precedence + assert mocked.attr == "overridden" + + def test_mock_attributes_stores_original(self): + """Test that the original object is accessible through the wrapper.""" + + class DummyClass: + def __init__(self): + self.value = 100 + + obj = DummyClass() + mocked = mock_attributes(obj, {}) + + # The wrapper should store the original object + assert mocked.original is obj + assert mocked.original.value == 100 From 45b3f9d5d61968bf8baf5fd26e3e18b7010d7e88 Mon Sep 17 00:00:00 2001 From: Hong-Sheng Zheng Date: Tue, 28 Oct 2025 16:54:28 -0400 Subject: [PATCH 13/42] fix pylint --- frontend/test/pytest/test_jax_extras_patches.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/frontend/test/pytest/test_jax_extras_patches.py b/frontend/test/pytest/test_jax_extras_patches.py index 0694693144..57a350c6bd 100644 --- a/frontend/test/pytest/test_jax_extras_patches.py +++ b/frontend/test/pytest/test_jax_extras_patches.py @@ -15,7 +15,7 @@ from catalyst.jax_extras.patches import mock_attributes - +# pylint: disable=missing-class-docstring,missing-function-docstring class TestMockAttributes: """Test the mock_attributes function and MockAttributeWrapper class.""" From 3d07aac0858b3cf390efafffd01b986af7baa9be Mon Sep 17 00:00:00 2001 From: Hong-Sheng Zheng Date: Wed, 29 Oct 2025 12:21:10 -0400 Subject: [PATCH 14/42] patch ods_cext with patcher --- frontend/catalyst/__init__.py | 49 +++++++++++++++++++++-------------- 1 file changed, 29 insertions(+), 20 deletions(-) diff --git a/frontend/catalyst/__init__.py b/frontend/catalyst/__init__.py index 16eea00647..3fb98d9763 100644 --- a/frontend/catalyst/__init__.py +++ b/frontend/catalyst/__init__.py @@ -36,7 +36,7 @@ from catalyst._configuration import INSTALLED from catalyst._version import __version__ -from catalyst.jax_extras.patches import mock_attributes +from catalyst.utils.patching import Patcher try: if INSTALLED: @@ -75,26 +75,35 @@ # once JAX updates to a compatible MLIR version from jaxlib.mlir._mlir_libs import _mlir as _ods_cext -_ods_cext.globals = mock_attributes( - _ods_cext.globals, {"register_traceback_file_exclusion": lambda x: None} -) - # pylint: disable=ungrouped-imports -from catalyst import debug, logging, passes -from catalyst.api_extensions import * -from catalyst.api_extensions import __all__ as _api_extension_list -from catalyst.autograph import * -from catalyst.autograph import __all__ as _autograph_functions -from catalyst.compiler import CompileOptions -from catalyst.debug.assertion import debug_assert -from catalyst.jit import QJIT, qjit -from catalyst.passes.pass_api import pipeline -from catalyst.utils.exceptions import ( - AutoGraphError, - CompileError, - DifferentiableCompileError, - PlxprCaptureCFCompatibilityError, -) +from catalyst.jax_extras.patches import mock_attributes + +with Patcher( + ( + _ods_cext, + "globals", + mock_attributes( + # pylint: disable=c-extension-no-member + _ods_cext.globals, + {"register_traceback_file_exclusion": lambda x: None}, + ), + ), +): + from catalyst import debug, logging, passes + from catalyst.api_extensions import * + from catalyst.api_extensions import __all__ as _api_extension_list + from catalyst.autograph import * + from catalyst.autograph import __all__ as _autograph_functions + from catalyst.compiler import CompileOptions + from catalyst.debug.assertion import debug_assert + from catalyst.jit import QJIT, qjit + from catalyst.passes.pass_api import pipeline + from catalyst.utils.exceptions import ( + AutoGraphError, + CompileError, + DifferentiableCompileError, + PlxprCaptureCFCompatibilityError, + ) autograph_ignore_fallbacks = False """bool: Specify whether AutoGraph should avoid raising From 11aee123f6758e23e9de49de23477a29c4c82fd2 Mon Sep 17 00:00:00 2001 From: Hong-Sheng Zheng Date: Wed, 29 Oct 2025 15:03:50 -0400 Subject: [PATCH 15/42] move the patch to jax_primitives.py --- frontend/catalyst/__init__.py | 71 ++++++++------- frontend/catalyst/jax_primitives.py | 135 ++++++++++++++++------------ 2 files changed, 112 insertions(+), 94 deletions(-) diff --git a/frontend/catalyst/__init__.py b/frontend/catalyst/__init__.py index 3fb98d9763..a47aa8626c 100644 --- a/frontend/catalyst/__init__.py +++ b/frontend/catalyst/__init__.py @@ -36,7 +36,6 @@ from catalyst._configuration import INSTALLED from catalyst._version import __version__ -from catalyst.utils.patching import Patcher try: if INSTALLED: @@ -69,41 +68,41 @@ sys.modules["mlir_quantum.ir"] = __import__("jaxlib.mlir.ir").mlir.ir sys.modules["mlir_quantum._mlir_libs"] = __import__("jaxlib.mlir._mlir_libs").mlir._mlir_libs -# Mock _ods_cext.globals.register_traceback_file_exclusion due to API conflicts between -# Catalyst's MLIR version and the MLIR version used by JAX. The current JAX version has not -# yet updated to the latest MLIR, causing compatibility issues. This workaround will be removed -# once JAX updates to a compatible MLIR version -from jaxlib.mlir._mlir_libs import _mlir as _ods_cext - -# pylint: disable=ungrouped-imports -from catalyst.jax_extras.patches import mock_attributes - -with Patcher( - ( - _ods_cext, - "globals", - mock_attributes( - # pylint: disable=c-extension-no-member - _ods_cext.globals, - {"register_traceback_file_exclusion": lambda x: None}, - ), - ), -): - from catalyst import debug, logging, passes - from catalyst.api_extensions import * - from catalyst.api_extensions import __all__ as _api_extension_list - from catalyst.autograph import * - from catalyst.autograph import __all__ as _autograph_functions - from catalyst.compiler import CompileOptions - from catalyst.debug.assertion import debug_assert - from catalyst.jit import QJIT, qjit - from catalyst.passes.pass_api import pipeline - from catalyst.utils.exceptions import ( - AutoGraphError, - CompileError, - DifferentiableCompileError, - PlxprCaptureCFCompatibilityError, - ) +# # Mock _ods_cext.globals.register_traceback_file_exclusion due to API conflicts between +# # Catalyst's MLIR version and the MLIR version used by JAX. The current JAX version has not +# # yet updated to the latest MLIR, causing compatibility issues. This workaround will be removed +# # once JAX updates to a compatible MLIR version +# from jaxlib.mlir._mlir_libs import _mlir as _ods_cext + +# # pylint: disable=ungrouped-imports +# from catalyst.jax_extras.patches import mock_attributes + +# with Patcher( +# ( +# _ods_cext, +# "globals", +# mock_attributes( +# # pylint: disable=c-extension-no-member +# _ods_cext.globals, +# {"register_traceback_file_exclusion": lambda x: None}, +# ), +# ), +# ): +from catalyst import debug, logging, passes +from catalyst.api_extensions import * +from catalyst.api_extensions import __all__ as _api_extension_list +from catalyst.autograph import * +from catalyst.autograph import __all__ as _autograph_functions +from catalyst.compiler import CompileOptions +from catalyst.debug.assertion import debug_assert +from catalyst.jit import QJIT, qjit +from catalyst.passes.pass_api import pipeline +from catalyst.utils.exceptions import ( + AutoGraphError, + CompileError, + DifferentiableCompileError, + PlxprCaptureCFCompatibilityError, +) autograph_ignore_fallbacks = False """bool: Specify whether AutoGraph should avoid raising diff --git a/frontend/catalyst/jax_primitives.py b/frontend/catalyst/jax_primitives.py index fec8e41684..649ecb2307 100644 --- a/frontend/catalyst/jax_primitives.py +++ b/frontend/catalyst/jax_primitives.py @@ -52,54 +52,83 @@ from jaxlib.mlir.dialects.scf import ConditionOp, ForOp, IfOp, WhileOp, YieldOp from jaxlib.mlir.dialects.stablehlo import ConstantOp as StableHLOConstantOp from jaxlib.mlir.dialects.stablehlo import ConvertOp as StableHLOConvertOp -from mlir_quantum.dialects.catalyst import ( - AssertionOp, - CallbackCallOp, - CallbackOp, - PrintOp, -) -from mlir_quantum.dialects.gradient import ( - CustomGradOp, - ForwardOp, - GradOp, - JVPOp, - ReverseOp, - ValueAndGradOp, - VJPOp, -) -from mlir_quantum.dialects.mbqc import MeasureInBasisOp -from mlir_quantum.dialects.mitigation import ZneOp -from mlir_quantum.dialects.quantum import ( - AdjointOp, - AllocOp, - ComputationalBasisOp, - CountsOp, - CustomOp, - DeallocOp, - DeallocQubitOp, - DeviceInitOp, - DeviceReleaseOp, - ExpvalOp, - ExtractOp, - GlobalPhaseOp, - HamiltonianOp, - HermitianOp, - InsertOp, - MeasureOp, - MultiRZOp, - NamedObsOp, - NumQubitsOp, - PCPhaseOp, - ProbsOp, - QubitUnitaryOp, - SampleOp, - SetBasisStateOp, - SetStateOp, - StateOp, - TensorOp, - VarianceOp, -) -from mlir_quantum.dialects.quantum import YieldOp as QYieldOp + +# Mock _ods_cext.globals.register_traceback_file_exclusion due to API conflicts between +# Catalyst's MLIR version and the MLIR version used by JAX. The current JAX version has not +# yet updated to the latest MLIR, causing compatibility issues. This workaround will be removed +# once JAX updates to a compatible MLIR version +# pylint: disable=ungrouped-imports +from catalyst.jax_extras.patches import mock_attributes +from jaxlib.mlir._mlir_libs import _mlir as _ods_cext +from catalyst.utils.patching import Patcher +with Patcher( + ( + _ods_cext, + "globals", + mock_attributes( + # pylint: disable=c-extension-no-member + _ods_cext.globals, + {"register_traceback_file_exclusion": lambda x: None}, + ), + ), +): + from mlir_quantum.dialects.catalyst import ( + AssertionOp, + CallbackCallOp, + CallbackOp, + PrintOp, + ) + from mlir_quantum.dialects.gradient import ( + CustomGradOp, + ForwardOp, + GradOp, + JVPOp, + ReverseOp, + ValueAndGradOp, + VJPOp, + ) + from mlir_quantum.dialects.mbqc import MeasureInBasisOp + from mlir_quantum.dialects.mitigation import ZneOp + from mlir_quantum.dialects.quantum import ( + AdjointOp, + AllocOp, + ComputationalBasisOp, + CountsOp, + CustomOp, + DeallocOp, + DeallocQubitOp, + DeviceInitOp, + DeviceReleaseOp, + ExpvalOp, + ExtractOp, + GlobalPhaseOp, + HamiltonianOp, + HermitianOp, + InsertOp, + MeasureOp, + MultiRZOp, + NamedObsOp, + NumQubitsOp, + PCPhaseOp, + ProbsOp, + QubitUnitaryOp, + SampleOp, + SetBasisStateOp, + SetStateOp, + StateOp, + TensorOp, + VarianceOp, + ) + from mlir_quantum.dialects.quantum import YieldOp as QYieldOp + from catalyst.jax_primitives_utils import ( + cache, + create_call_op, + get_cached, + get_call_jaxpr, + get_symbolref, + lower_callable, + lower_jaxpr, + ) from catalyst.compiler import get_lib_path from catalyst.jax_extras import ( @@ -110,19 +139,9 @@ infer_output_type_jaxpr, while_loop_expansion_strategy, ) -from catalyst.jax_primitives_utils import ( - cache, - create_call_op, - get_cached, - get_call_jaxpr, - get_symbolref, - lower_callable, - lower_jaxpr, -) from catalyst.utils.calculate_grad_shape import Signature, calculate_grad_shape from catalyst.utils.exceptions import CompileError from catalyst.utils.extra_bindings import FromElementsOp, TensorExtractOp -from catalyst.utils.patching import Patcher from catalyst.utils.types import convert_shaped_arrays_to_tensors # pylint: disable=unused-argument,too-many-lines,too-many-statements,protected-access From 096d7a2e98d4628a4df87fa7495cfb0e1716929b Mon Sep 17 00:00:00 2001 From: Hong-Sheng Zheng Date: Wed, 29 Oct 2025 15:05:44 -0400 Subject: [PATCH 16/42] fix formatting --- frontend/catalyst/jax_primitives.py | 1 + frontend/test/pytest/test_jax_extras_patches.py | 1 + 2 files changed, 2 insertions(+) diff --git a/frontend/catalyst/jax_primitives.py b/frontend/catalyst/jax_primitives.py index 649ecb2307..24cc6381d4 100644 --- a/frontend/catalyst/jax_primitives.py +++ b/frontend/catalyst/jax_primitives.py @@ -61,6 +61,7 @@ from catalyst.jax_extras.patches import mock_attributes from jaxlib.mlir._mlir_libs import _mlir as _ods_cext from catalyst.utils.patching import Patcher + with Patcher( ( _ods_cext, diff --git a/frontend/test/pytest/test_jax_extras_patches.py b/frontend/test/pytest/test_jax_extras_patches.py index 57a350c6bd..c7ee8346d5 100644 --- a/frontend/test/pytest/test_jax_extras_patches.py +++ b/frontend/test/pytest/test_jax_extras_patches.py @@ -15,6 +15,7 @@ from catalyst.jax_extras.patches import mock_attributes + # pylint: disable=missing-class-docstring,missing-function-docstring class TestMockAttributes: """Test the mock_attributes function and MockAttributeWrapper class.""" From f9f7763dd52e5e312a2f5deca7d49ae49a2d1d45 Mon Sep 17 00:00:00 2001 From: Hong-Sheng Zheng Date: Wed, 29 Oct 2025 15:08:37 -0400 Subject: [PATCH 17/42] fix formatting --- frontend/catalyst/jax_primitives.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/frontend/catalyst/jax_primitives.py b/frontend/catalyst/jax_primitives.py index 24cc6381d4..787962d284 100644 --- a/frontend/catalyst/jax_primitives.py +++ b/frontend/catalyst/jax_primitives.py @@ -39,6 +39,7 @@ from jax.interpreters import mlir from jax.tree_util import PyTreeDef, tree_unflatten from jaxlib.hlo_helpers import shape_dtype_to_ir_type +from jaxlib.mlir._mlir_libs import _mlir as _ods_cext from jaxlib.mlir.dialects.arith import ( AddIOp, CeilDivSIOp, @@ -59,7 +60,6 @@ # once JAX updates to a compatible MLIR version # pylint: disable=ungrouped-imports from catalyst.jax_extras.patches import mock_attributes -from jaxlib.mlir._mlir_libs import _mlir as _ods_cext from catalyst.utils.patching import Patcher with Patcher( From 73d87139387098d545f212a6af189c794dd87295 Mon Sep 17 00:00:00 2001 From: Hong-Sheng Zheng Date: Wed, 29 Oct 2025 17:18:48 -0400 Subject: [PATCH 18/42] remove redundant --- frontend/catalyst/__init__.py | 20 -------------------- 1 file changed, 20 deletions(-) diff --git a/frontend/catalyst/__init__.py b/frontend/catalyst/__init__.py index a47aa8626c..eeaf6aa1be 100644 --- a/frontend/catalyst/__init__.py +++ b/frontend/catalyst/__init__.py @@ -68,26 +68,6 @@ sys.modules["mlir_quantum.ir"] = __import__("jaxlib.mlir.ir").mlir.ir sys.modules["mlir_quantum._mlir_libs"] = __import__("jaxlib.mlir._mlir_libs").mlir._mlir_libs -# # Mock _ods_cext.globals.register_traceback_file_exclusion due to API conflicts between -# # Catalyst's MLIR version and the MLIR version used by JAX. The current JAX version has not -# # yet updated to the latest MLIR, causing compatibility issues. This workaround will be removed -# # once JAX updates to a compatible MLIR version -# from jaxlib.mlir._mlir_libs import _mlir as _ods_cext - -# # pylint: disable=ungrouped-imports -# from catalyst.jax_extras.patches import mock_attributes - -# with Patcher( -# ( -# _ods_cext, -# "globals", -# mock_attributes( -# # pylint: disable=c-extension-no-member -# _ods_cext.globals, -# {"register_traceback_file_exclusion": lambda x: None}, -# ), -# ), -# ): from catalyst import debug, logging, passes from catalyst.api_extensions import * from catalyst.api_extensions import __all__ as _api_extension_list From 6b518df1a6e234a8d6f5464713de504468d25f39 Mon Sep 17 00:00:00 2001 From: Hong-Sheng Zheng Date: Wed, 29 Oct 2025 17:29:41 -0400 Subject: [PATCH 19/42] Update frontend/catalyst/jax_primitives.py Co-authored-by: David Ittah --- frontend/catalyst/jax_primitives.py | 1 + 1 file changed, 1 insertion(+) diff --git a/frontend/catalyst/jax_primitives.py b/frontend/catalyst/jax_primitives.py index 787962d284..7d710362b7 100644 --- a/frontend/catalyst/jax_primitives.py +++ b/frontend/catalyst/jax_primitives.py @@ -54,6 +54,7 @@ from jaxlib.mlir.dialects.stablehlo import ConstantOp as StableHLOConstantOp from jaxlib.mlir.dialects.stablehlo import ConvertOp as StableHLOConvertOp +# TODO: remove after jax v0.7.2 upgrade # Mock _ods_cext.globals.register_traceback_file_exclusion due to API conflicts between # Catalyst's MLIR version and the MLIR version used by JAX. The current JAX version has not # yet updated to the latest MLIR, causing compatibility issues. This workaround will be removed From 2ccfbaea41b00c6efdd04f137f55b5e4b3d580ca Mon Sep 17 00:00:00 2001 From: Hong-Sheng Zheng Date: Tue, 4 Nov 2025 14:36:54 -0500 Subject: [PATCH 20/42] try to fix decompose failure --- mlir/lib/Quantum/Transforms/DecomposeLoweringImpl.hpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mlir/lib/Quantum/Transforms/DecomposeLoweringImpl.hpp b/mlir/lib/Quantum/Transforms/DecomposeLoweringImpl.hpp index 6715808aae..3005c9fcdd 100644 --- a/mlir/lib/Quantum/Transforms/DecomposeLoweringImpl.hpp +++ b/mlir/lib/Quantum/Transforms/DecomposeLoweringImpl.hpp @@ -436,7 +436,7 @@ class MultiRZOpSignatureAnalyzer : public BaseSignatureAnalyzer { MultiRZOpSignatureAnalyzer() = delete; MultiRZOpSignatureAnalyzer(MultiRZOp op, bool enableQregMode) - : BaseSignatureAnalyzer(op, op.getTheta(), op.getNonCtrlQubitOperands(), + : BaseSignatureAnalyzer(op, mlir::ValueRange(op.getTheta()), op.getNonCtrlQubitOperands(), op.getCtrlQubitOperands(), op.getCtrlValueOperands(), op.getNonCtrlQubitResults(), op.getCtrlQubitResults(), enableQregMode) From 028ad0ec4911cc74c5785afa8e5d04c51741b12a Mon Sep 17 00:00:00 2001 From: Hong-Sheng Zheng Date: Tue, 4 Nov 2025 14:56:42 -0500 Subject: [PATCH 21/42] fix decomp --- .../Transforms/DecomposeLoweringImpl.hpp | 39 +++++++++---------- 1 file changed, 19 insertions(+), 20 deletions(-) diff --git a/mlir/lib/Quantum/Transforms/DecomposeLoweringImpl.hpp b/mlir/lib/Quantum/Transforms/DecomposeLoweringImpl.hpp index 3005c9fcdd..a0c41f8838 100644 --- a/mlir/lib/Quantum/Transforms/DecomposeLoweringImpl.hpp +++ b/mlir/lib/Quantum/Transforms/DecomposeLoweringImpl.hpp @@ -78,7 +78,7 @@ class BaseSignatureAnalyzer { protected: bool isValid = true; - llvm::SmallVector paramsStorage; + llvm::SmallVector paramsStorage; // Unified Signature Structure: All parameters, regardless of source (params or theta), // are stored in a ValueRange for generalized processing. @@ -97,21 +97,20 @@ class BaseSignatureAnalyzer { llvm::SmallVector outCtrlQubitIndices; } signature; - BaseSignatureAnalyzer(mlir::Operation *op, mlir::ValueRange params, mlir::ValueRange inQubits, - mlir::ValueRange inCtrlQubits, mlir::ValueRange inCtrlValues, - mlir::ValueRange outQubits, mlir::ValueRange outCtrlQubits, - bool enableQregMode) - : paramsStorage(params.begin(), params.end()), - signature(Signature{.params = mlir::ValueRange(paramsStorage), - .inQubits = inQubits, - .inCtrlQubits = inCtrlQubits, - .inCtrlValues = inCtrlValues, - .outQubits = outQubits, - .outCtrlQubits = outCtrlQubits, - .inWireIndices = {}, - .inCtrlWireIndices = {}, - .outQubitIndices = {}, - .outCtrlQubitIndices = {}}) + BaseSignatureAnalyzer(mlir::Operation *op, llvm::SmallVector params, + mlir::ValueRange inQubits, mlir::ValueRange inCtrlQubits, + mlir::ValueRange inCtrlValues, mlir::ValueRange outQubits, + mlir::ValueRange outCtrlQubits, bool enableQregMode) + : paramsStorage(params), signature(Signature{.params = paramsStorage, + .inQubits = inQubits, + .inCtrlQubits = inCtrlQubits, + .inCtrlValues = inCtrlValues, + .outQubits = outQubits, + .outCtrlQubits = outCtrlQubits, + .inWireIndices = {}, + .inCtrlWireIndices = {}, + .outQubitIndices = {}, + .outCtrlQubitIndices = {}}) { initializeQregMode(op, enableQregMode); } @@ -436,10 +435,10 @@ class MultiRZOpSignatureAnalyzer : public BaseSignatureAnalyzer { MultiRZOpSignatureAnalyzer() = delete; MultiRZOpSignatureAnalyzer(MultiRZOp op, bool enableQregMode) - : BaseSignatureAnalyzer(op, mlir::ValueRange(op.getTheta()), op.getNonCtrlQubitOperands(), - op.getCtrlQubitOperands(), op.getCtrlValueOperands(), - op.getNonCtrlQubitResults(), op.getCtrlQubitResults(), - enableQregMode) + : BaseSignatureAnalyzer(op, llvm::SmallVector{op.getTheta()}, + op.getNonCtrlQubitOperands(), op.getCtrlQubitOperands(), + op.getCtrlValueOperands(), op.getNonCtrlQubitResults(), + op.getCtrlQubitResults(), enableQregMode) { } }; From 8add6232674f6a68a0064e046448dc55baadd6da Mon Sep 17 00:00:00 2001 From: Hong-Sheng Zheng Date: Tue, 4 Nov 2025 15:10:06 -0500 Subject: [PATCH 22/42] fix test --- mlir/test/Quantum/DecomposeLoweringTest.mlir | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/mlir/test/Quantum/DecomposeLoweringTest.mlir b/mlir/test/Quantum/DecomposeLoweringTest.mlir index 466983e261..28afc00db6 100644 --- a/mlir/test/Quantum/DecomposeLoweringTest.mlir +++ b/mlir/test/Quantum/DecomposeLoweringTest.mlir @@ -561,8 +561,6 @@ module @circuit_with_multirz { %1 = quantum.extract %0[ 0] : !quantum.reg -> !quantum.bit // CHECK: func.func public @test_with_multirz() -> tensor<4xf64> // CHECK: [[CST_RZ:%.+]] = arith.constant 5.000000e-01 : f64 - // CHECK: [[CST_PI2:%.+]] = arith.constant 1.5707963267948966 : f64 - // CHECK: [[CST_PI:%.+]] = arith.constant 3.1415926535897931 : f64 // CHECK: [[REG:%.+]] = quantum.alloc( 2) : !quantum.reg // CHECK: [[QUBIT1:%.+]] = quantum.custom "RZ"([[CST_RZ]]) {{%.+}} : !quantum.bit @@ -571,8 +569,8 @@ module @circuit_with_multirz { %extracted_2 = tensor.extract %cst[] : tensor %out_qubits = quantum.multirz(%extracted_2) %1 : !quantum.bit - // CHECK: [[QUBIT3:%.+]] = quantum.custom "RZ"([[CST_PI]]) {{%.+}} : !quantum.bit - // CHECK: [[QUBIT4:%.+]] = quantum.custom "RY"([[CST_PI2]]) [[QUBIT3]] : !quantum.bit + // CHECK: [[QUBIT3:%.+]] = quantum.custom "RZ"([[CST_RZ:%.+]]) {{%.+}} : !quantum.bit + // CHECK: [[QUBIT4:%.+]] = quantum.custom "RY"([[CST_RY:%.+]]) [[QUBIT3]] : !quantum.bit // CHECK-NOT: quantum.custom "Hadamard" %out_qubits_0 = quantum.custom "Hadamard"() %out_qubits : !quantum.bit From f4615f3f65ff6849c0a80dadd9c836ea86cd5913 Mon Sep 17 00:00:00 2001 From: Hong-Sheng Zheng Date: Tue, 4 Nov 2025 15:23:47 -0500 Subject: [PATCH 23/42] fix --- .../Transforms/DecomposeLoweringImpl.hpp | 39 ++++++++++--------- mlir/test/Quantum/DecomposeLoweringTest.mlir | 12 +++--- 2 files changed, 27 insertions(+), 24 deletions(-) diff --git a/mlir/lib/Quantum/Transforms/DecomposeLoweringImpl.hpp b/mlir/lib/Quantum/Transforms/DecomposeLoweringImpl.hpp index a0c41f8838..6715808aae 100644 --- a/mlir/lib/Quantum/Transforms/DecomposeLoweringImpl.hpp +++ b/mlir/lib/Quantum/Transforms/DecomposeLoweringImpl.hpp @@ -78,7 +78,7 @@ class BaseSignatureAnalyzer { protected: bool isValid = true; - llvm::SmallVector paramsStorage; + llvm::SmallVector paramsStorage; // Unified Signature Structure: All parameters, regardless of source (params or theta), // are stored in a ValueRange for generalized processing. @@ -97,20 +97,21 @@ class BaseSignatureAnalyzer { llvm::SmallVector outCtrlQubitIndices; } signature; - BaseSignatureAnalyzer(mlir::Operation *op, llvm::SmallVector params, - mlir::ValueRange inQubits, mlir::ValueRange inCtrlQubits, - mlir::ValueRange inCtrlValues, mlir::ValueRange outQubits, - mlir::ValueRange outCtrlQubits, bool enableQregMode) - : paramsStorage(params), signature(Signature{.params = paramsStorage, - .inQubits = inQubits, - .inCtrlQubits = inCtrlQubits, - .inCtrlValues = inCtrlValues, - .outQubits = outQubits, - .outCtrlQubits = outCtrlQubits, - .inWireIndices = {}, - .inCtrlWireIndices = {}, - .outQubitIndices = {}, - .outCtrlQubitIndices = {}}) + BaseSignatureAnalyzer(mlir::Operation *op, mlir::ValueRange params, mlir::ValueRange inQubits, + mlir::ValueRange inCtrlQubits, mlir::ValueRange inCtrlValues, + mlir::ValueRange outQubits, mlir::ValueRange outCtrlQubits, + bool enableQregMode) + : paramsStorage(params.begin(), params.end()), + signature(Signature{.params = mlir::ValueRange(paramsStorage), + .inQubits = inQubits, + .inCtrlQubits = inCtrlQubits, + .inCtrlValues = inCtrlValues, + .outQubits = outQubits, + .outCtrlQubits = outCtrlQubits, + .inWireIndices = {}, + .inCtrlWireIndices = {}, + .outQubitIndices = {}, + .outCtrlQubitIndices = {}}) { initializeQregMode(op, enableQregMode); } @@ -435,10 +436,10 @@ class MultiRZOpSignatureAnalyzer : public BaseSignatureAnalyzer { MultiRZOpSignatureAnalyzer() = delete; MultiRZOpSignatureAnalyzer(MultiRZOp op, bool enableQregMode) - : BaseSignatureAnalyzer(op, llvm::SmallVector{op.getTheta()}, - op.getNonCtrlQubitOperands(), op.getCtrlQubitOperands(), - op.getCtrlValueOperands(), op.getNonCtrlQubitResults(), - op.getCtrlQubitResults(), enableQregMode) + : BaseSignatureAnalyzer(op, op.getTheta(), op.getNonCtrlQubitOperands(), + op.getCtrlQubitOperands(), op.getCtrlValueOperands(), + op.getNonCtrlQubitResults(), op.getCtrlQubitResults(), + enableQregMode) { } }; diff --git a/mlir/test/Quantum/DecomposeLoweringTest.mlir b/mlir/test/Quantum/DecomposeLoweringTest.mlir index 28afc00db6..43fdf9c1ad 100644 --- a/mlir/test/Quantum/DecomposeLoweringTest.mlir +++ b/mlir/test/Quantum/DecomposeLoweringTest.mlir @@ -388,10 +388,12 @@ module @multi_wire_cnot_decomposition { // CHECK: [[QUBIT1:%.+]] = quantum.extract [[REG]][[[EXTRACTED]]] : !quantum.reg -> !quantum.bit // CHECK: [[RZ1:%.+]] = quantum.custom "RZ"([[CST_PI]]) [[QUBIT1]] : !quantum.bit // CHECK: [[RY1:%.+]] = quantum.custom "RY"([[CST_PI2]]) [[RZ1]] : !quantum.bit + // CHECK: [[INSERT1:%.+]] = quantum.insert [[REG]][[[EXTRACTED]]], [[RY1]] : !quantum.reg, !quantum.bit // CHECK: [[EXTRACTED2:%.+]] = tensor.extract [[RESHAPE1]][] : tensor - // CHECK: [[QUBIT0:%.+]] = quantum.extract [[REG]][[[EXTRACTED2]]] : !quantum.reg -> !quantum.bit - // CHECK: [[CZ_RESULT:%.+]]:2 = quantum.custom "CZ"() [[QUBIT0]], [[RY1]] : !quantum.bit, !quantum.bit - // CHECK: [[INSERT2:%.+]] = quantum.insert [[REG]][[[EXTRACTED2]]], [[CZ_RESULT]]#0 : !quantum.reg, !quantum.bit + // CHECK: [[QUBIT0:%.+]] = quantum.extract [[INSERT1]][[[EXTRACTED2]]] : !quantum.reg -> !quantum.bit + // CHECK: [[QUBIT1_2:%.+]] = quantum.extract [[INSERT1]][[[EXTRACTED]]] : !quantum.reg -> !quantum.bit + // CHECK: [[CZ_RESULT:%.+]]:2 = quantum.custom "CZ"() [[QUBIT0]], [[QUBIT1_2]] : !quantum.bit, !quantum.bit + // CHECK: [[INSERT2:%.+]] = quantum.insert [[INSERT1]][[[EXTRACTED2]]], [[CZ_RESULT]]#0 : !quantum.reg, !quantum.bit // CHECK: [[RZ2:%.+]] = quantum.custom "RZ"([[CST_PI]]) [[CZ_RESULT]]#1 : !quantum.bit // CHECK: [[RY2:%.+]] = quantum.custom "RY"([[CST_PI2]]) [[RZ2]] : !quantum.bit // CHECK: [[INSERT3:%.+]] = quantum.insert [[INSERT2]][[[EXTRACTED]]], [[RY2]] : !quantum.reg, !quantum.bit @@ -569,8 +571,8 @@ module @circuit_with_multirz { %extracted_2 = tensor.extract %cst[] : tensor %out_qubits = quantum.multirz(%extracted_2) %1 : !quantum.bit - // CHECK: [[QUBIT3:%.+]] = quantum.custom "RZ"([[CST_RZ:%.+]]) {{%.+}} : !quantum.bit - // CHECK: [[QUBIT4:%.+]] = quantum.custom "RY"([[CST_RY:%.+]]) [[QUBIT3]] : !quantum.bit + // CHECK: [[QUBIT3:%.+]] = quantum.custom "RZ" + // CHECK: [[QUBIT4:%.+]] = quantum.custom "RY" // CHECK-NOT: quantum.custom "Hadamard" %out_qubits_0 = quantum.custom "Hadamard"() %out_qubits : !quantum.bit From 3a71e654027395d69de170efc2bb8f1a7cc44eb3 Mon Sep 17 00:00:00 2001 From: Hong-Sheng Zheng Date: Tue, 4 Nov 2025 15:24:04 -0500 Subject: [PATCH 24/42] fix --- mlir/test/Quantum/DecomposeLoweringTest.mlir | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/mlir/test/Quantum/DecomposeLoweringTest.mlir b/mlir/test/Quantum/DecomposeLoweringTest.mlir index 43fdf9c1ad..b84c8a18d9 100644 --- a/mlir/test/Quantum/DecomposeLoweringTest.mlir +++ b/mlir/test/Quantum/DecomposeLoweringTest.mlir @@ -563,6 +563,8 @@ module @circuit_with_multirz { %1 = quantum.extract %0[ 0] : !quantum.reg -> !quantum.bit // CHECK: func.func public @test_with_multirz() -> tensor<4xf64> // CHECK: [[CST_RZ:%.+]] = arith.constant 5.000000e-01 : f64 + // CHECK: [[CST_PI2:%.+]] = arith.constant 1.5707963267948966 : f64 + // CHECK: [[CST_PI:%.+]] = arith.constant 3.1415926535897931 : f64 // CHECK: [[REG:%.+]] = quantum.alloc( 2) : !quantum.reg // CHECK: [[QUBIT1:%.+]] = quantum.custom "RZ"([[CST_RZ]]) {{%.+}} : !quantum.bit @@ -571,8 +573,8 @@ module @circuit_with_multirz { %extracted_2 = tensor.extract %cst[] : tensor %out_qubits = quantum.multirz(%extracted_2) %1 : !quantum.bit - // CHECK: [[QUBIT3:%.+]] = quantum.custom "RZ" - // CHECK: [[QUBIT4:%.+]] = quantum.custom "RY" + // CHECK: [[QUBIT3:%.+]] = quantum.custom "RZ"([[CST_PI]]) {{%.+}} : !quantum.bit + // CHECK: [[QUBIT4:%.+]] = quantum.custom "RY"([[CST_PI2]]) [[QUBIT3]] : !quantum.bit // CHECK-NOT: quantum.custom "Hadamard" %out_qubits_0 = quantum.custom "Hadamard"() %out_qubits : !quantum.bit From 0545b4be72173e337570b41be8bb7c57070c447d Mon Sep 17 00:00:00 2001 From: Hong-Sheng Zheng Date: Tue, 4 Nov 2025 15:36:55 -0500 Subject: [PATCH 25/42] fix --- mlir/test/Quantum/DecomposeLoweringTest.mlir | 10 ++++------ 1 file changed, 4 insertions(+), 6 deletions(-) diff --git a/mlir/test/Quantum/DecomposeLoweringTest.mlir b/mlir/test/Quantum/DecomposeLoweringTest.mlir index b84c8a18d9..15ea2d2e20 100644 --- a/mlir/test/Quantum/DecomposeLoweringTest.mlir +++ b/mlir/test/Quantum/DecomposeLoweringTest.mlir @@ -388,12 +388,10 @@ module @multi_wire_cnot_decomposition { // CHECK: [[QUBIT1:%.+]] = quantum.extract [[REG]][[[EXTRACTED]]] : !quantum.reg -> !quantum.bit // CHECK: [[RZ1:%.+]] = quantum.custom "RZ"([[CST_PI]]) [[QUBIT1]] : !quantum.bit // CHECK: [[RY1:%.+]] = quantum.custom "RY"([[CST_PI2]]) [[RZ1]] : !quantum.bit - // CHECK: [[INSERT1:%.+]] = quantum.insert [[REG]][[[EXTRACTED]]], [[RY1]] : !quantum.reg, !quantum.bit // CHECK: [[EXTRACTED2:%.+]] = tensor.extract [[RESHAPE1]][] : tensor - // CHECK: [[QUBIT0:%.+]] = quantum.extract [[INSERT1]][[[EXTRACTED2]]] : !quantum.reg -> !quantum.bit - // CHECK: [[QUBIT1_2:%.+]] = quantum.extract [[INSERT1]][[[EXTRACTED]]] : !quantum.reg -> !quantum.bit - // CHECK: [[CZ_RESULT:%.+]]:2 = quantum.custom "CZ"() [[QUBIT0]], [[QUBIT1_2]] : !quantum.bit, !quantum.bit - // CHECK: [[INSERT2:%.+]] = quantum.insert [[INSERT1]][[[EXTRACTED2]]], [[CZ_RESULT]]#0 : !quantum.reg, !quantum.bit + // CHECK: [[QUBIT0:%.+]] = quantum.extract [[REG]][[[EXTRACTED2]]] : !quantum.reg -> !quantum.bit + // CHECK: [[CZ_RESULT:%.+]]:2 = quantum.custom "CZ"() [[QUBIT0]], [[RY1]] : !quantum.bit, !quantum.bit + // CHECK: [[INSERT2:%.+]] = quantum.insert [[REG]][[[EXTRACTED2]]], [[CZ_RESULT]]#0 : !quantum.reg, !quantum.bit // CHECK: [[RZ2:%.+]] = quantum.custom "RZ"([[CST_PI]]) [[CZ_RESULT]]#1 : !quantum.bit // CHECK: [[RY2:%.+]] = quantum.custom "RY"([[CST_PI2]]) [[RZ2]] : !quantum.bit // CHECK: [[INSERT3:%.+]] = quantum.insert [[INSERT2]][[[EXTRACTED]]], [[RY2]] : !quantum.reg, !quantum.bit @@ -608,4 +606,4 @@ module @circuit_with_multirz { %3 = quantum.insert %arg0[%extracted_1], %out_qubits : !quantum.reg, !quantum.bit return %3 : !quantum.reg } -} +} \ No newline at end of file From 1a093cbe49ec7442e425d26e18fcbec21966fba7 Mon Sep 17 00:00:00 2001 From: Hong-Sheng Zheng Date: Tue, 4 Nov 2025 15:49:18 -0500 Subject: [PATCH 26/42] test --- mlir/test/Quantum/DecomposeLoweringTest.mlir | 9 --------- 1 file changed, 9 deletions(-) diff --git a/mlir/test/Quantum/DecomposeLoweringTest.mlir b/mlir/test/Quantum/DecomposeLoweringTest.mlir index 15ea2d2e20..b07493a669 100644 --- a/mlir/test/Quantum/DecomposeLoweringTest.mlir +++ b/mlir/test/Quantum/DecomposeLoweringTest.mlir @@ -559,24 +559,15 @@ module @circuit_with_multirz { func.func public @test_with_multirz() -> tensor<4xf64> { %0 = quantum.alloc( 2) : !quantum.reg %1 = quantum.extract %0[ 0] : !quantum.reg -> !quantum.bit - // CHECK: func.func public @test_with_multirz() -> tensor<4xf64> - // CHECK: [[CST_RZ:%.+]] = arith.constant 5.000000e-01 : f64 - // CHECK: [[CST_PI2:%.+]] = arith.constant 1.5707963267948966 : f64 - // CHECK: [[CST_PI:%.+]] = arith.constant 3.1415926535897931 : f64 - // CHECK: [[REG:%.+]] = quantum.alloc( 2) : !quantum.reg - // CHECK: [[QUBIT1:%.+]] = quantum.custom "RZ"([[CST_RZ]]) {{%.+}} : !quantum.bit // CHECK-NOT: quantum.multirz %cst = stablehlo.constant dense<5.000000e-01> : tensor %extracted_2 = tensor.extract %cst[] : tensor %out_qubits = quantum.multirz(%extracted_2) %1 : !quantum.bit - // CHECK: [[QUBIT3:%.+]] = quantum.custom "RZ"([[CST_PI]]) {{%.+}} : !quantum.bit - // CHECK: [[QUBIT4:%.+]] = quantum.custom "RY"([[CST_PI2]]) [[QUBIT3]] : !quantum.bit // CHECK-NOT: quantum.custom "Hadamard" %out_qubits_0 = quantum.custom "Hadamard"() %out_qubits : !quantum.bit - // CHECK: [[UPDATED_REG:%.+]] = quantum.insert [[REG]][ 0], [[QUBIT4]] : !quantum.reg, !quantum.bit %2 = quantum.insert %0[ 0], %out_qubits_0 : !quantum.reg, !quantum.bit %3 = quantum.compbasis qreg %2 : !quantum.obs %4 = quantum.probs %3 : tensor<4xf64> From 22870fa64e0bc0f65fd3551489229acc6bc5ee8d Mon Sep 17 00:00:00 2001 From: Hong-Sheng Zheng Date: Tue, 4 Nov 2025 15:58:38 -0500 Subject: [PATCH 27/42] CI test --- mlir/test/Quantum/DecomposeLoweringTest.mlir | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mlir/test/Quantum/DecomposeLoweringTest.mlir b/mlir/test/Quantum/DecomposeLoweringTest.mlir index b07493a669..9868f4c7fd 100644 --- a/mlir/test/Quantum/DecomposeLoweringTest.mlir +++ b/mlir/test/Quantum/DecomposeLoweringTest.mlir @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -// RUN: quantum-opt --decompose-lowering --split-input-file -verify-diagnostics %s | FileCheck %s +// RUN: quantum-opt --decompose-lowering --split-input-file -verify-diagnostics %s | tee /dev/stderr | FileCheck %s module @two_hadamards { func.func public @test_two_hadamards() -> tensor<4xf64> { From 5e93dbdd10f8c123439458374048d5f2fac190ba Mon Sep 17 00:00:00 2001 From: Hong-Sheng Zheng Date: Tue, 4 Nov 2025 16:11:56 -0500 Subject: [PATCH 28/42] fix decomp --- mlir/lib/Quantum/Transforms/DecomposeLoweringImpl.hpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mlir/lib/Quantum/Transforms/DecomposeLoweringImpl.hpp b/mlir/lib/Quantum/Transforms/DecomposeLoweringImpl.hpp index 6715808aae..b813573920 100644 --- a/mlir/lib/Quantum/Transforms/DecomposeLoweringImpl.hpp +++ b/mlir/lib/Quantum/Transforms/DecomposeLoweringImpl.hpp @@ -120,7 +120,7 @@ class BaseSignatureAnalyzer { mlir::ValueRange inCtrlQubits, mlir::ValueRange inCtrlValues, mlir::ValueRange outQubits, mlir::ValueRange outCtrlQubits, bool enableQregMode) - : paramsStorage(mlir::ValueRange(param).begin(), mlir::ValueRange(param).end()), + : paramsStorage{param}, signature(Signature{.params = mlir::ValueRange(paramsStorage), .inQubits = inQubits, .inCtrlQubits = inCtrlQubits, From 257a203cbd562fcc430024047705c2f9386a6072 Mon Sep 17 00:00:00 2001 From: Hong-Sheng Zheng Date: Tue, 4 Nov 2025 16:15:24 -0500 Subject: [PATCH 29/42] fix formatting --- .../Transforms/DecomposeLoweringImpl.hpp | 21 +++++++++---------- 1 file changed, 10 insertions(+), 11 deletions(-) diff --git a/mlir/lib/Quantum/Transforms/DecomposeLoweringImpl.hpp b/mlir/lib/Quantum/Transforms/DecomposeLoweringImpl.hpp index b813573920..1f0ddf871e 100644 --- a/mlir/lib/Quantum/Transforms/DecomposeLoweringImpl.hpp +++ b/mlir/lib/Quantum/Transforms/DecomposeLoweringImpl.hpp @@ -120,17 +120,16 @@ class BaseSignatureAnalyzer { mlir::ValueRange inCtrlQubits, mlir::ValueRange inCtrlValues, mlir::ValueRange outQubits, mlir::ValueRange outCtrlQubits, bool enableQregMode) - : paramsStorage{param}, - signature(Signature{.params = mlir::ValueRange(paramsStorage), - .inQubits = inQubits, - .inCtrlQubits = inCtrlQubits, - .inCtrlValues = inCtrlValues, - .outQubits = outQubits, - .outCtrlQubits = outCtrlQubits, - .inWireIndices = {}, - .inCtrlWireIndices = {}, - .outQubitIndices = {}, - .outCtrlQubitIndices = {}}) + : paramsStorage{param}, signature(Signature{.params = mlir::ValueRange(paramsStorage), + .inQubits = inQubits, + .inCtrlQubits = inCtrlQubits, + .inCtrlValues = inCtrlValues, + .outQubits = outQubits, + .outCtrlQubits = outCtrlQubits, + .inWireIndices = {}, + .inCtrlWireIndices = {}, + .outQubitIndices = {}, + .outCtrlQubitIndices = {}}) { initializeQregMode(op, enableQregMode); } From 6dafb845c9d38f68b64b9b1295351eabcb365ee0 Mon Sep 17 00:00:00 2001 From: Hong-Sheng Zheng Date: Tue, 4 Nov 2025 17:34:09 -0500 Subject: [PATCH 30/42] fix --- mlir/test/Quantum/DecomposeLoweringTest.mlir | 11 ++++++++++- 1 file changed, 10 insertions(+), 1 deletion(-) diff --git a/mlir/test/Quantum/DecomposeLoweringTest.mlir b/mlir/test/Quantum/DecomposeLoweringTest.mlir index 9868f4c7fd..da855c67d7 100644 --- a/mlir/test/Quantum/DecomposeLoweringTest.mlir +++ b/mlir/test/Quantum/DecomposeLoweringTest.mlir @@ -559,15 +559,24 @@ module @circuit_with_multirz { func.func public @test_with_multirz() -> tensor<4xf64> { %0 = quantum.alloc( 2) : !quantum.reg %1 = quantum.extract %0[ 0] : !quantum.reg -> !quantum.bit + // CHECK: func.func public @test_with_multirz() -> tensor<4xf64> + // CHECK: [[CST_RZ:%.+]] = arith.constant 5.000000e-01 : f64 + // CHECK: [[CST_PI2:%.+]] = arith.constant 1.5707963267948966 : f64 + // CHECK: [[CST_PI:%.+]] = arith.constant 3.1415926535897931 : f64 + // CHECK: [[REG:%.+]] = quantum.alloc( 2) : !quantum.reg + // CHECK: [[QUBIT1:%.+]] = quantum.custom "RZ"([[CST_RZ]]) {{%.+}} : !quantum.bit // CHECK-NOT: quantum.multirz %cst = stablehlo.constant dense<5.000000e-01> : tensor %extracted_2 = tensor.extract %cst[] : tensor %out_qubits = quantum.multirz(%extracted_2) %1 : !quantum.bit + // CHECK: [[QUBIT3:%.+]] = quantum.custom "RZ"([[CST_PI]]) {{%.+}} : !quantum.bit + // CHECK: [[QUBIT4:%.+]] = quantum.custom "RY"([[CST_PI2]]) [[QUBIT3]] : !quantum.bit // CHECK-NOT: quantum.custom "Hadamard" %out_qubits_0 = quantum.custom "Hadamard"() %out_qubits : !quantum.bit + // CHECK: [[UPDATED_REG:%.+]] = quantum.insert [[REG]][ 0], [[QUBIT4]] : !quantum.reg, !quantum.bit %2 = quantum.insert %0[ 0], %out_qubits_0 : !quantum.reg, !quantum.bit %3 = quantum.compbasis qreg %2 : !quantum.obs %4 = quantum.probs %3 : tensor<4xf64> @@ -597,4 +606,4 @@ module @circuit_with_multirz { %3 = quantum.insert %arg0[%extracted_1], %out_qubits : !quantum.reg, !quantum.bit return %3 : !quantum.reg } -} \ No newline at end of file +} From fd46638e7bc5f072667fe57c5028f2e473c46ec6 Mon Sep 17 00:00:00 2001 From: Hong-Sheng Zheng Date: Tue, 4 Nov 2025 17:34:30 -0500 Subject: [PATCH 31/42] update --- mlir/test/Quantum/DecomposeLoweringTest.mlir | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mlir/test/Quantum/DecomposeLoweringTest.mlir b/mlir/test/Quantum/DecomposeLoweringTest.mlir index da855c67d7..466983e261 100644 --- a/mlir/test/Quantum/DecomposeLoweringTest.mlir +++ b/mlir/test/Quantum/DecomposeLoweringTest.mlir @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -// RUN: quantum-opt --decompose-lowering --split-input-file -verify-diagnostics %s | tee /dev/stderr | FileCheck %s +// RUN: quantum-opt --decompose-lowering --split-input-file -verify-diagnostics %s | FileCheck %s module @two_hadamards { func.func public @test_two_hadamards() -> tensor<4xf64> { From 82068da9797b23de8999f701645fc7608da0042e Mon Sep 17 00:00:00 2001 From: Hong-Sheng Zheng Date: Tue, 4 Nov 2025 20:48:06 -0500 Subject: [PATCH 32/42] update and trigger compiler to recompile hpp --- .../Quantum/Transforms/DecomposeLoweringPatterns.cpp | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/mlir/lib/Quantum/Transforms/DecomposeLoweringPatterns.cpp b/mlir/lib/Quantum/Transforms/DecomposeLoweringPatterns.cpp index 6aa687ae22..11e014a2bb 100644 --- a/mlir/lib/Quantum/Transforms/DecomposeLoweringPatterns.cpp +++ b/mlir/lib/Quantum/Transforms/DecomposeLoweringPatterns.cpp @@ -71,7 +71,10 @@ struct DLCustomOpPattern : public OpRewritePattern { auto enableQreg = isa(decompFunc.getFunctionType().getInput(0)); auto analyzer = CustomOpSignatureAnalyzer(op, enableQreg); - assert(analyzer && "Analyzer should be valid"); + if (!analyzer) { + op.emitError("Failed to create CustomOpSignatureAnalyzer"); + return failure(); + } auto callOperands = analyzer.prepareCallOperands(decompFunc, rewriter, op.getLoc()); auto callOp = @@ -145,7 +148,10 @@ struct DLMultiRZOpPattern : public OpRewritePattern { } auto analyzer = MultiRZOpSignatureAnalyzer(op, enableQreg); - assert(analyzer && "Analyzer should be valid"); + if (!analyzer) { + op.emitError("Failed to create MultiRZOpSignatureAnalyzer"); + return failure(); + } auto callOperands = analyzer.prepareCallOperands(decompFunc, rewriter, op.getLoc()); auto callOp = From 9c9ba77d59bf17156a7a016edc259159c464dd9f Mon Sep 17 00:00:00 2001 From: Hong-Sheng Zheng Date: Tue, 4 Nov 2025 20:58:13 -0500 Subject: [PATCH 33/42] include decompose impl --- mlir/lib/Quantum/Transforms/CMakeLists.txt | 1 + 1 file changed, 1 insertion(+) diff --git a/mlir/lib/Quantum/Transforms/CMakeLists.txt b/mlir/lib/Quantum/Transforms/CMakeLists.txt index f1ae85d1ff..ec76d20b65 100644 --- a/mlir/lib/Quantum/Transforms/CMakeLists.txt +++ b/mlir/lib/Quantum/Transforms/CMakeLists.txt @@ -15,6 +15,7 @@ file(GLOB SRC merge_rotation.cpp MergeRotationsPatterns.cpp decompose_lowering.cpp + DecomposeLoweringImpl.hpp DecomposeLoweringPatterns.cpp DisentangleSWAP.cpp DisentangleCNOT.cpp From 0f1e68a8d564431ab82c0b5114d22c286976a6e2 Mon Sep 17 00:00:00 2001 From: Hong-Sheng Zheng Date: Tue, 4 Nov 2025 21:13:16 -0500 Subject: [PATCH 34/42] update --- .../Transforms/DecomposeLoweringImpl.hpp | 22 ++++++++++--------- 1 file changed, 12 insertions(+), 10 deletions(-) diff --git a/mlir/lib/Quantum/Transforms/DecomposeLoweringImpl.hpp b/mlir/lib/Quantum/Transforms/DecomposeLoweringImpl.hpp index 1f0ddf871e..c0f4a4f43f 100644 --- a/mlir/lib/Quantum/Transforms/DecomposeLoweringImpl.hpp +++ b/mlir/lib/Quantum/Transforms/DecomposeLoweringImpl.hpp @@ -13,6 +13,7 @@ // limitations under the License. #include +#include #include "llvm/ADT/StringMap.h" #include "llvm/ADT/StringSet.h" @@ -120,16 +121,17 @@ class BaseSignatureAnalyzer { mlir::ValueRange inCtrlQubits, mlir::ValueRange inCtrlValues, mlir::ValueRange outQubits, mlir::ValueRange outCtrlQubits, bool enableQregMode) - : paramsStorage{param}, signature(Signature{.params = mlir::ValueRange(paramsStorage), - .inQubits = inQubits, - .inCtrlQubits = inCtrlQubits, - .inCtrlValues = inCtrlValues, - .outQubits = outQubits, - .outCtrlQubits = outCtrlQubits, - .inWireIndices = {}, - .inCtrlWireIndices = {}, - .outQubitIndices = {}, - .outCtrlQubitIndices = {}}) + : paramsStorage(mlir::ValueRange(param).begin(), mlir::ValueRange(param).end()), + signature(Signature{.params = mlir::ValueRange(paramsStorage), + .inQubits = inQubits, + .inCtrlQubits = inCtrlQubits, + .inCtrlValues = inCtrlValues, + .outQubits = outQubits, + .outCtrlQubits = outCtrlQubits, + .inWireIndices = {}, + .inCtrlWireIndices = {}, + .outQubitIndices = {}, + .outCtrlQubitIndices = {}}) { initializeQregMode(op, enableQregMode); } From 828b9aee748c3b9aee964ad0c545bed63406d53b Mon Sep 17 00:00:00 2001 From: Hong-Sheng Zheng Date: Tue, 4 Nov 2025 21:15:18 -0500 Subject: [PATCH 35/42] formatting --- mlir/lib/Quantum/Transforms/DecomposeLoweringImpl.hpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mlir/lib/Quantum/Transforms/DecomposeLoweringImpl.hpp b/mlir/lib/Quantum/Transforms/DecomposeLoweringImpl.hpp index c0f4a4f43f..97e1b393d9 100644 --- a/mlir/lib/Quantum/Transforms/DecomposeLoweringImpl.hpp +++ b/mlir/lib/Quantum/Transforms/DecomposeLoweringImpl.hpp @@ -12,8 +12,8 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include #include +#include #include "llvm/ADT/StringMap.h" #include "llvm/ADT/StringSet.h" From cd5aad30b12e8991faf19cb60ef0bb89665439ec Mon Sep 17 00:00:00 2001 From: Hong-Sheng Zheng Date: Tue, 4 Nov 2025 21:30:36 -0500 Subject: [PATCH 36/42] debug CI --- mlir/lib/Quantum/Transforms/DecomposeLoweringPatterns.cpp | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/mlir/lib/Quantum/Transforms/DecomposeLoweringPatterns.cpp b/mlir/lib/Quantum/Transforms/DecomposeLoweringPatterns.cpp index 11e014a2bb..4b41a41f4b 100644 --- a/mlir/lib/Quantum/Transforms/DecomposeLoweringPatterns.cpp +++ b/mlir/lib/Quantum/Transforms/DecomposeLoweringPatterns.cpp @@ -111,8 +111,11 @@ struct DLMultiRZOpPattern : public OpRewritePattern { { StringRef gateName = "MultiRZ"; + llvm::errs() << "Decomposing MultiRZOp: " << gateName << "\n"; + // Only decompose the op if it is not in the target gate set if (targetGateSet.contains(gateName)) { + llvm::errs() << "MultiRZOp is in the target gate set, skipping\n"; return failure(); } @@ -122,6 +125,7 @@ struct DLMultiRZOpPattern : public OpRewritePattern { auto it = decompositionRegistry.find(MRZNameWithQubits.str()); if (it == decompositionRegistry.end()) { + llvm::errs() << "No decomposition function found for " << MRZNameWithQubits.str() << "\n"; return failure(); } @@ -149,6 +153,7 @@ struct DLMultiRZOpPattern : public OpRewritePattern { auto analyzer = MultiRZOpSignatureAnalyzer(op, enableQreg); if (!analyzer) { + llvm::errs() << "Failed to create MultiRZOpSignatureAnalyzer\n"; op.emitError("Failed to create MultiRZOpSignatureAnalyzer"); return failure(); } From df79232d3402a75b23eb009a30d431aa2075c77e Mon Sep 17 00:00:00 2001 From: Hong-Sheng Zheng Date: Tue, 4 Nov 2025 21:32:17 -0500 Subject: [PATCH 37/42] formatting --- mlir/lib/Quantum/Transforms/DecomposeLoweringPatterns.cpp | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/mlir/lib/Quantum/Transforms/DecomposeLoweringPatterns.cpp b/mlir/lib/Quantum/Transforms/DecomposeLoweringPatterns.cpp index 4b41a41f4b..f4841e264f 100644 --- a/mlir/lib/Quantum/Transforms/DecomposeLoweringPatterns.cpp +++ b/mlir/lib/Quantum/Transforms/DecomposeLoweringPatterns.cpp @@ -125,7 +125,8 @@ struct DLMultiRZOpPattern : public OpRewritePattern { auto it = decompositionRegistry.find(MRZNameWithQubits.str()); if (it == decompositionRegistry.end()) { - llvm::errs() << "No decomposition function found for " << MRZNameWithQubits.str() << "\n"; + llvm::errs() << "No decomposition function found for " << MRZNameWithQubits.str() + << "\n"; return failure(); } From b37de9d00fd00e2261c21fa1f0ff2b8970a15be1 Mon Sep 17 00:00:00 2001 From: Hong-Sheng Zheng Date: Tue, 4 Nov 2025 21:43:17 -0500 Subject: [PATCH 38/42] debug on CI --- mlir/lib/Quantum/Transforms/DecomposeLoweringPatterns.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mlir/lib/Quantum/Transforms/DecomposeLoweringPatterns.cpp b/mlir/lib/Quantum/Transforms/DecomposeLoweringPatterns.cpp index f4841e264f..f73b03f0fe 100644 --- a/mlir/lib/Quantum/Transforms/DecomposeLoweringPatterns.cpp +++ b/mlir/lib/Quantum/Transforms/DecomposeLoweringPatterns.cpp @@ -109,7 +109,7 @@ struct DLMultiRZOpPattern : public OpRewritePattern { LogicalResult matchAndRewrite(MultiRZOp op, PatternRewriter &rewriter) const override { - StringRef gateName = "MultiRZ"; + std::string gateName = "MultiRZ"; llvm::errs() << "Decomposing MultiRZOp: " << gateName << "\n"; From 0bce06ba5028aae23db7c4e20eb8f0772b88ea91 Mon Sep 17 00:00:00 2001 From: Hong-Sheng Zheng Date: Tue, 4 Nov 2025 21:52:44 -0500 Subject: [PATCH 39/42] fix --- mlir/lib/Quantum/Transforms/DecomposeLoweringPatterns.cpp | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/mlir/lib/Quantum/Transforms/DecomposeLoweringPatterns.cpp b/mlir/lib/Quantum/Transforms/DecomposeLoweringPatterns.cpp index f73b03f0fe..ee32719b6f 100644 --- a/mlir/lib/Quantum/Transforms/DecomposeLoweringPatterns.cpp +++ b/mlir/lib/Quantum/Transforms/DecomposeLoweringPatterns.cpp @@ -123,10 +123,9 @@ struct DLMultiRZOpPattern : public OpRewritePattern { auto numQubits = op.getInQubits().size(); auto MRZNameWithQubits = gateName + "_" + std::to_string(numQubits); - auto it = decompositionRegistry.find(MRZNameWithQubits.str()); + auto it = decompositionRegistry.find(MRZNameWithQubits); if (it == decompositionRegistry.end()) { - llvm::errs() << "No decomposition function found for " << MRZNameWithQubits.str() - << "\n"; + llvm::errs() << "No decomposition function found for " << MRZNameWithQubits << "\n"; return failure(); } From f53736072e6defff3667bf82ce83b0cf33741311 Mon Sep 17 00:00:00 2001 From: Hong-Sheng Zheng Date: Tue, 4 Nov 2025 22:08:26 -0500 Subject: [PATCH 40/42] fix --- mlir/lib/Quantum/Transforms/decompose_lowering.cpp | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/mlir/lib/Quantum/Transforms/decompose_lowering.cpp b/mlir/lib/Quantum/Transforms/decompose_lowering.cpp index 681fd73f47..c0f27d196e 100644 --- a/mlir/lib/Quantum/Transforms/decompose_lowering.cpp +++ b/mlir/lib/Quantum/Transforms/decompose_lowering.cpp @@ -102,13 +102,15 @@ struct DecomposeLoweringPass : impl::DecomposeLoweringPassBase Date: Tue, 4 Nov 2025 23:01:05 -0500 Subject: [PATCH 41/42] fix --- mlir/lib/Quantum/Transforms/CMakeLists.txt | 1 - .../Quantum/Transforms/DecomposeLoweringImpl.hpp | 1 - .../Transforms/DecomposeLoweringPatterns.cpp | 15 ++------------- .../lib/Quantum/Transforms/decompose_lowering.cpp | 7 ++++--- 4 files changed, 6 insertions(+), 18 deletions(-) diff --git a/mlir/lib/Quantum/Transforms/CMakeLists.txt b/mlir/lib/Quantum/Transforms/CMakeLists.txt index ec76d20b65..f1ae85d1ff 100644 --- a/mlir/lib/Quantum/Transforms/CMakeLists.txt +++ b/mlir/lib/Quantum/Transforms/CMakeLists.txt @@ -15,7 +15,6 @@ file(GLOB SRC merge_rotation.cpp MergeRotationsPatterns.cpp decompose_lowering.cpp - DecomposeLoweringImpl.hpp DecomposeLoweringPatterns.cpp DisentangleSWAP.cpp DisentangleCNOT.cpp diff --git a/mlir/lib/Quantum/Transforms/DecomposeLoweringImpl.hpp b/mlir/lib/Quantum/Transforms/DecomposeLoweringImpl.hpp index 97e1b393d9..6715808aae 100644 --- a/mlir/lib/Quantum/Transforms/DecomposeLoweringImpl.hpp +++ b/mlir/lib/Quantum/Transforms/DecomposeLoweringImpl.hpp @@ -12,7 +12,6 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include #include #include "llvm/ADT/StringMap.h" diff --git a/mlir/lib/Quantum/Transforms/DecomposeLoweringPatterns.cpp b/mlir/lib/Quantum/Transforms/DecomposeLoweringPatterns.cpp index ee32719b6f..3d5ccc8775 100644 --- a/mlir/lib/Quantum/Transforms/DecomposeLoweringPatterns.cpp +++ b/mlir/lib/Quantum/Transforms/DecomposeLoweringPatterns.cpp @@ -71,10 +71,7 @@ struct DLCustomOpPattern : public OpRewritePattern { auto enableQreg = isa(decompFunc.getFunctionType().getInput(0)); auto analyzer = CustomOpSignatureAnalyzer(op, enableQreg); - if (!analyzer) { - op.emitError("Failed to create CustomOpSignatureAnalyzer"); - return failure(); - } + assert(analyzer && "Analyzer should be valid"); auto callOperands = analyzer.prepareCallOperands(decompFunc, rewriter, op.getLoc()); auto callOp = @@ -111,11 +108,8 @@ struct DLMultiRZOpPattern : public OpRewritePattern { { std::string gateName = "MultiRZ"; - llvm::errs() << "Decomposing MultiRZOp: " << gateName << "\n"; - // Only decompose the op if it is not in the target gate set if (targetGateSet.contains(gateName)) { - llvm::errs() << "MultiRZOp is in the target gate set, skipping\n"; return failure(); } @@ -125,7 +119,6 @@ struct DLMultiRZOpPattern : public OpRewritePattern { auto it = decompositionRegistry.find(MRZNameWithQubits); if (it == decompositionRegistry.end()) { - llvm::errs() << "No decomposition function found for " << MRZNameWithQubits << "\n"; return failure(); } @@ -152,11 +145,7 @@ struct DLMultiRZOpPattern : public OpRewritePattern { } auto analyzer = MultiRZOpSignatureAnalyzer(op, enableQreg); - if (!analyzer) { - llvm::errs() << "Failed to create MultiRZOpSignatureAnalyzer\n"; - op.emitError("Failed to create MultiRZOpSignatureAnalyzer"); - return failure(); - } + assert(analyzer && "Analyzer should be valid"); auto callOperands = analyzer.prepareCallOperands(decompFunc, rewriter, op.getLoc()); auto callOp = diff --git a/mlir/lib/Quantum/Transforms/decompose_lowering.cpp b/mlir/lib/Quantum/Transforms/decompose_lowering.cpp index c0f27d196e..ef6931bcd0 100644 --- a/mlir/lib/Quantum/Transforms/decompose_lowering.cpp +++ b/mlir/lib/Quantum/Transforms/decompose_lowering.cpp @@ -102,15 +102,16 @@ struct DecomposeLoweringPass : impl::DecomposeLoweringPassBase Date: Tue, 4 Nov 2025 23:06:46 -0500 Subject: [PATCH 42/42] fix --- mlir/lib/Quantum/Transforms/decompose_lowering.cpp | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/mlir/lib/Quantum/Transforms/decompose_lowering.cpp b/mlir/lib/Quantum/Transforms/decompose_lowering.cpp index ef6931bcd0..8f311f7681 100644 --- a/mlir/lib/Quantum/Transforms/decompose_lowering.cpp +++ b/mlir/lib/Quantum/Transforms/decompose_lowering.cpp @@ -109,9 +109,7 @@ struct DecomposeLoweringPass : impl::DecomposeLoweringPassBase