Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SMT] Add Z3 lowering for set_logic op #7930

Merged
merged 3 commits into from
Dec 6, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 11 additions & 0 deletions integration_test/Dialect/SMT/basic.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -183,6 +183,17 @@ func.func @entry() {
smt.yield
}

// CHECK: unknown
// CHECK: Res: 0
smt.solver () : () -> () {
smt.set_logic "HORN"
%c = smt.declare_fun : !smt.int
%c4 = smt.int.constant 4
%eq = smt.eq %c, %c4 : !smt.int
func.call @check(%eq) : (!smt.bool) -> ()
smt.yield
}

return
}

Expand Down
53 changes: 51 additions & 2 deletions lib/Conversion/SMTToZ3LLVM/LowerSMTToZ3LLVM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -561,6 +561,7 @@ struct SolverOpLowering : public SMTLoweringPattern<SolverOp> {
auto ptrTy = LLVM::LLVMPointerType::get(getContext());
auto voidTy = LLVM::LLVMVoidType::get(getContext());
auto ptrToPtrFunc = LLVM::LLVMFunctionType::get(ptrTy, ptrTy);
auto ptrPtrToPtrFunc = LLVM::LLVMFunctionType::get(ptrTy, {ptrTy, ptrTy});
auto ptrToVoidFunc = LLVM::LLVMFunctionType::get(voidTy, ptrTy);
auto ptrPtrToVoidFunc = LLVM::LLVMFunctionType::get(voidTy, {ptrTy, ptrTy});

Expand All @@ -579,6 +580,17 @@ struct SolverOpLowering : public SMTLoweringPattern<SolverOp> {
{config, paramKey, paramValue});
}

// Check if the logic is set anywhere within the solver
std::optional<StringRef> logic = std::nullopt;
auto setLogicOps = op.getBodyRegion().getOps<smt::SetLogicOp>();
if (!setLogicOps.empty()) {
// We know from before patterns were applied that there is only one
// set_logic op
auto setLogicOp = *setLogicOps.begin();
logic = setLogicOp.getLogic();
rewriter.eraseOp(setLogicOp);
}

