Skip to content

Commit

Permalink
Merge pull request #142 from xmidt-org/denopink/patch/wrp-DecodeEntit…
Browse files Browse the repository at this point in the history
…y-doesnt-honor-contentType

patch: DecodeEntity doesn't honor contentType
  • Loading branch information
denopink authored Oct 30, 2023
2 parents e3db47f + 8ade81d commit 3460c43
Show file tree
Hide file tree
Showing 5 changed files with 83 additions and 50 deletions.
26 changes: 26 additions & 0 deletions wrpcontext/contents.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
// SPDX-FileCopyrightText: 2022 Comcast Cable Communications Management, LLC
// SPDX-License-Identifier: Apache-2.0

package wrpcontext

import (
"context"
)

type contextContentsKey struct{}

// Set provides a standard way to add a wrp message to a context.Context. This supports not only wrp.Message
// but also all the other message types, such as wrp.SimpleRequestResponse
func SetContents(ctx context.Context, b []byte) context.Context {
return context.WithValue(ctx, contextContentsKey{}, b)
}

// Get a message from a context and return it as type T
func GetContents(ctx context.Context) ([]byte, bool) {
src := ctx.Value(contextContentsKey{})
if src == nil {
return []byte{}, false
}

return get[[]byte](ctx, src)
}
15 changes: 1 addition & 14 deletions wrpcontext/wrpcontext.go → wrpcontext/context.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,21 +8,8 @@ import (
"reflect"
)

type contextKey struct{}

// Set provides a standard way to add a wrp message to a context.Context. This supports not only wrp.Message
// but also all the other message types, such as wrp.SimpleRequestResponse
func Set(ctx context.Context, msg any) context.Context {
return context.WithValue(ctx, contextKey{}, msg)
}

