Skip to content

Commit

Permalink
fix(bigquery/storage/managedwriter): resolve data races
Browse files Browse the repository at this point in the history
This PR adds a non-assertive test which helps expose data
races by doing a lot of concurrent write operations on a single
ManagedStream instance.

As a byproduct, this cleans up two possible races:  In the first,
a deferred function may incorrectly access a retained context.
We change this to grab a reference to the context in the defer
where we still retain the lock.

In the second, the retry mechanism leverages math/rand and retry
processing can yield concurrent usage of the random number source.
PR adds a mutex guard to the source.

Fixes: googleapis#9301
  • Loading branch information
shollyman committed Feb 3, 2024
1 parent 9ce3d5b commit 28af7d3
Show file tree
Hide file tree
Showing 3 changed files with 96 additions and 8 deletions.
15 changes: 8 additions & 7 deletions bigquery/storage/managedwriter/connection.go
Original file line number Diff line number Diff line change
Expand Up @@ -353,7 +353,7 @@ func (co *connection) lockingAppend(pw *pendingWrite) error {
return err
}

var statsOnExit func()
var statsOnExit func(ctx context.Context)

// critical section: Things that need to happen inside the critical section:
//
Expand All @@ -362,9 +362,10 @@ func (co *connection) lockingAppend(pw *pendingWrite) error {
// * add the pending write to the channel for the connection (ordering for the response)
co.mu.Lock()
defer func() {
sCtx := co.ctx
co.mu.Unlock()
if statsOnExit != nil {
statsOnExit()
if statsOnExit != nil && sCtx != nil {
statsOnExit(sCtx)
}
}()

Expand Down Expand Up @@ -441,12 +442,12 @@ func (co *connection) lockingAppend(pw *pendingWrite) error {
numRows = int64(len(pr.GetSerializedRows()))
}
}
statsOnExit = func() {
statsOnExit = func(ctx context.Context) {
// these will get recorded once we exit the critical section.
// TODO: resolve open questions around what labels should be attached (connection, streamID, etc)
recordStat(co.ctx, AppendRequestRows, numRows)
recordStat(co.ctx, AppendRequests, 1)
recordStat(co.ctx, AppendRequestBytes, int64(pw.reqSize))
recordStat(ctx, AppendRequestRows, numRows)
recordStat(ctx, AppendRequests, 1)
recordStat(ctx, AppendRequestBytes, int64(pw.reqSize))
}
ch <- pw
return nil
Expand Down
82 changes: 82 additions & 0 deletions bigquery/storage/managedwriter/managed_stream_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ import (
"errors"
"io"
"runtime"
"sync"
"testing"
"time"

Expand Down Expand Up @@ -643,3 +644,84 @@ func TestManagedStream_Closure(t *testing.T) {
t.Errorf("expected writer ctx to be dead, is alive")
}
}

// This test exists to try to surface data races by sharing
// a single writer with multiple goroutines. It doesn't assert
// anything about the behavior of the system.
func TestManagedStream_RaceFinder(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())

var totalsMu sync.Mutex
totalSends := 0
totalRecvs := 0
pool := &connectionPool{
ctx: ctx,
cancel: cancel,
baseFlowController: newFlowController(0, 0),
open: openTestArc(&testAppendRowsClient{},
func(req *storagepb.AppendRowsRequest) error {
totalsMu.Lock()
totalSends = totalSends + 1
curSends := totalSends
totalsMu.Unlock()
if curSends%25 == 0 {
//time.Sleep(10 * time.Millisecond)
return io.EOF
}
return nil
},
func() (*storagepb.AppendRowsResponse, error) {
totalsMu.Lock()
totalRecvs = totalRecvs + 1
curRecvs := totalRecvs
totalsMu.Unlock()
if curRecvs%15 == 0 {
return nil, io.EOF
}
return &storagepb.AppendRowsResponse{}, nil
}),
}
router := newSimpleRouter("")
if err := pool.activateRouter(router); err != nil {
t.Errorf("activateRouter: %v", err)
}

ms := &ManagedStream{
id: "foo",
streamSettings: defaultStreamSettings(),
retry: newStatelessRetryer(),
}
ms.retry.maxAttempts = 4
ms.ctx, ms.cancel = context.WithCancel(pool.ctx)
ms.curTemplate = newVersionedTemplate().revise(reviseProtoSchema(&descriptorpb.DescriptorProto{}))
if err := pool.addWriter(ms); err != nil {
t.Errorf("addWriter A: %v", err)
}

if router.conn == nil {
t.Errorf("expected non-nil connection")
}

numWriters := 5
numWrites := 50

var wg sync.WaitGroup
wg.Add(numWriters)
for i := 0; i < numWriters; i++ {
go func() {
for j := 0; j < numWrites; j++ {
result, err := ms.AppendRows(ctx, [][]byte{[]byte("foo")})
if err != nil {
continue
}
_, err = result.GetResult(ctx)
if err != nil {
continue
}
}
wg.Done()
}()
}
wg.Wait()
cancel()
}
7 changes: 6 additions & 1 deletion bigquery/storage/managedwriter/retry.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ import (
"io"
"math/rand"
"strings"
"sync"
"time"

"github.com/googleapis/gax-go/v2"
Expand Down Expand Up @@ -82,7 +83,9 @@ func (ur *unaryRetryer) Retry(err error) (time.Duration, bool) {
// from the receive side of the bidi stream. An individual item in that process has a notion of an attempt
// count, and we use maximum retries as a way of evicting bad items.
type statelessRetryer struct {
r *rand.Rand
mu sync.Mutex // guards r
r *rand.Rand

minBackoff time.Duration
jitter time.Duration
aggressiveFactor int
Expand All @@ -101,7 +104,9 @@ func newStatelessRetryer() *statelessRetryer {
func (sr *statelessRetryer) pause(aggressiveBackoff bool) time.Duration {
jitter := sr.jitter.Nanoseconds()
if jitter > 0 {
sr.mu.Lock()
jitter = sr.r.Int63n(jitter)
sr.mu.Unlock()
}
pause := sr.minBackoff.Nanoseconds() + jitter
if aggressiveBackoff {
Expand Down

0 comments on commit 28af7d3

Please sign in to comment.