Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

patch: DecodeEntity doesn't honor contentType #142

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 26 additions & 0 deletions wrpcontext/contents.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
// SPDX-FileCopyrightText: 2022 Comcast Cable Communications Management, LLC
// SPDX-License-Identifier: Apache-2.0

package wrpcontext

import (
"context"
)

type contextContentsKey struct{}

// Set provides a standard way to add a wrp message to a context.Context. This supports not only wrp.Message
// but also all the other message types, such as wrp.SimpleRequestResponse
func SetContents(ctx context.Context, b []byte) context.Context {
return context.WithValue(ctx, contextContentsKey{}, b)
}

// Get a message from a context and return it as type T
func GetContents(ctx context.Context) ([]byte, bool) {
src := ctx.Value(contextContentsKey{})
if src == nil {
return []byte{}, false
}

return get[[]byte](ctx, src)
}
15 changes: 1 addition & 14 deletions wrpcontext/wrpcontext.go → wrpcontext/context.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,21 +8,8 @@ import (
"reflect"
)

type contextKey struct{}

// Set provides a standard way to add a wrp message to a context.Context. This supports not only wrp.Message
// but also all the other message types, such as wrp.SimpleRequestResponse
func Set(ctx context.Context, msg any) context.Context {
return context.WithValue(ctx, contextKey{}, msg)
}

