Skip to content

Commit

Permalink
feat: added size based eviction to batcher + reflectutil adds sizeof …
Browse files Browse the repository at this point in the history
…method
  • Loading branch information
Ice3man543 committed Jul 1, 2024
1 parent 5bb2161 commit 1863395
Show file tree
Hide file tree
Showing 6 changed files with 279 additions and 1 deletion.
31 changes: 30 additions & 1 deletion batcher/batcher.go
Original file line number Diff line number Diff line change
@@ -1,15 +1,22 @@
package batcher

import (
"sync/atomic"
"time"

"github.com/DmitriyVTitov/size"
)

// FlushCallback is the callback function that will be called when the batcher is full or the flush interval is reached
type FlushCallback[T any] func([]T)

// Batcher is a batcher for any type of data
type Batcher[T any] struct {
maxCapacity int
maxCapacity int
maxSize int32

currentSize atomic.Int32

flushInterval *time.Duration
flushCallback FlushCallback[T]

Expand All @@ -29,6 +36,13 @@ func WithMaxCapacity[T any](maxCapacity int) BatcherOption[T] {
}
}

// WithMaxSize sets the max size of the batcher
func WithMaxSize[T any](maxSize int32) BatcherOption[T] {
return func(b *Batcher[T]) {
b.maxSize = maxSize
}
}

