From e6e5dca3f08e55a89732326b00559d2f068df42c Mon Sep 17 00:00:00 2001 From: Martin Erhart Date: Mon, 9 Dec 2024 12:45:19 +0000 Subject: [PATCH] [RTG] Add ElaborationPass (#7876) --- docs/Dialects/RTG.md | 4 + include/circt/Dialect/RTG/CMakeLists.txt | 1 + .../Dialect/RTG/Transforms/CMakeLists.txt | 6 + .../circt/Dialect/RTG/Transforms/RTGPasses.h | 31 + .../circt/Dialect/RTG/Transforms/RTGPasses.td | 35 + .../circt/Dialect/RTGTest/IR/RTGTestOps.td | 10 + include/circt/InitAllPasses.h | 2 + lib/Dialect/RTG/CMakeLists.txt | 1 + lib/Dialect/RTG/Transforms/CMakeLists.txt | 16 + .../RTG/Transforms/ElaborationPass.cpp | 624 ++++++++++++++++++ lib/Dialect/RTGTest/IR/RTGTestOps.cpp | 8 + test/Dialect/RTG/Transform/elaboration.mlir | 80 +++ test/Dialect/RTGTest/IR/basic.mlir | 5 +- 13 files changed, 822 insertions(+), 1 deletion(-) create mode 100644 include/circt/Dialect/RTG/Transforms/CMakeLists.txt create mode 100644 include/circt/Dialect/RTG/Transforms/RTGPasses.h create mode 100644 include/circt/Dialect/RTG/Transforms/RTGPasses.td create mode 100644 lib/Dialect/RTG/Transforms/CMakeLists.txt create mode 100644 lib/Dialect/RTG/Transforms/ElaborationPass.cpp create mode 100644 test/Dialect/RTG/Transform/elaboration.mlir diff --git a/docs/Dialects/RTG.md b/docs/Dialects/RTG.md index b217daeb5ac1..0fc053728be8 100644 --- a/docs/Dialects/RTG.md +++ b/docs/Dialects/RTG.md @@ -273,3 +273,7 @@ companion dialect to define any backends. ## Types [include "Dialects/RTGTypes.md"] + +## Passes + +[include "RTGPasses.md"] diff --git a/include/circt/Dialect/RTG/CMakeLists.txt b/include/circt/Dialect/RTG/CMakeLists.txt index f33061b2d87c..9f57627c321f 100644 --- a/include/circt/Dialect/RTG/CMakeLists.txt +++ b/include/circt/Dialect/RTG/CMakeLists.txt @@ -1 +1,2 @@ add_subdirectory(IR) +add_subdirectory(Transforms) diff --git a/include/circt/Dialect/RTG/Transforms/CMakeLists.txt b/include/circt/Dialect/RTG/Transforms/CMakeLists.txt new file mode 100644 index 000000000000..06f8b8112747 --- /dev/null +++ b/include/circt/Dialect/RTG/Transforms/CMakeLists.txt @@ -0,0 +1,6 @@ +set(LLVM_TARGET_DEFINITIONS RTGPasses.td) +mlir_tablegen(RTGPasses.h.inc -gen-pass-decls) +add_public_tablegen_target(CIRCTRTGTransformsIncGen) + +# Generate Pass documentation. +add_circt_doc(RTGPasses RTGPasses -gen-pass-doc) diff --git a/include/circt/Dialect/RTG/Transforms/RTGPasses.h b/include/circt/Dialect/RTG/Transforms/RTGPasses.h new file mode 100644 index 000000000000..a2a31021db0b --- /dev/null +++ b/include/circt/Dialect/RTG/Transforms/RTGPasses.h @@ -0,0 +1,31 @@ +//===- RTGPasses.h - RTG pass entry points ----------------------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This header file defines prototypes that expose pass constructors. +// +//===----------------------------------------------------------------------===// + +#ifndef CIRCT_DIALECT_RTG_TRANSFORMS_RTGPASSES_H +#define CIRCT_DIALECT_RTG_TRANSFORMS_RTGPASSES_H + +#include "mlir/Pass/Pass.h" +#include + +namespace circt { +namespace rtg { + +/// Generate the code for registering passes. +#define GEN_PASS_REGISTRATION +#define GEN_PASS_DECL +#include "circt/Dialect/RTG/Transforms/RTGPasses.h.inc" +#undef GEN_PASS_REGISTRATION + +} // namespace rtg +} // namespace circt + +#endif // CIRCT_DIALECT_RTG_TRANSFORMS_RTGPASSES_H diff --git a/include/circt/Dialect/RTG/Transforms/RTGPasses.td b/include/circt/Dialect/RTG/Transforms/RTGPasses.td new file mode 100644 index 000000000000..6eff4b5be034 --- /dev/null +++ b/include/circt/Dialect/RTG/Transforms/RTGPasses.td @@ -0,0 +1,35 @@ +//===-- RTGPasses.td - RTG pass definition file ------------*- tablegen -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file defines the passes that operate on the RTG dialect. +// +//===----------------------------------------------------------------------===// + +#ifndef CIRCT_DIALECT_RTG_TRANSFORMS_RTGPASSES_TD +#define CIRCT_DIALECT_RTG_TRANSFORMS_RTGPASSES_TD + +include "mlir/Pass/PassBase.td" + +def ElaborationPass : Pass<"rtg-elaborate", "mlir::ModuleOp"> { + let summary = "elaborate the randomization parts"; + let description = [{ + This pass interprets most RTG operations to perform the represented + randomization and in the process get rid of those operations. This means, + after this pass the IR does not contain any random constructs within tests + anymore. + }]; + + let options = [ + Option<"seed", "seed", "unsigned", /*default=*/"", + "The seed for any RNG constructs used in the pass.">, + ]; + + let dependentDialects = ["mlir::arith::ArithDialect"]; +} + +#endif // CIRCT_DIALECT_RTG_TRANSFORMS_RTGPASSES_TD diff --git a/include/circt/Dialect/RTGTest/IR/RTGTestOps.td b/include/circt/Dialect/RTGTest/IR/RTGTestOps.td index 697cf6c38571..f73b4c37acab 100644 --- a/include/circt/Dialect/RTGTest/IR/RTGTestOps.td +++ b/include/circt/Dialect/RTGTest/IR/RTGTestOps.td @@ -33,3 +33,13 @@ def CPUDeclOp : RTGTestOp<"cpu_decl", [ let assemblyFormat = "$id attr-dict"; } + +def ConstantTestOp : RTGTestOp<"constant_test", [ + Pure, ConstantLike, +]> { + let arguments = (ins AnyAttr:$value); + let results = (outs AnyType:$result); + + let assemblyFormat = "type($result) attr-dict"; + let hasFolder = 1; +} diff --git a/include/circt/InitAllPasses.h b/include/circt/InitAllPasses.h index 48ed141491c8..9f3c1d7f695e 100644 --- a/include/circt/InitAllPasses.h +++ b/include/circt/InitAllPasses.h @@ -32,6 +32,7 @@ #include "circt/Dialect/Moore/MoorePasses.h" #include "circt/Dialect/OM/OMPasses.h" #include "circt/Dialect/Pipeline/PipelinePasses.h" +#include "circt/Dialect/RTG/Transforms/RTGPasses.h" #include "circt/Dialect/SSP/SSPPasses.h" #include "circt/Dialect/SV/SVPasses.h" #include "circt/Dialect/Seq/SeqPasses.h" @@ -73,6 +74,7 @@ inline void registerAllPasses() { sv::registerPasses(); handshake::registerPasses(); kanagawa::registerPasses(); + rtg::registerPasses(); hw::registerPasses(); pipeline::registerPasses(); sim::registerPasses(); diff --git a/lib/Dialect/RTG/CMakeLists.txt b/lib/Dialect/RTG/CMakeLists.txt index f33061b2d87c..9f57627c321f 100644 --- a/lib/Dialect/RTG/CMakeLists.txt +++ b/lib/Dialect/RTG/CMakeLists.txt @@ -1 +1,2 @@ add_subdirectory(IR) +add_subdirectory(Transforms) diff --git a/lib/Dialect/RTG/Transforms/CMakeLists.txt b/lib/Dialect/RTG/Transforms/CMakeLists.txt new file mode 100644 index 000000000000..70dcd46cb5f9 --- /dev/null +++ b/lib/Dialect/RTG/Transforms/CMakeLists.txt @@ -0,0 +1,16 @@ +add_circt_dialect_library(CIRCTRTGTransforms + ElaborationPass.cpp + + DEPENDS + CIRCTRTGTransformsIncGen + + LINK_COMPONENTS + Support + + LINK_LIBS PRIVATE + CIRCTRTGDialect + MLIRArithDialect + MLIRIR + MLIRPass +) + diff --git a/lib/Dialect/RTG/Transforms/ElaborationPass.cpp b/lib/Dialect/RTG/Transforms/ElaborationPass.cpp new file mode 100644 index 000000000000..16a30d59cacf --- /dev/null +++ b/lib/Dialect/RTG/Transforms/ElaborationPass.cpp @@ -0,0 +1,624 @@ +//===- ElaborationPass.cpp - RTG ElaborationPass implementation -----------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This pass elaborates the random parts of the RTG dialect. +// It performs randomization top-down, i.e., random constructs in a sequence +// that is invoked multiple times can yield different randomization results +// for each invokation. +// +//===----------------------------------------------------------------------===// + +#include "circt/Dialect/RTG/IR/RTGOps.h" +#include "circt/Dialect/RTG/IR/RTGVisitors.h" +#include "circt/Dialect/RTG/Transforms/RTGPasses.h" +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/IR/IRMapping.h" +#include "mlir/IR/PatternMatch.h" +#include "llvm/Support/Debug.h" +#include +#include + +namespace circt { +namespace rtg { +#define GEN_PASS_DEF_ELABORATIONPASS +#include "circt/Dialect/RTG/Transforms/RTGPasses.h.inc" +} // namespace rtg +} // namespace circt + +using namespace mlir; +using namespace circt; +using namespace circt::rtg; + +#define DEBUG_TYPE "rtg-elaboration" + +//===----------------------------------------------------------------------===// +// Uniform Distribution Helper +// +// Simplified version of +// https://github.com/llvm/llvm-project/blob/main/libcxx/include/__random/uniform_int_distribution.h +// We use our custom version here to get the same results when compiled with +// different compiler versions and standard libraries. +//===----------------------------------------------------------------------===// + +static uint32_t computeMask(size_t w) { + size_t n = w / 32 + (w % 32 != 0); + size_t w0 = w / n; + return w0 > 0 ? uint32_t(~0) >> (32 - w0) : 0; +} + +/// Get a number uniformly at random in the in specified range. +static uint32_t getUniformlyInRange(std::mt19937 &rng, uint32_t a, uint32_t b) { + const uint32_t diff = b - a + 1; + if (diff == 1) + return a; + + const uint32_t digits = std::numeric_limits::digits; + if (diff == 0) + return rng(); + + uint32_t width = digits - llvm::countl_zero(diff) - 1; + if ((diff & (std::numeric_limits::max() >> (digits - width))) != 0) + ++width; + + uint32_t mask = computeMask(diff); + uint32_t u; + do { + u = rng() & mask; + } while (u >= diff); + + return u + a; +} + +//===----------------------------------------------------------------------===// +// Elaborator Values +//===----------------------------------------------------------------------===// + +namespace { + +/// The abstract base class for elaborated values. +struct ElaboratorValue { +public: + enum class ValueKind { Attribute, Set }; + + ElaboratorValue(ValueKind kind) : kind(kind) {} + virtual ~ElaboratorValue() {} + + virtual llvm::hash_code getHashValue() const = 0; + virtual bool isEqual(const ElaboratorValue &other) const = 0; + +#ifndef NDEBUG + virtual void print(llvm::raw_ostream &os) const = 0; +#endif + + ValueKind getKind() const { return kind; } + +private: + const ValueKind kind; +}; + +/// Holds any typed attribute. Wrapping around an MLIR `Attribute` allows us to +/// use this elaborator value class for any values that have a corresponding +/// MLIR attribute rather than one per kind of attribute. We only support typed +/// attributes because for materialization we need to provide the type to the +/// dialect's materializer. +class AttributeValue : public ElaboratorValue { +public: + AttributeValue(TypedAttr attr) + : ElaboratorValue(ValueKind::Attribute), attr(attr) { + assert(attr && "null attributes not allowed"); + } + + // Implement LLVMs RTTI + static bool classof(const ElaboratorValue *val) { + return val->getKind() == ValueKind::Attribute; + } + + llvm::hash_code getHashValue() const override { + return llvm::hash_combine(attr); + } + + bool isEqual(const ElaboratorValue &other) const override { + auto *attrValue = dyn_cast(&other); + if (!attrValue) + return false; + + return attr == attrValue->attr; + } + +#ifndef NDEBUG + void print(llvm::raw_ostream &os) const override { + os << ""; + } +#endif + + TypedAttr getAttr() const { return attr; } + +private: + const TypedAttr attr; +}; + +/// Holds an evaluated value of a `SetType`'d value. +class SetValue : public ElaboratorValue { +public: + SetValue(SetVector &&set, Type type) + : ElaboratorValue(ValueKind::Set), set(std::move(set)), type(type), + cachedHash(llvm::hash_combine( + llvm::hash_combine_range(set.begin(), set.end()), type)) {} + + // Implement LLVMs RTTI + static bool classof(const ElaboratorValue *val) { + return val->getKind() == ValueKind::Set; + } + + llvm::hash_code getHashValue() const override { return cachedHash; } + + bool isEqual(const ElaboratorValue &other) const override { + auto *otherSet = dyn_cast(&other); + if (!otherSet) + return false; + + if (cachedHash != otherSet->cachedHash) + return false; + + // Make sure empty sets of different types are not considered equal + return set == otherSet->set && type == otherSet->type; + } + +#ifndef NDEBUG + void print(llvm::raw_ostream &os) const override { + os << "print(os); }); + os << "} at " << this << ">"; + } +#endif + + const SetVector &getSet() const { return set; } + + Type getType() const { return type; } + +private: + // We currently use a sorted vector to represent sets. Note that it is sorted + // by the pointer value and thus non-deterministic. + // We probably want to do some profiling in the future to see if a DenseSet or + // other representation is better suited. + const SetVector set; + + // Store the set type such that we can materialize this evaluated value + // also in the case where the set is empty. + const Type type; + + // Compute the hash only once at constructor time. + const llvm::hash_code cachedHash; +}; +} // namespace + +#ifndef NDEBUG +static llvm::raw_ostream &operator<<(llvm::raw_ostream &os, + const ElaboratorValue &value) { + value.print(os); + return os; +} +#endif + +//===----------------------------------------------------------------------===// +// Hash Map Helpers +//===----------------------------------------------------------------------===// + +// NOLINTNEXTLINE(readability-identifier-naming) +static llvm::hash_code hash_value(const ElaboratorValue &val) { + return val.getHashValue(); +} + +namespace { +struct InternMapInfo : public DenseMapInfo { + static unsigned getHashValue(const ElaboratorValue *value) { + assert(value != getTombstoneKey() && value != getEmptyKey()); + return hash_value(*value); + } + + static bool isEqual(const ElaboratorValue *lhs, const ElaboratorValue *rhs) { + if (lhs == rhs) + return true; + + auto *tk = getTombstoneKey(); + auto *ek = getEmptyKey(); + if (lhs == tk || rhs == tk || lhs == ek || rhs == ek) + return false; + + return lhs->isEqual(*rhs); + } +}; +} // namespace + +//===----------------------------------------------------------------------===// +// Main Elaborator Implementation +//===----------------------------------------------------------------------===// + +namespace { + +/// Construct an SSA value from a given elaborated value. +class Materializer { +public: + Value materialize(ElaboratorValue *val, Block *block, Location loc, + function_ref emitError) { + auto iter = materializedValues.find({val, block}); + if (iter != materializedValues.end()) + return iter->second; + + auto [builderIter, _] = + builderPerBlock.insert({block, OpBuilder::atBlockBegin(block)}); + OpBuilder builder = builderIter->second; + + return TypeSwitch(val) + .Case( + [&](auto val) { return visit(val, builder, loc, emitError); }) + .Default([](auto val) { + assert(false && "all cases must be covered above"); + return Value(); + }); + } + +private: + Value visit(AttributeValue *val, OpBuilder &builder, Location loc, + function_ref emitError) { + auto attr = val->getAttr(); + + // For integer attributes (and arithmetic operations on them) we use the + // arith dialect. + if (isa(attr)) { + Value res = builder.getContext() + ->getLoadedDialect() + ->materializeConstant(builder, attr, attr.getType(), loc) + ->getResult(0); + materializedValues[{val, builder.getBlock()}] = res; + return res; + } + + // For any other attribute, we just call the materializer of the dialect + // defining that attribute. + auto *op = attr.getDialect().materializeConstant(builder, attr, + attr.getType(), loc); + if (!op) { + emitError() << "materializer of dialect '" + << attr.getDialect().getNamespace() + << "' unable to materialize value for attribute '" << attr + << "'"; + return Value(); + } + + Value res = op->getResult(0); + materializedValues[{val, builder.getBlock()}] = res; + return res; + } + + Value visit(SetValue *val, OpBuilder &builder, Location loc, + function_ref emitError) { + SmallVector elements; + elements.reserve(val->getSet().size()); + for (auto *el : val->getSet()) { + auto materialized = materialize(el, builder.getBlock(), loc, emitError); + if (!materialized) + return Value(); + + elements.push_back(materialized); + } + + auto res = builder.create(loc, val->getType(), elements); + materializedValues[{val, builder.getBlock()}] = res; + return res; + } + +private: + /// Cache values we have already materialized to reuse them later. We start + /// with an insertion point at the start of the block and cache the (updated) + /// insertion point such that future materializations can also reuse previous + /// materializations without running into dominance issues (or requiring + /// additional checks to avoid them). + DenseMap, Value> materializedValues; + + /// Cache the builders to continue insertions at their current insertion point + /// for the reason stated above. + DenseMap builderPerBlock; +}; + +/// Used to signal to the elaboration driver whether the operation should be +/// removed. +enum class DeletionKind { Keep, Delete }; + +/// Interprets the IR to perform and lower the represented randomizations. +class Elaborator : public RTGOpVisitor, + function_ref> { +public: + using RTGBase = RTGOpVisitor, + function_ref>; + using RTGBase::visitOp; + using RTGBase::visitRegisterOp; + + Elaborator(SymbolTable &table, std::mt19937 &rng) : rng(rng) {} + + /// Helper to perform internalization and keep track of interpreted value for + /// the given SSA value. + template + void internalizeResult(Value val, Args &&...args) { + // TODO: this isn't the most efficient way to internalize + auto ptr = std::make_unique(std::forward(args)...); + auto *e = ptr.get(); + auto [iter, _] = interned.insert({e, std::move(ptr)}); + state[val] = iter->second.get(); + } + + /// Print a nice error message for operations we don't support yet. + FailureOr + visitUnhandledOp(Operation *op, + function_ref addToWorklist) { + return op->emitOpError("elaboration not supported"); + } + + FailureOr + visitExternalOp(Operation *op, + function_ref addToWorklist) { + // TODO: we only have this to be able to write tests for this pass without + // having to add support for more operations for now, so it should be + // removed once it is not necessary anymore for writing tests + if (op->use_empty()) { + for (auto &operand : op->getOpOperands()) { + auto emitError = [&]() { + auto diag = op->emitError(); + diag.attachNote(op->getLoc()) + << "while materializing value for operand#" + << operand.getOperandNumber(); + return diag; + }; + Value val = materializer.materialize( + state.at(operand.get()), op->getBlock(), op->getLoc(), emitError); + if (!val) + return failure(); + operand.set(val); + } + return DeletionKind::Keep; + } + + return visitUnhandledOp(op, addToWorklist); + } + + FailureOr + visitOp(SetCreateOp op, function_ref addToWorklist) { + SetVector set; + for (auto val : op.getElements()) + set.insert(state.at(val)); + + internalizeResult(op.getSet(), std::move(set), + op.getSet().getType()); + return DeletionKind::Delete; + } + + FailureOr + visitOp(SetSelectRandomOp op, function_ref addToWorklist) { + auto *set = cast(state.at(op.getSet())); + + size_t selected; + if (auto intAttr = + op->getAttrOfType("rtg.elaboration_custom_seed")) { + std::mt19937 customRng(intAttr.getInt()); + selected = getUniformlyInRange(customRng, 0, set->getSet().size() - 1); + } else { + selected = getUniformlyInRange(rng, 0, set->getSet().size() - 1); + } + + state[op.getResult()] = set->getSet()[selected]; + return DeletionKind::Delete; + } + + FailureOr + visitOp(SetDifferenceOp op, function_ref addToWorklist) { + auto original = cast(state.at(op.getOriginal()))->getSet(); + auto diff = cast(state.at(op.getDiff()))->getSet(); + + SetVector result(original); + result.set_subtract(diff); + + internalizeResult(op.getResult(), std::move(result), + op.getResult().getType()); + return DeletionKind::Delete; + } + + FailureOr + dispatchOpVisitor(Operation *op, + function_ref addToWorklist) { + if (op->hasTrait()) { + SmallVector result; + auto foldResult = op->fold(result); + (void)foldResult; // Make sure there is a user when assertions are off. + assert(succeeded(foldResult) && + "constant folder of a constant-like must always succeed"); + auto attr = dyn_cast(result[0].dyn_cast()); + if (!attr) + return op->emitError( + "only typed attributes supported for constant-like operations"); + + internalizeResult(op->getResult(0), attr); + return DeletionKind::Delete; + } + + return RTGBase::dispatchOpVisitor(op, addToWorklist); + } + + LogicalResult elaborate(TestOp testOp) { + LLVM_DEBUG(llvm::dbgs() + << "\n=== Elaborating Test @" << testOp.getSymName() << "\n\n"); + + DenseSet visited; + std::deque worklist; + DenseSet toDelete; + for (auto &op : *testOp.getBody()) + if (op.use_empty()) + worklist.push_back(&op); + + while (!worklist.empty()) { + auto *curr = worklist.back(); + if (visited.contains(curr)) { + worklist.pop_back(); + continue; + } + + if (curr->getNumRegions() != 0) + return curr->emitOpError("nested regions not supported"); + + bool addedSomething = false; + for (auto val : curr->getOperands()) { + if (state.contains(val)) + continue; + + auto *defOp = val.getDefiningOp(); + assert(defOp && "cannot be a BlockArgument here"); + if (!visited.contains(defOp)) { + worklist.push_back(defOp); + addedSomething = true; + } + } + + if (addedSomething) + continue; + + auto addToWorklist = [&](Operation *op) { + if (op->use_empty()) + worklist.push_front(op); + }; + auto result = dispatchOpVisitor(curr, addToWorklist); + if (failed(result)) + return failure(); + + LLVM_DEBUG({ + llvm::dbgs() << "Elaborating " << *curr << " to\n["; + + llvm::interleaveComma(curr->getResults(), llvm::dbgs(), [&](auto res) { + if (state.contains(res)) + llvm::dbgs() << *state.at(res); + else + llvm::dbgs() << "unknown"; + }); + + llvm::dbgs() << "]\n\n"; + }); + + if (*result == DeletionKind::Delete) + toDelete.insert(curr); + + visited.insert(curr); + worklist.pop_back(); + } + + // FIXME: this assumes that we didn't query the opaque value from an + // interpreted elaborator value in a way that it can remain used in the IR. + for (auto *op : toDelete) { + op->dropAllUses(); + op->erase(); + } + + // Reduce max memory consumption and make sure the values cannot be accessed + // anymore because we deleted the ops above. + state.clear(); + interned.clear(); + + return success(); + } + +private: + std::mt19937 rng; + + // A map used to intern elaborator values. We do this such that we can + // compare pointers when, e.g., computing set differences, uniquing the + // elements in a set, etc. Otherwise, we'd need to do a deep value comparison + // in those situations. + // Use a pointer as the key with custom MapInfo because of object slicing when + // inserting an object of a derived class of ElaboratorValue. + // The custom MapInfo makes sure that we do a value comparison instead of + // comparing the pointers. + DenseMap, InternMapInfo> + interned; + + // A map from SSA values to a pointer of an interned elaborator value. + DenseMap state; + + // Allows us to materialize ElaboratorValues to the IR operations necessary to + // obtain an SSA value representing that elaborated value. + Materializer materializer; +}; +} // namespace + +//===----------------------------------------------------------------------===// +// Elaborator Pass +//===----------------------------------------------------------------------===// + +namespace { +struct ElaborationPass + : public rtg::impl::ElaborationPassBase { + using Base::Base; + + void runOnOperation() override; + void cloneTargetsIntoTests(SymbolTable &table); +}; +} // namespace + +void ElaborationPass::runOnOperation() { + auto moduleOp = getOperation(); + SymbolTable table(moduleOp); + + cloneTargetsIntoTests(table); + + std::mt19937 rng(seed); + Elaborator elaborator(table, rng); + for (auto testOp : moduleOp.getOps()) + if (failed(elaborator.elaborate(testOp))) + return signalPassFailure(); +} + +void ElaborationPass::cloneTargetsIntoTests(SymbolTable &table) { + auto moduleOp = getOperation(); + for (auto target : llvm::make_early_inc_range(moduleOp.getOps())) { + for (auto test : moduleOp.getOps()) { + // If the test requires nothing from a target, we can always run it. + if (test.getTarget().getEntries().empty()) + continue; + + // If the target requirements do not match, skip this test + // TODO: allow target refinements, just not coarsening + if (target.getTarget() != test.getTarget()) + continue; + + IRRewriter rewriter(test); + // Create a new test for the matched target + auto newTest = cast(test->clone()); + newTest.setSymName(test.getSymName().str() + "_" + + target.getSymName().str()); + table.insert(newTest, rewriter.getInsertionPoint()); + + // Copy the target body into the newly created test + IRMapping mapping; + rewriter.setInsertionPointToStart(newTest.getBody()); + for (auto &op : target.getBody()->without_terminator()) + rewriter.clone(op, mapping); + + for (auto [returnVal, result] : + llvm::zip(target.getBody()->getTerminator()->getOperands(), + newTest.getBody()->getArguments())) + result.replaceAllUsesWith(mapping.lookup(returnVal)); + + newTest.getBody()->eraseArguments(0, + newTest.getBody()->getNumArguments()); + newTest.setTarget(DictType::get(&getContext(), {})); + } + + target->erase(); + } + + // Erase all remaining non-matched tests. + for (auto test : llvm::make_early_inc_range(moduleOp.getOps())) + if (!test.getTarget().getEntries().empty()) + test->erase(); +} diff --git a/lib/Dialect/RTGTest/IR/RTGTestOps.cpp b/lib/Dialect/RTGTest/IR/RTGTestOps.cpp index a4a584ba6b56..0487ed297a51 100644 --- a/lib/Dialect/RTGTest/IR/RTGTestOps.cpp +++ b/lib/Dialect/RTGTest/IR/RTGTestOps.cpp @@ -23,6 +23,14 @@ using namespace rtgtest; size_t CPUDeclOp::getIdentifier(size_t idx) { return getId().getZExtValue(); } +//===----------------------------------------------------------------------===// +// ConstantTestOp +//===----------------------------------------------------------------------===// + +mlir::OpFoldResult ConstantTestOp::fold(FoldAdaptor adaptor) { + return getValueAttr(); +} + //===----------------------------------------------------------------------===// // TableGen generated logic. //===----------------------------------------------------------------------===// diff --git a/test/Dialect/RTG/Transform/elaboration.mlir b/test/Dialect/RTG/Transform/elaboration.mlir new file mode 100644 index 000000000000..69000aa3714c --- /dev/null +++ b/test/Dialect/RTG/Transform/elaboration.mlir @@ -0,0 +1,80 @@ +// RUN: circt-opt --rtg-elaborate=seed=0 --split-input-file --verify-diagnostics %s | FileCheck %s + +func.func @dummy1(%arg0: i32, %arg1: i32, %arg2: !rtg.set) -> () {return} +func.func @dummy2(%arg0: i32) -> () {return} +func.func @dummy3(%arg0: i64) -> () {return} + +// Test the set operations and passing a sequence to another one via argument +// CHECK-LABEL: rtg.test @setOperations +rtg.test @setOperations : !rtg.dict<> { + // CHECK-NEXT: [[V0:%.+]] = arith.constant 2 : i32 + // CHECK-NEXT: [[V1:%.+]] = arith.constant 3 : i32 + // CHECK-NEXT: [[V2:%.+]] = arith.constant 4 : i32 + // CHECK-NEXT: [[V3:%.+]] = rtg.set_create [[V1]], [[V2]] : i32 + // CHECK-NEXT: func.call @dummy1([[V0]], [[V1]], [[V3]]) : + // CHECK-NEXT: } + %0 = arith.constant 2 : i32 + %1 = arith.constant 3 : i32 + %2 = arith.constant 4 : i32 + %3 = arith.constant 5 : i32 + %set = rtg.set_create %0, %1, %2, %0 : i32 + %4 = rtg.set_select_random %set : !rtg.set {rtg.elaboration_custom_seed = 1} + %new_set = rtg.set_create %3, %4 : i32 + %diff = rtg.set_difference %set, %new_set : !rtg.set + %5 = rtg.set_select_random %diff : !rtg.set {rtg.elaboration_custom_seed = 2} + func.call @dummy1(%4, %5, %diff) : (i32, i32, !rtg.set) -> () +} + +// CHECK-LABEL: @targetTest_target0 +// CHECK: [[V0:%.+]] = arith.constant 0 +// CHECK: func.call @dummy2([[V0]]) : + +// CHECK-LABEL: @targetTest_target1 +// CHECK: [[V0:%.+]] = arith.constant 1 +// CHECK: func.call @dummy2([[V0]]) : +rtg.test @targetTest : !rtg.dict { +^bb0(%arg0: i32): + func.call @dummy2(%arg0) : (i32) -> () +} + +// CHECK-NOT: @unmatchedTest +rtg.test @unmatchedTest : !rtg.dict { +^bb0(%arg0: i64): + func.call @dummy3(%arg0) : (i64) -> () +} + +rtg.target @target0 : !rtg.dict { + %0 = arith.constant 0 : i32 + rtg.yield %0 : i32 +} + +rtg.target @target1 : !rtg.dict { + %0 = arith.constant 1 : i32 + rtg.yield %0 : i32 +} + +// ----- + +rtg.test @nestedRegionsNotSupported : !rtg.dict<> { + %cond = arith.constant false + // expected-error @below {{nested regions not supported}} + scf.if %cond { } +} + +// ----- + +rtg.test @untypedAttributes : !rtg.dict<> { + // expected-error @below {{only typed attributes supported for constant-like operations}} + %0 = rtgtest.constant_test i32 {value = [10 : i32]} +} + +// ----- + +func.func @dummy(%arg0: i32) {return} + +rtg.test @untypedAttributes : !rtg.dict<> { + %0 = rtgtest.constant_test i32 {value = "str"} + // expected-error @below {{materializer of dialect 'builtin' unable to materialize value for attribute '"str"'}} + // expected-note @below {{while materializing value for operand#0}} + func.call @dummy(%0) : (i32) -> () +} diff --git a/test/Dialect/RTGTest/IR/basic.mlir b/test/Dialect/RTGTest/IR/basic.mlir index d46aeac3e08c..62120bf9ca3b 100644 --- a/test/Dialect/RTGTest/IR/basic.mlir +++ b/test/Dialect/RTGTest/IR/basic.mlir @@ -3,7 +3,10 @@ // CHECK-LABEL: @cpus // CHECK-SAME: !rtgtest.cpu rtg.target @cpus : !rtg.dict { - // CHECK: %0 = rtgtest.cpu_decl 0 + // CHECK: rtgtest.cpu_decl 0 %0 = rtgtest.cpu_decl 0 rtg.yield %0 : !rtgtest.cpu } + +// CHECK: rtgtest.constant_test i32 {value = "str"} +%1 = rtgtest.constant_test i32 {value = "str"}