Skip to content

Commit

Permalink
Fix race in memberlist client when KV store keeps the value returned …
Browse files Browse the repository at this point in the history
…from CAS function.
  • Loading branch information
pstibrany committed Sep 23, 2024
1 parent 1f324b4 commit 323277b
Show file tree
Hide file tree
Showing 2 changed files with 108 additions and 10 deletions.
25 changes: 18 additions & 7 deletions kv/memberlist/memberlist_client.go
Original file line number Diff line number Diff line change
Expand Up @@ -1018,14 +1018,16 @@ func (m *KV) trySingleCas(key string, codec codec.Codec, f func(in interface{})
}

// Don't even try
r, ok := out.(Mergeable)
if !ok || r == nil {
incomingValue, ok := out.(Mergeable)
if !ok || incomingValue == nil {
return nil, 0, retry, fmt.Errorf("invalid type: %T, expected Mergeable", out)
}

// To support detection of removed items from value, we will only allow CAS operation to
// succeed if version hasn't changed, i.e. state hasn't changed since running 'f'.
change, newver, err := m.mergeValueForKey(key, r, ver, codec)
// Supplied function may have kept a reference to the returned "incoming value".
// If KV store will keep this value as well, it needs to make a clone.
change, newver, err := m.mergeValueForKey(key, incomingValue, true, ver, codec)
if err == errVersionMismatch {
return nil, 0, retry, err
}
Expand Down Expand Up @@ -1379,14 +1381,15 @@ func (m *KV) mergeBytesValueForKey(key string, incomingData []byte, codec codec.
return nil, 0, fmt.Errorf("expected Mergeable, got: %T", decodedValue)
}

return m.mergeValueForKey(key, incomingValue, 0, codec)
// No need to clone this "incomingValue", since we have just decoded it from bytes, and won't be using it.
return m.mergeValueForKey(key, incomingValue, false, 0, codec)
}

