Skip to content

Commit

Permalink
Fix race in tests and introduce DaemonStatusManager
Browse files Browse the repository at this point in the history
We need DaemonStatusManager to manage Init->Started->Stopped transisions
- it is a common pattern we use in Cadence server.
Now I'm only taking this for debouncer + hashring. Will change other
usages later.
  • Loading branch information
dkrotx committed Oct 9, 2024
1 parent 243fd2d commit ab36cff
Show file tree
Hide file tree
Showing 7 changed files with 131 additions and 38 deletions.
17 changes: 17 additions & 0 deletions common/daemon.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@

package common

import "sync/atomic"

const (
// used for background threads

Expand All @@ -39,3 +41,18 @@ type (
Stop()
}
)

// DaemonStatusManager wraps daemon status management
type DaemonStatusManager struct {
status atomic.Int32
}

// TransitionToStart returns true if daemon status transitioned from INITIAL to STARTED
func (m *DaemonStatusManager) TransitionToStart() bool {
return m.status.CompareAndSwap(DaemonStatusInitialized, DaemonStatusStarted)
}

// TransitionToStop returns true if daemon status transitioned from STARTED to STOPPED
func (m *DaemonStatusManager) TransitionToStop() bool {
return m.status.CompareAndSwap(DaemonStatusStarted, DaemonStatusStopped)
}
82 changes: 82 additions & 0 deletions common/daemon_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
// The MIT License (MIT)

// Copyright (c) 2017-2020 Uber Technologies Inc.

// Permission is hereby granted, free of charge, to any person obtaining a copy
// of this software and associated documentation files (the "Software"), to deal
// in the Software without restriction, including without limitation the rights
// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
// copies of the Software, and to permit persons to whom the Software is
// furnished to do so, subject to the following conditions:
//
// The above copyright notice and this permission notice shall be included in all
// copies or substantial portions of the Software.
//
// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
// SOFTWARE.

package common

