Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: SMT Verification Module Update #6849

Merged
merged 15 commits into from
Jun 5, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ template <typename FF_> class UltraArithmeticRelationImpl {
* at the next gate. Then we can treat (q_arith - 1) as a simulated q_6 selector and scale q_m to handle (q_arith -
* 3) at product.
*
* The The relation is
* The relation is
* defined as C(in(X)...) = q_arith * [ -1/2(q_arith - 3)(q_m * w_r * w_l) + (q_l * w_l) + (q_r * w_r) +
* (q_o * w_o) + (q_4 * w_4) + q_c + (q_arith - 1)w_4_shift ]
*
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,9 +26,9 @@ add_dependencies(cvc5-lib cvc5)
include_directories(${CVC5_INCLUDE})
set_target_properties(cvc5-lib PROPERTIES IMPORTED_LOCATION ${CVC5_LIB})

barretenberg_module(smt_verification common proof_system stdlib_primitives stdlib_sha256 circuit_checker transcript plonk cvc5-lib)
barretenberg_module(smt_verification common stdlib_primitives stdlib_sha256 circuit_checker transcript plonk cvc5-lib)
# We have no easy way to add a dependency to an external target, we list the built targets explicit. Could be cleaner.
add_dependencies(smt_verification cvc5)
add_dependencies(smt_verification_objects cvc5)
add_dependencies(smt_verification_tests cvc5)
add_dependencies(smt_verification_test_objects cvc5)
add_dependencies(smt_verification_test_objects cvc5)
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,20 @@ CircuitSchema unpack_from_file(const std::string& filename)
return cir;
}

/**
* @brief Get the CircuitSchema object
* @details Initialize the CircuitSchema from the msgpack compatible buffer.
*
* @param buf
* @return CircuitSchema
*/
CircuitSchema unpack_from_buffer(const msgpack::sbuffer& buf)
{
CircuitSchema cir;
msgpack::unpack(buf.data(), buf.size()).get().convert(cir);
return cir;
}

/**
* @brief Translates the schema to python format
* @details Returns the contents of the .py file
Expand All @@ -42,6 +56,7 @@ CircuitSchema unpack_from_file(const std::string& filename)
* gates = [
* [[0x000...0, 0x000...1, 0x000...0, 0x000...0, 0x000...0], [0, 0, 0]], ...
* ]
* @todo UltraCircuitSchema output
*/
void print_schema_for_use_in_python(CircuitSchema& cir)
{
Expand All @@ -64,37 +79,23 @@ void print_schema_for_use_in_python(CircuitSchema& cir)
for (size_t i = 0; i < cir.selectors.size(); i++) {
info("[",
"[",
cir.selectors[i][0],
cir.selectors[0][i][0],
", ",
cir.selectors[i][1],
cir.selectors[0][i][1],
", ",
cir.selectors[i][2],
cir.selectors[0][i][2],
", ",
cir.selectors[i][3],
cir.selectors[0][i][3],
", ",
cir.selectors[i][4],
cir.selectors[0][i][4],
"], [",
cir.wires[i][0],
cir.wires[0][i][0],
", ",
cir.wires[i][1],
cir.wires[0][i][1],
", ",
cir.wires[i][2],
cir.wires[0][i][2],
"]],");
}
info("]");
}

/**
* @brief Get the CircuitSchema object
* @details Initialize the CircuitSchema from the msgpack compatible buffer.
*
* @param buf
* @return CircuitSchema
*/
CircuitSchema unpack_from_buffer(const msgpack::sbuffer& buf)
{
CircuitSchema cir;
msgpack::unpack(buf.data(), buf.size()).get().convert(cir);
return cir;
}
} // namespace smt_circuit_schema
Original file line number Diff line number Diff line change
Expand Up @@ -27,10 +27,12 @@ struct CircuitSchema {
std::vector<uint32_t> public_inps;
std::unordered_map<uint32_t, std::string> vars_of_interest;
std::vector<bb::fr> variables;
std::vector<std::vector<bb::fr>> selectors;
std::vector<std::vector<uint32_t>> wires;
std::vector<std::vector<std::vector<bb::fr>>> selectors;
std::vector<std::vector<std::vector<uint32_t>>> wires;
std::vector<uint32_t> real_variable_index;
MSGPACK_FIELDS(modulus, public_inps, vars_of_interest, variables, selectors, wires, real_variable_index);
std::vector<std::vector<std::vector<bb::fr>>> lookup_tables;
MSGPACK_FIELDS(
modulus, public_inps, vars_of_interest, variables, selectors, wires, real_variable_index, lookup_tables);
};

CircuitSchema unpack_from_buffer(const msgpack::sbuffer& buf);
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
#include "solver.hpp"
#include <iostream>
#include "barretenberg/common/log.hpp"