// Merges incoming value with value we have in our store. Returns "a change" that can be sent to other
// cluster members to update their state, and new version of the value.
// If CAS version is specified, then merging will fail if state has changed already, and errVersionMismatch is reported.
// If no modification occurred, new version is 0.
func (m *KV) mergeValueForKey(key string, incomingValue Mergeable, casVersion uint, codec codec.Codec) (Mergeable, uint, error) {
func (m *KV) mergeValueForKey(key string, incomingValue Mergeable, incomingValueRequiresClone bool, casVersion uint, codec codec.Codec) (Mergeable, uint, error) {
m.storeMu.Lock()
defer m.storeMu.Unlock()

Expand All @@ -1398,7 +1401,7 @@ func (m *KV) mergeValueForKey(key string, incomingValue Mergeable, casVersion ui
if casVersion > 0 && curr.Version != casVersion {
return nil, 0, errVersionMismatch
}
result, change, err := computeNewValue(incomingValue, curr.value, casVersion > 0)
result, change, err := computeNewValue(incomingValue, incomingValueRequiresClone, curr.value, casVersion > 0)
if err != nil {
return nil, 0, err
}
Expand Down Expand Up @@ -1441,8 +1444,16 @@ func (m *KV) mergeValueForKey(key string, incomingValue Mergeable, casVersion ui
}

// returns [result, change, error]
func computeNewValue(incoming Mergeable, oldVal Mergeable, cas bool) (Mergeable, Mergeable, error) {
func computeNewValue(incoming Mergeable, incomingValueRequiresClone bool, oldVal Mergeable, cas bool) (Mergeable, Mergeable, error) {
if oldVal == nil {
// It's OK to return the same value twice (once as result, once as change), because "change" will be cloned
// in mergeValueForKey if needed.

if incomingValueRequiresClone {
clone := incoming.Clone()
return clone, clone, nil
}

return incoming, incoming, nil
}

Expand Down
93 changes: 90 additions & 3 deletions kv/memberlist/memberlist_client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package memberlist
import (
"bytes"
"context"
"encoding/binary"
"encoding/gob"
"errors"
"fmt"
Expand Down Expand Up @@ -1567,12 +1568,18 @@ func decodeDataFromMarshalledKeyValuePair(t *testing.T, marshalledKVP []byte, ke
return d
}

func marshalKeyValuePair(t *testing.T, key string, codec codec.Codec, value interface{}) []byte {
func keyValuePair(t *testing.T, key string, codec codec.Codec, value interface{}) *KeyValuePair {
data, err := codec.Encode(value)
require.NoError(t, err)

kvp := KeyValuePair{Key: key, Codec: codec.CodecID(), Value: data}
data, err = kvp.Marshal()
return &KeyValuePair{Key: key, Codec: codec.CodecID(), Value: data}

}

func marshalKeyValuePair(t *testing.T, key string, codec codec.Codec, value interface{}) []byte {
kvp := keyValuePair(t, key, codec, value)

data, err := kvp.Marshal()
require.NoError(t, err)
return data
}
Expand Down Expand Up @@ -1710,3 +1717,83 @@ func getKey(t *testing.T, msg []byte) string {
require.NoError(t, err)
return kvPair.Key
}

func TestRaceBetweenStoringNewValueForKeyAndUpdatingIt(t *testing.T) {
codec := dataCodec{}

cfg := KVConfig{}
cfg.Codecs = append(cfg.Codecs, codec)
cfg.TCPTransport = TCPTransportConfig{
BindAddrs: getLocalhostAddrs(),
}

kv := NewKV(cfg, log.NewNopLogger(), &dnsProviderMock{}, prometheus.NewPedanticRegistry())

require.NoError(t, services.StartAndAwaitRunning(context.Background(), kv))
defer services.StopAndAwaitTerminated(context.Background(), kv)

Check failure on line 1733 in kv/memberlist/memberlist_client_test.go

View workflow job for this annotation

GitHub Actions / Check

Error return value of `services.StopAndAwaitTerminated` is not checked (errcheck)

ctx, cancel := context.WithTimeout(context.Background(), 1*time.Second)
defer cancel()

vals := make(chan int64, 10000)

go func() {
d := &data{Members: map[string]member{}}
for i := 0; i < 100; i++ {
d.Members[fmt.Sprintf("member_%d", i)] = member{Timestamp: time.Now().Unix(), State: i % 3}
}

kv.CAS(context.Background(), key, codec, func(in interface{}) (out interface{}, retry bool, err error) {

Check failure on line 1746 in kv/memberlist/memberlist_client_test.go

View workflow job for this annotation

GitHub Actions / Check

Error return value of `kv.CAS` is not checked (errcheck)
return d, true, nil
})

// keep iterating over d.Members. If other goroutine modifies same ring descriptor, we will see a race error.
for ctx.Err() == nil {
sum := int64(0)
for n, m := range d.Members {
sum += int64(len(n))
sum += m.Timestamp
sum += int64(len(m.Tokens))
}
vals <- sum
time.Sleep(10 * time.Millisecond)
}
}()

// Wait until CAS and iteration finishes before pushing remote state.
<-vals

s := 0
for ctx.Err() == nil {
s++
d := &data{Members: map[string]member{}}
for i := 0; i < 100; i++ {
d.Members[fmt.Sprintf("member_%d", i)] = member{Timestamp: time.Now().Unix(), State: (i + s) % 3}
}

kv.MergeRemoteState(marshalState(t, keyValuePair(t, key, codec, d)), false)
time.Sleep(10 * time.Millisecond)

drain:
select {
case <-vals:
goto drain
default:
// stop draining.
}
}
}

func marshalState(t *testing.T, kvps ...*KeyValuePair) []byte {
buf := bytes.Buffer{}

for _, kvp := range kvps {
d, err := kvp.Marshal()
require.NoError(t, err)
err = binary.Write(&buf, binary.BigEndian, uint32(len(d)))
require.NoError(t, err)
buf.Write(d)
}

return buf.Bytes()
}

0 comments on commit 323277b

Please sign in to comment.