generated from xmidt-org/.go-template
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
composable CheckRedirect functions for clients
- Loading branch information
Showing
3 changed files
with
347 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,8 @@ | ||
// SPDX-FileCopyrightText: 2024 Comcast Cable Communications Management, LLC | ||
// SPDX-License-Identifier: Apache-2.0 | ||
|
||
/* | ||
Package client provides infrastructure for build HTTP clients, including observability | ||
and logging. | ||
*/ | ||
package client |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,97 @@ | ||
// SPDX-FileCopyrightText: 2024 Comcast Cable Communications Management, LLC | ||
// SPDX-License-Identifier: Apache-2.0 | ||
|
||
package client | ||
|
||
import ( | ||
"fmt" | ||
"net/http" | ||
) | ||
|
||
// CheckRedirect is the type expected by http.Client.CheckRedirect. | ||
// | ||
// Closures of this type can also be chained together via the | ||
// NewCheckRedirects function. | ||
type CheckRedirect func(*http.Request, []*http.Request) error | ||
|
||
// CopyHeadersOnRedirect copies the headers from the most recent | ||
// request into the next request. If no names are supplied, this | ||
// function returns nil so that the default behavior will take over. | ||
func CopyHeadersOnRedirect(names ...string) CheckRedirect { | ||
if len(names) == 0 { | ||
return nil | ||
} | ||
|
||
names = append([]string{}, names...) | ||
for i, n := range names { | ||
names[i] = http.CanonicalHeaderKey(n) | ||
} | ||
|
||
return func(request *http.Request, via []*http.Request) error { | ||
previous := via[len(via)-1] // the most recent request | ||
for _, n := range names { | ||
// direct map access is faster, since we've already | ||
// canonicalized everything | ||
if values := previous.Header[n]; len(values) > 0 { | ||
request.Header[n] = values | ||
} | ||
} | ||
|
||
return nil | ||
} | ||
} | ||
|
||
// MaxRedirects returns a CheckRedirect that returns an error if | ||
// a maximum number of redirects has been reached. If the max | ||
// value is 0 or negative, then no redirects are allowed. | ||
func MaxRedirects(max int) CheckRedirect { | ||
if max < 0 { | ||
max = 0 | ||
} | ||
|
||
// create the error once and reuse it | ||
// this error text mimics the one used in net/http | ||
err := fmt.Errorf("stopped after %d redirects", max) | ||
return func(_ *http.Request, via []*http.Request) error { | ||
if len(via) >= max { | ||
return err | ||
} | ||
|
||
return nil | ||
} | ||
} | ||
|
||
// NewCheckRedirects produces a CheckRedirect that is the logical AND | ||
// of the given strategies. All the checks must pass, or the returned | ||
// function halts early and returns the error from the failing check. | ||
// | ||
// Since a nil http.Request.CheckRedirect indicates that the internal | ||
// default will be used, this function returns nil if no checks are | ||
// supplied. Additionally, any nil checks are skipped. If all checks | ||
// are nil, this function also returns nil. | ||
func NewCheckRedirects(checks ...CheckRedirect) CheckRedirect { | ||
if len(checks) == 0 { | ||
return nil | ||
} | ||
|
||
// check nils before allocating a copy | ||
for _, c := range checks { | ||
if c == nil { | ||
return nil | ||
} | ||
} | ||
|
||
// optimization: if there's only (1) check, just use that | ||
if len(checks) == 1 { | ||
return checks[0] | ||
} | ||
|
||
checks = append([]CheckRedirect{}, checks...) | ||
return func(request *http.Request, via []*http.Request) (err error) { | ||
for i := 0; err == nil && i < len(checks); i++ { | ||
err = checks[i](request, via) | ||
} | ||
|
||
return | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,242 @@ | ||
// SPDX-FileCopyrightText: 2024 Comcast Cable Communications Management, LLC | ||
// SPDX-License-Identifier: Apache-2.0 | ||
|
||
package client | ||
|
||
import ( | ||
"errors" | ||
"fmt" | ||
"net/http" | ||
"testing" | ||
|
||
"github.com/stretchr/testify/suite" | ||
) | ||
|
||
type RedirectTestSuite struct { | ||
suite.Suite | ||
} | ||
|
||
func (suite *RedirectTestSuite) testCopyHeadersOnRedirectWithNames() { | ||
var ( | ||
via = []*http.Request{ | ||
&http.Request{ | ||
Header: http.Header{ | ||
"X-Original": []string{"should not be copied"}, | ||
}, | ||
}, | ||
&http.Request{ | ||
Header: http.Header{ | ||
"X-Original": []string{"overwritten"}, | ||
"Single-Value": []string{"value1"}, | ||
"Multi-Value": []string{"value1", "value2", "value3"}, | ||
"Empty": nil, // empty values, so it should be skipped | ||
}, | ||
}, | ||
} | ||
|
||
request = &http.Request{ | ||
Header: http.Header{ | ||
"Next": []string{"NextValue"}, | ||
}, | ||
} | ||
|
||
copier = CopyHeadersOnRedirect( | ||
// check that canonicalization is happening ... | ||
"x-original", | ||
"sINGLE-VAlue", | ||
"Multi-Value", | ||
"Empty", | ||
) | ||
) | ||
|
||
suite.NoError( | ||
copier(request, via), | ||
) | ||
|
||
suite.Equal( | ||
http.Header{ | ||
"Next": []string{"NextValue"}, // should be untouched | ||
"X-Original": []string{"overwritten"}, | ||
"Single-Value": []string{"value1"}, | ||
"Multi-Value": []string{"value1", "value2", "value3"}, | ||
}, | ||
request.Header, | ||
) | ||
} | ||
|
||
func (suite *RedirectTestSuite) testCopyHeadersOnRedirectNoNames() { | ||
suite.Nil( | ||
CopyHeadersOnRedirect(), | ||
) | ||
} | ||
|
||
func (suite *RedirectTestSuite) TestCopyHeadersOnRedirect() { | ||
suite.Run("WithNames", suite.testCopyHeadersOnRedirectWithNames) | ||
suite.Run("NoNames", suite.testCopyHeadersOnRedirectNoNames) | ||
} | ||
|
||
// newVia is a helper function that creates a number of dummy requests | ||
// to use as a via in a check redirect test. | ||
func (suite *RedirectTestSuite) newVia(count int) (via []*http.Request) { | ||
if count > 0 { | ||
via = make([]*http.Request, 0, count) | ||
for len(via) < cap(via) { | ||
via = append( | ||
via, | ||
&http.Request{ | ||
Header: http.Header{}, | ||
}, | ||
) | ||
} | ||
} | ||
|
||
return | ||
} | ||
|
||
func (suite *RedirectTestSuite) testMaxRedirectsNegative() { | ||
suite.Error( | ||
MaxRedirects(-1)(nil, suite.newVia(0)), | ||
) | ||
|
||
suite.Error( | ||
MaxRedirects(-5)(nil, suite.newVia(1)), | ||
) | ||
|
||
suite.Error( | ||
MaxRedirects(-34871)(nil, suite.newVia(3)), | ||
) | ||
} | ||
|
||
func (suite *RedirectTestSuite) testMaxRedirectsZero() { | ||
suite.Error( | ||
MaxRedirects(0)(nil, suite.newVia(0)), | ||
) | ||
|
||
suite.Error( | ||
MaxRedirects(0)(nil, suite.newVia(1)), | ||
) | ||
|
||
suite.Error( | ||
MaxRedirects(0)(nil, suite.newVia(3)), | ||
) | ||
} | ||
|
||
func (suite *RedirectTestSuite) testMaxRedirectsSuccess() { | ||
suite.NoError( | ||
MaxRedirects(1)(nil, suite.newVia(0)), | ||
) | ||
|
||
suite.NoError( | ||
MaxRedirects(3)(nil, suite.newVia(1)), | ||
) | ||
|
||
suite.NoError( | ||
MaxRedirects(20)(nil, suite.newVia(16)), | ||
) | ||
} | ||
|
||
func (suite *RedirectTestSuite) testMaxRedirectsFail() { | ||
suite.Error( | ||
MaxRedirects(1)(nil, suite.newVia(1)), | ||
) | ||
|
||
suite.Error( | ||
MaxRedirects(1)(nil, suite.newVia(2)), | ||
) | ||
|
||
suite.Error( | ||
MaxRedirects(4)(nil, suite.newVia(4)), | ||
) | ||
|
||
suite.Error( | ||
MaxRedirects(3)(nil, suite.newVia(6)), | ||
) | ||
|
||
suite.Error( | ||
MaxRedirects(20)(nil, suite.newVia(22)), | ||
) | ||
} | ||
|
||
func (suite *RedirectTestSuite) TestMaxRedirects() { | ||
suite.Run("Negative", suite.testMaxRedirectsNegative) | ||
suite.Run("Zero", suite.testMaxRedirectsZero) | ||
suite.Run("Success", suite.testMaxRedirectsSuccess) | ||
suite.Run("Fail", suite.testMaxRedirectsFail) | ||
} | ||
|
||
func (suite *RedirectTestSuite) checkRedirectSuccess(*http.Request, []*http.Request) error { | ||
return nil | ||
} | ||
|
||
func (suite *RedirectTestSuite) checkRedirectSuccesses(count int) (checks []CheckRedirect) { | ||
checks = make([]CheckRedirect, 0, count) | ||
for len(checks) < cap(checks) { | ||
checks = append(checks, suite.checkRedirectSuccess) | ||
} | ||
|
||
return | ||
} | ||
|
||
func (suite *RedirectTestSuite) checkRedirectFail(*http.Request, []*http.Request) error { | ||
return errors.New("test error") | ||
} | ||
|
||
func (suite *RedirectTestSuite) testNewCheckRedirectsNil() { | ||
suite.Nil( | ||
NewCheckRedirects(), | ||
) | ||
|
||
suite.Nil( | ||
NewCheckRedirects(nil), | ||
) | ||
|
||
suite.Nil( | ||
NewCheckRedirects(nil, suite.checkRedirectSuccess), | ||
) | ||
|
||
suite.Nil( | ||
NewCheckRedirects(suite.checkRedirectSuccess, nil, suite.checkRedirectSuccess), | ||
) | ||
} | ||
|
||
func (suite *RedirectTestSuite) testNewCheckRedirectsSuccess() { | ||
for _, count := range []int{1, 2, 5} { | ||
suite.Run(fmt.Sprintf("count=%d", count), func() { | ||
checkRedirect := NewCheckRedirects( | ||
suite.checkRedirectSuccesses(count)..., | ||
) | ||
|
||
suite.Require().NotNil(checkRedirect) | ||
|
||
// won't matter what's passed, as our test functions don't use the parameters | ||
suite.NoError(checkRedirect(nil, nil)) | ||
}) | ||
} | ||
} | ||
|
||
func (suite *RedirectTestSuite) testNewCheckRedirectsFail() { | ||
for _, count := range []int{1, 2, 5} { | ||
suite.Run(fmt.Sprintf("count=%d", count), func() { | ||
components := suite.checkRedirectSuccesses(count) | ||
|
||
// any fail will fail the entire check | ||
components[len(components)/2] = suite.checkRedirectFail | ||
checkRedirect := NewCheckRedirects(components...) | ||
|
||
suite.Require().NotNil(checkRedirect) | ||
|
||
// won't matter what's passed, as our test functions don't use the parameters | ||
suite.Error(checkRedirect(nil, nil)) | ||
}) | ||
} | ||
} | ||
|
||
func (suite *RedirectTestSuite) TestNewCheckRedirects() { | ||
suite.Run("Nil", suite.testNewCheckRedirectsNil) | ||
suite.Run("Success", suite.testNewCheckRedirectsSuccess) | ||
suite.Run("Fail", suite.testNewCheckRedirectsFail) | ||
} | ||
|
||
func TestRedirect(t *testing.T) { | ||
suite.Run(t, new(RedirectTestSuite)) | ||
} |