Skip to content

Commit

Permalink
Fix descriptor.Table buffer growth calc (#2311)
Browse files Browse the repository at this point in the history
Signed-off-by: Edward McFarlane <emcfarlane@buf.build>
Co-authored-by: Takeshi Yoneda <t.y.mathetake@gmail.com>
  • Loading branch information
emcfarlane and mathetake authored Sep 25, 2024
1 parent dc058c0 commit 111c51a
Show file tree
Hide file tree
Showing 3 changed files with 118 additions and 48 deletions.
11 changes: 11 additions & 0 deletions internal/descriptor/export_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
package descriptor

// Masks returns the masks of the table for testing purposes.
func Masks[Key ~int32, Item any](t *Table[Key, Item]) []uint64 {
return t.masks
}

// Items returns the items of the table for testing purposes.
func Items[Key ~int32, Item any](t *Table[Key, Item]) []Item {
return t.items
}
40 changes: 15 additions & 25 deletions internal/descriptor/table.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
package descriptor

import "math/bits"
import (
"math/bits"
"slices"
)

// Table is a data structure mapping 32 bit descriptor to items.
//
Expand Down Expand Up @@ -37,23 +40,13 @@ func (t *Table[Key, Item]) Len() (n int) {
return n
}

// grow ensures that t has enough room for n items, potentially reallocating the
// internal buffers if their capacity was too small to hold this many items.
// grow grows the table by n * 64 items.
func (t *Table[Key, Item]) grow(n int) {
// Round up to a multiple of 64 since this is the smallest increment due to
// using 64 bits masks.
n = (n*64 + 63) / 64
total := len(t.masks) + n
t.masks = slices.Grow(t.masks, n)[:total]

if n > len(t.masks) {
masks := make([]uint64, n)
copy(masks, t.masks)

items := make([]Item, n*64)
copy(items, t.items)

t.masks = masks
t.items = items
}
total = len(t.items) + n*64
t.items = slices.Grow(t.items, n*64)[:total]
}

// Insert inserts the given item to the table, returning the key that it is
Expand All @@ -78,13 +71,9 @@ insert:
}
}

// No free slot found, grow the table and retry.
offset = len(t.masks)
n := 2 * len(t.masks)
if n == 0 {
n = 1
}

t.grow(n)
t.grow(1)
goto insert
}

