Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[V13] fix: make concatenate and limit concurrent safe #9981

Merged
merged 1 commit into from
Mar 25, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
45 changes: 27 additions & 18 deletions go/vt/vtgate/engine/concatenate.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,6 @@ package engine
import (
"sync"

"vitess.io/vitess/go/sync2"

"vitess.io/vitess/go/sqltypes"
querypb "vitess.io/vitess/go/vt/proto/query"
vtrpcpb "vitess.io/vitess/go/vt/proto/vtrpc"
Expand All @@ -30,12 +28,12 @@ import (
// Concatenate Primitive is used to concatenate results from multiple sources.
var _ Primitive = (*Concatenate)(nil)

//Concatenate specified the parameter for concatenate primitive
// Concatenate specified the parameter for concatenate primitive
type Concatenate struct {
Sources []Primitive
}

//RouteType returns a description of the query routing type used by the primitive
// RouteType returns a description of the query routing type used by the primitive
func (c *Concatenate) RouteType() string {
return "Concatenate"
}
Expand Down Expand Up @@ -146,28 +144,33 @@ func (c *Concatenate) execSources(vcursor VCursor, bindVars map[string]*querypb.
// TryStreamExecute performs a streaming exec.
func (c *Concatenate) TryStreamExecute(vcursor VCursor, bindVars map[string]*querypb.BindVariable, wantfields bool, callback func(*sqltypes.Result) error) error {
var seenFields []*querypb.Field
var fieldset sync.WaitGroup
var cbMu sync.Mutex
var wg sync.WaitGroup
var cbMu, fieldsMu sync.Mutex

g, restoreCtx := vcursor.ErrorGroupCancellableContext()
defer restoreCtx()
var fieldsSent sync2.AtomicBool
fieldset.Add(1)
var fieldsSent bool
wg.Add(1)

for i, source := range c.Sources {
currIndex, currSource := i, source

g.Go(func() error {
err := vcursor.StreamExecutePrimitive(currSource, bindVars, wantfields, func(resultChunk *sqltypes.Result) error {
// if we have fields to compare, make sure all the fields are all the same
if currIndex == 0 && !fieldsSent.Get() {
defer fieldset.Done()
seenFields = resultChunk.Fields
fieldsSent.Set(true)
// No other call can happen before this call.
return callback(resultChunk)
if currIndex == 0 {
fieldsMu.Lock()
if !fieldsSent {
defer wg.Done()
defer fieldsMu.Unlock()
seenFields = resultChunk.Fields
fieldsSent = true
// No other call can happen before this call.
return callback(resultChunk)
}
fieldsMu.Unlock()
}
fieldset.Wait()
wg.Wait()
if resultChunk.Fields != nil {
err := compareFields(seenFields, resultChunk.Fields)
if err != nil {
Expand All @@ -185,9 +188,15 @@ func (c *Concatenate) TryStreamExecute(vcursor VCursor, bindVars map[string]*que
}
})
// This is to ensure other streams complete if the first stream failed to unlock the wait.
if currIndex == 0 && !fieldsSent.Get() {
fieldset.Done()
if currIndex == 0 {
fieldsMu.Lock()
if !fieldsSent {
fieldsSent = true
wg.Done()
}
fieldsMu.Unlock()
}

return err
})

Expand Down Expand Up @@ -218,7 +227,7 @@ func (c *Concatenate) GetFields(vcursor VCursor, bindVars map[string]*querypb.Bi
return res, nil
}

//NeedsTransaction returns whether a transaction is needed for this primitive
// NeedsTransaction returns whether a transaction is needed for this primitive
func (c *Concatenate) NeedsTransaction() bool {
for _, source := range c.Sources {
if source.NeedsTransaction() {
Expand Down
2 changes: 2 additions & 0 deletions go/vt/vtgate/engine/limit.go
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,8 @@ func (l *Limit) TryStreamExecute(vcursor VCursor, bindVars map[string]*querypb.B
return err
}

bindVars = copyBindVars(bindVars)

// When offset is present, we hijack the limit value so we can calculate
// the offset in memory from the result of the scatter query with count + offset.
bindVars["__upper_limit"] = sqltypes.Int64BindVariable(int64(count + offset))
Expand Down