From 843b6f78b075c2129ca48beab5b4a83c4ba3cfa6 Mon Sep 17 00:00:00 2001 From: Owen Cabalceta Date: Fri, 27 Oct 2023 13:02:59 -0400 Subject: [PATCH 1/2] patch: DecodeEntity doesn't honor contentType --- wrpcontext/contents.go | 26 ++++++++++++ wrpcontext/{wrpcontext.go => context.go} | 15 +------ wrpcontext/message.go | 28 +++++++++++++ wrphttp/decoders.go | 53 ++++++++++-------------- 4 files changed, 78 insertions(+), 44 deletions(-) create mode 100644 wrpcontext/contents.go rename wrpcontext/{wrpcontext.go => context.go} (70%) create mode 100644 wrpcontext/message.go diff --git a/wrpcontext/contents.go b/wrpcontext/contents.go new file mode 100644 index 0000000..8ab9548 --- /dev/null +++ b/wrpcontext/contents.go @@ -0,0 +1,26 @@ +// SPDX-FileCopyrightText: 2022 Comcast Cable Communications Management, LLC +// SPDX-License-Identifier: Apache-2.0 + +package wrpcontext + +import ( + "context" +) + +type contextContentsKey struct{} + +// Set provides a standard way to add a wrp message to a context.Context. This supports not only wrp.Message +// but also all the other message types, such as wrp.SimpleRequestResponse +func SetContents(ctx context.Context, b []byte) context.Context { + return context.WithValue(ctx, contextContentsKey{}, b) +} + +// Get a message from a context and return it as type T +func GetContents(ctx context.Context) ([]byte, bool) { + src := ctx.Value(contextContentsKey{}) + if src == nil { + return []byte{}, false + } + + return get[[]byte](ctx, src) +} diff --git a/wrpcontext/wrpcontext.go b/wrpcontext/context.go similarity index 70% rename from wrpcontext/wrpcontext.go rename to wrpcontext/context.go index 9b2ccca..e5f025d 100644 --- a/wrpcontext/wrpcontext.go +++ b/wrpcontext/context.go @@ -8,21 +8,8 @@ import ( "reflect" ) -type contextKey struct{} - -// Set provides a standard way to add a wrp message to a context.Context. This supports not only wrp.Message -// but also all the other message types, such as wrp.SimpleRequestResponse -func Set(ctx context.Context, msg any) context.Context { - return context.WithValue(ctx, contextKey{}, msg) -} - // Get a message from a context and return it as type T -func Get[T any](ctx context.Context) (dest T, ok bool) { - src := ctx.Value(contextKey{}) - if src == nil { - return - } - +func get[T any](ctx context.Context, src any) (dest T, ok bool) { // if src and dest are the exact same type if dest, ok = src.(T); ok { return diff --git a/wrpcontext/message.go b/wrpcontext/message.go new file mode 100644 index 0000000..d5358a6 --- /dev/null +++ b/wrpcontext/message.go @@ -0,0 +1,28 @@ +// SPDX-FileCopyrightText: 2022 Comcast Cable Communications Management, LLC +// SPDX-License-Identifier: Apache-2.0 + +package wrpcontext + +import ( + "context" + + "github.com/xmidt-org/wrp-go/v3" +) + +type contextWRPMessageKey struct{} + +// Set provides a standard way to add a wrp message to a context.Context. This supports not only wrp.Message +// but also all the other message types, such as wrp.SimpleRequestResponse +func SetMessage(ctx context.Context, msg any) context.Context { + return context.WithValue(ctx, contextWRPMessageKey{}, msg) +} + +// Get a message from a context and return it as type T +func GetMessage(ctx context.Context) (*wrp.Message, bool) { + src := ctx.Value(contextWRPMessageKey{}) + if src == nil { + return nil, false + } + + return get[*wrp.Message](ctx, src) +} diff --git a/wrphttp/decoders.go b/wrphttp/decoders.go index fd1713d..1c9fdb9 100644 --- a/wrphttp/decoders.go +++ b/wrphttp/decoders.go @@ -5,7 +5,6 @@ package wrphttp import ( "context" - "encoding/json" "fmt" "io" "net/http" @@ -29,7 +28,6 @@ func DefaultDecoder() Decoder { func DecodeEntity(defaultFormat wrp.Format) Decoder { return func(ctx context.Context, original *http.Request) (*Entity, error) { - format, err := DetermineFormat(defaultFormat, original.Header, "Content-Type") if err != nil { return nil, fmt.Errorf("failed to determine format of Content-Type header: %v", err) @@ -40,35 +38,25 @@ func DecodeEntity(defaultFormat wrp.Format) Decoder { return nil, fmt.Errorf("failed to determine format of Accept header: %v", err) } - // Check if the context already contains a message - // If so, return the original request's message as an entity - msg, ok := wrpcontext.Get[*wrp.Message](original.Context()) - if ok { - jsonBytes, err := json.Marshal(msg) + entity := &Entity{Format: format} + if contents, ok := wrpcontext.GetContents(original.Context()); ok { + entity.Bytes = contents + } else { + contents, err := io.ReadAll(original.Body) if err != nil { - return nil, err - } - entity := &Entity{ - Message: *msg, - Format: format, - Bytes: jsonBytes, + return nil, fmt.Errorf("failed to read request body: %v", err) } - return entity, nil - } - - contents, err := io.ReadAll(original.Body) - if err != nil { - return nil, err - } - entity := &Entity{ - Format: format, - Bytes: contents, + entity.Bytes = contents } - err = wrp.NewDecoderBytes(contents, format).Decode(&entity.Message) - if err != nil { - return nil, fmt.Errorf("failed to decode wrp: %v", err) + if msg, ok := wrpcontext.GetMessage(original.Context()); ok { + entity.Message = *msg + } else { + err = wrp.NewDecoderBytes(entity.Bytes, format).Decode(&entity.Message) + if err != nil { + return nil, fmt.Errorf("failed to decode wrp: %v", err) + } } return entity, err @@ -110,7 +98,7 @@ func DecodeRequestHeaders(ctx context.Context, original *http.Request) (*Entity, // Can work for servers that don't use a wrp.Handler func DecodeRequest(r *http.Request, msg any) (*http.Request, error) { - if _, ok := wrpcontext.Get[*wrp.Message](r.Context()); ok { + if _, ok := wrpcontext.GetMessage(r.Context()); ok { // Context already contains a message, so just return the original request return r, nil } @@ -122,18 +110,23 @@ func DecodeRequest(r *http.Request, msg any) (*http.Request, error) { var decodedMessage wrp.Message + contents, err := io.ReadAll(r.Body) + if err != nil { + return nil, fmt.Errorf("failed to read request body: %v", err) + } + // Try to decode the message using the HTTP Request headers // If this doesn't work, decode the message as Msgpack or JSON format if err = SetMessageFromHeaders(r.Header, &decodedMessage); err != nil { // Msgpack or JSON Format - bodyReader := r.Body - err = wrp.NewDecoder(bodyReader, format).Decode(&decodedMessage) + err = wrp.NewDecoderBytes(contents, format).Decode(&decodedMessage) if err != nil { return nil, fmt.Errorf("failed to decode wrp message: %v", err) } } - ctx := wrpcontext.Set(r.Context(), &decodedMessage) + ctx := wrpcontext.SetMessage(r.Context(), &decodedMessage) + ctx = wrpcontext.SetContents(ctx, contents) // Return a new request with the new context, containing the decoded message return r.WithContext(ctx), nil From 8ade81d43b143c31784e614bc161a862a987a5b0 Mon Sep 17 00:00:00 2001 From: Owen Cabalceta Date: Fri, 27 Oct 2023 16:51:58 -0400 Subject: [PATCH 2/2] chore: patch decoder test --- wrphttp/decoders.go | 2 +- wrphttp/decoders_test.go | 9 ++++----- 2 files changed, 5 insertions(+), 6 deletions(-) diff --git a/wrphttp/decoders.go b/wrphttp/decoders.go index 1c9fdb9..e4081f6 100644 --- a/wrphttp/decoders.go +++ b/wrphttp/decoders.go @@ -96,7 +96,7 @@ func DecodeRequestHeaders(ctx context.Context, original *http.Request) (*Entity, // DecodeRequest is a Decoder that provides lower-level way of decoding an *http.Request // Can work for servers that don't use a wrp.Handler -func DecodeRequest(r *http.Request, msg any) (*http.Request, error) { +func DecodeRequest(r *http.Request, _ any) (*http.Request, error) { if _, ok := wrpcontext.GetMessage(r.Context()); ok { // Context already contains a message, so just return the original request diff --git a/wrphttp/decoders_test.go b/wrphttp/decoders_test.go index 6020a65..04c596a 100644 --- a/wrphttp/decoders_test.go +++ b/wrphttp/decoders_test.go @@ -234,7 +234,7 @@ func testDecodeRequestSuccess(t *testing.T) { ) require.NoError( - wrp.NewEncoderBytes(&body, record.bodyFormat).Encode(&expected), + wrp.NewEncoderBytes(&body, record.bodyFormat).Encode(expected), ) request := httptest.NewRequest("POST", "/", bytes.NewBuffer(body)) @@ -246,15 +246,14 @@ func testDecodeRequestSuccess(t *testing.T) { expected.Type = record.msgType } - var msg wrp.Message - actual, err := DecodeRequest(request, &msg) - msg, ok := wrpcontext.Get[wrp.Message](actual.Context()) + actual, err := DecodeRequest(request, nil) + msg, ok := wrpcontext.GetMessage(actual.Context()) assert.True(ok) assert.Nil(err) require.NotNil(actual) require.NotNil(actual.Context()) - assert.Equal(expected, &msg) + assert.Equal(expected, msg) } }