Skip to content

Commit

Permalink
Retry to support forwarding server-made request.
Browse files Browse the repository at this point in the history
  • Loading branch information
Som-Som-CC committed Sep 25, 2024
1 parent 4408cc1 commit 8234543
Show file tree
Hide file tree
Showing 2 changed files with 47 additions and 0 deletions.
12 changes: 12 additions & 0 deletions client.go
Original file line number Diff line number Diff line change
Expand Up @@ -430,6 +430,18 @@ func (c *Client) setUA(req *http.Request) {

func (c *Client) cloneBody(req *http.Request) io.ReadCloser {
if c.retries > 0 && req.Body != nil && req.Body != http.NoBody {
if req.GetBody == nil { // Probably a server request body to be forwarded.
req.GetBody = func() (io.ReadCloser, error) {
recvdBuf, err := io.ReadAll(req.Body)
if err != nil {
return nil, err
}
clonedBuf := make([]byte, len(recvdBuf))
copy(clonedBuf, recvdBuf)
req.Body = io.NopCloser(bytes.NewReader(recvdBuf))
return io.NopCloser(bytes.NewReader(clonedBuf)), nil
}
}
clonedBody, _ := req.GetBody()
return clonedBody
}
Expand Down
35 changes: 35 additions & 0 deletions client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -349,6 +349,41 @@ func TestTimeout(t *testing.T) {
}
}

func TestRetryWithServerForwarding(t *testing.T) {
assert := assert.New(t)

// Server the client forwards requests to
srvCounter := 0
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if srvCounter&1 == 0 {
w.WriteHeader(502)
} else {
w.WriteHeader(200)
}
srvCounter++
}))
defer srv.Close()

client := NewClient().Root(srv.URL).Retry(3, time.Millisecond, time.Second)
{ // With body
req, _ := http.NewRequest("POST", "/hello", strings.NewReader(`{"hello": "world"}`))
req.GetBody = nil // Simulating server behavior
resp, err := client.Do(req)
assert.NoError(err)
assert.Equal(200, resp.StatusCode)
}

{ // No body
req, _ := http.NewRequest("DELETE", "/hello", nil)
req.GetBody = nil // Simulating server behavior
resp, err := client.Do(req)
assert.NoError(err)
assert.Equal(200, resp.StatusCode)
}

assert.Equal(4, srvCounter)
}

func TestMethodsError(t *testing.T) {
assert := assert.New(t)

Expand Down

0 comments on commit 8234543

Please sign in to comment.