import (
"sync"
"sync/atomic"
"testing"

"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)

func TestDaemonStatusManagerWorks(t *testing.T) {
var dm DaemonStatusManager

require.Equal(t, DaemonStatusInitialized, dm.status.Load())

assert.True(t, dm.TransitionToStart())
assert.False(t, dm.TransitionToStart(), "already started")
assert.True(t, dm.TransitionToStop())
assert.False(t, dm.TransitionToStop(), "already stopped")
}

func TestDaemonStatusManagerRequiresStartBeforeStop(t *testing.T) {
var dm DaemonStatusManager

assert.False(t, dm.TransitionToStop(), "never been started")
}

func TestDaemonStatusRaceCondition(t *testing.T) {
var dm DaemonStatusManager
var successes atomic.Int32
var wg sync.WaitGroup

// try to issue multiple Start-s at once
for i := 0; i < 10; i++ {
wg.Add(1)
go func() {
defer wg.Done()
if dm.TransitionToStart() {
successes.Add(1)
}
}()
}
wg.Wait()
assert.Equal(t, 1, int(successes.Load()), "only one Start call should succeed")

// now do the same for stops
successes.Store(0)
for i := 0; i < 10; i++ {
wg.Add(1)
go func() {
defer wg.Done()
if dm.TransitionToStop() {
successes.Add(1)
}
}()
}
wg.Wait()
assert.Equal(t, 1, int(successes.Load()), "only one Stop call should succeed")
}
10 changes: 10 additions & 0 deletions common/debounce/callback.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,12 +27,14 @@ import (
"sync"
"time"

"github.com/uber/cadence/common"
"github.com/uber/cadence/common/clock"
)

type DebouncedCallback struct {
sync.Mutex

status common.DaemonStatusManager
lastHandlerCall time.Time
callback func()
interval time.Duration
Expand Down Expand Up @@ -61,6 +63,10 @@ func NewDebouncedCallback(timeSource clock.TimeSource, interval time.Duration, c
}

func (d *DebouncedCallback) Start() {
if !d.status.TransitionToStart() {
return
}

d.waitGroup.Add(1)
go func() {
defer d.waitGroup.Done()
Expand All @@ -69,6 +75,10 @@ func (d *DebouncedCallback) Start() {
}

func (d *DebouncedCallback) Stop() {
if !d.status.TransitionToStop() {
return
}

d.cancelLoop()
d.waitGroup.Wait()
}
Expand Down
15 changes: 8 additions & 7 deletions common/debounce/callback_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
package debounce

import (
"sync/atomic"
"testing"
"time"

Expand All @@ -42,7 +43,7 @@ const (
type callbackTestData struct {
mockedTimeSource clock.MockedTimeSource
debouncedCallback *DebouncedCallback
calls int
calls atomic.Int32
}

func waitCondition(fn func() bool, duration time.Duration) bool {
Expand All @@ -69,7 +70,7 @@ func newCallbackTestData(t *testing.T) *callbackTestData {

td.mockedTimeSource = clock.NewMockedTimeSourceAt(time.Now())
callback := func() {
td.calls++
td.calls.Add(1)
}

td.debouncedCallback = NewDebouncedCallback(td.mockedTimeSource, testDebounceInterval, callback)
Expand All @@ -85,10 +86,10 @@ func TestDebouncedCallbackWorks(t *testing.T) {
td.debouncedCallback.Handler()
require.True(
t,
waitCondition(func() bool { return td.calls > 0 }, testTimeout),
waitCondition(func() bool { return td.calls.Load() > 0 }, testTimeout),
"first callback is expected to be issued immediately after handler",
)
assert.Equal(t, 1, td.calls, "should be just once call since handler() called once")
assert.Equal(t, 1, int(td.calls.Load()), "should be just once call since handler() called once")

// issue more calls to handler(); they all should be postponed to testDebounceInterval
for i := 0; i < 10; i++ {
Expand All @@ -97,7 +98,7 @@ func TestDebouncedCallbackWorks(t *testing.T) {

td.mockedTimeSource.Advance(testDebounceInterval)
time.Sleep(testSleepAmount)
assert.Equal(t, 2, td.calls)
assert.Equal(t, 2, int(td.calls.Load()))

// now call handler again, but advance time only by little - no callbacks are expected
for i := 0; i < 10; i++ {
Expand All @@ -106,15 +107,15 @@ func TestDebouncedCallbackWorks(t *testing.T) {

td.mockedTimeSource.Advance(testDebounceInterval / 2)
time.Sleep(testSleepAmount)
assert.Equal(t, 2, td.calls, "should not have new callbacks")
assert.Equal(t, 2, int(td.calls.Load()), "should not have new callbacks")
}

func TestDebouncedCallbackDoesntCallHandlerIfThereWereNoUpdates(t *testing.T) {
td := newCallbackTestData(t)

td.mockedTimeSource.Advance(2 * testDebounceInterval)
time.Sleep(testSleepAmount)
assert.Equal(t, 0, td.calls)
assert.Equal(t, 0, int(td.calls.Load()))
}

func TestDebouncedCallbackDoubleStopIsOK(t *testing.T) {
Expand Down
21 changes: 11 additions & 10 deletions common/debounce/channel_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ package debounce
import (
"context"
"sync"
"sync/atomic"
"testing"
"time"

Expand Down Expand Up @@ -62,9 +63,9 @@ func newChannelTestData(t *testing.T) *channelTestData {
return &td
}

func (td *channelTestData) countCalls(t *testing.T, calls *int) {
func (td *channelTestData) countCalls(t *testing.T, calls *atomic.Int32) {
// create a reader from channel which will save calls:
// this way it should be easier to write tests by just evaluating td.calls
// this way it should be easier to write tests by just evaluating `calls`
var wg sync.WaitGroup
ctx, cancel := context.WithCancel(context.Background())

Expand All @@ -75,7 +76,7 @@ func (td *channelTestData) countCalls(t *testing.T, calls *int) {
for {
select {
case <-td.debouncedChannel.Chan():
*calls++
calls.Add(1)
case <-ctx.Done():
return
}
Expand All @@ -97,16 +98,16 @@ func (td *channelTestData) countCalls(t *testing.T, calls *int) {
func TestDebouncedSignalWorks(t *testing.T) {
td := newChannelTestData(t)

var calls int
var calls atomic.Int32
td.countCalls(t, &calls)

td.debouncedChannel.Handler()
require.True(
t,
waitCondition(func() bool { return calls > 0 }, testTimeout),
waitCondition(func() bool { return calls.Load() > 0 }, testTimeout),
"first callback is expected to be issued immediately after handler",
)
assert.Equal(t, 1, calls, 1)
assert.Equal(t, 1, int(calls.Load()), 1)

// we call handler multiple times. There should be just one message in channel
for i := 0; i < 10; i++ {
Expand All @@ -115,7 +116,7 @@ func TestDebouncedSignalWorks(t *testing.T) {

td.mockedTimeSource.Advance(testDebounceInterval)
time.Sleep(testSleepAmount)
assert.Equal(t, 2, calls)
assert.Equal(t, 2, int(calls.Load()))

// now call handler again, but advance time only by little - no messages in channel are expected
for i := 0; i < 10; i++ {
Expand All @@ -124,7 +125,7 @@ func TestDebouncedSignalWorks(t *testing.T) {

td.mockedTimeSource.Advance(testDebounceInterval / 2)
time.Sleep(testSleepAmount)
assert.Equal(t, 2, calls, "should not have new messages in channel")
assert.Equal(t, 2, int(calls.Load()), "should not have new messages in channel")
}

func TestDebouncedSignalDoesntDuplicateIfWeDontReadChannel(t *testing.T) {
Expand All @@ -145,9 +146,9 @@ func TestDebouncedSignalDoesntDuplicateIfWeDontReadChannel(t *testing.T) {
time.Sleep(testSleepAmount)

// only now we start reading messages from channel
var calls int
var calls atomic.Int32
td.countCalls(t, &calls)
time.Sleep(testSleepAmount)

assert.Equal(t, 1, calls, "Only a single message is expected")
assert.Equal(t, 1, int(calls.Load()), "Only a single message is expected")
}
15 changes: 3 additions & 12 deletions common/membership/hashring.go
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ type PeerProvider interface {
type ring struct {
debounce *debounce.DebouncedChannel

status int32
status common.DaemonStatusManager
service string
peerProvider PeerProvider
shutdownCh chan struct{}
Expand Down Expand Up @@ -97,7 +97,6 @@ func newHashring(
scope metrics.Scope,
) *ring {
r := &ring{
status: common.DaemonStatusInitialized,
service: service,
peerProvider: provider,
shutdownCh: make(chan struct{}),
Expand All @@ -120,11 +119,7 @@ func emptyHashring() *hashring.HashRing {

// Start starts the hashring
func (r *ring) Start() {
if !atomic.CompareAndSwapInt32(
&r.status,
common.DaemonStatusInitialized,
common.DaemonStatusStarted,
) {
if !r.status.TransitionToStart() {
return
}

Expand All @@ -143,11 +138,7 @@ func (r *ring) Start() {

// Stop stops the resolver
func (r *ring) Stop() {
if !atomic.CompareAndSwapInt32(
&r.status,
common.DaemonStatusStarted,
common.DaemonStatusStopped,
) {
if !r.status.TransitionToStop() {
return
}

Expand Down
9 changes: 0 additions & 9 deletions common/membership/hashring_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -430,15 +430,6 @@ func TestErrorIsPropagatedWhenProviderFails(t *testing.T) {
assert.ErrorContains(t, td.hashRing.refresh(), "provider failure")
}

func TestStopWillStopProvider(t *testing.T) {
td := newHashringTestData(t)

td.mockPeerProvider.EXPECT().Stop().Times(1)

td.hashRing.status = common.DaemonStatusStarted
td.hashRing.Stop()
}

func TestLookupAndRefreshRaceCondition(t *testing.T) {
td := newHashringTestData(t)
var wg sync.WaitGroup
Expand Down

0 comments on commit ab36cff

Please sign in to comment.