Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add custom backoff strategy option #302

Merged
merged 1 commit into from
Oct 4, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 34 additions & 13 deletions v2/workloadapi/backoff.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,30 +5,51 @@ import (
"time"
)

// backoff defines an linear backoff policy.
type backoff struct {
InitialDelay time.Duration
MaxDelay time.Duration
// BackoffStrategy provides backoff facilities.
type BackoffStrategy interface {
// NewBackoff returns a new backoff for the strategy. The returned
// Backoff is in the same state that it would be in after a call to
// Reset().
NewBackoff() Backoff
}

// Backoff provides backoff for a workload API operation.
type Backoff interface {
// Next returns the next backoff period.
Next() time.Duration

// Reset() resets the backoff.
Reset()
}

type defaultBackoffStrategy struct{}

func (defaultBackoffStrategy) NewBackoff() Backoff {
return newLinearBackoff()
}

// linearBackoff defines an linear backoff policy.
type linearBackoff struct {
initialDelay time.Duration
maxDelay time.Duration
n int
}

func newBackoff() *backoff {
return &backoff{
InitialDelay: time.Second,
MaxDelay: 30 * time.Second,
func newLinearBackoff() *linearBackoff {
return &linearBackoff{
initialDelay: time.Second,
maxDelay: 30 * time.Second,
n: 0,
}
}

// Duration returns the next wait period for the backoff. Not goroutine-safe.
func (b *backoff) Duration() time.Duration {
func (b *linearBackoff) Next() time.Duration {
backoff := float64(b.n) + 1
d := math.Min(b.InitialDelay.Seconds()*backoff, b.MaxDelay.Seconds())
d := math.Min(b.initialDelay.Seconds()*backoff, b.maxDelay.Seconds())
b.n++
return time.Duration(d) * time.Second
}

// Reset resets the backoff's state.
func (b *backoff) Reset() {
func (b *linearBackoff) Reset() {
b.n = 0
}
23 changes: 8 additions & 15 deletions v2/workloadapi/backoff_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,34 +7,27 @@ import (
"github.com/stretchr/testify/require"
)

func TestBackoff(t *testing.T) {
new := func() *backoff { //nolint:all
b := newBackoff()
b.InitialDelay = time.Second
b.MaxDelay = 30 * time.Second
return b
}

testUntilMax := func(t *testing.T, b *backoff) {
func TestLinearBackoff(t *testing.T) {
testUntilMax := func(t *testing.T, b *linearBackoff) {
for i := 1; i < 30; i++ {
require.Equal(t, time.Duration(i)*time.Second, b.Duration())
require.Equal(t, time.Duration(i)*time.Second, b.Next())
}
require.Equal(t, 30*time.Second, b.Duration())
require.Equal(t, 30*time.Second, b.Duration())
require.Equal(t, 30*time.Second, b.Duration())
require.Equal(t, 30*time.Second, b.Next())
require.Equal(t, 30*time.Second, b.Next())
require.Equal(t, 30*time.Second, b.Next())
}

t.Run("test max", func(t *testing.T) {
t.Parallel()

b := new()
b := newLinearBackoff()
testUntilMax(t, b)
})

t.Run("test reset", func(t *testing.T) {
t.Parallel()

b := new()
b := newLinearBackoff()
testUntilMax(t, b)

b.Reset()
Expand Down
19 changes: 10 additions & 9 deletions v2/workloadapi/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,7 @@ func (c *Client) FetchX509Bundles(ctx context.Context) (*x509bundle.Set, error)
// WatchX509Bundles watches for changes to the X.509 bundles. The watcher receives
// the updated X.509 bundles.
func (c *Client) WatchX509Bundles(ctx context.Context, watcher X509BundleWatcher) error {
backoff := newBackoff()
backoff := c.config.backoffStrategy.NewBackoff()
for {
err := c.watchX509Bundles(ctx, watcher, backoff)
watcher.OnX509BundlesWatchError(err)
Expand Down Expand Up @@ -152,7 +152,7 @@ func (c *Client) FetchX509Context(ctx context.Context) (*X509Context, error) {
// WatchX509Context watches for updates to the X.509 context. The watcher
// receives the updated X.509 context.
func (c *Client) WatchX509Context(ctx context.Context, watcher X509ContextWatcher) error {
backoff := newBackoff()
backoff := c.config.backoffStrategy.NewBackoff()
for {
err := c.watchX509Context(ctx, watcher, backoff)
watcher.OnX509ContextWatchError(err)
Expand Down Expand Up @@ -224,7 +224,7 @@ func (c *Client) FetchJWTBundles(ctx context.Context) (*jwtbundle.Set, error) {
// WatchJWTBundles watches for changes to the JWT bundles. The watcher receives
// the updated JWT bundles.
func (c *Client) WatchJWTBundles(ctx context.Context, watcher JWTBundleWatcher) error {
backoff := newBackoff()
backoff := c.config.backoffStrategy.NewBackoff()
for {
err := c.watchJWTBundles(ctx, watcher, backoff)
watcher.OnJWTBundlesWatchError(err)
Expand Down Expand Up @@ -258,7 +258,7 @@ func (c *Client) newConn(ctx context.Context) (*grpc.ClientConn, error) {
return grpc.DialContext(ctx, c.config.address, c.config.dialOptions...) //nolint:staticcheck // preserve backcompat with WithDialOptions option
}

func (c *Client) handleWatchError(ctx context.Context, err error, backoff *backoff) error {
func (c *Client) handleWatchError(ctx context.Context, err error, backoff Backoff) error {
code := status.Code(err)
if code == codes.Canceled {
return err
Expand All @@ -270,7 +270,7 @@ func (c *Client) handleWatchError(ctx context.Context, err error, backoff *backo
}

c.config.log.Errorf("Failed to watch the Workload API: %v", err)
retryAfter := backoff.Duration()
retryAfter := backoff.Next()
c.config.log.Debugf("Retrying watch in %s", retryAfter)
select {
case <-time.After(retryAfter):
Expand All @@ -281,7 +281,7 @@ func (c *Client) handleWatchError(ctx context.Context, err error, backoff *backo
}
}

func (c *Client) watchX509Context(ctx context.Context, watcher X509ContextWatcher, backoff *backoff) error {
func (c *Client) watchX509Context(ctx context.Context, watcher X509ContextWatcher, backoff Backoff) error {
ctx, cancel := context.WithCancel(withHeader(ctx))
defer cancel()

Expand All @@ -308,7 +308,7 @@ func (c *Client) watchX509Context(ctx context.Context, watcher X509ContextWatche
}
}

func (c *Client) watchJWTBundles(ctx context.Context, watcher JWTBundleWatcher, backoff *backoff) error {
func (c *Client) watchJWTBundles(ctx context.Context, watcher JWTBundleWatcher, backoff Backoff) error {
ctx, cancel := context.WithCancel(withHeader(ctx))
defer cancel()

Expand All @@ -335,7 +335,7 @@ func (c *Client) watchJWTBundles(ctx context.Context, watcher JWTBundleWatcher,
}
}

func (c *Client) watchX509Bundles(ctx context.Context, watcher X509BundleWatcher, backoff *backoff) error {
func (c *Client) watchX509Bundles(ctx context.Context, watcher X509BundleWatcher, backoff Backoff) error {
ctx, cancel := context.WithCancel(withHeader(ctx))
defer cancel()

Expand Down Expand Up @@ -402,7 +402,8 @@ func withHeader(ctx context.Context) context.Context {

func defaultClientConfig() clientConfig {
return clientConfig{
log: logger.Null,
log: logger.Null,
backoffStrategy: defaultBackoffStrategy{},
}
}

Expand Down
48 changes: 45 additions & 3 deletions v2/workloadapi/client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"context"
"crypto/x509"
"sync"
"sync/atomic"
"testing"
"time"

Expand Down Expand Up @@ -103,7 +104,10 @@ func TestFetchX509Bundles(t *testing.T) {
func TestWatchX509Bundles(t *testing.T) {
wl := fakeworkloadapi.New(t)
defer wl.Stop()
c, err := New(context.Background(), WithAddr(wl.Addr()))

backoffStrategy := &testBackoffStrategy{}

c, err := New(context.Background(), WithAddr(wl.Addr()), WithBackoffStrategy(backoffStrategy))
require.NoError(t, err)
defer c.Close()

Expand Down Expand Up @@ -149,6 +153,9 @@ func TestWatchX509Bundles(t *testing.T) {
wl.Stop()
tw.WaitForUpdates(1)
assert.Len(t, tw.Errors(), 2)

// Assert that there was the expected number of backoffs.
assert.Equal(t, 2, backoffStrategy.BackedOff())
}

func TestFetchX509Context(t *testing.T) {
Expand Down Expand Up @@ -213,7 +220,10 @@ func TestWatchX509Context(t *testing.T) {
federatedCA := test.NewCA(t, federatedTD)
wl := fakeworkloadapi.New(t)
defer wl.Stop()
c, err := New(context.Background(), WithAddr(wl.Addr()))

backoffStrategy := &testBackoffStrategy{}

c, err := New(context.Background(), WithAddr(wl.Addr()), WithBackoffStrategy(backoffStrategy))
require.NoError(t, err)
defer c.Close()

Expand Down Expand Up @@ -291,6 +301,9 @@ func TestWatchX509Context(t *testing.T) {

cancel()
wg.Wait()

// Assert that there was the expected number of backoffs.
assert.Equal(t, 2, backoffStrategy.BackedOff())
}

func TestFetchJWTSVID(t *testing.T) {
Expand Down Expand Up @@ -375,7 +388,10 @@ func TestFetchJWTBundles(t *testing.T) {
func TestWatchJWTBundles(t *testing.T) {
wl := fakeworkloadapi.New(t)
defer wl.Stop()
c, err := New(context.Background(), WithAddr(wl.Addr()))

backoffStrategy := &testBackoffStrategy{}

c, err := New(context.Background(), WithAddr(wl.Addr()), WithBackoffStrategy(backoffStrategy))
require.NoError(t, err)
defer c.Close()

Expand Down Expand Up @@ -421,6 +437,9 @@ func TestWatchJWTBundles(t *testing.T) {
wl.Stop()
tw.WaitForUpdates(1)
assert.Len(t, tw.Errors(), 2)

// Assert that there was the expected number of backoffs.
assert.Equal(t, 2, backoffStrategy.BackedOff())
}

func TestValidateJWTSVID(t *testing.T) {
Expand Down Expand Up @@ -605,3 +624,26 @@ func (w *testWatcher) WaitForUpdates(expectedNumUpdates int) {
}
}
}

type testBackoffStrategy struct {
backedOff int32
}

func (s *testBackoffStrategy) NewBackoff() Backoff {
return testBackoff{backedOff: &s.backedOff}
}

func (s *testBackoffStrategy) BackedOff() int {
return int(atomic.LoadInt32(&s.backedOff))
}

type testBackoff struct {
backedOff *int32
}

func (b testBackoff) Next() time.Duration {
atomic.AddInt32(b.backedOff, 1)
return time.Millisecond * 200
}

func (b testBackoff) Reset() {}
17 changes: 13 additions & 4 deletions v2/workloadapi/option.go
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,14 @@ func WithLogger(logger logger.Logger) ClientOption {
})
}

// WithBackoff provides a custom backoff strategy that replaces the
// default backoff strategy (linear backoff).
func WithBackoffStrategy(backoffStrategy BackoffStrategy) ClientOption {
return clientOption(func(c *clientConfig) {
c.backoffStrategy = backoffStrategy
})
}

// SourceOption are options that are shared among all option types.
type SourceOption interface {
configureX509Source(*x509SourceConfig)
Expand Down Expand Up @@ -81,10 +89,11 @@ type BundleSourceOption interface {
}

type clientConfig struct {
address string
namedPipeName string
dialOptions []grpc.DialOption
log logger.Logger
address string
namedPipeName string
dialOptions []grpc.DialOption
log logger.Logger
backoffStrategy BackoffStrategy
}

type clientOption func(*clientConfig)
Expand Down