// WithFlushInterval sets the optional flush interval of the batcher
func WithFlushInterval[T any](flushInterval time.Duration) BatcherOption[T] {
return func(b *Batcher[T]) {
Expand All @@ -53,6 +67,9 @@ func New[T any](opts ...BatcherOption[T]) *Batcher[T] {
for _, opt := range opts {
opt(batcher)
}
if batcher.maxSize > 0 {
batcher.currentSize = atomic.Int32{}
}
batcher.incomingData = make(chan T, batcher.maxCapacity)
if batcher.flushCallback == nil {
panic("batcher: flush callback is required")
Expand All @@ -66,11 +83,22 @@ func New[T any](opts ...BatcherOption[T]) *Batcher[T] {
// Append appends data to the batcher
func (b *Batcher[T]) Append(d ...T) {
for _, item := range d {
sizeofItem := size.Of(item)
currentSize := b.currentSize.Load()

if b.maxSize > 0 && currentSize+int32(sizeofItem) > int32(b.maxSize) {
b.full <- true
b.incomingData <- item
b.currentSize.Add(int32(sizeofItem))
continue
}

if !b.put(item) {
// will wait until space available
b.full <- true
b.incomingData <- item
}
b.currentSize.Add(int32(sizeofItem))
}
}

Expand Down Expand Up @@ -148,6 +176,7 @@ func (b *Batcher[T]) doCallback() {
for item := range b.incomingData {
items[k] = item
k++
b.currentSize.Add(-int32(size.Of(item)))
if k >= n {
break
}
Expand Down
43 changes: 43 additions & 0 deletions batcher/batcher_test.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package batcher

import (
"crypto/rand"
"testing"
"time"

Expand Down Expand Up @@ -74,3 +75,45 @@ func TestBatcherWithInterval(t *testing.T) {
require.Equal(t, wanted, got)
require.True(t, minWantedBatches <= gotBatches)
}

type exampleBatcherStruct struct {
Value []byte
}

func TestBatcherWithSizeLimit(t *testing.T) {
var (
batchSize = 100
maxSize = 1000
wanted = 10
gotBatches int
)
var failedIteration bool

callback := func(ta []exampleBatcherStruct) {
gotBatches++

if len(ta) != 5 {
failedIteration = true
}
}
bat := New[exampleBatcherStruct](
WithMaxCapacity[exampleBatcherStruct](batchSize),
WithMaxSize[exampleBatcherStruct](int32(maxSize)),
WithFlushCallback[exampleBatcherStruct](callback),
)

bat.Run()

for i := 0; i < wanted; i++ {
randData := make([]byte, 200)
_, _ = rand.Read(randData)
bat.Append(exampleBatcherStruct{Value: randData})
}

bat.Stop()

bat.WaitDone()

require.Equal(t, 2, gotBatches)
require.False(t, failedIteration)
}
1 change: 1 addition & 0 deletions go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ module github.com/projectdiscovery/utils
go 1.21

require (
github.com/DmitriyVTitov/size v1.5.0
github.com/Masterminds/semver/v3 v3.2.1
github.com/Mzack9999/gcache v0.0.0-20230410081825-519e28eab057
github.com/andybalholm/brotli v1.0.6
Expand Down
4 changes: 4 additions & 0 deletions go.sum
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
aead.dev/minisign v0.2.0 h1:kAWrq/hBRu4AARY6AlciO83xhNnW9UaC8YipS2uhLPk=
aead.dev/minisign v0.2.0/go.mod h1:zdq6LdSd9TbuSxchxwhpA9zEb9YXcVGoE8JakuiGaIQ=
cloud.google.com/go/compute/metadata v0.2.0/go.mod h1:zFmK7XCadkQkj6TtorcaGlCW1hT1fIilQDwofLpJ20k=
github.com/DmitriyVTitov/size v1.5.0 h1:/PzqxYrOyOUX1BXj6J9OuVRVGe+66VL4D9FlUaW515g=
github.com/DmitriyVTitov/size v1.5.0/go.mod h1:le6rNI4CoLQV1b9gzp1+3d7hMAD/uu2QcJ+aYbNgiU0=
github.com/Masterminds/semver/v3 v3.2.1 h1:RN9w6+7QoMeJVGyfmbcgs28Br8cvmnucEXnY0rYXWg0=
github.com/Masterminds/semver/v3 v3.2.1/go.mod h1:qvl/7zhW3nngYb5+80sSMF+FG2BjYrf8m9wsX0PNOMQ=
github.com/Mzack9999/gcache v0.0.0-20230410081825-519e28eab057 h1:KFac3SiGbId8ub47e7kd2PLZeACxc1LkiiNoDOFRClE=
Expand Down Expand Up @@ -71,6 +73,8 @@ github.com/go-ole/go-ole v1.2.6/go.mod h1:pprOEPIfldk/42T2oK7lQ4v4JSDwmV0As9GaiU
github.com/go-task/slim-sprig v0.0.0-20230315185526-52ccab3ef572 h1:tfuBGBXKqDEevZMzYi5KSi8KkcZtzBcTgAUUtapy0OI=
github.com/go-task/slim-sprig v0.0.0-20230315185526-52ccab3ef572/go.mod h1:9Pwr4B2jHnOSGXyyzV8ROjYa2ojvAY6HCGYYfMoC3Ls=
github.com/gofrs/uuid v3.3.0+incompatible/go.mod h1:b2aQJv3Z4Fp6yNu3cdSllBxTCLRxnplIgP/c0N/04lM=
github.com/golang/groupcache v0.0.0-20210331224755-41bb18bfe9da h1:oI5xCqsCo564l8iNU+DwB5epxmsaqB+rhGL0m5jtYqE=
github.com/golang/groupcache v0.0.0-20210331224755-41bb18bfe9da/go.mod h1:cIg4eruTrX1D+g88fzRXU5OdNfaM+9IcxsU14FzY7Hc=
github.com/golang/protobuf v1.2.0/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U=
github.com/golang/protobuf v1.3.1/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U=
github.com/golang/protobuf v1.3.2/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U=
Expand Down
135 changes: 135 additions & 0 deletions reflect/reflectutil.go
Original file line number Diff line number Diff line change
Expand Up @@ -94,3 +94,138 @@ func setUnexportedField(field reflect.Value, value interface{}) {
Elem().
Set(reflect.ValueOf(value))
}

// SizeOf returns the size of 'v' in bytes.
// If there is an error during calculation, Of returns -1.
//
// Implementation is taken from https://github.com/DmitriyVTitov/size/blob/v1.5.0/size.go#L14 which
// in turn is inspired from binary.Size of stdlib
func SizeOf(v interface{}) int {
// Cache with every visited pointer so we don't count two pointers
// to the same memory twice.
cache := make(map[uintptr]bool)
return sizeOf(reflect.Indirect(reflect.ValueOf(v)), cache)
}

// sizeOf returns the number of bytes the actual data represented by v occupies in memory.
// If there is an error, sizeOf returns -1.
func sizeOf(v reflect.Value, cache map[uintptr]bool) int {
switch v.Kind() {

case reflect.Array:
sum := 0
for i := 0; i < v.Len(); i++ {
s := sizeOf(v.Index(i), cache)
if s < 0 {
return -1
}
sum += s
}

return sum + (v.Cap()-v.Len())*int(v.Type().Elem().Size())

case reflect.Slice:
// return 0 if this node has been visited already
if cache[v.Pointer()] {
return 0
}
cache[v.Pointer()] = true

sum := 0
for i := 0; i < v.Len(); i++ {
s := sizeOf(v.Index(i), cache)
if s < 0 {
return -1
}
sum += s
}

sum += (v.Cap() - v.Len()) * int(v.Type().Elem().Size())

return sum + int(v.Type().Size())

case reflect.Struct:
sum := 0
for i, n := 0, v.NumField(); i < n; i++ {
s := sizeOf(v.Field(i), cache)
if s < 0 {
return -1
}
sum += s
}

// Look for struct padding.
padding := int(v.Type().Size())
for i, n := 0, v.NumField(); i < n; i++ {
padding -= int(v.Field(i).Type().Size())
}

return sum + padding

case reflect.String:
s := v.String()
hdr := (*reflect.StringHeader)(unsafe.Pointer(&s))
if cache[hdr.Data] {
return int(v.Type().Size())
}
cache[hdr.Data] = true
return len(s) + int(v.Type().Size())

case reflect.Ptr:
// return Ptr size if this node has been visited already (infinite recursion)
if cache[v.Pointer()] {
return int(v.Type().Size())
}
cache[v.Pointer()] = true
if v.IsNil() {
return int(reflect.New(v.Type()).Type().Size())
}
s := sizeOf(reflect.Indirect(v), cache)
if s < 0 {
return -1
}
return s + int(v.Type().Size())

case reflect.Bool,
reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64,
reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64,
reflect.Int, reflect.Uint,
reflect.Chan,
reflect.Uintptr,
reflect.Float32, reflect.Float64, reflect.Complex64, reflect.Complex128,
reflect.Func:
return int(v.Type().Size())

case reflect.Map:
// return 0 if this node has been visited already (infinite recursion)
if cache[v.Pointer()] {
return 0
}
cache[v.Pointer()] = true
sum := 0
keys := v.MapKeys()
for i := range keys {
val := v.MapIndex(keys[i])
// calculate size of key and value separately
sv := sizeOf(val, cache)
if sv < 0 {
return -1
}
sum += sv
sk := sizeOf(keys[i], cache)
if sk < 0 {
return -1
}
sum += sk
}
// Include overhead due to unused map buckets. 10.79 comes
// from https://golang.org/src/runtime/map.go.
return sum + int(v.Type().Size()) + int(float64(len(keys))*10.79)

case reflect.Interface:
return sizeOf(v.Elem(), cache) + int(v.Type().Size())

}

return -1
}
66 changes: 66 additions & 0 deletions reflect/reflectutil_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -49,3 +49,69 @@ func TestUnexportedField(t *testing.T) {
value := GetUnexportedField(testStruct, "unexported")
require.Equal(t, value, "test")
}

// Test taken from https://github.com/DmitriyVTitov/size/blob/v1.5.0/size_test.go
func TestSizeOf(t *testing.T) {
tests := []struct {
name string
v interface{}
want int
}{
{
name: "Array",
v: [3]int32{1, 2, 3}, // 3 * 4 = 12
want: 12,
},
{
name: "Slice",
v: make([]int64, 2, 5), // 5 * 8 + 24 = 64
want: 64,
},
{
name: "String",
v: "ABCdef", // 6 + 16 = 22
want: 22,
},
{
name: "Map",
// (8 + 3 + 16) + (8 + 4 + 16) = 55
// 55 + 8 + 10.79 * 2 = 84
v: map[int64]string{0: "ABC", 1: "DEFG"},
want: 84,
},
{
name: "Struct",
v: struct {
slice []int64
array [2]bool
structure struct {
i int8
s string
}
}{
slice: []int64{12345, 67890}, // 2 * 8 + 24 = 40
array: [2]bool{true, false}, // 2 * 1 = 2
structure: struct {
i int8
s string
}{
i: 5, // 1
s: "abc", // 3 * 1 + 16 = 19
}, // 20 + 7 (padding) = 27
}, // 40 + 2 + 27 = 69 + 6 (padding) = 75
want: 75,
},
{
name: "Pointer",
v: new(int64), // 8
want: 8,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if got := SizeOf(tt.v); got != tt.want {
t.Errorf("Of() = %v, want %v", got, tt.want)
}
})
}
}

0 comments on commit 1863395

Please sign in to comment.