From ef74e5084386f82b4285fd7d3630cf1088eebb3f Mon Sep 17 00:00:00 2001 From: Neutron3529 Date: Wed, 15 Jul 2020 22:39:39 +0800 Subject: [PATCH] Rearrange the pipeline of `pow` to gain efficiency The check of the `exp` parameter seems useless if we execute the while-loop more than once. The original implementation of `pow` function using one more comparison if the `exp==0` and may break the pipeline of the cpu, which may generate a slower code. The performance gap between the old and the new implementation may be small, but IMO, at least the newer one looks more beautiful. --- bench prog: ``` extern crate test; ($a:expr)=>{let time=std::time::Instant::now();{$a;}print!("{:?} ",time.elapsed())}; ($a:expr,$b:literal)=>{let time=std::time::Instant::now();let mut a=0;for _ in 0..$b{a^=$a;}print!("{:?} {} ",time.elapsed(),a)} } pub fn pow_rust(x:i64, mut exp: u32) -> i64 { let mut base = x; let mut acc = 1; while exp > 1 { if (exp & 1) == 1 { acc = acc * base; } exp /= 2; base = base * base; } if exp == 1 { acc = acc * base; } acc } pub fn pow_new(x:i64, mut exp: u32) -> i64 { if exp==0{ 1 }else{ let mut base = x; let mut acc = 1; while exp > 1 { if (exp & 1) == 1 { acc = acc * base; } exp >>= 1; base = base * base; } acc * base } } fn main(){ let a=2i64; let b=1_u32; println!(); timing!(test::black_box(a).pow(test::black_box(b)),100000000); timing!(pow_new(test::black_box(a),test::black_box(b)),100000000); timing!(pow_rust(test::black_box(a),test::black_box(b)),100000000); println!(); timing!(test::black_box(a).pow(test::black_box(b)),100000000); timing!(pow_new(test::black_box(a),test::black_box(b)),100000000); timing!(pow_rust(test::black_box(a),test::black_box(b)),100000000); println!(); timing!(test::black_box(a).pow(test::black_box(b)),100000000); timing!(pow_new(test::black_box(a),test::black_box(b)),100000000); timing!(pow_rust(test::black_box(a),test::black_box(b)),100000000); println!(); timing!(test::black_box(a).pow(test::black_box(b)),100000000); timing!(pow_new(test::black_box(a),test::black_box(b)),100000000); timing!(pow_rust(test::black_box(a),test::black_box(b)),100000000); println!(); timing!(test::black_box(a).pow(test::black_box(b)),100000000); timing!(pow_new(test::black_box(a),test::black_box(b)),100000000); timing!(pow_rust(test::black_box(a),test::black_box(b)),100000000); println!(); timing!(test::black_box(a).pow(test::black_box(b)),100000000); timing!(pow_new(test::black_box(a),test::black_box(b)),100000000); timing!(pow_rust(test::black_box(a),test::black_box(b)),100000000); println!(); timing!(test::black_box(a).pow(test::black_box(b)),100000000); timing!(pow_new(test::black_box(a),test::black_box(b)),100000000); timing!(pow_rust(test::black_box(a),test::black_box(b)),100000000); println!(); timing!(test::black_box(a).pow(test::black_box(b)),100000000); timing!(pow_new(test::black_box(a),test::black_box(b)),100000000); timing!(pow_rust(test::black_box(a),test::black_box(b)),100000000); println!(); } ``` bench in my laptop: ``` neutron@Neutron:/me/rust$ rc commit.rs rustc commit.rs && ./commit 3.978419716s 0 4.079765171s 0 3.964630622s 0 3.997127013s 0 4.260304804s 0 3.997638211s 0 3.963195544s 0 4.11657718s 0 4.176054164s 0 3.830128579s 0 3.980396122s 0 3.937258567s 0 3.986055948s 0 4.127804162s 0 4.018943411s 0 4.185568857s 0 4.217512517s 0 3.98313603s 0 3.863018225s 0 4.030447988s 0 3.694878237s 0 4.206987927s 0 4.137608047s 0 4.115564664s 0 neutron@Neutron:/me/rust$ rc commit.rs -O rustc commit.rs -O && ./commit 162.111993ms 0 165.107125ms 0 166.26924ms 0 175.20479ms 0 205.062565ms 0 176.278791ms 0 174.408975ms 0 166.526899ms 0 201.857604ms 0 146.190062ms 0 168.592821ms 0 154.61411ms 0 199.678912ms 0 168.411598ms 0 162.129996ms 0 147.420765ms 0 209.759326ms 0 154.807907ms 0 165.507134ms 0 188.476239ms 0 157.351524ms 0 121.320123ms 0 126.401229ms 0 114.86428ms 0 ``` delete an unnecessary semicolon... Sorry for the typo. delete trailing whitespace Sorry, too.. Sorry for the missing... I checked all the implementations, and finally found that there is one function that does not check whether `exp == 0` add extra tests add extra tests. finished adding the extra tests to prevent further typo add pow(2) to negative exp add whitespace. add whitespace add whitespace delete extra line --- src/libcore/num/mod.rs | 87 ++++++++++++++-------------- src/libcore/tests/num/int_macros.rs | 33 ++++++++++- src/libcore/tests/num/uint_macros.rs | 25 ++++++++ 3 files changed, 101 insertions(+), 44 deletions(-) diff --git a/src/libcore/num/mod.rs b/src/libcore/num/mod.rs index f4a1afd436adb..c576465c622fe 100644 --- a/src/libcore/num/mod.rs +++ b/src/libcore/num/mod.rs @@ -1103,6 +1103,9 @@ $EndFeature, " without modifying the original"] #[inline] pub const fn checked_pow(self, mut exp: u32) -> Option { + if exp == 0 { + return Some(1); + } let mut base = self; let mut acc: Self = 1; @@ -1113,15 +1116,11 @@ $EndFeature, " exp /= 2; base = try_opt!(base.checked_mul(base)); } - + // since exp!=0, finally the exp must be 1. // Deal with the final bit of the exponent separately, since // squaring the base afterwards is not necessary and may cause a // needless overflow. - if exp == 1 { - acc = try_opt!(acc.checked_mul(base)); - } - - Some(acc) + Some(try_opt!(acc.checked_mul(base))) } } @@ -1631,6 +1630,9 @@ $EndFeature, " without modifying the original"] #[inline] pub const fn wrapping_pow(self, mut exp: u32) -> Self { + if exp == 0 { + return 1; + } let mut base = self; let mut acc: Self = 1; @@ -1642,14 +1644,11 @@ $EndFeature, " base = base.wrapping_mul(base); } + // since exp!=0, finally the exp must be 1. // Deal with the final bit of the exponent separately, since // squaring the base afterwards is not necessary and may cause a // needless overflow. - if exp == 1 { - acc = acc.wrapping_mul(base); - } - - acc + acc.wrapping_mul(base) } } @@ -1999,6 +1998,9 @@ $EndFeature, " without modifying the original"] #[inline] pub const fn overflowing_pow(self, mut exp: u32) -> (Self, bool) { + if exp == 0 { + return (1,false); + } let mut base = self; let mut acc: Self = 1; let mut overflown = false; @@ -2017,16 +2019,13 @@ $EndFeature, " overflown |= r.1; } + // since exp!=0, finally the exp must be 1. // Deal with the final bit of the exponent separately, since // squaring the base afterwards is not necessary and may cause a // needless overflow. - if exp == 1 { - r = acc.overflowing_mul(base); - acc = r.0; - overflown |= r.1; - } - - (acc, overflown) + r = acc.overflowing_mul(base); + r.1 |= overflown; + r } } @@ -2050,6 +2049,9 @@ $EndFeature, " #[inline] #[rustc_inherit_overflow_checks] pub const fn pow(self, mut exp: u32) -> Self { + if exp == 0 { + return 1; + } let mut base = self; let mut acc = 1; @@ -2061,14 +2063,11 @@ $EndFeature, " base = base * base; } + // since exp!=0, finally the exp must be 1. // Deal with the final bit of the exponent separately, since // squaring the base afterwards is not necessary and may cause a // needless overflow. - if exp == 1 { - acc = acc * base; - } - - acc + acc * base } } @@ -3306,6 +3305,9 @@ assert_eq!(", stringify!($SelfT), "::MAX.checked_pow(2), None);", $EndFeature, " without modifying the original"] #[inline] pub const fn checked_pow(self, mut exp: u32) -> Option { + if exp == 0 { + return Some(1); + } let mut base = self; let mut acc: Self = 1; @@ -3317,14 +3319,12 @@ assert_eq!(", stringify!($SelfT), "::MAX.checked_pow(2), None);", $EndFeature, " base = try_opt!(base.checked_mul(base)); } + // since exp!=0, finally the exp must be 1. // Deal with the final bit of the exponent separately, since // squaring the base afterwards is not necessary and may cause a // needless overflow. - if exp == 1 { - acc = try_opt!(acc.checked_mul(base)); - } - Some(acc) + Some(try_opt!(acc.checked_mul(base))) } } @@ -3715,6 +3715,9 @@ assert_eq!(3u8.wrapping_pow(6), 217);", $EndFeature, " without modifying the original"] #[inline] pub const fn wrapping_pow(self, mut exp: u32) -> Self { + if exp == 0 { + return 1; + } let mut base = self; let mut acc: Self = 1; @@ -3726,14 +3729,11 @@ assert_eq!(3u8.wrapping_pow(6), 217);", $EndFeature, " base = base.wrapping_mul(base); } + // since exp!=0, finally the exp must be 1. // Deal with the final bit of the exponent separately, since // squaring the base afterwards is not necessary and may cause a // needless overflow. - if exp == 1 { - acc = acc.wrapping_mul(base); - } - - acc + acc.wrapping_mul(base) } } @@ -4040,6 +4040,9 @@ assert_eq!(3u8.overflowing_pow(6), (217, true));", $EndFeature, " without modifying the original"] #[inline] pub const fn overflowing_pow(self, mut exp: u32) -> (Self, bool) { + if exp == 0{ + return (1,false); + } let mut base = self; let mut acc: Self = 1; let mut overflown = false; @@ -4058,16 +4061,14 @@ assert_eq!(3u8.overflowing_pow(6), (217, true));", $EndFeature, " overflown |= r.1; } + // since exp!=0, finally the exp must be 1. // Deal with the final bit of the exponent separately, since // squaring the base afterwards is not necessary and may cause a // needless overflow. - if exp == 1 { - r = acc.overflowing_mul(base); - acc = r.0; - overflown |= r.1; - } + r = acc.overflowing_mul(base); + r.1 |= overflown; - (acc, overflown) + r } } @@ -4088,6 +4089,9 @@ Basic usage: #[inline] #[rustc_inherit_overflow_checks] pub const fn pow(self, mut exp: u32) -> Self { + if exp == 0 { + return 1; + } let mut base = self; let mut acc = 1; @@ -4099,14 +4103,11 @@ Basic usage: base = base * base; } + // since exp!=0, finally the exp must be 1. // Deal with the final bit of the exponent separately, since // squaring the base afterwards is not necessary and may cause a // needless overflow. - if exp == 1 { - acc = acc * base; - } - - acc + acc * base } } diff --git a/src/libcore/tests/num/int_macros.rs b/src/libcore/tests/num/int_macros.rs index 8396a0dd62db9..58a585669122c 100644 --- a/src/libcore/tests/num/int_macros.rs +++ b/src/libcore/tests/num/int_macros.rs @@ -255,12 +255,43 @@ macro_rules! int_module { #[test] fn test_pow() { let mut r = 2 as $T; - assert_eq!(r.pow(2), 4 as $T); assert_eq!(r.pow(0), 1 as $T); + assert_eq!(r.wrapping_pow(2), 4 as $T); + assert_eq!(r.wrapping_pow(0), 1 as $T); + assert_eq!(r.checked_pow(2), Some(4 as $T)); + assert_eq!(r.checked_pow(0), Some(1 as $T)); + assert_eq!(r.overflowing_pow(2), (4 as $T, false)); + assert_eq!(r.overflowing_pow(0), (1 as $T, false)); + assert_eq!(r.saturating_pow(2), 4 as $T); + assert_eq!(r.saturating_pow(0), 1 as $T); + + r = MAX; + // use `^` to represent .pow() with no overflow. + // if itest::MAX == 2^j-1, then itest is a `j` bit int, + // so that `itest::MAX*itest::MAX == 2^(2*j)-2^(j+1)+1`, + // thussaturating_pow the overflowing result is exactly 1. + assert_eq!(r.wrapping_pow(2), 1 as $T); + assert_eq!(r.checked_pow(2), None); + assert_eq!(r.overflowing_pow(2), (1 as $T, true)); + assert_eq!(r.saturating_pow(2), MAX); + //test for negative exponent. r = -2 as $T; assert_eq!(r.pow(2), 4 as $T); assert_eq!(r.pow(3), -8 as $T); + assert_eq!(r.pow(0), 1 as $T); + assert_eq!(r.wrapping_pow(2), 4 as $T); + assert_eq!(r.wrapping_pow(3), -8 as $T); + assert_eq!(r.wrapping_pow(0), 1 as $T); + assert_eq!(r.checked_pow(2), Some(4 as $T)); + assert_eq!(r.checked_pow(3), Some(-8 as $T)); + assert_eq!(r.checked_pow(0), Some(1 as $T)); + assert_eq!(r.overflowing_pow(2), (4 as $T, false)); + assert_eq!(r.overflowing_pow(3), (-8 as $T, false)); + assert_eq!(r.overflowing_pow(0), (1 as $T, false)); + assert_eq!(r.saturating_pow(2), 4 as $T); + assert_eq!(r.saturating_pow(3), -8 as $T); + assert_eq!(r.saturating_pow(0), 1 as $T); } } }; diff --git a/src/libcore/tests/num/uint_macros.rs b/src/libcore/tests/num/uint_macros.rs index 8f1ca8e6fac2c..b84a8a7d9f88b 100644 --- a/src/libcore/tests/num/uint_macros.rs +++ b/src/libcore/tests/num/uint_macros.rs @@ -184,6 +184,31 @@ macro_rules! uint_module { assert_eq!($T::from_str_radix("Z", 10).ok(), None::<$T>); assert_eq!($T::from_str_radix("_", 2).ok(), None::<$T>); } + + #[test] + fn test_pow() { + let mut r = 2 as $T; + assert_eq!(r.pow(2), 4 as $T); + assert_eq!(r.pow(0), 1 as $T); + assert_eq!(r.wrapping_pow(2), 4 as $T); + assert_eq!(r.wrapping_pow(0), 1 as $T); + assert_eq!(r.checked_pow(2), Some(4 as $T)); + assert_eq!(r.checked_pow(0), Some(1 as $T)); + assert_eq!(r.overflowing_pow(2), (4 as $T, false)); + assert_eq!(r.overflowing_pow(0), (1 as $T, false)); + assert_eq!(r.saturating_pow(2), 4 as $T); + assert_eq!(r.saturating_pow(0), 1 as $T); + + r = MAX; + // use `^` to represent .pow() with no overflow. + // if itest::MAX == 2^j-1, then itest is a `j` bit int, + // so that `itest::MAX*itest::MAX == 2^(2*j)-2^(j+1)+1`, + // thussaturating_pow the overflowing result is exactly 1. + assert_eq!(r.wrapping_pow(2), 1 as $T); + assert_eq!(r.checked_pow(2), None); + assert_eq!(r.overflowing_pow(2), (1 as $T, true)); + assert_eq!(r.saturating_pow(2), MAX); + } } }; }