// Create the context and store a pointer to it in the global variable.
Value ctx = buildCall(rewriter, loc, "Z3_mk_context", ptrToPtrFunc, config)
.getResult();
Expand All @@ -591,8 +603,16 @@ struct SolverOpLowering : public SMTLoweringPattern<SolverOp> {

// Create a solver instance, increase its reference counter, and store a
// pointer to it in the global variable.
Value solver = buildCall(rewriter, loc, "Z3_mk_solver", ptrToPtrFunc, ctx)
->getResult(0);
Value solver;
if (logic) {
auto logicStr = buildString(rewriter, loc, logic.value());
solver = buildCall(rewriter, loc, "Z3_mk_solver_for_logic",
ptrPtrToPtrFunc, {ctx, logicStr})
->getResult(0);
} else {
solver = buildCall(rewriter, loc, "Z3_mk_solver", ptrToPtrFunc, ctx)
->getResult(0);
}
buildCall(rewriter, loc, "Z3_solver_inc_ref", ptrPtrToVoidFunc,
{ctx, solver});
Value solverAddr =
Expand Down Expand Up @@ -1450,6 +1470,35 @@ void LowerSMTToZ3LLVMPass::runOnOperation() {
LowerSMTToZ3LLVMOptions options;
options.debug = debug;

// Check that the lowering is possible
// Specifically, check that the use of set-logic ops is valid for z3
auto setLogicCheck = getOperation().walk([&](SolverOp solverOp)
-> WalkResult {
// Check that solver ops only contain one set-logic op and that they're at
// the start of the body
auto setLogicOps = solverOp.getBodyRegion().getOps<smt::SetLogicOp>();
auto numSetLogicOps = std::distance(setLogicOps.begin(), setLogicOps.end());
if (numSetLogicOps > 1) {
return solverOp.emitError(
"multiple set-logic operations found in one solver operation - Z3 "
"only supports setting the logic once");
}
if (numSetLogicOps == 1)
// Check the only ops before the set-logic op are ConstantLike
for (auto &blockOp : solverOp.getBodyRegion().getOps()) {
if (isa<smt::SetLogicOp>(blockOp))
break;
if (!blockOp.hasTrait<OpTrait::ConstantLike>()) {
return solverOp.emitError("set-logic operation must be the first "
"non-constant operation in a solver "
"operation");
}
}
return WalkResult::advance();
});
if (setLogicCheck.wasInterrupted())
return signalPassFailure();

// Set up the type converter
LLVMTypeConverter converter(&getContext());
populateSMTToZ3LLVMTypeConverter(converter);
Expand Down
23 changes: 23 additions & 0 deletions test/Conversion/SMTToZ3LLVM/smt-to-z3-llvm-errors.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
// RUN: circt-opt %s --lower-smt-to-z3-llvm --split-input-file --verify-diagnostics

func.func @multiple_set_logics() {
// expected-error @below {{multiple set-logic operations found in one solver operation - Z3 only supports setting the logic once}}
smt.solver () : () -> () {
smt.set_logic "HORN"
smt.set_logic "AUFLIA"
smt.yield
}
func.return
}

// -----

func.func @multiple_set_logics() {
// expected-error @below {{set-logic operation must be the first non-constant operation in a solver operation}}
smt.solver () : () -> () {
smt.check sat {} unknown {} unsat {}
smt.set_logic "HORN"
smt.yield
}
func.return
}
30 changes: 30 additions & 0 deletions test/Conversion/SMTToZ3LLVM/smt-to-z3-llvm.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,25 @@ llvm.mlir.global internal @solver() {alignment = 8 : i64} : !llvm.ptr {
// CHECK: llvm.call @Z3_del_context([[CTX]]) : (!llvm.ptr) -> ()
// CHECK: llvm.return

// CHECK-LABEL: llvm.func @test_logic
// CHECK: [[CONFIG1:%.+]] = llvm.call @Z3_mk_config() : () -> !llvm.ptr
// CHECK-DEBUG: [[PROOF_STR1:%.+]] = llvm.mlir.addressof @str{{.*}} : !llvm.ptr
// CHECK-DEBUG: [[TRUE_STR1:%.+]] = llvm.mlir.addressof @str{{.*}} : !llvm.ptr
// CHECK-DEBUG: llvm.call @Z3_set_param_value({{.*}}, [[PROOF_STR1]], [[TRUE_STR1]]) : (!llvm.ptr, !llvm.ptr, !llvm.ptr) -> ()
// CHECK: [[CTX1:%.+]] = llvm.call @Z3_mk_context([[CONFIG1]]) : (!llvm.ptr) -> !llvm.ptr
// CHECK: [[CTX_ADDR1:%.+]] = llvm.mlir.addressof @ctx_0 : !llvm.ptr
// CHECK: llvm.store [[CTX1]], [[CTX_ADDR1]] : !llvm.ptr, !llvm.ptr
// CHECK: llvm.call @Z3_del_config([[CONFIG1]]) : (!llvm.ptr) -> ()
// CHECK: [[LOGICADDR:%.+]] = llvm.mlir.addressof [[LOGICSTR:@.+]] : !llvm.ptr
// CHECK: [[SOLVER1:%.+]] = llvm.call @Z3_mk_solver_for_logic([[CTX1]], [[LOGICADDR]]) : (!llvm.ptr, !llvm.ptr) -> !llvm.ptr
// CHECK: llvm.call @Z3_solver_inc_ref([[CTX1]], [[SOLVER1]]) : (!llvm.ptr, !llvm.ptr) -> ()
// CHECK: [[SOLVER_ADDR1:%.+]] = llvm.mlir.addressof @solver_0 : !llvm.ptr
// CHECK: llvm.store [[SOLVER1]], [[SOLVER_ADDR1]] : !llvm.ptr, !llvm.ptr
// CHECK: llvm.call @solver
// CHECK: llvm.call @Z3_solver_dec_ref([[CTX1]], [[SOLVER1]]) : (!llvm.ptr, !llvm.ptr) -> ()
// CHECK: llvm.call @Z3_del_context([[CTX1]]) : (!llvm.ptr) -> ()
// CHECK: llvm.return

// CHECK-LABEL: llvm.func @solver
func.func @test(%arg0: i32) {
%0 = smt.solver (%arg0) : (i32) -> (i32) {
Expand Down Expand Up @@ -423,3 +442,14 @@ func.func @test(%arg0: i32) {
// CHECK: llvm.return
return
}

// CHECK-LABEL: llvm.func @solver
func.func @test_logic() {
smt.solver () : () -> () {
%c0_bv4 = smt.bv.constant #smt.bv<0> : !smt.bv<4>
smt.set_logic "HORN"
smt.check sat {} unknown {} unsat {}
smt.yield
}
func.return
}
Loading