From 6e2f21e6d44323436204afd50e987cd24344c2e8 Mon Sep 17 00:00:00 2001 From: Kent Overstreet Date: Mon, 16 Nov 2015 02:34:35 -0900 Subject: [PATCH] Improve multiply performance The main idea here is to do as much as possible with slices, instead of allocating new BigUints (= heap allocations). Current performance: multiply_0: 7,137 ns/iter (+/- 620) multiply_1: 2,104,565 ns/iter (+/- 255,208) multiply_2: 51,620,572 ns/iter (+/- 2,707,818) After this patch, we get: multiply_0: 4,224 ns/iter (+/- 635) multiply_1: 198,155 ns/iter (+/- 28,722) multiply_2: 31,541,853 ns/iter (+/- 4,978,257) --- src/bigint.rs | 165 ++++++++++++++++++++++++++++++++------------------ 1 file changed, 107 insertions(+), 58 deletions(-) diff --git a/src/bigint.rs b/src/bigint.rs index 16a52387c5..f5d4fe8871 100644 --- a/src/bigint.rs +++ b/src/bigint.rs @@ -501,6 +501,34 @@ fn sub3(a: BigUintS, b: BigUintS) -> BigUint { diff } +fn sub_sign(a: BigUintS, b: BigUintS) -> BigInt { + fn cmp(a: BigUintS, b: BigUintS) -> Ordering { + /* + * Since we're working with slices here, it's _not_ guaranteed that the highest element in + * the slice is nonzero: + */ + let mut i = cmp::max(a.len(), b.len()); + + while { + let av = if i < a.len() { a[i] } else { 0 }; + let bv = if i < b.len() { b[i] } else { 0 }; + + if av < bv { return Less; } + if av > bv { return Greater; } + + i != 0 + } { i = i - 1; } + + return Equal; + } + + match cmp(a, b) { + Less => BigInt::from_biguint(Plus, sub3(b, a)), + Greater => BigInt::from_biguint(Minus, sub3(a, b)), + _ => Zero::zero(), + } +} + impl<'a> Sub<&'a BigUint> for BigUint { type Output = BigUint; @@ -520,73 +548,94 @@ impl<'a, 'b> Sub<&'b BigUint> for &'a BigUint { } } -forward_all_binop!(impl Mul for BigUint, mul); +fn mul3(x: BigUintS, y: BigUintS) -> BigUint { + /* + * Karatsuba multiplication: + * + * x = x0 + x1 * b + * y = y0 + y1 * b + * + * p0 = x0 * y0 + * p1 = (x1 - x0) * (y1 - y0) + * p2 = x1 * y1 + * + * x * y = b^2 * p2 + * + b * (p2 - p1 + p0) + * + p0 + */ + + fn mul_digit(a: BigUintS, n: BigDigit) -> BigUint { + if n == 0 { return Zero::zero(); } + if n == 1 { return BigUint::from_slice(a); } -impl<'a, 'b> Mul<&'b BigUint> for &'a BigUint { - type Output = BigUint; + let mut carry = 0; + let mut prod: Vec = a.iter().map(|ai| { + let (hi, lo) = big_digit::from_doublebigdigit( + (*ai as DoubleBigDigit) * (n as DoubleBigDigit) + (carry as DoubleBigDigit) + ); + carry = hi; + lo + }).collect(); - fn mul(self, other: &BigUint) -> BigUint { - if self.is_zero() || other.is_zero() { return Zero::zero(); } + if carry != 0 { + prod.push(carry); + } - let (s_len, o_len) = (self.data.len(), other.data.len()); - if s_len == 1 { return mul_digit(other, self.data[0]); } - if o_len == 1 { return mul_digit(self, other.data[0]); } - - // Using Karatsuba multiplication - // (a1 * base + a0) * (b1 * base + b0) - // = a1*b1 * base^2 + - // (a1*b1 + a0*b0 - (a1-b0)*(b1-a0)) * base + - // a0*b0 - let half_len = cmp::max(s_len, o_len) / 2; - let (s_hi, s_lo) = cut_at(self, half_len); - let (o_hi, o_lo) = cut_at(other, half_len); - - let ll = &s_lo * &o_lo; - let hh = &s_hi * &o_hi; - let mm = { - let (s1, n1) = sub_sign(s_hi, s_lo); - let (s2, n2) = sub_sign(o_hi, o_lo); - match (s1, s2) { - (Equal, _) | (_, Equal) => &hh + &ll, - (Less, Greater) | (Greater, Less) => &hh + &ll + (n1 * n2), - (Less, Less) | (Greater, Greater) => &hh + &ll - (n1 * n2) - } - }; + return BigUint::new(prod); + } - return ll + mm.shl_unit(half_len) + hh.shl_unit(half_len * 2); + let (x, y) = if x.len() < y.len() { (x, y) } else { (y, x) }; + if x.len() == 0 { return Zero::zero(); } + if x.len() == 1 { return mul_digit(y, x[0]); } - fn mul_digit(a: &BigUint, n: BigDigit) -> BigUint { - if n == 0 { return Zero::zero(); } - if n == 1 { return a.clone(); } + /* + * When x is smaller than y, it's significantly faster to pick a midpoint that splits x in + * half, not y: + */ + let half_len = x.len() / 2; + let (x0, x1) = x.split_at(half_len); + let (y0, y1) = y.split_at(half_len); - let mut carry = 0; - let mut prod: Vec = a.data.iter().map(|ai| { - let (hi, lo) = big_digit::from_doublebigdigit( - (*ai as DoubleBigDigit) * (n as DoubleBigDigit) + (carry as DoubleBigDigit) - ); - carry = hi; - lo - }).collect(); - if carry != 0 { prod.push(carry); } - return BigUint::new(prod); - } + let p0 = mul3(x0, y0); + let p1 = sub_sign(x1, x0) * sub_sign(y1, y0); + let p2 = mul3(x1, y1); - #[inline] - fn cut_at(a: &BigUint, n: usize) -> (BigUint, BigUint) { - let mid = cmp::min(a.data.len(), n); - (BigUint::from_slice(&a.data[mid ..]), - BigUint::from_slice(&a.data[.. mid])) - } + let len = cmp::max(p2.data.len() + half_len * 2, + cmp::max(p1.data.data.len() + half_len, + p0.data.len() + half_len)) + 1; - #[inline] - fn sub_sign(a: BigUint, b: BigUint) -> (Ordering, BigUint) { - match a.cmp(&b) { - Less => (Less, b - a), - Greater => (Greater, a - b), - _ => (Equal, Zero::zero()) - } - } + let mut prod: BigUint = BigUint { data: Vec::with_capacity(len) }; + + // resize isn't stable yet: + //prod.data.resize(len, 0); + prod.data.extend(repeat(ZERO_BIG_DIGIT).take(len)); + + add2(&mut prod.data[half_len..], &p2.data[..]); + add2(&mut prod.data[half_len * 2..], &p2.data[..]); + + add2(&mut prod.data[..], &p0.data[..]); + add2(&mut prod.data[half_len..], &p0.data[..]); + + // Last, so we don't have a negative partial result: + match p1.sign { + Plus => sub2(&mut prod.data[half_len..], &p1.data.data[..]), + Minus => add2(&mut prod.data[half_len..], &p1.data.data[..]), + NoSign => (), + } + + prod.normalize(); + prod +} + +forward_all_binop!(impl Mul for BigUint, mul); + +impl<'a, 'b> Mul<&'b BigUint> for &'a BigUint { + type Output = BigUint; + + #[inline] + fn mul(self, other: &BigUint) -> BigUint { + mul3(&self.data[..], &other.data[..]) } }