Skip to content

Commit

Permalink
[opt] Mapping pass: prototype
Browse files Browse the repository at this point in the history
  • Loading branch information
boschmitt committed Sep 19, 2023
1 parent 2defe3d commit 6ead257
Show file tree
Hide file tree
Showing 11 changed files with 1,224 additions and 3 deletions.
119 changes: 119 additions & 0 deletions include/cudaq/ADT/GraphCSR.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,119 @@
/*******************************************************************************
* Copyright (c) 2022 - 2023 NVIDIA Corporation & Affiliates. *
* All rights reserved. *
* *
* This source code and the accompanying materials are made available under *
* the terms of the Apache License 2.0 which accompanies this distribution. *
******************************************************************************/

#pragma once

#include "cudaq/Support/Handle.h"
#include "llvm/ADT/ArrayRef.h"
#include "llvm/Support/raw_ostream.h"
#include "mlir/Support/LLVM.h"

namespace cudaq {

/// Compressed Sparse Row Format (CSR) for Representing Graphs
///
/// This way of representing graphs is based on a technique that originated
/// in HPC as a way to represent sparse matrices. The format is more compact and
/// is laid out more contiguously in memory than other forms, e.g., as adjacency
/// lists, which eliminates most space overheads and reduces random memory
/// accesses.
///
/// The price payed for these advantages is reduced flexibility and complexity
/// (cognitive overhead):
/// * Adding new edges is inefficient (see `addEdgeImpl`method)
/// * The implementation is tricker than other forms.
///
/// Since adding new edges is inefficient, this class suitable for graphs whose
/// structure is fixed and given all at once.
class GraphCSR {
using Offset = unsigned;

public:
struct Node : Handle {
using Handle::Handle;
};

GraphCSR() = default;

/// Creates a new node in the graph and returns its unique identifier.
Node createNode() {
Node node(getNumNodes());
nodeOffsets.push_back(edges.size());
return node;
}

void addEdge(Node src, Node dst, bool undirected = true) {
assert(src.isValid() && "Invalid source node");
assert(dst.isValid() && "Invalid destination node");
addEdgeImpl(src, dst);
if (undirected)
addEdgeImpl(dst, src);
}

std::size_t getNumNodes() const { return nodeOffsets.size(); }

std::size_t getNumEdges() const { return edges.size(); }

mlir::ArrayRef<Node> getNeighbours(Node node) const {
assert(node.isValid() && "Invalid node");
auto begin = edges.begin() + nodeOffsets[node.index];
auto end = node == Node(getNumNodes() - 1)
? edges.end()
: edges.begin() + nodeOffsets[node.index + 1];
return mlir::ArrayRef<Node>(begin, end);
}

void dump(llvm::raw_ostream &os = llvm::errs()) const {
if (getNumNodes() == 0) {
os << "Empty graph.\n";
return;
}

os << "Number of edges: " << getNumEdges() << '\n';
std::size_t lastID = getNumNodes() - 1;
for (std::size_t id = 0; id < lastID; ++id) {
os << id << " --> {";
for (Offset j = nodeOffsets[id], end = nodeOffsets[id + 1]; j < end; ++j)
os << edges[j] << (j == end - 1 ? "" : ", ");
os << "}\n";
}

// Handle last node
os << lastID << " --> {";
for (Offset j = nodeOffsets[lastID], end = edges.size(); j < end; ++j)
os << edges[j] << (j == end - 1 ? "" : ", ");
os << "}\n";
}

private:
void addEdgeImpl(Node src, Node dst) {
// If the source node is the last node, we just need push-back edges.
if (src == Node(getNumNodes() - 1)) {
edges.push_back(dst);
return;
}

// Insert the destination node in the offset.
edges.insert(edges.begin() + nodeOffsets[src.index], dst);

// Update the offsets of all nodes that have an ID greater than `src`.
src.index += 1;
std::transform(nodeOffsets.begin() + src.index, nodeOffsets.end(),
nodeOffsets.begin() + src.index,
[](Offset offset) { return offset + 1; });
}

/// Each entry in this vector contains the starting index in the edge array
/// where the edges from that node are stored.
mlir::SmallVector<Offset> nodeOffsets;

// Stores the destination vertices of each edge.
mlir::SmallVector<Node> edges;
};

} // namespace cudaq
28 changes: 28 additions & 0 deletions include/cudaq/Optimizer/Dialect/Quake/QuakeInterfaces.td
Original file line number Diff line number Diff line change
Expand Up @@ -133,4 +133,32 @@ def MeasurementInterface : OpInterface<"MeasurementInterface"> {
];
}

