diff --git a/pkg/frontend/querymiddleware/astmapper/astmapper_test.go b/pkg/frontend/querymiddleware/astmapper/astmapper_test.go index f6ff1c41c79..06b6b6948e1 100644 --- a/pkg/frontend/querymiddleware/astmapper/astmapper_test.go +++ b/pkg/frontend/querymiddleware/astmapper/astmapper_test.go @@ -138,8 +138,10 @@ func TestSharding_BinaryExpressionsDontTakeExponentialTime(t *testing.T) { ctx, cancel := context.WithTimeout(context.Background(), timeout) defer cancel() - mapper, err := NewSharding(ctx, 2, log.NewNopLogger(), NewMapperStats()) + + summer, err := NewQueryShardSummer(ctx, 2, VectorSquasher, log.NewNopLogger(), NewMapperStats()) require.NoError(t, err) + mapper := NewSharding(summer) _, err = mapper.Map(expr) require.NoError(t, err) diff --git a/pkg/frontend/querymiddleware/astmapper/embedded.go b/pkg/frontend/querymiddleware/astmapper/embedded.go index 66770c30b0b..22a51b11272 100644 --- a/pkg/frontend/querymiddleware/astmapper/embedded.go +++ b/pkg/frontend/querymiddleware/astmapper/embedded.go @@ -62,7 +62,7 @@ func (c jsonCodec) Decode(encoded string) (queries []string, err error) { // VectorSquash reduces an AST into a single vector query which can be hijacked by a Queryable impl. // It always uses a VectorSelector as the substitution expr. // This is important because logical/set binops can only be applied against vectors and not matrices. -func vectorSquasher(exprs ...parser.Expr) (parser.Expr, error) { +func VectorSquasher(exprs ...parser.Expr) (parser.Expr, error) { // concat OR legs strs := make([]string, 0, len(exprs)) for _, expr := range exprs { diff --git a/pkg/frontend/querymiddleware/astmapper/instant_splitting.go b/pkg/frontend/querymiddleware/astmapper/instant_splitting.go index 35196a6a1c3..cffd12b8c3b 100644 --- a/pkg/frontend/querymiddleware/astmapper/instant_splitting.go +++ b/pkg/frontend/querymiddleware/astmapper/instant_splitting.go @@ -380,7 +380,7 @@ func (i *instantSplitter) splitAndSquashCall(expr *parser.Call, rangeInterval ti embeddedQueries = append([]parser.Expr{splitExpr}, embeddedQueries...) } - squashExpr, err := vectorSquasher(embeddedQueries...) + squashExpr, err := VectorSquasher(embeddedQueries...) if err != nil { return nil, false, err } diff --git a/pkg/frontend/querymiddleware/astmapper/sharding.go b/pkg/frontend/querymiddleware/astmapper/sharding.go index b6ab1e0d3cb..a9fc037e40e 100644 --- a/pkg/frontend/querymiddleware/astmapper/sharding.go +++ b/pkg/frontend/querymiddleware/astmapper/sharding.go @@ -18,39 +18,71 @@ import ( ) // NewSharding creates a new query sharding mapper. -func NewSharding(ctx context.Context, shards int, logger log.Logger, stats *MapperStats) (ASTMapper, error) { - shardSummer, err := newShardSummer(ctx, shards, vectorSquasher, logger, stats) - if err != nil { - return nil, err - } +func NewSharding(shardSummer ASTMapper) ASTMapper { subtreeFolder := newSubtreeFolder() return NewMultiMapper( shardSummer, subtreeFolder, - ), nil + ) +} + +type Squasher = func(...parser.Expr) (parser.Expr, error) + +type ShardLabeller interface { + GetLabelName() string + GetLabelValue(shard int) string } -type squasher = func(...parser.Expr) (parser.Expr, error) +// queryShardLabeller implements ShardLabeller for query sharding. +type queryShardLabeller struct { + shards int +} + +func newQueryShardLabeller(shards int) ShardLabeller { + return &queryShardLabeller{shards: shards} +} + +func (lbl *queryShardLabeller) GetLabelName() string { + return sharding.ShardLabel +} + +func (lbl *queryShardLabeller) GetLabelValue(shard int) string { + return sharding.ShardSelector{ShardIndex: uint64(shard), ShardCount: uint64(lbl.shards)}.LabelValue() +} + +// NewQueryShardSummer instantiates an ASTMapper which will fan out sum queries by shard. +func NewQueryShardSummer(ctx context.Context, shards int, squasher Squasher, logger log.Logger, stats *MapperStats) (ASTMapper, error) { + return NewShardSummerWithLabeller(ctx, shards, squasher, logger, stats, newQueryShardLabeller(shards)) +} + +func NewShardSummerWithLabeller(ctx context.Context, shards int, squasher Squasher, logger log.Logger, stats *MapperStats, labeller ShardLabeller) (ASTMapper, error) { + summer, err := newShardSummer(ctx, shards, squasher, logger, stats, labeller) + if err != nil { + return nil, err + } + return NewASTExprMapper(summer), nil +} type shardSummer struct { ctx context.Context shards int currentShard *int - squash squasher + squash Squasher logger log.Logger stats *MapperStats + shardLabeller ShardLabeller + canShardAllVectorSelectorsCache map[string]bool } -// newShardSummer instantiates an ASTMapper which will fan out sum queries by shard -func newShardSummer(ctx context.Context, shards int, squasher squasher, logger log.Logger, stats *MapperStats) (ASTMapper, error) { +func newShardSummer(ctx context.Context, shards int, squasher Squasher, logger log.Logger, stats *MapperStats, shardLabeller ShardLabeller) (*shardSummer, error) { if squasher == nil { return nil, errors.Errorf("squasher required and not passed") } - return NewASTExprMapper(&shardSummer{ + return &shardSummer{ ctx: ctx, shards: shards, @@ -59,8 +91,10 @@ func newShardSummer(ctx context.Context, shards int, squasher squasher, logger l logger: logger, stats: stats, + shardLabeller: shardLabeller, + canShardAllVectorSelectorsCache: make(map[string]bool), - }), nil + }, nil } // Clone returns a clone of shardSummer with stats and current shard index reset to default. @@ -98,7 +132,7 @@ func (summer *shardSummer) MapExpr(expr parser.Expr) (mapped parser.Expr, finish case *parser.VectorSelector: if summer.currentShard != nil { - mapped, err := shardVectorSelector(*summer.currentShard, summer.shards, e) + mapped, err := summer.shardVectorSelector(e) return mapped, true, err } return e, true, nil @@ -522,8 +556,8 @@ func (summer *shardSummer) shardAndSquashBinOp(expr *parser.BinaryExpr) (parser. return summer.squash(children...) } -func shardVectorSelector(curshard, shards int, selector *parser.VectorSelector) (parser.Expr, error) { - shardMatcher, err := labels.NewMatcher(labels.MatchEqual, sharding.ShardLabel, sharding.ShardSelector{ShardIndex: uint64(curshard), ShardCount: uint64(shards)}.LabelValue()) +func (summer *shardSummer) shardVectorSelector(selector *parser.VectorSelector) (parser.Expr, error) { + shardMatcher, err := labels.NewMatcher(labels.MatchEqual, summer.shardLabeller.GetLabelName(), summer.shardLabeller.GetLabelValue(*summer.currentShard)) if err != nil { return nil, err } diff --git a/pkg/frontend/querymiddleware/astmapper/sharding_test.go b/pkg/frontend/querymiddleware/astmapper/sharding_test.go index 2ae900ff963..b0ef16ff393 100644 --- a/pkg/frontend/querymiddleware/astmapper/sharding_test.go +++ b/pkg/frontend/querymiddleware/astmapper/sharding_test.go @@ -523,8 +523,9 @@ func TestShardSummer(t *testing.T) { t.Run(tt.in, func(t *testing.T) { stats := NewMapperStats() - mapper, err := NewSharding(context.Background(), 3, log.NewNopLogger(), stats) + summer, err := NewQueryShardSummer(context.Background(), 3, VectorSquasher, log.NewNopLogger(), stats) require.NoError(t, err) + mapper := NewSharding(summer) expr, err := parser.ParseExpr(tt.in) require.NoError(t, err) out, err := parser.ParseExpr(tt.out) @@ -556,7 +557,7 @@ func concat(queries ...string) string { exprs = append(exprs, n) } - mapped, err := vectorSquasher(exprs...) + mapped, err := VectorSquasher(exprs...) if err != nil { panic(err) } @@ -577,7 +578,7 @@ func TestShardSummerWithEncoding(t *testing.T) { } { t.Run(fmt.Sprintf("[%d]", i), func(t *testing.T) { stats := NewMapperStats() - summer, err := newShardSummer(context.Background(), c.shards, vectorSquasher, log.NewNopLogger(), stats) + summer, err := NewQueryShardSummer(context.Background(), c.shards, VectorSquasher, log.NewNopLogger(), stats) require.Nil(t, err) expr, err := parser.ParseExpr(c.input) require.Nil(t, err) diff --git a/pkg/frontend/querymiddleware/astmapper/subtree_folder.go b/pkg/frontend/querymiddleware/astmapper/subtree_folder.go index f6b646b53ee..e0764d72434 100644 --- a/pkg/frontend/querymiddleware/astmapper/subtree_folder.go +++ b/pkg/frontend/querymiddleware/astmapper/subtree_folder.go @@ -39,7 +39,7 @@ func (f *subtreeFolder) MapExpr(expr parser.Expr) (mapped parser.Expr, finished // Change the expr if it contains vector selectors, as only those need to be embedded. if hasVectorSelector { - expr, err := vectorSquasher(expr) + expr, err := VectorSquasher(expr) return expr, true, err } return expr, false, nil diff --git a/pkg/frontend/querymiddleware/querysharding.go b/pkg/frontend/querymiddleware/querysharding.go index 3ac8456dccd..bea9986f30c 100644 --- a/pkg/frontend/querymiddleware/querysharding.go +++ b/pkg/frontend/querymiddleware/querysharding.go @@ -253,10 +253,11 @@ func (s *querySharding) shardQuery(ctx context.Context, query string, totalShard ctx, cancel := context.WithTimeout(ctx, shardingTimeout) defer cancel() - mapper, err := astmapper.NewSharding(ctx, totalShards, s.logger, stats) + summer, err := astmapper.NewQueryShardSummer(ctx, totalShards, astmapper.VectorSquasher, s.logger, stats) if err != nil { return "", nil, err } + mapper := astmapper.NewSharding(summer) // The mapper can modify the input expression in-place, so we must re-parse the original query // each time before passing it to the mapper.