From 39c350b93f1cc0a16fc4050a2db52d52fcf44083 Mon Sep 17 00:00:00 2001 From: auntan Date: Fri, 16 Sep 2022 18:51:31 +0500 Subject: [PATCH] fix: prevent metrics leak with cleanup Keep track all components metrics and unregister them on close --- async_producer.go | 24 +++++++++------ broker.go | 62 ++++++++++++++------------------------ client_test.go | 24 +++++++++++++++ consumer.go | 17 ++++++----- consumer_group.go | 21 ++++++++----- metrics.go | 77 +++++++++++++++++++++++++++++++++++++++++++++++ produce_set.go | 2 +- 7 files changed, 161 insertions(+), 66 deletions(-) diff --git a/async_producer.go b/async_producer.go index 07141e045..7574665bc 100644 --- a/async_producer.go +++ b/async_producer.go @@ -10,6 +10,7 @@ import ( "github.com/eapache/go-resiliency/breaker" "github.com/eapache/queue" + "github.com/rcrowley/go-metrics" ) // AsyncProducer publishes Kafka messages using a non-blocking API. It routes messages @@ -122,6 +123,8 @@ type asyncProducer struct { brokerLock sync.Mutex txnmgr *transactionManager + + metricsRegistry metrics.Registry } // NewAsyncProducer creates a new AsyncProducer using the given broker addresses and configuration. @@ -154,15 +157,16 @@ func newAsyncProducer(client Client) (AsyncProducer, error) { } p := &asyncProducer{ - client: client, - conf: client.Config(), - errors: make(chan *ProducerError), - input: make(chan *ProducerMessage), - successes: make(chan *ProducerMessage), - retries: make(chan *ProducerMessage), - brokers: make(map[*Broker]*brokerProducer), - brokerRefs: make(map[*brokerProducer]int), - txnmgr: txnmgr, + client: client, + conf: client.Config(), + errors: make(chan *ProducerError), + input: make(chan *ProducerMessage), + successes: make(chan *ProducerMessage), + retries: make(chan *ProducerMessage), + brokers: make(map[*Broker]*brokerProducer), + brokerRefs: make(map[*brokerProducer]int), + txnmgr: txnmgr, + metricsRegistry: newCleanupRegistry(client.Config().MetricRegistry), } // launch our singleton dispatchers @@ -1134,6 +1138,8 @@ func (p *asyncProducer) shutdown() { close(p.retries) close(p.errors) close(p.successes) + + p.metricsRegistry.UnregisterAll() } func (p *asyncProducer) bumpIdempotentProducerEpoch() { diff --git a/broker.go b/broker.go index d857c7048..d1ac71deb 100644 --- a/broker.go +++ b/broker.go @@ -33,8 +33,7 @@ type Broker struct { responses chan *responsePromise done chan bool - registeredMetrics map[string]struct{} - + metricRegistry metrics.Registry incomingByteRate metrics.Meter requestRate metrics.Meter fetchRate metrics.Meter @@ -174,6 +173,8 @@ func (b *Broker) Open(conf *Config) error { b.lock.Lock() + b.metricRegistry = newCleanupRegistry(conf.MetricRegistry) + go withRecover(func() { defer func() { b.lock.Unlock() @@ -208,15 +209,15 @@ func (b *Broker) Open(conf *Config) error { b.conf = conf // Create or reuse the global metrics shared between brokers - b.incomingByteRate = metrics.GetOrRegisterMeter("incoming-byte-rate", conf.MetricRegistry) - b.requestRate = metrics.GetOrRegisterMeter("request-rate", conf.MetricRegistry) - b.fetchRate = metrics.GetOrRegisterMeter("consumer-fetch-rate", conf.MetricRegistry) - b.requestSize = getOrRegisterHistogram("request-size", conf.MetricRegistry) - b.requestLatency = getOrRegisterHistogram("request-latency-in-ms", conf.MetricRegistry) - b.outgoingByteRate = metrics.GetOrRegisterMeter("outgoing-byte-rate", conf.MetricRegistry) - b.responseRate = metrics.GetOrRegisterMeter("response-rate", conf.MetricRegistry) - b.responseSize = getOrRegisterHistogram("response-size", conf.MetricRegistry) - b.requestsInFlight = metrics.GetOrRegisterCounter("requests-in-flight", conf.MetricRegistry) + b.incomingByteRate = metrics.GetOrRegisterMeter("incoming-byte-rate", b.metricRegistry) + b.requestRate = metrics.GetOrRegisterMeter("request-rate", b.metricRegistry) + b.fetchRate = metrics.GetOrRegisterMeter("consumer-fetch-rate", b.metricRegistry) + b.requestSize = getOrRegisterHistogram("request-size", b.metricRegistry) + b.requestLatency = getOrRegisterHistogram("request-latency-in-ms", b.metricRegistry) + b.outgoingByteRate = metrics.GetOrRegisterMeter("outgoing-byte-rate", b.metricRegistry) + b.responseRate = metrics.GetOrRegisterMeter("response-rate", b.metricRegistry) + b.responseSize = getOrRegisterHistogram("response-size", b.metricRegistry) + b.requestsInFlight = metrics.GetOrRegisterCounter("requests-in-flight", b.metricRegistry) // Do not gather metrics for seeded broker (only used during bootstrap) because they share // the same id (-1) and are already exposed through the global metrics above if b.id >= 0 && !metrics.UseNilMetrics { @@ -319,7 +320,7 @@ func (b *Broker) Close() error { b.done = nil b.responses = nil - b.unregisterMetrics() + b.metricRegistry.UnregisterAll() if err == nil { DebugLogger.Printf("Closed connection to broker %s\n", b.addr) @@ -435,7 +436,7 @@ func (b *Broker) AsyncProduce(request *ProduceRequest, cb ProduceCallback) error return } - if err := versionedDecode(packets, res, request.version(), b.conf.MetricRegistry); err != nil { + if err := versionedDecode(packets, res, request.version(), b.metricRegistry); err != nil { // Malformed response cb(nil, err) return @@ -979,7 +980,7 @@ func (b *Broker) sendInternal(rb protocolBody, promise *responsePromise) error { } req := &request{correlationID: b.correlationID, clientID: b.conf.ClientID, body: rb} - buf, err := encode(req, b.conf.MetricRegistry) + buf, err := encode(req, b.metricRegistry) if err != nil { return err } @@ -1029,7 +1030,7 @@ func (b *Broker) sendAndReceive(req protocolBody, res protocolBody) error { func (b *Broker) handleResponsePromise(req protocolBody, res protocolBody, promise *responsePromise) error { select { case buf := <-promise.packets: - return versionedDecode(buf, res, req.version(), b.conf.MetricRegistry) + return versionedDecode(buf, res, req.version(), b.metricRegistry) case err := <-promise.errors: return err } @@ -1121,7 +1122,7 @@ func (b *Broker) responseReceiver() { } decodedHeader := responseHeader{} - err = versionedDecode(header, &decodedHeader, response.headerVersion, b.conf.MetricRegistry) + err = versionedDecode(header, &decodedHeader, response.headerVersion, b.metricRegistry) if err != nil { b.updateIncomingCommunicationMetrics(bytesReadHeader, requestLatency) dead = err @@ -1243,7 +1244,7 @@ func (b *Broker) sendAndReceiveSASLHandshake(saslType SASLMechanism, version int rb := &SaslHandshakeRequest{Mechanism: string(saslType), Version: version} req := &request{correlationID: b.correlationID, clientID: b.conf.ClientID, body: rb} - buf, err := encode(req, b.conf.MetricRegistry) + buf, err := encode(req, b.metricRegistry) if err != nil { return err } @@ -1280,7 +1281,7 @@ func (b *Broker) sendAndReceiveSASLHandshake(saslType SASLMechanism, version int b.updateIncomingCommunicationMetrics(n+8, time.Since(requestTime)) res := &SaslHandshakeResponse{} - err = versionedDecode(payload, res, 0, b.conf.MetricRegistry) + err = versionedDecode(payload, res, 0, b.metricRegistry) if err != nil { Logger.Printf("Failed to parse SASL handshake : %s\n", err.Error()) return err @@ -1622,38 +1623,19 @@ func (b *Broker) registerMetrics() { b.brokerThrottleTime = b.registerHistogram("throttle-time-in-ms") } -func (b *Broker) unregisterMetrics() { - for name := range b.registeredMetrics { - b.conf.MetricRegistry.Unregister(name) - } - b.registeredMetrics = nil -} - func (b *Broker) registerMeter(name string) metrics.Meter { nameForBroker := getMetricNameForBroker(name, b) - if b.registeredMetrics == nil { - b.registeredMetrics = map[string]struct{}{} - } - b.registeredMetrics[nameForBroker] = struct{}{} - return metrics.GetOrRegisterMeter(nameForBroker, b.conf.MetricRegistry) + return metrics.GetOrRegisterMeter(nameForBroker, b.metricRegistry) } func (b *Broker) registerHistogram(name string) metrics.Histogram { nameForBroker := getMetricNameForBroker(name, b) - if b.registeredMetrics == nil { - b.registeredMetrics = map[string]struct{}{} - } - b.registeredMetrics[nameForBroker] = struct{}{} - return getOrRegisterHistogram(nameForBroker, b.conf.MetricRegistry) + return getOrRegisterHistogram(nameForBroker, b.metricRegistry) } func (b *Broker) registerCounter(name string) metrics.Counter { nameForBroker := getMetricNameForBroker(name, b) - if b.registeredMetrics == nil { - b.registeredMetrics = map[string]struct{}{} - } - b.registeredMetrics[nameForBroker] = struct{}{} - return metrics.GetOrRegisterCounter(nameForBroker, b.conf.MetricRegistry) + return metrics.GetOrRegisterCounter(nameForBroker, b.metricRegistry) } func validServerNameTLS(addr string, cfg *tls.Config) *tls.Config { diff --git a/client_test.go b/client_test.go index 79f1e85dd..c0b1d1ab3 100644 --- a/client_test.go +++ b/client_test.go @@ -8,6 +8,8 @@ import ( "syscall" "testing" "time" + + "github.com/rcrowley/go-metrics" ) func safeClose(t testing.TB, c io.Closer) { @@ -1096,3 +1098,25 @@ func TestInitProducerIDConnectionRefused(t *testing.T) { safeClose(t, client) } + +func TestMetricsCleanup(t *testing.T) { + seedBroker := NewMockBroker(t, 1) + seedBroker.Returns(new(MetadataResponse)) + + config := NewTestConfig() + metrics.GetOrRegisterMeter("a", config.MetricRegistry) + + client, err := NewClient([]string{seedBroker.Addr()}, config) + if err != nil { + t.Fatal(err) + } + safeClose(t, client) + + // Wait async close + time.Sleep(10 * time.Millisecond) + + all := config.MetricRegistry.GetAll() + if len(all) != 1 || all["a"] == nil { + t.Errorf("excepted 1 metric, found: %v", all) + } +} diff --git a/consumer.go b/consumer.go index 0a9a6c31b..216343c73 100644 --- a/consumer.go +++ b/consumer.go @@ -104,6 +104,7 @@ type consumer struct { children map[string]map[int32]*partitionConsumer brokerConsumers map[*Broker]*brokerConsumer client Client + metricRegistry metrics.Registry lock sync.Mutex } @@ -136,12 +137,14 @@ func newConsumer(client Client) (Consumer, error) { conf: client.Config(), children: make(map[string]map[int32]*partitionConsumer), brokerConsumers: make(map[*Broker]*brokerConsumer), + metricRegistry: newCleanupRegistry(client.Config().MetricRegistry), } return c, nil } func (c *consumer) Close() error { + c.metricRegistry.UnregisterAll() return c.client.Close() } @@ -678,13 +681,9 @@ func (child *partitionConsumer) parseRecords(batch *RecordBatch) ([]*ConsumerMes } func (child *partitionConsumer) parseResponse(response *FetchResponse) ([]*ConsumerMessage, error) { - var ( - metricRegistry = child.conf.MetricRegistry - consumerBatchSizeMetric metrics.Histogram - ) - - if metricRegistry != nil { - consumerBatchSizeMetric = getOrRegisterHistogram("consumer-batch-size", metricRegistry) + var consumerBatchSizeMetric metrics.Histogram + if child.consumer != nil && child.consumer.metricRegistry != nil { + consumerBatchSizeMetric = getOrRegisterHistogram("consumer-batch-size", child.consumer.metricRegistry) } // If request was throttled and empty we log and return without error @@ -709,7 +708,9 @@ func (child *partitionConsumer) parseResponse(response *FetchResponse) ([]*Consu return nil, err } - consumerBatchSizeMetric.Update(int64(nRecs)) + if consumerBatchSizeMetric != nil { + consumerBatchSizeMetric.Update(int64(nRecs)) + } if block.PreferredReadReplica != invalidPreferredReplicaID { child.preferredReadReplica = block.PreferredReadReplica diff --git a/consumer_group.go b/consumer_group.go index b20edd978..f0a8333e7 100644 --- a/consumer_group.go +++ b/consumer_group.go @@ -91,6 +91,8 @@ type consumerGroup struct { closeOnce sync.Once userData []byte + + metricRegistry metrics.Registry } // NewConsumerGroup creates a new consumer group the given broker addresses and configuration. @@ -129,13 +131,14 @@ func newConsumerGroup(groupID string, client Client) (ConsumerGroup, error) { } cg := &consumerGroup{ - client: client, - consumer: consumer, - config: config, - groupID: groupID, - errors: make(chan error, config.ChannelBufferSize), - closed: make(chan none), - userData: config.Consumer.Group.Member.UserData, + client: client, + consumer: consumer, + config: config, + groupID: groupID, + errors: make(chan error, config.ChannelBufferSize), + closed: make(chan none), + userData: config.Consumer.Group.Member.UserData, + metricRegistry: newCleanupRegistry(config.MetricRegistry), } if client.Config().Consumer.Group.InstanceId != "" && config.Version.IsAtLeast(V2_3_0_0) { cg.groupInstanceId = &client.Config().Consumer.Group.InstanceId @@ -167,6 +170,8 @@ func (c *consumerGroup) Close() (err error) { if e := c.client.Close(); e != nil { err = e } + + c.metricRegistry.UnregisterAll() }) return } @@ -261,7 +266,7 @@ func (c *consumerGroup) newSession(ctx context.Context, topics []string, handler } var ( - metricRegistry = c.config.MetricRegistry + metricRegistry = c.metricRegistry consumerGroupJoinTotal metrics.Counter consumerGroupJoinFailed metrics.Counter consumerGroupSyncTotal metrics.Counter diff --git a/metrics.go b/metrics.go index 90e5a87f4..7b7705f2e 100644 --- a/metrics.go +++ b/metrics.go @@ -3,6 +3,7 @@ package sarama import ( "fmt" "strings" + "sync" "github.com/rcrowley/go-metrics" ) @@ -41,3 +42,79 @@ func getOrRegisterTopicMeter(name string, topic string, r metrics.Registry) metr func getOrRegisterTopicHistogram(name string, topic string, r metrics.Registry) metrics.Histogram { return getOrRegisterHistogram(getMetricNameForTopic(name, topic), r) } + +// cleanupRegistry is an implementation of metrics.Registry that allows +// to unregister from the parent registry only those metrics +// that have been registered in cleanupRegistry +type cleanupRegistry struct { + parent metrics.Registry + metrics map[string]struct{} + mutex sync.RWMutex +} + +func newCleanupRegistry(parent metrics.Registry) metrics.Registry { + return &cleanupRegistry{ + parent: parent, + metrics: map[string]struct{}{}, + } +} + +func (r *cleanupRegistry) Each(fn func(string, interface{})) { + r.mutex.RLock() + defer r.mutex.RUnlock() + wrappedFn := func(name string, iface interface{}) { + if _, ok := r.metrics[name]; ok { + fn(name, iface) + } + } + r.parent.Each(wrappedFn) +} + +func (r *cleanupRegistry) Get(name string) interface{} { + r.mutex.RLock() + defer r.mutex.RUnlock() + if _, ok := r.metrics[name]; ok { + return r.parent.Get(name) + } + return nil +} + +func (r *cleanupRegistry) GetOrRegister(name string, metric interface{}) interface{} { + r.mutex.Lock() + defer r.mutex.Unlock() + r.metrics[name] = struct{}{} + return r.parent.GetOrRegister(name, metric) +} + +func (r *cleanupRegistry) Register(name string, metric interface{}) error { + r.mutex.Lock() + defer r.mutex.Unlock() + r.metrics[name] = struct{}{} + return r.parent.Register(name, metric) +} + +func (r *cleanupRegistry) RunHealthchecks() { + r.parent.RunHealthchecks() +} + +func (r *cleanupRegistry) GetAll() map[string]map[string]interface{} { + return r.parent.GetAll() +} + +func (r *cleanupRegistry) Unregister(name string) { + r.mutex.Lock() + defer r.mutex.Unlock() + if _, ok := r.metrics[name]; ok { + delete(r.metrics, name) + r.parent.Unregister(name) + } +} + +func (r *cleanupRegistry) UnregisterAll() { + r.mutex.Lock() + defer r.mutex.Unlock() + for name := range r.metrics { + delete(r.metrics, name) + r.parent.Unregister(name) + } +} diff --git a/produce_set.go b/produce_set.go index c3ba78f89..b8cc5ceb7 100644 --- a/produce_set.go +++ b/produce_set.go @@ -181,7 +181,7 @@ func (ps *produceSet) buildRequest() *ProduceRequest { msg.Offset = int64(i) } } - payload, err := encode(set.recordsToSend.MsgSet, ps.parent.conf.MetricRegistry) + payload, err := encode(set.recordsToSend.MsgSet, ps.parent.metricsRegistry) if err != nil { Logger.Println(err) // if this happens, it's basically our fault. panic(err)