diff --git a/go/vt/vtgate/engine/concatenate.go b/go/vt/vtgate/engine/concatenate.go index e415e683d4f..1e8eb5b3f1a 100644 --- a/go/vt/vtgate/engine/concatenate.go +++ b/go/vt/vtgate/engine/concatenate.go @@ -19,8 +19,6 @@ package engine import ( "sync" - "vitess.io/vitess/go/sync2" - "vitess.io/vitess/go/sqltypes" querypb "vitess.io/vitess/go/vt/proto/query" vtrpcpb "vitess.io/vitess/go/vt/proto/vtrpc" @@ -30,12 +28,12 @@ import ( // Concatenate Primitive is used to concatenate results from multiple sources. var _ Primitive = (*Concatenate)(nil) -//Concatenate specified the parameter for concatenate primitive +// Concatenate specified the parameter for concatenate primitive type Concatenate struct { Sources []Primitive } -//RouteType returns a description of the query routing type used by the primitive +// RouteType returns a description of the query routing type used by the primitive func (c *Concatenate) RouteType() string { return "Concatenate" } @@ -146,13 +144,13 @@ func (c *Concatenate) execSources(vcursor VCursor, bindVars map[string]*querypb. // TryStreamExecute performs a streaming exec. func (c *Concatenate) TryStreamExecute(vcursor VCursor, bindVars map[string]*querypb.BindVariable, wantfields bool, callback func(*sqltypes.Result) error) error { var seenFields []*querypb.Field - var fieldset sync.WaitGroup - var cbMu sync.Mutex + var wg sync.WaitGroup + var cbMu, fieldsMu sync.Mutex g, restoreCtx := vcursor.ErrorGroupCancellableContext() defer restoreCtx() - var fieldsSent sync2.AtomicBool - fieldset.Add(1) + var fieldsSent bool + wg.Add(1) for i, source := range c.Sources { currIndex, currSource := i, source @@ -160,14 +158,19 @@ func (c *Concatenate) TryStreamExecute(vcursor VCursor, bindVars map[string]*que g.Go(func() error { err := vcursor.StreamExecutePrimitive(currSource, bindVars, wantfields, func(resultChunk *sqltypes.Result) error { // if we have fields to compare, make sure all the fields are all the same - if currIndex == 0 && !fieldsSent.Get() { - defer fieldset.Done() - seenFields = resultChunk.Fields - fieldsSent.Set(true) - // No other call can happen before this call. - return callback(resultChunk) + if currIndex == 0 { + fieldsMu.Lock() + if !fieldsSent { + defer wg.Done() + defer fieldsMu.Unlock() + seenFields = resultChunk.Fields + fieldsSent = true + // No other call can happen before this call. + return callback(resultChunk) + } + fieldsMu.Unlock() } - fieldset.Wait() + wg.Wait() if resultChunk.Fields != nil { err := compareFields(seenFields, resultChunk.Fields) if err != nil { @@ -185,9 +188,15 @@ func (c *Concatenate) TryStreamExecute(vcursor VCursor, bindVars map[string]*que } }) // This is to ensure other streams complete if the first stream failed to unlock the wait. - if currIndex == 0 && !fieldsSent.Get() { - fieldset.Done() + if currIndex == 0 { + fieldsMu.Lock() + if !fieldsSent { + fieldsSent = true + wg.Done() + } + fieldsMu.Unlock() } + return err }) @@ -218,7 +227,7 @@ func (c *Concatenate) GetFields(vcursor VCursor, bindVars map[string]*querypb.Bi return res, nil } -//NeedsTransaction returns whether a transaction is needed for this primitive +// NeedsTransaction returns whether a transaction is needed for this primitive func (c *Concatenate) NeedsTransaction() bool { for _, source := range c.Sources { if source.NeedsTransaction() { diff --git a/go/vt/vtgate/engine/limit.go b/go/vt/vtgate/engine/limit.go index 58d6b24b670..de048f6a827 100644 --- a/go/vt/vtgate/engine/limit.go +++ b/go/vt/vtgate/engine/limit.go @@ -88,6 +88,8 @@ func (l *Limit) TryStreamExecute(vcursor VCursor, bindVars map[string]*querypb.B return err } + bindVars = copyBindVars(bindVars) + // When offset is present, we hijack the limit value so we can calculate // the offset in memory from the result of the scatter query with count + offset. bindVars["__upper_limit"] = sqltypes.Int64BindVariable(int64(count + offset))