// Get a message from a context and return it as type T
func Get[T any](ctx context.Context) (dest T, ok bool) {
src := ctx.Value(contextKey{})
if src == nil {
return
}

func get[T any](ctx context.Context, src any) (dest T, ok bool) {
// if src and dest are the exact same type
if dest, ok = src.(T); ok {
return
Expand Down
28 changes: 28 additions & 0 deletions wrpcontext/message.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
// SPDX-FileCopyrightText: 2022 Comcast Cable Communications Management, LLC
// SPDX-License-Identifier: Apache-2.0

package wrpcontext

import (
"context"

"github.com/xmidt-org/wrp-go/v3"
)

type contextWRPMessageKey struct{}

// Set provides a standard way to add a wrp message to a context.Context. This supports not only wrp.Message
// but also all the other message types, such as wrp.SimpleRequestResponse
func SetMessage(ctx context.Context, msg any) context.Context {
return context.WithValue(ctx, contextWRPMessageKey{}, msg)
}

// Get a message from a context and return it as type T
func GetMessage(ctx context.Context) (*wrp.Message, bool) {
src := ctx.Value(contextWRPMessageKey{})
if src == nil {
return nil, false
}

return get[*wrp.Message](ctx, src)
}
55 changes: 24 additions & 31 deletions wrphttp/decoders.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@

import (
"context"
"encoding/json"
"fmt"
"io"
"net/http"
Expand All @@ -29,7 +28,6 @@

func DecodeEntity(defaultFormat wrp.Format) Decoder {
return func(ctx context.Context, original *http.Request) (*Entity, error) {

format, err := DetermineFormat(defaultFormat, original.Header, "Content-Type")
if err != nil {
return nil, fmt.Errorf("failed to determine format of Content-Type header: %v", err)
Expand All @@ -40,35 +38,25 @@
return nil, fmt.Errorf("failed to determine format of Accept header: %v", err)
}

// Check if the context already contains a message
// If so, return the original request's message as an entity
msg, ok := wrpcontext.Get[*wrp.Message](original.Context())
if ok {
jsonBytes, err := json.Marshal(msg)
entity := &Entity{Format: format}
if contents, ok := wrpcontext.GetContents(original.Context()); ok {
entity.Bytes = contents

Check warning on line 43 in wrphttp/decoders.go

View check run for this annotation

Codecov / codecov/patch

wrphttp/decoders.go#L43

Added line #L43 was not covered by tests
} else {
contents, err := io.ReadAll(original.Body)
if err != nil {
return nil, err
}
entity := &Entity{
Message: *msg,
Format: format,
Bytes: jsonBytes,
return nil, fmt.Errorf("failed to read request body: %v", err)

Check warning on line 47 in wrphttp/decoders.go

View check run for this annotation

Codecov / codecov/patch

wrphttp/decoders.go#L47

Added line #L47 was not covered by tests
}
return entity, nil
}

contents, err := io.ReadAll(original.Body)
if err != nil {
return nil, err
}

entity := &Entity{
Format: format,
Bytes: contents,
entity.Bytes = contents
}

err = wrp.NewDecoderBytes(contents, format).Decode(&entity.Message)
if err != nil {
return nil, fmt.Errorf("failed to decode wrp: %v", err)
if msg, ok := wrpcontext.GetMessage(original.Context()); ok {
entity.Message = *msg

Check warning on line 54 in wrphttp/decoders.go

View check run for this annotation

Codecov / codecov/patch

wrphttp/decoders.go#L54

Added line #L54 was not covered by tests
} else {
err = wrp.NewDecoderBytes(entity.Bytes, format).Decode(&entity.Message)
if err != nil {
return nil, fmt.Errorf("failed to decode wrp: %v", err)
}
}

return entity, err
Expand Down Expand Up @@ -108,9 +96,9 @@

// DecodeRequest is a Decoder that provides lower-level way of decoding an *http.Request
// Can work for servers that don't use a wrp.Handler
func DecodeRequest(r *http.Request, msg any) (*http.Request, error) {
func DecodeRequest(r *http.Request, _ any) (*http.Request, error) {

if _, ok := wrpcontext.Get[*wrp.Message](r.Context()); ok {
if _, ok := wrpcontext.GetMessage(r.Context()); ok {
// Context already contains a message, so just return the original request
return r, nil
}
Expand All @@ -122,18 +110,23 @@

var decodedMessage wrp.Message

contents, err := io.ReadAll(r.Body)
if err != nil {
return nil, fmt.Errorf("failed to read request body: %v", err)
}

Check warning on line 116 in wrphttp/decoders.go

View check run for this annotation

Codecov / codecov/patch

wrphttp/decoders.go#L115-L116

Added lines #L115 - L116 were not covered by tests

// Try to decode the message using the HTTP Request headers
// If this doesn't work, decode the message as Msgpack or JSON format
if err = SetMessageFromHeaders(r.Header, &decodedMessage); err != nil {
// Msgpack or JSON Format
bodyReader := r.Body
err = wrp.NewDecoder(bodyReader, format).Decode(&decodedMessage)
err = wrp.NewDecoderBytes(contents, format).Decode(&decodedMessage)
if err != nil {
return nil, fmt.Errorf("failed to decode wrp message: %v", err)
}
}

ctx := wrpcontext.Set(r.Context(), &decodedMessage)
ctx := wrpcontext.SetMessage(r.Context(), &decodedMessage)
ctx = wrpcontext.SetContents(ctx, contents)

// Return a new request with the new context, containing the decoded message
return r.WithContext(ctx), nil
Expand Down
9 changes: 4 additions & 5 deletions wrphttp/decoders_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -234,7 +234,7 @@ func testDecodeRequestSuccess(t *testing.T) {
)

require.NoError(
wrp.NewEncoderBytes(&body, record.bodyFormat).Encode(&expected),
wrp.NewEncoderBytes(&body, record.bodyFormat).Encode(expected),
)

request := httptest.NewRequest("POST", "/", bytes.NewBuffer(body))
Expand All @@ -246,15 +246,14 @@ func testDecodeRequestSuccess(t *testing.T) {
expected.Type = record.msgType
}

var msg wrp.Message
actual, err := DecodeRequest(request, &msg)
msg, ok := wrpcontext.Get[wrp.Message](actual.Context())
actual, err := DecodeRequest(request, nil)
msg, ok := wrpcontext.GetMessage(actual.Context())

assert.True(ok)
assert.Nil(err)
require.NotNil(actual)
require.NotNil(actual.Context())
assert.Equal(expected, &msg)
assert.Equal(expected, msg)

}
}
Expand Down