From 242223ec3f071b91c2bccce365ddbf4106a6f282 Mon Sep 17 00:00:00 2001 From: Jeanette Tan Date: Mon, 23 Sep 2024 01:59:23 +0800 Subject: [PATCH] Generalise some of Mimir's query sharding code to be more reusable --- pkg/frontend/querymiddleware/codec.go | 229 +++++++++++++++++- pkg/frontend/querymiddleware/codec_json.go | 32 ++- .../querymiddleware/codec_protobuf.go | 16 ++ pkg/frontend/querymiddleware/limits.go | 4 +- pkg/frontend/querymiddleware/limits_test.go | 8 +- pkg/frontend/querymiddleware/model_extra.go | 189 ++++++++++++++- pkg/frontend/querymiddleware/prune.go | 2 +- pkg/frontend/querymiddleware/querysharding.go | 38 ++- .../querymiddleware/querysharding_test.go | 6 +- pkg/frontend/querymiddleware/remote_read.go | 2 +- .../querymiddleware/remote_read_test.go | 6 +- pkg/frontend/querymiddleware/roundtrip.go | 30 ++- .../querymiddleware/sharded_queryable.go | 60 +++-- .../querymiddleware/sharded_queryable_test.go | 6 +- .../querymiddleware/split_and_cache.go | 2 +- .../split_by_instant_interval.go | 6 +- .../merge_exemplar_queryable.go | 2 +- .../tenantfederation/merge_queryable.go | 6 +- .../tenantfederation/tenant_federation.go | 4 +- 19 files changed, 568 insertions(+), 80 deletions(-) diff --git a/pkg/frontend/querymiddleware/codec.go b/pkg/frontend/querymiddleware/codec.go index 70224131ecb..7675a942a5b 100644 --- a/pkg/frontend/querymiddleware/codec.go +++ b/pkg/frontend/querymiddleware/codec.go @@ -49,7 +49,9 @@ var ( allFormats = []string{formatJSON, formatProtobuf} // List of HTTP headers to propagate when a Prometheus request is encoded into a HTTP request. - prometheusCodecPropagateHeaders = []string{compat.ForceFallbackHeaderName, chunkinfologger.ChunkInfoLoggingHeader, api.ReadConsistencyOffsetsHeader} + clusterNameHeader = "X-Cluster-Name" + prometheusCodecPropagateHeaders = []string{compat.ForceFallbackHeaderName, chunkinfologger.ChunkInfoLoggingHeader, api.ReadConsistencyOffsetsHeader, clusterNameHeader} + prometheusCodecLabelsPropagateHeaders = []string{clusterNameHeader} ) const ( @@ -81,12 +83,18 @@ type Codec interface { // The original request is also passed as a parameter this is useful for implementation that needs the request // to merge result or build the result correctly. DecodeResponse(context.Context, *http.Response, MetricsQueryRequest, log.Logger) (Response, error) + // DecodeLabelsResponse decodes a Response from an http response. + // The original request is also passed as a parameter this is useful for implementation that needs the request + // to merge result or build the result correctly. + DecodeLabelsResponse(context.Context, *http.Response, LabelsQueryRequest, log.Logger) (Response, error) // EncodeMetricsQueryRequest encodes a MetricsQueryRequest into an http request. EncodeMetricsQueryRequest(context.Context, MetricsQueryRequest) (*http.Request, error) // EncodeLabelsQueryRequest encodes a LabelsQueryRequest into an http request. EncodeLabelsQueryRequest(context.Context, LabelsQueryRequest) (*http.Request, error) // EncodeResponse encodes a Response into an http response. EncodeResponse(context.Context, *http.Request, Response) (*http.Response, error) + // EncodeLabelsResponse encodes a Response into an http response. + EncodeLabelsResponse(context.Context, *http.Request, Response, LabelsQueryRequest) (*http.Response, error) } // Merger is used by middlewares making multiple requests to merge back all responses into a single one. @@ -166,6 +174,14 @@ type LabelsQueryRequest interface { GetLabelMatcherSets() []string // GetLimit returns the limit of the number of items in the response. GetLimit() uint64 + // GetHeaders returns the HTTP headers in the request. + GetHeaders() []*PrometheusHeader + // WithLabelName clones the current request with a different label name param. + WithLabelName(string) (LabelsQueryRequest, error) + // WithLabelMatcherSets clones the current request with different label matchers. + WithLabelMatcherSets([]string) (LabelsQueryRequest, error) + // WithHeaders clones the current request with different headers. + WithHeaders([]*PrometheusHeader) (LabelsQueryRequest, error) // AddSpanTags writes information about this request to an OpenTracing span AddSpanTags(opentracing.Span) } @@ -211,7 +227,11 @@ type prometheusCodec struct { type formatter interface { EncodeResponse(resp *PrometheusResponse) ([]byte, error) + EncodeLabelsResponse(resp *PrometheusLabelsResponse) ([]byte, error) + EncodeSeriesResponse(resp *PrometheusSeriesResponse) ([]byte, error) DecodeResponse([]byte) (*PrometheusResponse, error) + DecodeLabelsResponse([]byte) (*PrometheusLabelsResponse, error) + DecodeSeriesResponse([]byte) (*PrometheusSeriesResponse, error) Name() string ContentType() v1.MIMEType } @@ -319,7 +339,7 @@ func (c prometheusCodec) decodeRangeQueryRequest(r *http.Request) (MetricsQueryR query := reqValues.Get("query") queryExpr, err := parser.ParseExpr(query) if err != nil { - return nil, decorateWithParamName(err, "query") + return nil, DecorateWithParamName(err, "query") } var options Options @@ -345,13 +365,13 @@ func (c prometheusCodec) decodeInstantQueryRequest(r *http.Request) (MetricsQuer time, err := DecodeInstantQueryTimeParams(&reqValues, time.Now) if err != nil { - return nil, decorateWithParamName(err, "time") + return nil, DecorateWithParamName(err, "time") } query := reqValues.Get("query") queryExpr, err := parser.ParseExpr(query) if err != nil { - return nil, decorateWithParamName(err, "query") + return nil, DecorateWithParamName(err, "query") } var options Options @@ -364,7 +384,7 @@ func (c prometheusCodec) decodeInstantQueryRequest(r *http.Request) (MetricsQuer } func (prometheusCodec) DecodeLabelsQueryRequest(_ context.Context, r *http.Request) (LabelsQueryRequest, error) { - if !IsLabelsQuery(r.URL.Path) { + if !IsLabelsQuery(r.URL.Path) && !IsSeriesQuery(r.URL.Path) { return nil, fmt.Errorf("unknown labels query API endpoint %s", r.URL.Path) } @@ -387,6 +407,15 @@ func (prometheusCodec) DecodeLabelsQueryRequest(_ context.Context, r *http.Reque } } + if IsSeriesQuery(r.URL.Path) { + return &PrometheusSeriesQueryRequest{ + Path: r.URL.Path, + Start: start, + End: end, + LabelMatcherSets: labelMatcherSets, + Limit: limit, + }, nil + } if IsLabelNamesQuery(r.URL.Path) { return &PrometheusLabelNamesQueryRequest{ Path: r.URL.Path, @@ -412,12 +441,12 @@ func (prometheusCodec) DecodeLabelsQueryRequest(_ context.Context, r *http.Reque func DecodeRangeQueryTimeParams(reqValues *url.Values) (start, end, step int64, err error) { start, err = util.ParseTime(reqValues.Get("start")) if err != nil { - return 0, 0, 0, decorateWithParamName(err, "start") + return 0, 0, 0, DecorateWithParamName(err, "start") } end, err = util.ParseTime(reqValues.Get("end")) if err != nil { - return 0, 0, 0, decorateWithParamName(err, "end") + return 0, 0, 0, DecorateWithParamName(err, "end") } if end < start { @@ -426,7 +455,7 @@ func DecodeRangeQueryTimeParams(reqValues *url.Values) (start, end, step int64, step, err = parseDurationMs(reqValues.Get("step")) if err != nil { - return 0, 0, 0, decorateWithParamName(err, "step") + return 0, 0, 0, DecorateWithParamName(err, "step") } if step <= 0 { @@ -451,7 +480,7 @@ func DecodeInstantQueryTimeParams(reqValues *url.Values, defaultNow func() time. } else { time, err = util.ParseTime(timeVal) if err != nil { - return 0, decorateWithParamName(err, "time") + return 0, DecorateWithParamName(err, "time") } } @@ -476,7 +505,7 @@ func DecodeLabelsQueryTimeParams(reqValues *url.Values, usePromDefaults bool) (s } else { start, err = util.ParseTime(startVal) if err != nil { - return 0, 0, decorateWithParamName(err, "start") + return 0, 0, DecorateWithParamName(err, "start") } } @@ -486,7 +515,7 @@ func DecodeLabelsQueryTimeParams(reqValues *url.Values, usePromDefaults bool) (s } else { end, err = util.ParseTime(endVal) if err != nil { - return 0, 0, decorateWithParamName(err, "end") + return 0, 0, DecorateWithParamName(err, "end") } } @@ -652,6 +681,24 @@ func (c prometheusCodec) EncodeLabelsQueryRequest(ctx context.Context, req Label Path: req.Path, // path still contains label name RawQuery: urlValues.Encode(), } + case *PrometheusSeriesQueryRequest: + urlValues := url.Values{} + if req.GetStart() != 0 { + urlValues["start"] = []string{encodeTime(req.Start)} + } + if req.GetEnd() != 0 { + urlValues["end"] = []string{encodeTime(req.End)} + } + if len(req.GetLabelMatcherSets()) > 0 { + urlValues["match[]"] = req.GetLabelMatcherSets() + } + if req.GetLimit() > 0 { + urlValues["limit"] = []string{strconv.FormatUint(req.GetLimit(), 10)} + } + u = &url.URL{ + Path: req.Path, + RawQuery: urlValues.Encode(), + } default: return nil, fmt.Errorf("unsupported request type %T", req) @@ -678,6 +725,18 @@ func (c prometheusCodec) EncodeLabelsQueryRequest(ctx context.Context, req Label r.Header.Add(api.ReadConsistencyHeader, level) } + // Propagate allowed HTTP headers. + for _, h := range req.GetHeaders() { + if !slices.Contains(prometheusCodecLabelsPropagateHeaders, h.Name) { + continue + } + + for _, v := range h.Values { + // There should only be one value, but add all of them for completeness. + r.Header.Add(h.Name, v) + } + } + return r.WithContext(ctx), nil } @@ -755,6 +814,90 @@ func (c prometheusCodec) DecodeResponse(ctx context.Context, r *http.Response, _ return resp, nil } +func (c prometheusCodec) DecodeLabelsResponse(ctx context.Context, r *http.Response, lr LabelsQueryRequest, logger log.Logger) (Response, error) { + spanlog := spanlogger.FromContext(ctx, logger) + buf, err := readResponseBody(r) + if err != nil { + return nil, spanlog.Error(err) + } + + spanlog.LogFields(otlog.String("message", "ParseQueryRangeResponse"), + otlog.Int("status_code", r.StatusCode), + otlog.Int("bytes", len(buf))) + + // Before attempting to decode a response based on the content type, check if the + // Content-Type header was even set. When the scheduler returns gRPC errors, they + // are encoded as httpgrpc.HTTPResponse objects with an HTTP status code and the + // error message as the body of the response with no content type. We need to handle + // that case here before we decode well-formed success or error responses. + contentType := r.Header.Get("Content-Type") + if contentType == "" { + switch r.StatusCode { + case http.StatusServiceUnavailable: + return nil, apierror.New(apierror.TypeUnavailable, string(buf)) + case http.StatusTooManyRequests: + return nil, apierror.New(apierror.TypeTooManyRequests, string(buf)) + case http.StatusRequestEntityTooLarge: + return nil, apierror.New(apierror.TypeTooLargeEntry, string(buf)) + default: + if r.StatusCode/100 == 5 { + return nil, apierror.New(apierror.TypeInternal, string(buf)) + } + } + } + + formatter := findFormatter(contentType) + if formatter == nil { + return nil, apierror.Newf(apierror.TypeInternal, "unknown response content type '%v'", contentType) + } + + start := time.Now() + + var response Response + + switch lr.(type) { + case *PrometheusLabelNamesQueryRequest, *PrometheusLabelValuesQueryRequest: + resp, err := formatter.DecodeLabelsResponse(buf) + if err != nil { + return nil, apierror.Newf(apierror.TypeInternal, "error decoding response: %v", err) + } + + c.metrics.duration.WithLabelValues(operationDecode, formatter.Name()).Observe(time.Since(start).Seconds()) + c.metrics.size.WithLabelValues(operationDecode, formatter.Name()).Observe(float64(len(buf))) + + if resp.Status == statusError { + return nil, apierror.New(apierror.Type(resp.ErrorType), resp.Error) + } + + for h, hv := range r.Header { + resp.Headers = append(resp.Headers, &PrometheusHeader{Name: h, Values: hv}) + } + + response = resp + case *PrometheusSeriesQueryRequest: + resp, err := formatter.DecodeSeriesResponse(buf) + if err != nil { + return nil, apierror.Newf(apierror.TypeInternal, "error decoding response: %v", err) + } + + c.metrics.duration.WithLabelValues(operationDecode, formatter.Name()).Observe(time.Since(start).Seconds()) + c.metrics.size.WithLabelValues(operationDecode, formatter.Name()).Observe(float64(len(buf))) + + if resp.Status == statusError { + return nil, apierror.New(apierror.Type(resp.ErrorType), resp.Error) + } + + for h, hv := range r.Header { + resp.Headers = append(resp.Headers, &PrometheusHeader{Name: h, Values: hv}) + } + + response = resp + default: + return nil, apierror.Newf(apierror.TypeInternal, "unsupported request type %T", lr) + } + return response, nil +} + func findFormatter(contentType string) formatter { for _, f := range knownFormats { if f.ContentType().String() == contentType { @@ -807,6 +950,68 @@ func (c prometheusCodec) EncodeResponse(ctx context.Context, req *http.Request, return &resp, nil } +func (c prometheusCodec) EncodeLabelsResponse(ctx context.Context, req *http.Request, res Response, lr LabelsQueryRequest) (*http.Response, error) { + sp, _ := opentracing.StartSpanFromContext(ctx, "APIResponse.ToHTTPResponse") + defer sp.Finish() + + selectedContentType, formatter := c.negotiateContentType(req.Header.Get("Accept")) + if formatter == nil { + return nil, apierror.New(apierror.TypeNotAcceptable, "none of the content types in the Accept header are supported") + } + + var start time.Time + var b []byte + + switch lr.(type) { + case *PrometheusLabelNamesQueryRequest, *PrometheusLabelValuesQueryRequest: + a, ok := res.(*PrometheusLabelsResponse) + if !ok { + return nil, apierror.Newf(apierror.TypeInternal, "invalid response format") + } + if a.Data != nil { + sp.LogFields(otlog.Int("labels", len(a.Data))) + } + + start = time.Now() + var err error + b, err = formatter.EncodeLabelsResponse(a) + if err != nil { + return nil, apierror.Newf(apierror.TypeInternal, "error encoding response: %v", err) + } + case *PrometheusSeriesQueryRequest: + a, ok := res.(*PrometheusSeriesResponse) + if !ok { + return nil, apierror.Newf(apierror.TypeInternal, "invalid response format") + } + if a.Data != nil { + sp.LogFields(otlog.Int("labels", len(a.Data))) + } + + start = time.Now() + var err error + b, err = formatter.EncodeSeriesResponse(a) + if err != nil { + return nil, apierror.Newf(apierror.TypeInternal, "error encoding response: %v", err) + } + default: + return nil, apierror.Newf(apierror.TypeInternal, "unsupported request type %T", lr) + } + + c.metrics.duration.WithLabelValues(operationEncode, formatter.Name()).Observe(time.Since(start).Seconds()) + c.metrics.size.WithLabelValues(operationEncode, formatter.Name()).Observe(float64(len(b))) + sp.LogFields(otlog.Int("bytes", len(b))) + + resp := http.Response{ + Header: http.Header{ + "Content-Type": []string{selectedContentType}, + }, + Body: io.NopCloser(bytes.NewBuffer(b)), + StatusCode: http.StatusOK, + ContentLength: int64(len(b)), + } + return &resp, nil +} + func (prometheusCodec) negotiateContentType(acceptHeader string) (string, formatter) { if acceptHeader == "" { return jsonMimeType, jsonFormatterInstance @@ -967,7 +1172,7 @@ func encodeDurationMs(d int64) string { return strconv.FormatFloat(float64(d)/float64(time.Second/time.Millisecond), 'f', -1, 64) } -func decorateWithParamName(err error, field string) error { +func DecorateWithParamName(err error, field string) error { errTmpl := "invalid parameter %q: %v" if status, ok := grpcutil.ErrorToStatus(err); ok { return apierror.Newf(apierror.TypeBadData, errTmpl, field, status.Message()) diff --git a/pkg/frontend/querymiddleware/codec_json.go b/pkg/frontend/querymiddleware/codec_json.go index 5b484474f16..9f36ff10833 100644 --- a/pkg/frontend/querymiddleware/codec_json.go +++ b/pkg/frontend/querymiddleware/codec_json.go @@ -5,7 +5,9 @@ package querymiddleware -import v1 "github.com/prometheus/prometheus/web/api/v1" +import ( + v1 "github.com/prometheus/prometheus/web/api/v1" +) const jsonMimeType = "application/json" @@ -25,6 +27,34 @@ func (j jsonFormatter) DecodeResponse(buf []byte) (*PrometheusResponse, error) { return &resp, nil } +func (j jsonFormatter) EncodeLabelsResponse(resp *PrometheusLabelsResponse) ([]byte, error) { + return json.Marshal(resp) +} + +func (j jsonFormatter) DecodeLabelsResponse(buf []byte) (*PrometheusLabelsResponse, error) { + var resp PrometheusLabelsResponse + + if err := json.Unmarshal(buf, &resp); err != nil { + return nil, err + } + + return &resp, nil +} + +func (j jsonFormatter) EncodeSeriesResponse(resp *PrometheusSeriesResponse) ([]byte, error) { + return json.Marshal(resp) +} + +func (j jsonFormatter) DecodeSeriesResponse(buf []byte) (*PrometheusSeriesResponse, error) { + var resp PrometheusSeriesResponse + + if err := json.Unmarshal(buf, &resp); err != nil { + return nil, err + } + + return &resp, nil +} + func (j jsonFormatter) Name() string { return formatJSON } diff --git a/pkg/frontend/querymiddleware/codec_protobuf.go b/pkg/frontend/querymiddleware/codec_protobuf.go index ab1fea61ccc..89e3715b125 100644 --- a/pkg/frontend/querymiddleware/codec_protobuf.go +++ b/pkg/frontend/querymiddleware/codec_protobuf.go @@ -326,6 +326,22 @@ func (f protobufFormatter) decodeMatrixData(data *mimirpb.MatrixData) (*Promethe }, nil } +func (f protobufFormatter) EncodeLabelsResponse(*PrometheusLabelsResponse) ([]byte, error) { + return nil, errors.New("protobuf labels encoding is not supported") +} + +func (f protobufFormatter) DecodeLabelsResponse([]byte) (*PrometheusLabelsResponse, error) { + return nil, errors.New("protobuf labels decoding is not supported") +} + +func (f protobufFormatter) EncodeSeriesResponse(*PrometheusSeriesResponse) ([]byte, error) { + return nil, errors.New("protobuf series encoding is not supported") +} + +func (f protobufFormatter) DecodeSeriesResponse([]byte) (*PrometheusSeriesResponse, error) { + return nil, errors.New("protobuf series decoding is not supported") +} + func labelsFromStringArray(s []string) ([]mimirpb.LabelAdapter, error) { if len(s)%2 != 0 { return nil, fmt.Errorf("metric is malformed: expected even number of symbols, but got %v", len(s)) diff --git a/pkg/frontend/querymiddleware/limits.go b/pkg/frontend/querymiddleware/limits.go index 80d3e1bf05b..55a5141934a 100644 --- a/pkg/frontend/querymiddleware/limits.go +++ b/pkg/frontend/querymiddleware/limits.go @@ -195,8 +195,8 @@ type limitedParallelismRoundTripper struct { middleware MetricsQueryMiddleware } -// newLimitedParallelismRoundTripper creates a new roundtripper that enforces MaxQueryParallelism to the `next` roundtripper across `middlewares`. -func newLimitedParallelismRoundTripper(next http.RoundTripper, codec Codec, limits Limits, middlewares ...MetricsQueryMiddleware) http.RoundTripper { +// NewLimitedParallelismRoundTripper creates a new roundtripper that enforces MaxQueryParallelism to the `next` roundtripper across `middlewares`. +func NewLimitedParallelismRoundTripper(next http.RoundTripper, codec Codec, limits Limits, middlewares ...MetricsQueryMiddleware) http.RoundTripper { return limitedParallelismRoundTripper{ downstream: roundTripperHandler{ next: next, diff --git a/pkg/frontend/querymiddleware/limits_test.go b/pkg/frontend/querymiddleware/limits_test.go index fe4e80cc2c6..6163570c938 100644 --- a/pkg/frontend/querymiddleware/limits_test.go +++ b/pkg/frontend/querymiddleware/limits_test.go @@ -753,7 +753,7 @@ func TestLimitedRoundTripper_MaxQueryParallelism(t *testing.T) { }) require.Nil(t, err) - _, err = newLimitedParallelismRoundTripper(downstream, codec, mockLimits{maxQueryParallelism: maxQueryParallelism}, + _, err = NewLimitedParallelismRoundTripper(downstream, codec, mockLimits{maxQueryParallelism: maxQueryParallelism}, MetricsQueryMiddlewareFunc(func(next MetricsQueryHandler) MetricsQueryHandler { return HandlerFunc(func(c context.Context, _ MetricsQueryRequest) (Response, error) { var wg sync.WaitGroup @@ -797,7 +797,7 @@ func TestLimitedRoundTripper_MaxQueryParallelismLateScheduling(t *testing.T) { }) require.Nil(t, err) - _, err = newLimitedParallelismRoundTripper(downstream, codec, mockLimits{maxQueryParallelism: maxQueryParallelism}, + _, err = NewLimitedParallelismRoundTripper(downstream, codec, mockLimits{maxQueryParallelism: maxQueryParallelism}, MetricsQueryMiddlewareFunc(func(next MetricsQueryHandler) MetricsQueryHandler { return HandlerFunc(func(c context.Context, _ MetricsQueryRequest) (Response, error) { // fire up work and we don't wait. @@ -838,7 +838,7 @@ func TestLimitedRoundTripper_OriginalRequestContextCancellation(t *testing.T) { }) require.Nil(t, err) - _, err = newLimitedParallelismRoundTripper(downstream, codec, mockLimits{maxQueryParallelism: maxQueryParallelism}, + _, err = NewLimitedParallelismRoundTripper(downstream, codec, mockLimits{maxQueryParallelism: maxQueryParallelism}, MetricsQueryMiddlewareFunc(func(next MetricsQueryHandler) MetricsQueryHandler { return HandlerFunc(func(c context.Context, _ MetricsQueryRequest) (Response, error) { var wg sync.WaitGroup @@ -897,7 +897,7 @@ func BenchmarkLimitedParallelismRoundTripper(b *testing.B) { for _, concurrentRequestCount := range []int{1, 10, 100} { for _, subRequestCount := range []int{1, 2, 5, 10, 20, 50, 100} { - tripper := newLimitedParallelismRoundTripper(downstream, codec, mockLimits{maxQueryParallelism: maxParallelism}, + tripper := NewLimitedParallelismRoundTripper(downstream, codec, mockLimits{maxQueryParallelism: maxParallelism}, MetricsQueryMiddlewareFunc(func(next MetricsQueryHandler) MetricsQueryHandler { return HandlerFunc(func(c context.Context, _ MetricsQueryRequest) (Response, error) { wg := sync.WaitGroup{} diff --git a/pkg/frontend/querymiddleware/model_extra.go b/pkg/frontend/querymiddleware/model_extra.go index 7f20bce5169..6332c111de9 100644 --- a/pkg/frontend/querymiddleware/model_extra.go +++ b/pkg/frontend/querymiddleware/model_extra.go @@ -451,6 +451,10 @@ func (r *PrometheusLabelNamesQueryRequest) GetLabelName() string { return "" } +func (r *PrometheusSeriesQueryRequest) GetLabelName() string { + return "" +} + func (r *PrometheusLabelNamesQueryRequest) GetStartOrDefault() int64 { if r.GetStart() == 0 { return v1.MinTime.UnixMilli() @@ -479,6 +483,95 @@ func (r *PrometheusLabelValuesQueryRequest) GetEndOrDefault() int64 { return r.GetEnd() } +func (r *PrometheusSeriesQueryRequest) GetStartOrDefault() int64 { + if r.GetStart() == 0 { + return v1.MinTime.UnixMilli() + } + return r.GetStart() +} + +func (r *PrometheusSeriesQueryRequest) GetEndOrDefault() int64 { + if r.GetEnd() == 0 { + return v1.MaxTime.UnixMilli() + } + return r.GetEnd() +} + +func (r *PrometheusLabelNamesQueryRequest) GetHeaders() []*PrometheusHeader { + return r.Headers +} + +func (r *PrometheusLabelValuesQueryRequest) GetHeaders() []*PrometheusHeader { + return r.Headers +} + +func (r *PrometheusSeriesQueryRequest) GetHeaders() []*PrometheusHeader { + return r.Headers +} + +// WithLabelName clones the current `PrometheusLabelNamesQueryRequest` with a new label name param. +func (r *PrometheusLabelNamesQueryRequest) WithLabelName(string) (LabelsQueryRequest, error) { + return nil, fmt.Errorf("not implemented") +} + +// WithLabelName clones the current `PrometheusLabelValuesQueryRequest` with a new label name param. +func (r *PrometheusLabelValuesQueryRequest) WithLabelName(name string) (LabelsQueryRequest, error) { + newRequest := *r + newRequest.Path = labelValuesPathSuffix.ReplaceAllString(r.Path, `/api/v1/label/`+name+`/values`) + newRequest.LabelName = name + return &newRequest, nil +} + +// WithLabelName clones the current `PrometheusSeriesQueryRequest` with a new label name param. +func (r *PrometheusSeriesQueryRequest) WithLabelName(string) (LabelsQueryRequest, error) { + return nil, fmt.Errorf("not implemented") +} + +// WithLabelMatcherSets clones the current `PrometheusLabelNamesQueryRequest` with new label matcher sets. +func (r *PrometheusLabelNamesQueryRequest) WithLabelMatcherSets(labelMatcherSets []string) (LabelsQueryRequest, error) { + newRequest := *r + newRequest.LabelMatcherSets = make([]string, len(labelMatcherSets)) + copy(newRequest.LabelMatcherSets, labelMatcherSets) + return &newRequest, nil +} + +// WithLabelMatcherSets clones the current `PrometheusLabelValuesQueryRequest` with new label matcher sets. +func (r *PrometheusLabelValuesQueryRequest) WithLabelMatcherSets(labelMatcherSets []string) (LabelsQueryRequest, error) { + newRequest := *r + newRequest.LabelMatcherSets = make([]string, len(labelMatcherSets)) + copy(newRequest.LabelMatcherSets, labelMatcherSets) + return &newRequest, nil +} + +// WithLabelMatcherSets clones the current `PrometheusSeriesQueryRequest` with new label matcher sets. +func (r *PrometheusSeriesQueryRequest) WithLabelMatcherSets(labelMatcherSets []string) (LabelsQueryRequest, error) { + newRequest := *r + newRequest.LabelMatcherSets = make([]string, len(labelMatcherSets)) + copy(newRequest.LabelMatcherSets, labelMatcherSets) + return &newRequest, nil +} + +// WithHeaders clones the current `PrometheusLabelNamesQueryRequest` with new headers. +func (r *PrometheusLabelNamesQueryRequest) WithHeaders(headers []*PrometheusHeader) (LabelsQueryRequest, error) { + newRequest := *r + newRequest.Headers = cloneHeaders(headers) + return &newRequest, nil +} + +// WithHeaders clones the current `PrometheusLabelValuesQueryRequest` with new headers. +func (r *PrometheusLabelValuesQueryRequest) WithHeaders(headers []*PrometheusHeader) (LabelsQueryRequest, error) { + newRequest := *r + newRequest.Headers = cloneHeaders(headers) + return &newRequest, nil +} + +// WithHeaders clones the current `PrometheusSeriesQueryRequest` with new headers. +func (r *PrometheusSeriesQueryRequest) WithHeaders(headers []*PrometheusHeader) (LabelsQueryRequest, error) { + newRequest := *r + newRequest.Headers = cloneHeaders(headers) + return &newRequest, nil +} + // AddSpanTags writes query information about the current `PrometheusLabelNamesQueryRequest` // to a span's tag ("attributes" in OpenTelemetry parlance). func (r *PrometheusLabelNamesQueryRequest) AddSpanTags(sp opentracing.Span) { @@ -496,6 +589,14 @@ func (r *PrometheusLabelValuesQueryRequest) AddSpanTags(sp opentracing.Span) { sp.SetTag("end", timestamp.Time(r.GetEnd()).String()) } +// AddSpanTags writes query information about the current `PrometheusSeriesQueryRequest` +// to a span's tag ("attributes" in OpenTelemetry parlance). +func (r *PrometheusSeriesQueryRequest) AddSpanTags(sp opentracing.Span) { + sp.SetTag("matchers", fmt.Sprintf("%v", r.GetLabelMatcherSets())) + sp.SetTag("start", timestamp.Time(r.GetStart()).String()) + sp.SetTag("end", timestamp.Time(r.GetEnd()).String()) +} + type PrometheusLabelNamesQueryRequest struct { Path string Start int64 @@ -508,7 +609,8 @@ type PrometheusLabelNamesQueryRequest struct { // ID of the request used to correlate downstream requests and responses. ID int64 // Limit the number of label names returned. A value of 0 means no limit - Limit uint64 + Limit uint64 + Headers []*PrometheusHeader } func (r *PrometheusLabelNamesQueryRequest) GetPath() string { @@ -548,7 +650,8 @@ type PrometheusLabelValuesQueryRequest struct { // ID of the request used to correlate downstream requests and responses. ID int64 // Limit the number of label values returned. A value of 0 means no limit. - Limit uint64 + Limit uint64 + Headers []*PrometheusHeader } func (r *PrometheusLabelValuesQueryRequest) GetLabelName() string { @@ -576,6 +679,88 @@ func (r *PrometheusLabelValuesQueryRequest) GetLimit() uint64 { return r.Limit } +type PrometheusSeriesQueryRequest struct { + Path string + Start int64 + End int64 + // labelMatcherSets is a repeated field here in order to enable the representation + // of labels queries which have not yet been split; the prometheus querier code + // will eventually split requests like `?match[]=up&match[]=process_start_time_seconds{job="prometheus"}` + // into separate queries, one for each matcher set + LabelMatcherSets []string + // ID of the request used to correlate downstream requests and responses. + ID int64 + // Limit the number of label names returned. A value of 0 means no limit + Limit uint64 + Headers []*PrometheusHeader +} + +func (r *PrometheusSeriesQueryRequest) GetPath() string { + return r.Path +} + +func (r *PrometheusSeriesQueryRequest) GetStart() int64 { + return r.Start +} + +func (r *PrometheusSeriesQueryRequest) GetEnd() int64 { + return r.End +} + +func (r *PrometheusSeriesQueryRequest) GetLabelMatcherSets() []string { + return r.LabelMatcherSets +} + +func (r *PrometheusSeriesQueryRequest) GetID() int64 { + return r.ID +} + +func (r *PrometheusSeriesQueryRequest) GetLimit() uint64 { + return r.Limit +} + +type PrometheusLabelsResponse struct { + Status string `json:"status"` + Data []string `json:"data"` + ErrorType string `json:"errorType,omitempty"` + Error string `json:"error,omitempty"` + Headers []*PrometheusHeader `json:"-"` +} + +func (m *PrometheusLabelsResponse) GetHeaders() []*PrometheusHeader { + if m != nil { + return m.Headers + } + return nil +} + +func (m *PrometheusLabelsResponse) Reset() { *m = PrometheusLabelsResponse{} } +func (*PrometheusLabelsResponse) ProtoMessage() {} +func (m *PrometheusLabelsResponse) String() string { return fmt.Sprintf("%+v", *m) } + +type SeriesData map[string]string + +func (d *SeriesData) String() string { return fmt.Sprintf("%+v", *d) } + +type PrometheusSeriesResponse struct { + Status string `json:"status"` + Data []SeriesData `json:"data"` + ErrorType string `json:"errorType,omitempty"` + Error string `json:"error,omitempty"` + Headers []*PrometheusHeader `json:"-"` +} + +func (m *PrometheusSeriesResponse) GetHeaders() []*PrometheusHeader { + if m != nil { + return m.Headers + } + return nil +} + +func (m *PrometheusSeriesResponse) Reset() { *m = PrometheusSeriesResponse{} } +func (*PrometheusSeriesResponse) ProtoMessage() {} +func (m *PrometheusSeriesResponse) String() string { return fmt.Sprintf("%+v", *m) } + func (d *PrometheusData) UnmarshalJSON(b []byte) error { v := struct { Type model.ValueType `json:"resultType"` diff --git a/pkg/frontend/querymiddleware/prune.go b/pkg/frontend/querymiddleware/prune.go index daf4180bf66..b079aa6edf1 100644 --- a/pkg/frontend/querymiddleware/prune.go +++ b/pkg/frontend/querymiddleware/prune.go @@ -59,7 +59,7 @@ func (p *pruneMiddleware) pruneQuery(ctx context.Context, query string) (string, // Parse the query. expr, err := parser.ParseExpr(query) if err != nil { - return "", false, apierror.New(apierror.TypeBadData, decorateWithParamName(err, "query").Error()) + return "", false, apierror.New(apierror.TypeBadData, DecorateWithParamName(err, "query").Error()) } origQueryString := expr.String() diff --git a/pkg/frontend/querymiddleware/querysharding.go b/pkg/frontend/querymiddleware/querysharding.go index 10f838cf8db..a1954c07ecb 100644 --- a/pkg/frontend/querymiddleware/querysharding.go +++ b/pkg/frontend/querymiddleware/querysharding.go @@ -109,7 +109,7 @@ func (s *querySharding) Do(ctx context.Context, r MetricsQueryRequest) (Response // Parse the query. queryExpr, err := parser.ParseExpr(r.GetQuery()) if err != nil { - return nil, apierror.New(apierror.TypeBadData, decorateWithParamName(err, "query").Error()) + return nil, apierror.New(apierror.TypeBadData, DecorateWithParamName(err, "query").Error()) } totalShards := s.getShardsForQuery(ctx, tenantIDs, r, queryExpr, log) @@ -151,10 +151,14 @@ func (s *querySharding) Do(ctx context.Context, r MetricsQueryRequest) (Response return nil, apierror.New(apierror.TypeBadData, err.Error()) } - annotationAccumulator := newAnnotationAccumulator() - shardedQueryable := newShardedQueryable(r, annotationAccumulator, s.next) + annotationAccumulator := NewAnnotationAccumulator() + shardedQueryable := NewShardedQueryable(r, annotationAccumulator, s.next, nil) - qry, err := newQuery(ctx, r, s.engine, lazyquery.NewLazyQueryable(shardedQueryable)) + return ExecuteQueryOnQueryable(ctx, r, s.engine, shardedQueryable, annotationAccumulator) +} + +func ExecuteQueryOnQueryable(ctx context.Context, r MetricsQueryRequest, engine *promql.Engine, queryable storage.Queryable, annotationAccumulator *annotationAccumulator) (Response, error) { + qry, err := newQuery(ctx, r, engine, lazyquery.NewLazyQueryable(queryable)) if err != nil { return nil, apierror.New(apierror.TypeBadData, err.Error()) } @@ -169,12 +173,20 @@ func (s *querySharding) Do(ctx context.Context, r MetricsQueryRequest) (Response // query, so we pass in an empty string as the query so the positions will be hidden. warn, info := res.Warnings.AsStrings("", 0, 0) - // Add any annotations returned by the sharded queries, and remove any duplicates. - accumulatedWarnings, accumulatedInfos := annotationAccumulator.getAll() - warn = append(warn, accumulatedWarnings...) - info = append(info, accumulatedInfos...) - warn = removeDuplicates(warn) - info = removeDuplicates(info) + if annotationAccumulator != nil { + // Add any annotations returned by the sharded queries, and remove any duplicates. + accumulatedWarnings, accumulatedInfos := annotationAccumulator.getAll() + warn = append(warn, accumulatedWarnings...) + info = append(info, accumulatedInfos...) + warn = removeDuplicates(warn) + info = removeDuplicates(info) + } + + var headers []*PrometheusHeader + shardedQueryable, ok := queryable.(*shardedQueryable) + if ok { + headers = shardedQueryable.getResponseHeaders() + } return &PrometheusResponse{ Status: statusSuccess, @@ -182,7 +194,7 @@ func (s *querySharding) Do(ctx context.Context, r MetricsQueryRequest) (Response ResultType: string(res.Value.Type()), Result: extracted, }, - Headers: shardedQueryable.getResponseHeaders(), + Headers: headers, Warnings: warn, Infos: info, }, nil @@ -275,7 +287,7 @@ func (s *querySharding) shardQuery(ctx context.Context, query string, totalShard // each time before passing it to the mapper. expr, err := parser.ParseExpr(query) if err != nil { - return "", nil, apierror.New(apierror.TypeBadData, decorateWithParamName(err, "query").Error()) + return "", nil, apierror.New(apierror.TypeBadData, DecorateWithParamName(err, "query").Error()) } shardedQuery, err := mapper.Map(expr) @@ -496,7 +508,7 @@ type annotationAccumulator struct { infos *sync.Map } -func newAnnotationAccumulator() *annotationAccumulator { +func NewAnnotationAccumulator() *annotationAccumulator { //nolint:revive return &annotationAccumulator{ warnings: &sync.Map{}, infos: &sync.Map{}, diff --git a/pkg/frontend/querymiddleware/querysharding_test.go b/pkg/frontend/querymiddleware/querysharding_test.go index e97ff7b0575..ef6b7cda60c 100644 --- a/pkg/frontend/querymiddleware/querysharding_test.go +++ b/pkg/frontend/querymiddleware/querysharding_test.go @@ -75,9 +75,9 @@ func approximatelyEqualsSamples(t *testing.T, a, b *PrometheusResponse) { require.Equal(t, statusSuccess, a.Status) require.Equal(t, statusSuccess, b.Status) - as, err := responseToSamples(a) + as, err := ResponseToSamples(a) require.Nil(t, err) - bs, err := responseToSamples(b) + bs, err := ResponseToSamples(b) require.Nil(t, err) require.Equalf(t, len(as), len(bs), "expected same number of series: one contains %v, other %v", sampleStreamsStrings(as), sampleStreamsStrings(bs)) @@ -923,7 +923,7 @@ func TestQueryshardingDeterminism(t *testing.T) { shardedPrometheusRes := shardedRes.(*PrometheusResponse) - sampleStreams, err := responseToSamples(shardedPrometheusRes) + sampleStreams, err := ResponseToSamples(shardedPrometheusRes) require.NoError(t, err) require.Lenf(t, sampleStreams, 1, "There should be 1 samples stream (query %d)", i) diff --git a/pkg/frontend/querymiddleware/remote_read.go b/pkg/frontend/querymiddleware/remote_read.go index 463b20c052d..11070ce18d5 100644 --- a/pkg/frontend/querymiddleware/remote_read.go +++ b/pkg/frontend/querymiddleware/remote_read.go @@ -37,7 +37,7 @@ type remoteReadRoundTripper struct { middleware MetricsQueryMiddleware } -func newRemoteReadRoundTripper(next http.RoundTripper, middlewares ...MetricsQueryMiddleware) http.RoundTripper { +func NewRemoteReadRoundTripper(next http.RoundTripper, middlewares ...MetricsQueryMiddleware) http.RoundTripper { return &remoteReadRoundTripper{ next: next, middleware: MergeMetricsQueryMiddlewares(middlewares...), diff --git a/pkg/frontend/querymiddleware/remote_read_test.go b/pkg/frontend/querymiddleware/remote_read_test.go index ec87b3b020b..658f38cd0c5 100644 --- a/pkg/frontend/querymiddleware/remote_read_test.go +++ b/pkg/frontend/querymiddleware/remote_read_test.go @@ -135,7 +135,7 @@ func TestRemoteReadRoundTripperCallsDownstreamOnAll(t *testing.T) { actualMiddleWareCalls++ return tc.handler }) - rr := newRemoteReadRoundTripper(roundTripper, middleware) + rr := NewRemoteReadRoundTripper(roundTripper, middleware) _, err := rr.RoundTrip(makeTestHTTPRequestFromRemoteRead(makeTestRemoteReadRequest())) if tc.expectError != "" { require.Error(t, err) @@ -195,7 +195,7 @@ func TestRemoteReadRoundTripper_ShouldAllowMiddlewaresToManipulateRequest(t *tes }, } - rr := newRemoteReadRoundTripper(downstream, middleware) + rr := NewRemoteReadRoundTripper(downstream, middleware) _, err := rr.RoundTrip(makeTestHTTPRequestFromRemoteRead(origRemoteReadReq)) require.NoError(t, err) require.NotNil(t, downstreamReq) @@ -255,7 +255,7 @@ func TestRemoteReadRoundTripper_ShouldAllowMiddlewaresToReturnEmptyResponse(t *t }, } - rr := newRemoteReadRoundTripper(downstream, middleware) + rr := NewRemoteReadRoundTripper(downstream, middleware) origRemoteReadReq := makeTestRemoteReadRequest() _, err := rr.RoundTrip(makeTestHTTPRequestFromRemoteRead(origRemoteReadReq)) diff --git a/pkg/frontend/querymiddleware/roundtrip.go b/pkg/frontend/querymiddleware/roundtrip.go index f6d2ed82bb6..3f29ea7aefb 100644 --- a/pkg/frontend/querymiddleware/roundtrip.go +++ b/pkg/frontend/querymiddleware/roundtrip.go @@ -38,6 +38,7 @@ const ( cardinalityActiveNativeHistogramMetricsPathSuffix = "/api/v1/cardinality/active_native_histogram_metrics" labelNamesPathSuffix = "/api/v1/labels" remoteReadPathSuffix = "/api/v1/read" + seriesPathSuffix = "/api/v1/series" queryTypeInstant = "query" queryTypeRange = "query_range" @@ -134,6 +135,19 @@ type MetricsQueryHandler interface { Do(context.Context, MetricsQueryRequest) (Response, error) } +// LabelsHandlerFunc is like http.HandlerFunc, but for LabelsQueryHandler. +type LabelsHandlerFunc func(context.Context, LabelsQueryRequest) (Response, error) + +// Do implements LabelsQueryHandler. +func (q LabelsHandlerFunc) Do(ctx context.Context, req LabelsQueryRequest) (Response, error) { + return q(ctx, req) +} + +// LabelsQueryHandler is like http.Handle, but specifically for Prometheus label names and values calls. +type LabelsQueryHandler interface { + Do(context.Context, LabelsQueryRequest) (Response, error) +} + // MetricsQueryMiddlewareFunc is like http.HandlerFunc, but for MetricsQueryMiddleware. type MetricsQueryMiddlewareFunc func(MetricsQueryHandler) MetricsQueryHandler @@ -243,14 +257,14 @@ func newQueryTripperware( // It means that the first roundtrippers defined in this function will be the last to be // executed. - queryrange := newLimitedParallelismRoundTripper(next, codec, limits, queryRangeMiddleware...) - instant := newLimitedParallelismRoundTripper(next, codec, limits, queryInstantMiddleware...) - remoteRead := newRemoteReadRoundTripper(next, remoteReadMiddleware...) + queryrange := NewLimitedParallelismRoundTripper(next, codec, limits, queryRangeMiddleware...) + instant := NewLimitedParallelismRoundTripper(next, codec, limits, queryInstantMiddleware...) + remoteRead := NewRemoteReadRoundTripper(next, remoteReadMiddleware...) // Wrap next for cardinality, labels queries and all other queries. // That attempts to parse "start" and "end" from the HTTP request and set them in the request's QueryDetails. // range and instant queries have more accurate logic for query details. - next = newQueryDetailsStartEndRoundTripper(next) + next = NewQueryDetailsStartEndRoundTripper(next) cardinality := next activeSeries := next activeNativeHistogramMetrics := next @@ -432,8 +446,8 @@ func newQueryMiddlewares( return } -// newQueryDetailsStartEndRoundTripper parses "start" and "end" parameters from the query and sets same fields in the QueryDetails in the context. -func newQueryDetailsStartEndRoundTripper(next http.RoundTripper) http.RoundTripper { +// NewQueryDetailsStartEndRoundTripper parses "start" and "end" parameters from the query and sets same fields in the QueryDetails in the context. +func NewQueryDetailsStartEndRoundTripper(next http.RoundTripper) http.RoundTripper { return RoundTripFunc(func(req *http.Request) (*http.Response, error) { params, _ := util.ParseRequestFormWithoutConsumingBody(req) if details := QueryDetailsFromContext(req.Context()); details != nil { @@ -523,6 +537,10 @@ func IsLabelsQuery(path string) bool { return IsLabelNamesQuery(path) || IsLabelValuesQuery(path) } +func IsSeriesQuery(path string) bool { + return strings.HasSuffix(path, seriesPathSuffix) +} + func IsActiveSeriesQuery(path string) bool { return strings.HasSuffix(path, cardinalityActiveSeriesPathSuffix) } diff --git a/pkg/frontend/querymiddleware/sharded_queryable.go b/pkg/frontend/querymiddleware/sharded_queryable.go index 7f65a2e3e91..f50d9cdc346 100644 --- a/pkg/frontend/querymiddleware/sharded_queryable.go +++ b/pkg/frontend/querymiddleware/sharded_queryable.go @@ -32,29 +32,36 @@ var ( errNotImplemented = errors.New("not implemented") ) +type HandleEmbeddedQueryFunc func(ctx context.Context, queryString string, query MetricsQueryRequest, handler MetricsQueryHandler) ([]SampleStream, *PrometheusResponse, error) + // shardedQueryable is an implementor of the Queryable interface. type shardedQueryable struct { req MetricsQueryRequest annotationAccumulator *annotationAccumulator handler MetricsQueryHandler responseHeaders *responseHeadersTracker + handleEmbeddedQuery HandleEmbeddedQueryFunc } -// newShardedQueryable makes a new shardedQueryable. We expect a new queryable is created for each +// NewShardedQueryable makes a new shardedQueryable. We expect a new queryable is created for each // query, otherwise the response headers tracker doesn't work as expected, because it merges the // headers for all queries run through the queryable and never reset them. -func newShardedQueryable(req MetricsQueryRequest, annotationAccumulator *annotationAccumulator, next MetricsQueryHandler) *shardedQueryable { +func NewShardedQueryable(req MetricsQueryRequest, annotationAccumulator *annotationAccumulator, next MetricsQueryHandler, handleEmbeddedQuery HandleEmbeddedQueryFunc) *shardedQueryable { //nolint:revive + if handleEmbeddedQuery == nil { + handleEmbeddedQuery = defaultHandleEmbeddedQueryFunc() + } return &shardedQueryable{ req: req, annotationAccumulator: annotationAccumulator, handler: next, responseHeaders: newResponseHeadersTracker(), + handleEmbeddedQuery: handleEmbeddedQuery, } } // Querier implements storage.Queryable. func (q *shardedQueryable) Querier(_, _ int64) (storage.Querier, error) { - return &shardedQuerier{req: q.req, annotationAccumulator: q.annotationAccumulator, handler: q.handler, responseHeaders: q.responseHeaders}, nil + return &shardedQuerier{req: q.req, annotationAccumulator: q.annotationAccumulator, handler: q.handler, responseHeaders: q.responseHeaders, handleEmbeddedQuery: q.handleEmbeddedQuery}, nil } // getResponseHeaders returns the merged response headers received by the downstream @@ -73,6 +80,8 @@ type shardedQuerier struct { // Keep track of response headers received when running embedded queries. responseHeaders *responseHeadersTracker + + handleEmbeddedQuery HandleEmbeddedQueryFunc } // Select implements storage.Querier. @@ -106,33 +115,46 @@ func (q *shardedQuerier) Select(ctx context.Context, _ bool, hints *storage.Sele return q.handleEmbeddedQueries(ctx, queries, hints) } -// handleEmbeddedQueries concurrently executes the provided queries through the downstream handler. -// The returned storage.SeriesSet contains sorted series. -func (q *shardedQuerier) handleEmbeddedQueries(ctx context.Context, queries []string, hints *storage.SelectHints) storage.SeriesSet { - streams := make([][]SampleStream, len(queries)) - - // Concurrently run each query. It breaks and cancels each worker context on first error. - err := concurrency.ForEachJob(ctx, len(queries), len(queries), func(ctx context.Context, idx int) error { - query, err := q.req.WithQuery(queries[idx]) +func defaultHandleEmbeddedQueryFunc() HandleEmbeddedQueryFunc { + return func(ctx context.Context, queryString string, query MetricsQueryRequest, handler MetricsQueryHandler) ([]SampleStream, *PrometheusResponse, error) { + query, err := query.WithQuery(queryString) if err != nil { - return err + return nil, nil, err } - resp, err := q.handler.Do(ctx, query) + + resp, err := handler.Do(ctx, query) if err != nil { - return err + return nil, nil, err } promRes, ok := resp.(*PrometheusResponse) if !ok { - return errors.Errorf("error invalid response type: %T, expected: %T", resp, &PrometheusResponse{}) + return nil, nil, errors.Errorf("error invalid response type: %T, expected: %T", resp, &PrometheusResponse{}) } - resStreams, err := responseToSamples(promRes) + resStreams, err := ResponseToSamples(promRes) + if err != nil { + return nil, nil, err + } + + return resStreams, promRes, nil + } +} + +// handleEmbeddedQueries concurrently executes the provided queries through the downstream handler. +// The returned storage.SeriesSet contains sorted series. +func (q *shardedQuerier) handleEmbeddedQueries(ctx context.Context, queries []string, hints *storage.SelectHints) storage.SeriesSet { + streams := make([][]SampleStream, len(queries)) + + // Concurrently run each query. It breaks and cancels each worker context on first error. + err := concurrency.ForEachJob(ctx, len(queries), len(queries), func(ctx context.Context, idx int) error { + resStreams, promRes, err := q.handleEmbeddedQuery(ctx, queries[idx], q.req, q.handler) if err != nil { return err } + streams[idx] = resStreams // No mutex is needed since each job writes its own index. This is like writing separate variables. - q.responseHeaders.mergeHeaders(resp.(*PrometheusResponse).Headers) + q.responseHeaders.mergeHeaders(promRes.Headers) q.annotationAccumulator.addInfos(promRes.Infos) q.annotationAccumulator.addWarnings(promRes.Warnings) @@ -298,8 +320,8 @@ func newSeriesSetFromEmbeddedQueriesResults(results [][]SampleStream, hints *sto return series.NewConcreteSeriesSetFromUnsortedSeries(set) } -// responseToSamples is needed to map back from api response to the underlying series data -func responseToSamples(resp *PrometheusResponse) ([]SampleStream, error) { +// ResponseToSamples is needed to map back from api response to the underlying series data +func ResponseToSamples(resp *PrometheusResponse) ([]SampleStream, error) { if resp.Error != "" { return nil, errors.New(resp.Error) } diff --git a/pkg/frontend/querymiddleware/sharded_queryable_test.go b/pkg/frontend/querymiddleware/sharded_queryable_test.go index 273b3db6cc0..506b3b43796 100644 --- a/pkg/frontend/querymiddleware/sharded_queryable_test.go +++ b/pkg/frontend/querymiddleware/sharded_queryable_test.go @@ -257,7 +257,7 @@ func TestShardedQuerier_Select_ShouldConcurrentlyRunEmbeddedQueries(t *testing.T } func TestShardedQueryable_GetResponseHeaders(t *testing.T) { - queryable := newShardedQueryable(&PrometheusRangeQueryRequest{}, nil, nil) + queryable := NewShardedQueryable(&PrometheusRangeQueryRequest{}, nil, nil, nil) assert.Empty(t, queryable.getResponseHeaders()) // Merge some response headers from the 1st querier. @@ -288,7 +288,7 @@ func TestShardedQueryable_GetResponseHeaders(t *testing.T) { } func mkShardedQuerier(handler MetricsQueryHandler) *shardedQuerier { - return &shardedQuerier{req: &PrometheusRangeQueryRequest{}, handler: handler, responseHeaders: newResponseHeadersTracker()} + return &shardedQuerier{req: &PrometheusRangeQueryRequest{}, handler: handler, responseHeaders: newResponseHeadersTracker(), handleEmbeddedQuery: defaultHandleEmbeddedQueryFunc()} } func TestNewSeriesSetFromEmbeddedQueriesResults(t *testing.T) { @@ -418,7 +418,7 @@ func TestResponseToSamples(t *testing.T) { }, } - streams, err := responseToSamples(input) + streams, err := ResponseToSamples(input) require.NoError(t, err) assertEqualSampleStream(t, input.Data.Result, streams) } diff --git a/pkg/frontend/querymiddleware/split_and_cache.go b/pkg/frontend/querymiddleware/split_and_cache.go index 12b3c82d4b4..e888f4a3889 100644 --- a/pkg/frontend/querymiddleware/split_and_cache.go +++ b/pkg/frontend/querymiddleware/split_and_cache.go @@ -678,7 +678,7 @@ func splitQueryByInterval(req MetricsQueryRequest, interval time.Duration) ([]Me func evaluateAtModifierFunction(query string, start, end int64) (string, error) { expr, err := parser.ParseExpr(query) if err != nil { - return "", apierror.New(apierror.TypeBadData, decorateWithParamName(err, "query").Error()) + return "", apierror.New(apierror.TypeBadData, DecorateWithParamName(err, "query").Error()) } parser.Inspect(expr, func(n parser.Node, _ []parser.Node) error { switch exprAt := n.(type) { diff --git a/pkg/frontend/querymiddleware/split_by_instant_interval.go b/pkg/frontend/querymiddleware/split_by_instant_interval.go index 0d3687362d3..59f930f20bc 100644 --- a/pkg/frontend/querymiddleware/split_by_instant_interval.go +++ b/pkg/frontend/querymiddleware/split_by_instant_interval.go @@ -130,7 +130,7 @@ func (s *splitInstantQueryByIntervalMiddleware) Do(ctx context.Context, req Metr if err != nil { level.Warn(spanLog).Log("msg", "failed to parse query", "err", err) s.metrics.splittingSkipped.WithLabelValues(skippedReasonParsingFailed).Inc() - return nil, apierror.New(apierror.TypeBadData, decorateWithParamName(err, "query").Error()) + return nil, apierror.New(apierror.TypeBadData, DecorateWithParamName(err, "query").Error()) } instantSplitQuery, err := mapper.Map(expr) @@ -180,8 +180,8 @@ func (s *splitInstantQueryByIntervalMiddleware) Do(ctx context.Context, req Metr return nil, err } - annotationAccumulator := newAnnotationAccumulator() - shardedQueryable := newShardedQueryable(req, annotationAccumulator, s.next) + annotationAccumulator := NewAnnotationAccumulator() + shardedQueryable := NewShardedQueryable(req, annotationAccumulator, s.next, nil) qry, err := newQuery(ctx, req, s.engine, lazyquery.NewLazyQueryable(shardedQueryable)) if err != nil { diff --git a/pkg/querier/tenantfederation/merge_exemplar_queryable.go b/pkg/querier/tenantfederation/merge_exemplar_queryable.go index 51c441c0d35..4288752f05f 100644 --- a/pkg/querier/tenantfederation/merge_exemplar_queryable.go +++ b/pkg/querier/tenantfederation/merge_exemplar_queryable.go @@ -225,7 +225,7 @@ func filterTenantsAndRewriteMatchers(idLabelName string, ids []string, allMatche // In order to support that, we start with a set of 0 tenant IDs and add any tenant IDs that remain // after filtering (based on the inner slice of matchers), for each outer slice. for i, matchers := range allMatchers { - filteredIDs, unrelatedMatchers := filterValuesByMatchers(idLabelName, ids, matchers...) + filteredIDs, unrelatedMatchers := FilterValuesByMatchers(idLabelName, ids, matchers...) for k := range filteredIDs { outIDs[k] = struct{}{} } diff --git a/pkg/querier/tenantfederation/merge_queryable.go b/pkg/querier/tenantfederation/merge_queryable.go index e3b350280a6..c01da5e2871 100644 --- a/pkg/querier/tenantfederation/merge_queryable.go +++ b/pkg/querier/tenantfederation/merge_queryable.go @@ -198,7 +198,7 @@ func (m *mergeQuerier) LabelValues(ctx context.Context, name string, hints *stor spanlog, ctx := spanlogger.NewWithLogger(ctx, m.logger, "mergeQuerier.LabelValues") defer spanlog.Finish() - matchedIDs, filteredMatchers := filterValuesByMatchers(m.idLabelName, ids, matchers...) + matchedIDs, filteredMatchers := FilterValuesByMatchers(m.idLabelName, ids, matchers...) if name == m.idLabelName { labelValues := make([]string, 0, len(matchedIDs)) @@ -237,7 +237,7 @@ func (m *mergeQuerier) LabelNames(ctx context.Context, hints *storage.LabelHints spanlog, ctx := spanlogger.NewWithLogger(ctx, m.logger, "mergeQuerier.LabelNames") defer spanlog.Finish() - matchedIDs, filteredMatchers := filterValuesByMatchers(m.idLabelName, ids, matchers...) + matchedIDs, filteredMatchers := FilterValuesByMatchers(m.idLabelName, ids, matchers...) labelNames, warnings, err := m.mergeDistinctStringSliceWithTenants(ctx, matchedIDs, func(ctx context.Context, id string) ([]string, annotations.Annotations, error) { return m.upstream.LabelNames(ctx, id, hints, filteredMatchers...) @@ -349,7 +349,7 @@ func (m *mergeQuerier) Select(ctx context.Context, sortSeries bool, hints *stora spanlog, ctx := spanlogger.NewWithLogger(ctx, m.logger, "mergeQuerier.Select") defer spanlog.Finish() - matchedIDs, filteredMatchers := filterValuesByMatchers(m.idLabelName, ids, matchers...) + matchedIDs, filteredMatchers := FilterValuesByMatchers(m.idLabelName, ids, matchers...) jobs := make([]string, 0, len(matchedIDs)) seriesSets := make([]storage.SeriesSet, len(matchedIDs)) diff --git a/pkg/querier/tenantfederation/tenant_federation.go b/pkg/querier/tenantfederation/tenant_federation.go index 85abdc7b85b..50826eef68d 100644 --- a/pkg/querier/tenantfederation/tenant_federation.go +++ b/pkg/querier/tenantfederation/tenant_federation.go @@ -31,7 +31,7 @@ func (cfg *Config) RegisterFlags(f *flag.FlagSet) { f.IntVar(&cfg.MaxTenants, "tenant-federation.max-tenants", defaultMaxTenants, "The max number of tenant IDs that may be supplied for a federated query if enabled. 0 to disable the limit.") } -// filterValuesByMatchers applies matchers to inputed `idLabelName` and +// FilterValuesByMatchers applies matchers to inputed `idLabelName` and // `ids`. A set of matched IDs is returned and also all label matchers not // targeting the `idLabelName` label. // @@ -40,7 +40,7 @@ func (cfg *Config) RegisterFlags(f *flag.FlagSet) { // to as part of Select in the mergeQueryable, to ensure only relevant queries // are considered and the forwarded matchers do not contain matchers on the // `idLabelName`. -func filterValuesByMatchers(idLabelName string, ids []string, matchers ...*labels.Matcher) (matchedIDs map[string]struct{}, unrelatedMatchers []*labels.Matcher) { +func FilterValuesByMatchers(idLabelName string, ids []string, matchers ...*labels.Matcher) (matchedIDs map[string]struct{}, unrelatedMatchers []*labels.Matcher) { // this contains the matchers which are not related to idLabelName unrelatedMatchers = make([]*labels.Matcher, 0, len(matchers))