From 104355171938e99ad3567aebbfad50d1ce3b5af3 Mon Sep 17 00:00:00 2001 From: DaniPopes <57450786+DaniPopes@users.noreply.github.com> Date: Sun, 20 Aug 2023 17:42:06 +0200 Subject: [PATCH] refactor: optimize trace identifiers --- crates/chisel/src/dispatcher.rs | 14 +- crates/cli/src/utils/cmd.rs | 13 +- crates/evm/src/trace/decoder.rs | 74 +++--- crates/evm/src/trace/identifier/etherscan.rs | 10 +- crates/evm/src/trace/identifier/local.rs | 49 ++-- crates/evm/src/trace/identifier/mod.rs | 8 +- crates/evm/src/trace/mod.rs | 63 +++-- crates/evm/src/trace/utils.rs | 6 +- crates/forge/bin/cmd/script/mod.rs | 11 +- crates/forge/bin/cmd/test/mod.rs | 237 +++++++++---------- crates/forge/src/result.rs | 22 +- 11 files changed, 275 insertions(+), 232 deletions(-) diff --git a/crates/chisel/src/dispatcher.rs b/crates/chisel/src/dispatcher.rs index 1115514749b9..9aa4509f7884 100644 --- a/crates/chisel/src/dispatcher.rs +++ b/crates/chisel/src/dispatcher.rs @@ -958,13 +958,13 @@ impl ChiselDispatcher { session_config.evm_opts.get_remote_chain_id(), )?; - let mut decoder = - CallTraceDecoderBuilder::new().with_labels(result.labeled_addresses.clone()).build(); - - decoder.add_signature_identifier(SignaturesIdentifier::new( - Config::foundry_cache_dir(), - session_config.foundry_config.offline, - )?); + let mut decoder = CallTraceDecoderBuilder::new() + .with_labels(result.labeled_addresses.iter().map(|(a, s)| (*a, s.clone()))) + .with_signature_identifier(SignaturesIdentifier::new( + Config::foundry_cache_dir(), + session_config.foundry_config.offline, + )?) + .build(); for (_, trace) in &mut result.traces { // decoder.identify(trace, &mut local_identifier); diff --git a/crates/cli/src/utils/cmd.rs b/crates/cli/src/utils/cmd.rs index e41790a27fcf..4323ddb2c646 100644 --- a/crates/cli/src/utils/cmd.rs +++ b/crates/cli/src/utils/cmd.rs @@ -378,12 +378,13 @@ pub async fn handle_traces( None }); - let mut decoder = CallTraceDecoderBuilder::new().with_labels(labeled_addresses).build(); - - decoder.add_signature_identifier(SignaturesIdentifier::new( - Config::foundry_cache_dir(), - config.offline, - )?); + let mut decoder = CallTraceDecoderBuilder::new() + .with_labels(labeled_addresses) + .with_signature_identifier(SignaturesIdentifier::new( + Config::foundry_cache_dir(), + config.offline, + )?) + .build(); for (_, trace) in &mut result.traces { decoder.identify(trace, &mut etherscan_identifier); diff --git a/crates/evm/src/trace/decoder.rs b/crates/evm/src/trace/decoder.rs index c2a51c6dad77..39c00e38a4d6 100644 --- a/crates/evm/src/trace/decoder.rs +++ b/crates/evm/src/trace/decoder.rs @@ -1,5 +1,5 @@ use super::{ - identifier::{SingleSignaturesIdentifier, TraceIdentifier}, + identifier::{AddressIdentity, SingleSignaturesIdentifier, TraceIdentifier}, CallTraceArena, RawOrDecodedCall, RawOrDecodedLog, RawOrDecodedReturnData, }; use crate::{ @@ -20,22 +20,27 @@ use std::collections::{BTreeMap, HashMap}; /// Build a new [CallTraceDecoder]. #[derive(Default)] +#[must_use = "builders do nothing unless you call `build` on them"] pub struct CallTraceDecoderBuilder { decoder: CallTraceDecoder, } impl CallTraceDecoderBuilder { + /// Create a new builder. + #[inline] pub fn new() -> Self { Self { decoder: CallTraceDecoder::new().clone() } } /// Add known labels to the decoder. + #[inline] pub fn with_labels(mut self, labels: impl IntoIterator) -> Self { self.decoder.labels.extend(labels); self } /// Add known events to the decoder. + #[inline] pub fn with_events(mut self, events: impl IntoIterator) -> Self { for event in events { self.decoder @@ -48,12 +53,21 @@ impl CallTraceDecoderBuilder { } /// Sets the verbosity level of the decoder. + #[inline] pub fn with_verbosity(mut self, level: u8) -> Self { self.decoder.verbosity = level; self } + /// Sets the signature identifier for events and functions. + #[inline] + pub fn with_signature_identifier(mut self, identifier: SingleSignaturesIdentifier) -> Self { + self.decoder.signature_identifier = Some(identifier); + self + } + /// Build the decoder. + #[inline] pub fn build(self) -> CallTraceDecoder { self.decoder } @@ -168,23 +182,26 @@ impl CallTraceDecoder { } } - pub fn add_signature_identifier(&mut self, identifier: SingleSignaturesIdentifier) { - self.signature_identifier = Some(identifier); - } - /// Identify unknown addresses in the specified call trace using the specified identifier. /// /// Unknown contracts are contracts that either lack a label or an ABI. + #[inline] pub fn identify(&mut self, trace: &CallTraceArena, identifier: &mut impl TraceIdentifier) { - let unidentified_addresses = trace - .addresses() - .into_iter() - .filter(|(address, _)| { - !self.labels.contains_key(address) || !self.contracts.contains_key(address) - }) - .collect(); - - identifier.identify_addresses(unidentified_addresses).iter().for_each(|identity| { + self.collect_identities(identifier.identify_addresses(self.addresses(trace))); + } + + #[inline(always)] + fn addresses<'a>( + &'a self, + trace: &'a CallTraceArena, + ) -> impl Iterator)> + 'a { + trace.addresses().into_iter().filter(|&(address, _)| { + !self.labels.contains_key(address) || !self.contracts.contains_key(address) + }) + } + + fn collect_identities(&mut self, identities: Vec) { + for identity in identities { let address = identity.address; if let Some(contract) = &identity.contract { @@ -197,30 +214,31 @@ impl CallTraceDecoder { if let Some(abi) = &identity.abi { // Store known functions for the address - abi.functions() - .map(|func| (func.short_signature(), func.clone())) - .for_each(|(sig, func)| self.functions.entry(sig).or_default().push(func)); + for function in abi.functions() { + self.functions + .entry(function.short_signature()) + .or_default() + .push(function.clone()) + } // Flatten events from all ABIs - abi.events() - .map(|event| ((event.signature(), indexed_inputs(event)), event.clone())) - .for_each(|(sig, event)| { - self.events.entry(sig).or_default().push(event); - }); + for event in abi.events() { + let sig = (event.signature(), indexed_inputs(event)); + self.events.entry(sig).or_default().push(event.clone()); + } // Flatten errors from all ABIs - abi.errors().for_each(|error| { - let entry = self.errors.errors.entry(error.name.clone()).or_default(); - entry.push(error.clone()); - }); + for error in abi.errors() { + self.errors.errors.entry(error.name.clone()).or_default().push(error.clone()); + } self.receive_contracts.entry(address).or_insert(abi.receive); } - }); + } } pub async fn decode(&self, traces: &mut CallTraceArena) { - for node in traces.arena.iter_mut() { + for node in &mut traces.arena { // Set contract name if let Some(contract) = self.contracts.get(&node.trace.address).cloned() { node.trace.contract = Some(contract); diff --git a/crates/evm/src/trace/identifier/etherscan.rs b/crates/evm/src/trace/identifier/etherscan.rs index e13570aa2448..b9bdf43feb10 100644 --- a/crates/evm/src/trace/identifier/etherscan.rs +++ b/crates/evm/src/trace/identifier/etherscan.rs @@ -100,11 +100,11 @@ impl EtherscanIdentifier { } impl TraceIdentifier for EtherscanIdentifier { - fn identify_addresses( - &mut self, - addresses: Vec<(&Address, Option<&[u8]>)>, - ) -> Vec { - trace!(target: "etherscanidentifier", "identify {} addresses", addresses.len()); + fn identify_addresses<'a, A>(&mut self, addresses: A) -> Vec + where + A: Iterator)>, + { + trace!(target: "etherscanidentifier", "identify {:?} addresses", addresses.size_hint().1); let Some(client) = self.client.clone() else { // no client was configured diff --git a/crates/evm/src/trace/identifier/local.rs b/crates/evm/src/trace/identifier/local.rs index 9f07720ab145..ec8fb0f668ca 100644 --- a/crates/evm/src/trace/identifier/local.rs +++ b/crates/evm/src/trace/identifier/local.rs @@ -1,56 +1,45 @@ use super::{AddressIdentity, TraceIdentifier}; -use ethers::{ - abi::{Abi, Address, Event}, - prelude::ArtifactId, -}; +use ethers::abi::{Address, Event}; use foundry_common::contracts::{diff_score, ContractsByArtifact}; -use itertools::Itertools; use ordered_float::OrderedFloat; -use std::{borrow::Cow, collections::BTreeMap}; +use std::borrow::Cow; /// A trace identifier that tries to identify addresses using local contracts. -pub struct LocalTraceIdentifier { - local_contracts: BTreeMap, (ArtifactId, Abi)>, +pub struct LocalTraceIdentifier<'a> { + known_contracts: &'a ContractsByArtifact, } -impl LocalTraceIdentifier { - pub fn new(known_contracts: &ContractsByArtifact) -> Self { - Self { - local_contracts: known_contracts - .iter() - .map(|(id, (abi, runtime_code))| (runtime_code.clone(), (id.clone(), abi.clone()))) - .collect(), - } +impl<'a> LocalTraceIdentifier<'a> { + pub fn new(known_contracts: &'a ContractsByArtifact) -> Self { + Self { known_contracts } } /// Get all the events of the local contracts. pub fn events(&self) -> impl Iterator { - self.local_contracts.iter().flat_map(|(_, (_, abi))| abi.events()) + self.known_contracts.iter().flat_map(|(_, (abi, _))| abi.events()) } } -impl TraceIdentifier for LocalTraceIdentifier { - fn identify_addresses( - &mut self, - addresses: Vec<(&Address, Option<&[u8]>)>, - ) -> Vec { +impl TraceIdentifier for LocalTraceIdentifier<'_> { + fn identify_addresses<'a, A>(&mut self, addresses: A) -> Vec + where + A: Iterator)>, + { addresses - .into_iter() .filter_map(|(address, code)| { let code = code?; - let (_, (_, (id, abi))) = self - .local_contracts + let (_, id, abi) = self + .known_contracts .iter() - .filter_map(|entry| { - let score = diff_score(entry.0, code); + .filter_map(|(id, (abi, known_code))| { + let score = diff_score(known_code, code); if score < 0.1 { - Some((OrderedFloat(score), entry)) + Some((OrderedFloat(score), id, abi)) } else { None } }) - .sorted_by_key(|(score, _)| *score) - .next()?; + .min_by_key(|(score, _, _)| *score)?; Some(AddressIdentity { address: *address, diff --git a/crates/evm/src/trace/identifier/mod.rs b/crates/evm/src/trace/identifier/mod.rs index 6fca5d10a274..dd755ee9082d 100644 --- a/crates/evm/src/trace/identifier/mod.rs +++ b/crates/evm/src/trace/identifier/mod.rs @@ -33,9 +33,7 @@ pub struct AddressIdentity<'a> { pub trait TraceIdentifier { // TODO: Update docs /// Attempts to identify an address in one or more call traces. - #[allow(clippy::type_complexity)] - fn identify_addresses( - &mut self, - addresses: Vec<(&Address, Option<&[u8]>)>, - ) -> Vec; + fn identify_addresses<'a, A>(&mut self, addresses: A) -> Vec + where + A: Iterator)>; } diff --git a/crates/evm/src/trace/mod.rs b/crates/evm/src/trace/mod.rs index 23b5d1fa47a6..af490b6448c8 100644 --- a/crates/evm/src/trace/mod.rs +++ b/crates/evm/src/trace/mod.rs @@ -561,13 +561,39 @@ impl fmt::Display for CallTrace { } /// Specifies the kind of trace. -#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)] +#[derive(Clone, Copy, Debug, PartialEq, Eq, Serialize, Deserialize)] pub enum TraceKind { Deployment, Setup, Execution, } +impl TraceKind { + /// Returns `true` if the trace kind is [`Deployment`]. + /// + /// [`Deployment`]: TraceKind::Deployment + #[must_use] + pub fn is_deployment(self) -> bool { + matches!(self, Self::Deployment) + } + + /// Returns `true` if the trace kind is [`Setup`]. + /// + /// [`Setup`]: TraceKind::Setup + #[must_use] + pub fn is_setup(self) -> bool { + matches!(self, Self::Setup) + } + + /// Returns `true` if the trace kind is [`Execution`]. + /// + /// [`Execution`]: TraceKind::Execution + #[must_use] + pub fn is_execution(self) -> bool { + matches!(self, Self::Execution) + } +} + /// Chooses the color of the trace depending on the destination address and status of the call. fn trace_color(trace: &CallTrace) -> Color { if trace.address == CHEATCODE_ADDRESS { @@ -584,26 +610,23 @@ pub fn load_contracts( traces: Traces, known_contracts: Option<&ContractsByArtifact>, ) -> ContractsByAddress { - if let Some(contracts) = known_contracts { - let mut local_identifier = LocalTraceIdentifier::new(contracts); - let mut decoder = CallTraceDecoderBuilder::new().build(); - for (_, trace) in &traces { - decoder.identify(trace, &mut local_identifier); - } - - decoder - .contracts - .iter() - .filter_map(|(addr, name)| { - if let Ok(Some((_, (abi, _)))) = contracts.find_by_name_or_identifier(name) { - return Some((*addr, (name.clone(), abi.clone()))) - } - None - }) - .collect() - } else { - BTreeMap::new() + let Some(contracts) = known_contracts else { return BTreeMap::new() }; + let mut local_identifier = LocalTraceIdentifier::new(contracts); + let mut decoder = CallTraceDecoderBuilder::new().build(); + for (_, trace) in &traces { + decoder.identify(trace, &mut local_identifier); } + + decoder + .contracts + .iter() + .filter_map(|(addr, name)| { + if let Ok(Some((_, (abi, _)))) = contracts.find_by_name_or_identifier(name) { + return Some((*addr, (name.clone(), abi.clone()))) + } + None + }) + .collect() } /// creates the memory data in 32byte chunks diff --git a/crates/evm/src/trace/utils.rs b/crates/evm/src/trace/utils.rs index e65d6c30346e..14d4f245d2f3 100644 --- a/crates/evm/src/trace/utils.rs +++ b/crates/evm/src/trace/utils.rs @@ -79,7 +79,7 @@ pub(crate) fn decode_cheatcode_inputs( "serializeBytes32" | "serializeString" | "serializeBytes" => { - if verbosity == 5 { + if verbosity >= 5 { None } else { let mut decoded = func.decode_input(&data[SELECTOR_LEN..]).ok()?; @@ -111,10 +111,10 @@ pub(crate) fn decode_cheatcode_outputs( // redacts derived private key return Some("".to_string()) } - if func.name == "parseJson" && verbosity != 5 { + if func.name == "parseJson" && verbosity < 5 { return Some("".to_string()) } - if func.name == "readFile" && verbosity != 5 { + if func.name == "readFile" && verbosity < 5 { return Some("".to_string()) } None diff --git a/crates/forge/bin/cmd/script/mod.rs b/crates/forge/bin/cmd/script/mod.rs index 1825644c3129..4e10066332f3 100644 --- a/crates/forge/bin/cmd/script/mod.rs +++ b/crates/forge/bin/cmd/script/mod.rs @@ -237,15 +237,14 @@ impl ScriptArgs { let mut local_identifier = LocalTraceIdentifier::new(known_contracts); let mut decoder = CallTraceDecoderBuilder::new() - .with_labels(result.labeled_addresses.clone()) + .with_labels(result.labeled_addresses.iter().map(|(a, s)| (*a, s.clone()))) .with_verbosity(verbosity) + .with_signature_identifier(SignaturesIdentifier::new( + Config::foundry_cache_dir(), + script_config.config.offline, + )?) .build(); - decoder.add_signature_identifier(SignaturesIdentifier::new( - Config::foundry_cache_dir(), - script_config.config.offline, - )?); - // Decoding traces using etherscan is costly as we run into rate limits, // causing scripts to run for a very long time unnecesarily. // Therefore, we only try and use etherscan if the user has provided an API key. diff --git a/crates/forge/bin/cmd/test/mod.rs b/crates/forge/bin/cmd/test/mod.rs index e5a0b560acbe..3771eec91762 100644 --- a/crates/forge/bin/cmd/test/mod.rs +++ b/crates/forge/bin/cmd/test/mod.rs @@ -537,151 +537,146 @@ async fn test( if json { let results = runner.test(filter, None, test_options).await; println!("{}", serde_json::to_string(&results)?); - Ok(TestOutcome::new(results, allow_failure)) - } else { - // Set up identifiers - let mut local_identifier = LocalTraceIdentifier::new(&runner.known_contracts); - let remote_chain_id = runner.evm_opts.get_remote_chain_id(); - // Do not re-query etherscan for contracts that you've already queried today. - let mut etherscan_identifier = EtherscanIdentifier::new(&config, remote_chain_id)?; + return Ok(TestOutcome::new(results, allow_failure)) + } - // Set up test reporter channel - let (tx, rx) = channel::<(String, SuiteResult)>(); + // Set up identifiers + let known_contracts = runner.known_contracts.clone(); + let mut local_identifier = LocalTraceIdentifier::new(&known_contracts); + let remote_chain_id = runner.evm_opts.get_remote_chain_id(); + // Do not re-query etherscan for contracts that you've already queried today. + let mut etherscan_identifier = EtherscanIdentifier::new(&config, remote_chain_id)?; - // Run tests - let handle = - tokio::task::spawn(async move { runner.test(filter, Some(tx), test_options).await }); + // Set up test reporter channel + let (tx, rx) = channel::<(String, SuiteResult)>(); - let mut results: BTreeMap = BTreeMap::new(); - let mut gas_report = GasReport::new(config.gas_reports, config.gas_reports_ignore); - let sig_identifier = - SignaturesIdentifier::new(Config::foundry_cache_dir(), config.offline)?; + // Run tests + let handle = + tokio::task::spawn(async move { runner.test(filter, Some(tx), test_options).await }); - let mut total_passed = 0; - let mut total_failed = 0; - let mut total_skipped = 0; + let mut results: BTreeMap = BTreeMap::new(); + let mut gas_report = GasReport::new(config.gas_reports, config.gas_reports_ignore); + let sig_identifier = SignaturesIdentifier::new(Config::foundry_cache_dir(), config.offline)?; - 'outer: for (contract_name, suite_result) in rx { - results.insert(contract_name.clone(), suite_result.clone()); + let mut total_passed = 0; + let mut total_failed = 0; + let mut total_skipped = 0; - let mut tests = suite_result.test_results.clone(); - println!(); - for warning in suite_result.warnings.iter() { - eprintln!("{} {warning}", Paint::yellow("Warning:").bold()); - } - if !tests.is_empty() { - let term = if tests.len() > 1 { "tests" } else { "test" }; - println!("Running {} {term} for {contract_name}", tests.len()); - } - for (name, result) in &mut tests { - short_test_result(name, result); + 'outer: for (contract_name, suite_result) in rx { + results.insert(contract_name.clone(), suite_result.clone()); - // If the test failed, we want to stop processing the rest of the tests - if fail_fast && result.status == TestStatus::Failure { - break 'outer - } + let mut tests = suite_result.test_results.clone(); + println!(); + for warning in suite_result.warnings.iter() { + eprintln!("{} {warning}", Paint::yellow("Warning:").bold()); + } + if !tests.is_empty() { + let term = if tests.len() > 1 { "tests" } else { "test" }; + println!("Running {} {term} for {contract_name}", tests.len()); + } + for (name, result) in &mut tests { + short_test_result(name, result); - // We only display logs at level 2 and above - if verbosity >= 2 { - // We only decode logs from Hardhat and DS-style console events - let console_logs = decode_console_logs(&result.logs); - if !console_logs.is_empty() { - println!("Logs:"); - for log in console_logs { - println!(" {log}"); - } - println!(); + // If the test failed, we want to stop processing the rest of the tests + if fail_fast && result.status == TestStatus::Failure { + break 'outer + } + + // We only display logs at level 2 and above + if verbosity >= 2 { + // We only decode logs from Hardhat and DS-style console events + let console_logs = decode_console_logs(&result.logs); + if !console_logs.is_empty() { + println!("Logs:"); + for log in console_logs { + println!(" {log}"); } + println!(); } + } - if !result.traces.is_empty() { - // Identify addresses in each trace - let mut decoder = CallTraceDecoderBuilder::new() - .with_labels(result.labeled_addresses.clone()) - .with_events(local_identifier.events().cloned()) - .with_verbosity(verbosity) - .build(); - - // Signatures are of no value for gas reports - if !gas_reporting { - decoder.add_signature_identifier(sig_identifier.clone()); - } + if result.traces.is_empty() { + continue + } - // Decode the traces - let mut decoded_traces = Vec::new(); - for (kind, trace) in &mut result.traces { - decoder.identify(trace, &mut local_identifier); - decoder.identify(trace, &mut etherscan_identifier); - - let should_include = match kind { - // At verbosity level 3, we only display traces for failed tests - // At verbosity level 4, we also display the setup trace for failed - // tests At verbosity level 5, we display - // all traces for all tests - TraceKind::Setup => { - (verbosity >= 5) || - (verbosity == 4 && result.status == TestStatus::Failure) - } - TraceKind::Execution => { - verbosity > 3 || - (verbosity == 3 && result.status == TestStatus::Failure) - } - _ => false, - }; - - // We decode the trace if we either need to build a gas report or we need - // to print it - if should_include || gas_reporting { - decoder.decode(trace).await; - } - - if should_include { - decoded_traces.push(trace.to_string()); - } - } + // Identify addresses in each trace + let mut builder = CallTraceDecoderBuilder::new() + .with_labels(result.labeled_addresses.iter().map(|(a, s)| (*a, s.clone()))) + .with_events(local_identifier.events().cloned()) + .with_verbosity(verbosity); - if !decoded_traces.is_empty() { - println!("Traces:"); - decoded_traces.into_iter().for_each(|trace| println!("{trace}")); - } + // Signatures are of no value for gas reports + if !gas_reporting { + builder = builder.with_signature_identifier(sig_identifier.clone()); + } - if gas_reporting { - gas_report.analyze(&result.traces); + let mut decoder = builder.build(); + + // Decode the traces + let mut decoded_traces = Vec::with_capacity(result.traces.len()); + for (kind, trace) in &mut result.traces { + decoder.identify(trace, &mut local_identifier); + decoder.identify(trace, &mut etherscan_identifier); + + // verbosity: + // - 0..3: nothing + // - 3: only display traces for failed tests + // - 4: also display the setup trace for failed tests + // - 5..: display all traces for all tests + let should_include = match kind { + TraceKind::Execution => { + (verbosity == 3 && result.status.is_failure()) || verbosity >= 4 + } + TraceKind::Setup => { + (verbosity == 4 && result.status.is_failure()) || verbosity >= 5 } + TraceKind::Deployment => false, + }; + + // Decode the trace if we either need to build a gas report or we need to print it + if should_include || gas_reporting { + decoder.decode(trace).await; + } + + if should_include { + decoded_traces.push(trace.to_string()); } } - let block_outcome = - TestOutcome::new([(contract_name, suite_result)].into(), allow_failure); - total_passed += block_outcome.successes().count(); - total_failed += block_outcome.failures().count(); - total_skipped += block_outcome.skips().count(); + if !decoded_traces.is_empty() { + println!("Traces:"); + decoded_traces.into_iter().for_each(|trace| println!("{trace}")); + } - println!("{}", block_outcome.summary()); + if gas_reporting { + gas_report.analyze(&result.traces); + } } + let block_outcome = TestOutcome::new([(contract_name, suite_result)].into(), allow_failure); - if gas_reporting { - println!("{}", gas_report.finalize()); - } + total_passed += block_outcome.successes().count(); + total_failed += block_outcome.failures().count(); + total_skipped += block_outcome.skips().count(); - let num_test_suites = results.len(); + println!("{}", block_outcome.summary()); + } - if num_test_suites > 0 { - println!( - "{}", - format_aggregated_summary( - num_test_suites, - total_passed, - total_failed, - total_skipped - ) - ); - } + if gas_reporting { + println!("{}", gas_report.finalize()); + } - // reattach the thread - let _results = handle.await?; + let num_test_suites = results.len(); - trace!(target: "forge::test", "received {} results", results.len()); - Ok(TestOutcome::new(results, allow_failure)) + if num_test_suites > 0 { + println!( + "{}", + format_aggregated_summary(num_test_suites, total_passed, total_failed, total_skipped) + ); } + + // reattach the thread + let _results = handle.await?; + + trace!(target: "forge::test", "received {} results", results.len()); + Ok(TestOutcome::new(results, allow_failure)) } diff --git a/crates/forge/src/result.rs b/crates/forge/src/result.rs index 1199a4d28f55..c6363e0695ce 100644 --- a/crates/forge/src/result.rs +++ b/crates/forge/src/result.rs @@ -58,7 +58,7 @@ impl SuiteResult { } } -#[derive(Clone, Debug, Serialize, Deserialize, Default, PartialEq, Eq)] +#[derive(Clone, Copy, Debug, Default, PartialEq, Eq, Serialize, Deserialize)] pub enum TestStatus { Success, #[default] @@ -66,6 +66,26 @@ pub enum TestStatus { Skipped, } +impl TestStatus { + /// Returns `true` if the test was successful. + #[inline] + pub fn is_success(self) -> bool { + matches!(self, Self::Success) + } + + /// Returns `true` if the test failed. + #[inline] + pub fn is_failure(self) -> bool { + matches!(self, Self::Failure) + } + + /// Returns `true` if the test was skipped. + #[inline] + pub fn is_skipped(self) -> bool { + matches!(self, Self::Skipped) + } +} + /// The result of an executed solidity test #[derive(Clone, Debug, Default, Serialize, Deserialize)] pub struct TestResult {