From a091dd13ba755a1beb7953f2b46bbafec35a6d96 Mon Sep 17 00:00:00 2001 From: Alun Evans Date: Sun, 12 Nov 2023 17:34:31 -0800 Subject: [PATCH] Ensure Lookup and BatchLookup accept a slice by value Signed-off-by: Alun Evans --- map.go | 54 +++++++++++++++++++++++++++++++++++++++++++++++++-- map_test.go | 11 +++++++++++ marshalers.go | 24 ++++++++--------------- 3 files changed, 71 insertions(+), 18 deletions(-) diff --git a/map.go b/map.go index ce945ace0..cfedc4d12 100644 --- a/map.go +++ b/map.go @@ -642,11 +642,15 @@ func (m *Map) LookupBytes(key interface{}) ([]byte, error) { } func (m *Map) lookupPerCPU(key, valueOut any, flags MapLookupFlags) error { + slice, err := ensurePerCPUSlice(valueOut, int(m.valueSize)) + if err != nil { + return err + } valueBytes := make([]byte, m.fullValueSize) if err := m.lookup(key, sys.NewSlicePointer(valueBytes), flags); err != nil { return err } - return unmarshalPerCPUValue(valueOut, int(m.valueSize), valueBytes) + return unmarshalPerCPUValue(slice, int(m.valueSize), valueBytes) } func (m *Map) lookup(key interface{}, valueOut sys.Pointer, flags MapLookupFlags) error { @@ -669,11 +673,57 @@ func (m *Map) lookup(key interface{}, valueOut sys.Pointer, flags MapLookupFlags } func (m *Map) lookupAndDeletePerCPU(key, valueOut any, flags MapLookupFlags) error { + slice, err := ensurePerCPUSlice(valueOut, int(m.valueSize)) + if err != nil { + return err + } valueBytes := make([]byte, m.fullValueSize) if err := m.lookupAndDelete(key, sys.NewSlicePointer(valueBytes), flags); err != nil { return err } - return unmarshalPerCPUValue(valueOut, int(m.valueSize), valueBytes) + return unmarshalPerCPUValue(slice, int(m.valueSize), valueBytes) +} + +func ensurePerCPUSlice(sliceOrPtr any, elemLength int) (any, error) { + possibleCPUs, err := internal.PossibleCPUs() + if err != nil { + return nil, err + } + + sliceOrPtrType := reflect.TypeOf(sliceOrPtr) + if sliceOrPtrType.Kind() == reflect.Slice { + sliceValue := reflect.ValueOf(sliceOrPtr) + if sliceValue.Len() != possibleCPUs { + return nil, fmt.Errorf("per-cpu slice is incorrect length, expected %d, got %d", + possibleCPUs, sliceValue.Len()) + } + return sliceValue.Interface(), nil + } + + slicePtrType := sliceOrPtrType + if slicePtrType.Kind() != reflect.Ptr || slicePtrType.Elem().Kind() != reflect.Slice { + return nil, fmt.Errorf("per-cpu value requires a slice or a pointer to slice") + } + + sliceType := slicePtrType.Elem() + slice := reflect.MakeSlice(sliceType, possibleCPUs, possibleCPUs) + + sliceElemType := sliceType.Elem() + sliceElemIsPointer := sliceElemType.Kind() == reflect.Ptr + reflect.ValueOf(sliceOrPtr).Elem().Set(slice) + if !sliceElemIsPointer { + return slice.Interface(), nil + } + sliceElemType = sliceElemType.Elem() + + for i := 0; i < possibleCPUs; i++ { + if sliceElemIsPointer { + newElem := reflect.New(sliceElemType) + slice.Index(i).Set(newElem) + } + } + + return slice.Interface(), nil } func (m *Map) lookupAndDelete(key any, valuePtr sys.Pointer, flags MapLookupFlags) error { diff --git a/map_test.go b/map_test.go index 7bd7194c2..c41b9f4f2 100644 --- a/map_test.go +++ b/map_test.go @@ -75,6 +75,10 @@ func TestMap(t *testing.T) { t.Error("Want value 42, got", v) } + sliceVal := make([]uint32, 1) + qt.Assert(t, m.Lookup(uint32(0), sliceVal), qt.IsNil) + qt.Assert(t, sliceVal, qt.DeepEquals, []uint32{42}) + var slice []byte qt.Assert(t, m.Lookup(uint32(0), &slice), qt.IsNil) qt.Assert(t, slice, qt.DeepEquals, internal.NativeEndian.AppendUint32(nil, 42)) @@ -1348,6 +1352,13 @@ func TestPerCPUMarshaling(t *testing.T) { } // Make sure unmarshaling works on slices containing pointers + retrievedVal := make([]*customEncoding, numCPU) + for i := range retrievedVal { + retrievedVal[i] = &customEncoding{} + } + if err := arr.Lookup(uint32(0), retrievedVal); err != nil { + t.Fatal("Can't retrieve key 0:", err) + } var retrieved []*customEncoding if err := arr.Lookup(uint32(0), &retrieved); err != nil { t.Fatal("Can't retrieve key 0:", err) diff --git a/marshalers.go b/marshalers.go index e89a12f0f..474436e87 100644 --- a/marshalers.go +++ b/marshalers.go @@ -85,10 +85,10 @@ func marshalPerCPUValue(slice any, elemLength int) (sys.Pointer, error) { // possible CPU. // // slicePtr must be a pointer to a slice. -func unmarshalPerCPUValue(slicePtr any, elemLength int, buf []byte) error { - slicePtrType := reflect.TypeOf(slicePtr) - if slicePtrType.Kind() != reflect.Ptr || slicePtrType.Elem().Kind() != reflect.Slice { - return fmt.Errorf("per-cpu value requires pointer to slice") +func unmarshalPerCPUValue(slice any, elemLength int, buf []byte) error { + sliceType := reflect.TypeOf(slice) + if sliceType.Kind() != reflect.Slice { + return fmt.Errorf("per-cpu value requires a slice") } possibleCPUs, err := internal.PossibleCPUs() @@ -96,26 +96,19 @@ func unmarshalPerCPUValue(slicePtr any, elemLength int, buf []byte) error { return err } - sliceType := slicePtrType.Elem() - slice := reflect.MakeSlice(sliceType, possibleCPUs, possibleCPUs) + sliceValue := reflect.ValueOf(slice) sliceElemType := sliceType.Elem() sliceElemIsPointer := sliceElemType.Kind() == reflect.Ptr - if sliceElemIsPointer { - sliceElemType = sliceElemType.Elem() - } - stride := internal.Align(elemLength, 8) for i := 0; i < possibleCPUs; i++ { var elem any + v := sliceValue.Index(i) if sliceElemIsPointer { - newElem := reflect.New(sliceElemType) - slice.Index(i).Set(newElem) - elem = newElem.Interface() + elem = v.Elem().Addr().Interface() } else { - elem = slice.Index(i).Addr().Interface() + elem = v.Addr().Interface() } - err := sysenc.Unmarshal(elem, buf[:elemLength]) if err != nil { return fmt.Errorf("cpu %d: %w", i, err) @@ -124,6 +117,5 @@ func unmarshalPerCPUValue(slicePtr any, elemLength int, buf []byte) error { buf = buf[stride:] } - reflect.ValueOf(slicePtr).Elem().Set(slice) return nil }