diff --git a/CHANGELOG.md b/CHANGELOG.md index 413fa67059..ed1c5da691 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -2,6 +2,8 @@ #### Upcoming Changes +* feat: Make air public inputs deserializable [#1648](https://github.com/lambdaclass/cairo-vm/pull/1648) + * feat: Show only layout builtins in air private input [#1651](https://github.com/lambdaclass/cairo-vm/pull/1651) * feat: Sort builtin segment info upon serialization for Cairo PIE [#1654](https://github.com/lambdaclass/cairo-vm/pull/1654) diff --git a/vm/src/air_public_input.rs b/vm/src/air_public_input.rs index 692f391f2a..3a2bf8bc37 100644 --- a/vm/src/air_public_input.rs +++ b/vm/src/air_public_input.rs @@ -1,5 +1,5 @@ use crate::Felt252; -use serde::Serialize; +use serde::{Deserialize, Serialize}; use thiserror_no_std::Error; use crate::{ @@ -14,18 +14,21 @@ use crate::{ }, }; -#[derive(Serialize, Debug)] +#[derive(Serialize, Deserialize, Debug, PartialEq)] pub struct PublicMemoryEntry { pub address: usize, #[serde(serialize_with = "mem_value_serde::serialize")] + #[serde(deserialize_with = "mem_value_serde::deserialize")] pub value: Option, pub page: usize, } mod mem_value_serde { + use core::fmt; + use super::*; - use serde::Serializer; + use serde::{de, Deserializer, Serializer}; pub(crate) fn serialize( value: &Option, @@ -37,9 +40,41 @@ mod mem_value_serde { serializer.serialize_none() } } + + pub(crate) fn deserialize<'de, D: Deserializer<'de>>( + d: D, + ) -> Result, D::Error> { + d.deserialize_str(Felt252OptionVisitor) + } + + struct Felt252OptionVisitor; + + impl<'de> de::Visitor<'de> for Felt252OptionVisitor { + type Value = Option; + + fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result { + formatter.write_str("Could not deserialize hexadecimal string") + } + + fn visit_none(self) -> Result + where + E: de::Error, + { + Ok(None) + } + + fn visit_str(self, value: &str) -> Result + where + E: de::Error, + { + Felt252::from_hex(value) + .map_err(de::Error::custom) + .map(Some) + } + } } -#[derive(Serialize, Debug)] +#[derive(Serialize, Deserialize, Debug, PartialEq)] pub struct MemorySegmentAddresses { pub begin_addr: usize, pub stop_ptr: usize, @@ -55,7 +90,7 @@ impl From<(usize, usize)> for MemorySegmentAddresses { } } -#[derive(Serialize, Debug)] +#[derive(Serialize, Deserialize, Debug)] pub struct PublicInput<'a> { pub layout: &'a str, pub rc_min: isize, @@ -64,6 +99,7 @@ pub struct PublicInput<'a> { pub memory_segments: HashMap<&'a str, MemorySegmentAddresses>, pub public_memory: Vec, #[serde(rename = "dynamic_params")] + #[serde(skip_deserializing)] // This is set to None by default so we can skip it layout_params: Option<&'a CairoLayout>, } @@ -139,3 +175,52 @@ pub enum PublicInputError { #[error(transparent)] Trace(#[from] TraceError), } +#[cfg(test)] +mod tests { + #[cfg(feature = "std")] + use super::*; + #[cfg(feature = "std")] + use rstest::rstest; + + #[cfg(feature = "std")] + #[rstest] + #[case(include_bytes!("../../cairo_programs/proof_programs/fibonacci.json"))] + #[case(include_bytes!("../../cairo_programs/proof_programs/bitwise_output.json"))] + #[case(include_bytes!("../../cairo_programs/proof_programs/keccak_builtin.json"))] + #[case(include_bytes!("../../cairo_programs/proof_programs/poseidon_builtin.json"))] + #[case(include_bytes!("../../cairo_programs/proof_programs/relocate_temporary_segment_append.json"))] + #[case(include_bytes!("../../cairo_programs/proof_programs/pedersen_test.json"))] + #[case(include_bytes!("../../cairo_programs/proof_programs/ec_op.json"))] + fn serialize_and_deserialize_air_public_input(#[case] program_content: &[u8]) { + let config = crate::cairo_run::CairoRunConfig { + proof_mode: true, + relocate_mem: true, + trace_enabled: true, + layout: "all_cairo", + ..Default::default() + }; + let (runner, vm) = crate::cairo_run::cairo_run(program_content, &config, &mut crate::hint_processor::builtin_hint_processor::builtin_hint_processor_definition::BuiltinHintProcessor::new_empty()).unwrap(); + let public_input = runner.get_air_public_input(&vm).unwrap(); + // We already know serialization works as expected due to the comparison against python VM + let serialized_public_input = public_input.serialize_json().unwrap(); + let deserialized_public_input: PublicInput = + serde_json::from_str(&serialized_public_input).unwrap(); + // Check that the deserialized public input is equal to the one we obtained from the vm first + assert_eq!(public_input.layout, deserialized_public_input.layout); + assert_eq!(public_input.rc_max, deserialized_public_input.rc_max); + assert_eq!(public_input.rc_min, deserialized_public_input.rc_min); + assert_eq!(public_input.n_steps, deserialized_public_input.n_steps); + assert_eq!( + public_input.memory_segments, + deserialized_public_input.memory_segments + ); + assert_eq!( + public_input.public_memory, + deserialized_public_input.public_memory + ); + assert!( + public_input.layout_params.is_none() + && deserialized_public_input.layout_params.is_none() + ); + } +}