Skip to content

Commit

Permalink
Symbolic Toom-Cook multiplication
Browse files Browse the repository at this point in the history
Implements the algorithm of Section 4 of "Further Steps Down The Wrong
Path : Improving the Bit-Blasting of Multiplication" (see
https://ceur-ws.org/Vol-2908/short16.pdf).
  • Loading branch information
tautschnig committed Nov 17, 2023
1 parent 832578f commit ffe136e
Show file tree
Hide file tree
Showing 2 changed files with 159 additions and 2 deletions.
160 changes: 158 additions & 2 deletions src/solvers/flattening/bv_utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@ Author: Daniel Kroening, kroening@kroening.com

#include "bv_utils.h"

#include <util/arith_tools.h>

#include <utility>

bvt bv_utilst::build_constant(const mp_integer &n, std::size_t width)
Expand Down Expand Up @@ -782,7 +784,7 @@ bvt bv_utilst::dadda_tree(const std::vector<bvt> &pps)
// trees (and also the default addition scheme), but isn't consistently more
// performant with simple partial-product generation. Only when using
// higher-radix multipliers the combination appears to perform better.
#define DADDA_TREE
// #define DADDA_TREE

// The following examples demonstrate the performance differences (with a
// time-out of 7200 seconds):
Expand Down Expand Up @@ -924,7 +926,7 @@ bvt bv_utilst::dadda_tree(const std::vector<bvt> &pps)
// with Dadda's reduction yields the most consistent performance improvement
// while not regressing substantially in the matrix of different benchmarks and
// CaDiCaL and MiniSat2 as solvers.
#define RADIX_MULTIPLIER 8
// #define RADIX_MULTIPLIER 8

#ifdef RADIX_MULTIPLIER
static bvt unsigned_multiply_by_3(propt &prop, const bvt &op)
Expand Down Expand Up @@ -1861,6 +1863,155 @@ bvt bv_utilst::unsigned_karatsuba_multiplier(const bvt &_op0, const bvt &_op1)
return add(z0, z1);
}