// Get a message from a context and return it as type T
func Get[T any](ctx context.Context) (dest T, ok bool) {
src := ctx.Value(contextKey{})
if src == nil {
return
}

func get[T any](ctx context.Context, src any) (dest T, ok bool) {
// if src and dest are the exact same type
if dest, ok = src.(T); ok {
return
Expand Down
28 changes: 28 additions & 0 deletions wrpcontext/message.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
// SPDX-FileCopyrightText: 2022 Comcast Cable Communications Management, LLC
// SPDX-License-Identifier: Apache-2.0

package wrpcontext

import (
"context"

"github.com/xmidt-org/wrp-go/v3"
)

type contextWRPMessageKey struct{}

// Set provides a standard way to add a wrp message to a context.Context. This supports not only wrp.Message
// but also all the other message types, such as wrp.SimpleRequestResponse
func SetMessage(ctx context.Context, msg any) context.Context {
return context.WithValue(ctx, contextWRPMessageKey{}, msg)
}

// Get a message from a context and return it as type T
func GetMessage(ctx context.Context) (*wrp.Message, bool) {
src := ctx.Value(contextWRPMessageKey{})
if src == nil {
return nil, false
}

return get[*wrp.Message](ctx, src)
}
55 changes: 24 additions & 31 deletions wrphttp/decoders.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@ package wrphttp

import (
"context"
"encoding/json"
"fmt"
"io"
"net/http"
Expand All @@ -29,7 +28,6 @@ func DefaultDecoder() Decoder {

func DecodeEntity(defaultFormat wrp.Format) Decoder {
return func(ctx context.Context, original *http.Request) (*Entity, error) {

format, err := DetermineFormat(defaultFormat, original.Header, "Content-Type")
if err != nil {
return nil, fmt.Errorf("failed to determine format of Content-Type header: %v", err)
Expand All @@ -40,35 +38,25 @@ func DecodeEntity(defaultFormat wrp.Format) Decoder {
return nil, fmt.Errorf("failed to determine format of Accept header: %v", err)
}

// Check if the context already contains a message
// If so, return the original request's message as an entity
msg, ok := wrpcontext.Get[*wrp.Message](original.Context())
if ok {
jsonBytes, err := json.Marshal(msg)
entity := &Entity{Format: format}
if contents, ok := wrpcontext.GetContents(original.Context()); ok {
entity.Bytes = contents
} else {
contents, err := io.ReadAll(original.Body)
if err != nil {
return nil, err
}
entity := &Entity{
Message: *msg,
Format: format,
Bytes: jsonBytes,
return nil, fmt.Errorf("failed to read request body: %v", err)
}
return entity, nil
}

contents, err := io.ReadAll(original.Body)
if err != nil {
return nil, err
}

entity := &Entity{
Format: format,
Bytes: contents,
entity.Bytes = contents
}

err = wrp.NewDecoderBytes(contents, format).Decode(&entity.Message)
if err != nil {
return nil, fmt.Errorf("failed to decode wrp: %v", err)
if msg, ok := wrpcontext.GetMessage(original.Context()); ok {
entity.Message = *msg
} else {
err = wrp.NewDecoderBytes(entity.Bytes, format).Decode(&entity.Message)
if err != nil {
return nil, fmt.Errorf("failed to decode wrp: %v", err)
}
}

return entity, err
Expand Down Expand Up @@ -108,9 +96,9 @@ func DecodeRequestHeaders(ctx context.Context, original *http.Request) (*Entity,

// DecodeRequest is a Decoder that provides lower-level way of decoding an *http.Request
// Can work for servers that don't use a wrp.Handler
func DecodeRequest(r *http.Request, msg any) (*http.Request, error) {
func DecodeRequest(r *http.Request, _ any) (*http.Request, error) {

if _, ok := wrpcontext.Get[*wrp.Message](r.Context()); ok {
if _, ok := wrpcontext.GetMessage(r.Context()); ok {
// Context already contains a message, so just return the original request
return r, nil
}
Expand All @@ -122,18 +110,23 @@ func DecodeRequest(r *http.Request, msg any) (*http.Request, error) {

var decodedMessage wrp.Message

contents, err := io.ReadAll(r.Body)
if err != nil {
return nil, fmt.Errorf("failed to read request body: %v", err)
}

// Try to decode the message using the HTTP Request headers
// If this doesn't work, decode the message as Msgpack or JSON format
if err = SetMessageFromHeaders(r.Header, &decodedMessage); err != nil {
// Msgpack or JSON Format
bodyReader := r.Body
err = wrp.NewDecoder(bodyReader, format).Decode(&decodedMessage)
err = wrp.NewDecoderBytes(contents, format).Decode(&decodedMessage)
if err != nil {
return nil, fmt.Errorf("failed to decode wrp message: %v", err)
}
}

ctx := wrpcontext.Set(r.Context(), &decodedMessage)
ctx := wrpcontext.SetMessage(r.Context(), &decodedMessage)
ctx = wrpcontext.SetContents(ctx, contents)

// Return a new request with the new context, containing the decoded message
return r.WithContext(ctx), nil
Expand Down
9 changes: 4 additions & 5 deletions wrphttp/decoders_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -234,7 +234,7 @@ func testDecodeRequestSuccess(t *testing.T) {
)

require.NoError(
wrp.NewEncoderBytes(&body, record.bodyFormat).Encode(&expected),
wrp.NewEncoderBytes(&body, record.bodyFormat).Encode(expected),
)

request := httptest.NewRequest("POST", "/", bytes.NewBuffer(body))
Expand All @@ -246,15 +246,14 @@ func testDecodeRequestSuccess(t *testing.T) {
expected.Type = record.msgType
}

var msg wrp.Message
actual, err := DecodeRequest(request, &msg)
msg, ok := wrpcontext.Get[wrp.Message](actual.Context())
actual, err := DecodeRequest(request, nil)
msg, ok := wrpcontext.GetMessage(actual.Context())

assert.True(ok)
assert.Nil(err)
require.NotNil(actual)
require.NotNil(actual.Context())
assert.Equal(expected, &msg)
assert.Equal(expected, msg)

}
}
Expand Down

0 comments on commit 3460c43

Please sign in to comment.