Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add unsafe constructors to nodes for deserialization #1453

Merged
Merged
9 changes: 9 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,15 @@

## 0.11.0 (TBD)

#### Enhancements

- Updated `MastForest::read_from` to deserialize without computing node hashes unnecessarily (#1453).

#### Changes

- Added `new_unsafe()` constructors to MAST node types which do not compute node hashes (#1453).
- Consolidated `BasicBlockNode` constructors and converted assert flow to `MastForestError::EmptyBasicBlock` (#1453).

## 0.10.3 (2024-08-12)

#### Enhancements
Expand Down
8 changes: 2 additions & 6 deletions assembly/src/assembler/mast_forest_builder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -170,12 +170,8 @@ impl MastForestBuilder {
operations: Vec<Operation>,
decorators: Option<DecoratorList>,
) -> Result<MastNodeId, AssemblyError> {
match decorators {
Some(decorators) => {
self.ensure_node(MastNode::new_basic_block_with_decorators(operations, decorators))
},
None => self.ensure_node(MastNode::new_basic_block(operations)),
}
let block = MastNode::new_basic_block(operations, decorators)?;
self.ensure_node(block)
}

/// Adds a join node to the forest, and returns the [`MastNodeId`] associated with it.
Expand Down
10 changes: 4 additions & 6 deletions core/src/mast/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -68,12 +68,8 @@ impl MastForest {
operations: Vec<Operation>,
decorators: Option<DecoratorList>,
) -> Result<MastNodeId, MastForestError> {
match decorators {
Some(decorators) => {
self.add_node(MastNode::new_basic_block_with_decorators(operations, decorators))
},
None => self.add_node(MastNode::new_basic_block(operations)),
}
let block = MastNode::new_basic_block(operations, decorators)?;
self.add_node(block)
}

/// Adds a join node to the forest, and returns the [`MastNodeId`] associated with it.
Expand Down Expand Up @@ -271,4 +267,6 @@ pub enum MastForestError {
TooManyNodes,
#[error("node id: {0} is greater than or equal to forest length: {1}")]
NodeIdOverflow(MastNodeId, usize),
#[error("basic block cannot be created from an empty list of operations")]
EmptyBasicBlock,
}
79 changes: 47 additions & 32 deletions core/src/mast/node/basic_block_node/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,9 @@ use miden_crypto::{hash::rpo::RpoDigest, Felt, ZERO};
use miden_formatting::prettier::PrettyPrint;
use winter_utils::flatten_slice_elements;

use crate::{chiplets::hasher, Decorator, DecoratorIterator, DecoratorList, Operation};
use crate::{
chiplets::hasher, mast::MastForestError, Decorator, DecoratorIterator, DecoratorList, Operation,
};

mod op_batch;
pub use op_batch::OpBatch;
Expand Down Expand Up @@ -77,31 +79,38 @@ impl BasicBlockNode {
// ------------------------------------------------------------------------------------------------
/// Constructors
impl BasicBlockNode {
/// Returns a new [`BasicBlockNode`] instantiated with the specified operations.
///
/// # Errors (TODO)
/// Returns an error if:
/// - `operations` vector is empty.
/// - `operations` vector contains any number of system operations.
pub fn new(operations: Vec<Operation>) -> Self {
assert!(!operations.is_empty()); // TODO: return error
Self::with_decorators(operations, DecoratorList::new())
}

/// Returns a new [`BasicBlockNode`] instantiated with the specified operations and decorators.
///
/// # Errors (TODO)
/// Returns an error if:
/// - `operations` vector is empty.
/// - `operations` vector contains any number of system operations.
pub fn with_decorators(operations: Vec<Operation>, decorators: DecoratorList) -> Self {
assert!(!operations.is_empty()); // TODO: return error
pub fn new(
operations: Vec<Operation>,
decorators: Option<DecoratorList>,
) -> Result<Self, MastForestError> {
if operations.is_empty() {
return Err(MastForestError::EmptyBasicBlock);
}

// validate decorators list (only in debug mode)
// None is equivalent to an empty list of decorators moving forward.
let decorators = decorators.unwrap_or_default();

// Validate decorators list (only in debug mode).
#[cfg(debug_assertions)]
validate_decorators(&operations, &decorators);

let (op_batches, digest) = batch_ops(operations);
let (op_batches, digest) = batch_and_hash_ops(operations);
Ok(Self { op_batches, digest, decorators })
}

/// Returns a new [`BasicBlockNode`] from values that are assumed to be correct.
/// Should only be used when the source of the inputs is trusted (e.g. deserialization).
pub fn new_unsafe(
bobbinth marked this conversation as resolved.
Show resolved Hide resolved
operations: Vec<Operation>,
decorators: DecoratorList,
digest: RpoDigest,
) -> Self {
assert!(!operations.is_empty());
let (op_batches, _) = batch_ops(operations);
bobbinth marked this conversation as resolved.
Show resolved Hide resolved
Self { op_batches, digest, decorators }
}
}
Expand Down Expand Up @@ -292,18 +301,29 @@ impl<'a> Iterator for OperationOrDecoratorIterator<'a> {
// HELPER FUNCTIONS
// ================================================================================================

/// Groups the provided operations into batches and computes the hash of the block.
fn batch_and_hash_ops(ops: Vec<Operation>) -> (Vec<OpBatch>, RpoDigest) {
// Group the operations into batches.
let (batches, batch_groups) = batch_ops(ops);

// Compute the hash of all operation groups.
let op_groups = &flatten_slice_elements(&batch_groups);
let hash = hasher::hash_elements(op_groups);

(batches, hash)
}

/// Groups the provided operations into batches as described in the docs for this module (i.e.,
/// up to 9 operations per group, and 8 groups per batch).
///
/// After the operations have been grouped, computes the hash of the block.
fn batch_ops(ops: Vec<Operation>) -> (Vec<OpBatch>, RpoDigest) {
let mut batch_acc = OpBatchAccumulator::new();
/// Returns a list of operation batches and a list of operation groups.
fn batch_ops(ops: Vec<Operation>) -> (Vec<OpBatch>, Vec<[Felt; BATCH_SIZE]>) {
let mut batches = Vec::<OpBatch>::new();
let mut batch_acc = OpBatchAccumulator::new();
let mut batch_groups = Vec::<[Felt; BATCH_SIZE]>::new();

for op in ops {
// if the operation cannot be accepted into the current accumulator, add the contents of
// the accumulator to the list of batches and start a new accumulator
// If the operation cannot be accepted into the current accumulator, add the contents of
// the accumulator to the list of batches and start a new accumulator.
if !batch_acc.can_accept_op(op) {
let batch = batch_acc.into_batch();
batch_acc = OpBatchAccumulator::new();
Expand All @@ -312,22 +332,17 @@ fn batch_ops(ops: Vec<Operation>) -> (Vec<OpBatch>, RpoDigest) {
batches.push(batch);
}

// add the operation to the accumulator
// Add the operation to the accumulator.
batch_acc.add_op(op);
}

// make sure we finished processing the last batch
// Make sure we finished processing the last batch.
if !batch_acc.is_empty() {
let batch = batch_acc.into_batch();
batch_groups.push(*batch.groups());
batches.push(batch);
}

// compute the hash of all operation groups
let op_groups = &flatten_slice_elements(&batch_groups);
let hash = hasher::hash_elements(op_groups);

(batches, hash)
(batches, batch_groups)
}

/// Checks if a given decorators list is valid (only checked in debug mode)
Expand Down
20 changes: 10 additions & 10 deletions core/src/mast/node/basic_block_node/tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ use crate::{Decorator, ONE};
fn batch_ops() {
// --- one operation ----------------------------------------------------------------------
let ops = vec![Operation::Add];
let (batches, hash) = super::batch_ops(ops.clone());
let (batches, hash) = super::batch_and_hash_ops(ops.clone());
assert_eq!(1, batches.len());

let batch = &batches[0];
Expand All @@ -21,7 +21,7 @@ fn batch_ops() {

// --- two operations ---------------------------------------------------------------------
let ops = vec![Operation::Add, Operation::Mul];
let (batches, hash) = super::batch_ops(ops.clone());
let (batches, hash) = super::batch_and_hash_ops(ops.clone());
assert_eq!(1, batches.len());

let batch = &batches[0];
Expand All @@ -37,7 +37,7 @@ fn batch_ops() {

// --- one group with one immediate value -------------------------------------------------
let ops = vec![Operation::Add, Operation::Push(Felt::new(12345678))];
let (batches, hash) = super::batch_ops(ops.clone());
let (batches, hash) = super::batch_and_hash_ops(ops.clone());
assert_eq!(1, batches.len());

let batch = &batches[0];
Expand All @@ -63,7 +63,7 @@ fn batch_ops() {
Operation::Push(Felt::new(7)),
Operation::Add,
];
let (batches, hash) = super::batch_ops(ops.clone());
let (batches, hash) = super::batch_and_hash_ops(ops.clone());
assert_eq!(1, batches.len());

let batch = &batches[0];
Expand Down Expand Up @@ -98,7 +98,7 @@ fn batch_ops() {
Operation::Add,
Operation::Push(Felt::new(7)),
];
let (batches, hash) = super::batch_ops(ops.clone());
let (batches, hash) = super::batch_and_hash_ops(ops.clone());
assert_eq!(2, batches.len());

let batch0 = &batches[0];
Expand Down Expand Up @@ -147,7 +147,7 @@ fn batch_ops() {
Operation::Add,
];

let (batches, hash) = super::batch_ops(ops.clone());
let (batches, hash) = super::batch_and_hash_ops(ops.clone());
assert_eq!(1, batches.len());

let batch = &batches[0];
Expand Down Expand Up @@ -181,7 +181,7 @@ fn batch_ops() {
Operation::Add,
Operation::Push(Felt::new(11)),
];
let (batches, hash) = super::batch_ops(ops.clone());
let (batches, hash) = super::batch_and_hash_ops(ops.clone());
assert_eq!(1, batches.len());

let batch = &batches[0];
Expand Down Expand Up @@ -215,7 +215,7 @@ fn batch_ops() {
Operation::Push(ONE),
Operation::Push(Felt::new(2)),
];
let (batches, hash) = super::batch_ops(ops.clone());
let (batches, hash) = super::batch_and_hash_ops(ops.clone());
assert_eq!(1, batches.len());

let batch = &batches[0];
Expand Down Expand Up @@ -260,7 +260,7 @@ fn batch_ops() {
Operation::Pad,
];

let (batches, hash) = super::batch_ops(ops.clone());
let (batches, hash) = super::batch_and_hash_ops(ops.clone());
assert_eq!(2, batches.len());

let batch0 = &batches[0];
Expand Down Expand Up @@ -306,7 +306,7 @@ fn operation_or_decorator_iterator() {
(4, Decorator::Event(4)),
];

let node = BasicBlockNode::with_decorators(operations, decorators);
let node = BasicBlockNode::new(operations, Some(decorators)).unwrap();

let mut iterator = node.iter();

Expand Down
12 changes: 12 additions & 0 deletions core/src/mast/node/call_node.rs
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,12 @@ impl CallNode {
Ok(Self { callee, is_syscall: false, digest })
}

/// Returns a new [`CallNode`] from values that are assumed to be correct.
/// Should only be used when the source of the inputs is trusted (e.g. deserialization).
pub fn new_unsafe(callee: MastNodeId, digest: RpoDigest) -> Self {
Self { callee, is_syscall: false, digest }
}

/// Returns a new [`CallNode`] instantiated with the specified callee and marked as a kernel
/// call.
pub fn new_syscall(
Expand All @@ -68,6 +74,12 @@ impl CallNode {

Ok(Self { callee, is_syscall: true, digest })
}

/// Returns a new syscall [`CallNode`] from values that are assumed to be correct.
/// Should only be used when the source of the inputs is trusted (e.g. deserialization).
pub fn new_syscall_unsafe(callee: MastNodeId, digest: RpoDigest) -> Self {
Self { callee, is_syscall: true, digest }
}
}

//-------------------------------------------------------------------------------------------------
Expand Down
5 changes: 3 additions & 2 deletions core/src/mast/node/join_node.rs
Original file line number Diff line number Diff line change
Expand Up @@ -49,8 +49,9 @@ impl JoinNode {
Ok(Self { children, digest })
}

#[cfg(test)]
pub fn new_test(children: [MastNodeId; 2], digest: RpoDigest) -> Self {
/// Returns a new [`JoinNode`] from values that are assumed to be correct.
/// Should only be used when the source of the inputs is trusted (e.g. deserialization).
pub fn new_unsafe(children: [MastNodeId; 2], digest: RpoDigest) -> Self {
Self { children, digest }
}
}
Expand Down
7 changes: 7 additions & 0 deletions core/src/mast/node/loop_node.rs
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ impl LoopNode {

/// Constructors
impl LoopNode {
/// Returns a new [`LoopNode`] instantiated with the specified body node.
pub fn new(body: MastNodeId, mast_forest: &MastForest) -> Result<Self, MastForestError> {
if body.as_usize() >= mast_forest.nodes.len() {
return Err(MastForestError::NodeIdOverflow(body, mast_forest.nodes.len()));
Expand All @@ -44,6 +45,12 @@ impl LoopNode {

Ok(Self { body, digest })
}

/// Returns a new [`LoopNode`] from values that are assumed to be correct.
/// Should only be used when the source of the inputs is trusted (e.g. deserialization).
pub fn new_unsafe(body: MastNodeId, digest: RpoDigest) -> Self {
Self { body, digest }
}
}

impl LoopNode {
Expand Down
13 changes: 5 additions & 8 deletions core/src/mast/node/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -50,15 +50,12 @@ pub enum MastNode {
// ------------------------------------------------------------------------------------------------
/// Constructors
impl MastNode {
pub fn new_basic_block(operations: Vec<Operation>) -> Self {
Self::Block(BasicBlockNode::new(operations))
}

pub fn new_basic_block_with_decorators(
pub fn new_basic_block(
operations: Vec<Operation>,
decorators: DecoratorList,
) -> Self {
Self::Block(BasicBlockNode::with_decorators(operations, decorators))
decorators: Option<DecoratorList>,
) -> Result<Self, MastForestError> {
let block = BasicBlockNode::new(operations, decorators)?;
Ok(Self::Block(block))
}

pub fn new_join(
Expand Down
5 changes: 3 additions & 2 deletions core/src/mast/node/split_node.rs
Original file line number Diff line number Diff line change
Expand Up @@ -52,8 +52,9 @@ impl SplitNode {
Ok(Self { branches, digest })
}

#[cfg(test)]
pub fn new_test(branches: [MastNodeId; 2], digest: RpoDigest) -> Self {
/// Returns a new [`SplitNode`] from values that are assumed to be correct.
/// Should only be used when the source of the inputs is trusted (e.g. deserialization).
pub fn new_unsafe(branches: [MastNodeId; 2], digest: RpoDigest) -> Self {
Self { branches, digest }
}
}
Expand Down
Loading
Loading