generated from xmidt-org/.go-template
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #88 from xmidt-org/feature/composable-checkredirect
Feature/composable checkredirect
- Loading branch information
Showing
4 changed files
with
461 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,8 @@ | ||
// SPDX-FileCopyrightText: 2024 Comcast Cable Communications Management, LLC | ||
// SPDX-License-Identifier: Apache-2.0 | ||
|
||
/* | ||
Package client provides infrastructure for build HTTP clients, including observability | ||
and logging. | ||
*/ | ||
package client |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,108 @@ | ||
// SPDX-FileCopyrightText: 2024 Comcast Cable Communications Management, LLC | ||
// SPDX-License-Identifier: Apache-2.0 | ||
|
||
package client | ||
|
||
import ( | ||
"fmt" | ||
"net/http" | ||
) | ||
|
||
// CheckRedirect is the type expected by http.Client.CheckRedirect. | ||
// | ||
// Closures of this type can also be chained together via the | ||
// NewCheckRedirects function. | ||
type CheckRedirect func(*http.Request, []*http.Request) error | ||
|
||
// CopyHeadersOnRedirect copies the headers from the most recent | ||
// request into the next request. If no names are supplied, this | ||
// function returns nil so that the default behavior will take over. | ||
func CopyHeadersOnRedirect(names ...string) CheckRedirect { | ||
if len(names) == 0 { | ||
return nil | ||
} | ||
|
||
names = append([]string{}, names...) | ||
for i, n := range names { | ||
names[i] = http.CanonicalHeaderKey(n) | ||
} | ||
|
||
return func(request *http.Request, via []*http.Request) error { | ||
previous := via[len(via)-1] // the most recent request | ||
for _, n := range names { | ||
// direct map access is faster, since we've already | ||
// canonicalized everything | ||
if values := previous.Header[n]; len(values) > 0 { | ||
request.Header[n] = values | ||
} | ||
} | ||
|
||
return nil | ||
} | ||
} | ||
|
||
// MaxRedirects returns a CheckRedirect that returns an error if | ||
// a maximum number of redirects has been reached. If the max | ||
// value is 0 or negative, then no redirects are allowed. | ||
func MaxRedirects(max int) CheckRedirect { | ||
if max < 0 { | ||
max = 0 | ||
} | ||
|
||
// create the error once and reuse it | ||
// this error text mimics the one used in net/http | ||
err := fmt.Errorf("stopped after %d redirects", max) | ||
return func(_ *http.Request, via []*http.Request) error { | ||
if len(via) >= max { | ||
return err | ||
} | ||
|
||
return nil | ||
} | ||
} | ||
|
||
// NewCheckRedirects produces a CheckRedirect that is the logical AND | ||
// of the given strategies. All the checks must pass, or the returned | ||
// function halts early and returns the error from the failing check. | ||
// | ||
// Since a nil http.Request.CheckRedirect indicates that the internal | ||
// default will be used, this function returns nil if no checks are | ||
// supplied. Additionally, any nil checks are skipped. If all checks | ||
// are nil, this function also returns nil. | ||
func NewCheckRedirects(checks ...CheckRedirect) CheckRedirect { | ||
// skip nils, but check first before making a copy | ||
count := 0 | ||
for _, c := range checks { | ||
if c != nil { | ||
count++ | ||
} | ||
} | ||
|
||
if count == 0 { | ||
return nil | ||
} | ||
|
||
// now make our safe copy. this avoids soft memory leaks, since | ||
// this slice will be around a while. | ||
clone := make([]CheckRedirect, 0, count) | ||
for _, c := range checks { | ||
if c != nil { | ||
clone = append(clone, c) | ||
} | ||
} | ||
|
||
if len(clone) == 1 { | ||
// optimization: use the sole non-nil check as is | ||
return clone[0] | ||
} | ||
|
||
return func(request *http.Request, via []*http.Request) error { | ||
for _, c := range clone { | ||
if err := c(request, via); err != nil { | ||
return err | ||
} | ||
} | ||
|
||
return nil | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,77 @@ | ||
// SPDX-FileCopyrightText: 2024 Comcast Cable Communications Management, LLC | ||
// SPDX-License-Identifier: Apache-2.0 | ||
|
||
package client | ||
|
||
import ( | ||
"fmt" | ||
"net/http" | ||
"net/http/httptest" | ||
) | ||
|
||
func ExampleCopyHeadersOnRedirect() { | ||
request := httptest.NewRequest("GET", "/", nil) | ||
previous := httptest.NewRequest("GET", "/", nil) | ||
previous.Header.Set("Copy-Me", "copied value") | ||
|
||
client := http.Client{ | ||
CheckRedirect: CopyHeadersOnRedirect("copy-me"), // canonicalization will happen | ||
} | ||
|
||
if err := client.CheckRedirect(request, []*http.Request{previous}); err != nil { | ||
// shouldn't be output | ||
fmt.Println(err) | ||
} | ||
|
||
fmt.Println(request.Header.Get("Copy-Me")) | ||
|
||
// Output: | ||
// copied value | ||
} | ||
|
||
func ExampleMaxRedirects() { | ||
request := httptest.NewRequest("GET", "/", nil) | ||
client := http.Client{ | ||
CheckRedirect: MaxRedirects(5), | ||
} | ||
|
||
if client.CheckRedirect(request, make([]*http.Request, 4)) == nil { | ||
fmt.Println("fewer than 5 redirects") | ||
} | ||
|
||
if client.CheckRedirect(request, make([]*http.Request, 6)) != nil { | ||
fmt.Println("max redirects exceeded") | ||
} | ||
|
||
// Output: | ||
// fewer than 5 redirects | ||
// max redirects exceeded | ||
} | ||
|
||
func ExampleNewCheckRedirects() { | ||
request := httptest.NewRequest("GET", "/", nil) | ||
previous := httptest.NewRequest("GET", "/", nil) | ||
previous.Header.Set("Copy-Me", "copied value") | ||
|
||
client := http.Client{ | ||
CheckRedirect: NewCheckRedirects( | ||
MaxRedirects(10), | ||
CopyHeadersOnRedirect("Copy-Me"), | ||
func(*http.Request, []*http.Request) error { | ||
fmt.Println("custom check redirect") | ||
return nil | ||
}, | ||
), | ||
} | ||
|
||
if err := client.CheckRedirect(request, []*http.Request{previous}); err != nil { | ||
// shouldn't be output | ||
fmt.Println(err) | ||
} | ||
|
||
fmt.Println(request.Header.Get("Copy-Me")) | ||
|
||
// Output: | ||
// custom check redirect | ||
// copied value | ||
} |
Oops, something went wrong.