Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(tree): Return membership-proof when getting tree values #136

Merged
merged 10 commits into from
Oct 9, 2024
9 changes: 7 additions & 2 deletions crates/common/src/test_utils.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
use crate::{
hashchain::Hashchain,
operation::{Operation, ServiceChallenge, SigningKey, VerifyingKey},
tree::{InsertProof, KeyDirectoryTree, Proof, SnarkableTree, UpdateProof},
tree::{
HashchainResponse::*, InsertProof, KeyDirectoryTree, Proof, SnarkableTree, UpdateProof,
},
};
use anyhow::{anyhow, Result};
#[cfg(not(feature = "secp256k1"))]
Expand Down Expand Up @@ -170,7 +172,10 @@ pub fn create_random_update(state: &mut TestTreeState, rng: &mut StdRng) -> Upda
.iter()
.nth(rng.gen_range(0..state.inserted_keys.len()))
.unwrap();
let mut hc = state.tree.get(key).unwrap().unwrap();

let Found(mut hc, _) = state.tree.get(key).unwrap() else {
panic!("No response found for key. Cannot perform update.");
};

let signing_key = create_mock_signing_key();
let verifying_key = signing_key.verifying_key();
Expand Down
129 changes: 85 additions & 44 deletions crates/common/src/tree.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@ use crate::{
},
};

use HashchainResponse::*;

pub const SPARSE_MERKLE_PLACEHOLDER_HASH: Digest =
Digest::new(*b"SPARSE_MERKLE_PLACEHOLDER_HASH__");

Expand Down Expand Up @@ -187,6 +189,22 @@ pub enum Proof {
Insert(InsertProof),
}

#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct MembershipProof {
pub root: Digest,
pub proof: SparseMerkleProof<Hasher>,
pub key: KeyHash,
pub value: Hashchain,
}

impl MembershipProof {
pub fn verify(&self) -> Result<()> {
let value = bincode::serialize(&self.value)?;
self.proof
.verify_existence(self.root.into(), self.key, value)
}
}

#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct NonMembershipProof {
pub root: Digest,
Expand Down Expand Up @@ -268,11 +286,21 @@ impl UpdateProof {
}
}

/// Enumerates possible responses when fetching tree values
#[derive(Debug)]
pub enum HashchainResponse {
/// When a hashchain was found, provides the value and its corresponding membership-proof
Found(Hashchain, MembershipProof),

/// When no hashchain was found for a specific key, provides the corresponding non-membership-proof
NotFound(NonMembershipProof),
}

pub trait SnarkableTree {
fn process_operation(&mut self, operation: &Operation) -> Result<Proof>;
fn insert(&mut self, key: KeyHash, value: Hashchain) -> Result<InsertProof>;
fn update(&mut self, key: KeyHash, value: HashchainEntry) -> Result<UpdateProof>;
fn get(&self, key: KeyHash) -> Result<Result<Hashchain, NonMembershipProof>>;
fn get(&self, key: KeyHash) -> Result<HashchainResponse>;
}

pub struct KeyDirectoryTree<S>
Expand Down Expand Up @@ -351,16 +379,19 @@ where
let hashed_id = Digest::hash(id);
let key_hash = KeyHash::with::<Hasher>(hashed_id);

let mut current_chain = self
.get(key_hash)?
.map_err(|_| anyhow!("Failed to get hashchain for ID {}", id))?;

