diff --git a/Cargo.toml b/Cargo.toml index 5caa2be8..11e9e50d 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -27,12 +27,18 @@ harness = false [features] default = ["blake3/default", "std", "winter_crypto/default", "winter_math/default", "winter_utils/default"] std = ["blake3/std", "winter_crypto/std", "winter_math/std", "winter_utils/std"] +serde = ["winter_math/serde", "dep:serde", "serde/alloc"] [dependencies] blake3 = { version = "1.4", default-features = false } winter_crypto = { version = "0.6", package = "winter-crypto", default-features = false } winter_math = { version = "0.6", package = "winter-math", default-features = false } winter_utils = { version = "0.6", package = "winter-utils", default-features = false } +serde = { version = "1.0", features = [ "derive" ], optional = true, default-features = false } + +[patch.crates-io] +winter_math = { git = "https://github.com/hackaugusto/winterfell/", branch = "hacka-conditional-support-for-serde", package = "winter-math" } +winter_utils = { git = "https://github.com/hackaugusto/winterfell/", branch = "hacka-conditional-support-for-serde", package = "winter-utils" } [dev-dependencies] criterion = { version = "0.5", features = ["html_reports"] } diff --git a/src/hash/blake/mod.rs b/src/hash/blake/mod.rs index 91c9bca2..9f02eec3 100644 --- a/src/hash/blake/mod.rs +++ b/src/hash/blake/mod.rs @@ -1,5 +1,8 @@ use super::{Digest, ElementHasher, Felt, FieldElement, Hasher, StarkField}; -use crate::utils::{ByteReader, ByteWriter, Deserializable, DeserializationError, Serializable}; +use crate::utils::{ + bytes_to_hex_string, hex_to_bytes, string::String, ByteReader, ByteWriter, Deserializable, + DeserializationError, HexParseError, Serializable, +}; use core::{ mem::{size_of, transmute, transmute_copy}, ops::Deref, @@ -24,6 +27,8 @@ const DIGEST20_BYTES: usize = 20; /// Note: `N` can't be greater than `32` because [`Digest::as_bytes`] currently supports only 32 /// bytes. #[derive(Debug, Copy, Clone, Eq, PartialEq)] +#[cfg_attr(feature = "serde", derive(serde::Deserialize, serde::Serialize))] +#[cfg_attr(feature = "serde", serde(into = "String", try_from = "&str"))] pub struct Blake3Digest([u8; N]); impl Default for Blake3Digest { @@ -52,6 +57,20 @@ impl From<[u8; N]> for Blake3Digest { } } +impl From> for String { + fn from(value: Blake3Digest) -> Self { + bytes_to_hex_string(value.as_bytes()) + } +} + +impl TryFrom<&str> for Blake3Digest { + type Error = HexParseError; + + fn try_from(value: &str) -> Result { + hex_to_bytes(value).map(|v| v.into()) + } +} + impl Serializable for Blake3Digest { fn write_into(&self, target: &mut W) { target.write_bytes(&self.0); diff --git a/src/hash/rpo/digest.rs b/src/hash/rpo/digest.rs index 0e6c3109..7368f65e 100644 --- a/src/hash/rpo/digest.rs +++ b/src/hash/rpo/digest.rs @@ -1,13 +1,19 @@ use super::{Digest, Felt, StarkField, DIGEST_SIZE, ZERO}; use crate::utils::{ - string::String, ByteReader, ByteWriter, Deserializable, DeserializationError, Serializable, + bytes_to_hex_string, hex_to_bytes, string::String, ByteReader, ByteWriter, Deserializable, + DeserializationError, HexParseError, Serializable, }; use core::{cmp::Ordering, fmt::Display, ops::Deref}; +/// The number of bytes needed to encoded a digest +pub const DIGEST_BYTES: usize = 32; + // DIGEST TRAIT IMPLEMENTATIONS // ================================================================================================ #[derive(Debug, Default, Copy, Clone, Eq, PartialEq)] +#[cfg_attr(feature = "serde", derive(serde::Deserialize, serde::Serialize))] +#[cfg_attr(feature = "serde", serde(into = "String", try_from = "&str"))] pub struct RpoDigest([Felt; DIGEST_SIZE]); impl RpoDigest { @@ -19,7 +25,7 @@ impl RpoDigest { self.as_ref() } - pub fn as_bytes(&self) -> [u8; 32] { + pub fn as_bytes(&self) -> [u8; DIGEST_BYTES] { ::as_bytes(self) } @@ -32,8 +38,8 @@ impl RpoDigest { } impl Digest for RpoDigest { - fn as_bytes(&self) -> [u8; 32] { - let mut result = [0; 32]; + fn as_bytes(&self) -> [u8; DIGEST_BYTES] { + let mut result = [0; DIGEST_BYTES]; result[..8].copy_from_slice(&self.0[0].as_int().to_le_bytes()); result[8..16].copy_from_slice(&self.0[1].as_int().to_le_bytes()); @@ -73,6 +79,29 @@ impl From<[Felt; DIGEST_SIZE]> for RpoDigest { } } +impl From<[u64; DIGEST_SIZE]> for RpoDigest { + fn from(value: [u64; DIGEST_SIZE]) -> Self { + Self([ + Felt::new(value[0]), + Felt::new(value[1]), + Felt::new(value[2]), + Felt::new(value[3]), + ]) + } +} + +impl From<[u8; DIGEST_BYTES]> for RpoDigest { + fn from(value: [u8; DIGEST_BYTES]) -> Self { + // Note: `unwrap`s below are safe since we know the length of the input + Self([ + value[0..8].try_into().unwrap(), + value[8..16].try_into().unwrap(), + value[16..24].try_into().unwrap(), + value[24..32].try_into().unwrap(), + ]) + } +} + impl From<&RpoDigest> for [Felt; DIGEST_SIZE] { fn from(value: &RpoDigest) -> Self { value.0 @@ -107,18 +136,54 @@ impl From for [u64; DIGEST_SIZE] { } } -impl From<&RpoDigest> for [u8; 32] { +impl From<&RpoDigest> for [u8; DIGEST_BYTES] { fn from(value: &RpoDigest) -> Self { value.as_bytes() } } -impl From for [u8; 32] { +impl From for [u8; DIGEST_BYTES] { fn from(value: RpoDigest) -> Self { value.as_bytes() } } +impl From for String { + fn from(value: RpoDigest) -> Self { + bytes_to_hex_string(value.as_bytes()) + } +} + +impl From<&RpoDigest> for String { + fn from(value: &RpoDigest) -> Self { + (*value).into() + } +} + +impl TryFrom<&str> for RpoDigest { + type Error = HexParseError; + + fn try_from(value: &str) -> Result { + hex_to_bytes(value).map(|v| v.into()) + } +} + +impl TryFrom for RpoDigest { + type Error = HexParseError; + + fn try_from(value: String) -> Result { + value.as_str().try_into() + } +} + +impl TryFrom<&String> for RpoDigest { + type Error = HexParseError; + + fn try_from(value: &String) -> Result { + value.as_str().try_into() + } +} + impl Deref for RpoDigest { type Target = [Felt; DIGEST_SIZE]; @@ -158,9 +223,8 @@ impl PartialOrd for RpoDigest { impl Display for RpoDigest { fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { - for byte in self.as_bytes() { - write!(f, "{byte:02x}")?; - } + let encoded: String = self.into(); + write!(f, "{}", encoded)?; Ok(()) } } @@ -170,8 +234,7 @@ impl Display for RpoDigest { #[cfg(test)] mod tests { - - use super::{Deserializable, Felt, RpoDigest, Serializable}; + use super::{Deserializable, Felt, RpoDigest, Serializable, DIGEST_BYTES}; use crate::utils::SliceReader; use rand_utils::rand_value; @@ -186,11 +249,27 @@ mod tests { let mut bytes = vec![]; d1.write_into(&mut bytes); - assert_eq!(32, bytes.len()); + assert_eq!(DIGEST_BYTES, bytes.len()); let mut reader = SliceReader::new(&bytes); let d2 = RpoDigest::read_from(&mut reader).unwrap(); assert_eq!(d1, d2); } + + #[cfg(feature = "std")] + #[test] + fn digest_encoding() { + let digest = RpoDigest([ + Felt::new(rand_value()), + Felt::new(rand_value()), + Felt::new(rand_value()), + Felt::new(rand_value()), + ]); + + let string: String = digest.into(); + let round_trip: RpoDigest = string.try_into().expect("decoding failed"); + + assert_eq!(digest, round_trip); + } } diff --git a/src/merkle/delta.rs b/src/merkle/delta.rs index 71b822a7..cf6d1b9d 100644 --- a/src/merkle/delta.rs +++ b/src/merkle/delta.rs @@ -13,6 +13,7 @@ use super::{empty_roots::EMPTY_WORD, Felt, SimpleSmt}; /// [RpoDigest] represents the root of the Merkle tree and [MerkleTreeDelta] represents the /// differences between the initial and final Merkle tree states. #[derive(Default, Debug, Clone, PartialEq, Eq)] +#[cfg_attr(feature = "serde", derive(serde::Deserialize, serde::Serialize))] pub struct MerkleStoreDelta(pub Vec<(RpoDigest, MerkleTreeDelta)>); // MERKLE TREE DELTA @@ -26,6 +27,7 @@ pub struct MerkleStoreDelta(pub Vec<(RpoDigest, MerkleTreeDelta)>); /// - updated_slots: index-value pairs of slots where values were set to non [ZERO; 4] values. #[cfg(not(test))] #[derive(Debug, Clone, PartialEq, Eq)] +#[cfg_attr(feature = "serde", derive(serde::Deserialize, serde::Serialize))] pub struct MerkleTreeDelta { depth: u8, cleared_slots: Vec, @@ -107,6 +109,7 @@ pub fn merkle_tree_delta>( // -------------------------------------------------------------------------------------------- #[cfg(test)] #[derive(Debug, Clone, PartialEq, Eq)] +#[cfg_attr(feature = "serde", derive(serde::Deserialize, serde::Serialize))] pub struct MerkleTreeDelta { pub depth: u8, pub cleared_slots: Vec, diff --git a/src/merkle/index.rs b/src/merkle/index.rs index 3a79ac07..25c9282d 100644 --- a/src/merkle/index.rs +++ b/src/merkle/index.rs @@ -21,6 +21,7 @@ use core::fmt::Display; /// The root is represented by the pair $(0, 0)$, its left child is $(1, 0)$ and its right child /// $(1, 1)$. #[derive(Debug, Default, Copy, Clone, Eq, PartialEq, PartialOrd, Ord, Hash)] +#[cfg_attr(feature = "serde", derive(serde::Deserialize, serde::Serialize))] pub struct NodeIndex { depth: u8, value: u64, diff --git a/src/merkle/merkle_tree.rs b/src/merkle/merkle_tree.rs index cfb61bc5..206543aa 100644 --- a/src/merkle/merkle_tree.rs +++ b/src/merkle/merkle_tree.rs @@ -8,6 +8,7 @@ use winter_math::log2; /// A fully-balanced binary Merkle tree (i.e., a tree where the number of leaves is a power of two). #[derive(Debug, Clone, PartialEq, Eq)] +#[cfg_attr(feature = "serde", derive(serde::Deserialize, serde::Serialize))] pub struct MerkleTree { nodes: Vec, } diff --git a/src/merkle/mmr/accumulator.rs b/src/merkle/mmr/accumulator.rs index 0729c940..a610fe72 100644 --- a/src/merkle/mmr/accumulator.rs +++ b/src/merkle/mmr/accumulator.rs @@ -4,6 +4,7 @@ use super::{ }; #[derive(Debug, Clone, PartialEq)] +#[cfg_attr(feature = "serde", derive(serde::Deserialize, serde::Serialize))] pub struct MmrPeaks { /// The number of leaves is used to differentiate accumulators that have the same number of /// peaks. This happens because the number of peaks goes up-and-down as the structure is used diff --git a/src/merkle/mmr/full.rs b/src/merkle/mmr/full.rs index d2fbbeb2..c3dd3ac4 100644 --- a/src/merkle/mmr/full.rs +++ b/src/merkle/mmr/full.rs @@ -29,6 +29,7 @@ use std::error::Error; /// Since this is a full representation of the MMR, elements are never removed and the MMR will /// grow roughly `O(2n)` in number of leaf elements. #[derive(Debug, Clone)] +#[cfg_attr(feature = "serde", derive(serde::Deserialize, serde::Serialize))] pub struct Mmr { /// Refer to the `forest` method documentation for details of the semantics of this value. pub(super) forest: usize, diff --git a/src/merkle/mmr/proof.rs b/src/merkle/mmr/proof.rs index 0904b83c..d9b4bcfb 100644 --- a/src/merkle/mmr/proof.rs +++ b/src/merkle/mmr/proof.rs @@ -3,6 +3,7 @@ use super::super::MerklePath; use super::full::{high_bitmask, leaf_to_corresponding_tree}; #[derive(Debug, Clone, PartialEq)] +#[cfg_attr(feature = "serde", derive(serde::Deserialize, serde::Serialize))] pub struct MmrProof { /// The state of the MMR when the MmrProof was created. pub forest: usize, diff --git a/src/merkle/node.rs b/src/merkle/node.rs index 8440af80..4305e7f7 100644 --- a/src/merkle/node.rs +++ b/src/merkle/node.rs @@ -2,6 +2,7 @@ use crate::hash::rpo::RpoDigest; /// Representation of a node with two children used for iterating over containers. #[derive(Debug, Clone, PartialEq, Eq)] +#[cfg_attr(feature = "serde", derive(serde::Deserialize, serde::Serialize))] pub struct InnerNodeInfo { pub value: RpoDigest, pub left: RpoDigest, diff --git a/src/merkle/partial_mt/mod.rs b/src/merkle/partial_mt/mod.rs index a615e18b..1231a049 100644 --- a/src/merkle/partial_mt/mod.rs +++ b/src/merkle/partial_mt/mod.rs @@ -28,6 +28,7 @@ const EMPTY_DIGEST: RpoDigest = RpoDigest::new([ZERO; 4]); /// /// The root of the tree is recomputed on each new leaf update. #[derive(Debug, Clone, PartialEq, Eq)] +#[cfg_attr(feature = "serde", derive(serde::Deserialize, serde::Serialize))] pub struct PartialMerkleTree { max_depth: u8, nodes: BTreeMap, diff --git a/src/merkle/path.rs b/src/merkle/path.rs index 975bc68f..86f66e44 100644 --- a/src/merkle/path.rs +++ b/src/merkle/path.rs @@ -6,6 +6,7 @@ use core::ops::{Deref, DerefMut}; /// A merkle path container, composed of a sequence of nodes of a Merkle tree. #[derive(Clone, Debug, Default, PartialEq, Eq)] +#[cfg_attr(feature = "serde", derive(serde::Deserialize, serde::Serialize))] pub struct MerklePath { nodes: Vec, } diff --git a/src/merkle/simple_smt/mod.rs b/src/merkle/simple_smt/mod.rs index 542ab510..d9e80c17 100644 --- a/src/merkle/simple_smt/mod.rs +++ b/src/merkle/simple_smt/mod.rs @@ -13,6 +13,7 @@ mod tests; /// /// The root of the tree is recomputed on each new leaf update. #[derive(Debug, Clone, PartialEq, Eq)] +#[cfg_attr(feature = "serde", derive(serde::Deserialize, serde::Serialize))] pub struct SimpleSmt { depth: u8, root: RpoDigest, @@ -265,6 +266,7 @@ impl SimpleSmt { // ================================================================================================ #[derive(Debug, Default, Clone, PartialEq, Eq)] +#[cfg_attr(feature = "serde", derive(serde::Deserialize, serde::Serialize))] struct BranchNode { left: RpoDigest, right: RpoDigest, diff --git a/src/merkle/store/mod.rs b/src/merkle/store/mod.rs index 8d8b80aa..8c236189 100644 --- a/src/merkle/store/mod.rs +++ b/src/merkle/store/mod.rs @@ -19,6 +19,7 @@ pub type DefaultMerkleStore = MerkleStore>; pub type RecordingMerkleStore = MerkleStore>; #[derive(Debug, Default, Copy, Clone, Eq, PartialEq)] +#[cfg_attr(feature = "serde", derive(serde::Deserialize, serde::Serialize))] pub struct StoreNode { left: RpoDigest, right: RpoDigest, @@ -87,6 +88,7 @@ pub struct StoreNode { /// assert_eq!(store.num_internal_nodes() - 255, 10); /// ``` #[derive(Debug, Clone, Eq, PartialEq)] +#[cfg_attr(feature = "serde", derive(serde::Deserialize, serde::Serialize))] pub struct MerkleStore = BTreeMap> { nodes: T, } diff --git a/src/merkle/tiered_smt/mod.rs b/src/merkle/tiered_smt/mod.rs index fbdf2d36..2cb87926 100644 --- a/src/merkle/tiered_smt/mod.rs +++ b/src/merkle/tiered_smt/mod.rs @@ -43,6 +43,7 @@ mod tests; /// - Leaf node at depths 16, 32, or 64: hash(key, value, domain=depth). /// - Leaf node at depth 64: hash([key_0, value_0, ..., key_n, value_n], domain=64). #[derive(Debug, Clone, PartialEq, Eq)] +#[cfg_attr(feature = "serde", derive(serde::Deserialize, serde::Serialize))] pub struct TieredSmt { root: RpoDigest, nodes: NodeStore, diff --git a/src/merkle/tiered_smt/nodes.rs b/src/merkle/tiered_smt/nodes.rs index 0d94091c..1bb34df1 100644 --- a/src/merkle/tiered_smt/nodes.rs +++ b/src/merkle/tiered_smt/nodes.rs @@ -24,6 +24,7 @@ const MAX_DEPTH: u8 = super::TieredSmt::MAX_DEPTH; /// represent leaf nodes in a Tiered Sparse Merkle tree. In the current implementation, [BTreeSet]s /// are used to determine the position of the leaves in the tree. #[derive(Debug, Clone, PartialEq, Eq)] +#[cfg_attr(feature = "serde", derive(serde::Deserialize, serde::Serialize))] pub struct NodeStore { nodes: BTreeMap, upper_leaves: BTreeSet, diff --git a/src/merkle/tiered_smt/values.rs b/src/merkle/tiered_smt/values.rs index ec2a4651..d41ee6b0 100644 --- a/src/merkle/tiered_smt/values.rs +++ b/src/merkle/tiered_smt/values.rs @@ -26,6 +26,7 @@ const MAX_DEPTH: u8 = super::TieredSmt::MAX_DEPTH; /// The store supports lookup by the full key (i.e. [RpoDigest]) as well as by the 64-bit key /// prefix. #[derive(Debug, Default, Clone, PartialEq, Eq)] +#[cfg_attr(feature = "serde", derive(serde::Deserialize, serde::Serialize))] pub struct ValueStore { values: BTreeMap, } @@ -173,6 +174,7 @@ impl ValueStore { /// An entry can contain either a single key-value pair or a vector of key-value pairs sorted by /// key. #[derive(Debug, Clone, PartialEq, Eq)] +#[cfg_attr(feature = "serde", derive(serde::Deserialize, serde::Serialize))] pub enum StoreEntry { Single((RpoDigest, Word)), List(Vec<(RpoDigest, Word)>), diff --git a/src/utils/mod.rs b/src/utils/mod.rs index d71cd332..5713d465 100644 --- a/src/utils/mod.rs +++ b/src/utils/mod.rs @@ -1,5 +1,5 @@ use super::{utils::string::String, Word}; -use core::fmt::{self, Write}; +use core::fmt::{self, Display, Write}; #[cfg(not(feature = "std"))] pub use alloc::{format, vec}; @@ -36,3 +36,74 @@ pub fn word_to_hex(w: &Word) -> Result { Ok(s) } + +/// Renders an array of bytes as hex into a String. +pub fn bytes_to_hex_string(data: [u8; N]) -> String { + let mut s = String::with_capacity(N + 2); + + s.push_str("0x"); + for byte in data.iter() { + write!(s, "{byte:02x}").expect("formatting hex failed"); + } + + s +} + +#[derive(Debug)] +pub enum HexParseError { + InvalidLength { expected: usize, got: usize }, + MissingPrefix, + InvalidChar, +} + +impl Display for HexParseError { + fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { + match self { + HexParseError::InvalidLength { expected, got } => { + write!(f, "Hex encoded RpoDigest must have length 66, including the 0x prefix. expected {expected} got {got}") + } + HexParseError::MissingPrefix => { + write!(f, "Hex encoded RpoDigest must start with 0x prefix") + } + HexParseError::InvalidChar => { + write!(f, "Hex encoded RpoDigest must contain characters [a-zA-Z0-9]") + } + } + } +} + +#[cfg(feature = "std")] +impl std::error::Error for HexParseError {} + +/// Parses a hex string into an array of bytes of known size. +pub fn hex_to_bytes(value: &str) -> Result<[u8; N], HexParseError> { + let expected: usize = (N * 2) + 2; + if value.len() != expected { + return Err(HexParseError::InvalidLength { + expected, + got: value.len(), + }); + } + + if !value.starts_with("0x") { + return Err(HexParseError::MissingPrefix); + } + + let mut data = value.bytes().skip(2).map(|v| match v { + b'0'..=b'9' => Ok(v - b'0'), + b'a'..=b'f' => Ok(v - b'a' + 10), + b'A'..=b'F' => Ok(v - b'A' + 10), + _ => Err(HexParseError::InvalidChar), + }); + + let mut decoded = [0u8; N]; + #[allow(clippy::needless_range_loop)] + for pos in 0..N { + // These `unwrap` calls are okay because the length was checked above + let high: u8 = data.next().unwrap()?; + let low: u8 = data.next().unwrap()?; + decoded[pos] = (high << 4) + low; + } + + Ok(decoded) +}