diff --git a/corelib/src/integer.cairo b/corelib/src/integer.cairo index 4cec3b961f7..630fc8dbe5f 100644 --- a/corelib/src/integer.cairo +++ b/corelib/src/integer.cairo @@ -1102,7 +1102,7 @@ pub fn u256_wide_mul(a: u256, b: u256) -> u512 nopanic { /// Helper function for implementation of `u256_wide_mul`. /// Used for adding two u128s and receiving a BoundedInt for the carry result. -fn u128_add_with_bounded_int_carry( +pub(crate) fn u128_add_with_bounded_int_carry( a: u128, b: u128 ) -> (u128, core::internal::bounded_int::BoundedInt<0, 1>) nopanic { match u128_overflowing_add(a, b) { diff --git a/corelib/src/num/traits.cairo b/corelib/src/num/traits.cairo index d07f9e697ee..8a202b41a9f 100644 --- a/corelib/src/num/traits.cairo +++ b/corelib/src/num/traits.cairo @@ -17,4 +17,5 @@ pub use ops::wrapping::{WrappingAdd, WrappingSub, WrappingMul}; pub use ops::checked::{CheckedAdd, CheckedSub, CheckedMul}; pub use ops::saturating::{SaturatingAdd, SaturatingSub, SaturatingMul}; pub use ops::widemul::WideMul; +pub use ops::widesquare::WideSquare; pub use ops::sqrt::Sqrt; diff --git a/corelib/src/num/traits/ops.cairo b/corelib/src/num/traits/ops.cairo index 8c9dcadd1e0..eda726a270c 100644 --- a/corelib/src/num/traits/ops.cairo +++ b/corelib/src/num/traits/ops.cairo @@ -4,3 +4,4 @@ pub mod checked; pub mod saturating; pub(crate) mod sqrt; pub(crate) mod widemul; +pub(crate) mod widesquare; diff --git a/corelib/src/num/traits/ops/widesquare.cairo b/corelib/src/num/traits/ops/widesquare.cairo new file mode 100644 index 00000000000..d0c707965d6 --- /dev/null +++ b/corelib/src/num/traits/ops/widesquare.cairo @@ -0,0 +1,61 @@ +use core::num::traits::WideMul; + +/// A trait for a type that can be squared to produce a wider type. +pub trait WideSquare { + /// The type of the result of the square. + type Target; + /// Calculates the square, producing a wider type. + fn wide_square(self: T) -> Self::Target; +} + +mod wide_mul_based { + pub impl TWideSquare, +Copy> of super::WideSquare { + type Target = TWideMul::Target; + fn wide_square(self: T) -> Self::Target { + TWideMul::wide_mul(self, self) + } + } +} + +impl WideSquareI8 = wide_mul_based::TWideSquare; +impl WideSquareI16 = wide_mul_based::TWideSquare; +impl WideSquareI32 = wide_mul_based::TWideSquare; +impl WideSquareI64 = wide_mul_based::TWideSquare; +impl WideSquareU8 = wide_mul_based::TWideSquare; +impl WideSquareU16 = wide_mul_based::TWideSquare; +impl WideSquareU32 = wide_mul_based::TWideSquare; +impl WideSquareU64 = wide_mul_based::TWideSquare; +impl WideSquareU128 = wide_mul_based::TWideSquare; +impl WideSquareU256 of WideSquare { + type Target = core::integer::u512; + fn wide_square(self: u256) -> Self::Target { + inner::u256_wide_square(self) + } +} + +mod inner { + use core::integer::{u512, u128_add_with_bounded_int_carry, upcast}; + use core::internal::bounded_int; + use core::num::traits::{WideSquare, WideMul, WrappingAdd}; + + pub fn u256_wide_square(value: u256) -> u512 { + let u256 { high: limb1, low: limb0 } = value.low.wide_square(); + let u256 { high: limb2, low: limb1_part } = value.low.wide_mul(value.high); + let (limb1, limb1_overflow0) = u128_add_with_bounded_int_carry(limb1, limb1_part); + let (limb1, limb1_overflow1) = u128_add_with_bounded_int_carry(limb1, limb1_part); + let (limb2, limb2_overflow0) = u128_add_with_bounded_int_carry(limb2, limb2); + let u256 { high: limb3, low: limb2_part } = value.high.wide_square(); + let (limb2, limb2_overflow1) = u128_add_with_bounded_int_carry(limb2, limb2_part); + // Packing together the overflow bits, making a cheaper addition into limb2. + let limb1_overflow = bounded_int::add(limb1_overflow0, limb1_overflow1); + let (limb2, limb2_overflow2) = u128_add_with_bounded_int_carry( + limb2, upcast(limb1_overflow) + ); + // Packing together the overflow bits, making a cheaper addition into limb3. + let limb2_overflow = bounded_int::add(limb2_overflow0, limb2_overflow1); + let limb2_overflow = bounded_int::add(limb2_overflow, limb2_overflow2); + // No overflow since no limb4. + let limb3 = limb3.wrapping_add(upcast(limb2_overflow)); + u512 { limb0, limb1, limb2, limb3 } + } +} diff --git a/corelib/src/test/integer_test.cairo b/corelib/src/test/integer_test.cairo index a204152deba..af10baf72c1 100644 --- a/corelib/src/test/integer_test.cairo +++ b/corelib/src/test/integer_test.cairo @@ -1,7 +1,7 @@ #[feature("deprecated-bounded-int-trait")] use core::{integer, integer::{u512_safe_div_rem_by_u256, u512}}; use core::test::test_utils::{assert_eq, assert_ne, assert_le, assert_lt, assert_gt, assert_ge}; -use core::num::traits::{Bounded, Sqrt, WideMul, WrappingSub}; +use core::num::traits::{Bounded, Sqrt, WideMul, WideSquare, WrappingSub}; #[test] fn test_u8_operators() { @@ -706,6 +706,29 @@ fn test_u256_wide_mul() { ); } +#[test] +fn test_u256_wide_square() { + assert!(0_u256.wide_square() == u512 { limb0: 0, limb1: 0, limb2: 0, limb3: 0 }); + assert!( + 0x1001001001001001001001001001001001001001001001001001_u256 + .wide_square() == u512 { + limb0: 0x0b00a009008007006005004003002001, + limb1: 0xe00f01001101201101000f00e00d00c0, + limb2: 0x00400500600700800900a00b00c00d00, + limb3: 0x1002003 + } + ); + assert!( + 0x1000100010001000100010001000100010001000100010001000100010001_u256 + .wide_square() == u512 { + limb0: 0x00080007000600050004000300020001, + limb1: 0x0010000f000e000d000c000b000a0009, + limb2: 0x00080009000a000b000c000d000e000f, + limb3: 0x1000200030004000500060007 + } + ); +} + #[test] fn test_u512_safe_div_rem_by_u256() { let zero = u512 { limb0: 0, limb1: 0, limb2: 0, limb3: 0 }; @@ -846,7 +869,7 @@ fn test_u256_sqrt() { assert!(1_u256.sqrt() == 1); assert!(0_u256.sqrt() == 0); assert!(Bounded::::MAX.sqrt() == Bounded::::MAX); - assert!(Bounded::::MAX.wide_mul(Bounded::::MAX).sqrt() == Bounded::::MAX); + assert!(Bounded::::MAX.wide_square().sqrt() == Bounded::::MAX); } #[test]