Skip to content

Commit 92e691f

Browse files
committed
Implement Karatsuba's multiplication algorithm
Uses extra sign-bit to keep bit widths small.
1 parent b8287c2 commit 92e691f

File tree

2 files changed

+132
-1
lines changed

2 files changed

+132
-1
lines changed

src/solvers/flattening/bv_utils.cpp

Lines changed: 130 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -925,7 +925,8 @@ bvt bv_utilst::dadda_tree(const std::vector<bvt> &pps)
925925
// with Dadda's reduction yields the most consistent performance improvement
926926
// while not regressing substantially in the matrix of different benchmarks and
927927
// CaDiCaL and MiniSat2 as solvers.
928-
#define RADIX_MULTIPLIER 8
928+
// #define RADIX_MULTIPLIER 8
929+
#define USE_KARATSUBA
929930
#ifdef RADIX_MULTIPLIER
930931
# define DADDA_TREE
931932
#endif
@@ -1828,6 +1829,125 @@ bvt bv_utilst::unsigned_multiplier(const bvt &_op0, const bvt &_op1)
18281829
}
18291830
}
18301831

1832+
bvt bv_utilst::unsigned_karatsuba_full_multiplier(
1833+
const bvt &op0,
1834+
const bvt &op1)
1835+
{
1836+
// We review symbolic encoding of multiplication in context of sw
1837+
// verification, bit width is 2^n, distinguish truncating (x mod 2^2^n) from
1838+
// double-output-width multiplication, truncating Karatsuba is 2 truncating
1839+
// half-width multiplication plus one double-output-width of half width, for
1840+
// double output width Karatsuba idea is challenge to avoid width extension,
1841+
// check Wikipedia edit history
1842+
1843+
PRECONDITION(op0.size() == op1.size());
1844+
const std::size_t op_size = op0.size();
1845+
PRECONDITION(op_size > 0);
1846+
PRECONDITION((op_size & (op_size - 1)) == 0);
1847+
1848+
if(op_size == 1)
1849+
return {prop.land(op0[0], op1[0]), const_literal(false)};
1850+
1851+
const std::size_t half_op_size = op_size >> 1;
1852+
1853+
bvt x0{op0.begin(), op0.begin() + half_op_size};
1854+
bvt x1{op0.begin() + half_op_size, op0.end()};
1855+
1856+
bvt y0{op1.begin(), op1.begin() + half_op_size};
1857+
bvt y1{op1.begin() + half_op_size, op1.end()};
1858+
1859+
bvt z0 = unsigned_karatsuba_full_multiplier(x0, y0);
1860+
bvt z2 = unsigned_karatsuba_full_multiplier(x1, y1);
1861+
1862+
bvt x0_sub = zero_extension(x0, half_op_size + 1);
1863+
bvt x1_sub = zero_extension(x1, half_op_size + 1);
1864+
1865+
bvt y0_sub = zero_extension(y0, half_op_size + 1);
1866+
bvt y1_sub = zero_extension(y1, half_op_size + 1);
1867+
1868+
bvt x1_minus_x0_ext = sub(x1_sub, x0_sub);
1869+
literalt x1_minus_x0_sign = sign_bit(x1_minus_x0_ext);
1870+
bvt x1_minus_x0_abs = absolute_value(x1_minus_x0_ext);
1871+
x1_minus_x0_abs.pop_back();
1872+
bvt y0_minus_y1_ext = sub(y0_sub, y1_sub);
1873+
literalt y0_minus_y1_sign = sign_bit(y0_minus_y1_ext);
1874+
bvt y0_minus_y1_abs = absolute_value(y0_minus_y1_ext);
1875+
y0_minus_y1_abs.pop_back();
1876+
bvt sub_mult =
1877+
unsigned_karatsuba_full_multiplier(x1_minus_x0_abs, y0_minus_y1_abs);
1878+
bvt sub_mult_ext = zero_extension(sub_mult, op_size + 1);
1879+
bvt z1_ext = add_sub(
1880+
zero_extension(add(z0, z2), op_size + 1),
1881+
sub_mult_ext,
1882+
prop.lxor(x1_minus_x0_sign, y0_minus_y1_sign));
1883+
1884+
bvt z0_full = zero_extension(z0, op_size << 1);
1885+
bvt z1_full =
1886+
zero_extension(concatenate(zeros(half_op_size), z1_ext), op_size << 1);
1887+
bvt z2_full = concatenate(zeros(op_size), z2);
1888+
1889+
return add(add(z0_full, z1_full), z2_full);
1890+
}
1891+
1892+
bvt bv_utilst::unsigned_karatsuba_multiplier(const bvt &_op0, const bvt &_op1)
1893+
{
1894+
if(_op0.size() != _op1.size())
1895+
return unsigned_multiplier(_op0, _op1);
1896+
1897+
const std::size_t op_size = _op0.size();
1898+
// only use this approach for powers of two
1899+
if(op_size == 0 || (op_size & (op_size - 1)) != 0)
1900+
return unsigned_multiplier(_op0, _op1);
1901+
1902+
if(op_size == 1)
1903+
return {prop.land(_op0[0], _op1[0])};
1904+
1905+
const std::size_t half_op_size = op_size >> 1;
1906+
1907+
// We split each of the operands in half and treat them as coefficients of a
1908+
// polynomial a * 2^half_op_size + b. Straightforward polynomial
1909+
// multiplication then yields
1910+
// a0 * a1 * 2^op_size + (a0 * b1 + a1 * b0) * 2^half_op_size + b0 * b1
1911+
// These would be four multiplications (the operands of which have half the
1912+
// original bit width):
1913+
// z0 = b0 * b1
1914+
// z1 = a0 * b1 + a1 * b0
1915+
// z2 = a0 * a1
1916+
// Karatsuba's insight is that these four multiplications can be expressed
1917+
// using just three multiplications:
1918+
// z1 = (a0 - b0) * (b1 - a1) + z0 + z2
1919+
//
1920+
// Worked 4-bit example, 4-bit result:
1921+
// abcd * efgh -> 4-bit result
1922+
// cd * gh -> 4-bit result
1923+
// cd * ef -> 2-bit result
1924+
// ab * gh -> 2-bit result
1925+
// d * h -> 2-bit result
1926+
// c * g -> 2-bit result
1927+
// (c - d) * (h - g) + dh + cg; use an extra sign bit for each of the
1928+
// subtractions, and conditionally negate the product by xor-ing those sign
1929+
// bits; dh + cg is a 2-bit addition (with possible results 0, 1, 2); the
1930+
// product has possible values (-1, 0, 1); the final sum cannot evaluate to -1
1931+
// as
1932+
// * c=1, d=0, h=0, g=1 (1 * -1) implies cg=1
1933+
// * c=0, d=1, h=1, g=0 (-1 * 1) implies dh=1
1934+
// Therefore, after adding (dh + cg) the multiplication can safely be added
1935+
// over just 2 bits.
1936+
1937+
bvt x0{_op0.begin(), _op0.begin() + half_op_size};
1938+
bvt x1{_op0.begin() + half_op_size, _op0.end()};
1939+
bvt y0{_op1.begin(), _op1.begin() + half_op_size};
1940+
bvt y1{_op1.begin() + half_op_size, _op1.end()};
1941+
1942+
bvt z0 = unsigned_karatsuba_full_multiplier(x0, y0);
1943+
bvt z1 = add(
1944+
unsigned_karatsuba_multiplier(x1, y0),
1945+
unsigned_karatsuba_multiplier(x0, y1));
1946+
bvt z1_full = concatenate(zeros(half_op_size), z1);
1947+
1948+
return add(z0, z1_full);
1949+
}
1950+
18311951
bvt bv_utilst::unsigned_multiplier_no_overflow(
18321952
const bvt &op0,
18331953
const bvt &op1)
@@ -1878,7 +1998,11 @@ bvt bv_utilst::signed_multiplier(const bvt &op0, const bvt &op1)
18781998
bvt neg0=cond_negate(op0, sign0);
18791999
bvt neg1=cond_negate(op1, sign1);
18802000

