Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support ShiftLeft and ShiftRight #165

Merged
merged 7 commits into from
Aug 23, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
118 changes: 118 additions & 0 deletions bitset.go
Original file line number Diff line number Diff line change
Expand Up @@ -1182,3 +1182,121 @@ func (b *BitSet) Select(index uint) uint {
}
return b.length
}

// top detects the top bit set
func (b *BitSet) top() (uint, bool) {
panicIfNull(b)

idx := len(b.set) - 1
for ; idx >= 0 && b.set[idx] == 0; idx-- {
}

// no set bits
if idx < 0 {
return 0, false
}

return uint(idx)*wordSize + len64(b.set[idx]) - 1, true
}

// ShiftLeft shifts the bitset like << operation would do.
//
// Left shift may require bitset size extension. We try to avoid the
// unnecessary memory operations by detecting the leftmost set bit.
// The function will panic if shift causes excess of capacity.
func (b *BitSet) ShiftLeft(bits uint) {
panicIfNull(b)

if bits == 0 {
return
}

top, ok := b.top()
if !ok {
return
}

// capacity check
if top+bits >= Cap() {
panic("You are exceeding the capacity")
}

// destination set
dst := b.set

// not using extendSet() to avoid unneeded data copying
nsize := wordsNeeded(top + bits)
if len(b.set) < nsize {
dst = make([]uint64, nsize, 2*nsize)
}
if top+bits >= b.length {
b.length = top + bits + 1
}

pad, idx := top%wordSize, top>>log2WordSize
shift, pages := bits%wordSize, bits>>log2WordSize
if bits%wordSize == 0 { // happy case: just add pages
copy(dst[pages:nsize], b.set)
} else {
if pad+shift >= wordSize {
dst[idx+pages+1] = b.set[idx] >> (wordSize - shift)
}

for i := int(idx); i >= 0; i-- {
if i > 0 {
dst[i+int(pages)] = (b.set[i] << shift) | (b.set[i-1] >> (wordSize - shift))
} else {
dst[i+int(pages)] = b.set[i] << shift
}
}
}

// zeroing extra pages
for i := 0; i < int(pages); i++ {
dst[i] = 0
}

b.set = dst
}

// ShiftRight shifts the bitset like >> operation would do.
func (b *BitSet) ShiftRight(bits uint) {
panicIfNull(b)

if bits == 0 {
return
}

top, ok := b.top()
if !ok {
return
}

if bits >= top {
b.set = make([]uint64, wordsNeeded(b.length))
return
}

pad, idx := top%wordSize, top>>log2WordSize
shift, pages := bits%wordSize, bits>>log2WordSize
if bits%wordSize == 0 { // happy case: just clear pages
b.set = b.set[pages:]
b.length -= pages * wordSize
} else {
for i := 0; i <= int(idx-pages); i++ {
if i < int(idx-pages) {
b.set[i] = (b.set[i+int(pages)] >> shift) | (b.set[i+int(pages)+1] << (wordSize - shift))
} else {
b.set[i] = b.set[i+int(pages)] >> shift
}
}

if pad < shift {
b.set[int(idx-pages)] = 0
}
}

for i := int(idx-pages) + 1; i <= int(idx); i++ {
b.set[i] = 0
}
}
72 changes: 72 additions & 0 deletions bitset_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1963,3 +1963,75 @@ func TestSetAll(t *testing.T) {
test(fmt.Sprintf("length %d", length), New(length), length)
}
}

func TestShiftLeft(t *testing.T) {
data := []uint{5, 28, 45, 72, 89}

test := func(name string, bits uint) {
t.Run(name, func(t *testing.T) {
b := New(200)
for _, i := range data {
b.Set(i)
}

b.ShiftLeft(bits)

if int(b.Count()) != len(data) {
t.Error("bad bits count")
}

for _, i := range data {
if !b.Test(i + bits) {
t.Errorf("bit %v is not set", i+bits)
}
}
})
}

test("zero", 0)
test("no page change", 19)
test("shift to full page", 38)
test("full page shift", 64)
test("no page split", 80)
test("with page split", 114)
test("with extension", 242)
}

