diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index 4708e57..efb51c7 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -27,4 +27,4 @@ jobs: - name: "Build all feature combinations" run: | - cargo hack build --package ${{ matrix.package }} --feature-powerset + cargo hack build --package ${{ matrix.package }} --feature-powerset --no-dev-deps --depth 3 diff --git a/crates/starknet-types-core/Cargo.toml b/crates/starknet-types-core/Cargo.toml index 7642300..4ba946e 100644 --- a/crates/starknet-types-core/Cargo.toml +++ b/crates/starknet-types-core/Cargo.toml @@ -53,6 +53,7 @@ prime-bigint = ["dep:lazy_static"] num-traits = [] papyrus-serialization = ["std"] secret-felt = ["alloc", "dep:zeroize", "dep:subtle", "subtle/const-generics", "subtle/core_hint_black_box", "dep:rand", "rand/alloc"] +devnet = ["alloc"] [dev-dependencies] proptest = { version = "1.5", default-features = false, features = [ diff --git a/crates/starknet-types-core/src/chain_id/alloc_impls.rs b/crates/starknet-types-core/src/chain_id/alloc_impls.rs new file mode 100644 index 0000000..2bf5625 --- /dev/null +++ b/crates/starknet-types-core/src/chain_id/alloc_impls.rs @@ -0,0 +1,191 @@ +#[cfg(not(feature = "std"))] +pub extern crate alloc; +#[cfg(not(feature = "std"))] +use alloc::string::{String, ToString}; + +use crate::short_string; +use crate::short_string::ShortString; + +use super::{ChainId, SN_MAIN_STR, SN_SEPOLIA_STR}; + +impl From for ShortString { + fn from(value: ChainId) -> Self { + match value { + ChainId::Mainnet => short_string!("SN_MAIN"), + ChainId::Sepolia => short_string!("SN_SEPOLIA"), + #[cfg(feature = "devnet")] + ChainId::Devnet(ss) => ss, + } + } +} + +#[cfg(feature = "devnet")] +impl From for ChainId { + fn from(value: ShortString) -> Self { + if value.as_ref() == SN_MAIN_STR { + ChainId::Mainnet + } else if value.as_ref() == SN_SEPOLIA_STR { + ChainId::Sepolia + } else { + ChainId::Devnet(value) + } + } +} + +#[cfg(not(feature = "devnet"))] +mod try_chain_id_from_short_string { + use crate::chain_id::{SN_MAIN_STR, SN_SEPOLIA_STR}; + + use super::*; + + #[derive(Debug, Clone)] + pub struct TryChainIdFromShortStringError(ShortString); + + impl core::fmt::Display for TryChainIdFromShortStringError { + fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { + write!(f, "unknown chain id: {}", self.0) + } + } + + #[cfg(feature = "std")] + impl std::error::Error for TryChainIdFromShortStringError {} + + impl TryFrom for ChainId { + type Error = TryChainIdFromShortStringError; + + fn try_from(value: ShortString) -> Result { + if value.as_ref() == SN_MAIN_STR { + Ok(ChainId::Mainnet) + } else if value.as_ref() == SN_SEPOLIA_STR { + Ok(ChainId::Sepolia) + } else { + Err(TryChainIdFromShortStringError(value)) + } + } + } +} +#[cfg(not(feature = "devnet"))] +pub use try_chain_id_from_short_string::*; + +// String + +impl From for String { + fn from(value: ChainId) -> Self { + match value { + ChainId::Mainnet => SN_MAIN_STR.to_string(), + ChainId::Sepolia => SN_SEPOLIA_STR.to_string(), + #[cfg(feature = "devnet")] + ChainId::Devnet(ss) => ss.to_string(), + } + } +} + +#[cfg(not(feature = "devnet"))] +impl From for &str { + fn from(value: ChainId) -> Self { + match value { + ChainId::Mainnet => SN_MAIN_STR, + ChainId::Sepolia => SN_SEPOLIA_STR, + } + } +} + +#[derive(Debug, Clone, Copy)] +#[cfg(feature = "devnet")] +pub struct TryChainIdFromStringError(pub(super) crate::short_string::TryShortStringFromStringError); + +#[derive(Debug, Clone)] +#[cfg(not(feature = "devnet"))] +pub struct TryChainIdFromStringError(pub(super) String); + +impl core::fmt::Display for TryChainIdFromStringError { + fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { + #[cfg(feature = "devnet")] + write!(f, "failed to parse string as ShortString: {}", self.0)?; + + #[cfg(not(feature = "devnet"))] + write!(f, "unknown chain id: {}", self.0)?; + + Ok(()) + } +} + +#[cfg(feature = "std")] +impl std::error::Error for TryChainIdFromStringError {} + +impl TryFrom for ChainId { + type Error = TryChainIdFromStringError; + + fn try_from(value: String) -> Result { + if value == SN_MAIN_STR { + return Ok(ChainId::Mainnet); + } else if value == SN_SEPOLIA_STR { + return Ok(ChainId::Sepolia); + } + + #[cfg(feature = "devnet")] + match ShortString::try_from(value) { + Ok(ss) => Ok(ChainId::Devnet(ss)), + Err(e) => Err(TryChainIdFromStringError(e)), + } + + #[cfg(not(feature = "devnet"))] + Err(TryChainIdFromStringError(value)) + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn short_string_and_chain_id_round_trip() { + let ss = short_string!("SN_MAIN"); + let chain_id = ChainId::try_from(ss.clone()).unwrap(); + assert_eq!(chain_id.to_string(), ss.to_string()); + + let ss = short_string!("SN_SEPOLIA"); + let chain_id = ChainId::try_from(ss.clone()).unwrap(); + assert_eq!(chain_id.to_string(), ss.to_string()); + + #[cfg(not(feature = "devnet"))] + { + let ss = short_string!("SN_DEVNET"); + assert!(ChainId::try_from(ss).is_err()); + } + #[cfg(feature = "devnet")] + { + let ss = short_string!("SN_DEVNET"); + let chain_id = ChainId::try_from(ss.clone()).unwrap(); + assert_eq!(ss.to_string(), chain_id.to_string()); + } + } + + #[test] + fn string_and_chain_id_round_trip() { + let s = String::from(SN_MAIN_STR); + let chain_id = ChainId::try_from(s.clone()).unwrap(); + assert_eq!(chain_id.to_string(), s.to_string()); + + let s = String::from(SN_SEPOLIA_STR); + let chain_id = ChainId::try_from(s.clone()).unwrap(); + assert_eq!(chain_id.to_string(), s.to_string()); + + #[cfg(not(feature = "devnet"))] + { + let s = String::from("SN_DEVNET"); + assert!(ChainId::try_from(s).is_err()); + } + #[cfg(feature = "devnet")] + { + let s = String::from("SN_DEVNET"); + let chain_id = ChainId::try_from(s.clone()).unwrap(); + assert_eq!(s, chain_id.to_string()); + + let s = String::from("SN_DEVNET_LOOOOOOOOOOOOOOOOOOOOOOOOOOOOOOOOOOOOOOOOOONG"); + assert!(ChainId::try_from(s).is_err()); + let s = String::from("SN_DEVNET_🌟"); + assert!(ChainId::try_from(s).is_err()); + } + } +} diff --git a/crates/starknet-types-core/src/chain_id/mod.rs b/crates/starknet-types-core/src/chain_id/mod.rs new file mode 100644 index 0000000..85a065b --- /dev/null +++ b/crates/starknet-types-core/src/chain_id/mod.rs @@ -0,0 +1,241 @@ +#[cfg(feature = "alloc")] +pub extern crate alloc; +use core::str::FromStr; + +use crate::felt::Felt; + +#[cfg(feature = "alloc")] +mod alloc_impls; +#[cfg(feature = "devnet")] +use crate::short_string::ShortString; +#[cfg(feature = "alloc")] +pub use alloc_impls::*; + +#[derive(Debug, Clone)] +pub enum ChainId { + Mainnet, + Sepolia, + #[cfg(feature = "devnet")] + Devnet(ShortString), +} + +pub const SN_MAIN_STR: &str = "SN_MAIN"; +pub const SN_MAIN: Felt = Felt::from_raw([ + 502562008147966918, + 18446744073709551615, + 18446744073709551615, + 17696389056366564951, +]); + +pub const SN_SEPOLIA_STR: &str = "SN_SEPOLIA"; +pub const SN_SEPOLIA: Felt = Felt::from_raw([ + 507980251676163170, + 18446744073709551615, + 18446744073708869172, + 1555806712078248243, +]); + +impl core::fmt::Display for ChainId { + fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { + match self { + ChainId::Mainnet => core::fmt::Display::fmt(SN_MAIN_STR, f), + ChainId::Sepolia => core::fmt::Display::fmt(SN_SEPOLIA_STR, f), + #[cfg(feature = "devnet")] + ChainId::Devnet(ss) => core::fmt::Display::fmt(ss, f), + } + } +} + +impl AsRef for ChainId { + fn as_ref(&self) -> &str { + match self { + ChainId::Mainnet => SN_MAIN_STR, + ChainId::Sepolia => SN_SEPOLIA_STR, + #[cfg(feature = "devnet")] + ChainId::Devnet(ss) => ss.as_ref(), + } + } +} + +// Felt + +impl From for Felt { + fn from(value: ChainId) -> Self { + match value { + ChainId::Mainnet => SN_MAIN, + ChainId::Sepolia => SN_SEPOLIA, + #[cfg(feature = "devnet")] + ChainId::Devnet(id) => id.into(), + } + } +} + +#[derive(Debug, Clone, Copy)] +#[cfg(not(feature = "devnet"))] +pub struct TryChainIdFormFeltError(Felt); + +#[derive(Debug, Clone, Copy)] +#[cfg(feature = "devnet")] +pub struct TryChainIdFormFeltError(crate::short_string::TryShortStringFromFeltError, Felt); + +impl core::fmt::Display for TryChainIdFormFeltError { + fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { + #[cfg(not(feature = "devnet"))] + write!(f, "unknown chain id `{:#}`", self.0)?; + + #[cfg(feature = "devnet")] + write!( + f, + "invalid felt for chain id `{}`. Must be a valid ShortString: {}", + self.1, self.0 + )?; + + Ok(()) + } +} + +#[cfg(feature = "std")] +impl std::error::Error for TryChainIdFormFeltError {} + +impl TryFrom for ChainId { + type Error = TryChainIdFormFeltError; + + fn try_from(value: Felt) -> Result { + if value == SN_MAIN { + return Ok(ChainId::Mainnet); + } + if value == SN_SEPOLIA { + return Ok(ChainId::Sepolia); + } + + #[cfg(feature = "devnet")] + match ShortString::try_from(value) { + Ok(ss) => Ok(ChainId::Devnet(ss)), + Err(e) => Err(TryChainIdFormFeltError(e, value)), + } + + #[cfg(not(feature = "devnet"))] + Err(TryChainIdFormFeltError(value)) + } +} + +// str + +#[derive(Debug, Clone, Copy)] +#[cfg(feature = "devnet")] +pub struct TryChainIdFromStrError(crate::short_string::TryShortStringFromStringError); + +#[derive(Debug, Clone)] +#[cfg(all(not(feature = "devnet"), feature = "alloc"))] +pub struct TryChainIdFromStrError(alloc::string::String); + +#[derive(Debug, Clone)] +#[cfg(all(not(feature = "devnet"), not(feature = "alloc")))] +pub struct TryChainIdFromStrError; + +impl core::fmt::Display for TryChainIdFromStrError { + fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { + #[cfg(feature = "devnet")] + write!(f, "failed to parse string as ShortString: {}", self.0)?; + + #[cfg(all(not(feature = "devnet"), feature = "alloc"))] + write!(f, "unknown chain id: {}", self.0)?; + + #[cfg(all(not(feature = "devnet"), not(feature = "alloc")))] + write!(f, "unknown chain id")?; + + Ok(()) + } +} + +impl FromStr for ChainId { + type Err = TryChainIdFromStrError; + + fn from_str(value: &str) -> Result { + if value == SN_MAIN_STR { + return Ok(ChainId::Mainnet); + } else if value == SN_SEPOLIA_STR { + return Ok(ChainId::Sepolia); + } + + #[cfg(feature = "devnet")] + return match ShortString::from_str(value) { + Ok(ss) => Ok(ChainId::Devnet(ss)), + Err(e) => Err(TryChainIdFromStrError(e)), + }; + + #[cfg(all(not(feature = "devnet"), feature = "alloc"))] + return Err(TryChainIdFromStrError(alloc::string::ToString::to_string( + value, + ))); + + #[cfg(all(not(feature = "devnet"), not(feature = "alloc")))] + return Err(TryChainIdFromStrError); + } +} + +#[cfg(test)] +mod tests { + use core::str::FromStr; + + use crate::{ + chain_id::{SN_MAIN_STR, SN_SEPOLIA_STR}, + felt::Felt, + }; + + use super::{ChainId, SN_MAIN, SN_SEPOLIA}; + + #[test] + fn felt_and_chain_id_round_trip() { + let felt = SN_MAIN; + let chain_id = ChainId::try_from(felt).unwrap(); + assert_eq!(Felt::from(chain_id), felt); + + let felt = SN_SEPOLIA; + let chain_id = ChainId::try_from(felt).unwrap(); + assert_eq!(Felt::from(chain_id), felt); + + #[cfg(feature = "devnet")] + { + let felt = Felt::from_hex_unwrap("0x63616665"); + let chain_id = ChainId::try_from(felt).unwrap(); + assert_eq!(Felt::from(chain_id), felt); + + // Non ascii + let felt = Felt::from_hex_unwrap("0x1234567890"); + assert!(ChainId::try_from(felt).is_err()); + // Non too long + let felt = Felt::from_hex_unwrap( + "0x6363636363636363636363636363636363636363636363636363636363636363", + ); + assert!(ChainId::try_from(felt).is_err()); + } + } + + #[test] + fn str_and_chain_id_round_trip() { + let s = SN_MAIN_STR; + let chain_id = ChainId::from_str(s).unwrap(); + assert_eq!(chain_id.as_ref(), s); + + let s = SN_SEPOLIA_STR; + let chain_id = ChainId::from_str(s).unwrap(); + assert_eq!(chain_id.as_ref(), s); + + #[cfg(not(feature = "devnet"))] + { + let s = "SN_DEVNET"; + assert!(ChainId::from_str(s).is_err()); + } + #[cfg(feature = "devnet")] + { + let s = "SN_DEVNET"; + let chain_id = ChainId::from_str(s).unwrap(); + assert_eq!(s, chain_id.as_ref()); + let s = "SN_DEVNET_LOOOOOOOOOOOOOOOOOOOOOOOOOOOOOOOOOOOOOOOOOONG"; + assert!(ChainId::from_str(s).is_err()); + let s = "SN_DEVNET_🌟"; + assert!(ChainId::from_str(s).is_err()); + } + } +} diff --git a/crates/starknet-types-core/src/felt/secret_felt.rs b/crates/starknet-types-core/src/felt/secret_felt.rs index a9f929c..d9a2d0c 100644 --- a/crates/starknet-types-core/src/felt/secret_felt.rs +++ b/crates/starknet-types-core/src/felt/secret_felt.rs @@ -4,7 +4,7 @@ use subtle::ConstantTimeEq; use zeroize::{Zeroize, Zeroizing}; #[cfg(not(feature = "std"))] -use super::alloc::{boxed::Box, string::String, vec::Vec}; +use super::alloc::{boxed::Box, string::String}; /// A wrapper for a [Felt] that ensures the value is securely zeroized when dropped. /// diff --git a/crates/starknet-types-core/src/lib.rs b/crates/starknet-types-core/src/lib.rs index bda7d42..8190b5d 100644 --- a/crates/starknet-types-core/src/lib.rs +++ b/crates/starknet-types-core/src/lib.rs @@ -8,6 +8,8 @@ pub mod hash; pub mod felt; pub mod qm31; -#[cfg(any(feature = "std", feature = "alloc"))] +#[cfg(feature = "alloc")] pub mod short_string; pub mod u256; + +pub mod chain_id; diff --git a/crates/starknet-types-core/src/short_string/mod.rs b/crates/starknet-types-core/src/short_string/mod.rs index 53ead6f..a8ade64 100644 --- a/crates/starknet-types-core/src/short_string/mod.rs +++ b/crates/starknet-types-core/src/short_string/mod.rs @@ -9,6 +9,8 @@ //! //! The convesion to `Felt` is done by using the internal ascii short string as bytes and parse those as a big endian number. +use core::str::FromStr; + #[cfg(not(feature = "std"))] use crate::felt::alloc::string::{String, ToString}; use crate::felt::Felt; @@ -27,6 +29,12 @@ impl core::fmt::Display for ShortString { } } +impl AsRef for ShortString { + fn as_ref(&self) -> &str { + &self.0 + } +} + impl From for Felt { fn from(ss: ShortString) -> Self { let bytes = ss.0.as_bytes(); @@ -42,7 +50,45 @@ impl From for Felt { } } -#[derive(Debug, Clone)] +#[derive(Debug, Copy, Clone)] +pub enum TryShortStringFromFeltError { + TooLong, + NonAscii, +} + +impl core::fmt::Display for TryShortStringFromFeltError { + fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { + match self { + TryShortStringFromFeltError::TooLong => "string to long", + TryShortStringFromFeltError::NonAscii => "string contains non ascii characters", + } + .fmt(f) + } +} + +impl TryFrom for ShortString { + type Error = TryShortStringFromFeltError; + + fn try_from(value: Felt) -> Result { + let bytes = value.to_bytes_be(); + if bytes[0] != 0 { + return Err(TryShortStringFromFeltError::TooLong); + } + let first_non_zero_byte = match bytes.iter().position(|&v| v != 0) { + Some(i) => i, + None => return Ok(ShortString(String::new())), + }; + if !bytes[first_non_zero_byte..].is_ascii() { + return Err(TryShortStringFromFeltError::NonAscii); + } + + let s = unsafe { str::from_utf8_unchecked(&bytes[first_non_zero_byte..]) }; + + Ok(ShortString(s.to_string())) + } +} + +#[derive(Debug, Copy, Clone)] #[cfg_attr(test, derive(PartialEq, Eq))] pub enum TryShortStringFromStringError { TooLong, @@ -101,10 +147,10 @@ impl Felt { } } -impl TryFrom<&str> for ShortString { - type Error = TryShortStringFromStringError; +impl FromStr for ShortString { + type Err = TryShortStringFromStringError; - fn try_from(value: &str) -> Result { + fn from_str(value: &str) -> Result { if value.len() > 31 { return Err(TryShortStringFromStringError::TooLong); } @@ -116,10 +162,88 @@ impl TryFrom<&str> for ShortString { } } +/// Create a `ShortString` at compile time from a string literal. +/// +/// This macro validates at compile time that the string: +/// - Contains only ASCII characters +/// - Is no longer than 31 characters +/// +/// # Panics +/// +/// Panics at compile time if the string is invalid. +/// +/// # Examples +/// +/// ``` +/// use starknet_types_core::{short_string}; +/// +/// let ss = short_string!("Hello, Cairo!"); +/// assert_eq!(ss.to_string(), "Hello, Cairo!"); +/// +/// // This would fail to compile: +/// // let ss = short_string!("This string is way too long for a Cairo short string"); +/// ``` +#[macro_export] +macro_rules! short_string { + ($s:literal) => {{ + const _: () = { + let bytes = $s.as_bytes(); + assert!( + bytes.len() <= 31, + "Short string must be at most 31 characters" + ); + assert!( + bytes.is_ascii(), + "Short string must contain only ASCII characters" + ); + }; + + // Safety: We've validated the string at compile time + match <$crate::short_string::ShortString as core::str::FromStr>::from_str($s) { + Ok(ss) => ss, + Err(_) => unreachable!("compile-time validation should prevent this"), + } + }}; +} + #[cfg(test)] mod tests { + use crate::chain_id::{SN_MAIN, SN_MAIN_STR, SN_SEPOLIA, SN_SEPOLIA_STR}; + use super::*; + #[test] + fn test_short_string_macro() { + let ss = short_string!("test"); + assert_eq!(ss.to_string(), "test"); + + let ss = short_string!("SN_MAIN"); + assert_eq!(ss.to_string(), SN_MAIN_STR); + + let ss = short_string!("This is a 31 characters string."); + assert_eq!(ss.to_string(), "This is a 31 characters string."); + + let ss = short_string!(""); + assert_eq!(ss.to_string(), ""); + } + + #[test] + fn short_string_and_felt_full_round() { + let ss1 = ShortString::from_str("A short string").unwrap(); + let f = Felt::from(ss1.clone()); + let ss2 = ShortString::try_from(f).unwrap(); + + assert_eq!(ss1, ss2); + } + + #[test] + fn chain_ids() { + let ss = ShortString::try_from(SN_MAIN).unwrap(); + assert_eq!(ss.to_string(), SN_MAIN_STR.to_string()); + let ss = ShortString::try_from(SN_SEPOLIA).unwrap(); + assert_eq!(ss.to_string(), SN_SEPOLIA_STR.to_string()); + } + #[test] fn ok() { for (string, expected_felt) in [ @@ -130,7 +254,7 @@ mod tests { Felt::from_hex_unwrap("0x617070726f7665"), ), ( - String::from("SN_SEPOLIA"), + String::from(SN_SEPOLIA_STR), Felt::from_raw([ 507980251676163170, 18446744073709551615,