From 2cb59f97eacc0a1ccc2d14484df2ab75acee5695 Mon Sep 17 00:00:00 2001 From: Luc Talatinian Date: Fri, 6 Oct 2023 11:58:56 -0400 Subject: [PATCH] feat: add http.WithHeaderComment middleware --- .../a33cd9f8b5d5438a8f8140ee4f39c2f7.json | 8 ++ transport/http/middleware_header_comment.go | 81 +++++++++++++ .../http/middleware_header_comment_test.go | 113 ++++++++++++++++++ 3 files changed, 202 insertions(+) create mode 100644 .changelog/a33cd9f8b5d5438a8f8140ee4f39c2f7.json create mode 100644 transport/http/middleware_header_comment.go create mode 100644 transport/http/middleware_header_comment_test.go diff --git a/.changelog/a33cd9f8b5d5438a8f8140ee4f39c2f7.json b/.changelog/a33cd9f8b5d5438a8f8140ee4f39c2f7.json new file mode 100644 index 000000000..4c2a9ee48 --- /dev/null +++ b/.changelog/a33cd9f8b5d5438a8f8140ee4f39c2f7.json @@ -0,0 +1,8 @@ +{ + "id": "a33cd9f8-b5d5-438a-8f81-40ee4f39c2f7", + "type": "feature", + "description": "Add `http.WithHeaderComment` middleware.", + "modules": [ + "." + ] +} \ No newline at end of file diff --git a/transport/http/middleware_header_comment.go b/transport/http/middleware_header_comment.go new file mode 100644 index 000000000..855c22720 --- /dev/null +++ b/transport/http/middleware_header_comment.go @@ -0,0 +1,81 @@ +package http + +import ( + "context" + "fmt" + "net/http" + + "github.com/aws/smithy-go/middleware" +) + +// WithHeaderComment instruments a middleware stack to append an HTTP field +// comment to the given header as specified in RFC 9110 +// (https://www.rfc-editor.org/rfc/rfc9110#name-comments). +// +// The header is case-insensitive. If the provided header exists when the +// middleware runs, the content will be inserted as-is enclosed in parentheses. +// +// Note that per the HTTP specification, comments are only allowed in fields +// containing "comment" as part of their field value definition, but this API +// will NOT verify whether the provided header is one of them. +// +// WithHeaderComment MAY be applied more than once to a middleware stack and/or +// more than once per header. +func WithHeaderComment(header, content string) func(*middleware.Stack) error { + return func(s *middleware.Stack) error { + m, err := getOrAddHeaderComment(s) + if err != nil { + return fmt.Errorf("get or add header comment: %v", err) + } + + m.values.Add(header, content) + return nil + } +} + +type headerCommentMiddleware struct { + values http.Header // hijack case-insensitive access APIs +} + +func (*headerCommentMiddleware) ID() string { + return "headerComment" +} + +func (m *headerCommentMiddleware) HandleBuild(ctx context.Context, in middleware.BuildInput, next middleware.BuildHandler) ( + out middleware.BuildOutput, metadata middleware.Metadata, err error, +) { + r, ok := in.Request.(*Request) + if !ok { + return out, metadata, fmt.Errorf("unknown transport type %T", in.Request) + } + + for h, contents := range m.values { + for _, c := range contents { + if existing := r.Header.Get(h); existing != "" { + r.Header.Set(h, fmt.Sprintf("%s (%s)", existing, c)) + } + } + } + + return next.HandleBuild(ctx, in) +} + +func getOrAddHeaderComment(s *middleware.Stack) (*headerCommentMiddleware, error) { + id := (*headerCommentMiddleware)(nil).ID() + m, ok := s.Build.Get(id) + if !ok { + m := &headerCommentMiddleware{values: http.Header{}} + if err := s.Build.Add(m, middleware.After); err != nil { + return nil, fmt.Errorf("add build: %v", err) + } + + return m, nil + } + + hc, ok := m.(*headerCommentMiddleware) + if !ok { + return nil, fmt.Errorf("existing middleware w/ id %s is not *headerCommentMiddleware", id) + } + + return hc, nil +} diff --git a/transport/http/middleware_header_comment_test.go b/transport/http/middleware_header_comment_test.go new file mode 100644 index 000000000..6812c0b28 --- /dev/null +++ b/transport/http/middleware_header_comment_test.go @@ -0,0 +1,113 @@ +package http + +import ( + "context" + "net/http" + "testing" + + "github.com/aws/smithy-go/middleware" +) + +func TestWithHeaderComment_CaseInsensitive(t *testing.T) { + stack, err := newTestStack( + WithHeaderComment("foo", "bar"), + ) + if err != nil { + t.Errorf("expected no error on new stack, got %v", err) + } + + r := injectBuildRequest(stack) + r.Header.Set("Foo", "baz") + + if err := handle(stack); err != nil { + t.Errorf("expected no error on handle, got %v", err) + } + + expectHeader(t, r.Header, "Foo", "baz (bar)") +} + +func TestWithHeaderComment_Noop(t *testing.T) { + stack, err := newTestStack( + WithHeaderComment("foo", "bar"), + ) + if err != nil { + t.Errorf("expected no error on new stack, got %v", err) + } + + r := injectBuildRequest(stack) + + if err := handle(stack); err != nil { + t.Errorf("expected no error on handle, got %v", err) + } + + expectHeader(t, r.Header, "Foo", "") +} + +func TestWithHeaderComment_MultiCaseInsensitive(t *testing.T) { + stack, err := newTestStack( + WithHeaderComment("foo", "c1"), + WithHeaderComment("Foo", "c2"), + WithHeaderComment("baz", "c3"), + WithHeaderComment("Baz", "c4"), + ) + if err != nil { + t.Errorf("expected no error on new stack, got %v", err) + } + + r := injectBuildRequest(stack) + r.Header.Set("Foo", "1") + r.Header.Set("Baz", "2") + + if err := handle(stack); err != nil { + t.Errorf("expected no error on handle, got %v", err) + } + + expectHeader(t, r.Header, "Foo", "1 (c1) (c2)") + expectHeader(t, r.Header, "Baz", "2 (c3) (c4)") +} + +func newTestStack(fns ...func(*middleware.Stack) error) (*middleware.Stack, error) { + s := middleware.NewStack("", NewStackRequest) + for _, fn := range fns { + if err := fn(s); err != nil { + return nil, err + } + } + return s, nil +} + +func handle(stack *middleware.Stack) error { + _, _, err := middleware.DecorateHandler( + middleware.HandlerFunc( + func(ctx context.Context, input interface{}) ( + interface{}, middleware.Metadata, error, + ) { + return nil, middleware.Metadata{}, nil + }, + ), + stack, + ).Handle(context.Background(), nil) + return err +} + +func injectBuildRequest(s *middleware.Stack) *Request { + r := NewStackRequest() + s.Build.Add( + middleware.BuildMiddlewareFunc( + "injectBuildRequest", + func(ctx context.Context, in middleware.BuildInput, next middleware.BuildHandler) ( + middleware.BuildOutput, middleware.Metadata, error, + ) { + return next.HandleBuild(ctx, middleware.BuildInput{Request: r}) + }, + ), + middleware.Before, + ) + return r.(*Request) +} + +func expectHeader(t *testing.T, header http.Header, h, ev string) { + if av := header.Get(h); ev != av { + t.Errorf("expected header '%s: %s', got '%s'", h, ev, av) + } +}