namespace smt_solver {

Expand Down Expand Up @@ -47,6 +47,8 @@ std::unordered_map<std::string, std::string> Solver::model(std::unordered_map<st
str_val = val.getIntegerValue();
} else if (val.isFiniteFieldValue()) {
str_val = val.getFiniteFieldValue();
} else if (val.isBitVectorValue()) {
str_val = "0b" + val.getBitVectorValue();
} else {
throw std::invalid_argument("Expected Integer or FiniteField sorts. Got: " + val.getSort().toString());
}
Expand Down Expand Up @@ -82,6 +84,8 @@ std::unordered_map<std::string, std::string> Solver::model(std::vector<cvc5::Ter
str_val = val.getIntegerValue();
} else if (val.isFiniteFieldValue()) {
str_val = val.getFiniteFieldValue();
} else if (val.isBitVectorValue()) {
str_val = "0b" + val.getBitVectorValue();
} else {
throw std::invalid_argument("Expected Integer or FiniteField sorts. Got: " + val.getSort().toString());
}
Expand Down Expand Up @@ -181,11 +185,11 @@ std::string stringify_term(const cvc5::Term& term, bool parenthesis)
break;
case cvc5::Kind::BITVECTOR_SHL:
back = true;
op = " << " + term.getOp()[0].toString();
op = " << ";
break;
case cvc5::Kind::BITVECTOR_LSHR:
back = true;
op = " >> " + term.getOp()[0].toString();
op = " >> ";
break;
case cvc5::Kind::BITVECTOR_ROTATE_LEFT:
back = true;
Expand Down Expand Up @@ -237,8 +241,31 @@ std::string stringify_term(const cvc5::Term& term, bool parenthesis)
* */
void Solver::print_assertions() const
{
for (auto& t : this->solver.getAssertions()) {
for (const auto& t : this->solver.getAssertions()) {
info(stringify_term(t));
}
}

cvc5::Term Solver::create_lookup_table(std::vector<std::vector<cvc5::Term>>& table)
{
if (!lookup_enabled) {
this->solver.setLogic("ALL");
this->solver.setOption("finite-model-find", "true");
this->solver.setOption("sets-ext", "true");
lookup_enabled = true;
}

cvc5::Term tmp = table[0][0];
cvc5::Sort tuple_sort = this->term_manager.mkTupleSort({ tmp.getSort(), tmp.getSort(), tmp.getSort() });
cvc5::Sort relation = this->term_manager.mkSetSort(tuple_sort);
cvc5::Term resulting_table = this->term_manager.mkEmptySet(relation);

std::vector<cvc5::Term> children;
for (auto& table_entry : table) {
cvc5::Term entry = this->term_manager.mkTuple(table_entry);
children.push_back(entry);
}
children.push_back(resulting_table);
return this->term_manager.mkTerm(cvc5::Kind::SET_INSERT, children);
}
}; // namespace smt_solver
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ struct SolverConfiguration {
uint64_t timeout;
uint32_t debug;

bool ff_disjunctive_bit;
bool ff_elim_disjunctive_bit;
std::string ff_solver;
};

Expand Down Expand Up @@ -76,8 +76,8 @@ class Solver {
// Cause bit constraints are part of the split-gb optimization
// and without them it will probably perform less efficient
// TODO(alex): test this `probably` after finishing the pr sequence
if (config.ff_disjunctive_bit) {
solver.setOption("ff-disjunctive-bit", "true");
if (!config.ff_elim_disjunctive_bit) {
solver.setOption("ff-elim-disjunctive-bit", "false");
}
// split-gb is an updated version of gb ff-solver
// It basically SPLITS the polynomials in the system into subsets
Expand All @@ -96,6 +96,10 @@ class Solver {
Solver& operator=(const Solver& other) = delete;
Solver& operator=(Solver&& other) = delete;

bool lookup_enabled = false;

cvc5::Term create_lookup_table(std::vector<std::vector<cvc5::Term>>& table);

void assertFormula(const cvc5::Term& term) const { this->solver.assertFormula(term); }

cvc5::Term getValue(const cvc5::Term& term) const { return this->solver.getValue(term); }
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -154,7 +154,145 @@ TEST(BVTerm, rotl)
ASSERT_EQ(bvals, xvals);
}

// MUL, LSH, RSH, AND and OR are not tested, since they are not bijective
// non bijective operators
TEST(BVTerm, mul)
{
StandardCircuitBuilder builder;
uint_ct a = witness_ct(&builder, static_cast<uint32_t>(fr::random_element()));
uint_ct b = witness_ct(&builder, static_cast<uint32_t>(fr::random_element()));
uint_ct c = a * b;

uint32_t modulus_base = 16;
uint32_t bitvector_size = 32;
Solver s("30644e72e131a029b85045b68181585d2833e84879b9709143e1f593f0000001",
default_solver_config,
modulus_base,
bitvector_size);

STerm x = BVVar("x", &s);
STerm y = BVVar("y", &s);
STerm z = x * y;

x == a.get_value();
y == b.get_value();

ASSERT_TRUE(s.check());

std::string xvals = s.getValue(z.term).getBitVectorValue();
STerm bval = STerm(c.get_value(), &s, TermType::BVTerm);
std::string bvals = s.getValue(bval.term).getBitVectorValue();
ASSERT_EQ(bvals, xvals);
}

TEST(BVTerm, and)
{
StandardCircuitBuilder builder;
uint_ct a = witness_ct(&builder, static_cast<uint32_t>(fr::random_element()));
uint_ct b = witness_ct(&builder, static_cast<uint32_t>(fr::random_element()));
uint_ct c = a & b;

uint32_t modulus_base = 16;
uint32_t bitvector_size = 32;
Solver s("30644e72e131a029b85045b68181585d2833e84879b9709143e1f593f0000001",
default_solver_config,
modulus_base,
bitvector_size);

STerm x = BVVar("x", &s);
STerm y = BVVar("y", &s);
STerm z = x & y;

x == a.get_value();
y == b.get_value();

ASSERT_TRUE(s.check());

std::string xvals = s.getValue(z.term).getBitVectorValue();
STerm bval = STerm(c.get_value(), &s, TermType::BVTerm);
std::string bvals = s.getValue(bval.term).getBitVectorValue();
ASSERT_EQ(bvals, xvals);
}

TEST(BVTerm, or)
{
StandardCircuitBuilder builder;
uint_ct a = witness_ct(&builder, static_cast<uint32_t>(fr::random_element()));
uint_ct b = witness_ct(&builder, static_cast<uint32_t>(fr::random_element()));
uint_ct c = a | b;

uint32_t modulus_base = 16;
uint32_t bitvector_size = 32;
Solver s("30644e72e131a029b85045b68181585d2833e84879b9709143e1f593f0000001",
default_solver_config,
modulus_base,
bitvector_size);

STerm x = BVVar("x", &s);
STerm y = BVVar("y", &s);
STerm z = x | y;

x == a.get_value();
y == b.get_value();

ASSERT_TRUE(s.check());

std::string xvals = s.getValue(z.term).getBitVectorValue();
STerm bval = STerm(c.get_value(), &s, TermType::BVTerm);
std::string bvals = s.getValue(bval.term).getBitVectorValue();
ASSERT_EQ(bvals, xvals);
}

TEST(BVTerm, shr)
{
StandardCircuitBuilder builder;
uint_ct a = witness_ct(&builder, static_cast<uint32_t>(fr::random_element()));
uint_ct b = a >> 5;

uint32_t modulus_base = 16;
uint32_t bitvector_size = 32;
Solver s("30644e72e131a029b85045b68181585d2833e84879b9709143e1f593f0000001",
default_solver_config,
modulus_base,
bitvector_size);

STerm x = BVVar("x", &s);
STerm y = x >> 5;

x == a.get_value();

ASSERT_TRUE(s.check());

std::string xvals = s.getValue(y.term).getBitVectorValue();
STerm bval = STerm(b.get_value(), &s, TermType::BVTerm);
std::string bvals = s.getValue(bval.term).getBitVectorValue();
ASSERT_EQ(bvals, xvals);
}

TEST(BVTerm, shl)
{
StandardCircuitBuilder builder;
uint_ct a = witness_ct(&builder, static_cast<uint32_t>(fr::random_element()));
uint_ct b = a << 5;

uint32_t modulus_base = 16;
uint32_t bitvector_size = 32;
Solver s("30644e72e131a029b85045b68181585d2833e84879b9709143e1f593f0000001",
default_solver_config,
modulus_base,
bitvector_size);

STerm x = BVVar("x", &s);
STerm y = x << 5;

x == a.get_value();

ASSERT_TRUE(s.check());

std::string xvals = s.getValue(y.term).getBitVectorValue();
STerm bval = STerm(b.get_value(), &s, TermType::BVTerm);
std::string bvals = s.getValue(bval.term).getBitVectorValue();
ASSERT_EQ(bvals, xvals);
}

// This test aims to check for the absence of unintended
// behavior. If an unsupported operator is called, an info message appears in stderr
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,31 @@ TEST(FFTerm, division)
ASSERT_EQ(bvals, yvals);
}

TEST(FFTerm, set_inclusion)
{
Solver s("30644e72e131a029b85045b68181585d2833e84879b9709143e1f593f0000001");

std::vector<std::vector<cvc5::Term>> table = { { FFConst("1", &s), FFConst("2", &s), FFConst("3", &s) },
{ FFConst("4", &s), FFConst("5", &s), FFConst("6", &s) } };
cvc5::Term symbolic_table = s.create_lookup_table(table);

STerm x = FFVar("x", &s);
STerm y = FFVar("y", &s);
STerm z = FFVar("z", &s);
std::vector<STerm> tmp_vec = { x, y, z };
STerm::in_table(tmp_vec, symbolic_table);
x != 4;

ASSERT_TRUE(s.check());

std::string xval = s.getValue(x).getFiniteFieldValue();
ASSERT_EQ(xval, "1");
std::string yval = s.getValue(y).getFiniteFieldValue();
ASSERT_EQ(yval, "2");
std::string zval = s.getValue(z).getFiniteFieldValue();
ASSERT_EQ(zval, "3");
}

// This test aims to check for the absence of unintended
// behavior. If an unsupported operator is called, an info message appears in stderr
// and the value is supposed to remain unchanged.
Expand Down
Loading
Loading