Skip to content

Commit

Permalink
feat: add GetOrCreate and GetOrCall methods
Browse files Browse the repository at this point in the history
Also add Len/Clear/Trunc utility methods.

Signed-off-by: Dmitriy Matrenichev <dmitry.matrenichev@siderolabs.com>
  • Loading branch information
DmitriyMV committed Dec 1, 2022
1 parent 7c7ccc3 commit 8e89b1e
Show file tree
Hide file tree
Showing 3 changed files with 243 additions and 1 deletion.
74 changes: 74 additions & 0 deletions containers/map.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,49 @@ func (m *ConcurrentMap[K, V]) Get(key K) (V, bool) {
return val, ok
}

// GetOrCreate returns the existing value for the key if present. Otherwise, it stores and returns the given value.
// The loaded result is true if the value was loaded, false if stored.
func (m *ConcurrentMap[K, V]) GetOrCreate(key K, val V) (V, bool) {
m.mx.Lock()
defer m.mx.Unlock()

if res, ok := m.m[key]; ok {
return res, true
}

if m.m == nil {
m.m = map[K]V{}
}

m.m[key] = val

return val, false
}

// GetOrCall returns the existing value for the key if present. Otherwise, it calls fn, stores the result and returns it.
// The loaded result is true if the value was loaded, false if it was created using fn.
//
// The main reason for this function is to avoid unnecessary allocations if you use pointer types as values, since
// compiler cannot prove that the value does not escape if it's not stored.
func (m *ConcurrentMap[K, V]) GetOrCall(key K, fn func() V) (V, bool) {
m.mx.Lock()
defer m.mx.Unlock()

if res, ok := m.m[key]; ok {
return res, true
}

if m.m == nil {
m.m = map[K]V{}
}

val := fn()

m.m[key] = val

return val, false
}

