From 1305217ace3ea14caa32600a7a30076e6996a612 Mon Sep 17 00:00:00 2001 From: hengyoush Date: Tue, 28 Jan 2025 16:03:54 +0800 Subject: [PATCH 1/2] fix: fix memory leak fix: fix test_filter_by_remote_port test --- agent/conn/kern_event_handler.go | 48 ++++---- agent/conn/processor.go | 89 ++++++++------- common/ringbuffer.go | 181 +++++++++++++++++++++++++++++++ testdata/test_filter_by_l4.sh | 8 +- 4 files changed, 255 insertions(+), 71 deletions(-) create mode 100644 common/ringbuffer.go diff --git a/agent/conn/kern_event_handler.go b/agent/conn/kern_event_handler.go index b62dbab9..555f07de 100644 --- a/agent/conn/kern_event_handler.go +++ b/agent/conn/kern_event_handler.go @@ -15,7 +15,7 @@ import ( type KernEventStream struct { conn *Connection4 - kernEvents map[bpf.AgentStepT][]KernEvent + kernEvents map[bpf.AgentStepT]*common.RingBuffer kernEventsMu sync.RWMutex sslInEvents []SslEvent sslOutEvents []SslEvent @@ -32,7 +32,7 @@ type KernEventStream struct { func NewKernEventStream(conn *Connection4, maxLen int) *KernEventStream { stream := &KernEventStream{ conn: conn, - kernEvents: make(map[bpf.AgentStepT][]KernEvent), + kernEvents: make(map[bpf.AgentStepT]*common.RingBuffer), maxLen: maxLen, } monitor.RegisterMetricExporter(stream) @@ -93,18 +93,19 @@ func (s *KernEventStream) AddKernEvent(event *bpf.AgentKernEvt) bool { s.discardEventsIfNeeded() if event.Len > 0 { if _, ok := s.kernEvents[event.Step]; !ok { - s.kernEvents[event.Step] = make([]KernEvent, 0) + s.kernEvents[event.Step] = common.NewRingBuffer(s.maxLen) } - kernEvtSlice := s.kernEvents[event.Step] - index, found := slices.BinarySearchFunc(kernEvtSlice, KernEvent{seq: event.Seq}, func(i KernEvent, j KernEvent) int { - return cmp.Compare(i.seq, j.seq) + kernEvtRingBuffer := s.kernEvents[event.Step] + index, found := kernEvtRingBuffer.BinarySearch(KernEvent{seq: event.Seq}, func(i any, j any) int { + return cmp.Compare(i.(KernEvent).seq, j.(KernEvent).seq) }) isNicEvnt := event.Step == bpf.AgentStepTDEV_OUT || event.Step == bpf.AgentStepTDEV_IN var kernEvent *KernEvent if found { - oldKernEvent := &kernEvtSlice[index] + _oldKernEvent, _ := kernEvtRingBuffer.ReadIndex(index) + oldKernEvent := _oldKernEvent.(KernEvent) if oldKernEvent.startTs > event.Ts && !isNicEvnt { // this is a duplicate event which belongs to a future conn oldKernEvent.seq = event.Seq @@ -112,12 +113,11 @@ func (s *KernEventStream) AddKernEvent(event *bpf.AgentKernEvt) bool { oldKernEvent.startTs = event.Ts oldKernEvent.tsDelta = event.TsDelta oldKernEvent.step = event.Step - kernEvent = oldKernEvent + kernEvent = &oldKernEvent } else if !isNicEvnt { - kernEvent = &kernEvtSlice[index] return false } else { - kernEvent = &kernEvtSlice[index] + kernEvent = &oldKernEvent } } else { kernEvent = &KernEvent{ @@ -143,17 +143,11 @@ func (s *KernEventStream) AddKernEvent(event *bpf.AgentKernEvt) bool { } } if !found { - kernEvtSlice = slices.Insert(kernEvtSlice, index, *kernEvent) - } - if len(kernEvtSlice) > s.maxLen { - if common.ConntrackLog.Level >= logrus.DebugLevel { - common.ConntrackLog.Debugf("kern event stream size: %d exceed maxLen", len(kernEvtSlice)) + if err := kernEvtRingBuffer.Insert(index, *kernEvent); err != nil { + common.ConntrackLog.Debugf("kern event stream size: %d exceed maxLen", kernEvtRingBuffer.MaxCapacity()) + return false } } - for len(kernEvtSlice) > s.maxLen { - kernEvtSlice = kernEvtSlice[1:] - } - s.kernEvents[event.Step] = kernEvtSlice } return true } @@ -200,7 +194,8 @@ func (s *KernEventStream) FindEventsBySeqAndLen(step bpf.AgentStepT, seq uint64, start := seq end := start + uint64(len) result := make([]KernEvent, 0) - for _, each := range events { + events.ForEach(func(i any) bool { + each := i.(KernEvent) if each.seq <= start && each.seq+uint64(each.len) > start { result = append(result, each) } else if each.seq < end && each.seq+uint64(each.len) >= end { @@ -208,9 +203,10 @@ func (s *KernEventStream) FindEventsBySeqAndLen(step bpf.AgentStepT, seq uint64, } else if each.seq >= start && each.seq+uint64(each.len) <= end { result = append(result, each) } else if each.seq > end { - break + return false } - } + return true + }) return result } @@ -272,12 +268,12 @@ func (s *KernEventStream) discardEventsBySeq(seq uint64, egress bool) { if !egress && !bpf.IsIngressStep(step) { continue } - index, _ := slices.BinarySearchFunc(events, KernEvent{seq: seq}, func(i KernEvent, j KernEvent) int { - return cmp.Compare(i.seq, j.seq) + index, _ := events.BinarySearch(KernEvent{seq: seq}, func(i any, j any) int { + return cmp.Compare(i.(KernEvent).seq, j.(KernEvent).seq) }) discardIdx := index if discardIdx > 0 { - s.kernEvents[step] = events[discardIdx:] + events.DiscardBeforeIndex(discardIdx) // common.ConntrackLog.Debugf("Discarded kern events, step: %d(egress: %v) events num: %d, cur len: %d", step, egress, discardIdx, len(s.kernEvents[step])) } } @@ -407,7 +403,7 @@ var _ monitor.MetricExporter = &KernEventStream{} func (s *KernEventStream) ExportMetrics() monitor.MetricMap { allEventsNum := 0 for _, events := range s.kernEvents { - allEventsNum += len(events) + allEventsNum += events.Size() } return monitor.MetricMap{ "events_num": float64(allEventsNum), diff --git a/agent/conn/processor.go b/agent/conn/processor.go index 9af95854..df26efd4 100644 --- a/agent/conn/processor.go +++ b/agent/conn/processor.go @@ -103,10 +103,10 @@ type Processor struct { side common.SideEnum recordProcessor *RecordsProcessor conntrackCloseWaitTimeMills int - tempKernEvents []TimedEvent - tempSyscallEvents []TimedSyscallEvent - tempSslEvents []TimedSslEvent - tempFirstPacketEvents []TimedFirstPacketEvent + tempKernEvents *common.RingBuffer + tempSyscallEvents *common.RingBuffer + tempSslEvents *common.RingBuffer + tempFirstPacketEvents *common.RingBuffer } type TimedEvent struct { @@ -149,10 +149,10 @@ func initProcessor(name string, wg *sync.WaitGroup, ctx context.Context, connMan records: make([]RecordWithConn, 0), } p.conntrackCloseWaitTimeMills = conntrackCloseWaitTimeMills - p.tempKernEvents = make([]TimedEvent, 0, 100) // Preallocate with a capacity of 100 - p.tempSyscallEvents = make([]TimedSyscallEvent, 0, 100) // Preallocate with a capacity of 100 - p.tempFirstPacketEvents = make([]TimedFirstPacketEvent, 0, 100) - p.tempSslEvents = make([]TimedSslEvent, 0, 100) // Preallocate with a capacity of 100 + p.tempKernEvents = common.NewRingBuffer(1000) // Preallocate with a capacity of 100 + p.tempSyscallEvents = common.NewRingBuffer(1000) // Preallocate with a capacity of 100 + p.tempFirstPacketEvents = common.NewRingBuffer(100) + p.tempSslEvents = common.NewRingBuffer(100) // Preallocate with a capacity of 100 return p } @@ -319,7 +319,7 @@ func (p *Processor) run() { func (p *Processor) handleFirstPacketEvent(event *agentKernEvtWithConn, recordChannel chan RecordWithConn) { // Add event to the temporary queue - p.tempFirstPacketEvents = append(p.tempFirstPacketEvents, TimedFirstPacketEvent{event: event, timestamp: time.Now()}) + p.tempFirstPacketEvents.Write(TimedFirstPacketEvent{event: event, timestamp: time.Now()}) // Process events in the queue that have been there for more than 100ms p.processOldFirstPacketEvents(recordChannel) } @@ -330,16 +330,17 @@ func (p *Processor) processTimedFirstPacketEvents(recordChannel chan RecordWithC func (p *Processor) processOldFirstPacketEvents(recordChannel chan RecordWithConn) { now := time.Now() - lastIndex := 0 - for i := 0; i < len(p.tempFirstPacketEvents); i++ { - if now.Sub(p.tempFirstPacketEvents[i].timestamp) > 100*time.Millisecond { - p.processFirstPacketEvent(p.tempFirstPacketEvents[i].event, recordChannel) - lastIndex = i + 1 - } else { + for !p.tempFirstPacketEvents.IsEmpty() { + _event, err := p.tempFirstPacketEvents.Peek() + if err != nil { break } + event := _event.(TimedFirstPacketEvent) + if now.Sub(event.timestamp) > 100*time.Millisecond { + p.processFirstPacketEvent(event.event, recordChannel) + p.tempFirstPacketEvents.Read() + } } - p.tempFirstPacketEvents = p.tempFirstPacketEvents[lastIndex:] } func (p *Processor) processFirstPacketEvent(event *agentKernEvtWithConn, recordChannel chan RecordWithConn) { @@ -351,7 +352,7 @@ func (p *Processor) processFirstPacketEvent(event *agentKernEvtWithConn, recordC func (p *Processor) handleKernEvent(event *bpf.AgentKernEvt, recordChannel chan RecordWithConn) { // Add event to the temporary queue - p.tempKernEvents = append(p.tempKernEvents, TimedEvent{event: event, timestamp: time.Now()}) + p.tempKernEvents.Write(TimedEvent{event: event, timestamp: time.Now()}) // Process events in the queue that have been there for more than 100ms p.processOldKernEvents(recordChannel) @@ -363,16 +364,19 @@ func (p *Processor) processTimedKernEvents(recordChannel chan RecordWithConn) { func (p *Processor) processOldKernEvents(recordChannel chan RecordWithConn) { now := time.Now() - lastIndex := 0 - for i := 0; i < len(p.tempKernEvents); i++ { - if now.Sub(p.tempKernEvents[i].timestamp) > 100*time.Millisecond { - p.processKernEvent(p.tempKernEvents[i].event, recordChannel) - lastIndex = i + 1 + for !p.tempKernEvents.IsEmpty() { + _event, err := p.tempKernEvents.Peek() + if err != nil { + break + } + event := _event.(TimedEvent) + if now.Sub(event.timestamp) > 100*time.Millisecond { + p.processKernEvent(event.event, recordChannel) + p.tempKernEvents.Read() } else { break } } - p.tempKernEvents = p.tempKernEvents[lastIndex:] } func (p *Processor) processKernEvent(event *bpf.AgentKernEvt, recordChannel chan RecordWithConn) { @@ -426,7 +430,7 @@ func (p *Processor) processKernEvent(event *bpf.AgentKernEvt, recordChannel chan func (p *Processor) handleSyscallEvent(event *bpf.SyscallEventData, recordChannel chan RecordWithConn) { // Add event to the temporary queue - p.tempSyscallEvents = append(p.tempSyscallEvents, TimedSyscallEvent{event: event, timestamp: time.Now()}) + p.tempSyscallEvents.Write(TimedSyscallEvent{event: event, timestamp: time.Now()}) // Process events in the queue that have been there for more than 100ms p.processOldSyscallEvents(recordChannel) @@ -439,16 +443,19 @@ func (p *Processor) processTimedSyscallEvents(recordChannel chan RecordWithConn) func (p *Processor) processOldSyscallEvents(recordChannel chan RecordWithConn) { now := time.Now() - lastIndex := 0 - for i := 0; i < len(p.tempSyscallEvents); i++ { - if now.Sub(p.tempSyscallEvents[i].timestamp) > 100*time.Millisecond { - p.processSyscallEvent(p.tempSyscallEvents[i].event, recordChannel) - lastIndex = i + 1 + for !p.tempSyscallEvents.IsEmpty() { + _event, err := p.tempSyscallEvents.Peek() + if err != nil { + break + } + event := _event.(TimedSyscallEvent) + if now.Sub(event.timestamp) > 100*time.Millisecond { + p.processSyscallEvent(event.event, recordChannel) + p.tempSyscallEvents.Read() } else { break } } - p.tempSyscallEvents = p.tempSyscallEvents[lastIndex:] } func (p *Processor) processSyscallEvent(event *bpf.SyscallEventData, recordChannel chan RecordWithConn) { @@ -475,10 +482,7 @@ func (p *Processor) processSyscallEvent(event *bpf.SyscallEventData, recordChann common.BPFEventLog.Debugf("[syscall][len=%d][ts=%d][fn=%d]%s | %s", max(event.SyscallEvent.BufSize, event.SyscallEvent.Ke.Len), event.SyscallEvent.Ke.Ts, event.SyscallEvent.GetSourceFunction(), conn.ToString(), string(event.Buf)) } - addedToBuffer := conn.OnSyscallEvent(event.Buf, event, recordChannel) - if addedToBuffer { - conn.AddSyscallEvent(event) - } + conn.OnSyscallEvent(event.Buf, event, recordChannel) } else if conn.Protocol == bpf.AgentTrafficProtocolTKProtocolUnset { conn.AddSyscallEvent(event) if common.BPFEventLog.Level >= logrus.DebugLevel { @@ -498,7 +502,7 @@ func (p *Processor) processSyscallEvent(event *bpf.SyscallEventData, recordChann func (p *Processor) handleSslEvent(event *bpf.SslData, recordChannel chan RecordWithConn) { // Add event to the temporary queue - p.tempSslEvents = append(p.tempSslEvents, TimedSslEvent{event: event, timestamp: time.Now()}) + p.tempSslEvents.Write(TimedSslEvent{event: event, timestamp: time.Now()}) // Process events in the queue that have been there for more than 100ms p.processOldSslEvents(recordChannel) @@ -510,16 +514,19 @@ func (p *Processor) processTimedSslEvents(recordChannel chan RecordWithConn) { func (p *Processor) processOldSslEvents(recordChannel chan RecordWithConn) { now := time.Now() - lastIndex := 0 - for i := 0; i < len(p.tempSslEvents); i++ { - if now.Sub(p.tempSslEvents[i].timestamp) > 100*time.Millisecond { - p.processSslEvent(p.tempSslEvents[i].event, recordChannel) - lastIndex = i + 1 + for !p.tempSslEvents.IsEmpty() { + _event, err := p.tempSslEvents.Peek() + if err != nil { + break + } + event := _event.(TimedSslEvent) + if now.Sub(event.timestamp) > 100*time.Millisecond { + p.processSslEvent(event.event, recordChannel) + p.tempSslEvents.Read() } else { break } } - p.tempSslEvents = p.tempSslEvents[lastIndex:] } func (p *Processor) processSslEvent(event *bpf.SslData, recordChannel chan RecordWithConn) { diff --git a/common/ringbuffer.go b/common/ringbuffer.go new file mode 100644 index 00000000..ee115d8b --- /dev/null +++ b/common/ringbuffer.go @@ -0,0 +1,181 @@ +package common + +import "errors" + +var ( + ErrRingBufferFull = errors.New("ring buffer is full") + ErrRingBufferEmpty = errors.New("ring buffer is empty") +) + +// RingBuffer represents a ring buffer. +type RingBuffer struct { + data []any + size int + start, end int + isFull bool +} + +// NewRingBuffer creates a new ring buffer with the given size. +func NewRingBuffer(size int) *RingBuffer { + return &RingBuffer{ + data: make([]any, size), + size: size, + } +} + +// Write adds an element to the ring buffer. +func (rb *RingBuffer) Write(value any) error { + if rb.isFull { + return ErrRingBufferFull + } + rb.data[rb.end] = value + rb.end = (rb.end + 1) % rb.size + if rb.end == rb.start { + rb.isFull = true + } + return nil +} + +// Read removes and returns the oldest element from the ring buffer. +func (rb *RingBuffer) Read() (any, error) { + if rb.IsEmpty() { + return nil, ErrRingBufferEmpty + } + value := rb.data[rb.start] + rb.data[rb.start] = nil + rb.start = (rb.start + 1) % rb.size + rb.isFull = false + return value, nil +} + +// IsEmpty checks if the ring buffer is empty. +func (rb *RingBuffer) IsEmpty() bool { + return !rb.isFull && rb.start == rb.end +} + +// IsFull checks if the ring buffer is full. +func (rb *RingBuffer) IsFull() bool { + return rb.isFull +} + +// Size returns the number of elements in the ring buffer. +func (rb *RingBuffer) Size() int { + if rb.isFull { + return rb.size + } + if rb.end >= rb.start { + return rb.end - rb.start + } + return rb.size - rb.start + rb.end +} + +// Peek returns the oldest element without removing it from the ring buffer. +func (rb *RingBuffer) Peek() (any, error) { + if rb.IsEmpty() { + return nil, ErrRingBufferEmpty + } + return rb.data[rb.start], nil +} + +// ReadIndex retrieves the value at the specified index without removing it. +func (rb *RingBuffer) ReadIndex(index int) (any, error) { + if index < 0 || index >= rb.Size() { + return nil, errors.New("index out of range") + } + actualIndex := (rb.start + index) % rb.size + return rb.data[actualIndex], nil +} + +// Insert adds an element at the specified index in the ring buffer. +func (rb *RingBuffer) Insert(index int, value any) error { + if index < 0 || index > rb.Size() { + return errors.New("index out of range") + } + if rb.isFull { + return ErrRingBufferFull + } + + // Calculate the actual index in the underlying array + actualIndex := (rb.start + index) % rb.size + + // Shift elements to the right to make space for the new element + for i := rb.Size(); i > index; i-- { + rb.data[(rb.start+i)%rb.size] = rb.data[(rb.start+i-1)%rb.size] + } + + // Insert the new element + rb.data[actualIndex] = value + rb.end = (rb.end + 1) % rb.size + if rb.end == rb.start { + rb.isFull = true + } + return nil +} + +// BinarySearch performs a binary search on the ring buffer. +// It assumes that the buffer is sorted. +func (rb *RingBuffer) BinarySearch(target any, compare func(a, b any) int) (int, bool) { + if rb.IsEmpty() { + return 0, false + } + + low, high := 0, rb.Size()-1 + for low <= high { + mid := (low + high) / 2 + midValue := rb.data[(rb.start+mid)%rb.size] + comp := compare(midValue, target) + if comp == 0 { + return mid, true + } else if comp < 0 { + low = mid + 1 + } else { + high = mid - 1 + } + } + return low, false +} + +// MaxCapacity returns the maximum capacity of the ring buffer. +func (rb *RingBuffer) MaxCapacity() int { + return rb.size +} + +// Clear removes all elements from the ring buffer. +func (rb *RingBuffer) Clear() { + rb.data = make([]any, rb.size) + rb.start = 0 + rb.end = 0 + rb.isFull = false +} + +// ForEach iterates over all elements in the ring buffer and applies the given function. +// If the function returns false, the iteration stops. +func (rb *RingBuffer) ForEach(action func(any) bool) { + if rb.IsEmpty() { + return + } + for i := 0; i < rb.Size(); i++ { + index := (rb.start + i) % rb.size + if !action(rb.data[index]) { + break + } + } +} + +// DiscardBeforeIndex discards all elements before the specified index. +func (rb *RingBuffer) DiscardBeforeIndex(index int) error { + if index < 0 || index >= rb.Size() { + return errors.New("index out of range") + } + + // Calculate the actual index in the underlying array + actualIndex := (rb.start + index) % rb.size + + // Discard elements + for rb.start != actualIndex { + rb.data[rb.start] = nil + rb.start = (rb.start + 1) % rb.size + rb.isFull = false + } + return nil +} diff --git a/testdata/test_filter_by_l4.sh b/testdata/test_filter_by_l4.sh index 2694d578..dd03e53c 100755 --- a/testdata/test_filter_by_l4.sh +++ b/testdata/test_filter_by_l4.sh @@ -24,22 +24,22 @@ function test_filter_by_remote_port() { remote_port=88 timeout 20 ${CMD} watch --debug-output http --remote-ports "$remote_port" 2>&1 | tee "${LNAME_REMOTE_PORT}" & sleep 10 - curl http://example.com &>/dev/null || true + curl http://baidu.com &>/dev/null || true wait cat "${LNAME_REMOTE_PORT}" - if cat "${LNAME_REMOTE_PORT}" | grep "example.com"; then + if cat "${LNAME_REMOTE_PORT}" | grep "baidu.com"; then exit 1 fi remote_port=80 timeout 20 ${CMD} watch --debug-output http --remote-ports "$remote_port" 2>&1 | tee "${LNAME_REMOTE_PORT}" & sleep 10 - curl http://example.com &>/dev/null || true + curl http://baidu.com &>/dev/null || true wait cat "${LNAME_REMOTE_PORT}" - if ! cat "${LNAME_REMOTE_PORT}" | grep "example.com"; then + if ! cat "${LNAME_REMOTE_PORT}" | grep "baidu.com"; then exit 1 fi } From bf9676bfbbcfdfbf47316dcd80df28ea64ca34ad Mon Sep 17 00:00:00 2001 From: hengyoush Date: Tue, 28 Jan 2025 16:41:17 +0800 Subject: [PATCH 2/2] fix: fix index out of range error --- agent/analysis/stat.go | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/agent/analysis/stat.go b/agent/analysis/stat.go index cda56bf3..7635dfd4 100644 --- a/agent/analysis/stat.go +++ b/agent/analysis/stat.go @@ -233,7 +233,7 @@ func (s *StatRecorder) ReceiveRecord(r protocol.Record, connection *conn.Connect if !traceDevEvent { annotatedRecord.TotalDuration = annotatedRecord.BlackBoxDuration } - if !traceSocketEvent && hasNicInEvents && canCalculateReadPathTime { + if !traceSocketEvent && hasNicInEvents && canCalculateReadPathTime && hasReadSyscallEvents { if nicInTimestamp, _, ok := events.nicIngressEvents[0].GetMinIfItmestampAttr(); ok { annotatedRecord.ReadFromSocketBufferDuration = float64(events.readSyscallEvents[len(events.readSyscallEvents)-1].GetEndTs() - uint64(nicInTimestamp)) } @@ -305,7 +305,7 @@ func (s *StatRecorder) ReceiveRecord(r protocol.Record, connection *conn.Connect if hasTcpInEvents && hasNicInEvents && canCalculateReadPathTime { annotatedRecord.CopyToSocketBufferDuration = float64(events.tcpInEvents[len(events.tcpInEvents)-1].GetStartTs() - events.nicIngressEvents[0].GetStartTs()) } - if !traceSocketEvent && hasNicInEvents && canCalculateReadPathTime { + if !traceSocketEvent && hasNicInEvents && canCalculateReadPathTime && hasReadSyscallEvents { if _nicIngressTimestamp, _, ok := events.nicIngressEvents[0].GetMinIfItmestampAttr(); ok { annotatedRecord.ReadFromSocketBufferDuration = float64(events.readSyscallEvents[len(events.readSyscallEvents)-1].GetEndTs() - uint64(_nicIngressTimestamp)) }