diff --git a/barretenberg/cpp/src/barretenberg/smt_verification/CMakeLists.txt b/barretenberg/cpp/src/barretenberg/smt_verification/CMakeLists.txt index 265a64887e6..35ce5bcd48d 100644 --- a/barretenberg/cpp/src/barretenberg/smt_verification/CMakeLists.txt +++ b/barretenberg/cpp/src/barretenberg/smt_verification/CMakeLists.txt @@ -12,7 +12,7 @@ ExternalProject_Add( GIT_REPOSITORY "https://github.com/cvc5/cvc5.git" GIT_TAG main BUILD_IN_SOURCE YES - CONFIGURE_COMMAND ${SHELL} ./configure.sh production --auto-download --cocoa --cryptominisat --kissat -DCMAKE_C_COMPILER=/usr/bin/clang -DCMAKE_CXX_COMPILER=/usr/bin/clang++ --prefix=${CVC5_BUILD} + CONFIGURE_COMMAND ${SHELL} ./configure.sh production --gpl --auto-download --cocoa --cryptominisat --kissat -DCMAKE_C_COMPILER=/usr/bin/clang -DCMAKE_CXX_COMPILER=/usr/bin/clang++ --prefix=${CVC5_BUILD} BUILD_COMMAND make -C build INSTALL_COMMAND make -C build install UPDATE_COMMAND "" # No update step diff --git a/barretenberg/cpp/src/barretenberg/smt_verification/README.md b/barretenberg/cpp/src/barretenberg/smt_verification/README.md index aa12cccf666..2db78c3853b 100644 --- a/barretenberg/cpp/src/barretenberg/smt_verification/README.md +++ b/barretenberg/cpp/src/barretenberg/smt_verification/README.md @@ -184,11 +184,17 @@ Now you have the values of the specified terms, which resulted into `true` resul **!Note that the return values are decimal strings/binary strings**, so if you want to use them later you should use `FFConst` with base 10, etc. Also, there is a header file "barretenberg/smt_verification/utl/smt_util.hpp" that contains two useful functions: -- `default_model(verctor special_names, circuit1, circuit2, *solver, fname="witness.out")` -- `default_model_single(vector special_names, circuit, *solver, fname="witness.out)` +- `default_model(verctor special_names, circuit1, circuit2, *solver, fname="witness.out", bool pack = true)` +- `default_model_single(vector special_names, circuit, *solver, fname="witness.out, bool pack = true)` These functions will write witness variables in c-like array format into file named `fname`. The vector of `special_names` is the values that you want ot see in stdout. +`pack` argument tells this function to save an `msgpack` buffer of the witness on disk. Name of the file will be `fname`.pack + +You can then import the saved witness using one of the following functions: + +- `vec> import_witness(str fname)` +- `vec import_witness_single(str fname)` ## 4. Automated verification of a unique witness There's a static member of `StandardCircuit` and `UltraCircuit` @@ -211,6 +217,7 @@ Besides already mentioned `smt_timer`, `default_model` and `default_model_single - `pair, vector> base4(uint32_t el)` - that will return base4 accumulators - `void fix_range_lists(UltraCircuitBuilder& builder)` - Since we are not using the part of the witness, that contains range lists, they are set to 0 by the solver. We need to overwrite them to check the witness obtained by the solver. +- `bb::fr string_to_fr(str num, int base, size_t step)` - converts string of an arbitrary base into `bb::fr` value. $\max_{step}(base^{step} \le 2^{64})$ ```c++ UltraCircuitBuilder builder; diff --git a/barretenberg/cpp/src/barretenberg/smt_verification/circuit/circuit_base.cpp b/barretenberg/cpp/src/barretenberg/smt_verification/circuit/circuit_base.cpp index b4c6d1c872b..edde7e005cf 100644 --- a/barretenberg/cpp/src/barretenberg/smt_verification/circuit/circuit_base.cpp +++ b/barretenberg/cpp/src/barretenberg/smt_verification/circuit/circuit_base.cpp @@ -11,14 +11,14 @@ CircuitBase::CircuitBase(std::unordered_map& variable_nam Solver* solver, TermType type, const std::string& tag, - bool optimizations) + bool enable_optimizations) : variables(variables) , public_inps(public_inps) , variable_names(variable_names) , real_variable_index(real_variable_index) , real_variable_tags(real_variable_tags) , range_tags(range_tags) - , optimizations(optimizations) + , enable_optimizations(enable_optimizations) , solver(solver) , type(type) , tag(tag) diff --git a/barretenberg/cpp/src/barretenberg/smt_verification/circuit/circuit_base.hpp b/barretenberg/cpp/src/barretenberg/smt_verification/circuit/circuit_base.hpp index 5d551406a4c..e39d46a88e0 100644 --- a/barretenberg/cpp/src/barretenberg/smt_verification/circuit/circuit_base.hpp +++ b/barretenberg/cpp/src/barretenberg/smt_verification/circuit/circuit_base.hpp @@ -37,7 +37,7 @@ class CircuitBase { std::unordered_map range_tags; // ranges associated with a certain tag std::unordered_map optimized; // keeps track of the variables that were excluded from symbolic // circuit during optimizations - bool optimizations; // flags to turn on circuit optimizations + bool enable_optimizations; // flags to turn on circuit optimizations std::unordered_map> cached_subcircuits; // caches subcircuits during optimization // No need to recompute them each time @@ -58,7 +58,7 @@ class CircuitBase { Solver* solver, TermType type, const std::string& tag = "", - bool optimizations = true); + bool enable_optimizations = true); STerm operator[](const std::string& name); STerm operator[](const uint32_t& idx) { return this->symbolic_vars[this->real_variable_index[idx]]; }; diff --git a/barretenberg/cpp/src/barretenberg/smt_verification/circuit/standard_circuit.cpp b/barretenberg/cpp/src/barretenberg/smt_verification/circuit/standard_circuit.cpp index 1365618b02c..6d68d93333d 100644 --- a/barretenberg/cpp/src/barretenberg/smt_verification/circuit/standard_circuit.cpp +++ b/barretenberg/cpp/src/barretenberg/smt_verification/circuit/standard_circuit.cpp @@ -10,7 +10,7 @@ namespace smt_circuit { * @param tag tag of the circuit. Empty by default. */ StandardCircuit::StandardCircuit( - CircuitSchema& circuit_info, Solver* solver, TermType type, const std::string& tag, bool optimizations) + CircuitSchema& circuit_info, Solver* solver, TermType type, const std::string& tag, bool enable_optimizations) : CircuitBase(circuit_info.vars_of_interest, circuit_info.variables, circuit_info.public_inps, @@ -20,7 +20,7 @@ StandardCircuit::StandardCircuit( solver, type, tag, - optimizations) + enable_optimizations) , selectors(circuit_info.selectors[0]) , wires_idxs(circuit_info.wires[0]) { @@ -46,34 +46,34 @@ StandardCircuit::StandardCircuit( */ size_t StandardCircuit::prepare_gates(size_t cursor) { - if (this->type == TermType::BVTerm && this->optimizations) { + if (this->type == TermType::BVTerm && this->enable_optimizations) { size_t res = handle_logic_constraint(cursor); if (res != static_cast(-1)) { return res; } } - if ((this->type == TermType::BVTerm || this->type == TermType::FFITerm) && this->optimizations) { + if ((this->type == TermType::BVTerm || this->type == TermType::FFITerm) && this->enable_optimizations) { size_t res = handle_range_constraint(cursor); if (res != static_cast(-1)) { return res; } } - if ((this->type == TermType::BVTerm) && this->optimizations) { + if ((this->type == TermType::BVTerm) && this->enable_optimizations) { size_t res = handle_ror_constraint(cursor); if (res != static_cast(-1)) { return res; } } - if ((this->type == TermType::BVTerm) && this->optimizations) { + if ((this->type == TermType::BVTerm) && this->enable_optimizations) { size_t res = handle_shl_constraint(cursor); if (res != static_cast(-1)) { return res; } } - if ((this->type == TermType::BVTerm) && this->optimizations) { + if ((this->type == TermType::BVTerm) && this->enable_optimizations) { size_t res = handle_shr_constraint(cursor); if (res != static_cast(-1)) { return res; @@ -182,6 +182,7 @@ void StandardCircuit::handle_univariate_constraint( } } +// TODO(alex): Optimized out variables should be filled with proper values... /** * @brief Relaxes logic constraints(AND/XOR). * @details This function is needed when we use bitwise compatible @@ -252,6 +253,20 @@ size_t StandardCircuit::handle_logic_constraint(size_t cursor) xor_flag &= xor_circuit.selectors[0][j + xor_props.start_gate] == this->selectors[cursor + j]; and_flag &= and_circuit.selectors[0][j + and_props.start_gate] == this->selectors[cursor + j]; + // Before this fix this routine simplified two consecutive n bit xors(ands) into one 2n bit xor(and) + // Now it checks out_accumulator_idx and new_out_accumulator_idx match + // 14 here is a size of one iteration of logic_gate for loop in term of gates + // 13 is the accumulator index relative to the beginning of the iteration + + size_t single_iteration_size = 14; + size_t relative_acc_idx = 13; + xor_flag &= + (j % single_iteration_size != relative_acc_idx) || (j == relative_acc_idx) || + (this->wires_idxs[j + cursor][0] == this->wires_idxs[j + cursor - single_iteration_size][2]); + and_flag &= + (j % single_iteration_size != relative_acc_index) || (j == relative_acc_index) || + (this->wires_idxs[j + cursor][0] == this->wires_idxs[j + cursor - single_iteration_size][2]); + if (!xor_flag && !and_flag) { // Won't match at any bit length if (j == 0) { @@ -411,6 +426,10 @@ size_t StandardCircuit::handle_range_constraint(size_t cursor) // preserving shifted values // we need this because even right shifts do not create // any additional gates and therefore are undetectible + + // TODO(alex): I think I should simulate the whole subcircuit at that point + // Otherwise optimized out variables are not correct in the final witness + // And I can't fix them by hand each time size_t num_accs = range_props.gate_idxs.size() - 1; for (size_t j = 1; j < num_accs + 1 && (this->type == TermType::BVTerm); j++) { size_t acc_gate = range_props.gate_idxs[j]; @@ -418,10 +437,9 @@ size_t StandardCircuit::handle_range_constraint(size_t cursor) uint32_t acc_idx = this->real_variable_index[this->wires_idxs[cursor + acc_gate][acc_gate_idx]]; - // TODO(alex): Is it better? Can't come up with why not right now - // STerm acc = this->symbolic_vars[acc_idx]; - // acc == (left >> static_cast(2 * j)); - this->symbolic_vars[acc_idx] = (left >> static_cast(2 * j)); + this->symbolic_vars[acc_idx] == (left >> static_cast(2 * j)); + // I think the following is worse. The name of the variable is lost after that + // this->symbolic_vars[acc_idx] = (left >> static_cast(2 * j)); } left <= (bb::fr(2).pow(res) - 1); @@ -812,10 +830,10 @@ std::pair StandardCircuit::unique_witness_ext( const std::vector& not_equal, const std::vector& equal_at_the_same_time, const std::vector& not_equal_at_the_same_time, - bool optimizations) + bool enable_optimizations) { - StandardCircuit c1(circuit_info, s, type, "circuit1", optimizations); - StandardCircuit c2(circuit_info, s, type, "circuit2", optimizations); + StandardCircuit c1(circuit_info, s, type, "circuit1", enable_optimizations); + StandardCircuit c2(circuit_info, s, type, "circuit2", enable_optimizations); for (const auto& term : equal) { c1[term] == c2[term]; @@ -863,11 +881,14 @@ std::pair StandardCircuit::unique_witness_ext( * @param equal The list of names of variables which should be equal in both circuits(each is equal) * @return std::pair */ -std::pair StandardCircuit::unique_witness( - CircuitSchema& circuit_info, Solver* s, TermType type, const std::vector& equal, bool optimizations) +std::pair StandardCircuit::unique_witness(CircuitSchema& circuit_info, + Solver* s, + TermType type, + const std::vector& equal, + bool enable_optimizations) { - StandardCircuit c1(circuit_info, s, type, "circuit1", optimizations); - StandardCircuit c2(circuit_info, s, type, "circuit2", optimizations); + StandardCircuit c1(circuit_info, s, type, "circuit1", enable_optimizations); + StandardCircuit c2(circuit_info, s, type, "circuit2", enable_optimizations); for (const auto& term : equal) { c1[term] == c2[term]; @@ -893,4 +914,4 @@ std::pair StandardCircuit::unique_witness( } return { c1, c2 }; } -}; // namespace smt_circuit \ No newline at end of file +}; // namespace smt_circuit diff --git a/barretenberg/cpp/src/barretenberg/smt_verification/circuit/standard_circuit.hpp b/barretenberg/cpp/src/barretenberg/smt_verification/circuit/standard_circuit.hpp index 03d0a4393f1..967c7ffdad8 100644 --- a/barretenberg/cpp/src/barretenberg/smt_verification/circuit/standard_circuit.hpp +++ b/barretenberg/cpp/src/barretenberg/smt_verification/circuit/standard_circuit.hpp @@ -18,7 +18,7 @@ class StandardCircuit : public CircuitBase { Solver* solver, TermType type = TermType::FFTerm, const std::string& tag = "", - bool optimizations = true); + bool enable_optimizations = true); inline size_t get_num_gates() const { return selectors.size(); }; @@ -40,12 +40,12 @@ class StandardCircuit : public CircuitBase { const std::vector& not_equal = {}, const std::vector& equal_at_the_same_time = {}, const std::vector& not_equal_at_the_same_time = {}, - bool optimizations = false); + bool enable_optimizations = false); static std::pair unique_witness(CircuitSchema& circuit_info, Solver* s, TermType type, const std::vector& equal = {}, - bool optimizations = false); + bool enable_optimizations = false); }; }; // namespace smt_circuit \ No newline at end of file diff --git a/barretenberg/cpp/src/barretenberg/smt_verification/circuit/standard_circuit.test.cpp b/barretenberg/cpp/src/barretenberg/smt_verification/circuit/standard_circuit.test.cpp index 4c885ba6cd3..ffded6864dd 100644 --- a/barretenberg/cpp/src/barretenberg/smt_verification/circuit/standard_circuit.test.cpp +++ b/barretenberg/cpp/src/barretenberg/smt_verification/circuit/standard_circuit.test.cpp @@ -356,3 +356,23 @@ TEST(standard_circuit, shr_relaxation) StandardCircuit circuit(circuit_info, &s, TermType::BVTerm); } } + +TEST(standard_circuit, check_double_xor_bug) +{ + StandardCircuitBuilder builder; + uint_ct a = witness_t(&builder, 10); + uint_ct b = witness_t(&builder, 10); + + uint_ct c = a ^ b; + uint_ct d = a ^ b; + d = d ^ c; + + c = a & b; + d = a & b; + d = d & c; + + auto buf = builder.export_circuit(); + CircuitSchema circuit_info = unpack_from_buffer(buf); + Solver s(circuit_info.modulus, default_solver_config, 16, 64); + StandardCircuit circuit(circuit_info, &s, TermType::BVTerm); +} diff --git a/barretenberg/cpp/src/barretenberg/smt_verification/circuit/ultra_circuit.cpp b/barretenberg/cpp/src/barretenberg/smt_verification/circuit/ultra_circuit.cpp index c469ead82fc..01cb3728524 100644 --- a/barretenberg/cpp/src/barretenberg/smt_verification/circuit/ultra_circuit.cpp +++ b/barretenberg/cpp/src/barretenberg/smt_verification/circuit/ultra_circuit.cpp @@ -138,6 +138,38 @@ size_t UltraCircuit::handle_arithmetic_relation(size_t cursor, size_t idx) return cursor + 1; } +void UltraCircuit::process_new_table(uint32_t table_idx) +{ + std::vector> new_table; + bool is_xor = true; + bool is_and = true; + + for (auto table_entry : this->lookup_tables[table_idx]) { + std::vector tmp_entry = { + STerm(table_entry[0], this->solver, this->type), + STerm(table_entry[1], this->solver, this->type), + STerm(table_entry[2], this->solver, this->type), + }; + new_table.push_back(tmp_entry); + + is_xor &= (static_cast(table_entry[0]) ^ static_cast(table_entry[1])) == + static_cast(table_entry[2]); + is_and &= (static_cast(table_entry[0]) & static_cast(table_entry[1])) == + static_cast(table_entry[2]); + } + this->cached_symbolic_tables.insert({ table_idx, this->solver->create_lookup_table(new_table) }); + if (is_xor) { + this->tables_types.insert({ table_idx, TableType::XOR }); + info("Encountered a XOR table"); + } else if (is_and) { + this->tables_types.insert({ table_idx, TableType::AND }); + info("Encountered an AND table"); + } else { + this->tables_types.insert({ table_idx, TableType::UNKNOWN }); + info("Encountered an UNKNOWN table"); + } +} + /** * @brief Adds all the lookup gate constraints to the solver. * Relaxes constraint system for non-ff solver engines @@ -175,31 +207,36 @@ size_t UltraCircuit::handle_lookup_relation(size_t cursor, size_t idx) auto table_idx = static_cast(q_o); if (!this->cached_symbolic_tables.contains(table_idx)) { - std::vector> new_table; - for (auto table_entry : this->lookup_tables[table_idx]) { - std::vector tmp_entry = { - STerm(table_entry[0], this->solver, this->type), - STerm(table_entry[1], this->solver, this->type), - STerm(table_entry[2], this->solver, this->type), - }; - new_table.push_back(tmp_entry); - } - this->cached_symbolic_tables.insert({ table_idx, this->solver->create_lookup_table(new_table) }); - } - - // Sort of an optimization. - // However if we don't do this, solver will find a unique witness that corresponds to overflowed value. - if (this->type == TermType::BVTerm && q_r == -64 && q_m == -64 && q_c == -64) { - this->symbolic_vars[w_l_shift_idx] = this->symbolic_vars[w_l_idx] >> 6; - this->symbolic_vars[w_r_shift_idx] = this->symbolic_vars[w_r_idx] >> 6; - this->symbolic_vars[w_o_shift_idx] = this->symbolic_vars[w_o_idx] >> 6; + this->process_new_table(table_idx); } STerm first_entry = this->symbolic_vars[w_l_idx] + q_r * this->symbolic_vars[w_l_shift_idx]; STerm second_entry = this->symbolic_vars[w_r_idx] + q_m * this->symbolic_vars[w_r_shift_idx]; STerm third_entry = this->symbolic_vars[w_o_idx] + q_c * this->symbolic_vars[w_o_shift_idx]; - std::vector entries = { first_entry, second_entry, third_entry }; + + if (this->type == TermType::BVTerm && this->enable_optimizations) { + // Sort of an optimization. + // However if we don't do this, solver will find a unique witness that corresponds to overflowed value. + if (q_r == -64 && q_m == -64 && q_c == -64) { + this->symbolic_vars[w_l_shift_idx] = this->symbolic_vars[w_l_idx] >> 6; + this->symbolic_vars[w_r_shift_idx] = this->symbolic_vars[w_r_idx] >> 6; + this->symbolic_vars[w_o_shift_idx] = this->symbolic_vars[w_o_idx] >> 6; + } + switch (this->tables_types[table_idx]) { + case TableType::XOR: + info("XOR optimization"); + (first_entry ^ second_entry) == third_entry; + return cursor + 1; + case TableType::AND: + info("AND optimization"); + (first_entry & second_entry) == third_entry; + return cursor + 1; + case TableType::UNKNOWN: + break; + } + } + info("Unknown Table"); STerm::in_table(entries, this->cached_symbolic_tables[table_idx]); return cursor + 1; } @@ -341,7 +378,7 @@ void UltraCircuit::handle_range_constraints() uint32_t tag = this->real_variable_tags[this->real_variable_index[i]]; if (tag != 0 && this->range_tags.contains(tag)) { uint64_t range = this->range_tags[tag]; - if (this->type == TermType::FFTerm || !this->optimizations) { + if (this->type == TermType::FFTerm || !this->enable_optimizations) { if (!this->cached_range_tables.contains(range)) { std::vector new_range_table; for (size_t entry = 0; entry < range; entry++) { @@ -401,10 +438,10 @@ std::pair UltraCircuit::unique_witness_ext( const std::vector& not_equal, const std::vector& equal_at_the_same_time, const std::vector& not_equal_at_the_same_time, - bool optimizations) + bool enable_optimizations) { - UltraCircuit c1(circuit_info, s, type, "circuit1", optimizations); - UltraCircuit c2(circuit_info, s, type, "circuit2", optimizations); + UltraCircuit c1(circuit_info, s, type, "circuit1", enable_optimizations); + UltraCircuit c2(circuit_info, s, type, "circuit2", enable_optimizations); for (const auto& term : equal) { c1[term] == c2[term]; @@ -452,11 +489,14 @@ std::pair UltraCircuit::unique_witness_ext( * @param equal The list of names of variables which should be equal in both circuits(each is equal) * @return std::pair */ -std::pair UltraCircuit::unique_witness( - CircuitSchema& circuit_info, Solver* s, TermType type, const std::vector& equal, bool optimizations) +std::pair UltraCircuit::unique_witness(CircuitSchema& circuit_info, + Solver* s, + TermType type, + const std::vector& equal, + bool enable_optimizations) { - UltraCircuit c1(circuit_info, s, type, "circuit1", optimizations); - UltraCircuit c2(circuit_info, s, type, "circuit2", optimizations); + UltraCircuit c1(circuit_info, s, type, "circuit1", enable_optimizations); + UltraCircuit c2(circuit_info, s, type, "circuit2", enable_optimizations); for (const auto& term : equal) { c1[term] == c2[term]; diff --git a/barretenberg/cpp/src/barretenberg/smt_verification/circuit/ultra_circuit.hpp b/barretenberg/cpp/src/barretenberg/smt_verification/circuit/ultra_circuit.hpp index c6ec07a849e..5fd3be1fd4d 100644 --- a/barretenberg/cpp/src/barretenberg/smt_verification/circuit/ultra_circuit.hpp +++ b/barretenberg/cpp/src/barretenberg/smt_verification/circuit/ultra_circuit.hpp @@ -3,6 +3,8 @@ namespace smt_circuit { +enum class TableType : int32_t { XOR, AND, UNKNOWN }; + /** * @brief Symbolic Circuit class for Standard Circuit Builder. * @@ -22,13 +24,14 @@ class UltraCircuit : public CircuitBase { std::vector>> lookup_tables; std::unordered_map cached_symbolic_tables; + std::unordered_map tables_types; std::unordered_map cached_range_tables; explicit UltraCircuit(CircuitSchema& circuit_info, Solver* solver, TermType type = TermType::FFTerm, const std::string& tag = "", - bool optimizations = true); + bool enable_optimizations = true); UltraCircuit(const UltraCircuit& other) = default; UltraCircuit(UltraCircuit&& other) = default; UltraCircuit& operator=(const UltraCircuit& other) = default; @@ -42,6 +45,7 @@ class UltraCircuit : public CircuitBase { }; bool simulate_circuit_eval(std::vector& witness) const override; + void process_new_table(uint32_t table_idx); size_t handle_arithmetic_relation(size_t cursor, size_t idx); size_t handle_lookup_relation(size_t cursor, size_t idx); @@ -57,11 +61,11 @@ class UltraCircuit : public CircuitBase { const std::vector& not_equal = {}, const std::vector& equal_at_the_same_time = {}, const std::vector& not_equal_at_the_same_time = {}, - bool optimizations = false); + bool enable_optimizations = false); static std::pair unique_witness(CircuitSchema& circuit_info, Solver* s, TermType type, const std::vector& equal = {}, - bool optimizations = false); + bool enable_optimizations = false); }; }; // namespace smt_circuit \ No newline at end of file diff --git a/barretenberg/cpp/src/barretenberg/smt_verification/circuit/ultra_circuit.test.cpp b/barretenberg/cpp/src/barretenberg/smt_verification/circuit/ultra_circuit.test.cpp index bdff9b10774..3c7912f116c 100644 --- a/barretenberg/cpp/src/barretenberg/smt_verification/circuit/ultra_circuit.test.cpp +++ b/barretenberg/cpp/src/barretenberg/smt_verification/circuit/ultra_circuit.test.cpp @@ -239,4 +239,66 @@ TEST(ultra_circuit, lookup_tables) std::string c_solver_val = s.getValue(cir["c"]).getBitVectorValue(); std::string c_builder_val = STerm(c.get_value(), &s, TermType::BVTerm).term.getBitVectorValue(); ASSERT_EQ(c_solver_val, c_builder_val); -} \ No newline at end of file +} + +TEST(ultra_circuit, xor_optimization) +{ + UltraCircuitBuilder builder; + uint_t a(witness_t(&builder, static_cast(bb::fr::random_element()))); + builder.set_variable_name(a.get_witness_index(), "a"); + uint_t b(witness_t(&builder, static_cast(bb::fr::random_element()))); + builder.set_variable_name(b.get_witness_index(), "b"); + uint_t c = a ^ b; + builder.set_variable_name(c.get_witness_index(), "c"); + + CircuitSchema circuit_info = unpack_from_buffer(builder.export_circuit()); + uint32_t modulus_base = 16; + uint32_t bvsize = 35; + Solver s(circuit_info.modulus, ultra_solver_config, modulus_base, bvsize); + + UltraCircuit circuit(circuit_info, &s, TermType::BVTerm); + + circuit["a"] == a.get_value(); + circuit["b"] == b.get_value(); + + s.print_assertions(); + + bool res = smt_timer(&s); + ASSERT_TRUE(res); + std::vector to_model = { circuit["c"] }; + std::unordered_map model = s.model(to_model); + + bb::fr c_sym = string_to_fr(model["c"], 2); + ASSERT_EQ(c_sym, c.get_value()); +} + +TEST(ultra_circuit, and_optimization) +{ + UltraCircuitBuilder builder; + uint_t a(witness_t(&builder, static_cast(bb::fr::random_element()))); + builder.set_variable_name(a.get_witness_index(), "a"); + uint_t b(witness_t(&builder, static_cast(bb::fr::random_element()))); + builder.set_variable_name(b.get_witness_index(), "b"); + uint_t c = a & b; + builder.set_variable_name(c.get_witness_index(), "c"); + + CircuitSchema circuit_info = unpack_from_buffer(builder.export_circuit()); + uint32_t modulus_base = 16; + uint32_t bvsize = 35; + Solver s(circuit_info.modulus, ultra_solver_config, modulus_base, bvsize); + + UltraCircuit circuit(circuit_info, &s, TermType::BVTerm); + + circuit["a"] == a.get_value(); + circuit["b"] == b.get_value(); + + s.print_assertions(); + + bool res = smt_timer(&s); + ASSERT_TRUE(res); + std::vector to_model = { circuit["c"] }; + std::unordered_map model = s.model(to_model); + + bb::fr c_sym = string_to_fr(model["c"], 2); + ASSERT_EQ(c_sym, c.get_value()); +} diff --git a/barretenberg/cpp/src/barretenberg/smt_verification/solver/solver.cpp b/barretenberg/cpp/src/barretenberg/smt_verification/solver/solver.cpp index acd855ad69d..2219a05b0d4 100644 --- a/barretenberg/cpp/src/barretenberg/smt_verification/solver/solver.cpp +++ b/barretenberg/cpp/src/barretenberg/smt_verification/solver/solver.cpp @@ -48,7 +48,7 @@ std::unordered_map Solver::model(std::unordered_map Solver::model(std::vectortype == TermType::FFTerm || this->type == TermType::FFITerm) { other != bb::fr(0); + // Random value added to the name to prevent collisions. This value is MD5('Aztec') STerm res = Var("df8b586e3fa7a1224ec95a886e17a7da_div_" + static_cast(*this) + "_" + static_cast(other), this->solver, @@ -176,6 +177,7 @@ void STerm::operator/=(const STerm& other) } if (this->type == TermType::FFTerm || this->type == TermType::FFITerm) { other != bb::fr(0); + // Random value added to the name to prevent collisions. This value is MD5('Aztec') STerm res = Var("df8b586e3fa7a1224ec95a886e17a7da_div_" + static_cast(*this) + "_" + static_cast(other), this->solver, diff --git a/barretenberg/cpp/src/barretenberg/smt_verification/util/smt_util.cpp b/barretenberg/cpp/src/barretenberg/smt_verification/util/smt_util.cpp index 91a59608f5f..2a2ec75c54b 100644 --- a/barretenberg/cpp/src/barretenberg/smt_verification/util/smt_util.cpp +++ b/barretenberg/cpp/src/barretenberg/smt_verification/util/smt_util.cpp @@ -1,5 +1,41 @@ #include "smt_util.hpp" +/** + * @brief Converts a string of an arbitrary base to fr. + * Note: there should be no prefix + * + * @param number string to be converted + * @param base base representation of the string + * @param step power n such that base^n <= 2^64. If base = 2, 10, 16. May remain undeclared. + * @return bb::fr + */ +bb::fr string_to_fr(const std::string& number, int base, size_t step) +{ + bb::fr res = 0; + char* ptr = nullptr; + if (base == 2) { + step = 64; + } else if (base == 16) { + step = 4; + } else if (base == 10) { + step = 19; + } else if (step == 0) { + info("Step should be non zero"); + return 0; + } + + size_t i = number[0] == '-' ? 1 : 0; + bb::fr step_power = bb::fr(base).pow(step); + for (; i < number.size(); i += step) { + std::string slice = number.substr(i, step); + bb::fr cur_power = i + step > number.size() ? bb::fr(base).pow(number.size() - i) : step_power; + res *= cur_power; + res += std::strtoull(slice.data(), &ptr, base); + } + res = number[0] == '-' ? -res : res; + return res; +} + /** * @brief Get pretty formatted result of the solver work * @@ -13,11 +49,13 @@ * @param c2 the copy of the first circuit with changed tag * @param s solver * @param fname file to store the resulting witness if succeded + * @param pack flags out to pack the resulting witness using msgpack */ void default_model(const std::vector& special, smt_circuit::CircuitBase& c1, smt_circuit::CircuitBase& c2, - const std::string& fname) + const std::string& fname, + bool pack) { std::vector vterms1; std::vector vterms2; @@ -35,38 +73,42 @@ void default_model(const std::vector& special, std::fstream myfile; myfile.open(fname, std::ios::out | std::ios::trunc | std::ios::binary); myfile << "w12 = {" << std::endl; + + std::vector> packed_witness; + packed_witness.reserve(c1.get_num_vars()); + int base = c1.type == smt_terms::TermType::BVTerm ? 2 : 10; + for (uint32_t i = 0; i < c1.get_num_vars(); i++) { std::string vname1 = vterms1[i].toString(); std::string vname2 = vterms2[i].toString(); - if (c1.real_variable_index[i] == i) { - myfile << "{" << mmap1[vname1] << ", " << mmap2[vname2] << "}"; - myfile << ", // " << vname1 << ", " << vname2 << std::endl; - if (mmap1[vname1] != mmap2[vname2]) { - info(RED, "{", mmap1[vname1], ", ", mmap2[vname2], "}", ", // ", vname1, ", ", vname2, RESET); - } - } else { - myfile << "{" << mmap1[vname1] << ", " + mmap2[vname2] << "}"; - myfile << ", // " << vname1 << " ," << vname2 << " -> " << c1.real_variable_index[i] << std::endl; - if (mmap1[vname1] != mmap2[vname2]) { - info(RED, - "{", - mmap1[vname1], - ", ", - mmap2[vname2], - "}", - ", // ", - vname1, - ", ", - vname2, - " -> ", - c1.real_variable_index[i], - RESET); - } + std::string new_line = "{" + mmap1[vname1] + ", " + mmap2[vname2] + "}, // " + vname1 + ", " + vname2; + + if (c1.real_variable_index[i] != i) { + new_line += " -> " + std::to_string(c1.real_variable_index[i]); + } + + if (mmap1[vname1] != mmap2[vname2]) { + info(RED, new_line, RESET); } + myfile << new_line << std::endl; + ; + + packed_witness.push_back({ string_to_fr(mmap1[vname1], base), string_to_fr(mmap2[vname2], base) }); } myfile << "};"; myfile.close(); + if (pack) { + msgpack::sbuffer buffer; + msgpack::pack(buffer, packed_witness); + + std::fstream myfile; + myfile.open(fname + ".pack", std::ios::out | std::ios::trunc | std::ios::binary); + + myfile.write(buffer.data(), static_cast(buffer.size())); + myfile.close(); + } + std::unordered_map vterms; for (const auto& vname : special) { vterms.insert({ vname + "_1", c1[vname] }); @@ -91,10 +133,12 @@ void default_model(const std::vector& special, * @param c first circuit * @param s solver * @param fname file to store the resulting witness if succeded + * @param pack flags out to pack the resulting witness using msgpack */ void default_model_single(const std::vector& special, smt_circuit::CircuitBase& c, - const std::string& fname) + const std::string& fname, + bool pack) { std::vector vterms; vterms.reserve(c.get_num_vars()); @@ -108,17 +152,34 @@ void default_model_single(const std::vector& special, std::fstream myfile; myfile.open(fname, std::ios::out | std::ios::trunc | std::ios::binary); myfile << "w = {" << std::endl; + + std::vector packed_witness; + packed_witness.reserve(c.get_num_vars()); + int base = c.type == smt_terms::TermType::BVTerm ? 2 : 10; + for (size_t i = 0; i < c.get_num_vars(); i++) { std::string vname = vterms[i].toString(); - if (c.real_variable_index[i] == i) { - myfile << mmap[vname] << ", // " << vname << std::endl; - } else { - myfile << mmap[vname] << ", // " << vname << " -> " << c.real_variable_index[i] << std::endl; + std::string new_line = mmap[vname] + ", // " + vname; + if (c.real_variable_index[i] != i) { + new_line += " -> " + std::to_string(c.real_variable_index[i]); } + myfile << new_line << std::endl; + packed_witness.push_back(string_to_fr(mmap[vname], base)); } myfile << "};"; myfile.close(); + if (pack) { + msgpack::sbuffer buffer; + msgpack::pack(buffer, packed_witness); + + std::fstream myfile; + myfile.open(fname + ".pack", std::ios::out | std::ios::trunc | std::ios::binary); + + myfile.write(buffer.data(), static_cast(buffer.size())); + myfile.close(); + } + std::unordered_map vterms1; for (const auto& vname : special) { vterms1.insert({ vname, c[vname] }); @@ -130,6 +191,62 @@ void default_model_single(const std::vector& special, } } +/** + * @brief Import witness, obtained by solver, from file. + * @details Imports the witness, that was packed by default_model function + * + * @param fname + * @return std::vector> + */ +std::vector> import_witness(const std::string& fname) +{ + std::ifstream fin; + fin.open(fname, std::ios::ate | std::ios::binary); + if (!fin.is_open()) { + throw std::invalid_argument("file not found"); + } + if (fin.tellg() == -1) { + throw std::invalid_argument("something went wrong"); + } + + uint64_t fsize = static_cast(fin.tellg()); + fin.seekg(0, std::ios_base::beg); + + std::vector> res; + char* encoded_data = new char[fsize]; + fin.read(encoded_data, static_cast(fsize)); + msgpack::unpack(encoded_data, fsize).get().convert(res); + return res; +} + +/** + * @brief Import witness, obtained by solver, from file. + * @details Imports the witness, that was packed by default_model_single function + * + * @param fname + * @return std::vector> + */ +std::vector import_witness_single(const std::string& fname) +{ + std::ifstream fin; + fin.open(fname, std::ios::ate | std::ios::binary); + if (!fin.is_open()) { + throw std::invalid_argument("file not found"); + } + if (fin.tellg() == -1) { + throw std::invalid_argument("something went wrong"); + } + + uint64_t fsize = static_cast(fin.tellg()); + fin.seekg(0, std::ios_base::beg); + + std::vector res; + char* encoded_data = new char[fsize]; + fin.read(encoded_data, static_cast(fsize)); + msgpack::unpack(encoded_data, fsize).get().convert(res); + return res; +} + /** * @brief Get the solver result and amount of time * that it took to solve. diff --git a/barretenberg/cpp/src/barretenberg/smt_verification/util/smt_util.hpp b/barretenberg/cpp/src/barretenberg/smt_verification/util/smt_util.hpp index cf0521b88f0..dcf08418028 100644 --- a/barretenberg/cpp/src/barretenberg/smt_verification/util/smt_util.hpp +++ b/barretenberg/cpp/src/barretenberg/smt_verification/util/smt_util.hpp @@ -9,11 +9,16 @@ void default_model(const std::vector& special, smt_circuit::CircuitBase& c1, smt_circuit::CircuitBase& c2, - const std::string& fname = "witness.out"); + const std::string& fname = "witness.out", + bool pack = true); void default_model_single(const std::vector& special, smt_circuit::CircuitBase& c, - const std::string& fname = "witness.out"); + const std::string& fname = "witness.out", + bool pack = true); bool smt_timer(smt_solver::Solver* s); std::pair, std::vector> base4(uint32_t el); -void fix_range_lists(bb::UltraCircuitBuilder& builder); \ No newline at end of file +void fix_range_lists(bb::UltraCircuitBuilder& builder); +bb::fr string_to_fr(const std::string& number, int base, size_t step = 0); +std::vector> import_witness(const std::string& fname); +std::vector import_witness_single(const std::string& fname); \ No newline at end of file