let new_entry = current_chain.perform_operation(operation.clone())?;
match self.get(key_hash)? {
Found(mut current_chain, _) => {
let new_entry = current_chain.perform_operation(operation.clone())?;

debug!("updating hashchain for user id {}", id.clone());
let proof = self.update(key_hash, new_entry.clone())?;
debug!("updating hashchain for user id {}", id.clone());
let proof = self.update(key_hash, new_entry.clone())?;
distractedm1nd marked this conversation as resolved.
Show resolved Hide resolved

Ok(Proof::Update(proof))
Ok(Proof::Update(proof))
}
NotFound(_) => {
bail!("Failed to get hashchain for ID {}", id)
}
}
}
Operation::CreateAccount(CreateAccountArgs {
id,
Expand All @@ -373,17 +404,18 @@ where
let account_key_hash = KeyHash::with::<Hasher>(hashed_id);

// Verify that the account doesn't already exist
if self.get(account_key_hash)?.is_ok() {
if matches!(self.get(account_key_hash)?, Found(_, _)) {
bail!(DatabaseError::NotFoundError(format!(
"Account already exists for ID {}",
id
)));
}

let service_key_hash = KeyHash::with::<Hasher>(Digest::hash(service_id.as_bytes()));
let service_hashchain = self.get(service_key_hash)?.map_err(|_| {
anyhow!("Failed to get hashchain for service ID {}", service_id)
})?;

let Found(service_hashchain, _) = self.get(service_key_hash)? else {
bail!("Failed to get hashchain for service ID {}", service_id);
};
jns-ps marked this conversation as resolved.
Show resolved Hide resolved

let service_last_entry = service_hashchain.last().ok_or(anyhow!(
"Service hashchain is empty, could not retrieve challenge key"
Expand Down Expand Up @@ -436,12 +468,12 @@ where
let key_hash = KeyHash::with::<Hasher>(hashed_id);

// hashchain should not already exist
if self.get(key_hash)?.is_ok() {
let NotFound(_) = self.get(key_hash)? else {
bail!(DatabaseError::NotFoundError(format!(
"empty slot for ID {}",
id
)));
}
};

debug!("creating new hashchain for service id {}", id);
let chain = Hashchain::register_service(id.clone(), creation_gate.clone())?;
Expand Down Expand Up @@ -518,19 +550,25 @@ where
})
}

fn get(&self, key: KeyHash) -> Result<Result<Hashchain, NonMembershipProof>> {
fn get(&self, key: KeyHash) -> Result<HashchainResponse> {
let root = self.get_current_root()?.into();
let (value, proof) = self.jmt.get_with_proof(key, self.epoch)?;

match value {
Some(serialized_value) => {
let deserialized_value = Self::deserialize_value(&serialized_value)?;
Ok(Ok(deserialized_value))
let membership_proof = MembershipProof {
root,
proof,
key,
value: deserialized_value.clone(),
};
Ok(Found(deserialized_value, membership_proof))
}
None => {
let non_membership_proof = NonMembershipProof { root, proof, key };
Ok(NotFound(non_membership_proof))
}
None => Ok(Err(NonMembershipProof {
root: self.get_current_root()?.into(),
proof,
key,
})),
}
}
}
Expand All @@ -554,8 +592,13 @@ mod tests {
let insert_proof = tree_state.insert_account(account.clone()).unwrap();
assert!(insert_proof.verify().is_ok());

let get_result = tree_state.tree.get(account.key_hash).unwrap().unwrap();
assert_eq!(get_result, account.hashchain);
let Found(hashchain, membership_proof) = tree_state.tree.get(account.key_hash).unwrap()
else {
panic!("Expected hashchain to be found, but was not found.")
};

assert_eq!(hashchain, account.hashchain);
assert!(membership_proof.verify().is_ok());
}

#[test]
Expand Down Expand Up @@ -622,8 +665,8 @@ mod tests {
let update_proof = tree_state.update_account(account.clone()).unwrap();
assert!(update_proof.verify().is_ok());

let get_result = tree_state.tree.get(account.key_hash).unwrap().unwrap();
assert_eq!(get_result, account.hashchain);
let get_result = tree_state.tree.get(account.key_hash);
assert!(matches!(get_result.unwrap(), Found(hc, _) if hc == account.hashchain));
}

#[test]
Expand All @@ -645,11 +688,12 @@ mod tests {
let key = KeyHash::with::<Hasher>(b"non_existing_key");

let result = tree_state.tree.get(key).unwrap();
assert!(result.is_err());

if let Err(non_membership_proof) = result {
assert!(non_membership_proof.verify().is_ok());
}
let NotFound(non_membership_proof) = result else {
panic!("Hashchain found for key while it was expected to be missing");
};

assert!(non_membership_proof.verify().is_ok());
}

#[test]
Expand All @@ -672,11 +716,11 @@ mod tests {
tree_state.update_account(account1.clone()).unwrap();
tree_state.update_account(account2.clone()).unwrap();

let tree_hashchain1 = tree_state.tree.get(account1.key_hash).unwrap().unwrap();
let tree_hashchain2 = tree_state.tree.get(account2.key_hash).unwrap().unwrap();
let get_result1 = tree_state.tree.get(account1.key_hash);
let get_result2 = tree_state.tree.get(account2.key_hash);

assert_eq!(tree_hashchain1, account1.hashchain);
assert_eq!(tree_hashchain2, account2.hashchain);
assert!(matches!(get_result1.unwrap(), Found(hc, _) if hc == account1.hashchain));
assert!(matches!(get_result2.unwrap(), Found(hc, _) if hc == account2.hashchain));
}

#[test]
Expand All @@ -702,14 +746,11 @@ mod tests {
// Update account_2 using the correct key index
let last_proof = test_tree.update_account(account_2.clone()).unwrap();

assert_eq!(
test_tree.tree.get(account_1.key_hash).unwrap().unwrap(),
account_1.hashchain
);
assert_eq!(
test_tree.tree.get(account_2.key_hash).unwrap().unwrap(),
account_2.hashchain
);
let get_result1 = test_tree.tree.get(account_1.key_hash);
let get_result2 = test_tree.tree.get(account_2.key_hash);

assert!(matches!(get_result1.unwrap(), Found(hc, _) if hc == account_1.hashchain));
assert!(matches!(get_result2.unwrap(), Found(hc, _) if hc == account_2.hashchain));
assert_eq!(
last_proof.new_root,
test_tree.tree.get_current_root().unwrap()
Expand Down Expand Up @@ -778,7 +819,7 @@ mod tests {
println!("Final get result for key1: {:?}", get_result1);
println!("Final get result for key2: {:?}", get_result2);

assert_eq!(get_result1.unwrap().unwrap(), account1.hashchain);
assert_eq!(get_result2.unwrap().unwrap(), account2.hashchain);
assert!(matches!(get_result1.unwrap(), Found(hc, _) if hc == account1.hashchain));
assert!(matches!(get_result2.unwrap(), Found(hc, _) if hc == account2.hashchain));
}
}
29 changes: 12 additions & 17 deletions crates/prism/src/node_types/sequencer.rs
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
use anyhow::{anyhow, Context, Result};
use anyhow::{anyhow, bail, Context, Result};
use async_trait::async_trait;
use ed25519_dalek::SigningKey;
use jmt::KeyHash;
use prism_common::{
hashchain::Hashchain,
tree::{Batch, Digest, Hasher, KeyDirectoryTree, NonMembershipProof, Proof, SnarkableTree},
use prism_common::tree::{
Batch, Digest, HashchainResponse, HashchainResponse::*, Hasher, KeyDirectoryTree, Proof,
SnarkableTree,
};
use prism_errors::DataAvailabilityError;
use std::{self, collections::VecDeque, sync::Arc};
Expand Down Expand Up @@ -363,10 +363,7 @@ impl Sequencer {
tree.get_commitment().context("Failed to get commitment")
}

pub async fn get_hashchain(
&self,
id: &String,
) -> Result<Result<Hashchain, NonMembershipProof>> {
pub async fn get_hashchain(&self, id: &String) -> Result<HashchainResponse> {
let tree = self.tree.read().await;
let hashed_id = Digest::hash(id);
let key_hash = KeyHash::with::<Hasher>(hashed_id);
Expand All @@ -393,15 +390,13 @@ impl Sequencer {
Operation::RegisterService(_) => (),
Operation::CreateAccount(_) => (),
Operation::AddKey(_) | Operation::RevokeKey(_) => {
let hc = self.get_hashchain(&incoming_operation.id()).await?;
if let Ok(mut hc) = hc {
hc.perform_operation(incoming_operation.clone())?;
} else {
return Err(anyhow!(
"Hashchain not found for id: {}",
incoming_operation.id()
));
}
let hc_response = self.get_hashchain(&incoming_operation.id()).await?;

let Found(mut hc, _) = hc_response else {
bail!("Hashchain not found for id: {}", incoming_operation.id())
};

hc.perform_operation(incoming_operation.clone())?;
}
};

Expand Down
49 changes: 35 additions & 14 deletions crates/prism/src/webserver.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,12 @@ use indexed_merkle_tree::{
tree::{Proof, UpdateProof},
Hash as TreeHash,
};
use prism_common::{hashchain::Hashchain, operation::Operation};
use jmt::proof::SparseMerkleProof;
use prism_common::{
hashchain::Hashchain,
operation::Operation,
tree::{HashchainResponse, Hasher},
};
use serde::{Deserialize, Serialize};
use std::{self, sync::Arc};
use tower_http::cors::CorsLayer;
Expand Down Expand Up @@ -46,11 +51,10 @@ pub struct UserKeyRequest {
pub id: String,
}

// TODO: Retrieve Merkle proof of current epoch
#[derive(Serialize, Deserialize, ToSchema)]
pub struct UserKeyResponse {
pub hashchain: Hashchain,
// pub proof: MerkleProof
pub hashchain: Option<Hashchain>,
pub proof: SparseMerkleProof<Hasher>,
}

#[derive(OpenApi)]
Expand Down Expand Up @@ -142,16 +146,33 @@ async fn get_hashchain(
State(session): State<Arc<Sequencer>>,
Json(request): Json<UserKeyRequest>,
) -> impl IntoResponse {
match session.get_hashchain(&request.id).await {
Ok(hashchain_or_proof) => match hashchain_or_proof {
Ok(hashchain) => (StatusCode::OK, Json(UserKeyResponse { hashchain })).into_response(),
Err(non_inclusion_proof) => {
(StatusCode::BAD_REQUEST, Json(non_inclusion_proof)).into_response()
}
},
Err(err) => (
StatusCode::BAD_REQUEST,
format!("Couldn't get hashchain: {}", err),
let get_hashchain_result = session.get_hashchain(&request.id).await;
let Ok(hashchain_response) = get_hashchain_result else {
distractedm1nd marked this conversation as resolved.
Show resolved Hide resolved
return (
StatusCode::INTERNAL_SERVER_ERROR,
format!(
"Failed to retrieve hashchain or non-membership-proof: {}",
get_hashchain_result.unwrap_err()
),
)
.into_response();
};

match hashchain_response {
HashchainResponse::Found(hashchain, membership_proof) => (
StatusCode::OK,
Json(UserKeyResponse {
hashchain: Some(hashchain),
proof: membership_proof.proof,
}),
)
.into_response(),
HashchainResponse::NotFound(non_membership_proof) => (
StatusCode::OK,
Json(UserKeyResponse {
hashchain: None,
proof: non_membership_proof.proof,
}),
)
.into_response(),
}
Expand Down
Loading