Skip to content

Commit

Permalink
composable CheckRedirect functions for clients
Browse files Browse the repository at this point in the history
  • Loading branch information
johnabass committed Jul 16, 2024
1 parent 1b4539b commit 0f0bd4b
Show file tree
Hide file tree
Showing 3 changed files with 347 additions and 0 deletions.
8 changes: 8 additions & 0 deletions client/doc.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
// SPDX-FileCopyrightText: 2024 Comcast Cable Communications Management, LLC
// SPDX-License-Identifier: Apache-2.0

/*
Package client provides infrastructure for build HTTP clients, including observability
and logging.
*/
package client
97 changes: 97 additions & 0 deletions client/redirect.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
// SPDX-FileCopyrightText: 2024 Comcast Cable Communications Management, LLC
// SPDX-License-Identifier: Apache-2.0

package client

import (
"fmt"
"net/http"
)

// CheckRedirect is the type expected by http.Client.CheckRedirect.
//
// Closures of this type can also be chained together via the
// NewCheckRedirects function.
type CheckRedirect func(*http.Request, []*http.Request) error

// CopyHeadersOnRedirect copies the headers from the most recent
// request into the next request. If no names are supplied, this
// function returns nil so that the default behavior will take over.
func CopyHeadersOnRedirect(names ...string) CheckRedirect {
if len(names) == 0 {
return nil
}

names = append([]string{}, names...)
for i, n := range names {
names[i] = http.CanonicalHeaderKey(n)
}

return func(request *http.Request, via []*http.Request) error {
previous := via[len(via)-1] // the most recent request
for _, n := range names {
// direct map access is faster, since we've already
// canonicalized everything
if values := previous.Header[n]; len(values) > 0 {
request.Header[n] = values
}
}

return nil
}
}

// MaxRedirects returns a CheckRedirect that returns an error if
// a maximum number of redirects has been reached. If the max
// value is 0 or negative, then no redirects are allowed.
func MaxRedirects(max int) CheckRedirect {
if max < 0 {
max = 0
}

// create the error once and reuse it
// this error text mimics the one used in net/http
err := fmt.Errorf("stopped after %d redirects", max)
return func(_ *http.Request, via []*http.Request) error {
if len(via) >= max {
return err
}

return nil
}
}

// NewCheckRedirects produces a CheckRedirect that is the logical AND
// of the given strategies. All the checks must pass, or the returned
// function halts early and returns the error from the failing check.
//
// Since a nil http.Request.CheckRedirect indicates that the internal
// default will be used, this function returns nil if no checks are
// supplied. Additionally, any nil checks are skipped. If all checks
// are nil, this function also returns nil.
func NewCheckRedirects(checks ...CheckRedirect) CheckRedirect {
if len(checks) == 0 {
return nil
}

// check nils before allocating a copy
for _, c := range checks {
if c == nil {
return nil
}
}

// optimization: if there's only (1) check, just use that
if len(checks) == 1 {
return checks[0]
}

checks = append([]CheckRedirect{}, checks...)
return func(request *http.Request, via []*http.Request) (err error) {
for i := 0; err == nil && i < len(checks); i++ {
err = checks[i](request, via)
}

return
}
}
242 changes: 242 additions & 0 deletions client/redirect_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,242 @@
// SPDX-FileCopyrightText: 2024 Comcast Cable Communications Management, LLC
// SPDX-License-Identifier: Apache-2.0

package client

import (
"errors"
"fmt"
"net/http"
"testing"

"github.com/stretchr/testify/suite"
)

type RedirectTestSuite struct {
suite.Suite
}

func (suite *RedirectTestSuite) testCopyHeadersOnRedirectWithNames() {
var (
via = []*http.Request{
&http.Request{
Header: http.Header{
"X-Original": []string{"should not be copied"},
},
},
&http.Request{
Header: http.Header{
"X-Original": []string{"overwritten"},
"Single-Value": []string{"value1"},
"Multi-Value": []string{"value1", "value2", "value3"},
"Empty": nil, // empty values, so it should be skipped
},
},
}

request = &http.Request{
Header: http.Header{
"Next": []string{"NextValue"},
},
}

copier = CopyHeadersOnRedirect(
// check that canonicalization is happening ...
"x-original",
"sINGLE-VAlue",
"Multi-Value",
"Empty",
)
)

suite.NoError(
copier(request, via),
)

suite.Equal(
http.Header{
"Next": []string{"NextValue"}, // should be untouched
"X-Original": []string{"overwritten"},
"Single-Value": []string{"value1"},
"Multi-Value": []string{"value1", "value2", "value3"},
},
request.Header,
)
}

func (suite *RedirectTestSuite) testCopyHeadersOnRedirectNoNames() {
suite.Nil(
CopyHeadersOnRedirect(),
)
}

