diff --git a/mlir/include/mlir/Dialect/Flux/Builder/FluxProgramBuilder.h b/mlir/include/mlir/Dialect/Flux/Builder/FluxProgramBuilder.h index 7ab595435..5edd8f5af 100644 --- a/mlir/include/mlir/Dialect/Flux/Builder/FluxProgramBuilder.h +++ b/mlir/include/mlir/Dialect/Flux/Builder/FluxProgramBuilder.h @@ -994,6 +994,30 @@ class FluxProgramBuilder final : public OpBuilder { ctrl(ValueRange controls, ValueRange targets, const std::function& body); + /** + * @brief Apply an inverse operation + * + * @param targets Target qubits + * @param body Function that builds the body containing the target operation + * @return Output qubits + * + * @par Example: + * ```c++ + * targets_out = builder.inv(targets_in, [&](auto& b) { + * auto targets_res = b.s(targets_in); + * return {targets_res}; + * }); + * ``` + * ```mlir + * %targets_out = flux.inv %targets_in { + * %targets_res = flux.s %targets_in : !flux.qubit -> !flux.qubit + * flux.yield %targets_res + * } : {!flux.qubit} -> {!flux.qubit} + * ``` + */ + ValueRange inv(ValueRange targets, + const std::function& body); + //===--------------------------------------------------------------------===// // Deallocation //===--------------------------------------------------------------------===// diff --git a/mlir/include/mlir/Dialect/Flux/IR/FluxOps.td b/mlir/include/mlir/Dialect/Flux/IR/FluxOps.td index 9741ec0d3..15bfdd617 100644 --- a/mlir/include/mlir/Dialect/Flux/IR/FluxOps.td +++ b/mlir/include/mlir/Dialect/Flux/IR/FluxOps.td @@ -1108,4 +1108,71 @@ def CtrlOp : FluxOp<"ctrl", traits = let hasVerifier = 1; } +def InvOp : FluxOp<"inv", traits = + [ + UnitaryOpInterface, + SameOperandsAndResultType, + SameOperandsAndResultShape, + SingleBlock + ]> { + let summary = "Invert a unitary operation"; + let description = [{ + A modifier operation that inverts the unitary operation defined in its body + region. The operation takes a variadic number of target qubits as inputs and + produces corresponding output qubits. + + Example: + ```mlir + %targets_out = flux.inv %targets_in { + %targets_res = flux.s %targets_in : !flux.qubit -> !flux.qubit + flux.yield %targets_res : !flux.qubit + } : {!flux.qubit} -> {!flux.qubit} + ``` + }]; + + let arguments = (ins Arg, "the target qubits", [MemRead]>:$targets_in); + let results = (outs Variadic:$targets_out); + let regions = (region SizedRegion<1>:$body); + let assemblyFormat = [{ + $targets_in + $body attr-dict `:` + `{` type($targets_in) `}` + `->` + `{` type($targets_out) `}` + }]; + + let extraClassDeclaration = [{ + UnitaryOpInterface getBodyUnitary(); + size_t getNumQubits(); + size_t getNumTargets(); + size_t getNumControls(); + size_t getNumPosControls(); + size_t getNumNegControls(); + Value getInputQubit(size_t i); + Value getOutputQubit(size_t i); + Value getInputTarget(size_t i); + Value getOutputTarget(size_t i); + Value getInputPosControl(size_t i); + Value getOutputPosControl(size_t i); + Value getInputNegControl(size_t i); + Value getOutputNegControl(size_t i); + Value getInputForOutput(Value output); + Value getOutputForInput(Value input); + size_t getNumParams(); + Value getParameter(size_t i); + static StringRef getBaseSymbol() { return "inv"; } + }]; + + let builders = [ + OpBuilder<(ins "ValueRange":$targets), [{ + build($_builder, $_state, targets.getTypes(), targets); + }]>, + OpBuilder<(ins "ValueRange":$targets, "UnitaryOpInterface":$bodyUnitary)>, + OpBuilder<(ins "ValueRange":$targets, "const std::function&":$bodyBuilder)> + ]; + + let hasCanonicalizer = 1; + let hasVerifier = 1; +} + #endif // FluxOPS diff --git a/mlir/include/mlir/Dialect/Quartz/Builder/QuartzProgramBuilder.h b/mlir/include/mlir/Dialect/Quartz/Builder/QuartzProgramBuilder.h index d2367bbef..a578a2f94 100644 --- a/mlir/include/mlir/Dialect/Quartz/Builder/QuartzProgramBuilder.h +++ b/mlir/include/mlir/Dialect/Quartz/Builder/QuartzProgramBuilder.h @@ -824,6 +824,25 @@ class QuartzProgramBuilder final : public OpBuilder { QuartzProgramBuilder& ctrl(ValueRange controls, const std::function& body); + /** + * @brief Apply an inverse (i.e., adjoint) operation. + * + * @param body Function that builds the body containing the operation to + * invert + * @return QuartzProgramBuilder& Reference to this builder for method chaining + * + * @par Example: + * ```c++ + * builder.inv([&](auto& b) { b.s(q0); }); + * ``` + * ```mlir + * quartz.inv { + * quartz.s %q0 : !quartz.qubit + * } + * ``` + */ + QuartzProgramBuilder& inv(const std::function& body); + //===--------------------------------------------------------------------===// // Deallocation //===--------------------------------------------------------------------===// diff --git a/mlir/include/mlir/Dialect/Quartz/IR/QuartzOps.td b/mlir/include/mlir/Dialect/Quartz/IR/QuartzOps.td index 20325e06a..d2521e067 100644 --- a/mlir/include/mlir/Dialect/Quartz/IR/QuartzOps.td +++ b/mlir/include/mlir/Dialect/Quartz/IR/QuartzOps.td @@ -978,4 +978,50 @@ def CtrlOp : QuartzOp<"ctrl", let hasVerifier = 1; } +def InvOp : QuartzOp<"inv", + traits = [ + UnitaryOpInterface, + SingleBlockImplicitTerminator<"::mlir::quartz::YieldOp"> + ]> { + let summary = "Invert a unitary operation"; + let description = [{ + A modifier operation that inverts the unitary operation defined in its body + region. + + Example: + ```mlir + quartz.inv { + quartz.s %q0 : !quartz.qubit + } + ``` + }]; + + let regions = (region SizedRegion<1>:$body); + let assemblyFormat = "$body attr-dict"; + + let extraClassDeclaration = [{ + [[nodiscard]] UnitaryOpInterface getBodyUnitary(); + size_t getNumQubits(); + size_t getNumTargets(); + size_t getNumControls(); + size_t getNumPosControls(); + size_t getNumNegControls(); + Value getQubit(size_t i); + Value getTarget(size_t i); + Value getPosControl(size_t i); + Value getNegControl(size_t i); + size_t getNumParams(); + Value getParameter(size_t i); + static StringRef getBaseSymbol() { return "inv"; } + }]; + + let builders = [ + OpBuilder<(ins "UnitaryOpInterface":$bodyUnitary)>, + OpBuilder<(ins "const std::function&":$bodyBuilder)> + ]; + + let hasCanonicalizer = 1; + let hasVerifier = 1; +} + #endif // QUARTZ_OPS diff --git a/mlir/lib/Conversion/FluxToQuartz/FluxToQuartz.cpp b/mlir/lib/Conversion/FluxToQuartz/FluxToQuartz.cpp index 97e5448ed..fd7f04bf4 100644 --- a/mlir/lib/Conversion/FluxToQuartz/FluxToQuartz.cpp +++ b/mlir/lib/Conversion/FluxToQuartz/FluxToQuartz.cpp @@ -798,6 +798,44 @@ struct ConvertFluxCtrlOp final : OpConversionPattern { } }; +/** + * @brief Converts flux.inv to quartz.inv + * + * @par Example: + * ```mlir + * %targets_out = flux.inv %targets_in { + * %targets_res = flux.s %targets_in : !flux.qubit -> !flux.qubit + * flux.yield %targets_res + * } : {!flux.qubit} -> {!flux.qubit} + * ``` + * is converted to + * ```mlir + * quartz.inv { + * quartz.s %q0 : !quartz.qubit + * quartz.yield + * } + * ``` + */ +struct ConvertFluxInvOp final : OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(flux::InvOp op, OpAdaptor adaptor, + ConversionPatternRewriter& rewriter) const override { + // Create quartz.inv operation + auto quartzOp = rewriter.create(op.getLoc()); + + // Clone body region from Flux to Quartz + auto& dstRegion = quartzOp.getBody(); + rewriter.cloneRegionBefore(op.getBody(), dstRegion, dstRegion.end()); + + // Replace the output qubits with the same Quartz references + rewriter.replaceOp(op, adaptor.getOperands()); + + return success(); + } +}; + /** * @brief Converts flux.yield to quartz.yield * @@ -865,19 +903,18 @@ struct FluxToQuartz final : impl::FluxToQuartzBase { // Register operation conversion patterns // Note: No state tracking needed - OpAdaptors handle type conversion - patterns - .add( - typeConverter, context); + patterns.add< + ConvertFluxAllocOp, ConvertFluxDeallocOp, ConvertFluxStaticOp, + ConvertFluxMeasureOp, ConvertFluxResetOp, ConvertFluxGPhaseOp, + ConvertFluxIdOp, ConvertFluxXOp, ConvertFluxYOp, ConvertFluxZOp, + ConvertFluxHOp, ConvertFluxSOp, ConvertFluxSdgOp, ConvertFluxTOp, + ConvertFluxTdgOp, ConvertFluxSXOp, ConvertFluxSXdgOp, ConvertFluxRXOp, + ConvertFluxRYOp, ConvertFluxRZOp, ConvertFluxPOp, ConvertFluxROp, + ConvertFluxU2Op, ConvertFluxUOp, ConvertFluxSWAPOp, ConvertFluxiSWAPOp, + ConvertFluxDCXOp, ConvertFluxECROp, ConvertFluxRXXOp, ConvertFluxRYYOp, + ConvertFluxRZXOp, ConvertFluxRZZOp, ConvertFluxXXPlusYYOp, + ConvertFluxXXMinusYYOp, ConvertFluxBarrierOp, ConvertFluxCtrlOp, + ConvertFluxInvOp, ConvertFluxYieldOp>(typeConverter, context); // Conversion of flux types in func.func signatures // Note: This currently has limitations with signature changes diff --git a/mlir/lib/Conversion/QuartzToFlux/QuartzToFlux.cpp b/mlir/lib/Conversion/QuartzToFlux/QuartzToFlux.cpp index f3345c53c..4b6d1ea32 100644 --- a/mlir/lib/Conversion/QuartzToFlux/QuartzToFlux.cpp +++ b/mlir/lib/Conversion/QuartzToFlux/QuartzToFlux.cpp @@ -1148,6 +1148,69 @@ struct ConvertQuartzCtrlOp final : StatefulOpConversionPattern { } }; +/** + * @brief Converts quartz.inv to flux.inv + * + * @par Example: + * ```mlir + * quartz.inv { + * quartz.s %q0 + * quartz.yield + * } + * ``` + * is converted to + * ```mlir + * %targets_out = flux.inv %targets_in { + * %targets_res = flux.s %targets_in : !flux.qubit -> !flux.qubit + * flux.yield %targets_res + * } : {!flux.qubit} -> {!flux.qubit} + * ``` + */ +struct ConvertQuartzInvOp final : StatefulOpConversionPattern { + using StatefulOpConversionPattern::StatefulOpConversionPattern; + + LogicalResult + matchAndRewrite(quartz::InvOp op, OpAdaptor /*adaptor*/, + ConversionPatternRewriter& rewriter) const override { + auto& state = getState(); + auto& qubitMap = state.qubitMap; + + // Get Flux targets from state map + const auto numTargets = op.getNumTargets(); + SmallVector fluxTargets; + fluxTargets.reserve(numTargets); + for (size_t i = 0; i < numTargets; ++i) { + const auto& quartzTarget = op.getTarget(i); + assert(qubitMap.contains(quartzTarget) && "Quartz qubit not found"); + const auto& fluxTarget = qubitMap[quartzTarget]; + fluxTargets.push_back(fluxTarget); + } + + // Create flux.inv + auto fluxOp = rewriter.create(op.getLoc(), fluxTargets); + + // Update state map + if (state.inCtrlOp == 0) { + const auto targetsOut = fluxOp.getTargetsOut(); + for (size_t i = 0; i < numTargets; ++i) { + const auto& quartzTarget = op.getTarget(i); + qubitMap[quartzTarget] = targetsOut[i]; + } + } + + // Update modifier information + state.inCtrlOp++; + state.targetsIn.try_emplace(state.inCtrlOp, fluxTargets); + + // Clone body region from Quartz to Flux + auto& dstRegion = fluxOp.getBody(); + rewriter.cloneRegionBefore(op.getBody(), dstRegion, dstRegion.end()); + + rewriter.eraseOp(op); + return success(); + } +}; + /** * @brief Converts quartz.yield to flux.yield * @@ -1220,19 +1283,20 @@ struct QuartzToFlux final : impl::QuartzToFluxBase { // Register operation conversion patterns with state // tracking - patterns.add< - ConvertQuartzAllocOp, ConvertQuartzDeallocOp, ConvertQuartzStaticOp, - ConvertQuartzMeasureOp, ConvertQuartzResetOp, ConvertQuartzGPhaseOp, - ConvertQuartzIdOp, ConvertQuartzXOp, ConvertQuartzYOp, ConvertQuartzZOp, - ConvertQuartzHOp, ConvertQuartzSOp, ConvertQuartzSdgOp, - ConvertQuartzTOp, ConvertQuartzTdgOp, ConvertQuartzSXOp, - ConvertQuartzSXdgOp, ConvertQuartzRXOp, ConvertQuartzRYOp, - ConvertQuartzRZOp, ConvertQuartzPOp, ConvertQuartzROp, - ConvertQuartzU2Op, ConvertQuartzUOp, ConvertQuartzSWAPOp, - ConvertQuartziSWAPOp, ConvertQuartzDCXOp, ConvertQuartzECROp, - ConvertQuartzRXXOp, ConvertQuartzRYYOp, ConvertQuartzRZXOp, - ConvertQuartzRZZOp, ConvertQuartzXXPlusYYOp, ConvertQuartzXXMinusYYOp, - ConvertQuartzBarrierOp, ConvertQuartzCtrlOp, ConvertQuartzYieldOp>( + patterns.add( typeConverter, context, &state); // Conversion of quartz types in func.func signatures diff --git a/mlir/lib/Dialect/Flux/Builder/FluxProgramBuilder.cpp b/mlir/lib/Dialect/Flux/Builder/FluxProgramBuilder.cpp index 9c71837ae..e1c67bdb1 100644 --- a/mlir/lib/Dialect/Flux/Builder/FluxProgramBuilder.cpp +++ b/mlir/lib/Dialect/Flux/Builder/FluxProgramBuilder.cpp @@ -592,6 +592,22 @@ std::pair FluxProgramBuilder::ctrl( return {controlsOut, targetsOut}; } +ValueRange FluxProgramBuilder::inv( + ValueRange targets, + const std::function& body) { + checkFinalized(); + + auto invOp = create(loc, targets, body); + + // Update tracking + const auto& targetsOut = invOp.getTargetsOut(); + for (const auto& [target, targetOut] : llvm::zip(targets, targetsOut)) { + updateQubitTracking(target, targetOut); + } + + return targetsOut; +} + //===----------------------------------------------------------------------===// // Deallocation //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/Flux/IR/Modifiers/InvOp.cpp b/mlir/lib/Dialect/Flux/IR/Modifiers/InvOp.cpp new file mode 100644 index 000000000..4cb137ca4 --- /dev/null +++ b/mlir/lib/Dialect/Flux/IR/Modifiers/InvOp.cpp @@ -0,0 +1,207 @@ +/* + * Copyright (c) 2023 - 2025 Chair for Design Automation, TUM + * Copyright (c) 2025 Munich Quantum Software Company GmbH + * All rights reserved. + * + * SPDX-License-Identifier: MIT + * + * Licensed under the MIT License + */ + +#include "mlir/Dialect/Flux/IR/FluxDialect.h" + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +using namespace mlir; +using namespace mlir::flux; + +namespace { + +/** + * @brief Cancel nested inverse modifiers, i.e., `inv(inv(x)) => x`. + */ +struct CancelNestedInv final : OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(InvOp op, + PatternRewriter& rewriter) const override { + auto innerUnitary = op.getBodyUnitary(); + auto innerInvOp = llvm::dyn_cast(innerUnitary.getOperation()); + if (!innerInvOp) { + return failure(); + } + + // Remove both inverse operations + auto innerInnerUnitary = innerInvOp.getBodyUnitary(); + auto* clonedOp = rewriter.clone(*innerInnerUnitary.getOperation()); + rewriter.replaceOp(op, clonedOp->getResults()); + + return success(); + } +}; + +} // namespace + +UnitaryOpInterface InvOp::getBodyUnitary() { + return llvm::dyn_cast(&getBody().front().front()); +} + +size_t InvOp::getNumQubits() { return getNumTargets() + getNumControls(); } + +size_t InvOp::getNumTargets() { return getTargetsIn().size(); } + +size_t InvOp::getNumControls() { + return getNumPosControls() + getNumNegControls(); +} + +size_t InvOp::getNumPosControls() { + return getBodyUnitary().getNumPosControls(); +} + +size_t InvOp::getNumNegControls() { + return getBodyUnitary().getNumNegControls(); +} + +Value InvOp::getInputQubit(const size_t i) { + return getBodyUnitary().getInputQubit(i); +} + +Value InvOp::getOutputQubit(const size_t i) { + return getBodyUnitary().getOutputQubit(i); +} + +Value InvOp::getInputTarget(const size_t i) { + if (i >= getNumTargets()) { + llvm::reportFatalUsageError("Target index out of bounds"); + } + return getTargetsIn()[i]; +} + +Value InvOp::getOutputTarget(const size_t i) { + if (i >= getNumTargets()) { + llvm::reportFatalUsageError("Target index out of bounds"); + } + return getTargetsOut()[i]; +} + +Value InvOp::getInputPosControl(const size_t i) { + return getBodyUnitary().getInputPosControl(i); +} + +Value InvOp::getOutputPosControl(const size_t i) { + return getBodyUnitary().getOutputPosControl(i); +} + +Value InvOp::getInputNegControl(const size_t i) { + return getBodyUnitary().getInputNegControl(i); +} + +Value InvOp::getOutputNegControl(const size_t i) { + return getBodyUnitary().getOutputNegControl(i); +} + +Value InvOp::getInputForOutput(Value output) { + for (size_t i = 0; i < getNumTargets(); ++i) { + if (output == getTargetsOut()[i]) { + return getTargetsIn()[i]; + } + } + llvm::reportFatalUsageError("Given qubit is not an output of the operation"); +} + +Value InvOp::getOutputForInput(Value input) { + for (size_t i = 0; i < getNumTargets(); ++i) { + if (input == getTargetsIn()[i]) { + return getTargetsOut()[i]; + } + } + llvm::reportFatalUsageError("Given qubit is not an input of the operation"); +} + +size_t InvOp::getNumParams() { return getBodyUnitary().getNumParams(); } + +Value InvOp::getParameter(const size_t i) { + return getBodyUnitary().getParameter(i); +} + +void InvOp::build(OpBuilder& builder, OperationState& state, + const ValueRange targets, UnitaryOpInterface bodyUnitary) { + build(builder, state, targets); + auto& block = state.regions.front()->emplaceBlock(); + + // Move the unitary op into the block + const OpBuilder::InsertionGuard guard(builder); + builder.setInsertionPointToStart(&block); + auto* op = builder.clone(*bodyUnitary.getOperation()); + builder.create(state.location, op->getResults()); +} + +void InvOp::build( + OpBuilder& builder, OperationState& state, const ValueRange targets, + const std::function& bodyBuilder) { + build(builder, state, targets); + auto& block = state.regions.front()->emplaceBlock(); + + // Move the unitary op into the block + const OpBuilder::InsertionGuard guard(builder); + builder.setInsertionPointToStart(&block); + auto targetsOut = bodyBuilder(builder, targets); + builder.create(state.location, targetsOut); +} + +LogicalResult InvOp::verify() { + auto& block = getBody().front(); + if (block.getOperations().size() != 2) { + return emitOpError("body region must have exactly two operations"); + } + if (!llvm::isa(block.front())) { + return emitOpError( + "first operation in body region must be a unitary operation"); + } + if (!llvm::isa(block.back())) { + return emitOpError( + "second operation in body region must be a yield operation"); + } + if (block.back().getNumOperands() != getNumTargets()) { + return emitOpError("yield operation must yield ") + << getNumTargets() << " values, but found " + << block.back().getNumOperands(); + } + + SmallPtrSet uniqueQubitsIn; + auto bodyUnitary = getBodyUnitary(); + const auto numQubits = bodyUnitary.getNumQubits(); + for (size_t i = 0; i < numQubits; i++) { + if (!uniqueQubitsIn.insert(bodyUnitary.getInputQubit(i)).second) { + return emitOpError("duplicate qubit found"); + } + } + SmallPtrSet uniqueQubitsOut; + for (size_t i = 0; i < numQubits; i++) { + if (!uniqueQubitsOut.insert(bodyUnitary.getOutputQubit(i)).second) { + return emitOpError("duplicate qubit found"); + } + } + + if (llvm::isa(bodyUnitary.getOperation())) { + return emitOpError("BarrierOp cannot be inverted"); + } + + return success(); +} + +void InvOp::getCanonicalizationPatterns(RewritePatternSet& results, + MLIRContext* context) { + results.add(context); +} diff --git a/mlir/lib/Dialect/Quartz/Builder/QuartzProgramBuilder.cpp b/mlir/lib/Dialect/Quartz/Builder/QuartzProgramBuilder.cpp index bdd32fe09..252886892 100644 --- a/mlir/lib/Dialect/Quartz/Builder/QuartzProgramBuilder.cpp +++ b/mlir/lib/Dialect/Quartz/Builder/QuartzProgramBuilder.cpp @@ -423,6 +423,13 @@ QuartzProgramBuilder::ctrl(ValueRange controls, return *this; } +QuartzProgramBuilder& +QuartzProgramBuilder::inv(const std::function& body) { + checkFinalized(); + create(loc, body); + return *this; +} + //===----------------------------------------------------------------------===// // Deallocation //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/Quartz/IR/Modifiers/InvOp.cpp b/mlir/lib/Dialect/Quartz/IR/Modifiers/InvOp.cpp new file mode 100644 index 000000000..3c86e17e1 --- /dev/null +++ b/mlir/lib/Dialect/Quartz/IR/Modifiers/InvOp.cpp @@ -0,0 +1,140 @@ +/* + * Copyright (c) 2023 - 2025 Chair for Design Automation, TUM + * Copyright (c) 2025 Munich Quantum Software Company GmbH + * All rights reserved. + * + * SPDX-License-Identifier: MIT + * + * Licensed under the MIT License + */ + +#include "mlir/Dialect/Quartz/IR/QuartzDialect.h" + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +using namespace mlir; +using namespace mlir::quartz; + +namespace { + +/** + * @brief Cancel nested inverse modifiers, i.e., `inv(inv(x)) => x`. + */ +struct CancelNestedInv final : OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(InvOp invOp, + PatternRewriter& rewriter) const override { + auto innerUnitary = invOp.getBodyUnitary(); + auto innerInvOp = llvm::dyn_cast(innerUnitary.getOperation()); + if (!innerInvOp) { + return failure(); + } + + auto innerInnerUnitary = innerInvOp.getBodyUnitary(); + auto* clonedOp = rewriter.clone(*innerInnerUnitary.getOperation()); + rewriter.replaceOp(invOp, clonedOp->getResults()); + + return success(); + } +}; + +} // namespace + +UnitaryOpInterface InvOp::getBodyUnitary() { + return llvm::dyn_cast(&getBody().front().front()); +} + +size_t InvOp::getNumQubits() { return getNumTargets() + getNumControls(); } + +size_t InvOp::getNumTargets() { return getBodyUnitary().getNumTargets(); } + +size_t InvOp::getNumControls() { + return getNumPosControls() + getNumNegControls(); +} + +size_t InvOp::getNumPosControls() { + return getBodyUnitary().getNumPosControls(); +} + +size_t InvOp::getNumNegControls() { + return getBodyUnitary().getNumNegControls(); +} + +Value InvOp::getQubit(const size_t i) { return getBodyUnitary().getQubit(i); } + +Value InvOp::getTarget(const size_t i) { return getBodyUnitary().getTarget(i); } + +Value InvOp::getPosControl(const size_t i) { + return getBodyUnitary().getPosControl(i); +} + +Value InvOp::getNegControl(const size_t i) { + return getBodyUnitary().getNegControl(i); +} + +size_t InvOp::getNumParams() { return getBodyUnitary().getNumParams(); } + +Value InvOp::getParameter(const size_t i) { + return getBodyUnitary().getParameter(i); +} + +void InvOp::build(OpBuilder& builder, OperationState& state, + UnitaryOpInterface bodyUnitary) { + const OpBuilder::InsertionGuard guard(builder); + auto* region = state.addRegion(); + auto& block = region->emplaceBlock(); + + // Move the unitary op into the block + builder.setInsertionPointToStart(&block); + builder.clone(*bodyUnitary.getOperation()); + builder.create(state.location); +} + +void InvOp::build(OpBuilder& builder, OperationState& state, + const std::function& bodyBuilder) { + const OpBuilder::InsertionGuard guard(builder); + auto* region = state.addRegion(); + auto& block = region->emplaceBlock(); + + builder.setInsertionPointToStart(&block); + bodyBuilder(builder); + builder.create(state.location); +} + +LogicalResult InvOp::verify() { + auto& block = getBody().front(); + if (block.getOperations().size() != 2) { + return emitOpError("body region must have exactly two operations"); + } + if (!llvm::isa(block.front())) { + return emitOpError( + "first operation in body region must be a unitary operation"); + } + if (!llvm::isa(block.back())) { + return emitOpError( + "second operation in body region must be a yield operation"); + } + llvm::SmallPtrSet uniqueQubits; + auto bodyUnitary = getBodyUnitary(); + const auto numQubits = bodyUnitary.getNumQubits(); + for (size_t i = 0; i < numQubits; i++) { + if (!uniqueQubits.insert(bodyUnitary.getQubit(i)).second) { + return emitOpError("duplicate qubit found"); + } + } + return success(); +} + +void InvOp::getCanonicalizationPatterns(RewritePatternSet& results, + MLIRContext* context) { + results.add(context); +}