Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: StandardCircuitBuilder create_logic_constraint and uint logic_operator #4530

Merged
merged 6 commits into from
Feb 9, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -301,6 +301,15 @@ std::vector<uint32_t> StandardCircuitBuilder_<FF>::decompose_into_base4_accumula
return accumulators;
}

/**
* @brief Create an AND or an XOR constraint
*
* @param a The first argument variable index
* @param b The second argument variable index
* @param num_bits The width of arguments. Has to be even
* @param is_xor_gate If true, create an xor gate, otherwise an and gate
* @return accumulator_triple_<FF> Accumulated witnesses (steps) for input arguments and output
*/
template <typename FF>
accumulator_triple_<FF> StandardCircuitBuilder_<FF>::create_logic_constraint(const uint32_t a,
const uint32_t b,
Expand All @@ -311,9 +320,14 @@ accumulator_triple_<FF> StandardCircuitBuilder_<FF>::create_logic_constraint(con

accumulator_triple_<FF> accumulators;

// Get the values of inputs
const uint256_t left_witness_value(this->get_variable(a));
const uint256_t right_witness_value(this->get_variable(b));

ASSERT(left_witness_value < (uint256_t(1) << num_bits));
ASSERT(right_witness_value < (uint256_t(1) << num_bits));

// We are starting accumulation with zeros
FF left_accumulator = FF::zero();
FF right_accumulator = FF::zero();
FF out_accumulator = FF::zero();
Expand All @@ -323,23 +337,34 @@ accumulator_triple_<FF> StandardCircuitBuilder_<FF>::create_logic_constraint(con
uint32_t out_accumulator_idx = this->zero_idx;
constexpr FF four = FF(4);
constexpr FF neg_two = -FF(2);
constexpr FF two = FF(2);

// Num bits is expected to be even
ASSERT(num_bits % 2 == 0);

// Accumulate the bits in quads starting from the high ones
for (size_t i = num_bits - 1; i < num_bits; i -= 2) {
// Get bit values of arguments
bool left_hi_val = left_witness_value.get_bit(i);
bool left_lo_val = left_witness_value.get_bit(i - 1);
bool right_hi_val = right_witness_value.get_bit((i));
bool right_lo_val = right_witness_value.get_bit(i - 1);

// Convert to wintesses
uint32_t left_hi_idx = this->add_variable(left_hi_val ? FF::one() : FF::zero());
uint32_t left_lo_idx = this->add_variable(left_lo_val ? FF::one() : FF::zero());
uint32_t right_hi_idx = this->add_variable(right_hi_val ? FF::one() : FF::zero());
uint32_t right_lo_idx = this->add_variable(right_lo_val ? FF::one() : FF::zero());

// Compute resulting bits
bool out_hi_val = is_xor_gate ? left_hi_val ^ right_hi_val : left_hi_val & right_hi_val;
bool out_lo_val = is_xor_gate ? left_lo_val ^ right_lo_val : left_lo_val & right_lo_val;

// Convert to witnesses
uint32_t out_hi_idx = this->add_variable(out_hi_val ? FF::one() : FF::zero());
uint32_t out_lo_idx = this->add_variable(out_lo_val ? FF::one() : FF::zero());

// Constrain all individual bit witnesses to be boolean
create_bool_gate(left_hi_idx);
create_bool_gate(right_hi_idx);
create_bool_gate(out_hi_idx);
Expand All @@ -348,6 +373,7 @@ accumulator_triple_<FF> StandardCircuitBuilder_<FF>::create_logic_constraint(con
create_bool_gate(right_lo_idx);
create_bool_gate(out_lo_idx);

// Create 2 individual xor or and gates
// a & b = ab
// a ^ b = a + b - 2ab
create_poly_gate({ left_hi_idx,
Expand All @@ -368,21 +394,36 @@ accumulator_triple_<FF> StandardCircuitBuilder_<FF>::create_logic_constraint(con
FF::neg_one(),
FF::zero() });

// Reconstruct the value of the left quad and add as witness
FF left_quad =
this->get_variable(left_lo_idx) + this->get_variable(left_hi_idx) + this->get_variable(left_hi_idx);
uint32_t left_quad_idx = this->add_variable(left_quad);

// Connect the left bits to the left quad
create_add_gate({ left_hi_idx, left_lo_idx, left_quad_idx, two, FF::one(), FF::neg_one(), FF::zero() });

// Reconstruct the value of the right quad and add as witness
FF right_quad =
this->get_variable(right_lo_idx) + this->get_variable(right_hi_idx) + this->get_variable(right_hi_idx);
FF out_quad = this->get_variable(out_lo_idx) + this->get_variable(out_hi_idx) + this->get_variable(out_hi_idx);

uint32_t left_quad_idx = this->add_variable(left_quad);
uint32_t right_quad_idx = this->add_variable(right_quad);

// Connect the left bits to the left quad
create_add_gate({ right_hi_idx, right_lo_idx, right_quad_idx, two, FF::one(), FF::neg_one(), FF::zero() });

// Reconstruct the value of the output quad and add as witness
FF out_quad = this->get_variable(out_lo_idx) + this->get_variable(out_hi_idx) + this->get_variable(out_hi_idx);
uint32_t out_quad_idx = this->add_variable(out_quad);

// Connect the out bits to the out quad
create_add_gate({ out_hi_idx, out_lo_idx, out_quad_idx, two, FF::one(), FF::neg_one(), FF::zero() });

// Compute the value of the left accumulator and add as witness
FF new_left_accumulator = left_accumulator + left_accumulator;
new_left_accumulator = new_left_accumulator + new_left_accumulator;
new_left_accumulator = new_left_accumulator + left_quad;
uint32_t new_left_accumulator_idx = this->add_variable(new_left_accumulator);

// Connect the left quad, previous accumulator and current accumulator
create_add_gate({ left_accumulator_idx,
left_quad_idx,
new_left_accumulator_idx,
Expand All @@ -391,11 +432,13 @@ accumulator_triple_<FF> StandardCircuitBuilder_<FF>::create_logic_constraint(con
FF::neg_one(),
FF::zero() });

// Compute the value of the right accumulator and add as witness
FF new_right_accumulator = right_accumulator + right_accumulator;
new_right_accumulator = new_right_accumulator + new_right_accumulator;
new_right_accumulator = new_right_accumulator + right_quad;
uint32_t new_right_accumulator_idx = this->add_variable(new_right_accumulator);

// Connect the right quad, previous accumulator and current accumulator
create_add_gate({ right_accumulator_idx,
right_quad_idx,
new_right_accumulator_idx,
Expand All @@ -404,18 +447,21 @@ accumulator_triple_<FF> StandardCircuitBuilder_<FF>::create_logic_constraint(con
FF::neg_one(),
FF::zero() });

// Compute the value of the out accumulator and add as witness
FF new_out_accumulator = out_accumulator + out_accumulator;
new_out_accumulator = new_out_accumulator + new_out_accumulator;
new_out_accumulator = new_out_accumulator + out_quad;
uint32_t new_out_accumulator_idx = this->add_variable(new_out_accumulator);

// Connect the out quad, previous accumulator and current accumulator
create_add_gate(
{ out_accumulator_idx, out_quad_idx, new_out_accumulator_idx, four, FF::one(), FF::neg_one(), FF::zero() });

accumulators.left.emplace_back(new_left_accumulator_idx);
accumulators.right.emplace_back(new_right_accumulator_idx);
accumulators.out.emplace_back(new_out_accumulator_idx);

// Update current accumulators
left_accumulator = new_left_accumulator;
left_accumulator_idx = new_left_accumulator_idx;

Expand All @@ -425,6 +471,9 @@ accumulator_triple_<FF> StandardCircuitBuilder_<FF>::create_logic_constraint(con
out_accumulator = new_out_accumulator;
out_accumulator_idx = new_out_accumulator_idx;
}
// Connect the accumulators to inputs
this->assert_equal(accumulators.left.back(), a);
this->assert_equal(accumulators.right.back(), b);
return accumulators;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -460,6 +460,7 @@ uint<Builder, Native> uint<Builder, Native>::logic_operator(const uint& other, c
const uint256_t rhs = other.get_value();
uint256_t out = 0;

// Compute the value of the result
switch (op_type) {
case AND: {
out = lhs & rhs;
Expand All @@ -473,11 +474,14 @@ uint<Builder, Native> uint<Builder, Native>::logic_operator(const uint& other, c
}
}

// If both inputs are constants, just output a new constant uint with the result
if (is_constant() && other.is_constant()) {
// returns a constant uint.
return uint<Builder, Native>(ctx, out);
}

// If one of the inputs is a constant, we need to create a witness from it, because we can only perform logical
// constraints between witnesses
const uint32_t lhs_idx = is_constant() ? ctx->add_variable(lhs) : witness_index;
const uint32_t rhs_idx = other.is_constant() ? ctx->add_variable(rhs) : other.witness_index;

Expand Down
Loading