Expand All @@ -109,10 +98,10 @@ func (t *Table[Key, Item]) InsertAt(item Item, key Key) bool {
if key < 0 {
return false
}
if diff := int(key) - t.Len(); diff > 0 {
index := uint(key) / 64
if diff := int(index) - len(t.masks) + 1; diff > 0 {
t.grow(diff)
}
index := uint(key) / 64
shift := uint(key) % 64
t.masks[index] |= 1 << shift
t.items[key] = item
Expand All @@ -124,7 +113,8 @@ func (t *Table[Key, Item]) Delete(key Key) {
if key < 0 { // invalid key
return
}
if index, shift := key/64, key%64; int(index) < len(t.masks) {
if index := uint(key) / 64; int(index) < len(t.masks) {
shift := uint(key) % 64
mask := t.masks[index]
if (mask & (1 << shift)) != 0 {
var zero Item
Expand Down
115 changes: 92 additions & 23 deletions internal/descriptor/table_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,16 +3,16 @@ package descriptor_test
import (
"testing"

"github.com/tetratelabs/wazero/internal/descriptor"
"github.com/tetratelabs/wazero/internal/sys"
"github.com/tetratelabs/wazero/internal/testing/require"
)

func TestFileTable(t *testing.T) {
table := new(sys.FileTable)

if n := table.Len(); n != 0 {
t.Errorf("new table is not empty: length=%d", n)
}
n := table.Len()
require.Equal(t, 0, n, "new table is not empty: length=%d", n)

// The id field is used as a sentinel value.
v0 := &sys.FileEntry{Name: "1"}
Expand All @@ -38,16 +38,12 @@ func TestFileTable(t *testing.T) {
{key: k1, val: v1},
{key: k2, val: v2},
} {
if v, ok := table.Lookup(lookup.key); !ok {
t.Errorf("value not found for key '%v'", lookup.key)
} else if v.Name != lookup.val.Name {
t.Errorf("wrong value returned for key '%v': want=%v got=%v", lookup.key, lookup.val.Name, v.Name)
}
v, ok := table.Lookup(lookup.key)
require.True(t, ok, "value not found for key '%v'", lookup.key)
require.Equal(t, lookup.val.Name, v.Name, "wrong value returned for key '%v'", lookup.key)
}

if n := table.Len(); n != 3 {
t.Errorf("wrong table length: want=3 got=%d", n)
}
require.Equal(t, 3, table.Len(), "wrong table length: want=3 got=%d", table.Len())

k0Found := false
k1Found := false
Expand All @@ -62,9 +58,7 @@ func TestFileTable(t *testing.T) {
case k2:
k2Found, want = true, v2
}
if v.Name != want.Name {
t.Errorf("wrong value found ranging over '%v': want=%v got=%v", k, want.Name, v.Name)
}
require.Equal(t, want.Name, v.Name, "wrong value found ranging over table")
return true
})

Expand All @@ -76,9 +70,7 @@ func TestFileTable(t *testing.T) {
{key: k1, ok: k1Found},
{key: k2, ok: k2Found},
} {
if !found.ok {
t.Errorf("key not found while ranging over table: %v", found.key)
}
require.True(t, found.ok, "key not found while ranging over table: %v", found.key)
}

for i, deletion := range []struct {
Expand All @@ -89,12 +81,10 @@ func TestFileTable(t *testing.T) {
{key: k2},
} {
table.Delete(deletion.key)
if _, ok := table.Lookup(deletion.key); ok {
t.Errorf("item found after deletion of '%v'", deletion.key)
}
if n, want := table.Len(), 3-(i+1); n != want {
t.Errorf("wrong table length after deletion: want=%d got=%d", want, n)
}
_, ok := table.Lookup(deletion.key)
require.False(t, ok, "item found after deletion of '%v'", deletion.key)
n, want := table.Len(), 3-(i+1)
require.Equal(t, want, n, "wrong table length after deletion: want=%d got=%d", want, n)
}
}

Expand Down Expand Up @@ -134,3 +124,82 @@ func BenchmarkFileTableLookup(b *testing.B) {
b.Error("wrong file returned by lookup")
}
}

func Test_sizeOfTable(t *testing.T) {
tests := []struct {
name string
operation func(*descriptor.Table[int32, string])
expectedSize int
}{
{
name: "empty table",
operation: func(table *descriptor.Table[int32, string]) {},
expectedSize: 0,
},
{
name: "1 insert",
operation: func(table *descriptor.Table[int32, string]) {
table.Insert("a")
},
expectedSize: 1,
},
{
name: "32 inserts",
operation: func(table *descriptor.Table[int32, string]) {
for i := 0; i < 32; i++ {
table.Insert("a")
}
},
expectedSize: 1,
},
{
name: "257 inserts",
operation: func(table *descriptor.Table[int32, string]) {
for i := 0; i < 257; i++ {
table.Insert("a")
}
},
expectedSize: 5,
},
{
name: "1 insert at 63",
operation: func(table *descriptor.Table[int32, string]) {
table.InsertAt("a", 63)
},
expectedSize: 1,
},
{
name: "1 insert at 64",
operation: func(table *descriptor.Table[int32, string]) {
table.InsertAt("a", 64)
},
expectedSize: 2,
},
{
name: "1 insert at 257",
operation: func(table *descriptor.Table[int32, string]) {
table.InsertAt("a", 257)
},
expectedSize: 5,
},
{
name: "insert at until 320",
operation: func(table *descriptor.Table[int32, string]) {
for i := int32(0); i < 320; i++ {
table.InsertAt("a", i)
}
},
expectedSize: 5,
},
}
for _, tt := range tests {
tc := tt

t.Run(tc.name, func(t *testing.T) {
table := new(descriptor.Table[int32, string])
tc.operation(table)
require.Equal(t, tc.expectedSize, len(descriptor.Masks(table)))
require.Equal(t, tc.expectedSize*64, len(descriptor.Items(table)))
})
}
}

0 comments on commit 111c51a

Please sign in to comment.