Skip to content

Commit

Permalink
GODRIVER-3095 Add moving STD to RTT Stats (mongodb#1845)
Browse files Browse the repository at this point in the history
  • Loading branch information
joyjwang authored Oct 8, 2024
1 parent 7446373 commit 24153e5
Show file tree
Hide file tree
Showing 4 changed files with 191 additions and 9 deletions.
48 changes: 39 additions & 9 deletions x/mongo/driver/topology/rtt_monitor.go
Original file line number Diff line number Diff line change
Expand Up @@ -41,11 +41,14 @@ type rttMonitor struct {
// connMu guards connecting and disconnecting. This is necessary since
// disconnecting will await the cancellation of a started connection. The
// use case for rttMonitor.connect needs to be goroutine safe.
connMu sync.Mutex
averageRTT time.Duration
averageRTTSet bool
movingMin *list.List
minRTT time.Duration
connMu sync.Mutex
averageRTT time.Duration
averageRTTSet bool
movingMin *list.List
minRTT time.Duration
stddevRTT time.Duration
stddevSum float64
callsToAppendMovingMin int

closeWg sync.WaitGroup
cfg *rttConfig
Expand Down Expand Up @@ -179,20 +182,24 @@ func (r *rttMonitor) runHellos(conn *connection) {
}
}

// reset sets the average and min RTT to 0. This should only be called from the server monitor when an error
// occurs during a server check. Errors in the RTT monitor should not reset the RTTs.
// reset sets the average, min, and stddev RTT to 0. This should only be called from the server monitor
// when an error occurs during a server check. Errors in the RTT monitor should not reset the RTTs.
func (r *rttMonitor) reset() {
r.mu.Lock()
defer r.mu.Unlock()

r.movingMin = list.New()
r.averageRTT = 0
r.averageRTTSet = false
r.stddevSum = 0
r.callsToAppendMovingMin = 0
}

// appendMovingMin will append the RTT to the movingMin list which tracks a
// minimum RTT within the last "minRTTSamplesForMovingMin" RTT samples.
func (r *rttMonitor) appendMovingMin(rtt time.Duration) {
r.callsToAppendMovingMin++

if r.movingMin == nil || rtt < 0 {
return
}
Expand All @@ -202,6 +209,12 @@ func (r *rttMonitor) appendMovingMin(rtt time.Duration) {
}

r.movingMin.PushBack(rtt)

// Collect a sum of stddevs over maxRTTSamplesForMovingMin calls, ignore if calls are less than max
if r.callsToAppendMovingMin >= maxRTTSamplesForMovingMin {
stddev := standardDeviationList(r.movingMin)
r.stddevSum += stddev
}
}

// min will return the minimum value in the movingMin list.
Expand All @@ -222,6 +235,21 @@ func (r *rttMonitor) min() time.Duration {
return min
}

// stddev will return the current moving stddev.
func (r *rttMonitor) stddev() time.Duration {
var stddev time.Duration

if r.callsToAppendMovingMin < maxRTTSamplesForMovingMin {
return 0
}

// Get the number of times stddev was updated and calculate the average stddev
frequency := (r.callsToAppendMovingMin + 1) - maxRTTSamplesForMovingMin
stddev = time.Duration(r.stddevSum / float64(frequency))

return stddev
}

func (r *rttMonitor) addSample(rtt time.Duration) {
// Lock for the duration of this method. We're doing compuationally inexpensive work very infrequently, so lock
// contention isn't expected.
Expand All @@ -230,6 +258,7 @@ func (r *rttMonitor) addSample(rtt time.Duration) {

r.appendMovingMin(rtt)
r.minRTT = r.min()
r.stddevRTT = r.stddev()

if !r.averageRTTSet {
r.averageRTT = rtt
Expand Down Expand Up @@ -262,7 +291,8 @@ func (r *rttMonitor) Stats() string {
defer r.mu.RUnlock()

return fmt.Sprintf(
"network round-trip time stats: moving avg: %v, min: %v",
"network round-trip time stats: moving avg: %v, min: %v, moving stddev: %v",
r.averageRTT,
r.minRTT)
r.minRTT,
r.stddevRTT)
}
68 changes: 68 additions & 0 deletions x/mongo/driver/topology/rtt_monitor_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -416,3 +416,71 @@ func TestRTTMonitor_min(t *testing.T) {
})
}
}

func TestRTTMonitor_stddev(t *testing.T) {
t.Parallel()

tests := []struct {
name string
samples []time.Duration
want float64
}{
{
name: "empty",
samples: []time.Duration{},
want: 0,
},
{
name: "one",
samples: makeArithmeticSamples(1, 1),
want: 0,
},
{
name: "below maxRTTSamples",
samples: makeArithmeticSamples(1, 5),
want: 0,
},
{
name: "equal maxRTTSamples",
samples: makeArithmeticSamples(1, 10),
want: 2.872281e+06,
},
{
name: "exceed maxRTTSamples",
samples: makeArithmeticSamples(1, 15),
want: 2.872281e+06,
},
{
name: "non-sequential",
samples: []time.Duration{
2 * time.Millisecond,
1 * time.Millisecond,
4 * time.Millisecond,
3 * time.Millisecond,
7 * time.Millisecond,
12 * time.Millisecond,
6 * time.Millisecond,
8 * time.Millisecond,
5 * time.Millisecond,
13 * time.Millisecond,
},
want: 3.806573e+06,
},
}

for _, test := range tests {
test := test // capture the range variable

t.Run(test.name, func(t *testing.T) {
t.Parallel()

rtt := &rttMonitor{
movingMin: list.New(),
}
for _, sample := range test.samples {
rtt.appendMovingMin(sample)
}
assert.Equal(t, test.want, float64(rtt.stddev()))
})
}
}
33 changes: 33 additions & 0 deletions x/mongo/driver/topology/stats.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
// Copyright (C) MongoDB, Inc. 2024-present.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0

package topology

import (
"container/list"
"math"
"time"
)

func standardDeviationList(l *list.List) float64 {
if l.Len() == 0 {
return 0
}

var mean, variance float64
count := 0.0

for el := l.Front(); el != nil; el = el.Next() {
count++
sample := float64(el.Value.(time.Duration))

delta := sample - mean
mean += delta / count
variance += delta * (sample - mean)
}

return math.Sqrt(variance / count)
}
51 changes: 51 additions & 0 deletions x/mongo/driver/topology/stats_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
// Copyright (C) MongoDB, Inc. 2022-present.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0

package topology

import (
"container/list"
"testing"
"time"

"go.mongodb.org/mongo-driver/v2/internal/assert"
)

func TestStandardDeviationList_Duration(t *testing.T) {
tests := []struct {
name string
data []time.Duration
want float64
}{
{
name: "empty",
data: []time.Duration{},
want: 0,
},
{
name: "multiple",
data: []time.Duration{
time.Millisecond,
2 * time.Millisecond,
time.Microsecond,
},
want: 816088.36667497,
},
}

for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
l := list.New()
for _, d := range test.data {
l.PushBack(d)
}

got := standardDeviationList(l)

assert.InDelta(t, test.want, got, 1e-6)
})
}
}

0 comments on commit 24153e5

Please sign in to comment.