Skip to content

Commit

Permalink
Ensure Lookup and BatchLookup accept a slice by value
Browse files Browse the repository at this point in the history
Signed-off-by: Alun Evans <alun@badgerous.net>
  • Loading branch information
alxn committed Nov 13, 2023
1 parent 0acd95c commit a091dd1
Show file tree
Hide file tree
Showing 3 changed files with 71 additions and 18 deletions.
54 changes: 52 additions & 2 deletions map.go
Original file line number Diff line number Diff line change
Expand Up @@ -642,11 +642,15 @@ func (m *Map) LookupBytes(key interface{}) ([]byte, error) {
}

func (m *Map) lookupPerCPU(key, valueOut any, flags MapLookupFlags) error {
slice, err := ensurePerCPUSlice(valueOut, int(m.valueSize))
if err != nil {
return err
}
valueBytes := make([]byte, m.fullValueSize)
if err := m.lookup(key, sys.NewSlicePointer(valueBytes), flags); err != nil {
return err
}
return unmarshalPerCPUValue(valueOut, int(m.valueSize), valueBytes)
return unmarshalPerCPUValue(slice, int(m.valueSize), valueBytes)
}

func (m *Map) lookup(key interface{}, valueOut sys.Pointer, flags MapLookupFlags) error {
Expand All @@ -669,11 +673,57 @@ func (m *Map) lookup(key interface{}, valueOut sys.Pointer, flags MapLookupFlags
}

func (m *Map) lookupAndDeletePerCPU(key, valueOut any, flags MapLookupFlags) error {
slice, err := ensurePerCPUSlice(valueOut, int(m.valueSize))
if err != nil {
return err
}
valueBytes := make([]byte, m.fullValueSize)
if err := m.lookupAndDelete(key, sys.NewSlicePointer(valueBytes), flags); err != nil {
return err
}
return unmarshalPerCPUValue(valueOut, int(m.valueSize), valueBytes)
return unmarshalPerCPUValue(slice, int(m.valueSize), valueBytes)
}

func ensurePerCPUSlice(sliceOrPtr any, elemLength int) (any, error) {
possibleCPUs, err := internal.PossibleCPUs()
if err != nil {
return nil, err
}

sliceOrPtrType := reflect.TypeOf(sliceOrPtr)
if sliceOrPtrType.Kind() == reflect.Slice {
sliceValue := reflect.ValueOf(sliceOrPtr)
if sliceValue.Len() != possibleCPUs {
return nil, fmt.Errorf("per-cpu slice is incorrect length, expected %d, got %d",
possibleCPUs, sliceValue.Len())
}
return sliceValue.Interface(), nil
}

slicePtrType := sliceOrPtrType
if slicePtrType.Kind() != reflect.Ptr || slicePtrType.Elem().Kind() != reflect.Slice {
return nil, fmt.Errorf("per-cpu value requires a slice or a pointer to slice")
}

sliceType := slicePtrType.Elem()
slice := reflect.MakeSlice(sliceType, possibleCPUs, possibleCPUs)

sliceElemType := sliceType.Elem()
sliceElemIsPointer := sliceElemType.Kind() == reflect.Ptr
reflect.ValueOf(sliceOrPtr).Elem().Set(slice)
if !sliceElemIsPointer {
return slice.Interface(), nil
}
sliceElemType = sliceElemType.Elem()

for i := 0; i < possibleCPUs; i++ {
if sliceElemIsPointer {
newElem := reflect.New(sliceElemType)
slice.Index(i).Set(newElem)
}
}

return slice.Interface(), nil
}

func (m *Map) lookupAndDelete(key any, valuePtr sys.Pointer, flags MapLookupFlags) error {
Expand Down
11 changes: 11 additions & 0 deletions map_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,10 @@ func TestMap(t *testing.T) {
t.Error("Want value 42, got", v)
}

sliceVal := make([]uint32, 1)
qt.Assert(t, m.Lookup(uint32(0), sliceVal), qt.IsNil)
qt.Assert(t, sliceVal, qt.DeepEquals, []uint32{42})

var slice []byte
qt.Assert(t, m.Lookup(uint32(0), &slice), qt.IsNil)
qt.Assert(t, slice, qt.DeepEquals, internal.NativeEndian.AppendUint32(nil, 42))
Expand Down Expand Up @@ -1348,6 +1352,13 @@ func TestPerCPUMarshaling(t *testing.T) {
}

// Make sure unmarshaling works on slices containing pointers
retrievedVal := make([]*customEncoding, numCPU)
for i := range retrievedVal {
retrievedVal[i] = &customEncoding{}
}
if err := arr.Lookup(uint32(0), retrievedVal); err != nil {
t.Fatal("Can't retrieve key 0:", err)
}
var retrieved []*customEncoding
if err := arr.Lookup(uint32(0), &retrieved); err != nil {
t.Fatal("Can't retrieve key 0:", err)
Expand Down
24 changes: 8 additions & 16 deletions marshalers.go
Original file line number Diff line number Diff line change
Expand Up @@ -85,37 +85,30 @@ func marshalPerCPUValue(slice any, elemLength int) (sys.Pointer, error) {
// possible CPU.
//
// slicePtr must be a pointer to a slice.
func unmarshalPerCPUValue(slicePtr any, elemLength int, buf []byte) error {
slicePtrType := reflect.TypeOf(slicePtr)
if slicePtrType.Kind() != reflect.Ptr || slicePtrType.Elem().Kind() != reflect.Slice {
return fmt.Errorf("per-cpu value requires pointer to slice")
func unmarshalPerCPUValue(slice any, elemLength int, buf []byte) error {
sliceType := reflect.TypeOf(slice)
if sliceType.Kind() != reflect.Slice {
return fmt.Errorf("per-cpu value requires a slice")
}

possibleCPUs, err := internal.PossibleCPUs()
if err != nil {
return err
}

sliceType := slicePtrType.Elem()
slice := reflect.MakeSlice(sliceType, possibleCPUs, possibleCPUs)
sliceValue := reflect.ValueOf(slice)

sliceElemType := sliceType.Elem()
sliceElemIsPointer := sliceElemType.Kind() == reflect.Ptr
if sliceElemIsPointer {
sliceElemType = sliceElemType.Elem()
}

stride := internal.Align(elemLength, 8)
for i := 0; i < possibleCPUs; i++ {
var elem any
v := sliceValue.Index(i)
if sliceElemIsPointer {
newElem := reflect.New(sliceElemType)
slice.Index(i).Set(newElem)
elem = newElem.Interface()
elem = v.Elem().Addr().Interface()
} else {
elem = slice.Index(i).Addr().Interface()
elem = v.Addr().Interface()
}

err := sysenc.Unmarshal(elem, buf[:elemLength])
if err != nil {
return fmt.Errorf("cpu %d: %w", i, err)
Expand All @@ -124,6 +117,5 @@ func unmarshalPerCPUValue(slicePtr any, elemLength int, buf []byte) error {
buf = buf[stride:]
}

reflect.ValueOf(slicePtr).Elem().Set(slice)
return nil
}

0 comments on commit a091dd1

Please sign in to comment.