diff --git a/bigquery/storage/managedwriter/connection.go b/bigquery/storage/managedwriter/connection.go index 5c3d81f1c59b..7e14fff2791e 100644 --- a/bigquery/storage/managedwriter/connection.go +++ b/bigquery/storage/managedwriter/connection.go @@ -353,7 +353,7 @@ func (co *connection) lockingAppend(pw *pendingWrite) error { return err } - var statsOnExit func() + var statsOnExit func(ctx context.Context) // critical section: Things that need to happen inside the critical section: // @@ -362,9 +362,10 @@ func (co *connection) lockingAppend(pw *pendingWrite) error { // * add the pending write to the channel for the connection (ordering for the response) co.mu.Lock() defer func() { + sCtx := co.ctx co.mu.Unlock() - if statsOnExit != nil { - statsOnExit() + if statsOnExit != nil && sCtx != nil { + statsOnExit(sCtx) } }() @@ -441,12 +442,12 @@ func (co *connection) lockingAppend(pw *pendingWrite) error { numRows = int64(len(pr.GetSerializedRows())) } } - statsOnExit = func() { + statsOnExit = func(ctx context.Context) { // these will get recorded once we exit the critical section. // TODO: resolve open questions around what labels should be attached (connection, streamID, etc) - recordStat(co.ctx, AppendRequestRows, numRows) - recordStat(co.ctx, AppendRequests, 1) - recordStat(co.ctx, AppendRequestBytes, int64(pw.reqSize)) + recordStat(ctx, AppendRequestRows, numRows) + recordStat(ctx, AppendRequests, 1) + recordStat(ctx, AppendRequestBytes, int64(pw.reqSize)) } ch <- pw return nil diff --git a/bigquery/storage/managedwriter/managed_stream_test.go b/bigquery/storage/managedwriter/managed_stream_test.go index 6ec2ca584a7f..b4bc04b44254 100644 --- a/bigquery/storage/managedwriter/managed_stream_test.go +++ b/bigquery/storage/managedwriter/managed_stream_test.go @@ -19,6 +19,7 @@ import ( "errors" "io" "runtime" + "sync" "testing" "time" @@ -643,3 +644,84 @@ func TestManagedStream_Closure(t *testing.T) { t.Errorf("expected writer ctx to be dead, is alive") } } + +// This test exists to try to surface data races by sharing +// a single writer with multiple goroutines. It doesn't assert +// anything about the behavior of the system. +func TestManagedStream_RaceFinder(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + + var totalsMu sync.Mutex + totalSends := 0 + totalRecvs := 0 + pool := &connectionPool{ + ctx: ctx, + cancel: cancel, + baseFlowController: newFlowController(0, 0), + open: openTestArc(&testAppendRowsClient{}, + func(req *storagepb.AppendRowsRequest) error { + totalsMu.Lock() + totalSends = totalSends + 1 + curSends := totalSends + totalsMu.Unlock() + if curSends%25 == 0 { + //time.Sleep(10 * time.Millisecond) + return io.EOF + } + return nil + }, + func() (*storagepb.AppendRowsResponse, error) { + totalsMu.Lock() + totalRecvs = totalRecvs + 1 + curRecvs := totalRecvs + totalsMu.Unlock() + if curRecvs%15 == 0 { + return nil, io.EOF + } + return &storagepb.AppendRowsResponse{}, nil + }), + } + router := newSimpleRouter("") + if err := pool.activateRouter(router); err != nil { + t.Errorf("activateRouter: %v", err) + } + + ms := &ManagedStream{ + id: "foo", + streamSettings: defaultStreamSettings(), + retry: newStatelessRetryer(), + } + ms.retry.maxAttempts = 4 + ms.ctx, ms.cancel = context.WithCancel(pool.ctx) + ms.curTemplate = newVersionedTemplate().revise(reviseProtoSchema(&descriptorpb.DescriptorProto{})) + if err := pool.addWriter(ms); err != nil { + t.Errorf("addWriter A: %v", err) + } + + if router.conn == nil { + t.Errorf("expected non-nil connection") + } + + numWriters := 5 + numWrites := 50 + + var wg sync.WaitGroup + wg.Add(numWriters) + for i := 0; i < numWriters; i++ { + go func() { + for j := 0; j < numWrites; j++ { + result, err := ms.AppendRows(ctx, [][]byte{[]byte("foo")}) + if err != nil { + continue + } + _, err = result.GetResult(ctx) + if err != nil { + continue + } + } + wg.Done() + }() + } + wg.Wait() + cancel() +} diff --git a/bigquery/storage/managedwriter/retry.go b/bigquery/storage/managedwriter/retry.go index b16cca378731..c2983e84a797 100644 --- a/bigquery/storage/managedwriter/retry.go +++ b/bigquery/storage/managedwriter/retry.go @@ -19,6 +19,7 @@ import ( "io" "math/rand" "strings" + "sync" "time" "github.com/googleapis/gax-go/v2" @@ -82,7 +83,9 @@ func (ur *unaryRetryer) Retry(err error) (time.Duration, bool) { // from the receive side of the bidi stream. An individual item in that process has a notion of an attempt // count, and we use maximum retries as a way of evicting bad items. type statelessRetryer struct { - r *rand.Rand + mu sync.Mutex // guards r + r *rand.Rand + minBackoff time.Duration jitter time.Duration aggressiveFactor int @@ -101,7 +104,9 @@ func newStatelessRetryer() *statelessRetryer { func (sr *statelessRetryer) pause(aggressiveBackoff bool) time.Duration { jitter := sr.jitter.Nanoseconds() if jitter > 0 { + sr.mu.Lock() jitter = sr.r.Int63n(jitter) + sr.mu.Unlock() } pause := sr.minBackoff.Nanoseconds() + jitter if aggressiveBackoff {