func TestShiftRight(t *testing.T) {
data := []uint{5, 28, 45, 72, 89}

test := func(name string, bits uint) {
t.Run(name, func(t *testing.T) {
b := New(200)
for _, i := range data {
b.Set(i)
}

b.ShiftRight(bits)

count := 0
for _, i := range data {
if i > bits {
count++

if !b.Test(i - bits) {
t.Errorf("bit %v is not set", i-bits)
}
}
}

if int(b.Count()) != count {
t.Error("bad bits count")
}
})
}

test("zero", 0)
test("no page change", 3)
test("no page split", 20)
test("with page split", 40)
test("full page shift", 64)
test("with extension", 70)
test("full shift", 89)
test("remove all", 242)
}
43 changes: 43 additions & 0 deletions leading_zeros_18.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
//go:build !go1.9
// +build !go1.9

package bitset

var len8tab = "" +
Copy link
Member

@lemire lemire Aug 23, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What about adding a quick test? Like...

for k:= 0; k < 64; ++) {
  if len64(uint64(1)<<k) != k+1 {
    // error!
  }
}

"\x00\x01\x02\x02\x03\x03\x03\x03\x04\x04\x04\x04\x04\x04\x04\x04" +
"\x05\x05\x05\x05\x05\x05\x05\x05\x05\x05\x05\x05\x05\x05\x05\x05" +
"\x06\x06\x06\x06\x06\x06\x06\x06\x06\x06\x06\x06\x06\x06\x06\x06" +
"\x06\x06\x06\x06\x06\x06\x06\x06\x06\x06\x06\x06\x06\x06\x06\x06" +
"\x07\x07\x07\x07\x07\x07\x07\x07\x07\x07\x07\x07\x07\x07\x07\x07" +
"\x07\x07\x07\x07\x07\x07\x07\x07\x07\x07\x07\x07\x07\x07\x07\x07" +
"\x07\x07\x07\x07\x07\x07\x07\x07\x07\x07\x07\x07\x07\x07\x07\x07" +
"\x07\x07\x07\x07\x07\x07\x07\x07\x07\x07\x07\x07\x07\x07\x07\x07" +
"\x08\x08\x08\x08\x08\x08\x08\x08\x08\x08\x08\x08\x08\x08\x08\x08" +
"\x08\x08\x08\x08\x08\x08\x08\x08\x08\x08\x08\x08\x08\x08\x08\x08" +
"\x08\x08\x08\x08\x08\x08\x08\x08\x08\x08\x08\x08\x08\x08\x08\x08" +
"\x08\x08\x08\x08\x08\x08\x08\x08\x08\x08\x08\x08\x08\x08\x08\x08" +
"\x08\x08\x08\x08\x08\x08\x08\x08\x08\x08\x08\x08\x08\x08\x08\x08" +
"\x08\x08\x08\x08\x08\x08\x08\x08\x08\x08\x08\x08\x08\x08\x08\x08" +
"\x08\x08\x08\x08\x08\x08\x08\x08\x08\x08\x08\x08\x08\x08\x08\x08" +
"\x08\x08\x08\x08\x08\x08\x08\x08\x08\x08\x08\x08\x08\x08\x08\x08"

// Len64 returns the minimum number of bits required to represent x; the result is 0 for x == 0.
func len64(x uint64) (n uint) {
if x >= 1<<32 {
x >>= 32
n = 32
}
if x >= 1<<16 {
x >>= 16
n += 16
}
if x >= 1<<8 {
x >>= 8
n += 8
}
return n + uint(len8tab[x])
}

func leadingZeroes64(v uint64) uint {
return 64 - len64(x)
}
14 changes: 14 additions & 0 deletions leading_zeros_19.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
//go:build go1.9
// +build go1.9

package bitset

import "math/bits"

func len64(v uint64) uint {
return uint(bits.Len64(v))
}

func leadingZeroes64(v uint64) uint {
return uint(bits.LeadingZeros64(v))
}
Loading