Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Several updates in SMT verification module #7105

Merged
merged 13 commits into from
Jun 19, 2024
166 changes: 121 additions & 45 deletions barretenberg/cpp/src/barretenberg/smt_verification/README.md

Large diffs are not rendered by default.

Original file line number Diff line number Diff line change
Expand Up @@ -814,7 +814,6 @@ std::pair<StandardCircuit, StandardCircuit> StandardCircuit::unique_witness_ext(
const std::vector<std::string>& not_equal_at_the_same_time,
bool optimizations)
{
// TODO(alex): set optimizations to be true once they are confirmed
StandardCircuit c1(circuit_info, s, type, "circuit1", optimizations);
StandardCircuit c2(circuit_info, s, type, "circuit2", optimizations);

Expand Down Expand Up @@ -867,7 +866,6 @@ std::pair<StandardCircuit, StandardCircuit> StandardCircuit::unique_witness_ext(
std::pair<StandardCircuit, StandardCircuit> StandardCircuit::unique_witness(
CircuitSchema& circuit_info, Solver* s, TermType type, const std::vector<std::string>& equal, bool optimizations)
{
// TODO(alex): set optimizations to be true once they are confirmed
StandardCircuit c1(circuit_info, s, type, "circuit1", optimizations);
StandardCircuit c2(circuit_info, s, type, "circuit2", optimizations);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,12 +26,6 @@ UltraCircuit::UltraCircuit(
, wires_idxs(circuit_info.wires)
, lookup_tables(circuit_info.lookup_tables)
{
info("Arithmetic gates: ", selectors[1].size());
info("Delta Range gates: ", selectors[2].size());
info("Elliptic gates: ", selectors[3].size());
info("Aux gates: ", selectors[4].size());
info("Lookup gates: ", selectors[5].size());

// Perform all relaxations for gates or
// add gate in its normal state to solver

Expand Down Expand Up @@ -193,6 +187,14 @@ size_t UltraCircuit::handle_lookup_relation(size_t cursor, size_t idx)
this->cached_symbolic_tables.insert({ table_idx, this->solver->create_lookup_table(new_table) });
}

// Sort of an optimization.
// However if we don't do this, solver will find a unique witness that corresponds to overflowed value.
if (this->type == TermType::BVTerm && q_r == -64 && q_m == -64 && q_c == -64) {
this->symbolic_vars[w_l_shift_idx] = this->symbolic_vars[w_l_idx] >> 6;
this->symbolic_vars[w_r_shift_idx] = this->symbolic_vars[w_r_idx] >> 6;
this->symbolic_vars[w_o_shift_idx] = this->symbolic_vars[w_o_idx] >> 6;
}

STerm first_entry = this->symbolic_vars[w_l_idx] + q_r * this->symbolic_vars[w_l_shift_idx];
STerm second_entry = this->symbolic_vars[w_r_idx] + q_m * this->symbolic_vars[w_r_shift_idx];
STerm third_entry = this->symbolic_vars[w_o_idx] + q_c * this->symbolic_vars[w_o_shift_idx];
Expand Down Expand Up @@ -339,7 +341,7 @@ void UltraCircuit::handle_range_constraints()
uint32_t tag = this->real_variable_tags[this->real_variable_index[i]];
if (tag != 0 && this->range_tags.contains(tag)) {
uint64_t range = this->range_tags[tag];
if ((this->type != TermType::FFITerm) && (this->type != TermType::BVTerm)) {
if (this->type == TermType::FFTerm || !this->optimizations) {
if (!this->cached_range_tables.contains(range)) {
std::vector<cvc5::Term> new_range_table;
for (size_t entry = 0; entry < range; entry++) {
Expand All @@ -352,6 +354,7 @@ void UltraCircuit::handle_range_constraints()
} else {
this->symbolic_vars[i] <= range;
}
optimized[i] = false;
}
}
}
Expand Down Expand Up @@ -397,11 +400,11 @@ std::pair<UltraCircuit, UltraCircuit> UltraCircuit::unique_witness_ext(
const std::vector<std::string>& equal,
const std::vector<std::string>& not_equal,
const std::vector<std::string>& equal_at_the_same_time,
const std::vector<std::string>& not_equal_at_the_same_time)
const std::vector<std::string>& not_equal_at_the_same_time,
bool optimizations)
{
// TODO(alex): set optimizations to be true once they are confirmed
UltraCircuit c1(circuit_info, s, type, "circuit1", false);
UltraCircuit c2(circuit_info, s, type, "circuit2", false);
UltraCircuit c1(circuit_info, s, type, "circuit1", optimizations);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd rename to "enable_optimisations", since this is a flag

UltraCircuit c2(circuit_info, s, type, "circuit2", optimizations);

for (const auto& term : equal) {
c1[term] == c2[term];
Expand Down Expand Up @@ -449,14 +452,11 @@ std::pair<UltraCircuit, UltraCircuit> UltraCircuit::unique_witness_ext(
* @param equal The list of names of variables which should be equal in both circuits(each is equal)
* @return std::pair<Circuit, Circuit>
*/
std::pair<UltraCircuit, UltraCircuit> UltraCircuit::unique_witness(CircuitSchema& circuit_info,
Solver* s,
TermType type,
const std::vector<std::string>& equal)
std::pair<UltraCircuit, UltraCircuit> UltraCircuit::unique_witness(
CircuitSchema& circuit_info, Solver* s, TermType type, const std::vector<std::string>& equal, bool optimizations)
{
// TODO(alex): set optimizations to be true once they are confirmed
UltraCircuit c1(circuit_info, s, type, "circuit1", false);
UltraCircuit c2(circuit_info, s, type, "circuit2", false);
UltraCircuit c1(circuit_info, s, type, "circuit1", optimizations);
UltraCircuit c2(circuit_info, s, type, "circuit2", optimizations);

for (const auto& term : equal) {
c1[term] == c2[term];
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -56,10 +56,12 @@ class UltraCircuit : public CircuitBase {
const std::vector<std::string>& equal = {},
const std::vector<std::string>& not_equal = {},
const std::vector<std::string>& equal_at_the_same_time = {},
const std::vector<std::string>& not_equal_at_the_same_time = {});
const std::vector<std::string>& not_equal_at_the_same_time = {},
bool optimizations = false);
static std::pair<UltraCircuit, UltraCircuit> unique_witness(CircuitSchema& circuit_info,
Solver* s,
TermType type,
const std::vector<std::string>& equal = {});
const std::vector<std::string>& equal = {},
bool optimizations = false);
};
}; // namespace smt_circuit
Original file line number Diff line number Diff line change
Expand Up @@ -132,10 +132,13 @@ std::string Solver::stringify_term(const cvc5::Term& term, bool parenthesis)
return res + ")";
}
if (term.getKind() == cvc5::Kind::INTERNAL_KIND) {
return "";
}
if (term.getKind() == cvc5::Kind::SET_INSERT) {
return "set_" + std::to_string(this->tables[term]);
}
if (term.getKind() == cvc5::Kind::SET_INSERT || term.getKind() == cvc5::Kind::SET_EMPTY) {
return "";
if (term.getKind() == cvc5::Kind::SET_EMPTY) {
return "{}";
}

std::string res;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -242,11 +242,12 @@ TEST(BVTerm, or)
ASSERT_EQ(bvals, xvals);
}

TEST(BVTerm, shr)
TEST(BVTerm, div)
{
StandardCircuitBuilder builder;
uint_ct a = witness_ct(&builder, static_cast<uint32_t>(fr::random_element()));
uint_ct b = a >> 5;
uint_ct b = witness_ct(&builder, static_cast<uint32_t>(fr::random_element()));
uint_ct c = a / b;

uint32_t modulus_base = 16;
uint32_t bitvector_size = 32;
Expand All @@ -256,23 +257,25 @@ TEST(BVTerm, shr)
bitvector_size);

STerm x = BVVar("x", &s);
STerm y = x >> 5;
STerm y = BVVar("y", &s);
STerm z = x / y;

x == a.get_value();
y == b.get_value();

ASSERT_TRUE(s.check());

std::string xvals = s.getValue(y.term).getBitVectorValue();
STerm bval = STerm(b.get_value(), &s, TermType::BVTerm);
std::string xvals = s.getValue(z.term).getBitVectorValue();
STerm bval = STerm(c.get_value(), &s, TermType::BVTerm);
std::string bvals = s.getValue(bval.term).getBitVectorValue();
ASSERT_EQ(bvals, xvals);
}

TEST(BVTerm, shl)
TEST(BVTerm, shr)
{
StandardCircuitBuilder builder;
uint_ct a = witness_ct(&builder, static_cast<uint32_t>(fr::random_element()));
uint_ct b = a << 5;
uint_ct b = a >> 5;

uint32_t modulus_base = 16;
uint32_t bitvector_size = 32;
Expand All @@ -282,7 +285,7 @@ TEST(BVTerm, shl)
bitvector_size);

STerm x = BVVar("x", &s);
STerm y = x << 5;
STerm y = x >> 5;

x == a.get_value();

Expand All @@ -294,11 +297,11 @@ TEST(BVTerm, shl)
ASSERT_EQ(bvals, xvals);
}

// This test aims to check for the absence of unintended
// behavior. If an unsupported operator is called, an info message appears in stderr
// and the value is supposed to remain unchanged.
TEST(BVTerm, unsupported_operations)
TEST(BVTerm, shl)
{
StandardCircuitBuilder builder;
uint_ct a = witness_ct(&builder, static_cast<uint32_t>(fr::random_element()));
uint_ct b = a << 5;

uint32_t modulus_base = 16;
uint32_t bitvector_size = 32;
Expand All @@ -308,8 +311,14 @@ TEST(BVTerm, unsupported_operations)
bitvector_size);

STerm x = BVVar("x", &s);
STerm y = BVVar("y", &s);
STerm y = x << 5;

STerm z = x / y;
ASSERT_EQ(z.term, x.term);
x == a.get_value();

ASSERT_TRUE(s.check());

std::string xvals = s.getValue(y.term).getBitVectorValue();
STerm bval = STerm(b.get_value(), &s, TermType::BVTerm);
std::string bvals = s.getValue(bval.term).getBitVectorValue();
ASSERT_EQ(bvals, xvals);
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,149 @@
#include <unordered_map>

#include "barretenberg/stdlib/primitives/uint/uint.hpp"
#include "term.hpp"

#include <gtest/gtest.h>

namespace {
auto& engine = bb::numeric::get_debug_randomness();
}

using namespace bb;
using witness_ct = stdlib::witness_t<StandardCircuitBuilder>;

using namespace smt_terms;

TEST(ITerm, addition)
{
StandardCircuitBuilder builder;
uint64_t a = static_cast<uint32_t>(fr::random_element()) % (static_cast<uint32_t>(1) << 31);
uint64_t b = static_cast<uint32_t>(fr::random_element()) % (static_cast<uint32_t>(1) << 31);
uint64_t c = a + b;

Solver s("30644e72e131a029b85045b68181585d2833e84879b9709143e1f593f0000001", default_solver_config);

STerm x = IVar("x", &s);
STerm y = IVar("y", &s);
STerm z = x + y;

z == c;
x == a;
ASSERT_TRUE(s.check());

std::string yvals = s.getValue(y.term).getIntegerValue();

STerm bval = STerm(b, &s, TermType::ITerm);
std::string bvals = s.getValue(bval.term).getIntegerValue();
ASSERT_EQ(bvals, yvals);
}

TEST(ITerm, subtraction)
{
StandardCircuitBuilder builder;
uint64_t c = static_cast<uint32_t>(fr::random_element()) % (static_cast<uint32_t>(1) << 31);
uint64_t b = static_cast<uint32_t>(fr::random_element()) % (static_cast<uint32_t>(1) << 31);
uint64_t a = c + b;

Solver s("30644e72e131a029b85045b68181585d2833e84879b9709143e1f593f0000001", default_solver_config);

STerm x = IVar("x", &s);
STerm y = IVar("y", &s);
STerm z = x - y;

x == a;
z == c;
ASSERT_TRUE(s.check());

std::string yvals = s.getValue(y.term).getIntegerValue();

STerm bval = STerm(b, &s, TermType::ITerm);
std::string bvals = s.getValue(bval.term).getIntegerValue();
ASSERT_EQ(bvals, yvals);
}

TEST(ITerm, mul)
{
StandardCircuitBuilder builder;
uint64_t a = static_cast<uint32_t>(fr::random_element()) % (static_cast<uint32_t>(1) << 31);
uint64_t b = static_cast<uint32_t>(fr::random_element()) % (static_cast<uint32_t>(1) << 31);
uint64_t c = a * b;

Solver s("30644e72e131a029b85045b68181585d2833e84879b9709143e1f593f0000001", default_solver_config);

STerm x = IVar("x", &s);
STerm y = IVar("y", &s);
STerm z = x * y;

x == a;
y == b;

ASSERT_TRUE(s.check());

std::string xvals = s.getValue(z.term).getIntegerValue();
STerm bval = STerm(c, &s, TermType::ITerm);
std::string bvals = s.getValue(bval.term).getIntegerValue();
ASSERT_EQ(bvals, xvals);
}

TEST(ITerm, div)
{
StandardCircuitBuilder builder;
uint64_t a = static_cast<uint32_t>(fr::random_element()) % (static_cast<uint32_t>(1) << 31);
uint64_t b = static_cast<uint32_t>(fr::random_element()) % (static_cast<uint32_t>(1) << 31) + 1;
uint64_t c = a / b;

Solver s("30644e72e131a029b85045b68181585d2833e84879b9709143e1f593f0000001", default_solver_config);

STerm x = IVar("x", &s);
STerm y = IVar("y", &s);
STerm z = x / y;

x == a;
y == b;

ASSERT_TRUE(s.check());

std::string xvals = s.getValue(z.term).getIntegerValue();
STerm bval = STerm(c, &s, TermType::ITerm);
std::string bvals = s.getValue(bval.term).getIntegerValue();
ASSERT_EQ(bvals, xvals);
}

// This test aims to check for the absence of unintended
// behavior. If an unsupported operator is called, an info message appears in stderr
// and the value is supposed to remain unchanged.
TEST(ITerm, unsupported_operations)
{
Solver s("30644e72e131a029b85045b68181585d2833e84879b9709143e1f593f0000001");

STerm x = IVar("x", &s);
STerm y = IVar("y", &s);

STerm z = x ^ y;
ASSERT_EQ(z.term, x.term);
z = x & y;
ASSERT_EQ(z.term, x.term);
z = x | y;
ASSERT_EQ(z.term, x.term);
z = x >> 10;
ASSERT_EQ(z.term, x.term);
z = x << 10;
ASSERT_EQ(z.term, x.term);
z = x.rotr(10);
ASSERT_EQ(z.term, x.term);
z = x.rotl(10);
ASSERT_EQ(z.term, x.term);

cvc5::Term before_term = x.term;
x ^= y;
ASSERT_EQ(x.term, before_term);
x &= y;
ASSERT_EQ(x.term, before_term);
x |= y;
ASSERT_EQ(x.term, before_term);
x >>= 10;
ASSERT_EQ(x.term, before_term);
x <<= 10;
ASSERT_EQ(x.term, before_term);
}
Loading
Loading