Skip to content

Commit

Permalink
Avoid overflow in BackOffDelay and CombineDelay
Browse files Browse the repository at this point in the history
Although we have MaxDelay option, that only kicks in after the
DelayTypeFunc calculation is done, and the calculation itself could
overflow.

For example, with initial delay of 1 second, the 34th retry would cause
the calculated back off delay to go negative, and the 35th retry will
cause it to go back to zero.

This is loosely ported from our implementation outside of the package:
reddit/baseplate.go@660b6c5
  • Loading branch information
fishy committed Sep 16, 2020
1 parent a8f6dc7 commit c0babf7
Show file tree
Hide file tree
Showing 2 changed files with 124 additions and 5 deletions.
29 changes: 25 additions & 4 deletions options.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package retry

import (
"math"
"math/rand"
"time"
)
Expand All @@ -23,6 +24,8 @@ type Config struct {
retryIf RetryIfFunc
delayType DelayTypeFunc
lastErrorOnly bool

maxBackOffN uint
}

// Option represents an option for retry.
Expand Down Expand Up @@ -77,7 +80,20 @@ func DelayType(delayType DelayTypeFunc) Option {

// BackOffDelay is a DelayType which increases delay between consecutive retries
func BackOffDelay(n uint, config *Config) time.Duration {
return config.delay * (1 << n)
// 1 << 63 would overflow signed int64 (time.Duration), thus 62.
const max uint = 62

if config.maxBackOffN == 0 {
if config.delay <= 0 {
config.delay = 1
}
config.maxBackOffN = max - uint(math.Floor(math.Log2(float64(config.delay))))
}

if n > config.maxBackOffN {
n = config.maxBackOffN
}
return config.delay << n
}

// FixedDelay is a DelayType which keeps delay the same through all iterations
Expand All @@ -92,12 +108,17 @@ func RandomDelay(_ uint, config *Config) time.Duration {

// CombineDelay is a DelayType the combines all of the specified delays into a new DelayTypeFunc
func CombineDelay(delays ...DelayTypeFunc) DelayTypeFunc {
const maxInt64 = uint64(math.MaxInt64)

return func(n uint, config *Config) time.Duration {
var total time.Duration
var total uint64
for _, delay := range delays {
total += delay(n, config)
total += uint64(delay(n, config))
if total > maxInt64 {
total = maxInt64
}
}
return total
return time.Duration(total)
}
}

Expand Down
100 changes: 99 additions & 1 deletion retry_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -159,6 +159,104 @@ func TestMaxDelay(t *testing.T) {
)
dur := time.Since(start)
assert.Error(t, err)
assert.True(t, dur > 170*time.Millisecond, "5 times with maximum delay retry is longer than 70ms")
assert.True(t, dur > 170*time.Millisecond, "5 times with maximum delay retry is longer than 170ms")
assert.True(t, dur < 200*time.Millisecond, "5 times with maximum delay retry is shorter than 200ms")
}

func TestBackOffDelay(t *testing.T) {
for _, c := range []struct {
label string
delay time.Duration
expectedMaxN uint
n uint
expectedDelay time.Duration
}{
{
label: "negative-delay",
delay: -1,
expectedMaxN: 62,
n: 2,
expectedDelay: 4,
},
{
label: "zero-delay",
delay: 0,
expectedMaxN: 62,
n: 65,
expectedDelay: 1 << 62,
},
{
label: "one-second",
delay: time.Second,
expectedMaxN: 33,
n: 62,
expectedDelay: time.Second << 33,
},
} {
t.Run(
c.label,
func(t *testing.T) {
config := Config{
delay: c.delay,
}
delay := BackOffDelay(c.n, &config)
assert.Equal(t, c.expectedMaxN, config.maxBackOffN, "max n mismatch")
assert.Equal(t, c.expectedDelay, delay, "delay duration mismatch")
},
)
}
}

func TestCombineDelay(t *testing.T) {
f := func(d time.Duration) DelayTypeFunc {
return func(_ uint, _ *Config) time.Duration {
return d
}
}
const max = time.Duration(1<<63 - 1)
for _, c := range []struct {
label string
delays []time.Duration
expected time.Duration
}{
{
label: "empty",
},
{
label: "single",
delays: []time.Duration{
time.Second,
},
expected: time.Second,
},
{
label: "negative",
delays: []time.Duration{
time.Second,
-time.Millisecond,
},
expected: time.Second - time.Millisecond,
},
{
label: "overflow",
delays: []time.Duration{
max,
time.Second,
time.Millisecond,
},
expected: max,
},
} {
t.Run(
c.label,
func(t *testing.T) {
funcs := make([]DelayTypeFunc, len(c.delays))
for i, d := range c.delays {
funcs[i] = f(d)
}
actual := CombineDelay(funcs...)(0, nil)
assert.Equal(t, c.expected, actual, "delay duration mismatch")
},
)
}
}

0 comments on commit c0babf7

Please sign in to comment.