Skip to content

Commit

Permalink
[SMT] Add Z3 lowering for set_logic op (#7930)
Browse files Browse the repository at this point in the history
  • Loading branch information
TaoBi22 authored Dec 6, 2024
1 parent ac73775 commit 61d5eae
Show file tree
Hide file tree
Showing 4 changed files with 115 additions and 2 deletions.
11 changes: 11 additions & 0 deletions integration_test/Dialect/SMT/basic.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -183,6 +183,17 @@ func.func @entry() {
smt.yield
}

// CHECK: unknown
// CHECK: Res: 0
smt.solver () : () -> () {
smt.set_logic "HORN"
%c = smt.declare_fun : !smt.int
%c4 = smt.int.constant 4
%eq = smt.eq %c, %c4 : !smt.int
func.call @check(%eq) : (!smt.bool) -> ()
smt.yield
}

return
}

Expand Down
53 changes: 51 additions & 2 deletions lib/Conversion/SMTToZ3LLVM/LowerSMTToZ3LLVM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -561,6 +561,7 @@ struct SolverOpLowering : public SMTLoweringPattern<SolverOp> {
auto ptrTy = LLVM::LLVMPointerType::get(getContext());
auto voidTy = LLVM::LLVMVoidType::get(getContext());
auto ptrToPtrFunc = LLVM::LLVMFunctionType::get(ptrTy, ptrTy);
auto ptrPtrToPtrFunc = LLVM::LLVMFunctionType::get(ptrTy, {ptrTy, ptrTy});
auto ptrToVoidFunc = LLVM::LLVMFunctionType::get(voidTy, ptrTy);
auto ptrPtrToVoidFunc = LLVM::LLVMFunctionType::get(voidTy, {ptrTy, ptrTy});

Expand All @@ -579,6 +580,17 @@ struct SolverOpLowering : public SMTLoweringPattern<SolverOp> {
{config, paramKey, paramValue});
}

// Check if the logic is set anywhere within the solver
std::optional<StringRef> logic = std::nullopt;
auto setLogicOps = op.getBodyRegion().getOps<smt::SetLogicOp>();
if (!setLogicOps.empty()) {
// We know from before patterns were applied that there is only one
// set_logic op
auto setLogicOp = *setLogicOps.begin();
logic = setLogicOp.getLogic();
rewriter.eraseOp(setLogicOp);
}