2001+
#ifdef USE_KARATSUBA
2002+
bvt result = unsigned_karatsuba_multiplier(neg0, neg1);
2003+
#else
18812004
bvt result=unsigned_multiplier(neg0, neg1);
2005+
#endif
18822006

18832007
literalt result_sign=prop.lxor(sign0, sign1);
18842008

@@ -1946,7 +2070,12 @@ bvt bv_utilst::multiplier(
19462070
switch(rep)
19472071
{
19482072
case representationt::SIGNED: return signed_multiplier(op0, op1);
2073+
#ifdef USE_KARATSUBA
2074+
case representationt::UNSIGNED:
2075+
return unsigned_karatsuba_multiplier(op0, op1);
2076+
#else
19492077
case representationt::UNSIGNED: return unsigned_multiplier(op0, op1);
2078+
#endif
19502079
}
19512080

19522081
UNREACHABLE;

src/solvers/flattening/bv_utils.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,8 @@ class bv_utilst
7979
bvt shift(const bvt &op, const shiftt shift, const bvt &distance);
8080

8181
bvt unsigned_multiplier(const bvt &op0, const bvt &op1);
82+
bvt unsigned_karatsuba_multiplier(const bvt &op0, const bvt &op1);
83+
bvt unsigned_karatsuba_full_multiplier(const bvt &op0, const bvt &op1);
8284
bvt signed_multiplier(const bvt &op0, const bvt &op1);
8385
bvt multiplier(const bvt &op0, const bvt &op1, representationt rep);
8486
bvt multiplier_no_overflow(

0 commit comments

Comments
 (0)