Skip to content

Commit

Permalink
Improve multiply performance
Browse files Browse the repository at this point in the history
The main idea here is to do as much as possible with slices, instead of
allocating new BigUints (= heap allocations).

Current performance:

multiply_0:	7,137 ns/iter (+/- 620)
multiply_1:	2,104,565 ns/iter (+/- 255,208)
multiply_2:	51,620,572 ns/iter (+/- 2,707,818)

After this patch, we get:

multiply_0:	4,224 ns/iter (+/- 635)
multiply_1:	198,155 ns/iter (+/- 28,722)
multiply_2:	31,541,853 ns/iter (+/- 4,978,257)
  • Loading branch information
koverstreet committed Nov 16, 2015
1 parent 1c2f8bc commit 6e2f21e
Showing 1 changed file with 107 additions and 58 deletions.
165 changes: 107 additions & 58 deletions src/bigint.rs
Original file line number Diff line number Diff line change
Expand Up @@ -501,6 +501,34 @@ fn sub3(a: BigUintS, b: BigUintS) -> BigUint {
diff
}

fn sub_sign(a: BigUintS, b: BigUintS) -> BigInt {
fn cmp(a: BigUintS, b: BigUintS) -> Ordering {
/*
* Since we're working with slices here, it's _not_ guaranteed that the highest element in
* the slice is nonzero:
*/
let mut i = cmp::max(a.len(), b.len());

while {
let av = if i < a.len() { a[i] } else { 0 };
let bv = if i < b.len() { b[i] } else { 0 };

if av < bv { return Less; }
if av > bv { return Greater; }

i != 0
} { i = i - 1; }

return Equal;
}

match cmp(a, b) {
Less => BigInt::from_biguint(Plus, sub3(b, a)),
Greater => BigInt::from_biguint(Minus, sub3(a, b)),
_ => Zero::zero(),
}
}

impl<'a> Sub<&'a BigUint> for BigUint {
type Output = BigUint;

Expand All @@ -520,73 +548,94 @@ impl<'a, 'b> Sub<&'b BigUint> for &'a BigUint {
}
}

forward_all_binop!(impl Mul for BigUint, mul);
fn mul3(x: BigUintS, y: BigUintS) -> BigUint {
/*
* Karatsuba multiplication:
*
* x = x0 + x1 * b
* y = y0 + y1 * b
*
* p0 = x0 * y0
* p1 = (x1 - x0) * (y1 - y0)
* p2 = x1 * y1
*
* x * y = b^2 * p2
* + b * (p2 - p1 + p0)
* + p0
*/

fn mul_digit(a: BigUintS, n: BigDigit) -> BigUint {
if n == 0 { return Zero::zero(); }
if n == 1 { return BigUint::from_slice(a); }

impl<'a, 'b> Mul<&'b BigUint> for &'a BigUint {
type Output = BigUint;
let mut carry = 0;
let mut prod: Vec<BigDigit> = a.iter().map(|ai| {
let (hi, lo) = big_digit::from_doublebigdigit(
(*ai as DoubleBigDigit) * (n as DoubleBigDigit) + (carry as DoubleBigDigit)
);
carry = hi;
lo
}).collect();

fn mul(self, other: &BigUint) -> BigUint {
if self.is_zero() || other.is_zero() { return Zero::zero(); }
if carry != 0 {
prod.push(carry);
}

let (s_len, o_len) = (self.data.len(), other.data.len());
if s_len == 1 { return mul_digit(other, self.data[0]); }
if o_len == 1 { return mul_digit(self, other.data[0]); }

// Using Karatsuba multiplication
// (a1 * base + a0) * (b1 * base + b0)
// = a1*b1 * base^2 +
// (a1*b1 + a0*b0 - (a1-b0)*(b1-a0)) * base +
// a0*b0
let half_len = cmp::max(s_len, o_len) / 2;
let (s_hi, s_lo) = cut_at(self, half_len);
let (o_hi, o_lo) = cut_at(other, half_len);

let ll = &s_lo * &o_lo;
let hh = &s_hi * &o_hi;
let mm = {
let (s1, n1) = sub_sign(s_hi, s_lo);
let (s2, n2) = sub_sign(o_hi, o_lo);
match (s1, s2) {
(Equal, _) | (_, Equal) => &hh + &ll,
(Less, Greater) | (Greater, Less) => &hh + &ll + (n1 * n2),
(Less, Less) | (Greater, Greater) => &hh + &ll - (n1 * n2)
}
};
return BigUint::new(prod);
}

return ll + mm.shl_unit(half_len) + hh.shl_unit(half_len * 2);
let (x, y) = if x.len() < y.len() { (x, y) } else { (y, x) };

if x.len() == 0 { return Zero::zero(); }
if x.len() == 1 { return mul_digit(y, x[0]); }

fn mul_digit(a: &BigUint, n: BigDigit) -> BigUint {
if n == 0 { return Zero::zero(); }
if n == 1 { return a.clone(); }
/*
* When x is smaller than y, it's significantly faster to pick a midpoint that splits x in
* half, not y:
*/
let half_len = x.len() / 2;
let (x0, x1) = x.split_at(half_len);
let (y0, y1) = y.split_at(half_len);

let mut carry = 0;
let mut prod: Vec<BigDigit> = a.data.iter().map(|ai| {
let (hi, lo) = big_digit::from_doublebigdigit(
(*ai as DoubleBigDigit) * (n as DoubleBigDigit) + (carry as DoubleBigDigit)
);
carry = hi;
lo
}).collect();
if carry != 0 { prod.push(carry); }
return BigUint::new(prod);
}
let p0 = mul3(x0, y0);
let p1 = sub_sign(x1, x0) * sub_sign(y1, y0);
let p2 = mul3(x1, y1);

#[inline]
fn cut_at(a: &BigUint, n: usize) -> (BigUint, BigUint) {
let mid = cmp::min(a.data.len(), n);
(BigUint::from_slice(&a.data[mid ..]),
BigUint::from_slice(&a.data[.. mid]))
}
let len = cmp::max(p2.data.len() + half_len * 2,
cmp::max(p1.data.data.len() + half_len,
p0.data.len() + half_len)) + 1;

#[inline]
fn sub_sign(a: BigUint, b: BigUint) -> (Ordering, BigUint) {
match a.cmp(&b) {
Less => (Less, b - a),
Greater => (Greater, a - b),
_ => (Equal, Zero::zero())
}
}
let mut prod: BigUint = BigUint { data: Vec::with_capacity(len) };

// resize isn't stable yet:
//prod.data.resize(len, 0);
prod.data.extend(repeat(ZERO_BIG_DIGIT).take(len));

add2(&mut prod.data[half_len..], &p2.data[..]);
add2(&mut prod.data[half_len * 2..], &p2.data[..]);

add2(&mut prod.data[..], &p0.data[..]);
add2(&mut prod.data[half_len..], &p0.data[..]);

// Last, so we don't have a negative partial result:
match p1.sign {
Plus => sub2(&mut prod.data[half_len..], &p1.data.data[..]),
Minus => add2(&mut prod.data[half_len..], &p1.data.data[..]),
NoSign => (),
}

prod.normalize();
prod
}

forward_all_binop!(impl Mul for BigUint, mul);

impl<'a, 'b> Mul<&'b BigUint> for &'a BigUint {
type Output = BigUint;

#[inline]
fn mul(self, other: &BigUint) -> BigUint {
mul3(&self.data[..], &other.data[..])
}
}

Expand Down

0 comments on commit 6e2f21e

Please sign in to comment.