diff --git a/pkg/promclient/labelfilter.go b/pkg/promclient/labelfilter.go new file mode 100644 index 000000000..32493aa11 --- /dev/null +++ b/pkg/promclient/labelfilter.go @@ -0,0 +1,325 @@ +package promclient + +import ( + "context" + "fmt" + "sync" + "sync/atomic" + "time" + + v1 "github.com/prometheus/client_golang/api/prometheus/v1" + "github.com/prometheus/client_golang/prometheus" + "github.com/prometheus/common/model" + "github.com/prometheus/prometheus/model/labels" + "github.com/prometheus/prometheus/promql/parser" + "github.com/sirupsen/logrus" +) + +// Metrics +var ( + syncCount = prometheus.NewCounterVec(prometheus.CounterOpts{ + Name: "promxy_label_filter_sync_count_total", + Help: "How many syncs completed from a promxy label_filter, partitioned by success", + }, []string{"status"}) + syncSummary = prometheus.NewSummaryVec(prometheus.SummaryOpts{ + Name: "promxy_label_filter_sync_duration_seconds", + Help: "Latency of sync process from a promxy label_fitler", + }, []string{"status"}) + filteredCount = prometheus.NewCounterVec(prometheus.CounterOpts{ + Name: "promxy_label_filter_filtered_count_total", + Help: "How many requests have been filtered from the downstream,, partitioned by query type", + }, []string{"type"}) +) + +func init() { + prometheus.MustRegister( + syncCount, + syncSummary, + ) +} + +// LabelFilterConfig is the configuraiton for the LabelFilterClient +type LabelFilterConfig struct { + // DynamicLabels is a list of labels to dynamically maintain a filter from the downstream from + DynamicLabels []string `yaml:"dynamic_labels"` + // SyncInterval defines how frequenlty to update the dynamic label filter + SyncInterval time.Duration `yaml:"sync_interval"` + // StaticLabelsInclude is a set of labels to always add to the downstream filter + // this allows you to define some metrics to be included statically if you want to + // avoid polling the downstream. + // NOTE: this is not a "secure" measure as this entire label_filter is based on matchers + // and as such doesn't restrict which metrics they touch (e.g. if you restrict by `__name__` + // the could just query by another label). + StaticLabelsInclude map[string][]string `yaml:"static_labels_include"` + // StaticLabelsExclude is a set of labels to always exclude from the filter. This is done last + // so it will apply after the dynamic and static lists are added to the filter. + StaticLabelsExclude map[string][]string `yaml:"static_labels_exclude"` +} + +func (c *LabelFilterConfig) Validate() error { + for _, l := range c.DynamicLabels { + if !model.IsValidMetricName(model.LabelValue(l)) { + return fmt.Errorf("%s is not a valid label name", l) + } + } + + if c.SyncInterval > 0 && len(c.DynamicLabels) == 0 { + return fmt.Errorf("sync_interval requires `dynamic_labels_include` to be set") + } + + return nil +} + +// UnmarshalYAML implements the yaml.Unmarshaler interface. +func (c *LabelFilterConfig) UnmarshalYAML(unmarshal func(interface{}) error) error { + type plain LabelFilterConfig + if err := unmarshal((*plain)(c)); err != nil { + return err + } + + return c.Validate() +} + +// NewLabelFilterClient returns a LabelFilterClient which will filter the queries sent downstream based +// on a filter of labels maintained in memory from the downstream API. +func NewLabelFilterClient(ctx context.Context, a API, cfg *LabelFilterConfig) (*LabelFilterClient, error) { + c := &LabelFilterClient{ + API: a, + ctx: ctx, + cfg: cfg, + } + + // Do an initial sync + if err := c.Sync(ctx); err != nil { + return nil, err + } + + if cfg.SyncInterval > 0 { + go func() { + ticker := time.NewTicker(cfg.SyncInterval) + for { + select { + case <-ticker.C: + start := time.Now() + err := c.Sync(ctx) + took := time.Since(start) + status := "success" + if err != nil { + logrus.Errorf("error syncing in label_filter from downstream: %#v", err) + status = "error" + } + syncCount.WithLabelValues(status).Inc() + syncSummary.WithLabelValues(status).Observe(took.Seconds()) + + case <-ctx.Done(): + ticker.Stop() + return + } + } + }() + } + + return c, nil +} + +// LabelFilterClient filters out calls to the downstream based on a label filter +// which is pulled and maintained from the downstream API. +type LabelFilterClient struct { + API + + // filter is an atomic to hold the LabelFilter which is a map of labelName -> labelValue -> nothing (for quick lookups) + filter atomic.Value + + // Used as the background context for this client + ctx context.Context + + // cfg is a pointer to the config for this client + cfg *LabelFilterConfig +} + +// State returns the current ServerGroupState +func (c *LabelFilterClient) LabelFilter() map[string]map[string]struct{} { + tmp := c.filter.Load() + if ret, ok := tmp.(map[string]map[string]struct{}); ok { + return ret + } + return nil +} + +func (c *LabelFilterClient) Sync(ctx context.Context) error { + filter := make(map[string]map[string]struct{}) + + for _, label := range c.cfg.DynamicLabels { + labelFilter := make(map[string]struct{}) + // TODO: warn? + vals, _, err := c.LabelValues(ctx, label, nil, model.Time(0).Time(), model.Now().Time()) + if err != nil { + return err + } + for _, v := range vals { + labelFilter[string(v)] = struct{}{} + } + filter[label] = labelFilter + } + + // Apply static include list + for k, vList := range c.cfg.StaticLabelsInclude { + filterMap, ok := filter[k] + if !ok { + filterMap = make(map[string]struct{}) + } + for _, item := range vList { + filterMap[item] = struct{}{} + } + filter[k] = filterMap + } + + // Apply exclude list + for k, vList := range c.cfg.StaticLabelsExclude { + if filterMap, ok := filter[k]; ok { + for _, item := range vList { + delete(filterMap, item) + } + filter[k] = filterMap + } + } + + c.filter.Store(filter) + + return nil +} + +// Query performs a query for the given time. +func (c *LabelFilterClient) Query(ctx context.Context, query string, ts time.Time) (model.Value, v1.Warnings, error) { + // Parse out the promql query into expressions etc. + e, err := parser.ParseExpr(query) + if err != nil { + return nil, nil, err + } + + filterVisitor := NewFilterLabelVisitor(c.LabelFilter()) + if _, err := parser.Walk(ctx, filterVisitor, &parser.EvalStmt{Expr: e}, e, nil, nil); err != nil { + return nil, nil, err + } + if !filterVisitor.filterMatch { + filteredCount.WithLabelValues("Query").Inc() + return nil, nil, nil + } + + return c.API.Query(ctx, query, ts) +} + +// Query performs a query for the given time. +func (c *LabelFilterClient) QueryRange(ctx context.Context, query string, r v1.Range) (model.Value, v1.Warnings, error) { + // Parse out the promql query into expressions etc. + e, err := parser.ParseExpr(query) + if err != nil { + return nil, nil, err + } + + filterVisitor := NewFilterLabelVisitor(c.LabelFilter()) + if _, err := parser.Walk(ctx, filterVisitor, &parser.EvalStmt{Expr: e}, e, nil, nil); err != nil { + return nil, nil, err + } + if !filterVisitor.filterMatch { + filteredCount.WithLabelValues("QueryRange").Inc() + return nil, nil, nil + } + + return c.API.QueryRange(ctx, query, r) +} + +// Series finds series by label matchers. +func (c *LabelFilterClient) Series(ctx context.Context, matches []string, startTime time.Time, endTime time.Time) ([]model.LabelSet, v1.Warnings, error) { + for _, m := range matches { + matchers, err := parser.ParseMetricSelector(m) + if err != nil { + return nil, nil, err + } + // check if the matcher is excluded by our filter + for _, matcher := range matchers { + if !FilterLabelMatchers(c.LabelFilter(), matcher) { + filteredCount.WithLabelValues("Series").Inc() + return nil, nil, nil + } + } + } + return c.API.Series(ctx, matches, startTime, endTime) +} + +// GetValue loads the raw data for a given set of matchers in the time range +func (c *LabelFilterClient) GetValue(ctx context.Context, start, end time.Time, matchers []*labels.Matcher) (model.Value, v1.Warnings, error) { + // check if the matcher is excluded by our filter + for _, matcher := range matchers { + if !FilterLabelMatchers(c.LabelFilter(), matcher) { + filteredCount.WithLabelValues("GetValue").Inc() + return nil, nil, nil + } + } + return c.API.GetValue(ctx, start, end, matchers) +} + +// Metadata returns metadata about metrics currently scraped by the metric name. +func (c *LabelFilterClient) Metadata(ctx context.Context, metric, limit string) (map[string][]v1.Metadata, error) { + matcher, err := labels.NewMatcher(labels.MatchEqual, labels.MetricName, metric) + if err != nil { + return nil, err + } + if !FilterLabelMatchers(c.LabelFilter(), matcher) { + filteredCount.WithLabelValues("Metadata").Inc() + return nil, nil + } + return c.API.Metadata(ctx, metric, limit) +} + +func NewFilterLabelVisitor(filter map[string]map[string]struct{}) *FilterLabelVisitor { + return &FilterLabelVisitor{ + labelFilter: filter, + filterMatch: true, + } +} + +// FilterLabel implements the parser.Visitor interface to filter selectors based on a labelstet +type FilterLabelVisitor struct { + l sync.Mutex + labelFilter map[string]map[string]struct{} + filterMatch bool +} + +// Visit checks if the given node matches the labels in the filter +func (l *FilterLabelVisitor) Visit(node parser.Node, path []parser.Node) (w parser.Visitor, err error) { + switch nodeTyped := node.(type) { + case *parser.VectorSelector: + for _, matcher := range nodeTyped.LabelMatchers { + if !FilterLabelMatchers(l.labelFilter, matcher) { + l.l.Lock() + l.filterMatch = false + l.l.Unlock() + return nil, nil + } + } + } + + return l, nil +} + +// TODO: better name, this is to check if a matcher is in the filter +func FilterLabelMatchers(filter map[string]map[string]struct{}, matcher *labels.Matcher) bool { + for labelName, labelFilter := range filter { + if matcher.Name == labelName { + match := false + // Check that there is a match somewhere! + for v := range labelFilter { + if matcher.Matches(v) { + match = true + break + } + } + if !match { + return match + } + } + } + + return true +} diff --git a/pkg/promclient/labelfilter_test.go b/pkg/promclient/labelfilter_test.go new file mode 100644 index 000000000..ded5d9fb7 --- /dev/null +++ b/pkg/promclient/labelfilter_test.go @@ -0,0 +1,240 @@ +package promclient + +import ( + "context" + "strconv" + "testing" + "time" + + v1 "github.com/prometheus/client_golang/api/prometheus/v1" + "github.com/prometheus/common/model" + "github.com/prometheus/prometheus/model/labels" + "github.com/prometheus/prometheus/promql/parser" +) + +func newCountAPI(a API) *countAPI { + return &countAPI{ + API: a, + callCount: map[string]int{ + "LabelNames": 0, + "LabelValues": 0, + "Query": 0, + "QueryRange": 0, + "Series": 0, + "GetValue": 0, + "Metadata": 0, + }, + } +} + +type countAPI struct { + API + callCount map[string]int +} + +// LabelNames returns all the unique label names present in the block in sorted order. +func (s *countAPI) LabelNames(ctx context.Context, matchers []string, startTime time.Time, endTime time.Time) ([]string, v1.Warnings, error) { + s.callCount["LabelNames"]++ + return s.API.LabelNames(ctx, matchers, startTime, endTime) +} + +// LabelValues performs a query for the values of the given label. +func (s *countAPI) LabelValues(ctx context.Context, label string, matchers []string, startTime time.Time, endTime time.Time) (model.LabelValues, v1.Warnings, error) { + s.callCount["LabelValues"]++ + return s.API.LabelValues(ctx, label, matchers, startTime, endTime) +} + +// Query performs a query for the given time. +func (s *countAPI) Query(ctx context.Context, query string, ts time.Time) (model.Value, v1.Warnings, error) { + s.callCount["Query"]++ + return s.API.Query(ctx, query, ts) +} + +// QueryRange performs a query for the given range. +func (s *countAPI) QueryRange(ctx context.Context, query string, r v1.Range) (model.Value, v1.Warnings, error) { + s.callCount["QueryRange"]++ + return s.API.QueryRange(ctx, query, r) +} + +// Series finds series by label matchers. +func (s *countAPI) Series(ctx context.Context, matches []string, startTime time.Time, endTime time.Time) ([]model.LabelSet, v1.Warnings, error) { + s.callCount["Series"]++ + return s.API.Series(ctx, matches, startTime, endTime) +} + +// GetValue loads the raw data for a given set of matchers in the time range +func (s *countAPI) GetValue(ctx context.Context, start, end time.Time, matchers []*labels.Matcher) (model.Value, v1.Warnings, error) { + s.callCount["GetValue"]++ + return s.API.GetValue(ctx, start, end, matchers) +} + +// Metadata returns metadata about metrics currently scraped by the metric name. +func (s *countAPI) Metadata(ctx context.Context, metric, limit string) (map[string][]v1.Metadata, error) { + s.callCount["Metadata"]++ + return s.API.Metadata(ctx, metric, limit) +} + +func TestLabelFilter(t *testing.T) { + /* + + The idea here is that the datasource has the following data: + + up{filterlabel="a"} + up{filterlabel="b"} + testmetric{filterlabel="a"} + testmetric{filterlabel="b"} + + */ + + stub := &stubAPI{ + // Override the LabelValues endpoint (which is the one that LabelFilter uses to determine its filter) + labelValues: func(label string) model.LabelValues { + switch label { + case "__name__": + return model.LabelValues{ + "up", + "testmetric", + } + case "filterlabel": + return model.LabelValues{ + "a", + "b", + } + } + return model.LabelValues{} + }, + } + + // Wrap the stub in a counter + countAPI := newCountAPI(stub) + + // Set up some vars + ctx := context.TODO() // TODO + + // Create the LabelFilter client + cfg := &LabelFilterConfig{ + DynamicLabels: []string{"__name__", "filterlabel"}, + StaticLabelsInclude: map[string][]string{ + "__name__": {"staticinclude"}, + }, + StaticLabelsExclude: map[string][]string{ + "__name__": {"up"}, + }, + } + + filterClient, err := NewLabelFilterClient(ctx, countAPI, cfg) + if err != nil { + t.Fatal(err) + } + + tests := []struct { + query string // query to run + callCount int // how many calls expected + }{ + {query: "notametric"}, // A metric that definitely doesn't exist + {query: "testmetric", callCount: 1}, // A metric that does exist + {query: "staticinclude", callCount: 1}, // A metric that statically exists + {query: "up"}, // A metric that does exist, but we filter out + {query: `{filterlabel="notavalue"}`}, // A metric that definitely doesn't exist + {query: `{notalabel="notavalue"}`, callCount: 1}, // A metric that definitely doesn't exist, but isn't filterable + {query: `{filterlabel="a"}`, callCount: 1}, // A metric that does exist + {query: `{filterlabel="b"}`, callCount: 1}, // A metric that does exist + } + + t.Run("Query", func(t *testing.T) { + for i, test := range tests { + t.Run(strconv.Itoa(i), func(t *testing.T) { + beforeCount := countAPI.callCount["Query"] + _, _, err := filterClient.Query(ctx, test.query, model.Time(100).Time()) + if err != nil { + t.Fatal(err) + } + callCount := countAPI.callCount["Query"] - beforeCount + if test.callCount != callCount { + t.Fatalf("mismatch in callCount when running %s expected=%d actual=%d", test.query, test.callCount, callCount) + } + }) + } + }) + + t.Run("QueryRange", func(t *testing.T) { + for i, test := range tests { + t.Run(strconv.Itoa(i), func(t *testing.T) { + beforeCount := countAPI.callCount["QueryRange"] + _, _, err := filterClient.QueryRange(ctx, test.query, v1.Range{Start: model.Time(0).Time(), End: model.Time(100).Time(), Step: time.Millisecond}) + if err != nil { + t.Fatal(err) + } + callCount := countAPI.callCount["QueryRange"] - beforeCount + if test.callCount != callCount { + t.Fatalf("mismatch in callCount when running %s expected=%d actual=%d", test.query, test.callCount, callCount) + } + }) + } + }) + + t.Run("Series", func(t *testing.T) { + for i, test := range tests { + t.Run(strconv.Itoa(i), func(t *testing.T) { + beforeCount := countAPI.callCount["Series"] + _, _, err := filterClient.Series(ctx, []string{test.query}, model.Time(0).Time(), model.Time(100).Time()) + if err != nil { + t.Fatal(err) + } + callCount := countAPI.callCount["Series"] - beforeCount + if test.callCount != callCount { + t.Fatalf("mismatch in callCount when running %s expected=%d actual=%d", test.query, test.callCount, callCount) + } + }) + } + }) + + t.Run("GetValue", func(t *testing.T) { + for i, test := range tests { + t.Run(strconv.Itoa(i), func(t *testing.T) { + beforeCount := countAPI.callCount["GetValue"] + + // TODO: convert query to matchers + matchers, err := parser.ParseMetricSelector(test.query) + if err != nil { + t.Fatal(err) + } + + _, _, err = filterClient.GetValue(ctx, model.Time(0).Time(), model.Time(100).Time(), matchers) + if err != nil { + t.Fatal(err) + } + callCount := countAPI.callCount["GetValue"] - beforeCount + if test.callCount != callCount { + t.Fatalf("mismatch in callCount when running %s expected=%d actual=%d", test.query, test.callCount, callCount) + } + }) + } + }) + + t.Run("Metadata", func(t *testing.T) { + tests := []struct { + metric string // query to run + callCount int // how many calls expected + }{ + {metric: "notametric"}, // A metric that definitely doesn't exist + {metric: "testmetric", callCount: 1}, // A metric that does exist + {metric: "up"}, // A metric that does exist, but we filter out + } + + for i, test := range tests { + t.Run(strconv.Itoa(i), func(t *testing.T) { + beforeCount := countAPI.callCount["Metadata"] + _, err := filterClient.Metadata(ctx, test.metric, "") + if err != nil { + t.Fatal(err) + } + callCount := countAPI.callCount["Metadata"] - beforeCount + if test.callCount != callCount { + t.Fatalf("mismatch in callCount when running %s expected=%d actual=%d", test.metric, test.callCount, callCount) + } + }) + } + }) + +} diff --git a/pkg/promclient/multi_api_test.go b/pkg/promclient/multi_api_test.go index 8c8338673..7440fc5f8 100644 --- a/pkg/promclient/multi_api_test.go +++ b/pkg/promclient/multi_api_test.go @@ -14,7 +14,7 @@ import ( type stubAPI struct { labelNames func() []string - labelValues func() model.LabelValues + labelValues func(label string) model.LabelValues query func() model.Value queryRange func() model.Value series func() []model.LabelSet @@ -24,36 +24,57 @@ type stubAPI struct { // LabelNames returns all the unique label names present in the block in sorted order. func (s *stubAPI) LabelNames(ctx context.Context, matchers []string, startTime time.Time, endTime time.Time) ([]string, v1.Warnings, error) { + if s.labelNames == nil { + return nil, nil, nil + } return s.labelNames(), nil, nil } // LabelValues performs a query for the values of the given label. func (s *stubAPI) LabelValues(ctx context.Context, label string, matchers []string, startTime time.Time, endTime time.Time) (model.LabelValues, v1.Warnings, error) { - return s.labelValues(), nil, nil + if s.labelValues == nil { + return nil, nil, nil + } + return s.labelValues(label), nil, nil } // Query performs a query for the given time. func (s *stubAPI) Query(ctx context.Context, query string, ts time.Time) (model.Value, v1.Warnings, error) { + if s.query == nil { + return nil, nil, nil + } return s.query(), nil, nil } // QueryRange performs a query for the given range. func (s *stubAPI) QueryRange(ctx context.Context, query string, r v1.Range) (model.Value, v1.Warnings, error) { + if s.queryRange == nil { + return nil, nil, nil + } return s.queryRange(), nil, nil } // Series finds series by label matchers. func (s *stubAPI) Series(ctx context.Context, matches []string, startTime time.Time, endTime time.Time) ([]model.LabelSet, v1.Warnings, error) { + if s.series == nil { + return nil, nil, nil + } return s.series(), nil, nil } // GetValue loads the raw data for a given set of matchers in the time range func (s *stubAPI) GetValue(ctx context.Context, start, end time.Time, matchers []*labels.Matcher) (model.Value, v1.Warnings, error) { + if s.getValue == nil { + return nil, nil, nil + } return s.getValue(), nil, nil } // Metadata returns metadata about metrics currently scraped by the metric name. func (s *stubAPI) Metadata(ctx context.Context, metric, limit string) (map[string][]v1.Metadata, error) { + if s.metadata == nil { + return nil, nil + } return s.metadata(), nil } @@ -130,7 +151,7 @@ func TestMultiAPIMerging(t *testing.T) { return []string{"a"} }, - labelValues: func() model.LabelValues { + labelValues: func(_ string) model.LabelValues { return model.LabelValues{} }, query: func() model.Value { diff --git a/pkg/server/api.go b/pkg/server/api.go index 4b0f9e2cb..3fae3cf9c 100644 --- a/pkg/server/api.go +++ b/pkg/server/api.go @@ -3,6 +3,7 @@ package server import ( "crypto/tls" "io" + "net" "net/http" "os" "time" @@ -17,17 +18,21 @@ import ( func CreateAndStart(bindAddr string, logFormat string, webReadTimeout time.Duration, accessLogOut io.Writer, router http.Handler, tlsConfigFile string) (*http.Server, error) { handler := createHandler(accessLogOut, router, logFormat) + ln, err := net.Listen("tcp", bindAddr) + if err != nil { + return nil, err + } srv := &http.Server{ - Addr: bindAddr, + Addr: ln.Addr().String(), Handler: handler, ReadTimeout: webReadTimeout, } if tlsConfigFile == "" { - return createAndStartHTTP(srv) + return createAndStartHTTP(srv, ln) } - return createAndStartHTTPS(srv, tlsConfigFile) + return createAndStartHTTPS(srv, ln, tlsConfigFile) } func createHandler(accessLogOut io.Writer, router http.Handler, logFormat string) http.Handler { @@ -46,12 +51,12 @@ func createHandler(accessLogOut io.Writer, router http.Handler, logFormat string return handler } -func createAndStartHTTP(srv *http.Server) (*http.Server, error) { +func createAndStartHTTP(srv *http.Server, ln net.Listener) (*http.Server, error) { srv.TLSConfig = nil go func() { logrus.Infof("promxy starting with HTTP...") - if err := srv.ListenAndServe(); err != nil { + if err := srv.Serve(ln); err != nil { if err == http.ErrServerClosed { return } @@ -61,7 +66,7 @@ func createAndStartHTTP(srv *http.Server) (*http.Server, error) { return srv, nil } -func createAndStartHTTPS(srv *http.Server, tlsConfigFile string) (*http.Server, error) { +func createAndStartHTTPS(srv *http.Server, ln net.Listener, tlsConfigFile string) (*http.Server, error) { tlsConfig, err := parseConfigFile(tlsConfigFile) if err != nil { return nil, err @@ -71,7 +76,7 @@ func createAndStartHTTPS(srv *http.Server, tlsConfigFile string) (*http.Server, go func() { logrus.Infof("promxy starting with TLS...") - if err := srv.ListenAndServeTLS("", ""); err != nil { + if err := srv.ServeTLS(ln, "", ""); err != nil { if err == http.ErrServerClosed { return } diff --git a/pkg/servergroup/config.go b/pkg/servergroup/config.go index e5b86f48a..088c6c4f7 100644 --- a/pkg/servergroup/config.go +++ b/pkg/servergroup/config.go @@ -160,6 +160,32 @@ type Config struct { // An example use-case would be if a specific servergroup was was "deprecated" and wasn't getting // any new data after a specific given point in time AbsoluteTimeRangeConfig *AbsoluteTimeRangeConfig `yaml:"absolute_time_range"` + + // LabelFilterConfig is a mechanism to restrict which queries are sent to the particular downstream. + // This is done by maintaining a "filter" of labels that are downstream and ensuring that the + // matchers for any particular query match that in-memory filter. This can be defined both + // statically and dynamically. + // NOTE: this is not a "secure" mechanism as it is relying on the query's matchers. So it is trivial + // for a malicious actor to work around this filter by changing matchers. + // Example: + // + // label_filter: + // # This will dynamically query the downstream for the values of `__name__` and `job` + // dynamic_labels: + // - __name__ + // - job + // # (optional) this will define a re-sync interval for dynamic labels from the downstream + // sync_interval: 5m + // # This will statically define a filter of labels + // static_labels_include: + // instance: + // - instance1 + // # This will statically define an exclusion list (removed from the filter)| + // static_labels_exclude: + // __name__: + // - up + + LabelFilterConfig *promclient.LabelFilterConfig `yaml:"label_filter"` } // GetScheme returns the scheme for this servergroup diff --git a/pkg/servergroup/servergroup.go b/pkg/servergroup/servergroup.go index 66a3db405..946e535c0 100644 --- a/pkg/servergroup/servergroup.go +++ b/pkg/servergroup/servergroup.go @@ -73,6 +73,9 @@ type ServerGroupState struct { // Targets is the list of target URLs for this discovery round Targets []string apiClient promclient.API + + ctx context.Context + ctxCancel context.CancelFunc } // ServerGroup encapsulates a set of prometheus downstreams to query/aggregate @@ -132,10 +135,17 @@ func (s *ServerGroup) Sync() { } } -func (s *ServerGroup) loadTargetGroupMap(targetGroupMap map[string][]*targetgroup.Group) error { +func (s *ServerGroup) loadTargetGroupMap(targetGroupMap map[string][]*targetgroup.Group) (err error) { targets := make([]string, 0) apiClients := make([]promclient.API, 0) + ctx, ctxCancel := context.WithCancel(context.Background()) + defer func() { + if err != nil { + ctxCancel() + } + }() + for _, targetGroupList := range targetGroupMap { for _, targetGroup := range targetGroupList { for _, target := range targetGroup.Targets { @@ -252,6 +262,14 @@ func (s *ServerGroup) loadTargetGroupMap(targetGroupMap map[string][]*targetgrou apiClient = &promclient.DebugAPI{apiClient, u.String()} } + // Add LabelFilter if configured + if s.Cfg.LabelFilterConfig != nil { + apiClient, err = promclient.NewLabelFilterClient(ctx, apiClient, s.Cfg.LabelFilterConfig) + if err != nil { + return err + } + } + apiClients = append(apiClients, apiClient) } } @@ -266,16 +284,23 @@ func (s *ServerGroup) loadTargetGroupMap(targetGroupMap map[string][]*targetgrou if err != nil { return err } + newState := &ServerGroupState{ Targets: targets, apiClient: apiClient, + ctx: ctx, + ctxCancel: ctxCancel, } if s.Cfg.IgnoreError { newState.apiClient = &promclient.IgnoreErrorAPI{newState.apiClient} } - s.state.Store(newState) + oldState := s.State() // Fetch the current state (so we can stop it) + s.state.Store(newState) // Store new state + if oldState != nil { + oldState.ctxCancel() // Cancel the old state + } if !s.loaded { s.loaded = true