Skip to content

Commit

Permalink
Wrap os_input in Arc
Browse files Browse the repository at this point in the history
  • Loading branch information
notlesh committed Oct 15, 2024
1 parent bf7f291 commit 3ddd5ef
Show file tree
Hide file tree
Showing 12 changed files with 53 additions and 37 deletions.
5 changes: 3 additions & 2 deletions crates/bin/prove_block/src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
use std::collections::HashMap;
use std::sync::Arc;

use blockifier::state::cached_state::CachedState;
use cairo_vm::types::layout_name::LayoutName;
Expand Down Expand Up @@ -307,7 +308,7 @@ pub async fn prove_block(

let contract_class_commitment_info = compute_class_commitment(&previous_class_proofs, &class_proofs);

let os_input = StarknetOsInput {
let os_input = Arc::new(StarknetOsInput {
contract_state_commitment_info,
contract_class_commitment_info,
deprecated_compiled_classes,
Expand All @@ -322,7 +323,7 @@ pub async fn prove_block(
new_block_hash: block_with_txs.block_hash,
prev_block_hash: previous_block.block_hash,
full_output,
};
});
let execution_helper = ExecutionHelperWrapper::<ProverPerContractStorage>::new(
contract_storages,
tx_execution_infos,
Expand Down
5 changes: 3 additions & 2 deletions crates/starknet-os/src/execution/helper.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
use std::collections::HashMap;
use std::ops::Deref;
use std::rc::Rc;
use std::sync::Arc;
use std::vec::IntoIter;

use blockifier::context::BlockContext;
Expand Down Expand Up @@ -29,7 +30,7 @@ where
PCS: PerContractStorage,
{
pub _prev_block_context: Option<BlockContext>,
pub os_input: Option<StarknetOsInput>,
pub os_input: Option<Arc<StarknetOsInput>>,
pub kzg_manager: KzgManager,
// Pointer tx execution info
pub tx_execution_info_iter: IntoIter<TransactionExecutionInfo>,
Expand Down Expand Up @@ -114,7 +115,7 @@ where
contract_storage_map: ContractStorageMap<PCS>,
tx_execution_infos: Vec<TransactionExecutionInfo>,
block_context: &BlockContext,
os_input: Option<StarknetOsInput>,
os_input: Option<Arc<StarknetOsInput>>,
old_block_number_and_hash: (Felt252, Felt252),
) -> Self {
// Block number and block hash (current_block_number - buffer) block buffer=STORED_BLOCK_HASH_BUFFER
Expand Down
15 changes: 8 additions & 7 deletions crates/starknet-os/src/hints/block_context.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ use core::panic;
use std::any::Any;
use std::collections::hash_map::IntoIter;
use std::collections::HashMap;
use std::sync::Arc;

use blockifier::context::BlockContext;
use cairo_vm::hint_processor::builtin_hint_processor::dict_manager::Dictionary;
Expand Down Expand Up @@ -41,7 +42,7 @@ pub fn load_class_facts(
ap_tracking: &ApTracking,
_constants: &HashMap<String, Felt252>,
) -> Result<(), HintError> {
let os_input = exec_scopes.get::<StarknetOsInput>(vars::scopes::OS_INPUT)?;
let os_input: Arc<StarknetOsInput> = exec_scopes.get::<Arc<StarknetOsInput>>(vars::scopes::OS_INPUT)?.clone();
let compiled_class_facts_ptr = vm.add_memory_segment();
insert_value_from_var_name(vars::ids::COMPILED_CLASS_FACTS, compiled_class_facts_ptr, vm, ids_data, ap_tracking)?;

Expand All @@ -53,8 +54,8 @@ pub fn load_class_facts(
ap_tracking,
)?;

let compiled_class_facts: Box<dyn Any> = Box::new(os_input.compiled_classes.into_iter());
let compiled_class_visited_pcs: Box<dyn Any> = Box::new(os_input.compiled_class_visited_pcs);
let compiled_class_facts: Box<dyn Any> = Box::new(os_input.compiled_classes.clone().into_iter());
let compiled_class_visited_pcs: Box<dyn Any> = Box::new(os_input.compiled_class_visited_pcs.clone());
exec_scopes.enter_scope(HashMap::from([
(String::from(vars::scopes::COMPILED_CLASS_FACTS), compiled_class_facts),
(String::from(vars::scopes::COMPILED_CLASS_VISITED_PCS), compiled_class_visited_pcs),
Expand Down Expand Up @@ -218,7 +219,7 @@ pub fn chain_id(
_ap_tracking: &ApTracking,
_constants: &HashMap<String, Felt252>,
) -> Result<(), HintError> {
let os_input = exec_scopes.get::<StarknetOsInput>(vars::scopes::OS_INPUT)?;
let os_input = exec_scopes.get::<Arc<StarknetOsInput>>(vars::scopes::OS_INPUT)?;
let chain_id = chain_id_to_felt(&os_input.general_config.starknet_os_config.chain_id);
insert_value_into_ap(vm, chain_id)
}
Expand All @@ -231,7 +232,7 @@ pub fn fee_token_address(
_ap_tracking: &ApTracking,
_constants: &HashMap<String, Felt252>,
) -> Result<(), HintError> {
let os_input = exec_scopes.get::<StarknetOsInput>(vars::scopes::OS_INPUT)?;
let os_input = exec_scopes.get::<Arc<StarknetOsInput>>(vars::scopes::OS_INPUT)?;
let fee_token_address = *os_input.general_config.starknet_os_config.fee_token_address.0.key();
log::debug!("fee_token_address: {}", fee_token_address);
insert_value_into_ap(vm, fee_token_address)
Expand All @@ -246,7 +247,7 @@ pub fn deprecated_fee_token_address(
_ap_tracking: &ApTracking,
_constants: &HashMap<String, Felt252>,
) -> Result<(), HintError> {
let os_input = exec_scopes.get::<StarknetOsInput>(vars::scopes::OS_INPUT)?;
let os_input = exec_scopes.get::<Arc<StarknetOsInput>>(vars::scopes::OS_INPUT)?;
let deprecated_fee_token_address = *os_input.general_config.starknet_os_config.deprecated_fee_token_address.0.key();
log::debug!("deprecated_fee_token_address: {}", deprecated_fee_token_address);
insert_value_into_ap(vm, deprecated_fee_token_address)
Expand Down Expand Up @@ -357,7 +358,7 @@ pub fn write_use_kzg_da_to_mem(
let block_context = exec_scopes.get_ref::<BlockContext>(vars::scopes::BLOCK_CONTEXT)?;
let use_kzg_da = block_context.block_info().use_kzg_da;

let os_input: &StarknetOsInput = exec_scopes.get_ref(vars::scopes::OS_INPUT)?;
let os_input = exec_scopes.get::<Arc<StarknetOsInput>>(vars::scopes::OS_INPUT)?;
let full_output = os_input.full_output;

let use_kzg_da_felt = if use_kzg_da && !full_output { Felt252::ONE } else { Felt252::ZERO };
Expand Down
5 changes: 3 additions & 2 deletions crates/starknet-os/src/hints/deprecated_compiled_class.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
use std::any::Any;
use std::collections::hash_map::IntoIter;
use std::collections::{HashMap, HashSet};
use std::sync::Arc;

use cairo_vm::hint_processor::builtin_hint_processor::hint_utils::{get_ptr_from_var_name, insert_value_from_var_name};
use cairo_vm::hint_processor::hint_processor_definition::{HintExtension, HintProcessor, HintReference};
Expand Down Expand Up @@ -35,7 +36,7 @@ pub fn load_deprecated_class_facts(
ap_tracking: &ApTracking,
_constants: &HashMap<String, Felt252>,
) -> Result<(), HintError> {
let os_input = exec_scopes.get::<StarknetOsInput>(vars::scopes::OS_INPUT)?;
let os_input = exec_scopes.get::<Arc<StarknetOsInput>>(vars::scopes::OS_INPUT)?;
let deprecated_class_hashes: HashSet<Felt252> =
HashSet::from_iter(os_input.deprecated_compiled_classes.keys().cloned());
exec_scopes.insert_value(vars::scopes::DEPRECATED_CLASS_HASHES, deprecated_class_hashes);
Expand All @@ -48,7 +49,7 @@ pub fn load_deprecated_class_facts(
ids_data,
ap_tracking,
)?;
let scoped_classes: Box<dyn Any> = Box::new(os_input.deprecated_compiled_classes.into_iter());
let scoped_classes: Box<dyn Any> = Box::new(os_input.deprecated_compiled_classes.clone().into_iter());
exec_scopes.enter_scope(HashMap::from([(String::from(vars::scopes::COMPILED_CLASS_FACTS), scoped_classes)]));

Ok(())
Expand Down
7 changes: 4 additions & 3 deletions crates/starknet-os/src/hints/execution.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
use std::any::Any;
use std::borrow::Cow;
use std::collections::{HashMap, HashSet};
use std::sync::Arc;
use std::vec::IntoIter;

use cairo_vm::hint_processor::builtin_hint_processor::dict_manager::Dictionary;
Expand Down Expand Up @@ -435,11 +436,11 @@ pub fn enter_syscall_scopes<PCS>(
where
PCS: PerContractStorage + 'static,
{
let os_input = exec_scopes.get::<StarknetOsInput>(vars::scopes::OS_INPUT)?;
let os_input = exec_scopes.get::<Arc<StarknetOsInput>>(vars::scopes::OS_INPUT)?;
let deprecated_class_hashes: Box<dyn Any> =
Box::new(exec_scopes.get::<HashSet<Felt252>>(vars::scopes::DEPRECATED_CLASS_HASHES)?);
let transactions: Box<dyn Any> = Box::new(os_input.transactions.into_iter());
let component_hashes: Box<dyn Any> = Box::new(os_input.declared_class_hash_to_component_hashes);
let transactions: Box<dyn Any> = Box::new(os_input.transactions.clone().into_iter());
let component_hashes: Box<dyn Any> = Box::new(os_input.declared_class_hash_to_component_hashes.clone());
let execution_helper: Box<dyn Any> =
Box::new(exec_scopes.get::<ExecutionHelperWrapper<PCS>>(vars::scopes::EXECUTION_HELPER)?);
let deprecated_syscall_handler: Box<dyn Any> =
Expand Down
11 changes: 6 additions & 5 deletions crates/starknet-os/src/hints/mod.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
use std::collections::{HashMap, HashSet};
use std::marker::PhantomData;
use std::sync::Arc;

use cairo_lang_casm::hints::{Hint, StarknetHint};
use cairo_lang_casm::operand::{BinOpOperand, DerefOrImmediate, Operation, Register, ResOperand};
Expand Down Expand Up @@ -471,9 +472,9 @@ pub fn initialize_state_changes(
_ap_tracking: &ApTracking,
_constants: &HashMap<String, Felt252>,
) -> Result<(), HintError> {
let os_input = exec_scopes.get::<StarknetOsInput>(vars::scopes::OS_INPUT)?;
let os_input = exec_scopes.get::<Arc<StarknetOsInput>>(vars::scopes::OS_INPUT)?;
let mut state_dict: HashMap<MaybeRelocatable, MaybeRelocatable> = HashMap::new();
for (addr, contract_state) in os_input.contracts {
for (addr, contract_state) in &os_input.contracts {
let change_base = vm.add_memory_segment();
vm.insert_value(change_base, Felt252::from_bytes_be_slice(&contract_state.contract_hash))?;
let storage_commitment_base = vm.add_memory_segment();
Expand All @@ -496,9 +497,9 @@ pub fn initialize_class_hashes(
_ap_tracking: &ApTracking,
_constants: &HashMap<String, Felt252>,
) -> Result<(), HintError> {
let os_input = exec_scopes.get::<StarknetOsInput>(vars::scopes::OS_INPUT)?;
let os_input = exec_scopes.get::<Arc<StarknetOsInput>>(vars::scopes::OS_INPUT)?;
let mut class_dict: HashMap<MaybeRelocatable, MaybeRelocatable> = HashMap::new();
for (class_hash, compiled_class_hash) in os_input.class_hash_to_compiled_class_hash {
for (class_hash, compiled_class_hash) in &os_input.class_hash_to_compiled_class_hash {
class_dict.insert(MaybeRelocatable::from(class_hash), MaybeRelocatable::from(compiled_class_hash));
}

Expand Down Expand Up @@ -699,7 +700,7 @@ pub fn os_input_transactions(
_ap_tracking: &ApTracking,
_constants: &HashMap<String, Felt252>,
) -> Result<(), HintError> {
let os_input = exec_scopes.get::<StarknetOsInput>(vars::scopes::OS_INPUT)?;
let os_input = exec_scopes.get::<Arc<StarknetOsInput>>(vars::scopes::OS_INPUT)?;
let num_txns = os_input.transactions.len();
vm.insert_value((vm.get_fp() + 12)?, num_txns).map_err(HintError::Memory)
}
7 changes: 4 additions & 3 deletions crates/starknet-os/src/hints/os.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
use std::collections::HashMap;
use std::sync::Arc;

use cairo_vm::hint_processor::builtin_hint_processor::hint_utils::insert_value_into_ap;
use cairo_vm::hint_processor::hint_processor_definition::HintReference;
Expand All @@ -22,7 +23,7 @@ pub fn write_full_output_to_mem(
_ap_tracking: &ApTracking,
_constants: &HashMap<String, Felt252>,
) -> Result<(), HintError> {
let os_input: &StarknetOsInput = exec_scopes.get_ref(vars::scopes::OS_INPUT)?;
let os_input: Arc<StarknetOsInput> = exec_scopes.get(vars::scopes::OS_INPUT)?;
let full_output = os_input.full_output;

vm.insert_value((vm.get_fp() + 19)?, Felt252::from(full_output)).map_err(HintError::Memory)
Expand Down Expand Up @@ -63,7 +64,7 @@ pub fn set_ap_to_prev_block_hash(
_ap_tracking: &ApTracking,
_constants: &HashMap<String, Felt252>,
) -> Result<(), HintError> {
let os_input: &StarknetOsInput = exec_scopes.get_ref(vars::scopes::OS_INPUT)?;
let os_input: Arc<StarknetOsInput> = exec_scopes.get(vars::scopes::OS_INPUT)?;
insert_value_into_ap(vm, os_input.prev_block_hash)?;

Ok(())
Expand All @@ -78,7 +79,7 @@ pub fn set_ap_to_new_block_hash(
_ap_tracking: &ApTracking,
_constants: &HashMap<String, Felt252>,
) -> Result<(), HintError> {
let os_input: &StarknetOsInput = exec_scopes.get_ref(vars::scopes::OS_INPUT)?;
let os_input: Arc<StarknetOsInput> = exec_scopes.get(vars::scopes::OS_INPUT)?;
insert_value_into_ap(vm, os_input.new_block_hash)?;

Ok(())
Expand Down
19 changes: 12 additions & 7 deletions crates/starknet-os/src/hints/state.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
use std::collections::HashMap;
use std::sync::Arc;

use cairo_vm::hint_processor::builtin_hint_processor::hint_utils::{
get_integer_from_var_name, get_ptr_from_var_name, get_relocatable_from_var_name, insert_value_from_var_name,
Expand Down Expand Up @@ -51,7 +52,7 @@ pub fn set_preimage_for_state_commitments(
ap_tracking: &ApTracking,
constants: &HashMap<String, Felt252>,
) -> Result<(), HintError> {
let os_input = exec_scopes.get::<StarknetOsInput>(vars::scopes::OS_INPUT)?;
let os_input = exec_scopes.get::<Arc<StarknetOsInput>>(vars::scopes::OS_INPUT)?;
insert_value_from_var_name(
vars::ids::INITIAL_ROOT,
os_input.contract_state_commitment_info.previous_root,
Expand All @@ -67,7 +68,8 @@ pub fn set_preimage_for_state_commitments(
ap_tracking,
)?;

let preimage = os_input.contract_state_commitment_info.commitment_facts;
// TODO: can we avoid this clone?
let preimage = os_input.contract_state_commitment_info.commitment_facts.clone();
exec_scopes.insert_value(vars::scopes::PREIMAGE, preimage);

let merkle_height = get_constant(vars::constants::MERKLE_HEIGHT, constants)?;
Expand All @@ -94,7 +96,7 @@ pub fn set_preimage_for_class_commitments(
ap_tracking: &ApTracking,
constants: &HashMap<String, Felt252>,
) -> Result<(), HintError> {
let os_input = exec_scopes.get::<StarknetOsInput>(vars::scopes::OS_INPUT)?;
let os_input = exec_scopes.get::<Arc<StarknetOsInput>>(vars::scopes::OS_INPUT)?;
insert_value_from_var_name(
vars::ids::INITIAL_ROOT,
os_input.contract_class_commitment_info.previous_root,
Expand All @@ -113,7 +115,8 @@ pub fn set_preimage_for_class_commitments(
log::debug!("Setting class trie mode");
exec_scopes.data[0].insert(vars::scopes::PATRICIA_TREE_MODE.to_string(), any_box!(PatriciaTreeMode::Class));

let preimage = os_input.contract_class_commitment_info.commitment_facts;
// TODO: can we avoid this clone?
let preimage = os_input.contract_class_commitment_info.commitment_facts.clone();
exec_scopes.insert_value(vars::scopes::PREIMAGE, preimage);

let merkle_height = get_constant(vars::constants::MERKLE_HEIGHT, constants)?;
Expand Down Expand Up @@ -163,6 +166,7 @@ pub fn set_preimage_for_current_commitment_info(
ap_tracking,
)?;

// TODO: can we avoid this clone?
let preimage = commitment_info.commitment_facts.clone();

let merkle_height = get_constant(vars::constants::MERKLE_HEIGHT, constants)?;
Expand Down Expand Up @@ -247,6 +251,7 @@ pub fn load_bottom(
let edge = get_relocatable_from_var_name(vars::ids::EDGE, vm, ids_data, ap_tracking)?;
let edge_bottom = vm.get_integer((edge + NodeEdge::bottom_offset())?)?;

// TODO: avoid clone here
let preimage: Preimage = exec_scopes.get(vars::scopes::PREIMAGE)?;
let preimage_vec = preimage
.get(&edge_bottom)
Expand Down Expand Up @@ -320,7 +325,7 @@ where
PCS: PerContractStorage + 'static,
{
let execution_helper: ExecutionHelperWrapper<PCS> = exec_scopes.get(vars::scopes::EXECUTION_HELPER)?;
let os_input: StarknetOsInput = exec_scopes.get(vars::scopes::OS_INPUT)?;
let os_input = exec_scopes.get::<Arc<StarknetOsInput>>(vars::scopes::OS_INPUT)?;

let commitment_info_by_address = execute_coroutine(execution_helper.compute_storage_commitments())??;

Expand Down Expand Up @@ -411,7 +416,7 @@ mod tests {
]);

let mut exec_scopes: ExecutionScopes = Default::default();
exec_scopes.insert_value(vars::scopes::OS_INPUT, os_input);
exec_scopes.insert_value(vars::scopes::OS_INPUT, Arc::new(os_input));

set_preimage_for_state_commitments(&mut vm, &mut exec_scopes, &ids_data, &ap_tracking, &constants).unwrap();

Expand Down Expand Up @@ -443,7 +448,7 @@ mod tests {
]);

let mut exec_scopes: ExecutionScopes = Default::default();
exec_scopes.insert_value(vars::scopes::OS_INPUT, os_input);
exec_scopes.insert_value(vars::scopes::OS_INPUT, Arc::new(os_input));

set_preimage_for_class_commitments(&mut vm, &mut exec_scopes, &ids_data, &ap_tracking, &constants).unwrap();

Expand Down
2 changes: 1 addition & 1 deletion crates/starknet-os/src/io/input.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ use crate::starknet::business_logic::fact_state::contract_state_objects::Contrac
use crate::starknet::starknet_storage::CommitmentInfo;
use crate::utils::Felt252HexNoPrefix;

#[derive(Debug, Clone, Deserialize, Serialize)]
#[derive(Debug, Deserialize, Serialize)]
pub struct StarknetOsInput {
pub contract_state_commitment_info: CommitmentInfo,
pub contract_class_commitment_info: CommitmentInfo,
Expand Down
4 changes: 3 additions & 1 deletion crates/starknet-os/src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
use std::sync::Arc;

use blockifier::context::BlockContext;
use cairo_vm::cairo_run::CairoRunConfig;
use cairo_vm::types::layout_name::LayoutName;
Expand Down Expand Up @@ -32,7 +34,7 @@ pub mod utils;
pub fn run_os<PCS>(
compiled_os: &[u8],
layout: LayoutName,
os_input: StarknetOsInput,
os_input: Arc<StarknetOsInput>,
block_context: BlockContext,
execution_helper: ExecutionHelperWrapper<PCS>,
) -> Result<(CairoPie, StarknetOsOutput), SnOsError>
Expand Down
7 changes: 4 additions & 3 deletions tests/integration/common/block_utils.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
use std::collections::{HashMap, HashSet};
use std::sync::Arc;

use blockifier::context::BlockContext;
use blockifier::execution::contract_class::ContractClass::{V0, V1};
Expand Down Expand Up @@ -33,7 +34,7 @@ pub async fn os_hints<S>(
deprecated_compiled_classes: HashMap<ClassHash, GenericDeprecatedCompiledClass>,
compiled_classes: HashMap<ClassHash, GenericCasmContractClass>,
declared_class_hash_to_component_hashes: HashMap<ClassHash, ContractClassComponentHashes>,
) -> (StarknetOsInput, ExecutionHelperWrapper<OsSingleStarknetStorage<S, PedersenHash>>)
) -> (Arc<StarknetOsInput>, ExecutionHelperWrapper<OsSingleStarknetStorage<S, PedersenHash>>)
where
S: Storage,
{
Expand Down Expand Up @@ -198,7 +199,7 @@ where
.map(|(class_hash, components)| (class_hash.0, components.to_vec()))
.collect();

let os_input = StarknetOsInput {
let os_input = Arc::new(StarknetOsInput {
contract_state_commitment_info,
contract_class_commitment_info,
deprecated_compiled_classes,
Expand All @@ -213,7 +214,7 @@ where
new_block_hash: Default::default(),
prev_block_hash: Default::default(),
full_output: false,
};
});

let execution_helper = ExecutionHelperWrapper::new(
contract_storage_map,
Expand Down
3 changes: 2 additions & 1 deletion tests/integration/common/transaction_utils.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
use std::collections::HashMap;
use std::sync::Arc;

use blockifier::abi::abi_utils::selector_from_name;
use blockifier::context::BlockContext;
Expand Down Expand Up @@ -830,7 +831,7 @@ async fn execute_txs<S>(
deprecated_contract_classes: HashMap<ClassHash, GenericDeprecatedCompiledClass>,
contract_classes: HashMap<ClassHash, GenericCasmContractClass>,
declared_class_hash_to_component_hashes: HashMap<ClassHash, ContractClassComponentHashes>,
) -> (StarknetOsInput, ExecutionHelperWrapper<OsSingleStarknetStorage<S, PedersenHash>>)
) -> (Arc<StarknetOsInput>, ExecutionHelperWrapper<OsSingleStarknetStorage<S, PedersenHash>>)
where
S: Storage,
{
Expand Down

0 comments on commit 3ddd5ef

Please sign in to comment.