Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat!: keccak in noir using a permutation opcode #3726

Closed
wants to merge 12 commits into from
2 changes: 1 addition & 1 deletion barretenberg/cpp/src/barretenberg/bb/main.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -400,7 +400,7 @@ void acvm_info(const std::string& output_path)
"width" : 3
},
"opcodes_supported" : ["arithmetic", "directive", "brillig", "memory_init", "memory_op"],
"black_box_functions_supported" : ["and", "xor", "range", "sha256", "blake2s", "keccak256", "schnorr_verify", "pedersen", "pedersen_hash", "ecdsa_secp256k1", "ecdsa_secp256r1", "fixed_base_scalar_mul", "recursive_aggregation"]
"black_box_functions_supported" : ["and", "xor", "range", "sha256", "blake2s", "keccak256", "keccak_f1600", "schnorr_verify", "pedersen", "pedersen_hash", "ecdsa_secp256k1", "ecdsa_secp256r1", "fixed_base_scalar_mul", "recursive_aggregation"]
})";

size_t length = strlen(jsonData);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,9 @@ void build_constraints(Builder& builder, acir_format const& constraint_system, b
for (const auto& constraint : constraint_system.keccak_var_constraints) {
create_keccak_var_constraints(builder, constraint);
}
for (const auto& constraint : constraint_system.keccak_permutations) {
create_keccak_permutations(builder, constraint);
}

// Add pedersen constraints
for (const auto& constraint : constraint_system.pedersen_constraints) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ struct acir_format {
std::vector<Blake2sConstraint> blake2s_constraints;
std::vector<KeccakConstraint> keccak_constraints;
std::vector<KeccakVarConstraint> keccak_var_constraints;
std::vector<Keccakf1600> keccak_permutations;
std::vector<PedersenConstraint> pedersen_constraints;
std::vector<PedersenHashConstraint> pedersen_hash_constraints;
std::vector<FixedBaseScalarMul> fixed_base_scalar_mul_constraints;
Expand All @@ -57,6 +58,7 @@ struct acir_format {
blake2s_constraints,
keccak_constraints,
keccak_var_constraints,
keccak_permutations,
pedersen_constraints,
pedersen_hash_constraints,
fixed_base_scalar_mul_constraints,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ TEST_F(AcirFormatTests, TestASingleConstraintNoPubInputs)
.blake2s_constraints = {},
.keccak_constraints = {},
.keccak_var_constraints = {},
.keccak_permutations = {},
.pedersen_constraints = {},
.pedersen_hash_constraints = {},
.fixed_base_scalar_mul_constraints = {},
Expand Down Expand Up @@ -145,6 +146,7 @@ TEST_F(AcirFormatTests, TestLogicGateFromNoirCircuit)
.blake2s_constraints = {},
.keccak_constraints = {},
.keccak_var_constraints = {},
.keccak_permutations = {},
.pedersen_constraints = {},
.pedersen_hash_constraints = {},
.fixed_base_scalar_mul_constraints = {},
Expand Down Expand Up @@ -209,6 +211,7 @@ TEST_F(AcirFormatTests, TestSchnorrVerifyPass)
.blake2s_constraints = {},
.keccak_constraints = {},
.keccak_var_constraints = {},
.keccak_permutations = {},
.pedersen_constraints = {},
.pedersen_hash_constraints = {},
.fixed_base_scalar_mul_constraints = {},
Expand Down Expand Up @@ -296,6 +299,7 @@ TEST_F(AcirFormatTests, TestSchnorrVerifySmallRange)
.blake2s_constraints = {},
.keccak_constraints = {},
.keccak_var_constraints = {},
.keccak_permutations = {},
.pedersen_constraints = {},
.pedersen_hash_constraints = {},
.fixed_base_scalar_mul_constraints = {},
Expand Down Expand Up @@ -402,6 +406,7 @@ TEST_F(AcirFormatTests, TestVarKeccak)
.blake2s_constraints = {},
.keccak_constraints = {},
.keccak_var_constraints = { keccak },
.keccak_permutations = {},
.pedersen_constraints = {},
.pedersen_hash_constraints = {},
.fixed_base_scalar_mul_constraints = {},
Expand All @@ -419,4 +424,48 @@ TEST_F(AcirFormatTests, TestVarKeccak)
EXPECT_EQ(verifier.verify_proof(proof), true);
}

TEST_F(AcirFormatTests, TestKeccakPermutation)
{
Keccakf1600
keccak_permutation{
.state = { 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25 },
.result = { 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38,
39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50 },
};

acir_format constraint_system{ .varnum = 51,
.public_inputs = {},
.logic_constraints = {},
.range_constraints = {},
.sha256_constraints = {},
.schnorr_constraints = {},
.ecdsa_k1_constraints = {},
.ecdsa_r1_constraints = {},
.blake2s_constraints = {},
.keccak_constraints = {},
.keccak_var_constraints = {},
.keccak_permutations = { keccak_permutation },
.pedersen_constraints = {},
.pedersen_hash_constraints = {},
.hash_to_field_constraints = {},
.fixed_base_scalar_mul_constraints = {},
.recursion_constraints = {},
.constraints = {},
.block_constraints = {} };

WitnessVector witness{ 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17,
18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34,
35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50 };

auto builder = create_circuit_with_witness(constraint_system, witness);

auto composer = Composer();
auto prover = composer.create_ultra_with_keccak_prover(builder);
auto proof = prover.construct_proof();

auto verifier = composer.create_ultra_with_keccak_verifier(builder);

EXPECT_EQ(verifier.verify_proof(proof), true);
}

} // namespace acir_format::tests
Original file line number Diff line number Diff line change
Expand Up @@ -180,6 +180,11 @@ void handle_blackbox_func_call(Circuit::Opcode::BlackBoxFuncCall const& arg, aci
.result = map(arg.outputs, [](auto& e) { return e.value; }),
.var_message_size = arg.var_message_size.witness.value,
});
} else if constexpr (std::is_same_v<T, Circuit::BlackBoxFuncCall::Keccakf1600>) {
af.keccak_permutations.push_back(Keccakf1600{
.state = map(arg.inputs, [](auto& e) { return e.witness.value; }),
.result = map(arg.outputs, [](auto& e) { return e.value; }),
});
} else if constexpr (std::is_same_v<T, Circuit::BlackBoxFuncCall::RecursiveAggregation>) {
auto c = RecursionConstraint{
.key = map(arg.verification_key, [](auto& e) { return e.witness.value; }),
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
#include "keccak_constraint.hpp"
#include "barretenberg/stdlib/hash/keccak/keccak.hpp"
#include "barretenberg/stdlib/primitives/circuit_builders/circuit_builders_fwd.hpp"
#include "round.hpp"

namespace acir_format {
Expand Down Expand Up @@ -73,13 +74,40 @@ template <typename Builder> void create_keccak_var_constraints(Builder& builder,
}
}

template <typename Builder> void create_keccak_permutations(Builder& builder, const Keccakf1600& constraint)
{
using field_ct = proof_system::plonk::stdlib::field_t<Builder>;

// Create the array containing the permuted state
std::array<field_ct, proof_system::plonk::stdlib::keccak<Builder>::NUM_KECCAK_LANES> state;

// Get the witness assignment for each witness index
// Write the witness assignment to the byte_array
for (size_t i = 0; i < constraint.state.size(); ++i) {
info(constraint.state[i]);
state[i] = field_ct::from_witness_index(&builder, constraint.state[i]);
}

std::array<field_ct, 25> output_state =
proof_system::plonk::stdlib::keccak<Builder>::permutation_opcode(state, &builder);

for (size_t i = 0; i < output_state.size(); ++i) {
builder.assert_equal(output_state[i].normalize().witness_index, constraint.result[i]);
}
}
template void create_keccak_constraints<UltraCircuitBuilder>(UltraCircuitBuilder& builder,
const KeccakConstraint& constraint);
template void create_keccak_var_constraints<UltraCircuitBuilder>(UltraCircuitBuilder& builder,
const KeccakVarConstraint& constraint);
template void create_keccak_permutations<UltraCircuitBuilder>(UltraCircuitBuilder& builder,
const Keccakf1600& constraint);

template void create_keccak_constraints<GoblinUltraCircuitBuilder>(GoblinUltraCircuitBuilder& builder,
const KeccakConstraint& constraint);
template void create_keccak_var_constraints<GoblinUltraCircuitBuilder>(GoblinUltraCircuitBuilder& builder,
const KeccakVarConstraint& constraint);

template void create_keccak_permutations<GoblinUltraCircuitBuilder>(GoblinUltraCircuitBuilder& builder,
const Keccakf1600& constraint);

} // namespace acir_format
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,15 @@ struct HashInput {
friend bool operator==(HashInput const& lhs, HashInput const& rhs) = default;
};

struct Keccakf1600 {
std::vector<uint32_t> state;
std::vector<uint32_t> result;

// For serialization, update with any new fields
MSGPACK_FIELDS(state, result);
friend bool operator==(Keccakf1600 const& lhs, Keccakf1600 const& rhs) = default;
};

struct KeccakConstraint {
std::vector<HashInput> inputs;
std::vector<uint32_t> result;
Expand All @@ -36,5 +45,6 @@ struct KeccakVarConstraint {

template <typename Builder> void create_keccak_constraints(Builder& builder, const KeccakConstraint& constraint);
template <typename Builder> void create_keccak_var_constraints(Builder& builder, const KeccakVarConstraint& constraint);
template <typename Builder> void create_keccak_permutations(Builder& builder, const Keccakf1600& constraint);

} // namespace acir_format
62 changes: 62 additions & 0 deletions barretenberg/cpp/src/barretenberg/dsl/acir_format/serde/acir.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -155,6 +155,15 @@ struct BlackBoxFuncCall {
static Keccak256VariableLength bincodeDeserialize(std::vector<uint8_t>);
};

struct Keccakf1600 {
std::vector<Circuit::FunctionInput> inputs;
std::vector<Circuit::Witness> outputs;

friend bool operator==(const Keccakf1600&, const Keccakf1600&);
std::vector<uint8_t> bincodeSerialize() const;
static Keccakf1600 bincodeDeserialize(std::vector<uint8_t>);
};

struct RecursiveAggregation {
std::vector<Circuit::FunctionInput> verification_key;
std::vector<Circuit::FunctionInput> proof;
Expand All @@ -181,6 +190,7 @@ struct BlackBoxFuncCall {
FixedBaseScalarMul,
Keccak256,
Keccak256VariableLength,
Keccakf1600,
RecursiveAggregation>
value;

Expand Down Expand Up @@ -2520,6 +2530,58 @@ Circuit::BlackBoxFuncCall::Keccak256VariableLength serde::Deserializable<

namespace Circuit {

inline bool operator==(const BlackBoxFuncCall::Keccakf1600& lhs, const BlackBoxFuncCall::Keccakf1600& rhs)
{
if (!(lhs.inputs == rhs.inputs)) {
return false;
}
if (!(lhs.outputs == rhs.outputs)) {
return false;
}
return true;
}

inline std::vector<uint8_t> BlackBoxFuncCall::Keccakf1600::bincodeSerialize() const
{
auto serializer = serde::BincodeSerializer();
serde::Serializable<BlackBoxFuncCall::Keccakf1600>::serialize(*this, serializer);
return std::move(serializer).bytes();
}

inline BlackBoxFuncCall::Keccakf1600 BlackBoxFuncCall::Keccakf1600::bincodeDeserialize(std::vector<uint8_t> input)
{
auto deserializer = serde::BincodeDeserializer(input);
auto value = serde::Deserializable<BlackBoxFuncCall::Keccakf1600>::deserialize(deserializer);
if (deserializer.get_buffer_offset() < input.size()) {
throw_or_abort("Some input bytes were not read");
}
return value;
}

} // end of namespace Circuit

template <>
template <typename Serializer>
void serde::Serializable<Circuit::BlackBoxFuncCall::Keccakf1600>::serialize(
const Circuit::BlackBoxFuncCall::Keccakf1600& obj, Serializer& serializer)
{
serde::Serializable<decltype(obj.inputs)>::serialize(obj.inputs, serializer);
serde::Serializable<decltype(obj.outputs)>::serialize(obj.outputs, serializer);
}

template <>
template <typename Deserializer>
Circuit::BlackBoxFuncCall::Keccakf1600 serde::Deserializable<Circuit::BlackBoxFuncCall::Keccakf1600>::deserialize(
Deserializer& deserializer)
{
Circuit::BlackBoxFuncCall::Keccakf1600 obj;
obj.inputs = serde::Deserializable<decltype(obj.inputs)>::deserialize(deserializer);
obj.outputs = serde::Deserializable<decltype(obj.outputs)>::deserialize(deserializer);
return obj;
}

namespace Circuit {

inline bool operator==(const BlackBoxFuncCall::RecursiveAggregation& lhs,
const BlackBoxFuncCall::RecursiveAggregation& rhs)
{
Expand Down
Loading
Loading