diff --git a/.dep-versions b/.dep-versions index 6fbf3a9fb1..727441ee1b 100644 --- a/.dep-versions +++ b/.dep-versions @@ -2,9 +2,9 @@ # To update JAX version alongside compatible dependency tags, run the following script: # python3 .github/workflows/set_dep_versions.py {JAX_version} jax=0.6.2 -stablehlo=69d6dae46e1c7de36e6e6973654754f05353cba5 -llvm=f8cb7987c64dcffb72414a40560055cb717dbf74 -enzyme=v0.0.186 +stablehlo=0a4440a5c8de45c4f9649bf3eb4913bf3f97da0d +llvm=113f01aa82d055410f22a9d03b3468fa68600589 +enzyme=v0.0.203 # Always remove custom PL/LQ versions before release. diff --git a/frontend/catalyst/jax_extras/patches.py b/frontend/catalyst/jax_extras/patches.py index ba32f498b9..82ea21984b 100644 --- a/frontend/catalyst/jax_extras/patches.py +++ b/frontend/catalyst/jax_extras/patches.py @@ -43,6 +43,29 @@ ) +def mock_attributes(obj, attrs: dict[str, any]): + """Mock the attribute of an object by returning a wrapper. + + Args: + obj: The object to mock the attributes of. + attrs: A dictionary of attributes to mock. + Example: {"attribute_name": attribute_value} + """ + + class MockAttributeWrapper: + """Wrapper to mock the attribute of an object.""" + + def __init__(self, original): + self.original = original + + def __getattr__(self, name): + if name in attrs: + return attrs[name] + return getattr(self.original, name) + + return MockAttributeWrapper(obj) + + def _drop_unused_vars2(jaxpr, constvals): """ A patch to not drop unused vars during classical tracing of control flow. diff --git a/frontend/catalyst/jax_primitives.py b/frontend/catalyst/jax_primitives.py index 68625edc81..2d050aef63 100644 --- a/frontend/catalyst/jax_primitives.py +++ b/frontend/catalyst/jax_primitives.py @@ -39,6 +39,7 @@ from jax.interpreters import mlir from jax.tree_util import PyTreeDef, tree_unflatten from jaxlib.hlo_helpers import shape_dtype_to_ir_type +from jaxlib.mlir._mlir_libs import _mlir as _ods_cext from jaxlib.mlir.dialects.arith import ( AddIOp, CeilDivSIOp, @@ -52,54 +53,85 @@ from jaxlib.mlir.dialects.scf import ConditionOp, ForOp, IfOp, WhileOp, YieldOp from jaxlib.mlir.dialects.stablehlo import ConstantOp as StableHLOConstantOp from jaxlib.mlir.dialects.stablehlo import ConvertOp as StableHLOConvertOp -from mlir_quantum.dialects.catalyst import ( - AssertionOp, - CallbackCallOp, - CallbackOp, - PrintOp, -) -from mlir_quantum.dialects.gradient import ( - CustomGradOp, - ForwardOp, - GradOp, - JVPOp, - ReverseOp, - ValueAndGradOp, - VJPOp, -) -from mlir_quantum.dialects.mbqc import MeasureInBasisOp -from mlir_quantum.dialects.mitigation import ZneOp -from mlir_quantum.dialects.quantum import ( - AdjointOp, - AllocOp, - ComputationalBasisOp, - CountsOp, - CustomOp, - DeallocOp, - DeallocQubitOp, - DeviceInitOp, - DeviceReleaseOp, - ExpvalOp, - ExtractOp, - GlobalPhaseOp, - HamiltonianOp, - HermitianOp, - InsertOp, - MeasureOp, - MultiRZOp, - NamedObsOp, - NumQubitsOp, - PCPhaseOp, - ProbsOp, - QubitUnitaryOp, - SampleOp, - SetBasisStateOp, - SetStateOp, - StateOp, - TensorOp, - VarianceOp, -) -from mlir_quantum.dialects.quantum import YieldOp as QYieldOp + +# TODO: remove after jax v0.7.2 upgrade +# Mock _ods_cext.globals.register_traceback_file_exclusion due to API conflicts between +# Catalyst's MLIR version and the MLIR version used by JAX. The current JAX version has not +# yet updated to the latest MLIR, causing compatibility issues. This workaround will be removed +# once JAX updates to a compatible MLIR version +# pylint: disable=ungrouped-imports +from catalyst.jax_extras.patches import mock_attributes +from catalyst.utils.patching import Patcher + +with Patcher( + ( + _ods_cext, + "globals", + mock_attributes( + # pylint: disable=c-extension-no-member + _ods_cext.globals, + {"register_traceback_file_exclusion": lambda x: None}, + ), + ), +): + from mlir_quantum.dialects.catalyst import ( + AssertionOp, + CallbackCallOp, + CallbackOp, + PrintOp, + ) + from mlir_quantum.dialects.gradient import ( + CustomGradOp, + ForwardOp, + GradOp, + JVPOp, + ReverseOp, + ValueAndGradOp, + VJPOp, + ) + from mlir_quantum.dialects.mbqc import MeasureInBasisOp + from mlir_quantum.dialects.mitigation import ZneOp + from mlir_quantum.dialects.quantum import ( + AdjointOp, + AllocOp, + ComputationalBasisOp, + CountsOp, + CustomOp, + DeallocOp, + DeallocQubitOp, + DeviceInitOp, + DeviceReleaseOp, + ExpvalOp, + ExtractOp, + GlobalPhaseOp, + HamiltonianOp, + HermitianOp, + InsertOp, + MeasureOp, + MultiRZOp, + NamedObsOp, + NumQubitsOp, + PCPhaseOp, + ProbsOp, + QubitUnitaryOp, + SampleOp, + SetBasisStateOp, + SetStateOp, + StateOp, + TensorOp, + VarianceOp, + ) + from mlir_quantum.dialects.quantum import YieldOp as QYieldOp + from catalyst.jax_primitives_utils import ( + cache, + create_call_op, + get_cached, + get_call_jaxpr, + get_symbolref, + lower_callable, + lower_jaxpr, + ) + from pennylane.capture.primitives import jacobian_prim as pl_jac_prim from catalyst.compiler import get_lib_path @@ -111,19 +143,9 @@ infer_output_type_jaxpr, while_loop_expansion_strategy, ) -from catalyst.jax_primitives_utils import ( - cache, - create_call_op, - get_cached, - get_call_jaxpr, - get_symbolref, - lower_callable, - lower_jaxpr, -) from catalyst.utils.calculate_grad_shape import Signature, calculate_grad_shape from catalyst.utils.exceptions import CompileError from catalyst.utils.extra_bindings import FromElementsOp, TensorExtractOp -from catalyst.utils.patching import Patcher from catalyst.utils.types import convert_shaped_arrays_to_tensors # pylint: disable=unused-argument,too-many-lines,too-many-statements,protected-access diff --git a/frontend/test/pytest/test_jax_extras_patches.py b/frontend/test/pytest/test_jax_extras_patches.py new file mode 100644 index 0000000000..c7ee8346d5 --- /dev/null +++ b/frontend/test/pytest/test_jax_extras_patches.py @@ -0,0 +1,114 @@ +# Copyright 2025 Xanadu Quantum Technologies Inc. + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Tests for the jax_extras.patches module""" + +from catalyst.jax_extras.patches import mock_attributes + + +# pylint: disable=missing-class-docstring,missing-function-docstring +class TestMockAttributes: + """Test the mock_attributes function and MockAttributeWrapper class.""" + + def test_mock_attributes_returns_mocked_value(self): + """Test that accessing a mocked attribute returns the mocked value.""" + + class DummyClass: + def __init__(self): + self.original_attr = "original" + + obj = DummyClass() + mocked = mock_attributes(obj, {"mocked_attr": "mocked_value"}) + + # Access the mocked attribute - this should come from the attrs dict + assert mocked.mocked_attr == "mocked_value" + + def test_mock_attributes_returns_original_value(self): + """Test that accessing an unmocked attribute returns the original value.""" + + class DummyClass: + def __init__(self): + self.original_attr = "original" + + obj = DummyClass() + mocked = mock_attributes(obj, {"mocked_attr": "mocked_value"}) + + # Access the original attribute - this should come from the original object + # This tests the else branch in __getattr__ + assert mocked.original_attr == "original" + + def test_mock_attributes_with_methods(self): + """Test that calling original methods works through the wrapper.""" + + class DummyClass: + def __init__(self): + self.value = 42 + + def get_value(self): + return self.value + + obj = DummyClass() + + def mocked_method(): + return "mocked" + + mocked = mock_attributes(obj, {"mocked_method": mocked_method}) + + # Access the mocked method + assert mocked.mocked_method() == "mocked" + + # Access the original method - tests the getattr fallback + assert mocked.get_value() == 42 + + def test_mock_attributes_with_callable(self): + """Test mocking with callable attributes like lambda functions.""" + + class DummyClass: + def __init__(self): + self.original_func = lambda x: x * 2 + + obj = DummyClass() + mocked = mock_attributes(obj, {"new_func": lambda x: x * 3}) + + # Access the mocked callable + assert mocked.new_func(5) == 15 + + # Access the original callable - tests the getattr fallback + assert mocked.original_func(5) == 10 + + def test_mock_attributes_override_existing(self): + """Test that mocking can override existing attributes.""" + + class DummyClass: + def __init__(self): + self.attr = "original" + + obj = DummyClass() + mocked = mock_attributes(obj, {"attr": "overridden"}) + + # The mocked value should take precedence + assert mocked.attr == "overridden" + + def test_mock_attributes_stores_original(self): + """Test that the original object is accessible through the wrapper.""" + + class DummyClass: + def __init__(self): + self.value = 100 + + obj = DummyClass() + mocked = mock_attributes(obj, {}) + + # The wrapper should store the original object + assert mocked.original is obj + assert mocked.original.value == 100 diff --git a/mlir/Enzyme b/mlir/Enzyme index 8c1a596158..476c8e3193 160000 --- a/mlir/Enzyme +++ b/mlir/Enzyme @@ -1 +1 @@ -Subproject commit 8c1a596158f6194f10e8ffd56a1660a61c54337e +Subproject commit 476c8e3193a8577ba24ff845ae2294109225f83a diff --git a/mlir/Makefile b/mlir/Makefile index 8fc76e11e7..4628b99bd7 100644 --- a/mlir/Makefile +++ b/mlir/Makefile @@ -138,7 +138,7 @@ enzyme: -DCMAKE_CXX_VISIBILITY_PRESET=$(SYMBOL_VISIBILITY) \ -DCMAKE_POLICY_DEFAULT_CMP0116=NEW - cmake --build $(ENZYME_BUILD_DIR) --target EnzymeStatic-21 + cmake --build $(ENZYME_BUILD_DIR) --target EnzymeStatic-22 .PHONY: plugin plugin: diff --git a/mlir/include/Catalyst/IR/CatalystOps.td b/mlir/include/Catalyst/IR/CatalystOps.td index 12daf4f6e9..c3c60fc840 100644 --- a/mlir/include/Catalyst/IR/CatalystOps.td +++ b/mlir/include/Catalyst/IR/CatalystOps.td @@ -138,8 +138,17 @@ def CallbackOp : Catalyst_Op<"callback", let builders = [OpBuilder<(ins "mlir::StringRef":$name, "mlir::FunctionType":$type, - CArg<"mlir::ArrayRef", "{}">:$attrs) - >]; + CArg<"mlir::ArrayRef", "{}">:$attrs), [{ + $_state.addAttribute(mlir::SymbolTable::getSymbolAttrName(), + $_builder.getStringAttr(name)); + $_state.addAttribute("function_type", mlir::TypeAttr::get(type)); + $_state.addAttribute("id", $_builder.getI64IntegerAttr(0)); + $_state.addAttribute("argc", $_builder.getI64IntegerAttr(type.getNumInputs())); + $_state.addAttribute("resc", $_builder.getI64IntegerAttr(type.getNumResults())); + $_state.attributes.append(attrs.begin(), attrs.end()); + $_state.addRegion(); + }]> + ]; let extraClassDeclaration = [{ //===------------------------------------------------------------------===// diff --git a/mlir/include/Gradient/IR/GradientOps.td b/mlir/include/Gradient/IR/GradientOps.td index fb81419b99..75905049aa 100644 --- a/mlir/include/Gradient/IR/GradientOps.td +++ b/mlir/include/Gradient/IR/GradientOps.td @@ -28,7 +28,7 @@ include "Gradient/IR/GradientInterfaces.td" def GradOp : Gradient_Op<"grad", [ DeclareOpInterfaceMethods, - DeclareOpInterfaceMethods, + DeclareOpInterfaceMethods, GradientOpInterface ]> { let summary = "Compute the gradient of a function."; @@ -287,7 +287,7 @@ def ForwardOp : Gradient_Op<"forward", Then: followed by the original return type, if any. - + since there is none, then: %returnTy = { %tape } @@ -302,7 +302,7 @@ def ForwardOp : Gradient_Op<"forward", One thing that was found experimentally and through tests in Enzyme is that the tape can also be a pointer. We use this in the case when there is no tape to return. Instead of returning an empty struct, we return a null pointer that is just never dereferenced. - + }]; let arguments = (ins @@ -320,8 +320,18 @@ def ForwardOp : Gradient_Op<"forward", let builders = [OpBuilder<(ins "mlir::StringRef":$name, "mlir::FunctionType":$type, - CArg<"mlir::ArrayRef", "{}">:$attrs) - >]; + CArg<"mlir::ArrayRef", "{}">:$attrs), [{ + $_state.addAttribute(mlir::SymbolTable::getSymbolAttrName(), + $_builder.getStringAttr(name)); + $_state.addAttribute("function_type", mlir::TypeAttr::get(type)); + $_state.addAttribute("implementation", mlir::FlatSymbolRefAttr::get($_builder.getStringAttr(""))); + $_state.addAttribute("argc", $_builder.getI64IntegerAttr(0)); + $_state.addAttribute("resc", $_builder.getI64IntegerAttr(0)); + $_state.addAttribute("tape", $_builder.getI64IntegerAttr(0)); + $_state.attributes.append(attrs.begin(), attrs.end()); + $_state.addRegion(); + }]> + ]; let extraClassDeclaration = [{ //===------------------------------------------------------------------===// @@ -358,7 +368,6 @@ def ReverseOp : Gradient_Op<"reverse", %returnTy = { %tape } - }]; let arguments = (ins @@ -376,8 +385,18 @@ def ReverseOp : Gradient_Op<"reverse", let builders = [OpBuilder<(ins "mlir::StringRef":$name, "mlir::FunctionType":$type, - CArg<"mlir::ArrayRef", "{}">:$attrs) - >]; + CArg<"mlir::ArrayRef", "{}">:$attrs), [{ + $_state.addAttribute(mlir::SymbolTable::getSymbolAttrName(), + $_builder.getStringAttr(name)); + $_state.addAttribute("function_type", mlir::TypeAttr::get(type)); + $_state.addAttribute("implementation", mlir::FlatSymbolRefAttr::get($_builder.getStringAttr(""))); + $_state.addAttribute("argc", $_builder.getI64IntegerAttr(0)); + $_state.addAttribute("resc", $_builder.getI64IntegerAttr(0)); + $_state.addAttribute("tape", $_builder.getI64IntegerAttr(0)); + $_state.attributes.append(attrs.begin(), attrs.end()); + $_state.addRegion(); + }]> + ]; let extraClassDeclaration = [{ //===------------------------------------------------------------------===// diff --git a/mlir/lib/Driver/CMakeLists.txt b/mlir/lib/Driver/CMakeLists.txt index c3a82be9ec..5ec6857426 100644 --- a/mlir/lib/Driver/CMakeLists.txt +++ b/mlir/lib/Driver/CMakeLists.txt @@ -27,6 +27,9 @@ set(LIBS ${translation_libs} ExternalStablehloLib MLIROptLib + MLIRRegisterAllDialects + MLIRRegisterAllPasses + MLIRRegisterAllExtensions MLIRCatalyst catalyst-transforms MLIRQuantum diff --git a/mlir/lib/Driver/Pipelines.cpp b/mlir/lib/Driver/Pipelines.cpp index 4f58be09d3..a05e70e41b 100644 --- a/mlir/lib/Driver/Pipelines.cpp +++ b/mlir/lib/Driver/Pipelines.cpp @@ -14,10 +14,23 @@ #include +#include "mlir/Conversion/AffineToStandard/AffineToStandard.h" +#include "mlir/Conversion/ArithToLLVM/ArithToLLVM.h" +#include "mlir/Conversion/BufferizationToMemRef/BufferizationToMemRef.h" +#include "mlir/Conversion/ComplexToLLVM/ComplexToLLVM.h" +#include "mlir/Conversion/ComplexToStandard/ComplexToStandard.h" +#include "mlir/Conversion/IndexToLLVM/IndexToLLVM.h" +#include "mlir/Conversion/MathToLLVM/MathToLLVM.h" +#include "mlir/Conversion/MathToLibm/MathToLibm.h" +#include "mlir/Conversion/MemRefToLLVM/MemRefToLLVM.h" +#include "mlir/Conversion/ReconcileUnrealizedCasts/ReconcileUnrealizedCasts.h" +#include "mlir/Conversion/SCFToControlFlow/SCFToControlFlow.h" +#include "mlir/Conversion/TensorToLinalg/TensorToLinalgPass.h" +#include "mlir/Dialect/Arith/Transforms/Passes.h" #include "mlir/Dialect/Bufferization/IR/Bufferization.h" -#include "mlir/InitAllDialects.h" -#include "mlir/InitAllPasses.h" -#include "mlir/Pass/Pass.h" +#include "mlir/Dialect/Bufferization/Transforms/Passes.h" +#include "mlir/Dialect/Linalg/Passes.h" +#include "mlir/Dialect/MemRef/Transforms/Passes.h" #include "mlir/Pass/PassManager.h" #include "mlir/Transforms/Passes.h" #include "stablehlo/conversions/linalg/transforms/Passes.h" diff --git a/mlir/lib/Gradient/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Gradient/Transforms/BufferizableOpInterfaceImpl.cpp index 54cbc85aac..4397aafab2 100644 --- a/mlir/lib/Gradient/Transforms/BufferizableOpInterfaceImpl.cpp +++ b/mlir/lib/Gradient/Transforms/BufferizableOpInterfaceImpl.cpp @@ -124,7 +124,7 @@ void TensorType2MemrefType(const TypeRange &inTypes, SmallVector &converte } } -static BaseMemRefType +static bufferization::BufferLikeType getBufferizedFunctionArgType(FunctionOpInterface funcOp, int64_t index, const bufferization::BufferizationOptions &options) { @@ -134,7 +134,7 @@ getBufferizedFunctionArgType(FunctionOpInterface funcOp, int64_t index, BaseMemRefType memrefType = options.functionArgTypeConverterFn( tensorType, *options.defaultMemorySpaceFn(tensorType), nullptr, options); - return memrefType; + return cast(memrefType); } static ReturnOp getAssumedUniqueReturnOp(FunctionOpInterface funcOp) @@ -402,10 +402,10 @@ struct ForwardOpInterface return {}; } - FailureOr getBufferType(Operation *op, Value value, - const bufferization::BufferizationOptions &options, - const bufferization::BufferizationState &state, - SmallVector &invocationStack) const + FailureOr + getBufferType(Operation *op, Value value, const bufferization::BufferizationOptions &options, + const bufferization::BufferizationState &state, + SmallVector &invocationStack) const { // The getBufferType() method is called on either BlockArguments or OpResults. // https://github.com/llvm/llvm-project/blob/main/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.td#L506 @@ -526,10 +526,10 @@ struct ReverseOpInterface return {}; } - FailureOr getBufferType(Operation *op, Value value, - const bufferization::BufferizationOptions &options, - const bufferization::BufferizationState &state, - SmallVector &invocationStack) const + FailureOr + getBufferType(Operation *op, Value value, const bufferization::BufferizationOptions &options, + const bufferization::BufferizationState &state, + SmallVector &invocationStack) const { // See comment on the getBufferType() method on forward op. auto reverseOp = cast(op); diff --git a/mlir/lib/Gradient/Transforms/GradMethods/JVPVJPPatterns.cpp b/mlir/lib/Gradient/Transforms/GradMethods/JVPVJPPatterns.cpp index 672819d3e7..ad0e26cdf0 100644 --- a/mlir/lib/Gradient/Transforms/GradMethods/JVPVJPPatterns.cpp +++ b/mlir/lib/Gradient/Transforms/GradMethods/JVPVJPPatterns.cpp @@ -23,6 +23,7 @@ #include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Dialect/Linalg/IR/Linalg.h" +#include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/IR/SymbolTable.h" #include "Gradient/Utils/EinsumLinalgGeneric.h" @@ -60,8 +61,6 @@ template std::vector _tovec(const T &x) LogicalResult JVPLoweringPattern::matchAndRewrite(JVPOp op, PatternRewriter &rewriter) const { - MLIRContext *ctx = getContext(); - Location loc = op.getLoc(); auto func_diff_operand_indices = computeDiffArgIndices(op.getDiffArgIndices()); @@ -159,12 +158,9 @@ LogicalResult JVPLoweringPattern::matchAndRewrite(JVPOp op, PatternRewriter &rew } else { assert(acc.value().getType() == res.getType()); - - auto add_op = rewriter.create( - loc, res.getType(), ValueRange({acc.value(), res}), acc.value(), - linalg::BinaryFnAttr::get(ctx, linalg::BinaryFn::add), - linalg::TypeFnAttr::get(ctx, linalg::TypeFn::cast_signed)); - acc = add_op.getResultTensors()[0]; + auto addOp = rewriter.create( + loc, res.getType(), ValueRange{acc.value(), res}, ValueRange{acc.value()}); + acc = addOp.getResultTensors()[0]; } } assert(acc.has_value()); @@ -181,8 +177,6 @@ LogicalResult JVPLoweringPattern::matchAndRewrite(JVPOp op, PatternRewriter &rew LogicalResult VJPLoweringPattern::matchAndRewrite(VJPOp op, PatternRewriter &rewriter) const { - MLIRContext *ctx = getContext(); - Location loc = op.getLoc(); auto func_diff_operand_indices = computeDiffArgIndices(op.getDiffArgIndices()); @@ -278,11 +272,9 @@ LogicalResult VJPLoweringPattern::matchAndRewrite(VJPOp op, PatternRewriter &rew else { assert(acc.value().getType() == res.getType()); - auto add_op = rewriter.create( - loc, res.getType(), ValueRange({acc.value(), res}), acc.value(), - linalg::BinaryFnAttr::get(ctx, linalg::BinaryFn::add), - linalg::TypeFnAttr::get(ctx, linalg::TypeFn::cast_signed)); - acc = add_op.getResultTensors()[0]; + auto addOp = rewriter.create( + loc, res.getType(), ValueRange{acc.value(), res}, ValueRange{acc.value()}); + acc = addOp.getResultTensors()[0]; } } assert(acc.has_value()); diff --git a/mlir/lib/Quantum/Transforms/DecomposeLoweringPatterns.cpp b/mlir/lib/Quantum/Transforms/DecomposeLoweringPatterns.cpp index 6aa687ae22..3d5ccc8775 100644 --- a/mlir/lib/Quantum/Transforms/DecomposeLoweringPatterns.cpp +++ b/mlir/lib/Quantum/Transforms/DecomposeLoweringPatterns.cpp @@ -106,7 +106,7 @@ struct DLMultiRZOpPattern : public OpRewritePattern { LogicalResult matchAndRewrite(MultiRZOp op, PatternRewriter &rewriter) const override { - StringRef gateName = "MultiRZ"; + std::string gateName = "MultiRZ"; // Only decompose the op if it is not in the target gate set if (targetGateSet.contains(gateName)) { @@ -117,7 +117,7 @@ struct DLMultiRZOpPattern : public OpRewritePattern { auto numQubits = op.getInQubits().size(); auto MRZNameWithQubits = gateName + "_" + std::to_string(numQubits); - auto it = decompositionRegistry.find(MRZNameWithQubits.str()); + auto it = decompositionRegistry.find(MRZNameWithQubits); if (it == decompositionRegistry.end()) { return failure(); } diff --git a/mlir/lib/Quantum/Transforms/decompose_lowering.cpp b/mlir/lib/Quantum/Transforms/decompose_lowering.cpp index 681fd73f47..8f311f7681 100644 --- a/mlir/lib/Quantum/Transforms/decompose_lowering.cpp +++ b/mlir/lib/Quantum/Transforms/decompose_lowering.cpp @@ -106,9 +106,10 @@ struct DecomposeLoweringPass : impl::DecomposeLoweringPassBase