diff --git a/client/doc.go b/client/doc.go new file mode 100644 index 0000000..c127b2f --- /dev/null +++ b/client/doc.go @@ -0,0 +1,8 @@ +// SPDX-FileCopyrightText: 2024 Comcast Cable Communications Management, LLC +// SPDX-License-Identifier: Apache-2.0 + +/* +Package client provides infrastructure for build HTTP clients, including observability +and logging. +*/ +package client diff --git a/client/redirect.go b/client/redirect.go new file mode 100644 index 0000000..17e8867 --- /dev/null +++ b/client/redirect.go @@ -0,0 +1,108 @@ +// SPDX-FileCopyrightText: 2024 Comcast Cable Communications Management, LLC +// SPDX-License-Identifier: Apache-2.0 + +package client + +import ( + "fmt" + "net/http" +) + +// CheckRedirect is the type expected by http.Client.CheckRedirect. +// +// Closures of this type can also be chained together via the +// NewCheckRedirects function. +type CheckRedirect func(*http.Request, []*http.Request) error + +// CopyHeadersOnRedirect copies the headers from the most recent +// request into the next request. If no names are supplied, this +// function returns nil so that the default behavior will take over. +func CopyHeadersOnRedirect(names ...string) CheckRedirect { + if len(names) == 0 { + return nil + } + + names = append([]string{}, names...) + for i, n := range names { + names[i] = http.CanonicalHeaderKey(n) + } + + return func(request *http.Request, via []*http.Request) error { + previous := via[len(via)-1] // the most recent request + for _, n := range names { + // direct map access is faster, since we've already + // canonicalized everything + if values := previous.Header[n]; len(values) > 0 { + request.Header[n] = values + } + } + + return nil + } +} + +// MaxRedirects returns a CheckRedirect that returns an error if +// a maximum number of redirects has been reached. If the max +// value is 0 or negative, then no redirects are allowed. +func MaxRedirects(max int) CheckRedirect { + if max < 0 { + max = 0 + } + + // create the error once and reuse it + // this error text mimics the one used in net/http + err := fmt.Errorf("stopped after %d redirects", max) + return func(_ *http.Request, via []*http.Request) error { + if len(via) >= max { + return err + } + + return nil + } +} + +// NewCheckRedirects produces a CheckRedirect that is the logical AND +// of the given strategies. All the checks must pass, or the returned +// function halts early and returns the error from the failing check. +// +// Since a nil http.Request.CheckRedirect indicates that the internal +// default will be used, this function returns nil if no checks are +// supplied. Additionally, any nil checks are skipped. If all checks +// are nil, this function also returns nil. +func NewCheckRedirects(checks ...CheckRedirect) CheckRedirect { + // skip nils, but check first before making a copy + count := 0 + for _, c := range checks { + if c != nil { + count++ + } + } + + if count == 0 { + return nil + } + + // now make our safe copy. this avoids soft memory leaks, since + // this slice will be around a while. + clone := make([]CheckRedirect, 0, count) + for _, c := range checks { + if c != nil { + clone = append(clone, c) + } + } + + if len(clone) == 1 { + // optimization: use the sole non-nil check as is + return clone[0] + } + + return func(request *http.Request, via []*http.Request) error { + for _, c := range clone { + if err := c(request, via); err != nil { + return err + } + } + + return nil + } +} diff --git a/client/redirect_examples_test.go b/client/redirect_examples_test.go new file mode 100644 index 0000000..440bd20 --- /dev/null +++ b/client/redirect_examples_test.go @@ -0,0 +1,77 @@ +// SPDX-FileCopyrightText: 2024 Comcast Cable Communications Management, LLC +// SPDX-License-Identifier: Apache-2.0 + +package client + +import ( + "fmt" + "net/http" + "net/http/httptest" +) + +func ExampleCopyHeadersOnRedirect() { + request := httptest.NewRequest("GET", "/", nil) + previous := httptest.NewRequest("GET", "/", nil) + previous.Header.Set("Copy-Me", "copied value") + + client := http.Client{ + CheckRedirect: CopyHeadersOnRedirect("copy-me"), // canonicalization will happen + } + + if err := client.CheckRedirect(request, []*http.Request{previous}); err != nil { + // shouldn't be output + fmt.Println(err) + } + + fmt.Println(request.Header.Get("Copy-Me")) + + // Output: + // copied value +} + +func ExampleMaxRedirects() { + request := httptest.NewRequest("GET", "/", nil) + client := http.Client{ + CheckRedirect: MaxRedirects(5), + } + + if client.CheckRedirect(request, make([]*http.Request, 4)) == nil { + fmt.Println("fewer than 5 redirects") + } + + if client.CheckRedirect(request, make([]*http.Request, 6)) != nil { + fmt.Println("max redirects exceeded") + } + + // Output: + // fewer than 5 redirects + // max redirects exceeded +} + +func ExampleNewCheckRedirects() { + request := httptest.NewRequest("GET", "/", nil) + previous := httptest.NewRequest("GET", "/", nil) + previous.Header.Set("Copy-Me", "copied value") + + client := http.Client{ + CheckRedirect: NewCheckRedirects( + MaxRedirects(10), + CopyHeadersOnRedirect("Copy-Me"), + func(*http.Request, []*http.Request) error { + fmt.Println("custom check redirect") + return nil + }, + ), + } + + if err := client.CheckRedirect(request, []*http.Request{previous}); err != nil { + // shouldn't be output + fmt.Println(err) + } + + fmt.Println(request.Header.Get("Copy-Me")) + + // Output: + // custom check redirect + // copied value +} diff --git a/client/redirect_test.go b/client/redirect_test.go new file mode 100644 index 0000000..ba0053c --- /dev/null +++ b/client/redirect_test.go @@ -0,0 +1,268 @@ +// SPDX-FileCopyrightText: 2024 Comcast Cable Communications Management, LLC +// SPDX-License-Identifier: Apache-2.0 + +package client + +import ( + "errors" + "fmt" + "net/http" + "testing" + + "github.com/stretchr/testify/suite" +) + +type RedirectTestSuite struct { + suite.Suite +} + +func (suite *RedirectTestSuite) testCopyHeadersOnRedirectWithNames() { + var ( + via = []*http.Request{ + &http.Request{ + Header: http.Header{ + "X-Original": []string{"should not be copied"}, + }, + }, + &http.Request{ + Header: http.Header{ + "X-Original": []string{"overwritten"}, + "Single-Value": []string{"value1"}, + "Multi-Value": []string{"value1", "value2", "value3"}, + "Empty": nil, // empty values, so it should be skipped + }, + }, + } + + request = &http.Request{ + Header: http.Header{ + "Next": []string{"NextValue"}, + }, + } + + copier = CopyHeadersOnRedirect( + // check that canonicalization is happening ... + "x-original", + "sINGLE-VAlue", + "Multi-Value", + "Empty", + ) + ) + + suite.NoError( + copier(request, via), + ) + + suite.Equal( + http.Header{ + "Next": []string{"NextValue"}, // should be untouched + "X-Original": []string{"overwritten"}, + "Single-Value": []string{"value1"}, + "Multi-Value": []string{"value1", "value2", "value3"}, + }, + request.Header, + ) +} + +func (suite *RedirectTestSuite) testCopyHeadersOnRedirectNoNames() { + suite.Nil( + CopyHeadersOnRedirect(), + ) +} + +func (suite *RedirectTestSuite) TestCopyHeadersOnRedirect() { + suite.Run("WithNames", suite.testCopyHeadersOnRedirectWithNames) + suite.Run("NoNames", suite.testCopyHeadersOnRedirectNoNames) +} + +// newVia is a helper function that creates a number of dummy requests +// to use as a via in a check redirect test. +func (suite *RedirectTestSuite) newVia(count int) (via []*http.Request) { + if count > 0 { + via = make([]*http.Request, 0, count) + for len(via) < cap(via) { + via = append( + via, + &http.Request{ + Header: http.Header{}, + }, + ) + } + } + + return +} + +func (suite *RedirectTestSuite) testMaxRedirectsNegative() { + suite.Error( + MaxRedirects(-1)(nil, suite.newVia(0)), + ) + + suite.Error( + MaxRedirects(-5)(nil, suite.newVia(1)), + ) + + suite.Error( + MaxRedirects(-34871)(nil, suite.newVia(3)), + ) +} + +func (suite *RedirectTestSuite) testMaxRedirectsZero() { + suite.Error( + MaxRedirects(0)(nil, suite.newVia(0)), + ) + + suite.Error( + MaxRedirects(0)(nil, suite.newVia(1)), + ) + + suite.Error( + MaxRedirects(0)(nil, suite.newVia(3)), + ) +} + +func (suite *RedirectTestSuite) testMaxRedirectsSuccess() { + suite.NoError( + MaxRedirects(1)(nil, suite.newVia(0)), + ) + + suite.NoError( + MaxRedirects(3)(nil, suite.newVia(1)), + ) + + suite.NoError( + MaxRedirects(20)(nil, suite.newVia(16)), + ) +} + +func (suite *RedirectTestSuite) testMaxRedirectsFail() { + suite.Error( + MaxRedirects(1)(nil, suite.newVia(1)), + ) + + suite.Error( + MaxRedirects(1)(nil, suite.newVia(2)), + ) + + suite.Error( + MaxRedirects(4)(nil, suite.newVia(4)), + ) + + suite.Error( + MaxRedirects(3)(nil, suite.newVia(6)), + ) + + suite.Error( + MaxRedirects(20)(nil, suite.newVia(22)), + ) +} + +func (suite *RedirectTestSuite) TestMaxRedirects() { + suite.Run("Negative", suite.testMaxRedirectsNegative) + suite.Run("Zero", suite.testMaxRedirectsZero) + suite.Run("Success", suite.testMaxRedirectsSuccess) + suite.Run("Fail", suite.testMaxRedirectsFail) +} + +func (suite *RedirectTestSuite) checkRedirectSuccess(*http.Request, []*http.Request) error { + return nil +} + +func (suite *RedirectTestSuite) checkRedirectSuccesses(count int) (checks []CheckRedirect) { + checks = make([]CheckRedirect, 0, count) + for len(checks) < cap(checks) { + checks = append(checks, suite.checkRedirectSuccess) + } + + return +} + +func (suite *RedirectTestSuite) checkRedirectFail(*http.Request, []*http.Request) error { + return errors.New("test error") +} + +func (suite *RedirectTestSuite) testNewCheckRedirectsNil() { + suite.Nil( + NewCheckRedirects(), + ) + + suite.Nil( + NewCheckRedirects(nil), + ) + + suite.Nil( + NewCheckRedirects(nil, nil, nil), + ) +} + +func (suite *RedirectTestSuite) testNewCheckRedirectsSuccess() { + suite.Run("NoNils", func() { + for _, count := range []int{1, 2, 5} { + suite.Run(fmt.Sprintf("count=%d", count), func() { + checkRedirect := NewCheckRedirects( + suite.checkRedirectSuccesses(count)..., + ) + + suite.Require().NotNil(checkRedirect) + + // won't matter what's passed, as our test functions don't use the parameters + suite.NoError(checkRedirect(nil, nil)) + }) + } + }) + + suite.Run("WithNils", func() { + checkRedirect := NewCheckRedirects( + suite.checkRedirectSuccess, + nil, + suite.checkRedirectSuccess, + ) + + suite.Require().NotNil(checkRedirect) + + // won't matter what's passed, as our test functions don't use the parameters + suite.NoError(checkRedirect(nil, nil)) + }) +} + +func (suite *RedirectTestSuite) testNewCheckRedirectsFail() { + suite.Run("NoNils", func() { + for _, count := range []int{1, 2, 5} { + suite.Run(fmt.Sprintf("count=%d", count), func() { + components := suite.checkRedirectSuccesses(count) + + // any fail will fail the entire check + components[len(components)/2] = suite.checkRedirectFail + checkRedirect := NewCheckRedirects(components...) + + suite.Require().NotNil(checkRedirect) + + // won't matter what's passed, as our test functions don't use the parameters + suite.Error(checkRedirect(nil, nil)) + }) + } + }) + + suite.Run("WithNils", func() { + checkRedirect := NewCheckRedirects( + suite.checkRedirectFail, + nil, + suite.checkRedirectSuccess, + ) + + suite.Require().NotNil(checkRedirect) + + // won't matter what's passed, as our test functions don't use the parameters + suite.Error(checkRedirect(nil, nil)) + }) +} + +func (suite *RedirectTestSuite) TestNewCheckRedirects() { + suite.Run("Nil", suite.testNewCheckRedirectsNil) + suite.Run("Success", suite.testNewCheckRedirectsSuccess) + suite.Run("Fail", suite.testNewCheckRedirectsFail) +} + +func TestRedirect(t *testing.T) { + suite.Run(t, new(RedirectTestSuite)) +}