diff --git a/.changes/next-release/sdk-feature-1615946574682127000.json b/.changes/next-release/sdk-feature-1615946574682127000.json new file mode 100644 index 00000000000..1e0ae796a3a --- /dev/null +++ b/.changes/next-release/sdk-feature-1615946574682127000.json @@ -0,0 +1,9 @@ +{ + "ID": "sdk-feature-1615946574682127000", + "SchemaVersion": 1, + "Module": "/", + "Type": "feature", + "Description": "Add helper to V4 signer package to swap compute payload hash middleware with unsigned payload middleware", + "MinVersion": "", + "AffectedModules": null +} \ No newline at end of file diff --git a/aws/signer/v4/middleware.go b/aws/signer/v4/middleware.go index ffa297668e1..b6e28b4bd46 100644 --- a/aws/signer/v4/middleware.go +++ b/aws/signer/v4/middleware.go @@ -156,6 +156,16 @@ func (m *computePayloadSHA256) HandleBuild( return next.HandleBuild(ctx, in) } +// SwapComputePayloadSHA256ForUnsignedPayloadMiddleware replaces the +// ComputePayloadSHA256 middleware with the UnsignedPayload middleware. +// +// Use this to disable computing the Payload SHA256 checksum and instead use +// UNSIGNED-PAYLOAD for the SHA256 value. +func SwapComputePayloadSHA256ForUnsignedPayloadMiddleware(stack *middleware.Stack) error { + _, err := stack.Build.Swap(computePayloadHashMiddlewareID, &unsignedPayload{}) + return err +} + // contentSHA256Header sets the X-Amz-Content-Sha256 header value to // the Payload hash stored in the context. type contentSHA256Header struct{} diff --git a/aws/signer/v4/middleware_test.go b/aws/signer/v4/middleware_test.go index a56058a86c0..acbc3d827b3 100644 --- a/aws/signer/v4/middleware_test.go +++ b/aws/signer/v4/middleware_test.go @@ -18,6 +18,7 @@ import ( "github.com/aws/smithy-go/logging" "github.com/aws/smithy-go/middleware" smithyhttp "github.com/aws/smithy-go/transport/http" + "github.com/google/go-cmp/cmp" ) func TestComputePayloadHashMiddleware(t *testing.T) { @@ -205,6 +206,106 @@ func TestSignHTTPRequestMiddleware(t *testing.T) { } } +func TestSwapComputePayloadSHA256ForUnsignedPayloadMiddleware(t *testing.T) { + cases := map[string]struct { + InitStep func(*middleware.Stack) error + Mutator func(*middleware.Stack) error + ExpectErr string + ExpectIDs []string + }{ + "swap in place": { + InitStep: func(s *middleware.Stack) (err error) { + err = s.Build.Add(middleware.BuildMiddlewareFunc("before", nil), middleware.After) + if err != nil { + return err + } + err = AddComputePayloadSHA256Middleware(s) + if err != nil { + return err + } + err = s.Build.Add(middleware.BuildMiddlewareFunc("after", nil), middleware.After) + if err != nil { + return err + } + return nil + }, + Mutator: SwapComputePayloadSHA256ForUnsignedPayloadMiddleware, + ExpectIDs: []string{ + "before", + computePayloadHashMiddlewareID, + "after", + }, + }, + + "already unsigned payload exists": { + InitStep: func(s *middleware.Stack) (err error) { + err = s.Build.Add(middleware.BuildMiddlewareFunc("before", nil), middleware.After) + if err != nil { + return err + } + err = AddUnsignedPayloadMiddleware(s) + if err != nil { + return err + } + err = s.Build.Add(middleware.BuildMiddlewareFunc("after", nil), middleware.After) + if err != nil { + return err + } + return nil + }, + Mutator: SwapComputePayloadSHA256ForUnsignedPayloadMiddleware, + ExpectIDs: []string{ + "before", + computePayloadHashMiddlewareID, + "after", + }, + }, + + "no compute payload": { + InitStep: func(s *middleware.Stack) (err error) { + err = s.Build.Add(middleware.BuildMiddlewareFunc("before", nil), middleware.After) + if err != nil { + return err + } + err = s.Build.Add(middleware.BuildMiddlewareFunc("after", nil), middleware.After) + if err != nil { + return err + } + return nil + }, + Mutator: SwapComputePayloadSHA256ForUnsignedPayloadMiddleware, + ExpectErr: "not found, " + computePayloadHashMiddlewareID, + }, + } + + for name, c := range cases { + t.Run(name, func(t *testing.T) { + stack := middleware.NewStack(t.Name(), smithyhttp.NewStackRequest) + if err := c.InitStep(stack); err != nil { + t.Fatalf("expect no error, got %v", err) + } + + err := c.Mutator(stack) + if len(c.ExpectErr) != 0 { + if err == nil { + t.Fatalf("expect error, got none") + } + if e, a := c.ExpectErr, err.Error(); !strings.Contains(a, e) { + t.Fatalf("expect error to contain %v, got %v", e, a) + } + return + } + if err != nil { + t.Fatalf("expect no error, got %v", err) + } + + if diff := cmp.Diff(c.ExpectIDs, stack.Build.List()); len(diff) != 0 { + t.Errorf("expect match\n%v", diff) + } + }) + } +} + type nonSeeker struct{} func (nonSeeker) Read(p []byte) (n int, err error) {