Skip to content

Commit

Permalink
Use time values instead of pointers
Browse files Browse the repository at this point in the history
  • Loading branch information
Micah Parks committed Nov 13, 2021
1 parent 846f636 commit b671333
Show file tree
Hide file tree
Showing 9 changed files with 41 additions and 41 deletions.
6 changes: 3 additions & 3 deletions examples/aws_cognito/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,9 +31,9 @@ func main() {
RefreshErrorHandler: func(err error) {
log.Printf("There was an error with the jwt.Keyfunc\nError: %s", err.Error())
},
RefreshInterval: &refreshInterval,
RefreshRateLimit: &refreshRateLimit,
RefreshTimeout: &refreshTimeout,
RefreshInterval: refreshInterval,
RefreshRateLimit: refreshRateLimit,
RefreshTimeout: refreshTimeout,
RefreshUnknownKID: &refreshUnknownKID,
}

Expand Down
2 changes: 1 addition & 1 deletion examples/ctx/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ func main() {
// seconds. This timeout is also used to create the initial context.Context for keyfunc.Get.
refreshTimeout := time.Second * 10
options := keyfunc.Options{
RefreshTimeout: &refreshTimeout,
RefreshTimeout: refreshTimeout,
RefreshErrorHandler: func(err error) {
log.Printf("There was an error with the jwt.Keyfunc\nError: %s", err.Error())
},
Expand Down
6 changes: 3 additions & 3 deletions examples/given/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -47,9 +47,9 @@ func main() {
RefreshErrorHandler: func(err error) {
log.Printf("There was an error with the jwt.Keyfunc\nError: %s", err.Error())
},
RefreshInterval: &refreshInterval,
RefreshRateLimit: &refreshRateLimit,
RefreshTimeout: &refreshTimeout,
RefreshInterval: refreshInterval,
RefreshRateLimit: refreshRateLimit,
RefreshTimeout: refreshTimeout,
RefreshUnknownKID: &refreshUnknownKID,
}

Expand Down
4 changes: 2 additions & 2 deletions examples/interval/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,8 @@ func main() {
refreshInterval := time.Hour
refreshTimeout := time.Second * 10
options := keyfunc.Options{
RefreshInterval: &refreshInterval,
RefreshTimeout: &refreshTimeout,
RefreshInterval: refreshInterval,
RefreshTimeout: refreshTimeout,
RefreshErrorHandler: func(err error) {
log.Printf("There was an error with the jwt.Keyfunc\nError: %s", err.Error())
},
Expand Down
6 changes: 3 additions & 3 deletions examples/recommended_options/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,9 +32,9 @@ func main() {
RefreshErrorHandler: func(err error) {
log.Printf("There was an error with the jwt.Keyfunc\nError: %s", err.Error())
},
RefreshInterval: &refreshInterval,
RefreshRateLimit: &refreshRateLimit,
RefreshTimeout: &refreshTimeout,
RefreshInterval: refreshInterval,
RefreshRateLimit: refreshRateLimit,
RefreshTimeout: refreshTimeout,
RefreshUnknownKID: &refreshUnknownKID,
}

Expand Down
22 changes: 11 additions & 11 deletions get.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,8 +33,8 @@ func Get(jwksURL string, options ...Options) (jwks *JWKs, err error) {
if jwks.client == nil {
jwks.client = http.DefaultClient
}
if jwks.refreshTimeout == nil {
jwks.refreshTimeout = &defaultRefreshTimeout
if jwks.refreshTimeout == 0 {
jwks.refreshTimeout = defaultRefreshTimeout
}

// Get the keys for the JWKs.
Expand All @@ -43,7 +43,7 @@ func Get(jwksURL string, options ...Options) (jwks *JWKs, err error) {
}

// Check to see if a background refresh of the JWKs should happen.
if jwks.refreshInterval != nil || jwks.refreshUnknownKID {
if jwks.refreshInterval != 0 || jwks.refreshUnknownKID {

// Attach a context used to end the background goroutine.
jwks.ctx, jwks.cancel = context.WithCancel(context.Background())
Expand All @@ -66,8 +66,8 @@ func (j *JWKs) backgroundRefresh() {
var lastRefresh time.Time
var queueOnce sync.Once
var refreshMux sync.Mutex
if j.refreshRateLimit != nil {
lastRefresh = time.Now().Add(-*j.refreshRateLimit)
if j.refreshRateLimit != 0 {
lastRefresh = time.Now().Add(-j.refreshRateLimit)
}

// Create a channel that will never send anything unless there is a refresh interval.
Expand All @@ -77,8 +77,8 @@ func (j *JWKs) backgroundRefresh() {
for {

// If there is a refresh interval, create the channel for it.
if j.refreshInterval != nil {
refreshInterval = time.After(*j.refreshInterval)
if j.refreshInterval != 0 {
refreshInterval = time.After(j.refreshInterval)
}

// Wait for a refresh to occur or the background to end.
Expand All @@ -98,7 +98,7 @@ func (j *JWKs) backgroundRefresh() {

// Rate limit, if needed.
refreshMux.Lock()
if j.refreshRateLimit != nil && lastRefresh.Add(*j.refreshRateLimit).After(time.Now()) {
if j.refreshRateLimit != 0 && lastRefresh.Add(j.refreshRateLimit).After(time.Now()) {

// Don't make the JWT parsing goroutine wait for the JWKs to refresh.
cancel()
Expand All @@ -111,7 +111,7 @@ func (j *JWKs) backgroundRefresh() {

// Wait for the next time to refresh.
refreshMux.Lock()
wait := time.Until(lastRefresh.Add(*j.refreshRateLimit))
wait := time.Until(lastRefresh.Add(j.refreshRateLimit))
refreshMux.Unlock()
select {
case <-j.ctx.Done():
Expand Down Expand Up @@ -162,9 +162,9 @@ func (j *JWKs) refresh() (err error) {
var ctx context.Context
var cancel context.CancelFunc
if j.ctx != nil {
ctx, cancel = context.WithTimeout(j.ctx, *j.refreshTimeout)
ctx, cancel = context.WithTimeout(j.ctx, j.refreshTimeout)
} else {
ctx, cancel = context.WithTimeout(context.Background(), *j.refreshTimeout)
ctx, cancel = context.WithTimeout(context.Background(), j.refreshTimeout)
}
defer cancel()

Expand Down
6 changes: 3 additions & 3 deletions jwks.go
Original file line number Diff line number Diff line change
Expand Up @@ -44,10 +44,10 @@ type JWKs struct {
keys map[string]*jsonKey
mux sync.RWMutex
refreshErrorHandler ErrorHandler
refreshInterval *time.Duration
refreshRateLimit *time.Duration
refreshInterval time.Duration
refreshRateLimit time.Duration
refreshRequests chan context.CancelFunc
refreshTimeout *time.Duration
refreshTimeout time.Duration
refreshUnknownKID bool
}

Expand Down
18 changes: 9 additions & 9 deletions jwks_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ func TestInvalidServer(t *testing.T) {
// Set the options to refresh KID when unknown.
refreshInterval := time.Second
options := keyfunc.Options{
RefreshInterval: &refreshInterval,
RefreshInterval: refreshInterval,
RefreshErrorHandler: testingRefreshErrorHandler,
}

Expand Down Expand Up @@ -117,13 +117,13 @@ func TestJWKs(t *testing.T) {
RefreshErrorHandler: testingRefreshErrorHandler,
},
{
RefreshInterval: &testingRefreshInterval,
RefreshInterval: testingRefreshInterval,
},
{
RefreshRateLimit: &testingRateLimit,
RefreshRateLimit: testingRateLimit,
},
{
RefreshTimeout: &testingRefreshTimeout,
RefreshTimeout: testingRefreshTimeout,
},
}

Expand Down Expand Up @@ -156,8 +156,8 @@ func TestJWKs(t *testing.T) {
}

// Wait for the interval to pass, if required.
if opts.RefreshInterval != nil {
time.Sleep(*opts.RefreshInterval)
if opts.RefreshInterval != 0 {
time.Sleep(opts.RefreshInterval)
}

// Iterate through the test cases.
Expand Down Expand Up @@ -280,9 +280,9 @@ func TestRateLimit(t *testing.T) {
RefreshErrorHandler: func(err error) {
t.Errorf("The package itself had an error.\nError: %s", err.Error())
},
RefreshInterval: &refreshInterval,
RefreshRateLimit: &refreshRateLimit,
RefreshTimeout: &refreshTimeout,
RefreshInterval: refreshInterval,
RefreshRateLimit: refreshRateLimit,
RefreshTimeout: refreshTimeout,
RefreshUnknownKID: &refreshUnknownKID,
}

Expand Down
12 changes: 6 additions & 6 deletions options.go
Original file line number Diff line number Diff line change
Expand Up @@ -39,16 +39,16 @@ type Options struct {
// RefreshInterval is the duration to refresh the JWKs in the background via a new HTTP request. If this is not nil,
// then a background goroutine will be used to refresh the JWKs once per the given interval. Make sure to call the
// JWKs.EndBackground method to end this goroutine when it's no longer needed.
RefreshInterval *time.Duration
RefreshInterval time.Duration

// RefreshRateLimit limits the rate at which refresh requests are granted. Only one refresh request can be queued
// at a time any refresh requests received while there is already a queue are ignored. It does not make sense to
// have RefreshInterval's value shorter than this.
RefreshRateLimit *time.Duration
RefreshRateLimit time.Duration

// RefreshTimeout is the duration for the context timeout used to create the HTTP request for a refresh of the JWKs.
// This defaults to one minute. This is only effectual if RefreshInterval is not nil.
RefreshTimeout *time.Duration
RefreshTimeout time.Duration

// RefreshUnknownKID indicates that the JWKs refresh request will occur every time a kid that isn't cached is seen.
// This is done through a background goroutine. Without specifying a RefreshInterval a malicious client could
Expand Down Expand Up @@ -79,13 +79,13 @@ func applyOptions(jwks *JWKs, options Options) {
if options.RefreshErrorHandler != nil {
jwks.refreshErrorHandler = options.RefreshErrorHandler
}
if options.RefreshInterval != nil {
if options.RefreshInterval != 0 {
jwks.refreshInterval = options.RefreshInterval
}
if options.RefreshRateLimit != nil {
if options.RefreshRateLimit != 0 {
jwks.refreshRateLimit = options.RefreshRateLimit
}
if options.RefreshTimeout != nil {
if options.RefreshTimeout != 0 {
jwks.refreshTimeout = options.RefreshTimeout
}
if options.RefreshUnknownKID != nil {
Expand Down

0 comments on commit b671333

Please sign in to comment.