Skip to content

Commit

Permalink
feat: implementing updates and simple insertions into TSMT
Browse files Browse the repository at this point in the history
  • Loading branch information
bobbinth committed Aug 10, 2023
1 parent 32598b8 commit 5b3b93b
Show file tree
Hide file tree
Showing 10 changed files with 745 additions and 64 deletions.
22 changes: 14 additions & 8 deletions assembly/src/ast/nodes/advice.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ pub enum AdviceInjectorNode {
PushU64div,
PushExt2intt,
PushSmtGet,
PushSmtInsert,
PushMapVal,
PushMapValImm { offset: u8 },
PushMapValN,
Expand All @@ -35,6 +36,7 @@ impl From<&AdviceInjectorNode> for AdviceInjector {
PushU64div => Self::DivU64,
PushExt2intt => Self::Ext2Intt,
PushSmtGet => Self::SmtGet,
PushSmtInsert => Self::SmtInsert,
PushMapVal => Self::MapValueToStack {
include_len: false,
key_offset: 0,
Expand Down Expand Up @@ -68,6 +70,7 @@ impl fmt::Display for AdviceInjectorNode {
PushU64div => write!(f, "push_u64div"),
PushExt2intt => write!(f, "push_ext2intt"),
PushSmtGet => write!(f, "push_smtget"),
PushSmtInsert => write!(f, "push_smtinsert"),
PushMapVal => write!(f, "push_mapval"),
PushMapValImm { offset } => write!(f, "push_mapval.{offset}"),
PushMapValN => write!(f, "push_mapvaln"),
Expand All @@ -86,14 +89,15 @@ impl fmt::Display for AdviceInjectorNode {
const PUSH_U64DIV: u8 = 0;
const PUSH_EXT2INTT: u8 = 1;
const PUSH_SMTGET: u8 = 2;
const PUSH_MAPVAL: u8 = 3;
const PUSH_MAPVAL_IMM: u8 = 4;
const PUSH_MAPVALN: u8 = 5;
const PUSH_MAPVALN_IMM: u8 = 6;
const PUSH_MTNODE: u8 = 7;
const INSERT_MEM: u8 = 8;
const INSERT_HDWORD: u8 = 9;
const INSERT_HDWORD_IMM: u8 = 10;
const PUSH_SMTINSERT: u8 = 3;
const PUSH_MAPVAL: u8 = 4;
const PUSH_MAPVAL_IMM: u8 = 5;
const PUSH_MAPVALN: u8 = 6;
const PUSH_MAPVALN_IMM: u8 = 7;
const PUSH_MTNODE: u8 = 8;
const INSERT_MEM: u8 = 9;
const INSERT_HDWORD: u8 = 10;
const INSERT_HDWORD_IMM: u8 = 11;

impl Serializable for AdviceInjectorNode {
fn write_into<W: ByteWriter>(&self, target: &mut W) {
Expand All @@ -102,6 +106,7 @@ impl Serializable for AdviceInjectorNode {
PushU64div => target.write_u8(PUSH_U64DIV),
PushExt2intt => target.write_u8(PUSH_EXT2INTT),
PushSmtGet => target.write_u8(PUSH_SMTGET),
PushSmtInsert => target.write_u8(PUSH_SMTINSERT),
PushMapVal => target.write_u8(PUSH_MAPVAL),
PushMapValImm { offset } => {
target.write_u8(PUSH_MAPVAL_IMM);
Expand Down Expand Up @@ -129,6 +134,7 @@ impl Deserializable for AdviceInjectorNode {
PUSH_U64DIV => Ok(AdviceInjectorNode::PushU64div),
PUSH_EXT2INTT => Ok(AdviceInjectorNode::PushExt2intt),
PUSH_SMTGET => Ok(AdviceInjectorNode::PushSmtGet),
PUSH_SMTINSERT => Ok(AdviceInjectorNode::PushSmtInsert),
PUSH_MAPVAL => Ok(AdviceInjectorNode::PushMapVal),
PUSH_MAPVAL_IMM => {
let offset = source.read_u8()?;
Expand Down
4 changes: 4 additions & 0 deletions assembly/src/ast/parsers/adv_ops.rs
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,10 @@ pub fn parse_adv_inject(op: &Token) -> Result<Node, ParsingError> {
2 => AdvInject(PushSmtGet),
_ => return Err(ParsingError::extra_param(op)),
},
"push_smtinsert" => match op.num_parts() {
2 => AdvInject(PushSmtInsert),
_ => return Err(ParsingError::extra_param(op)),
},
"push_mapval" => match op.num_parts() {
2 => AdvInject(PushMapVal),
3 => {
Expand Down
2 changes: 1 addition & 1 deletion core/src/operations/decorators/advice.rs
Original file line number Diff line number Diff line change
Expand Up @@ -181,7 +181,7 @@ pub enum AdviceInjector {
/// Where KEY is computed as hash(A || B, domain), where domain is provided via the immediate
/// value.
HdwordToMap { domain: Felt },

/// TODO: add docs
SmtInsert,
}
Expand Down
114 changes: 114 additions & 0 deletions processor/src/decorators/adv_stack_injectors.rs
Original file line number Diff line number Diff line change
Expand Up @@ -338,6 +338,120 @@ where

Ok(())
}

/// Pushes values onto the advice stack which are required for successful insertion of a
/// key-value pair into a Sparse Merkle Tree data structure.
///
/// The Sparse Merkle Tree is tiered, meaning it will have leaf depths in `{16, 32, 48, 64}`.
///
/// Inputs:
/// Operand stack: [VALUE, KEY, ROOT, ...]
/// Advice stack: [...]
///
/// Outputs:
/// Operand stack: [OLD_VALUE, NEW_ROOT, ...]
/// Advice stack, depends on the type of insert:
/// - Simple insert at depth 16: [d0, d1, ONE (is_simple_insert), ZERO (is_update)]
/// - Simple insert at depth 32 or 48: [d0, d1, ONE (is_simple_insert), ZERO (is_update), P_NODE]
/// - Update of an existing leaf: [ZERO (padding), d0, d1, ONE (is_update), OLD_VALUE]
///
/// Where:
/// - d0 is a boolean flag set to `1` if the depth is `16` or `48`.
/// - d1 is a boolean flag set to `1` if the depth is `16` or `32`.
/// - P_NODE is an internal node located at the tier above the insert tier.
/// - VALUE is the value to be inserted.
/// - OLD_VALUE is the value previously associated with the specified KEY.
/// - ROOT and NEW_ROOT are the roots of the TSMT prior and post the insert respectively.
///
/// # Errors
/// Will return an error if the provided Merkle root doesn't exist on the advice provider.
///
/// # Panics
/// Will panic as unimplemented if the target depth is `64`.
pub(super) fn push_smtinsert_inputs(&mut self) -> Result<(), ExecutionError> {
// get the key and tree root from the stack
let key = [self.stack.get(7), self.stack.get(6), self.stack.get(5), self.stack.get(4)];
let root = [self.stack.get(11), self.stack.get(10), self.stack.get(9), self.stack.get(8)];

// determine the depth of the first leaf or an empty tree node
let index = &key[3];
let depth = self.advice_provider.get_leaf_depth(root, &SMT_MAX_TREE_DEPTH, index)?;
debug_assert!(depth < 65);

// map the depth value to its tier; this rounds up depth to 16, 32, 48, or 64
let depth = SMT_NORMALIZED_DEPTHS[depth as usize];
if depth == 64 {
unimplemented!("handling of depth=64 tier hasn't been implemented yet");
}

// get the value of the node a this index/depth
let index = index.as_int() >> (64 - depth);
let index = Felt::new(index);
let node = self.advice_provider.get_tree_node(root, &Felt::new(depth as u64), &index)?;

// figure out what kind of insert we are doing; possible options are:
// - if the node is a root of an empty subtree, this is a simple insert.
// - if the node is a leaf, this could be either an update (for the same key), or a
// complex insert (i.e., the existing leaf needs to be moved to a lower tier).
let empty = EmptySubtreeRoots::empty_hashes(64)[depth as usize];
let (is_update, is_simple_insert) = if node == Word::from(empty) {
// handle simple insert case
if depth == 32 || depth == 48 {
// for depth 32 and 48, we need to provide the internal node located on the tier
// above the insert tier
let p_index = Felt::from(index.as_int() >> 16);
let p_depth = Felt::from(depth - 16);
let p_node = self.advice_provider.get_tree_node(root, &p_depth, &p_index)?;
for &element in p_node.iter().rev() {
self.advice_provider.push_stack(AdviceSource::Value(element))?;
}
}

// return is_update = ZERO, is_simple_insert = ONE
(ZERO, ONE)
} else {
// if the node is a leaf node, push the elements mapped to this node onto the advice
// stack; the elements should be [KEY, VALUE], with key located at the top of the
// advice stack.
self.advice_provider.push_stack(AdviceSource::Map {
key: node,
include_len: false,
})?;

// remove the KEY from the advice stack, leaving only the VALUE on the stack
let leaf_key = self.advice_provider.pop_stack_word()?;

// if the key for the value to be inserted is the same as the leaf's key, we are
// dealing with a simple update. otherwise, we are dealing with a complex insert
// (i.e., the leaf needs to be moved to a lower tier).
if leaf_key == key {
// return is_update = ONE, is_simple_insert = ZERO
(ONE, ZERO)
} else {
// return is_update = ZERO, is_simple_insert = ZERO
(ZERO, ZERO)
}
};

// set the flags used to determine which tier the insert is happening at
let is_16_or_32 = if depth == 16 || depth == 32 { ONE } else { ZERO };
let is_16_or_48 = if depth == 16 || depth == 48 { ONE } else { ZERO };

self.advice_provider.push_stack(AdviceSource::Value(is_update))?;
if is_update == ONE {
// for update we don't need to specify whether we are dealing with an insert; but we
// insert an extra ONE at the end so that we can read 4 values from the advice stack
// regardless of which branch is taken.
self.advice_provider.push_stack(AdviceSource::Value(is_16_or_32))?;
self.advice_provider.push_stack(AdviceSource::Value(is_16_or_48))?;
self.advice_provider.push_stack(AdviceSource::Value(ZERO))?;
} else {
self.advice_provider.push_stack(AdviceSource::Value(is_simple_insert))?;
self.advice_provider.push_stack(AdviceSource::Value(is_16_or_32))?;
self.advice_provider.push_stack(AdviceSource::Value(is_16_or_48))?;
}
Ok(())
}
}

// HELPER FUNCTIONS
Expand Down
2 changes: 1 addition & 1 deletion processor/src/decorators/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ where
AdviceInjector::Ext2Inv => self.push_ext2_inv_result(),
AdviceInjector::Ext2Intt => self.push_ext2_intt_result(),
AdviceInjector::SmtGet => self.push_smtget_inputs(),
AdviceInjector::SmtInsert => todo!(),
AdviceInjector::SmtInsert => self.push_smtinsert_inputs(),
AdviceInjector::MemToMap => self.insert_mem_values_into_adv_map(),
AdviceInjector::HdwordToMap { domain } => self.insert_hdword_into_adv_map(*domain),
}
Expand Down
80 changes: 50 additions & 30 deletions processor/src/decorators/tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ use super::{
Process,
};
use crate::{MemAdviceProvider, StackInputs, Word};
use test_utils::{crypto::get_smt_remaining_key, rand::seeded_word};
use test_utils::rand::seeded_word;
use vm_core::{
crypto::{
hash::{Rpo256, RpoDigest},
Expand Down Expand Up @@ -74,13 +74,10 @@ fn push_smtget() {

// check leaves on empty trees
for depth in [16, 32, 48] {
// compute the remaining key
let remaining = get_smt_remaining_key(key, depth);

// compute node value
let depth_element = Felt::from(depth);
let store = MerkleStore::new();
let node = Rpo256::merge_in_domain(&[remaining.into(), value.into()], depth_element);
let node = Rpo256::merge_in_domain(&[key.into(), value.into()], depth_element);

// expect absent value with constant depth 16
let expected = [ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, ONE, ONE];
Expand All @@ -89,9 +86,6 @@ fn push_smtget() {

// check leaves inserted on all tiers
for depth in [16, 32, 48] {
// compute the remaining key
let remaining = get_smt_remaining_key(key, depth);

// set depth flags
let is_16_or_32 = (depth == 16 || depth == 32).then_some(ONE).unwrap_or(ZERO);
let is_16_or_48 = (depth == 16 || depth == 48).then_some(ONE).unwrap_or(ZERO);
Expand All @@ -100,7 +94,7 @@ fn push_smtget() {
let index = key[3].as_int() >> 64 - depth;
let index = NodeIndex::new(depth, index).unwrap();
let depth_element = Felt::from(depth);
let node = Rpo256::merge_in_domain(&[remaining.into(), value.into()], depth_element);
let node = Rpo256::merge_in_domain(&[key.into(), value.into()], depth_element);

// set tier node value and expect the value from the injector
let mut store = MerkleStore::new();
Expand All @@ -111,10 +105,10 @@ fn push_smtget() {
value[2],
value[1],
value[0],
remaining[3],
remaining[2],
remaining[1],
remaining[0],
key[3],
key[2],
key[1],
key[0],
is_16_or_32,
is_16_or_48,
];
Expand Down Expand Up @@ -158,23 +152,41 @@ fn inject_smtinsert() {

let raw_a = 0b_01101001_01101100_00011111_11111111_10010110_10010011_11100000_00000000_u64;
let key_a = build_key(raw_a);
let val_a = [ONE, ZERO, ZERO, ZERO];

// insertion should happen at depth 16 and thus 16_or_32 and 16_or_48 flags should be set to ONE;
// since we are replacing a node which is an empty subtree, the is_empty flag should also be ONE
let expected_stack = [ONE, ONE, ONE];
let process = prepare_smt_insert(key_a, val_a, &smt, expected_stack.len());
let val_a = [Felt::new(3), Felt::new(5), Felt::new(7), Felt::new(9)];

// this is a simple insertion at depth 16, and thus the flags should look as follows:
let is_update = ZERO;
let is_simple_insert = ONE;
let is_16_or_32 = ONE;
let is_16_or_48 = ONE;
let expected_stack = [is_update, is_simple_insert, is_16_or_32, is_16_or_48];
let process = prepare_smt_insert(key_a, val_a, &smt, expected_stack.len(), Vec::new());
assert_eq!(build_expected(&expected_stack), process.stack.trace_state());

// --- update same key with different value -------------------------------

// insert val_a into the tree so that val_b overwrites it
smt.insert(key_a.into(), val_a);
let val_b = [ONE, ONE, ZERO, ZERO];
smt.insert(key_a.into(), val_b);

// we are updating a node at depth 16 and thus 16_or_32 and 16_or_48 flags should be set to ONE;
// since we are updating an existing leaf, the is_empty flag should be set to ZERO
let expected_stack = [ZERO, ONE, ONE];
let process = prepare_smt_insert(key_a, val_b, &smt, expected_stack.len());
// this is a simple update, and thus the flags should look as follows:
let is_update = ONE;
let is_16_or_32 = ONE;
let is_16_or_48 = ONE;

// also, the old value should be present in the advice stack:
let expected_stack = [
val_a[3],
val_a[2],
val_a[1],
val_a[0],
is_update,
is_16_or_32,
is_16_or_48,
ZERO,
];
let adv_map = vec![build_adv_map_entry(key_a, val_a, 16)];
let process = prepare_smt_insert(key_a, val_b, &smt, expected_stack.len(), adv_map);
assert_eq!(build_expected(&expected_stack), process.stack.trace_state());
}

Expand All @@ -183,12 +195,13 @@ fn prepare_smt_insert(
value: Word,
smt: &TieredSmt,
adv_stack_depth: usize,
adv_map: Vec<([u8; 32], Vec<Felt>)>,
) -> Process<MemAdviceProvider> {
let root: Word = smt.root().into();
let store = MerkleStore::from(smt);

let stack_inputs = build_stack_inputs(value, key, root);
let advice_inputs = AdviceInputs::default().with_merkle_store(store);
let advice_inputs = AdviceInputs::default().with_merkle_store(store).with_map(adv_map);
let mut process = build_process(stack_inputs, advice_inputs);

process.execute_op(Operation::Noop).unwrap();
Expand Down Expand Up @@ -217,7 +230,7 @@ fn build_expected(values: &[Felt]) -> [Felt; 16] {
}

fn assert_case_smtget(
depth: u8,
_depth: u8,
key: Word,
value: Word,
node: RpoDigest,
Expand All @@ -226,9 +239,8 @@ fn assert_case_smtget(
expected_stack: &[Felt],
) {
// build the process
let stack_inputs = build_stack_inputs(key, root, Word::default());
let remaining = get_smt_remaining_key(key, depth);
let mapped = remaining.into_iter().chain(value.into_iter()).collect();
let stack_inputs = build_stack_inputs(key, root.into(), Word::default());
let mapped = key.into_iter().chain(value.into_iter()).collect();
let advice_inputs = AdviceInputs::default()
.with_merkle_store(store)
.with_map([(node.into_bytes(), mapped)]);
Expand All @@ -251,7 +263,7 @@ fn build_process(
adv_inputs: AdviceInputs,
) -> Process<MemAdviceProvider> {
let advice_provider = MemAdviceProvider::from(adv_inputs);
Process::new(Kernel::default(), stack_inputs, advice_provider)
Process::new(Kernel::default(), stack_inputs, advice_provider, ExecutionOptions::default())
}

fn build_stack_inputs(w0: Word, w1: Word, w2: Word) -> StackInputs {
Expand Down Expand Up @@ -288,3 +300,11 @@ fn move_adv_to_stack(process: &mut Process<MemAdviceProvider>, adv_stack_depth:
process.execute_op(Operation::AdvPop).unwrap();
}
}

fn build_adv_map_entry(key: Word, val: Word, depth: u8) -> ([u8; 32], Vec<Felt>) {
let node = Rpo256::merge_in_domain(&[key.into(), val.into()], Felt::from(depth));
let mut elements = Vec::new();
elements.extend_from_slice(&key);
elements.extend_from_slice(&val);
(node.into(), elements)
}
Loading

0 comments on commit 5b3b93b

Please sign in to comment.