// Set sets the value for the given key.
func (m *ConcurrentMap[K, V]) Set(key K, val V) {
m.mx.Lock()
Expand All @@ -46,6 +89,21 @@ func (m *ConcurrentMap[K, V]) Remove(key K) {
delete(m.m, key)
}

// RemoveAndGet removes the value for the given key and returns it if it exists.
func (m *ConcurrentMap[K, V]) RemoveAndGet(key K) (V, bool) {
m.mx.Lock()
defer m.mx.Unlock()

if m.m == nil {
return *new(V), false //nolint:gocritic
}

val, ok := m.m[key]
delete(m.m, key)

return val, ok
}

// ForEach calls the given function for each key-value pair.
func (m *ConcurrentMap[K, V]) ForEach(f func(K, V)) {
m.mx.Lock()
Expand All @@ -56,6 +114,14 @@ func (m *ConcurrentMap[K, V]) ForEach(f func(K, V)) {
}
}

// Len returns the number of elements in the map.
func (m *ConcurrentMap[K, V]) Len() int {
m.mx.Lock()
defer m.mx.Unlock()

return len(m.m)
}

// Clear removes all key-value pairs.
func (m *ConcurrentMap[K, V]) Clear() {
m.mx.Lock()
Expand All @@ -65,3 +131,11 @@ func (m *ConcurrentMap[K, V]) Clear() {
delete(m.m, k)
}
}

// Reset resets the underlying map.
func (m *ConcurrentMap[K, V]) Reset() {
m.mx.Lock()
defer m.mx.Unlock()

m.m = nil
}
166 changes: 166 additions & 0 deletions containers/map_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,26 +5,37 @@
package containers_test

import (
"fmt"
"math/rand"
"testing"

"github.com/stretchr/testify/require"

"github.com/siderolabs/gen/containers"
"github.com/siderolabs/gen/xsync"
)

func TestConcurrentMap(t *testing.T) {
t.Parallel()

t.Run("should return nothing if key doesnt exist", func(t *testing.T) {
t.Parallel()

m := containers.ConcurrentMap[int, int]{}
_, ok := m.Get(0)
require.False(t, ok)
})

t.Run("should remove nothing if map is empty", func(t *testing.T) {
t.Parallel()

m := containers.ConcurrentMap[int, int]{}
m.Remove(0)
})

t.Run("should return setted value", func(t *testing.T) {
t.Parallel()

m := containers.ConcurrentMap[int, int]{}
m.Set(1, 1)
val, ok := m.Get(1)
Expand All @@ -33,14 +44,34 @@ func TestConcurrentMap(t *testing.T) {
})

t.Run("should remove value", func(t *testing.T) {
t.Parallel()

m := containers.ConcurrentMap[int, int]{}
m.Set(1, 1)
m.Remove(1)
_, ok := m.Get(1)
require.False(t, ok)

m.Set(2, 2)
got, ok := m.RemoveAndGet(2)
require.True(t, ok)
require.Equal(t, 2, got)

got, ok = m.RemoveAndGet(2)
require.False(t, ok)
require.Zero(t, got)

m.Reset()
got, ok = m.RemoveAndGet(2)
require.False(t, ok)
require.Zero(t, got)

require.False(t, ok)
})

t.Run("should call fn for every key", func(t *testing.T) {
t.Parallel()

m := containers.ConcurrentMap[int, int]{}
m.Set(1, 1)
m.Set(2, 2)
Expand All @@ -52,4 +83,139 @@ func TestConcurrentMap(t *testing.T) {
})
require.Equal(t, 3, count)
})

t.Run("should clear the map", func(t *testing.T) {
t.Parallel()

m := containers.ConcurrentMap[int, int]{}
m.Set(1, 1)

require.Equal(t, 1, m.Len())

m.Clear()

require.Equal(t, 0, m.Len())
})

t.Run("should trunc the map", func(t *testing.T) {
t.Parallel()

m := containers.ConcurrentMap[int, int]{}
m.Set(1, 1)

require.Equal(t, 1, m.Len())

m.Reset()

require.Equal(t, 0, m.Len())
})
}

func TestConcurrentMap_GetOrCall(t *testing.T) {
var m containers.ConcurrentMap[int, int]

t.Run("group", func(t *testing.T) {
t.Run("try to insert value", func(t *testing.T) {
parallelGetOrCall(t, &m, 100, 1000)
})

t.Run("try to insert value #2", func(t *testing.T) {
parallelGetOrCall(t, &m, 1000, 100)
})
})
}

func parallelGetOrCall(t *testing.T, m *containers.ConcurrentMap[int, int], our, another int) {
t.Parallel()

oneAnotherGet := false

for i := 0; i < 10000; i++ {
key := int(rand.Int63n(10000))

res, ok := m.GetOrCall(key, func() int { return key * our })
if ok {
switch res {
case key * our:
case key * another:
oneAnotherGet = true
default:
t.Fatalf("unexpected value %d", res)
}
}
}

require.True(t, oneAnotherGet)
}

func TestConcurrentMap_GetOrCreate(t *testing.T) {
var m containers.ConcurrentMap[int, int]

t.Run("group", func(t *testing.T) {
t.Run("try to insert value", func(t *testing.T) {
parallelGetOrCreate(t, &m, 100, 1000)
})

t.Run("try to insert value #2", func(t *testing.T) {
parallelGetOrCreate(t, &m, 1000, 100)
})
})
}

func parallelGetOrCreate(t *testing.T, m *containers.ConcurrentMap[int, int], our, another int) {
t.Parallel()

oneAnotherGet := false

for i := 0; i < 10000; i++ {
key := int(rand.Int63n(10000))

res, ok := m.GetOrCreate(key, key*our)
if ok {
switch res {
case key * our:
case key * another:
oneAnotherGet = true
default:
t.Fatalf("unexpected value %d", res)
}
}
}

require.True(t, oneAnotherGet)
}

func Example_benchConcurrentMap() {
var sink int

benchResult := testing.Benchmark(func(b *testing.B) {
b.ReportAllocs()

var m containers.ConcurrentMap[int, *xsync.Once[int]]

for i := 0; i < b.N; i++ {
variable := 0

res, _ := m.GetOrCall(10, func() *xsync.Once[int] {
return &xsync.Once[int]{}
})

sink = res.Do(func() int {
variable++

return variable
})
}
})

if benchResult.AllocsPerOp() > 0 {
fmt.Println("this benchmark should not allocate memory")
}

fmt.Println("ok")

// Output:
// ok

_ = sink
}
4 changes: 3 additions & 1 deletion xsync/once.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,9 @@
// Package xsync contains the additions to std sync package.
package xsync

import "sync"
import (
"sync"
)

// Once is small wrapper around [sync.Once]. It stores the result inside.
type Once[T any] struct {
Expand Down

0 comments on commit 8e89b1e

Please sign in to comment.