From 82345434b635e7f098f647ccaa8236696e07b87f Mon Sep 17 00:00:00 2001 From: Som-Som-CC <84196538+Som-Som-CC@users.noreply.github.com> Date: Wed, 25 Sep 2024 10:35:54 +0200 Subject: [PATCH] Retry to support forwarding server-made request. --- client.go | 12 ++++++++++++ client_test.go | 35 +++++++++++++++++++++++++++++++++++ 2 files changed, 47 insertions(+) diff --git a/client.go b/client.go index 2988414..55f9417 100644 --- a/client.go +++ b/client.go @@ -430,6 +430,18 @@ func (c *Client) setUA(req *http.Request) { func (c *Client) cloneBody(req *http.Request) io.ReadCloser { if c.retries > 0 && req.Body != nil && req.Body != http.NoBody { + if req.GetBody == nil { // Probably a server request body to be forwarded. + req.GetBody = func() (io.ReadCloser, error) { + recvdBuf, err := io.ReadAll(req.Body) + if err != nil { + return nil, err + } + clonedBuf := make([]byte, len(recvdBuf)) + copy(clonedBuf, recvdBuf) + req.Body = io.NopCloser(bytes.NewReader(recvdBuf)) + return io.NopCloser(bytes.NewReader(clonedBuf)), nil + } + } clonedBody, _ := req.GetBody() return clonedBody } diff --git a/client_test.go b/client_test.go index d7b71bf..0f9355b 100644 --- a/client_test.go +++ b/client_test.go @@ -349,6 +349,41 @@ func TestTimeout(t *testing.T) { } } +func TestRetryWithServerForwarding(t *testing.T) { + assert := assert.New(t) + + // Server the client forwards requests to + srvCounter := 0 + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if srvCounter&1 == 0 { + w.WriteHeader(502) + } else { + w.WriteHeader(200) + } + srvCounter++ + })) + defer srv.Close() + + client := NewClient().Root(srv.URL).Retry(3, time.Millisecond, time.Second) + { // With body + req, _ := http.NewRequest("POST", "/hello", strings.NewReader(`{"hello": "world"}`)) + req.GetBody = nil // Simulating server behavior + resp, err := client.Do(req) + assert.NoError(err) + assert.Equal(200, resp.StatusCode) + } + + { // No body + req, _ := http.NewRequest("DELETE", "/hello", nil) + req.GetBody = nil // Simulating server behavior + resp, err := client.Do(req) + assert.NoError(err) + assert.Equal(200, resp.StatusCode) + } + + assert.Equal(4, srvCounter) +} + func TestMethodsError(t *testing.T) { assert := assert.New(t)