Skip to content

Commit

Permalink
Add interface RequestBody to encapsulate bodies that can be replayed …
Browse files Browse the repository at this point in the history
…and those that can't (#695)
  • Loading branch information
bmoylan authored Nov 19, 2024
1 parent 1cd3256 commit 74beb11
Show file tree
Hide file tree
Showing 7 changed files with 432 additions and 88 deletions.
6 changes: 6 additions & 0 deletions changelog/@unreleased/pr-695.v2.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
type: improvement
improvement:
description: Add interface RequestBody to encapsulate bodies that can be replayed
and those that can't
links:
- https://github.com/palantir/conjure-go-runtime/pull/695
64 changes: 28 additions & 36 deletions conjure-go-client/httpclient/body_handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,7 @@
package httpclient

import (
"bytes"
"io"
"io/ioutil"
"fmt"
"net/http"

"github.com/palantir/conjure-go-runtime/v2/conjure-go-contract/codecs"
Expand Down Expand Up @@ -57,47 +55,41 @@ func (b *bodyMiddleware) RoundTrip(req *http.Request, next http.RoundTripper) (*
// setRequestBody returns a function that should be called once the request has been completed.
func (b *bodyMiddleware) setRequestBody(req *http.Request) (func(), error) {
cleanup := func() {}
var requestBody RequestBody

if b.requestInput == nil {
return cleanup, nil
}

// Special case: if the requestInput is an io.ReadCloser and the requestEncoder is nil,
// use the provided input directly as the request body.
if bodyReadCloser, ok := b.requestInput.(io.ReadCloser); ok && b.requestEncoder == nil {
req.Body = bodyReadCloser
// Use the same heuristic as http.NewRequest to generate the "GetBody" function.
if newReq, err := http.NewRequest("", "", bodyReadCloser); err == nil {
req.GetBody = newReq.GetBody
}
return cleanup, nil
}

var buf *bytes.Buffer
if b.bufferPool != nil {
buf = b.bufferPool.Get()
cleanup = func() {
b.bufferPool.Put(buf)
requestBody = RequestBodyEmpty()
} else if b.requestEncoder != nil {
if b.bufferPool != nil {
// If buffer pool is set, use it with Encode and return a func to return the buffer to the pool.
buf := b.bufferPool.Get()
cleanup = func() {
b.bufferPool.Put(buf)
}
requestBody = RequestBodyEncoderObjectBuffer(b.requestInput, b.requestEncoder, buf)
} else {
// If buffer pool is not set, let Marshal allocate memory for the serialized object.
requestBody = RequestBodyEncoderObject(b.requestInput, b.requestEncoder)
}
} else if body, ok := b.requestInput.(RequestBody); ok {
// Special case: if the requestInput is a RequestBody and the requestEncoder is nil,
// use the provided input directly as the request body.
requestBody = body
} else {
buf = new(bytes.Buffer)
return nil, werror.ErrorWithContextParams(req.Context(), "requestEncoder is nil but requestInput is not RequestBody",
werror.SafeParam("requestInputType", fmt.Sprintf("%T", b.requestInput)))
}

if err := b.requestEncoder.Encode(buf, b.requestInput); err != nil {
return cleanup, werror.Wrap(err, "failed to encode request object")
}
return cleanup, requestBody.setRequestBody(req)
}

if buf.Len() != 0 {
req.Body = ioutil.NopCloser(buf)
req.ContentLength = int64(buf.Len())
req.GetBody = func() (io.ReadCloser, error) {
return ioutil.NopCloser(bytes.NewReader(buf.Bytes())), nil
}
} else {
req.Body = http.NoBody
req.GetBody = func() (io.ReadCloser, error) { return http.NoBody, nil }
// returns true if the request body is a noRetriesRequestBody
func (b *bodyMiddleware) noRetriesRequestBody() bool {
if b.requestEncoder == nil && b.requestInput != nil {
_, ok := b.requestInput.(noRetriesRequestBody)
return ok
}
return cleanup, nil
return false
}

func (b *bodyMiddleware) readResponse(resp *http.Response, respErr error) error {
Expand Down
149 changes: 125 additions & 24 deletions conjure-go-client/httpclient/body_handler_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,11 @@ package httpclient_test
import (
"bytes"
"context"
"fmt"
"io"
"io/ioutil"
"net/http"
"net/http/httptest"
"strings"
"testing"

"github.com/palantir/conjure-go-runtime/v2/conjure-go-client/httpclient"
Expand Down Expand Up @@ -70,7 +71,7 @@ func TestRawBody(t *testing.T) {

server := httptest.NewServer(http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
assert.Equal(t, "TestNewRequest", req.Header.Get("User-Agent"))
gotReqBytes, err := ioutil.ReadAll(req.Body)
gotReqBytes, err := io.ReadAll(req.Body)
assert.NoError(t, err)
assert.Equal(t, gotReqBytes, reqVar)
_, err = rw.Write(respVar)
Expand All @@ -86,14 +87,12 @@ func TestRawBody(t *testing.T) {

resp, err := client.Do(context.Background(),
httpclient.WithRequestMethod(http.MethodPost),
httpclient.WithRawRequestBodyProvider(func() io.ReadCloser {
return ioutil.NopCloser(bytes.NewBuffer(reqVar))
}),
httpclient.WithBinaryRequestBody(httpclient.RequestBodyInMemory(bytes.NewBuffer(reqVar))),
httpclient.WithRawResponseBody(),
)
assert.NoError(t, err)

gotRespBytes, err := ioutil.ReadAll(resp.Body)
gotRespBytes, err := io.ReadAll(resp.Body)
assert.NoError(t, err)
defer func() {
_ = resp.Body.Close()
Expand All @@ -108,7 +107,7 @@ func TestRawRequestRetry(t *testing.T) {
requestBytes := []byte{12, 13}

server := httptest.NewServer(http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
gotReqBytes, err := ioutil.ReadAll(req.Body)
gotReqBytes, err := io.ReadAll(req.Body)
assert.NoError(t, err)
assert.Equal(t, requestBytes, gotReqBytes)
if count == 0 {
Expand All @@ -124,9 +123,7 @@ func TestRawRequestRetry(t *testing.T) {

_, err = client.Do(
context.Background(),
httpclient.WithRawRequestBodyProvider(func() io.ReadCloser {
return ioutil.NopCloser(bytes.NewReader(requestBytes))
}),
httpclient.WithBinaryRequestBody(httpclient.RequestBodyInMemory(bytes.NewReader(requestBytes))),
httpclient.WithRequestMethod(http.MethodPost))
assert.NoError(t, err)
assert.Equal(t, 2, count)
Expand All @@ -139,15 +136,22 @@ func TestRedirectWithBodyAndBytesBuffer(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
var actualReqVar map[string]string
err := codecs.JSON.Decode(req.Body, &actualReqVar)
assert.NoError(t, err)
assert.Equal(t, reqVar, actualReqVar)
if assert.NoError(t, err) {
assert.Equal(t, reqVar, actualReqVar)
} else {
t.Log(err)
rw.WriteHeader(http.StatusInternalServerError)
return
}

switch req.URL.Path {
case "/redirect":
rw.Header().Add("Location", "/location")
rw.WriteHeader(302)
rw.WriteHeader(307)
case "/location":
assert.NoError(t, codecs.JSON.Encode(rw, respVar))
default:
rw.WriteHeader(http.StatusNotFound)
}
}))
defer server.Close()
Expand All @@ -159,16 +163,113 @@ func TestRedirectWithBodyAndBytesBuffer(t *testing.T) {
)
require.NoError(t, err)

var actualRespVar map[string]string
resp, err := client.Do(context.Background(),
httpclient.WithRequestMethod(http.MethodPost),
httpclient.WithPath("/redirect"),
httpclient.WithJSONRequest(&reqVar),
httpclient.WithJSONResponse(&actualRespVar),
)
t.Run("WithJSONRequest", func(t *testing.T) {
var actualRespVar map[string]string
resp, err := client.Do(context.Background(),
httpclient.WithRequestMethod(http.MethodPost),
httpclient.WithPath("/redirect"),
httpclient.WithJSONRequest(&reqVar),
httpclient.WithJSONResponse(&actualRespVar),
)

assert.NoError(t, err)
assert.NotNil(t, resp)
assert.Equal(t, resp.StatusCode, 200)
assert.Equal(t, respVar, actualRespVar)
require.NoError(t, err)
assert.NotNil(t, resp)
if assert.NotNil(t, resp) {
assert.Equal(t, resp.StatusCode, 200)
}
assert.Equal(t, respVar, actualRespVar)
})

t.Run("RequestBodyInMemory[*strings.Reader]", func(t *testing.T) {
var actualRespVar map[string]string
resp, err := client.Do(context.Background(),
httpclient.WithRequestMethod(http.MethodPost),
httpclient.WithPath("/redirect"),
httpclient.WithBinaryRequestBody(httpclient.RequestBodyInMemory(strings.NewReader(`{"1":"2"}`))),
httpclient.WithJSONResponse(&actualRespVar),
)

require.NoError(t, err)
assert.NotNil(t, resp)
if assert.NotNil(t, resp) {
assert.Equal(t, resp.StatusCode, 200)
}
assert.Equal(t, respVar, actualRespVar)
})

t.Run("RequestBodyStreamWithReplay", func(t *testing.T) {
var actualRespVar map[string]string
resp, err := client.Do(context.Background(),
httpclient.WithRequestMethod(http.MethodPost),
httpclient.WithPath("/redirect"),
httpclient.WithBinaryRequestBody(httpclient.RequestBodyStreamWithReplay(func() (io.ReadCloser, error) {
return io.NopCloser(strings.NewReader(`{"1":"2"}`)), nil
})),
httpclient.WithJSONResponse(&actualRespVar),
)

require.NoError(t, err)
if assert.NotNil(t, resp) {
assert.Equal(t, resp.StatusCode, 200)
}
assert.Equal(t, respVar, actualRespVar)
})

t.Run("RequestBodyStreamOnce posts body", func(t *testing.T) {
var actualRespVar map[string]string
resp, err := client.Do(context.Background(),
httpclient.WithRequestMethod(http.MethodPost),
httpclient.WithPath("/location"),
httpclient.WithBinaryRequestBody(httpclient.RequestBodyStreamOnce(func() (io.ReadCloser, error) {
return io.NopCloser(strings.NewReader(`{"1":"2"}`)), nil
})),
httpclient.WithJSONResponse(&actualRespVar),
)

require.NoError(t, err)
if assert.NotNil(t, resp) {
assert.Equal(t, resp.StatusCode, 200)
}
assert.Equal(t, respVar, actualRespVar)
})

t.Run("RequestBodyStreamOnce does not follow redirect", func(t *testing.T) {
var readOnce bool
var actualRespVar map[string]string
resp, err := client.Do(context.Background(),
httpclient.WithRequestMethod(http.MethodPost),
httpclient.WithPath("/redirect"),
httpclient.WithBinaryRequestBody(httpclient.RequestBodyStreamOnce(func() (io.ReadCloser, error) {
if readOnce {
return nil, fmt.Errorf("readOnce is true")
}
readOnce = true
return io.NopCloser(strings.NewReader(`{"1":"2"}`)), nil
})),
httpclient.WithJSONResponse(&actualRespVar),
)

require.EqualError(t, err, "httpclient request failed: 307 Temporary Redirect")
assert.Nil(t, resp)
})

t.Run("RequestBodyStreamOnce does not retry on 404", func(t *testing.T) {
var readOnce bool
var actualRespVar map[string]string
resp, err := client.Do(context.Background(),
httpclient.WithRequestMethod(http.MethodPost),
httpclient.WithPath("/invalid"),
httpclient.WithBinaryRequestBody(httpclient.RequestBodyStreamOnce(func() (io.ReadCloser, error) {
if readOnce {
return nil, fmt.Errorf("readOnce is true")
}
readOnce = true
return io.NopCloser(strings.NewReader(`{"1":"2"}`)), nil
})),
httpclient.WithJSONResponse(&actualRespVar),
)

require.EqualError(t, err, "httpclient request failed: 404 Not Found")
assert.Nil(t, resp)
})
}
Loading

0 comments on commit 74beb11

Please sign in to comment.