def MapperInterface : OpInterface<"MapperInterface"> {
let description = [{
}];
let cppNamespace = "quake";

let methods = [
InterfaceMethod<
/*desc=*/ "Set the set of wire operands",
/*retType=*/ "void",
/*methodName=*/ "setWireOperands",
/*args=*/ (ins "mlir::ValueRange":$wires)
>,
InterfaceMethod<
/*desc=*/ "Returns the set of wire operands",
/*retType=*/ "mlir::ValueRange",
/*methodName=*/ "getWireOperands",
/*args=*/ (ins)
>,
InterfaceMethod<
/*desc=*/ "Returns the set of new wires (results)",
/*retType=*/ "mlir::ValueRange",
/*methodName=*/ "getWireResults",
/*args=*/ (ins),
/*methodBody=*/ "return $_op.getWires();"
>,
];
}

#endif // CUDAQ_OPTIMIZER_DIALECT_QUAKE_IR_QUAKE_INTERFACES
55 changes: 52 additions & 3 deletions include/cudaq/Optimizer/Dialect/Quake/QuakeOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -532,7 +532,7 @@ def quake_ResetOp : QuakeOp<"reset", [QuantumGate,
// Sink
//===----------------------------------------------------------------------===//

def quake_SinkOp : QuakeOp<"sink"> {
def quake_SinkOp : QuakeOp<"sink", [MapperInterface]> {
let summary = "Sink for a qubit that is used but not measured.";
let description = [{
The `quake.sink` operation is used to mark a particular wire in the value
Expand All @@ -554,14 +554,34 @@ def quake_SinkOp : QuakeOp<"sink"> {
let assemblyFormat = [{
$target `:` qualified(type(operands)) attr-dict
}];

let extraClassDeclaration = [{
//===------------------------------------------------------------------===//
// MapperInterface
//===------------------------------------------------------------------===//

void setWireOperands(mlir::ValueRange wires) {
assert(wires.size() == 1 && "SinkOp has only one wire operand");
getTargetMutable().assign(wires[0]);
}

mlir::ValueRange getWireOperands() {
return getTarget();
}

mlir::ValueRange getWires() {
return {};
}
}];
}

//===----------------------------------------------------------------------===//
// Measurements
//===----------------------------------------------------------------------===//

class Measurement<string mnemonic> : QuakeOp<mnemonic, [MeasurementInterface,
QuantumMeasure, DeclareOpInterfaceMethods<MemoryEffectsOpInterface>]> {
QuantumMeasure, DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
MapperInterface]> {
let arguments = (ins
Variadic<AnyQType>:$targets,
OptionalAttr<StrAttr>:$registerName
Expand All @@ -581,6 +601,18 @@ class Measurement<string mnemonic> : QuakeOp<mnemonic, [MeasurementInterface,
mlir::MemoryEffects::Effect>> &effects) {
quake::getMeasurementEffectsImpl(effects, getTargets());
}

//===------------------------------------------------------------------===//
// MapperInterface
//===------------------------------------------------------------------===//

void setWireOperands(mlir::ValueRange wires) {
getTargetsMutable().assign(wires);
}

mlir::ValueRange getWireOperands() {
return getTargets();
}
}];

let hasVerifier = 1;
Expand Down Expand Up @@ -636,7 +668,8 @@ def MzOp : Measurement<"mz"> {
class QuakeOperator<string mnemonic, list<Trait> traits = []>
: QuakeOp<mnemonic,
!listconcat([QuantumGate, AttrSizedOperandSegments, OperatorInterface,
DeclareOpInterfaceMethods<MemoryEffectsOpInterface>], traits)> {
DeclareOpInterfaceMethods<MemoryEffectsOpInterface>, MapperInterface
], traits)> {

let arguments = (ins
UnitAttr:$is_adj,
Expand Down Expand Up @@ -714,6 +747,22 @@ class QuakeOperator<string mnemonic, list<Trait> traits = []>
quake::getOperatorEffectsImpl(effects, getControls(), getTargets());
}

//===------------------------------------------------------------------===//
// MapperInterface
//===------------------------------------------------------------------===//

void setWireOperands(mlir::ValueRange wires) {
// Controls and targets operands come after the parameters.
auto i = getParameters().size();
for (auto wire : wires)
setOperand(i++, wire);
}

mlir::ValueRange getWireOperands() {
auto numParameters = getParameters().size();
return getOperands().drop_front(numParameters);
}

//===------------------------------------------------------------------===//
// Properties
//===------------------------------------------------------------------===//
Expand Down
9 changes: 9 additions & 0 deletions include/cudaq/Optimizer/Transforms/Passes.td
Original file line number Diff line number Diff line change
Expand Up @@ -453,4 +453,13 @@ def MultiControlDecompositionPass: Pass<"quake-multicontrol-decomposition",
}];
}

def MappingPass: Pass<"quake-mapping", "mlir::func::FuncOp"> {
let summary = "TODO";
let description = [{ TODO }];
let options = [
Option<"device", "device", "std::string", /*default=*/"\"-\"",
"Defines device topology.">,
];
}

#endif // CUDAQ_OPT_OPTIMIZER_TRANSFORMS_PASSES
Loading

0 comments on commit 6ead257

Please sign in to comment.