Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

variable-length keccak #441

Merged
merged 15 commits into from
May 18, 2023
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
194 changes: 166 additions & 28 deletions cpp/src/barretenberg/stdlib/hash/keccak/keccak.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -487,13 +487,16 @@ template <typename Composer> void keccak<Composer>::keccakf1600(keccak_state& in
template <typename Composer>
void keccak<Composer>::sponge_absorb(keccak_state& internal,
const std::vector<field_ct>& input_buffer,
const std::vector<field_ct>& msb_buffer)
const std::vector<field_ct>& msb_buffer,
const field_ct& num_blocks_with_data)
{
const size_t l = input_buffer.size();

const size_t num_blocks = l / (BLOCK_SIZE / 8);

for (size_t i = 0; i < num_blocks; ++i) {
// create a copy of our keccak state in case we need to revert this hash block application
keccak_state previous = internal;
if (i == 0) {
for (size_t j = 0; j < LIMBS_PER_BLOCK; ++j) {
internal.state[j] = input_buffer[j];
Expand All @@ -506,13 +509,27 @@ void keccak<Composer>::sponge_absorb(keccak_state& internal,
} else {
for (size_t j = 0; j < LIMBS_PER_BLOCK; ++j) {
internal.state[j] += input_buffer[i * LIMBS_PER_BLOCK + j];

internal.state[j] = normalize_and_rotate<0>(internal.state[j], internal.state_msb[j]);
}
}

compute_twisted_state(internal);
keccakf1600(internal);

// if `i >= num_blocks_with_data` then we want to revert the effects of this block and set `internal_state` to
// equal `previous`.
// This can happen for circuits where the input hash size is not known at circuit-compile time (only the maximum
// hash size).
// For example, a circuit that hashes up to 544 bytes (but maybe less depending on the witness assignment)
bool_ct block_predicate = field_ct(i).template ranged_less_than<8>(num_blocks_with_data);

for (size_t j = 0; j < 25; ++j) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What is this 25? This might be confusing to anyone updating this code. It's better to introduce a constant

internal.state[j] = field_ct::conditional_assign(block_predicate, internal.state[j], previous.state[j]);
internal.state_msb[j] =
field_ct::conditional_assign(block_predicate, internal.state_msb[j], previous.state_msb[j]);
internal.twisted_state[j] =
field_ct::conditional_assign(block_predicate, internal.twisted_state[j], previous.twisted_state[j]);
}
}
}

Expand All @@ -538,37 +555,53 @@ template <typename Composer> byte_array<Composer> keccak<Composer>::sponge_squee
return result;
}