bvt bv_utilst::unsigned_toom_cook_multiplier(const bvt &_op0, const bvt &_op1)
{
PRECONDITION(!_op0.empty());
PRECONDITION(!_op1.empty());

if(_op1.size() == 1)
return unsigned_multiplier(_op0, _op1);

// break up _op0, _op1 in groups of at most GROUP_SIZE bits
PRECONDITION(_op0.size() == _op1.size());
#define GROUP_SIZE 8
const std::size_t d_bits =
2 * GROUP_SIZE +
2 * address_bits((_op0.size() + GROUP_SIZE - 1) / GROUP_SIZE);
std::vector<bvt> a, b, c_ops, d;
for(std::size_t i = 0; i < _op0.size(); i += GROUP_SIZE)
{
std::size_t u = std::min(i + GROUP_SIZE, _op0.size());
a.emplace_back(_op0.begin() + i, _op0.begin() + u);
b.emplace_back(_op1.begin() + i, _op1.begin() + u);

c_ops.push_back(zeros(i));
d.push_back(prop.new_variables(d_bits));
c_ops.back().insert(c_ops.back().end(), d.back().begin(), d.back().end());
c_ops.back() = zero_extension(c_ops.back(), _op0.size());
}
for(std::size_t i = a.size(); i < 2 * a.size() - 1; ++i)
{
d.push_back(prop.new_variables(d_bits));
}

// r(0)
bvt r_0 = d[0];
prop.l_set_to_true(equal(
r_0,
unsigned_multiplier(
zero_extension(a[0], r_0.size()), zero_extension(b[0], r_0.size()))));

for(std::size_t j = 1; j < a.size(); ++j)
{
// r(2^(j-1))
bvt r_j = zero_extension(
d[0], std::min(_op0.size(), d[0].size() + (j - 1) * (d.size() - 1)));
for(std::size_t i = 1; i < d.size(); ++i)
{
r_j = add(
r_j,
shift(
zero_extension(d[i], r_j.size()), shiftt::SHIFT_LEFT, (j - 1) * i));
}

bvt a_even = zero_extension(a[0], r_j.size());
for(std::size_t i = 2; i < a.size(); i += 2)
{
a_even = add(
a_even,
shift(
zero_extension(a[i], a_even.size()),
shiftt::SHIFT_LEFT,
(j - 1) * i));
}
bvt a_odd = zero_extension(a[1], r_j.size());
for(std::size_t i = 3; i < a.size(); i += 2)
{
a_odd = add(
a_odd,
shift(
zero_extension(a[i], a_odd.size()),
shiftt::SHIFT_LEFT,
(j - 1) * (i - 1)));
}
bvt b_even = zero_extension(b[0], r_j.size());
for(std::size_t i = 2; i < b.size(); i += 2)
{
b_even = add(
b_even,
shift(
zero_extension(b[i], b_even.size()),
shiftt::SHIFT_LEFT,
(j - 1) * i));
}
bvt b_odd = zero_extension(b[1], r_j.size());
for(std::size_t i = 3; i < b.size(); i += 2)
{
b_odd = add(
b_odd,
shift(
zero_extension(b[i], b_odd.size()),
shiftt::SHIFT_LEFT,
(j - 1) * (i - 1)));
}

prop.l_set_to_true(equal(
r_j,
unsigned_multiplier(
add(a_even, shift(a_odd, shiftt::SHIFT_LEFT, j - 1)),
add(b_even, shift(b_odd, shiftt::SHIFT_LEFT, j - 1)))));

// r(-2^(j-1))
bvt r_minus_j = zero_extension(
d[0], std::min(_op0.size(), d[0].size() + (j - 1) * (d.size() - 1)));
for(std::size_t i = 1; i < d.size(); ++i)
{
if(i % 2 == 1)
{
r_minus_j = sub(
r_minus_j,
shift(
zero_extension(d[i], r_minus_j.size()),
shiftt::SHIFT_LEFT,
(j - 1) * i));
}
else
{
r_minus_j = add(
r_minus_j,
shift(
zero_extension(d[i], r_minus_j.size()),
shiftt::SHIFT_LEFT,
(j - 1) * i));
}
}

prop.l_set_to_true(equal(
r_minus_j,
unsigned_multiplier(
sub(a_even, shift(a_odd, shiftt::SHIFT_LEFT, j - 1)),
sub(b_even, shift(b_odd, shiftt::SHIFT_LEFT, j - 1)))));
}

if(c_ops.empty())
return zeros(_op0.size());
else
{
#ifdef WALLACE_TREE
return wallace_tree(c_ops);
#elif defined(DADDA_TREE)
return dadda_tree(c_ops);
#else
bvt product = c_ops.front();

for(auto it = std::next(c_ops.begin()); it != c_ops.end(); ++it)
product = add(product, *it);

return product;
#endif
}
}

bvt bv_utilst::unsigned_multiplier_no_overflow(
const bvt &op0,
const bvt &op1)
Expand Down Expand Up @@ -1913,6 +2064,8 @@ bvt bv_utilst::signed_multiplier(const bvt &op0, const bvt &op1)

#ifdef USE_KARATSUBA
bvt result = unsigned_karatsuba_multiplier(neg0, neg1);
#elif defined(USE_TOOM_COOK)
bvt result = unsigned_toom_cook_multiplier(neg0, neg1);
#else
bvt result=unsigned_multiplier(neg0, neg1);
#endif
Expand Down Expand Up @@ -1986,6 +2139,9 @@ bvt bv_utilst::multiplier(
#ifdef USE_KARATSUBA
case representationt::UNSIGNED:
return unsigned_karatsuba_multiplier(op0, op1);
#elif defined(USE_TOOM_COOK)
case representationt::UNSIGNED:
return unsigned_toom_cook_multiplier(op0, op1);
#else
case representationt::UNSIGNED: return unsigned_multiplier(op0, op1);
#endif
Expand Down
1 change: 1 addition & 0 deletions src/solvers/flattening/bv_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,7 @@ class bv_utilst

bvt unsigned_multiplier(const bvt &op0, const bvt &op1);
bvt unsigned_karatsuba_multiplier(const bvt &op0, const bvt &op1);
bvt unsigned_toom_cook_multiplier(const bvt &op0, const bvt &op1);
bvt signed_multiplier(const bvt &op0, const bvt &op1);
bvt multiplier(const bvt &op0, const bvt &op1, representationt rep);
bvt multiplier_no_overflow(
Expand Down

0 comments on commit ffe136e

Please sign in to comment.