Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fixes a race in the scheduling limits. #3417

Merged
merged 2 commits into from
Mar 2, 2021
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 6 additions & 8 deletions pkg/querier/queryrange/downstreamer.go
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ type instance struct {
}

func (in instance) Downstream(ctx context.Context, queries []logql.DownstreamQuery) ([]logql.Result, error) {
return in.For(queries, func(qry logql.DownstreamQuery) (logql.Result, error) {
return in.For(ctx, queries, func(qry logql.DownstreamQuery) (logql.Result, error) {
req := ParamsToLokiRequest(qry.Params).WithShards(qry.Shards).WithQuery(qry.Expr.String()).(*LokiRequest)
logger, ctx := spanlogger.New(ctx, "DownstreamHandler.instance")
defer logger.Finish()
Expand All @@ -72,6 +72,7 @@ func (in instance) Downstream(ctx context.Context, queries []logql.DownstreamQue

// For runs a function against a list of queries, collecting the results or returning an error. The indices are preserved such that input[i] maps to output[i].
func (in instance) For(
ctx context.Context,
queries []logql.DownstreamQuery,
fn func(logql.DownstreamQuery) (logql.Result, error),
) ([]logql.Result, error) {
Expand All @@ -81,16 +82,15 @@ func (in instance) For(
err error
}

done := make(chan struct{})
defer close(done)

ctx, cancel := context.WithCancel(ctx)
defer cancel()
ch := make(chan resp)

// Make one goroutine to dispatch the other goroutines, bounded by instance parallelism
go func() {
for i := 0; i < len(queries); i++ {
select {
case <-done:
case <-ctx.Done():
break
case <-in.locks:
go func(i int) {
Expand All @@ -108,7 +108,7 @@ func (in instance) For(

// Feed the result into the channel unless the work has completed.
select {
case <-done:
case <-ctx.Done():
case ch <- response:
}
}(i)
Expand All @@ -125,7 +125,6 @@ func (in instance) For(
results[resp.i] = resp.res
}
return results, nil

}

// convert to matrix
Expand All @@ -136,7 +135,6 @@ func sampleStreamToMatrix(streams []queryrange.SampleStream) parser.Value {
x.Metric = make(labels.Labels, 0, len(stream.Labels))
for _, l := range stream.Labels {
x.Metric = append(x.Metric, labels.Label(l))

}

x.Points = make([]promql.Point, 0, len(stream.Samples))
Expand Down
9 changes: 3 additions & 6 deletions pkg/querier/queryrange/downstreamer_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -184,7 +184,6 @@ func TestResponseToResult(t *testing.T) {
}

func TestDownstreamHandler(t *testing.T) {

// Pretty poor test, but this is just a passthrough struct, so ensure we create locks
// and can consume them
h := DownstreamHandler{nil}
Expand Down Expand Up @@ -220,7 +219,7 @@ func TestInstanceFor(t *testing.T) {
var ct int

// ensure we can execute queries that number more than the parallelism parameter
_, err := in.For(queries, func(_ logql.DownstreamQuery) (logql.Result, error) {
_, err := in.For(context.TODO(), queries, func(_ logql.DownstreamQuery) (logql.Result, error) {
mtx.Lock()
defer mtx.Unlock()
ct++
Expand All @@ -233,7 +232,7 @@ func TestInstanceFor(t *testing.T) {
// ensure an early error abandons the other queues queries
in = mkIn()
ct = 0
_, err = in.For(queries, func(_ logql.DownstreamQuery) (logql.Result, error) {
_, err = in.For(context.TODO(), queries, func(_ logql.DownstreamQuery) (logql.Result, error) {
mtx.Lock()
defer mtx.Unlock()
ct++
Expand All @@ -250,6 +249,7 @@ func TestInstanceFor(t *testing.T) {

in = mkIn()
results, err := in.For(
context.TODO(),
[]logql.DownstreamQuery{
{
Shards: logql.Shards{
Expand All @@ -263,7 +263,6 @@ func TestInstanceFor(t *testing.T) {
},
},
func(qry logql.DownstreamQuery) (logql.Result, error) {

return logql.Result{
Data: logql.Streams{{
Labels: qry.Shards[0].String(),
Expand All @@ -285,7 +284,6 @@ func TestInstanceFor(t *testing.T) {
results,
)
ensureParallelism(t, in, in.parallelism)

}

func TestInstanceDownstream(t *testing.T) {
Expand Down Expand Up @@ -345,5 +343,4 @@ func TestInstanceDownstream(t *testing.T) {

require.Nil(t, err)
require.Equal(t, []logql.Result{expected}, results)

}
17 changes: 14 additions & 3 deletions pkg/querier/queryrange/limits.go
Original file line number Diff line number Diff line change
Expand Up @@ -186,6 +186,13 @@ func newWork(ctx context.Context, req queryrange.Request) work {
}

func (rt limitedRoundTripper) RoundTrip(r *http.Request) (*http.Response, error) {
var wg sync.WaitGroup
intermediate := make(chan work)
defer func() {
wg.Wait()
close(intermediate)
}()

ctx, cancel := context.WithCancel(r.Context())
defer cancel()

Expand All @@ -203,8 +210,6 @@ func (rt limitedRoundTripper) RoundTrip(r *http.Request) (*http.Response, error)
}

parallelism := rt.limits.MaxQueryParallelism(userid)
intermediate := make(chan work)
defer close(intermediate)

for i := 0; i < parallelism; i++ {
go func() {
Expand All @@ -222,13 +227,19 @@ func (rt limitedRoundTripper) RoundTrip(r *http.Request) (*http.Response, error)

response, err := rt.middleware.Wrap(
queryrange.HandlerFunc(func(ctx context.Context, r queryrange.Request) (queryrange.Response, error) {
wg.Add(1)
defer wg.Done()

if ctx.Err() != nil {
return nil, ctx.Err()
}
w := newWork(ctx, r)
intermediate <- w
select {
case response := <-w.result:
return response.response, response.err
case <-ctx.Done():
return nil, err
return nil, ctx.Err()
}
})).Do(ctx, request)
if err != nil {
Expand Down
30 changes: 29 additions & 1 deletion pkg/querier/queryrange/limits_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -147,7 +147,7 @@ func Test_seriesLimiter(t *testing.T) {
require.LessOrEqual(t, *c, 4)
}

func Test_MaxQueryPallelism(t *testing.T) {
func Test_MaxQueryParallelism(t *testing.T) {
maxQueryParallelism := 2
f, err := newfakeRoundTripper()
require.Nil(t, err)
Expand Down Expand Up @@ -186,3 +186,31 @@ func Test_MaxQueryPallelism(t *testing.T) {
maxFound := int(max.Load())
require.LessOrEqual(t, maxFound, maxQueryParallelism, "max query parallelism: ", maxFound, " went over the configured one:", maxQueryParallelism)
}

func Test_MaxQueryParallelismLateCancel(t *testing.T) {
maxQueryParallelism := 2
f, err := newfakeRoundTripper()
require.Nil(t, err)

f.setHandler(http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
// simulate some work
time.Sleep(20 * time.Millisecond)
}))
ctx := user.InjectOrgID(context.Background(), "foo")

r, err := http.NewRequestWithContext(ctx, "GET", "/query_range", http.NoBody)
require.Nil(t, err)

_, _ = NewLimitedRoundTripper(f, lokiCodec, fakeLimits{maxQueryParallelism: maxQueryParallelism},
queryrange.MiddlewareFunc(func(next queryrange.Handler) queryrange.Handler {
return queryrange.HandlerFunc(func(c context.Context, r queryrange.Request) (queryrange.Response, error) {
for i := 0; i < 10; i++ {
go func() {
_, _ = next.Do(c, &LokiRequest{})
}()
}
return nil, nil
})
}),
).RoundTrip(r)
}