template <typename Composer> stdlib::byte_array<Composer> keccak<Composer>::hash(byte_array_ct& input)
/**
* @brief Convert the input buffer into 8-bit keccak lanes in little-endian form.
* Additionally, insert padding bytes if required,
* and add the keccak terminating bytes 0x1/0x80
* (0x1 inserted after the final byte of input data)
* (0x80 inserted at the end of the final block)
*
* @tparam Composer
* @param input
* @param num_bytes
* @return std::vector<field_t<Composer>>
*/
template <typename Composer>
std::vector<field_t<Composer>> keccak<Composer>::format_input_lanes(byte_array_ct& input, const uint32_ct& num_bytes)
{
auto ctx = input.get_context();

if (ctx == nullptr) {
// if buffer is constant compute hash and return w/o creating constraints
byte_array_ct output(nullptr, 32);
const std::vector<uint8_t> result = hash_native(input.get_value());
for (size_t i = 0; i < 32; ++i) {
output.set_byte(i, result[i]);
}
return output;
}
auto* ctx = input.get_context();

// We require that `num_bytes` does not exceed the size of our input byte array.
// (can be less if the hash size is not known at circuit-compile time, only the maximum)
ASSERT(input.size() >= static_cast<size_t>(num_bytes.get_value()));
field_ct(num_bytes > uint32_ct(static_cast<uint32_t>(input.size()))).assert_equal(0);
const size_t input_size = input.size();

// copy input into buffer and pad
const size_t blocks = input_size / BLOCK_SIZE;
const size_t blocks_length = (BLOCK_SIZE * (blocks + 1));
// max_blocks_length = maximum number of bytes to hash
const size_t max_blocks = (input_size + BLOCK_SIZE) / BLOCK_SIZE;
const size_t max_blocks_length = (BLOCK_SIZE * (max_blocks));

byte_array_ct block_bytes(input);

const size_t byte_difference = blocks_length - input_size;
const size_t byte_difference = max_blocks_length - input_size;
byte_array_ct padding_bytes(ctx, byte_difference);
for (size_t i = 0; i < byte_difference; ++i) {
padding_bytes.set_byte(i, witness_ct::create_constant_witness(ctx, 0));
}

block_bytes.write(padding_bytes);
block_bytes.set_byte(input_size, witness_ct::create_constant_witness(ctx, 0x1));
block_bytes.set_byte(block_bytes.size() - 1, witness_ct::create_constant_witness(ctx, 0x80));

uint32_ct num_real_blocks = (num_bytes + BLOCK_SIZE) / BLOCK_SIZE;
uint32_ct num_real_blocks_bytes = num_real_blocks * BLOCK_SIZE;

// Keccak requires that 0x1 is appended after the final byte of input data.
// Similarly, the final byte of the final padded block must be 0x80.
// If `num_bytes` is constant then we know where to write these values at circuit-compile time
if (num_bytes.is_constant()) {
const auto terminating_byte = static_cast<size_t>(num_bytes.get_value());
const auto terminating_block_byte = static_cast<size_t>(num_real_blocks_bytes.get_value()) - 1;
block_bytes.set_byte(terminating_byte, witness_ct::create_constant_witness(ctx, 0x1));
block_bytes.set_byte(terminating_block_byte, witness_ct::create_constant_witness(ctx, 0x80));
}

// keccak lanes interpret memory as little-endian integers,
// means we need to swap our byte ordering...
Expand All @@ -587,13 +620,11 @@ template <typename Composer> stdlib::byte_array<Composer> keccak<Composer>::hash
block_bytes.set_byte(i + 7, temp[0]);
}
const size_t byte_size = block_bytes.size();
keccak_state internal;
internal.context = ctx;

const size_t num_limbs = byte_size / WORD_SIZE;
std::vector<field_ct> converted_buffer(num_limbs);
std::vector<field_ct> msb_buffer(num_limbs);
std::vector<field_ct> sliced_buffer;

// populate a vector of 64-bit limbs from our byte array
for (size_t i = 0; i < num_limbs; ++i) {
field_ct sliced;
if (i * WORD_SIZE + WORD_SIZE > byte_size) {
Expand All @@ -604,12 +635,119 @@ template <typename Composer> stdlib::byte_array<Composer> keccak<Composer>::hash
} else {
sliced = field_ct(block_bytes.slice(i * WORD_SIZE, WORD_SIZE));
}
const auto accumulators = plookup_read::get_lookup_accumulators(KECCAK_FORMAT_INPUT, sliced);
sliced_buffer.emplace_back(sliced);
}

// If the input preimage size is known at circuit-compile time, nothing more to do.
if (num_bytes.is_constant()) {
return sliced_buffer;
}

// If we do *not* know the preimage size at circuit-compile time, we have several steps we must execute:
// 1. Validate that `input[num_bytes], input[num_bytes + 1], ..., input[input.size() - 1]` are all ZERO.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do we need to do this?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the keccak algorithm takes the input message and concatenates with the byte 0x1. This string is then zero-padded to fill a multiple of the keccak block length. The final byte of this padded string is set to 0x80.

We have a need to compute keccak hashes in a circuit where the length of the preimage is not known, only the maximum length. This happens when computing keccak for Ethereum storage proofs (the hash preimage is the RLP-encoded node data. The RLP encoding does not produce a fixed length)

