Skip to content

Commit

Permalink
Fix various bugs in DiscordBitSet (#772)
Browse files Browse the repository at this point in the history
* WIDTH was defined to the width of a Byte not a Long, resulting in bugs
  in size, get and set

* binary just did some completely wrong things that I can't really
  describe

* hashCode could return different values for equal DiscordBitSets
  (when they had different amounts of trailing zeros in the data array),
  breaking its general contract

* equals was comparing this with this and not this with other, which
  resulted in any two DiscordBitSets considered equal

* contains would erroneously return false if other did have more
  trailing zeros but would otherwise be contained

* set wasn't growing data correctly and also didn't unset bits properly
  • Loading branch information
lukellmann authored Feb 26, 2023
1 parent f2e9c14 commit 3f5cd17
Show file tree
Hide file tree
Showing 2 changed files with 86 additions and 37 deletions.
55 changes: 31 additions & 24 deletions common/src/main/kotlin/DiscordBitSet.kt
Original file line number Diff line number Diff line change
Expand Up @@ -13,19 +13,20 @@ import kotlin.math.max
import kotlin.math.min

private const val SAFE_LENGTH = 19
private const val WIDTH = Byte.SIZE_BITS
private const val WIDTH = Long.SIZE_BITS

@Suppress("FunctionName")
public fun EmptyBitSet(): DiscordBitSet = DiscordBitSet(0)
public fun EmptyBitSet(): DiscordBitSet = DiscordBitSet()

@Serializable(with = DiscordBitSetSerializer::class)
public class DiscordBitSet(internal var data: LongArray) {
public class DiscordBitSet(internal var data: LongArray) { // data is in little-endian order

public val isEmpty: Boolean
get() = data.all { it == 0L }

public val value: String
get() {
// need to convert from little-endian data to big-endian expected by BigInteger
val buffer = ByteBuffer.allocate(data.size * Long.SIZE_BYTES)
buffer.asLongBuffer().put(data.reversedArray())
return BigInteger(buffer.array()).toString()
Expand All @@ -35,53 +36,64 @@ public class DiscordBitSet(internal var data: LongArray) {
get() = data.size * WIDTH

public val binary: String
get() = data.joinToString("") { it.toULong().toString(2) }.reversed().padEnd(8, '0')
get() = data.map { it.toULong().toString(radix = 2).padStart(length = ULong.SIZE_BITS, '0') }
.reversed()
.joinToString(separator = "")
.trimStart('0')
.ifEmpty { "0" }

override fun equals(other: Any?): Boolean {
if (other !is DiscordBitSet) return false
for (i in 0 until max(data.size, other.data.size)) {
if (getOrZero(i) != getOrZero(i)) return false
// trailing zeros are ignored -> getOrZero
for (i in 0 until max(this.data.size, other.data.size)) {
if (this.getOrZero(i) != other.getOrZero(i)) return false
}
return true
}

override fun hashCode(): Int {
var result = 1
// trailing zeros are ignored to have the same hashCode for equal bit sets
for (i in 0..(data.indexOfLast { it != 0L })) {
result = (31 * result) + data[i].hashCode()
}
return result
}

private fun getOrZero(i: Int) = data.getOrNull(i) ?: 0L

public operator fun get(index: Int): Boolean {
if (index !in 0 until size) return false
require(index >= 0)
if (index >= size) return false
val indexOfWidth = index / WIDTH
val bitIndex = index % WIDTH
return data[indexOfWidth] and (1L shl bitIndex) != 0L
}

public operator fun contains(other: DiscordBitSet): Boolean {
if (other.size > size) return false
for (i in other.data.indices) {
if (data[i] and other.data[i] != other.data[i]) return false
for ((index, value) in other.data.withIndex()) {
if ((this.getOrZero(index) and value) != value) return false
}
return true
}

public operator fun set(index: Int, value: Boolean) {
if (index !in 0 until size) data.copyOf((63 + index) / WIDTH)
require(index >= 0)
val indexOfWidth = index / WIDTH
if (index >= size) data = data.copyOf(indexOfWidth + 1)
val bitIndex = index % WIDTH
val bit = if (value) 1L else 0L
data[index] = data[indexOfWidth] or (bit shl bitIndex)
val prev = data[indexOfWidth]
data[indexOfWidth] = if (value) prev or (1L shl bitIndex) else prev and (1L shl bitIndex).inv()
}

public operator fun plus(another: DiscordBitSet): DiscordBitSet {
val dist = LongArray(data.size)
data.copyInto(dist)
val copy = DiscordBitSet(dist)
val copy = DiscordBitSet(data.copyOf())
copy.add(another)
return copy
}

public operator fun minus(another: DiscordBitSet): DiscordBitSet {
val dist = LongArray(data.size)
data.copyInto(dist)
val copy = DiscordBitSet(dist)
val copy = DiscordBitSet(data.copyOf())
copy.remove(another)
return copy
}
Expand All @@ -100,11 +112,6 @@ public class DiscordBitSet(internal var data: LongArray) {
}
}

override fun hashCode(): Int {
var result = data.contentHashCode()
result = 31 * result + size
return result
}

override fun toString(): String {
return "DiscordBitSet($binary)"
Expand Down
68 changes: 55 additions & 13 deletions common/src/test/kotlin/BitSetTests.kt
Original file line number Diff line number Diff line change
@@ -1,26 +1,58 @@
import dev.kord.common.DiscordBitSet
import org.junit.jupiter.api.Test
import dev.kord.common.EmptyBitSet
import kotlin.test.*

class BitSetTests {
@Test
fun `b contains a`() {
val a = DiscordBitSet(0b101)
val b = DiscordBitSet(0b111)
assert(a in b)
fun `a contains b and c`() {
val a = DiscordBitSet(0b111)
val b = DiscordBitSet(0b101)
val c = DiscordBitSet(0b101, 0)
assertTrue(b in a)
assertTrue(c in a)
}

@Test
fun `a equals b`() {
fun `a and b are equal and have the same hashCode`() {
val a = DiscordBitSet(0b111, 0)
val b = DiscordBitSet(0b111)
assert(a == b)
assertEquals(a, b)
assertEquals(a.hashCode(), b.hashCode())
}

@Test
fun `a does not equal b`() {
val a = DiscordBitSet(0b111, 0)
val b = DiscordBitSet(0b111, 0b1)
assertNotEquals(a, b)
}

@Test
fun `get a bit`() {
fun `get bits`() {
val a = DiscordBitSet(0b101, 0)
assert(!a[1])
assertTrue(a[0])
assertFalse(a[1])
assertTrue(a[2])
for (i in 3..64) assertFalse(a[i])

val b = DiscordBitSet(1L shl 63)
for (i in 0..62) assertFalse(b[i])
assertTrue(b[63])
}

@Test
fun `set bits`() {
val a = EmptyBitSet()
for (i in 0..64) a[i] = true
assertEquals(DiscordBitSet(ULong.MAX_VALUE.toLong(), 1), a)

val b = EmptyBitSet()
b[1] = true
b[2] = true
b[5] = true
assertEquals(DiscordBitSet(0b100110), b)
b[2] = false
assertEquals(DiscordBitSet(0b100010), b)
}

@Test
Expand All @@ -30,21 +62,31 @@ class BitSetTests {
}

@Test
fun `add and remove a bit`() {
fun `add and remove a bit`() {
val a = DiscordBitSet(0b101, 0)
a.add(DiscordBitSet(0b111))
assert(a.value == 0b111.toString())
a.remove(DiscordBitSet(0b001))
assert(a.value == 0b110.toString())

}

@Test
fun `remove a bit`() {
val a = DiscordBitSet(0b101, 0)
a.remove(DiscordBitSet(0b111))
assert(a.value == "0")

}

}
@Test
fun `binary works`() {
assertEquals("0", DiscordBitSet().binary)
assertEquals("0", DiscordBitSet(0).binary)
assertEquals("10011", DiscordBitSet(0b10011).binary)
assertEquals(
"110" +
"0000000000000000000000000000000000000000000000000000000000111001" +
"0000000000000000000000000000000000000000000000000000000000001011",
DiscordBitSet(0b1011, 0b111001, 0b110).binary,
)
}
}

0 comments on commit 3f5cd17

Please sign in to comment.