Skip to content

Commit

Permalink
Merge pull request #248 from coinbase/patrick/sharded-map
Browse files Browse the repository at this point in the history
[utils] Implement ShardedMap
  • Loading branch information
patrick-ogrady authored Nov 23, 2020
2 parents c55106b + 7b5cef3 commit a61aa5d
Show file tree
Hide file tree
Showing 7 changed files with 190 additions and 21 deletions.
1 change: 1 addition & 0 deletions go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ require (
github.com/lucasjones/reggen v0.0.0-20180717132126-cdb49ff09d77
github.com/mitchellh/mapstructure v1.3.3
github.com/pkg/errors v0.9.1 // indirect
github.com/segmentio/fasthash v1.0.3
github.com/stretchr/objx v0.1.1 // indirect
github.com/stretchr/testify v1.6.1
github.com/tidwall/gjson v1.6.3
Expand Down
2 changes: 2 additions & 0 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -226,6 +226,8 @@ github.com/rs/cors v0.0.0-20160617231935-a62a804a8a00/go.mod h1:gFx+x8UowdsKA9Ac
github.com/rs/xhandler v0.0.0-20160618193221-ed27b6fd6521/go.mod h1:RvLn4FgxWubrpZHtQLnOf6EwhN2hEMusxZOhcW9H3UQ=
github.com/russross/blackfriday v1.5.2/go.mod h1:JO/DiYxRf+HjHt06OyowR9PTA263kcR/rfWxYHBV53g=
github.com/russross/blackfriday/v2 v2.0.1/go.mod h1:+Rmxgy9KzJVeS9/2gXHxylqXiyQDYRxCVz55jmeOWTM=
github.com/segmentio/fasthash v1.0.3 h1:EI9+KE1EwvMLBWwjpRDc+fEM+prwxDYbslddQGtrmhM=
github.com/segmentio/fasthash v1.0.3/go.mod h1:waKX8l2N8yckOgmSsXJi7x1ZfdKZ4x7KRMzBtS3oedY=
github.com/shirou/gopsutil v2.20.5+incompatible h1:tYH07UPoQt0OCQdgWWMgYHy3/a9bcxNpBIysykNIP7I=
github.com/shirou/gopsutil v2.20.5+incompatible/go.mod h1:5b4v6he4MtMOwMlS0TUMTu2PcXUg8+E1lC7eC3UO/RA=
github.com/shurcooL/sanitized_anchor_name v1.0.0/go.mod h1:1NzhyTcUVG4SuEtjjoZeVRXNmyL/1OwPU0+IJeTBvfc=
Expand Down
2 changes: 1 addition & 1 deletion storage/badger_storage_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -309,7 +309,7 @@ func TestBadgerTrain_Limit(t *testing.T) {
namespace,
newDir,
dictionaryPath,
10,
50,
[]*CompressorEntry{},
)
assert.NoError(t, err)
Expand Down
36 changes: 21 additions & 15 deletions utils/priority_mutex_map.go → utils/mutex_map.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,10 @@ import (
"sync"
)

const (
unlockPriority = true
)

// MutexMap is a struct that allows for
// acquiring a *PriorityMutex via a string identifier
// or for acquiring a global mutex that blocks
Expand All @@ -26,8 +30,7 @@ import (
// This is useful for coordinating concurrent, non-overlapping
// writes in the storage package.
type MutexMap struct {
entries map[string]*mutexMapEntry
mutex sync.Mutex
entries *ShardedMap
globalMutex sync.RWMutex
}

Expand All @@ -39,9 +42,9 @@ type mutexMapEntry struct {
}

// NewMutexMap returns a new *MutexMap.
func NewMutexMap() *MutexMap {
func NewMutexMap(shards int) *MutexMap {
return &MutexMap{
entries: map[string]*mutexMapEntry{},
entries: NewShardedMap(shards),
}
}

Expand Down Expand Up @@ -70,20 +73,23 @@ func (m *MutexMap) Lock(identifier string, priority bool) {
// We acquire m when adding items to m.table
// so that we don't accidentally overwrite
// lock created by another goroutine.
m.mutex.Lock()
l, ok := m.entries[identifier]
data := m.entries.Lock(identifier, priority)
raw, ok := data[identifier]
var entry *mutexMapEntry
if !ok {
l = &mutexMapEntry{
entry = &mutexMapEntry{
lock: new(PriorityMutex),
}
m.entries[identifier] = l
data[identifier] = entry
} else {
entry = raw.(*mutexMapEntry)
}
l.count++
m.mutex.Unlock()
entry.count++
m.entries.Unlock(identifier)

// Once we have a m.globalMutex.RLock, it is
// safe to acquire an identifier lock.
l.lock.Lock(priority)
entry.lock.Lock(priority)
}

// Unlock releases a lock held for a particular identifier.
Expand All @@ -92,15 +98,15 @@ func (m *MutexMap) Unlock(identifier string) {
// exist by the time we unlock, otherwise
// it would not have been possible to get
// the lock to begin with.
m.mutex.Lock()
entry := m.entries[identifier]
data := m.entries.Lock(identifier, unlockPriority)
entry := data[identifier].(*mutexMapEntry)
if entry.count <= 1 { // this should never be < 0
delete(m.entries, identifier)
delete(data, identifier)
} else {
entry.count--
entry.lock.Unlock()
}
m.mutex.Unlock()
m.entries.Unlock(identifier)

// We release the globalMutex after unlocking
// the identifier lock, otherwise it would be possible
Expand Down
16 changes: 11 additions & 5 deletions utils/priority_mutex_map_test.go → utils/mutex_map_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ import (

func TestMutexMap(t *testing.T) {
arr := []string{}
m := NewMutexMap()
m := NewMutexMap(DefaultShards)
g, _ := errgroup.WithContext(context.Background())

// Lock while adding all locks
Expand All @@ -47,7 +47,8 @@ func TestMutexMap(t *testing.T) {

g.Go(func() error {
m.Lock("a", false)
assert.Equal(t, m.entries["a"].count, 1)
entry := m.entries.shards[m.entries.shardIndex("a")].entries["a"].(*mutexMapEntry)
assert.Equal(t, entry.count, 1)
<-a
arr = append(arr, "a")
close(b)
Expand All @@ -57,7 +58,8 @@ func TestMutexMap(t *testing.T) {

g.Go(func() error {
m.Lock("b", false)
assert.Equal(t, m.entries["b"].count, 1)
entry := m.entries.shards[m.entries.shardIndex("b")].entries["b"].(*mutexMapEntry)
assert.Equal(t, entry.count, 1)
close(a)
<-b
arr = append(arr, "b")
Expand All @@ -68,7 +70,9 @@ func TestMutexMap(t *testing.T) {
time.Sleep(1 * time.Second)

// Ensure number of expected locks is correct
assert.Len(t, m.entries, 0)
totalKeys := len(m.entries.shards[m.entries.shardIndex("a")].entries) +
len(m.entries.shards[m.entries.shardIndex("b")].entries)
assert.Equal(t, totalKeys, 0)
arr = append(arr, "global-a")
m.GUnlock()
assert.NoError(t, g.Wait())
Expand All @@ -83,5 +87,7 @@ func TestMutexMap(t *testing.T) {
}, arr)

// Ensure lock is no longer occupied
assert.Len(t, m.entries, 0)
totalKeys = len(m.entries.shards[m.entries.shardIndex("a")].entries) +
len(m.entries.shards[m.entries.shardIndex("b")].entries)
assert.Equal(t, totalKeys, 0)
}
85 changes: 85 additions & 0 deletions utils/sharded_map.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
// Copyright 2020 Coinbase, Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

package utils

import (
"github.com/segmentio/fasthash/fnv1a"
)

const (
// DefaultShards is the default number of shards
// to use in ShardedMap.
DefaultShards = 256
)

// shardMapEntry governs access to the shard of
// the map contained at a particular index.
type shardMapEntry struct {
mutex *PriorityMutex
entries map[string]interface{}
}

// ShardedMap allows concurrent writes
// to a map by sharding the map into some
// number of independently locked subsections.
type ShardedMap struct {
shards []*shardMapEntry
}

// NewShardedMap creates a new *ShardedMap
// with some number of shards. The larger the
// number provided for shards, the less lock
// contention there will be.
//
// As a rule of thumb, shards should usually
// be set to the concurrency of the caller.
func NewShardedMap(shards int) *ShardedMap {
m := &ShardedMap{
shards: make([]*shardMapEntry, shards),
}

for i := 0; i < shards; i++ {
m.shards[i] = &shardMapEntry{
entries: map[string]interface{}{},
mutex: new(PriorityMutex),
}
}

return m
}

// shardIndex returns the index of the shard
// that could contain the key.
func (m *ShardedMap) shardIndex(key string) int {
return int(fnv1a.HashString32(key) % uint32(len(m.shards)))
}

// Lock acquires the lock for a shard that could contain
// the key. This syntax allows the caller to perform multiple
// operations while holding the lock for a single shard.
func (m *ShardedMap) Lock(key string, priority bool) map[string]interface{} {
shardIndex := m.shardIndex(key)
shard := m.shards[shardIndex]
shard.mutex.Lock(priority)
return shard.entries
}

// Unlock releases the lock for a shard that could contain
// the key.
func (m *ShardedMap) Unlock(key string) {
shardIndex := m.shardIndex(key)
shard := m.shards[shardIndex]
shard.mutex.Unlock()
}
69 changes: 69 additions & 0 deletions utils/sharded_map_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
// Copyright 2020 Coinbase, Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

package utils

import (
"context"
"testing"
"time"

"github.com/stretchr/testify/assert"
"golang.org/x/sync/errgroup"
)

func TestShardedMap(t *testing.T) {
m := NewShardedMap(2)
g, _ := errgroup.WithContext(context.Background())

// To test locking, we use channels
// that will cause deadlock if not executed
// concurrently.
a := make(chan struct{})
b := make(chan struct{})

g.Go(func() error {
s := m.Lock("a", false)
assert.Len(t, s, 0)
s["test"] = "a"
<-a
close(b)
m.Unlock("a")
return nil
})

g.Go(func() error {
s := m.Lock("b", false)
assert.Len(t, s, 0)
s["test"] = "b"
close(a)
<-b
m.Unlock("b")
return nil
})

time.Sleep(1 * time.Second)
assert.NoError(t, g.Wait())

// Ensure keys set correctly
s := m.Lock("a", false)
assert.Len(t, s, 1)
assert.Equal(t, s["test"], "a")
m.Unlock("a")

s = m.Lock("b", false)
assert.Len(t, s, 1)
assert.Equal(t, s["test"], "b")
m.Unlock("b")
}

0 comments on commit a61aa5d

Please sign in to comment.