When the input is variable-length (up to a max), we compute a zero-padded byte array and dynamically insert the two formatting bytes depending on the witness value of the length parameter

// 2. Insert the keccak input terminating byte `0x1` at `input[num_bytes]`
// 3. Insert the keccak block terminating byte `0x80` at `input[num_real_block_bytes - 1]`
// We do these steps after we have converted into 64 bit lanes as we have fewer elements to iterate over (is
// cheaper)
std::vector<field_ct> lanes = sliced_buffer;

// compute the lane index of the terminating input byte
field_ct num_bytes_as_field(num_bytes);
field_ct terminating_index = field_ct(uint32_ct((num_bytes) / WORD_SIZE));

// compute the value we must add to limbs[terminating_index] to insert 0x1 at the correct byte index (accounting for
// the previous little-endian conversion)
field_ct terminating_index_bytes_shift = (num_bytes_as_field) - (terminating_index * WORD_SIZE);
field_ct terminating_index_limb_addition = field_ct(256).pow(terminating_index_bytes_shift);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

pow is quite inefficient here, I think. Usually we expect 32-bit index. Here it will obviously be less. Maybe we should parametrize pow?


// compute the lane index of the terminating block byte
field_ct terminating_block_index = field_ct((num_real_blocks_bytes - 1) / WORD_SIZE);
field_ct terminating_block_bytes_shift =
field_ct(num_real_blocks_bytes - 1) - (terminating_block_index * WORD_SIZE);
// compute the value we must add to limbs[terminating_index] to insert 0x1 at the correct byte index (accounting for
// the previous little-endian conversion)
field_ct terminating_block_limb_addition = field_ct(0x80ULL) * field_ct(256).pow(terminating_block_bytes_shift);

// validate the number of lanes is less than the default plookup size (we use the default size to do a cheap `<`
// check later on. Should be fine as this translates to ~2MB of input data)
ASSERT(uint256_t(sliced_buffer.size()) < (uint256_t(1ULL) << Composer::DEFAULT_PLOOKUP_RANGE_BITNUM));

// If the terminating input byte index matches the terminating block byte index, we set the byte to 0x80.
// If we trigger this case, set `terminating_index_limb_addition` to 0 so that we do not write `0x01 + 0x80`
terminating_index_limb_addition = field_ct::conditional_assign(
field_ct(num_bytes) == field_ct(num_real_blocks_bytes) - 1, 0, terminating_index_limb_addition);
field_ct terminating_limb;

// iterate over our lanes to perform the above listed checks
for (size_t i = 0; i < sliced_buffer.size(); ++i) {
// If i > terminating_index, limb must be 0
bool_ct limb_must_be_zeroes =
terminating_index.template ranged_less_than<Composer::DEFAULT_PLOOKUP_RANGE_BITNUM>(field_ct(i));
// Is i == terminating_limb_index?
bool_ct is_terminating_limb = terminating_index == field_ct(i);

// Is i == terminating_block_limb?
bool_ct is_terminating_block_limb = terminating_block_index == field_ct(i);

(lanes[i] * limb_must_be_zeroes).assert_equal(0);

// If i == terminating_limb_index, *some* of the limb must be zero.
// Assign to `terminating_limb` that we will check later.
terminating_limb = lanes[i].madd(is_terminating_limb, terminating_limb);

// conditionally insert terminating_index_limb_addition and/or terminating_block_limb_addition into limb
// (addition is as good as "insertion" as we check the original byte value at this position is 0)
lanes[i] = terminating_index_limb_addition.madd(is_terminating_limb, lanes[i]);
lanes[i] = terminating_block_limb_addition.madd(is_terminating_block_limb, lanes[i]);
}