// Create the context and store a pointer to it in the global variable.
Value ctx = buildCall(rewriter, loc, "Z3_mk_context", ptrToPtrFunc, config)
.getResult();
Expand All @@ -591,8 +603,16 @@ struct SolverOpLowering : public SMTLoweringPattern<SolverOp> {

// Create a solver instance, increase its reference counter, and store a
// pointer to it in the global variable.
Value solver = buildCall(rewriter, loc, "Z3_mk_solver", ptrToPtrFunc, ctx)
->getResult(0);
Value solver;
if (logic) {
auto logicStr = buildString(rewriter, loc, logic.value());
solver = buildCall(rewriter, loc, "Z3_mk_solver_for_logic",
ptrPtrToPtrFunc, {ctx, logicStr})
->getResult(0);
} else {
solver = buildCall(rewriter, loc, "Z3_mk_solver", ptrToPtrFunc, ctx)
->getResult(0);
}
buildCall(rewriter, loc, "Z3_solver_inc_ref", ptrPtrToVoidFunc,
{ctx, solver});
Value solverAddr =
Expand Down Expand Up @@ -1450,6 +1470,35 @@ void LowerSMTToZ3LLVMPass::runOnOperation() {
LowerSMTToZ3LLVMOptions options;
options.debug = debug;

// Check that the lowering is possible
// Specifically, check that the use of set-logic ops is valid for z3
auto setLogicCheck = getOperation().walk([&](SolverOp solverOp)
-> WalkResult {
// Check that solver ops only contain one set-logic op and that they're at
// the start of the body
auto setLogicOps = solverOp.getBodyRegion().getOps<smt::SetLogicOp>();
auto numSetLogicOps = std::distance(setLogicOps.begin(), setLogicOps.end());
if (numSetLogicOps > 1) {
return solverOp.emitError(
"multiple set-logic operations found in one solver operation - Z3 "
"only supports setting the logic once");
}
if (numSetLogicOps == 1)
// Check the only ops before the set-logic op are ConstantLike
for (auto &blockOp : solverOp.getBodyRegion().getOps()) {
if (isa<smt::SetLogicOp>(blockOp))
break;
if (!blockOp.hasTrait<OpTrait::ConstantLike>()) {
return solverOp.emitError("set-logic operation must be the first "
"non-constant operation in a solver "
"operation");
}
}
return WalkResult::advance();
});
if (setLogicCheck.wasInterrupted())
return signalPassFailure();

// Set up the type converter
LLVMTypeConverter converter(&getContext());
populateSMTToZ3LLVMTypeConverter(converter);
Expand Down
23 changes: 23 additions & 0 deletions test/Conversion/SMTToZ3LLVM/smt-to-z3-llvm-errors.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
// RUN: circt-opt %s --lower-smt-to-z3-llvm --split-input-file --verify-diagnostics

func.func @multiple_set_logics() {
// expected-error @below {{multiple set-logic operations found in one solver operation - Z3 only supports setting the logic once}}
smt.solver () : () -> () {
smt.set_logic "HORN"
smt.set_logic "AUFLIA"
smt.yield
}
func.return
}

// -----

func.func @multiple_set_logics() {
// expected-error @below {{set-logic operation must be the first non-constant operation in a solver operation}}
smt.solver () : () -> () {
smt.check sat {} unknown {} unsat {}
smt.set_logic "HORN"
smt.yield
}
func.return
}
30 changes: 30 additions & 0 deletions test/Conversion/SMTToZ3LLVM/smt-to-z3-llvm.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,25 @@ llvm.mlir.global internal @solver() {alignment = 8 : i64} : !llvm.ptr {
// CHECK: llvm.call @Z3_del_context([[CTX]]) : (!llvm.ptr) -> ()
// CHECK: llvm.return

// CHECK-LABEL: llvm.func @test_logic
// CHECK: [[CONFIG1:%.+]] = llvm.call @Z3_mk_config() : () -> !llvm.ptr
// CHECK-DEBUG: [[PROOF_STR1:%.+]] = llvm.mlir.addressof @str{{.*}} : !llvm.ptr
// CHECK-DEBUG: [[TRUE_STR1:%.+]] = llvm.mlir.addressof @str{{.*}} : !llvm.ptr
// CHECK-DEBUG: llvm.call @Z3_set_param_value({{.*}}, [[PROOF_STR1]], [[TRUE_STR1]]) : (!llvm.ptr, !llvm.ptr, !llvm.ptr) -> ()
// CHECK: [[CTX1:%.+]] = llvm.call @Z3_mk_context([[CONFIG1]]) : (!llvm.ptr) -> !llvm.ptr
// CHECK: [[CTX_ADDR1:%.+]] = llvm.mlir.addressof @ctx_0 : !llvm.ptr
// CHECK: llvm.store [[CTX1]], [[CTX_ADDR1]] : !llvm.ptr, !llvm.ptr
// CHECK: llvm.call @Z3_del_config([[CONFIG1]]) : (!llvm.ptr) -> ()
// CHECK: [[LOGICADDR:%.+]] = llvm.mlir.addressof [[LOGICSTR:@.+]] : !llvm.ptr
// CHECK: [[SOLVER1:%.+]] = llvm.call @Z3_mk_solver_for_logic([[CTX1]], [[LOGICADDR]]) : (!llvm.ptr, !llvm.ptr) -> !llvm.ptr
// CHECK: llvm.call @Z3_solver_inc_ref([[CTX1]], [[SOLVER1]]) : (!llvm.ptr, !llvm.ptr) -> ()
// CHECK: [[SOLVER_ADDR1:%.+]] = llvm.mlir.addressof @solver_0 : !llvm.ptr
// CHECK: llvm.store [[SOLVER1]], [[SOLVER_ADDR1]] : !llvm.ptr, !llvm.ptr
// CHECK: llvm.call @solver
// CHECK: llvm.call @Z3_solver_dec_ref([[CTX1]], [[SOLVER1]]) : (!llvm.ptr, !llvm.ptr) -> ()
// CHECK: llvm.call @Z3_del_context([[CTX1]]) : (!llvm.ptr) -> ()
// CHECK: llvm.return

// CHECK-LABEL: llvm.func @solver
func.func @test(%arg0: i32) {
%0 = smt.solver (%arg0) : (i32) -> (i32) {
Expand Down Expand Up @@ -423,3 +442,14 @@ func.func @test(%arg0: i32) {
// CHECK: llvm.return
return
}

// CHECK-LABEL: llvm.func @solver
func.func @test_logic() {
smt.solver () : () -> () {
%c0_bv4 = smt.bv.constant #smt.bv<0> : !smt.bv<4>
smt.set_logic "HORN"
smt.check sat {} unknown {} unsat {}
smt.yield
}
func.return
}

0 comments on commit 61d5eae

Please sign in to comment.