Skip to content

Commit

Permalink
simplify implementation of copy_n() (remove special cases)
Browse files Browse the repository at this point in the history
  • Loading branch information
ilanschnell committed Nov 29, 2023
1 parent 8a03c21 commit d2d6fd5
Show file tree
Hide file tree
Showing 5 changed files with 123 additions and 228 deletions.
1 change: 1 addition & 0 deletions CHANGE_LOG
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
2023-XX-XX 2.8.4:
-------------------
* simplify implementation of `copy_n()` (remove special cases)
* improve documentation and testing


Expand Down
90 changes: 32 additions & 58 deletions bitarray/_bitarray.c
Original file line number Diff line number Diff line change
Expand Up @@ -226,7 +226,7 @@ shift_r8(bitarrayobject *self, Py_ssize_t a, Py_ssize_t b, int n, int bebr)
return;

/* as the big-endian representation has reversed bit order in each
byte, we reverse each byte, and (re-) reverse again below */
byte, we reverse each byte, and (re-) reverse again at the end */
if (bebr && IS_BE(self))
bytereverse(self, a, b);

Expand Down Expand Up @@ -268,16 +268,17 @@ shift_r8(bitarrayobject *self, Py_ssize_t a, Py_ssize_t b, int n, int bebr)
}

/* copy n bits from other (starting at b) onto self (starting at a),
please find details about how this function works in copy_n.txt */
see also: examples/copy_n.py */
static void
copy_n(bitarrayobject *self, Py_ssize_t a,
bitarrayobject *other, Py_ssize_t b, Py_ssize_t n)
{
Py_ssize_t p1 = a / 8;
Py_ssize_t p2 = (a + n - 1) / 8;
Py_ssize_t p3 = b / 8;
int sa = a % 8;
int sb = -(b % 8);
Py_ssize_t p1 = a / 8; /* first byte to copied to */
Py_ssize_t p2 = (a + n - 1) / 8; /* last byte to be copied to */
Py_ssize_t p3 = b / 8; /* first byte to be copied from */
int sa = a % 8, sb = -(b % 8), be = IS_BE(self);
char t3 = 0; /* silence uninitialized warning on some compilers */
Py_ssize_t i;

assert(0 <= n && n <= self->nbits && n <= other->nbits);
assert(0 <= a && a <= self->nbits - n);
Expand All @@ -286,59 +287,32 @@ copy_n(bitarrayobject *self, Py_ssize_t a,
if (n == 0 || (self == other && a == b))
return;

if (sa == 0 && sb == 0) { /***** aligned case *****/
char *cp2 = self->ob_item + p2;
char m2 = ones_table[IS_BE(self)][(a + n) % 8];
char t2 = *cp2;

assert(p1 + BYTES(n) == p2 + 1 && p1 <= p2);

memmove(self->ob_item + p1, other->ob_item + p3, (size_t) BYTES(n));
if (sa + sb < 0) {
t3 = other->ob_item[p3++]; /* store byte in case other == self */
sb += 8;
}
assert(a - sa == 8 * p1 && b + sb == 8 * p3);
assert(p1 <= p2 && 8 * p2 < a + n && a + n <= 8 * (p2 + 1));
if (n > sb) {
Py_ssize_t m = BYTES(n - sb);
char *cp1 = self->ob_item + p1, m1 = ones_table[be][sa];
char *cp2 = self->ob_item + p2, m2 = ones_table[be][(a + n) % 8];
char t1 = *cp1, t2 = *cp2;

assert(p1 + m == p2 || p1 + m == p2 + 1);
assert(p1 + m <= Py_SIZE(self) && p3 + m <= Py_SIZE(other));
memmove(self->ob_item + p1, other->ob_item + p3, (size_t) m);
if (self->endian != other->endian)
bytereverse(self, p1, p2 + 1);

if (m2) /* restore bits overwritten by highest copied byte */
*cp2 = (*cp2 & m2) | (t2 & ~m2);
}
else if (n < 8) { /***** small n case *****/
Py_ssize_t i;

if (a <= b) { /* loop forward (delete) */
for (i = 0; i < n; i++)
setbit(self, i + a, getbit(other, i + b));
}
else { /* loop backwards (insert) */
for (i = n - 1; i >= 0; i--)
setbit(self, i + a, getbit(other, i + b));
}
}
else { /***** general case *****/
char *cp1 = self->ob_item + p1;
char *cp2 = self->ob_item + p2;
char m1 = ones_table[IS_BE(self)][sa];
char m2 = ones_table[IS_BE(self)][(a + n) % 8];
char t1 = *cp1, t2 = *cp2, t3 = other->ob_item[p3];
Py_ssize_t i;

assert(n >= 8 && cp1 <= cp2);
assert(a - sa == 8 * p1); /* useful equations */
assert(b + sb == 8 * p3);
assert(a + n > 8 * p2);

if (sa + sb < 0)
sb += 8;
copy_n(self, a - sa, other, b + sb, n - sb); /* aligned copy */
shift_r8(self, p1, p2 + 1, sa + sb, 1); /* right shift */

if (m1) /* restore bits at p1 */
*cp1 = (*cp1 & ~m1) | (t1 & m1);

if (m2 && sa + sb) /* if shifted, restore bits at p2 */
*cp2 = (*cp2 & m2) | (t2 & ~m2);
bytereverse(self, p1, p1 + m);

for (i = 0; i < sb; i++) /* copy first bits missed by copy_n() */
setbit(self, i + a, t3 & BITMASK(other, i + b));
shift_r8(self, p1, p2 + 1, sa + sb, 1); /* right shift by sa + sb */
if (m1)
*cp1 = (*cp1 & ~m1) | (t1 & m1); /* restore bits at p1 */
if (m2)
*cp2 = (*cp2 & m2) | (t2 & ~m2); /* restore bits at p2 */
}
for (i = 0; i < sb && i < n; i++) /* copy first sb bits */
setbit(self, i + a, t3 & BITMASK(other, i + b));
}

/* starting at start, delete n bits from self */
Expand Down Expand Up @@ -705,7 +679,7 @@ extend_bytes01(bitarrayobject *self, PyObject *bytes)
const Py_ssize_t original_nbits = self->nbits;
unsigned char c;
char *data;
int vi = 0; /* to avoid uninitialized warning for some compilers */
int vi = 0; /* silence uninitialized warning on some compilers */

assert(PyBytes_Check(bytes));
data = PyBytes_AS_STRING(bytes);
Expand Down
100 changes: 0 additions & 100 deletions bitarray/copy_n.txt

This file was deleted.

2 changes: 1 addition & 1 deletion bitarray/test_bitarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -731,7 +731,7 @@ def test_copy_n_explicit(self):
self.assertEqual(x, y)

def test_copy_n_example(self):
# example givin in bitarray/copy_n.txt
# example given in examples/copy_n.py
y = bitarray(
'00101110 11111001 01011101 11001011 10110000 01011110 011')
x = bitarray(
Expand Down
Loading

0 comments on commit d2d6fd5

Please sign in to comment.