// check terminating_limb has correct number of zeroes
{
// we know terminating_limb < 2^64
// offset of first zero byte = (num_bytes % 8)
// i.e. in our 8-byte limb, bytes[(8 - offset), ..., 7] are zeroes in little-endian form
// i.e. we multiply the limb by the above, the result should still be < 2^64 (but only if excess bytes are 0)
field_ct limb_shift = field_ct(256).pow(field_ct(8) - terminating_index_bytes_shift);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here we expect the maximum power to be 8, which is an accumulation of 3 squares maximum. We will be doing 32 instead.

field_ct to_constrain = terminating_limb * limb_shift;
to_constrain.create_range_constraint(WORD_SIZE * 8);
}
return lanes;
}

template <typename Composer>
stdlib::byte_array<Composer> keccak<Composer>::hash(byte_array_ct& input, const uint32_ct& num_bytes)
{
auto ctx = input.get_context();

ASSERT(uint256_t(num_bytes.get_value()) <= input.size());

if (ctx == nullptr) {
// if buffer is constant compute hash and return w/o creating constraints
byte_array_ct output(nullptr, 32);
const std::vector<uint8_t> result = hash_native(input.get_value());
for (size_t i = 0; i < 32; ++i) {
output.set_byte(i, result[i]);
}
return output;
}

// convert the input byte array into 64-bit keccak lanes (+ apply padding)
auto formatted_slices = format_input_lanes(input, num_bytes);

std::vector<field_ct> converted_buffer(formatted_slices.size());
std::vector<field_ct> msb_buffer(formatted_slices.size());

// populate keccak_state, convert our 64-bit lanes into an extended base-11 representation
keccak_state internal;
internal.context = ctx;
for (size_t i = 0; i < formatted_slices.size(); ++i) {
const auto accumulators = plookup_read::get_lookup_accumulators(KECCAK_FORMAT_INPUT, formatted_slices[i]);
converted_buffer[i] = accumulators[ColumnIdx::C2][0];
msb_buffer[i] = accumulators[ColumnIdx::C3][accumulators[ColumnIdx::C3].size() - 1];
}

sponge_absorb(internal, converted_buffer, msb_buffer);
uint32_ct num_blocks_with_data = (num_bytes + BLOCK_SIZE) / BLOCK_SIZE;
sponge_absorb(internal, converted_buffer, msb_buffer, field_ct(num_blocks_with_data));

auto result = sponge_squeeze(internal);

Expand Down
10 changes: 8 additions & 2 deletions cpp/src/barretenberg/stdlib/hash/keccak/keccak.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,9 @@ template <typename Composer> class keccak {
public:
using witness_ct = stdlib::witness_t<Composer>;
using field_ct = stdlib::field_t<Composer>;
using bool_ct = stdlib::bool_t<Composer>;
using byte_array_ct = stdlib::byte_array<Composer>;
using uint32_ct = stdlib::uint32<Composer>;

// base of extended representation we use for efficient logic operations
static constexpr uint256_t BASE = 11;
Expand Down Expand Up @@ -166,10 +168,14 @@ template <typename Composer> class keccak {
static void iota(keccak_state& state, size_t round);
static void sponge_absorb(keccak_state& internal,
const std::vector<field_ct>& input_buffer,
const std::vector<field_ct>& msb_buffer);
const std::vector<field_ct>& msb_buffer,
const field_ct& num_blocks_with_data);
static byte_array_ct sponge_squeeze(keccak_state& internal);
static void keccakf1600(keccak_state& state);
static byte_array_ct hash(byte_array_ct& input);
static byte_array_ct hash(byte_array_ct& input, const uint32_ct& num_bytes);
static byte_array_ct hash(byte_array_ct& input) { return hash(input, static_cast<uint32_t>(input.size())); };

static std::vector<field_ct> format_input_lanes(byte_array_ct& input, const uint32_ct& num_bytes);

static std::vector<uint8_t> hash_native(const std::vector<uint8_t>& data)
{
Expand Down
84 changes: 84 additions & 0 deletions cpp/src/barretenberg/stdlib/hash/keccak/keccak.test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ typedef stdlib::byte_array<Composer> byte_array;
typedef stdlib::public_witness_t<Composer> public_witness_t;
typedef stdlib::field_t<Composer> field_ct;
typedef stdlib::witness_t<Composer> witness_ct;
typedef stdlib::uint32<Composer> uint32_ct;

namespace {
auto& engine = numeric::random::get_debug_engine();
Expand Down Expand Up @@ -156,6 +157,55 @@ TEST(stdlib_keccak, keccak_chi_output_table)
EXPECT_EQ(proof_result, true);
}

TEST(stdlib_keccak, test_format_input_lanes)
{
Composer composer = Composer();

for (size_t i = 543; i < 544; ++i) {
std::cout << "i = " << i << std::endl;
std::string input;
for (size_t j = 0; j < i; ++j) {
input += "a";
}

// std::string input = "abcdefghijklmnopqrstuvwxyz0123456789abcdefghijklmnopqrstuvwxyz01";
std::vector<uint8_t> input_v(input.begin(), input.end());
const size_t excess_zeroes = i % 543;
std::vector<uint8_t> input_padded_v(input.begin(), input.end());
for (size_t k = 0; k < excess_zeroes; ++k) {
input_padded_v.push_back(0);
}
byte_array input_arr(&composer, input_v);
byte_array input_padded_arr(&composer, input_padded_v);

auto num_bytes_native = static_cast<uint32_t>(i);
uint32_ct num_bytes(witness_ct(&composer, num_bytes_native));
std::vector<field_ct> result = stdlib::keccak<Composer>::format_input_lanes(input_padded_arr, num_bytes);
std::vector<field_ct> expected = stdlib::keccak<Composer>::format_input_lanes(input_arr, num_bytes_native);

EXPECT_GT(result.size(), expected.size() - 1);

for (size_t j = 0; j < expected.size(); ++j) {
// std::cout << "i = " << i << std::endl;
zac-williamson marked this conversation as resolved.
Show resolved Hide resolved
EXPECT_EQ(result[j].get_value(), expected[j].get_value());
}
for (size_t j = expected.size(); j < result.size(); ++j) {
EXPECT_EQ(result[j].get_value(), 0);
}
}

composer.print_num_gates();
zac-williamson marked this conversation as resolved.
Show resolved Hide resolved

auto prover = composer.create_prover();
std::cout << "prover circuit_size = " << prover.key->circuit_size << std::endl;
zac-williamson marked this conversation as resolved.
Show resolved Hide resolved
auto verifier = composer.create_verifier();

auto proof = prover.construct_proof();

bool proof_result = verifier.verify_proof(proof);
EXPECT_EQ(proof_result, true);
}

TEST(stdlib_keccak, test_single_block)
{
Composer composer = Composer();
Expand Down Expand Up @@ -207,3 +257,37 @@ TEST(stdlib_keccak, test_double_block)
bool proof_result = verifier.verify_proof(proof);
EXPECT_EQ(proof_result, true);
}

TEST(stdlib_keccak, test_double_block_variable_length)
{
Composer composer = Composer();
std::string input = "";
for (size_t i = 0; i < 200; ++i) {
input += "a";
}
std::vector<uint8_t> input_v(input.begin(), input.end());

// add zero padding
std::vector<uint8_t> input_v_padded(input_v);
for (size_t i = 0; i < 51; ++i) {
input_v_padded.push_back(0);
}
byte_array input_arr(&composer, input_v_padded);

uint32_ct length(witness_ct(&composer, 200));
byte_array output = stdlib::keccak<Composer>::hash(input_arr, length);

std::vector<uint8_t> expected = stdlib::keccak<Composer>::hash_native(input_v);

EXPECT_EQ(output.get_value(), expected);

composer.print_num_gates();
zac-williamson marked this conversation as resolved.
Show resolved Hide resolved

auto prover = composer.create_prover();
auto verifier = composer.create_verifier();

auto proof = prover.construct_proof();

bool proof_result = verifier.verify_proof(proof);
EXPECT_EQ(proof_result, true);
}
Loading