func (suite *RedirectTestSuite) TestCopyHeadersOnRedirect() {
suite.Run("WithNames", suite.testCopyHeadersOnRedirectWithNames)
suite.Run("NoNames", suite.testCopyHeadersOnRedirectNoNames)
}

// newVia is a helper function that creates a number of dummy requests
// to use as a via in a check redirect test.
func (suite *RedirectTestSuite) newVia(count int) (via []*http.Request) {
if count > 0 {
via = make([]*http.Request, 0, count)
for len(via) < cap(via) {
via = append(
via,
&http.Request{
Header: http.Header{},
},
)
}
}

return
}

func (suite *RedirectTestSuite) testMaxRedirectsNegative() {
suite.Error(
MaxRedirects(-1)(nil, suite.newVia(0)),
)

suite.Error(
MaxRedirects(-5)(nil, suite.newVia(1)),
)

suite.Error(
MaxRedirects(-34871)(nil, suite.newVia(3)),
)
}

func (suite *RedirectTestSuite) testMaxRedirectsZero() {
suite.Error(
MaxRedirects(0)(nil, suite.newVia(0)),
)

suite.Error(
MaxRedirects(0)(nil, suite.newVia(1)),
)

suite.Error(
MaxRedirects(0)(nil, suite.newVia(3)),
)
}

func (suite *RedirectTestSuite) testMaxRedirectsSuccess() {
suite.NoError(
MaxRedirects(1)(nil, suite.newVia(0)),
)

suite.NoError(
MaxRedirects(3)(nil, suite.newVia(1)),
)

suite.NoError(
MaxRedirects(20)(nil, suite.newVia(16)),
)
}

func (suite *RedirectTestSuite) testMaxRedirectsFail() {
suite.Error(
MaxRedirects(1)(nil, suite.newVia(1)),
)

suite.Error(
MaxRedirects(1)(nil, suite.newVia(2)),
)

suite.Error(
MaxRedirects(4)(nil, suite.newVia(4)),
)

suite.Error(
MaxRedirects(3)(nil, suite.newVia(6)),
)

suite.Error(
MaxRedirects(20)(nil, suite.newVia(22)),
)
}

func (suite *RedirectTestSuite) TestMaxRedirects() {
suite.Run("Negative", suite.testMaxRedirectsNegative)
suite.Run("Zero", suite.testMaxRedirectsZero)
suite.Run("Success", suite.testMaxRedirectsSuccess)
suite.Run("Fail", suite.testMaxRedirectsFail)
}

func (suite *RedirectTestSuite) checkRedirectSuccess(*http.Request, []*http.Request) error {
return nil
}

func (suite *RedirectTestSuite) checkRedirectSuccesses(count int) (checks []CheckRedirect) {
checks = make([]CheckRedirect, 0, count)
for len(checks) < cap(checks) {
checks = append(checks, suite.checkRedirectSuccess)
}

return
}

func (suite *RedirectTestSuite) checkRedirectFail(*http.Request, []*http.Request) error {
return errors.New("test error")
}

func (suite *RedirectTestSuite) testNewCheckRedirectsNil() {
suite.Nil(
NewCheckRedirects(),
)

suite.Nil(
NewCheckRedirects(nil),
)

suite.Nil(
NewCheckRedirects(nil, suite.checkRedirectSuccess),
)

suite.Nil(
NewCheckRedirects(suite.checkRedirectSuccess, nil, suite.checkRedirectSuccess),
)
}

func (suite *RedirectTestSuite) testNewCheckRedirectsSuccess() {
for _, count := range []int{1, 2, 5} {
suite.Run(fmt.Sprintf("count=%d", count), func() {
checkRedirect := NewCheckRedirects(
suite.checkRedirectSuccesses(count)...,
)

suite.Require().NotNil(checkRedirect)

// won't matter what's passed, as our test functions don't use the parameters
suite.NoError(checkRedirect(nil, nil))
})
}
}

func (suite *RedirectTestSuite) testNewCheckRedirectsFail() {
for _, count := range []int{1, 2, 5} {
suite.Run(fmt.Sprintf("count=%d", count), func() {
components := suite.checkRedirectSuccesses(count)

// any fail will fail the entire check
components[len(components)/2] = suite.checkRedirectFail
checkRedirect := NewCheckRedirects(components...)

suite.Require().NotNil(checkRedirect)

// won't matter what's passed, as our test functions don't use the parameters
suite.Error(checkRedirect(nil, nil))
})
}
}

func (suite *RedirectTestSuite) TestNewCheckRedirects() {
suite.Run("Nil", suite.testNewCheckRedirectsNil)
suite.Run("Success", suite.testNewCheckRedirectsSuccess)
suite.Run("Fail", suite.testNewCheckRedirectsFail)
}

func TestRedirect(t *testing.T) {
suite.Run(t, new(RedirectTestSuite))
}

0 comments on commit 0f0bd4b

Please sign in to comment.