Skip to content

Commit

Permalink
Merge pull request #88 from xmidt-org/feature/composable-checkredirect
Browse files Browse the repository at this point in the history
Feature/composable checkredirect
  • Loading branch information
johnabass authored Jul 17, 2024
2 parents 1b4539b + e5ccff9 commit b56c5cc
Show file tree
Hide file tree
Showing 4 changed files with 461 additions and 0 deletions.
8 changes: 8 additions & 0 deletions client/doc.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
// SPDX-FileCopyrightText: 2024 Comcast Cable Communications Management, LLC
// SPDX-License-Identifier: Apache-2.0

/*
Package client provides infrastructure for build HTTP clients, including observability
and logging.
*/
package client
108 changes: 108 additions & 0 deletions client/redirect.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,108 @@
// SPDX-FileCopyrightText: 2024 Comcast Cable Communications Management, LLC
// SPDX-License-Identifier: Apache-2.0

package client

import (
"fmt"
"net/http"
)

// CheckRedirect is the type expected by http.Client.CheckRedirect.
//
// Closures of this type can also be chained together via the
// NewCheckRedirects function.
type CheckRedirect func(*http.Request, []*http.Request) error

// CopyHeadersOnRedirect copies the headers from the most recent
// request into the next request. If no names are supplied, this
// function returns nil so that the default behavior will take over.
func CopyHeadersOnRedirect(names ...string) CheckRedirect {
if len(names) == 0 {
return nil
}

names = append([]string{}, names...)
for i, n := range names {
names[i] = http.CanonicalHeaderKey(n)
}

return func(request *http.Request, via []*http.Request) error {
previous := via[len(via)-1] // the most recent request
for _, n := range names {
// direct map access is faster, since we've already
// canonicalized everything
if values := previous.Header[n]; len(values) > 0 {
request.Header[n] = values
}
}

return nil
}
}

// MaxRedirects returns a CheckRedirect that returns an error if
// a maximum number of redirects has been reached. If the max
// value is 0 or negative, then no redirects are allowed.
func MaxRedirects(max int) CheckRedirect {
if max < 0 {
max = 0
}

// create the error once and reuse it
// this error text mimics the one used in net/http
err := fmt.Errorf("stopped after %d redirects", max)
return func(_ *http.Request, via []*http.Request) error {
if len(via) >= max {
return err
}

return nil
}
}

// NewCheckRedirects produces a CheckRedirect that is the logical AND
// of the given strategies. All the checks must pass, or the returned
// function halts early and returns the error from the failing check.
//
// Since a nil http.Request.CheckRedirect indicates that the internal
// default will be used, this function returns nil if no checks are
// supplied. Additionally, any nil checks are skipped. If all checks
// are nil, this function also returns nil.
func NewCheckRedirects(checks ...CheckRedirect) CheckRedirect {
// skip nils, but check first before making a copy
count := 0
for _, c := range checks {
if c != nil {
count++
}
}

if count == 0 {
return nil
}

// now make our safe copy. this avoids soft memory leaks, since
// this slice will be around a while.
clone := make([]CheckRedirect, 0, count)
for _, c := range checks {
if c != nil {
clone = append(clone, c)
}
}

if len(clone) == 1 {
// optimization: use the sole non-nil check as is
return clone[0]
}

return func(request *http.Request, via []*http.Request) error {
for _, c := range clone {
if err := c(request, via); err != nil {
return err
}
}

return nil
}
}
77 changes: 77 additions & 0 deletions client/redirect_examples_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
// SPDX-FileCopyrightText: 2024 Comcast Cable Communications Management, LLC
// SPDX-License-Identifier: Apache-2.0

package client

import (
"fmt"
"net/http"
"net/http/httptest"
)

func ExampleCopyHeadersOnRedirect() {
request := httptest.NewRequest("GET", "/", nil)
previous := httptest.NewRequest("GET", "/", nil)
previous.Header.Set("Copy-Me", "copied value")

client := http.Client{
CheckRedirect: CopyHeadersOnRedirect("copy-me"), // canonicalization will happen
}

if err := client.CheckRedirect(request, []*http.Request{previous}); err != nil {
// shouldn't be output
fmt.Println(err)
}

fmt.Println(request.Header.Get("Copy-Me"))

// Output:
// copied value
}

func ExampleMaxRedirects() {
request := httptest.NewRequest("GET", "/", nil)
client := http.Client{
CheckRedirect: MaxRedirects(5),
}

if client.CheckRedirect(request, make([]*http.Request, 4)) == nil {
fmt.Println("fewer than 5 redirects")
}

if client.CheckRedirect(request, make([]*http.Request, 6)) != nil {
fmt.Println("max redirects exceeded")
}

// Output:
// fewer than 5 redirects
// max redirects exceeded
}

func ExampleNewCheckRedirects() {
request := httptest.NewRequest("GET", "/", nil)
previous := httptest.NewRequest("GET", "/", nil)
previous.Header.Set("Copy-Me", "copied value")

client := http.Client{
CheckRedirect: NewCheckRedirects(
MaxRedirects(10),
CopyHeadersOnRedirect("Copy-Me"),
func(*http.Request, []*http.Request) error {
fmt.Println("custom check redirect")
return nil
},
),
}

if err := client.CheckRedirect(request, []*http.Request{previous}); err != nil {
// shouldn't be output
fmt.Println(err)
}

fmt.Println(request.Header.Get("Copy-Me"))

// Output:
// custom check redirect
// copied value
}
Loading

0 comments on commit b56c5cc

Please sign in to comment.