From 6d365fb33d3a9b6f7fd3bd0153beacc1a507649a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Luis=20Madrigal=20=F0=9F=90=A7?= Date: Thu, 6 Jun 2024 15:20:03 -0400 Subject: [PATCH] address feedback from pr, mostly moving things around --- aws/middleware/middleware.go | 56 ---- aws/middleware/middleware_test.go | 240 ----------------- aws/retry/middleware.go | 30 ++- aws/retry/middleware_test.go | 4 +- aws/retry/retryable_error.go | 10 +- aws/retry/standard.go | 3 - .../aws/go/codegen/AwsGoDependency.java | 1 + .../aws/go/codegen/ClockSkewGenerator.java | 121 +++++---- internal/auth/smithy/v4signer_adapter.go | 4 +- internal/context/context.go | 13 + internal/middleware/middleware.go | 42 +++ internal/middleware/middleware_test.go | 254 ++++++++++++++++++ internal/v4a/smithy.go | 7 +- 13 files changed, 407 insertions(+), 378 deletions(-) create mode 100644 internal/middleware/middleware.go create mode 100644 internal/middleware/middleware_test.go diff --git a/aws/middleware/middleware.go b/aws/middleware/middleware.go index 60e88d7fb65..6d5f0079c2f 100644 --- a/aws/middleware/middleware.go +++ b/aws/middleware/middleware.go @@ -3,7 +3,6 @@ package middleware import ( "context" "fmt" - "sync/atomic" "time" "github.com/aws/aws-sdk-go-v2/internal/rand" @@ -125,19 +124,6 @@ func setAttemptSkew(metadata *middleware.Metadata, v time.Duration) { metadata.Set(attemptSkewKey{}, v) } -type clockSkew struct{} - -// SetAttemptSkewContext sets the clock skew value on the context -func SetAttemptSkewContext(ctx context.Context, v time.Duration) context.Context { - return middleware.WithStackValue(ctx, clockSkew{}, v) -} - -// GetAttemptSkewContext gets the clock skew value from the context -func GetAttemptSkewContext(ctx context.Context) time.Duration { - x, _ := middleware.GetStackValue(ctx, clockSkew{}).(time.Duration) - return x -} - // AddClientRequestIDMiddleware adds ClientRequestID to the middleware stack func AddClientRequestIDMiddleware(stack *middleware.Stack) error { return stack.Build.Add(&ClientRequestID{}, middleware.After) @@ -180,45 +166,3 @@ func AddRawResponseToMetadata(stack *middleware.Stack) error { func GetRawResponse(metadata middleware.Metadata) interface{} { return metadata.Get(rawResponseKey{}) } - -// AddTimeOffsetBuildMiddleware sets a value representing clock skew on the request context. -// This can be read by other operations (such as signing) to correct the date value they send -// on the request -type AddTimeOffsetBuildMiddleware struct { - Offset *atomic.Int64 -} - -// ID the identifier for AddTimeOffsetBuildMiddleware -func (m *AddTimeOffsetBuildMiddleware) ID() string { return "AddTimeOffsetMiddleware" } - -// HandleBuild sets a value for attemptSkew on the request context if one is set on the client. -func (m AddTimeOffsetBuildMiddleware) HandleBuild(ctx context.Context, in middleware.BuildInput, next middleware.BuildHandler) ( - out middleware.BuildOutput, metadata middleware.Metadata, err error, -) { - if m.Offset != nil { - offset := time.Duration(m.Offset.Load()) - ctx = SetAttemptSkewContext(ctx, offset) - } - return next.HandleBuild(ctx, in) -} - -// AddTimeOffsetDeserializeMiddleware sets the clock skew on the client if it's present on the context -// at the end of the request -type AddTimeOffsetDeserializeMiddleware struct { - Offset *atomic.Int64 -} - -// ID the identifier for AddTimeOffsetDeserializeMiddleware -func (m *AddTimeOffsetDeserializeMiddleware) ID() string { return "AddTimeOffsetDeserializeMiddleware" } - -// HandleDeserialize gets the clock skew context from the context, and if set, sets it on the pointer -// held by AddTimeOffsetDeserializeMiddleware -func (m *AddTimeOffsetDeserializeMiddleware) HandleDeserialize(ctx context.Context, in middleware.DeserializeInput, next middleware.DeserializeHandler) ( - out middleware.DeserializeOutput, metadata middleware.Metadata, err error, -) { - v := GetAttemptSkewContext(ctx) - if v != 0 { - m.Offset.Store(v.Nanoseconds()) - } - return next.HandleDeserialize(ctx, in) -} diff --git a/aws/middleware/middleware_test.go b/aws/middleware/middleware_test.go index 548671a339f..e4a69c9c22a 100644 --- a/aws/middleware/middleware_test.go +++ b/aws/middleware/middleware_test.go @@ -3,13 +3,9 @@ package middleware_test import ( "bytes" "context" - "fmt" - "github.com/aws/aws-sdk-go-v2/aws" - "github.com/aws/aws-sdk-go-v2/aws/retry" "net/http" "reflect" "strings" - "sync/atomic" "testing" "time" @@ -191,239 +187,3 @@ func TestAttemptClockSkewHandler(t *testing.T) { }) } } - -type HTTPClient interface { - Do(*http.Request) (*http.Response, error) -} - -type Options struct { - HTTPClient HTTPClient - RetryMode aws.RetryMode - Retryer aws.Retryer - Offset *atomic.Int64 -} - -type MockClient struct { - options Options -} - -func addRetry(stack *smithymiddleware.Stack, o Options) error { - attempt := retry.NewAttemptMiddleware(o.Retryer, smithyhttp.RequestCloner, func(m *retry.Attempt) { - m.LogAttempts = false - }) - return stack.Finalize.Add(attempt, smithymiddleware.After) -} - -func addOffset(stack *smithymiddleware.Stack, o Options) error { - buildOffset := middleware.AddTimeOffsetBuildMiddleware{Offset: o.Offset} - deserializeOffset := middleware.AddTimeOffsetDeserializeMiddleware{Offset: o.Offset} - err := stack.Build.Add(&buildOffset, smithymiddleware.After) - if err != nil { - return err - } - err = stack.Deserialize.Add(&deserializeOffset, smithymiddleware.Before) - if err != nil { - return err - } - return nil -} - -// Middleware to set a `Date` object that includes sdk time and offset -type MockAddDateHeader struct { -} - -func (l *MockAddDateHeader) ID() string { - return "MockAddDateHeader" -} - -func (l *MockAddDateHeader) HandleFinalize( - ctx context.Context, in smithymiddleware.FinalizeInput, next smithymiddleware.FinalizeHandler, -) ( - out smithymiddleware.FinalizeOutput, metadata smithymiddleware.Metadata, attemptError error, -) { - req := in.Request.(*smithyhttp.Request) - date := sdk.NowTime() - skew := middleware.GetAttemptSkewContext(ctx) - date = date.Add(skew) - req.Header.Set("Date", date.Format(time.RFC850)) - return next.HandleFinalize(ctx, in) -} - -// Middleware to deserialize the response which just says "OK" if the response is 200 -type DeserializeFailIfNotHTTP200 struct { -} - -func (*DeserializeFailIfNotHTTP200) ID() string { - return "DeserializeFailIfNotHTTP200" -} - -func (m *DeserializeFailIfNotHTTP200) HandleDeserialize(ctx context.Context, in smithymiddleware.DeserializeInput, next smithymiddleware.DeserializeHandler) ( - out smithymiddleware.DeserializeOutput, metadata smithymiddleware.Metadata, err error, -) { - out, metadata, err = next.HandleDeserialize(ctx, in) - if err != nil { - return out, metadata, err - } - response, ok := out.RawResponse.(*smithyhttp.Response) - if !ok { - return out, metadata, fmt.Errorf("expected raw response to be set on testing") - } - if response.StatusCode != 200 { - return out, metadata, mockRetryableError{true} - } - return out, metadata, err -} - -func (c *MockClient) setupMiddleware(stack *smithymiddleware.Stack) error { - err := error(nil) - if c.options.Retryer != nil { - err = addRetry(stack, c.options) - if err != nil { - return err - } - } - if c.options.Offset != nil { - err = addOffset(stack, c.options) - if err != nil { - return err - } - } - err = stack.Finalize.Add(&MockAddDateHeader{}, smithymiddleware.After) - if err != nil { - return err - } - err = middleware.AddRecordResponseTiming(stack) - if err != nil { - return err - } - err = stack.Deserialize.Add(&DeserializeFailIfNotHTTP200{}, smithymiddleware.After) - if err != nil { - return err - } - return nil -} - -func (c *MockClient) Do(ctx context.Context) (interface{}, error) { - // setup middlewares - ctx = smithymiddleware.ClearStackValues(ctx) - stack := smithymiddleware.NewStack("stack", smithyhttp.NewStackRequest) - err := c.setupMiddleware(stack) - if err != nil { - return nil, err - } - handler := smithymiddleware.DecorateHandler(smithyhttp.NewClientHandler(c.options.HTTPClient), stack) - result, _, err := handler.Handle(ctx, 1) - if err != nil { - return nil, err - } - return result, err -} - -type mockRetryableError struct{ b bool } - -func (m mockRetryableError) RetryableError() bool { return m.b } -func (m mockRetryableError) Error() string { - return fmt.Sprintf("mock retryable %t", m.b) -} - -func failRequestIfSkewed() smithyhttp.ClientDoFunc { - return func(req *http.Request) (*http.Response, error) { - dateHeader := req.Header.Get("Date") - if dateHeader == "" { - return nil, fmt.Errorf("expected `Date` header to be set") - } - reqDate, err := time.Parse(time.RFC850, dateHeader) - if err != nil { - return nil, err - } - parsedReqTime := time.Now().Sub(reqDate) - parsedReqTime = time.Duration.Abs(parsedReqTime) - thresholdForSkewError := 4 * time.Minute - if thresholdForSkewError-parsedReqTime <= 0 { - return &http.Response{ - StatusCode: 403, - Header: http.Header{ - "Date": {time.Now().Format(time.RFC850)}, - }, - }, nil - } - // else, return OK - return &http.Response{ - StatusCode: 200, - Header: http.Header{}, - }, nil - } -} - -func TestSdkOffsetIsSet(t *testing.T) { - nowTime := sdk.NowTime - defer func() { - sdk.NowTime = nowTime - }() - fiveMinuteSkew := func() time.Time { - return time.Now().Add(5 * time.Minute) - } - sdk.NowTime = fiveMinuteSkew - c := MockClient{ - Options{ - HTTPClient: failRequestIfSkewed(), - }, - } - resp, err := c.Do(context.Background()) - if err == nil { - t.Errorf("Expected first request to fail since clock skew logic has not run. Got %v and err %v", resp, err) - } -} - -func TestRetrySetsSkewInContext(t *testing.T) { - defer resetDefaults(sdk.TestingUseNopSleep()) - fiveMinuteSkew := func() time.Time { - return time.Now().Add(5 * time.Minute) - } - sdk.NowTime = fiveMinuteSkew - c := MockClient{ - Options{ - HTTPClient: failRequestIfSkewed(), - Retryer: retry.NewStandard(func(s *retry.StandardOptions) { - }), - }, - } - resp, err := c.Do(context.Background()) - if err != nil { - t.Errorf("Expected request to succeed on retry. Got %v and err %v", resp, err) - } -} - -func TestSkewIsSetOnTheWholeClient(t *testing.T) { - defer resetDefaults(sdk.TestingUseNopSleep()) - fiveMinuteSkew := func() time.Time { - return time.Now().Add(5 * time.Minute) - } - sdk.NowTime = fiveMinuteSkew - var offset atomic.Int64 - offset.Store(0) - c := MockClient{ - Options{ - HTTPClient: failRequestIfSkewed(), - Retryer: retry.NewStandard(func(s *retry.StandardOptions) { - }), - Offset: &offset, - }, - } - resp, err := c.Do(context.Background()) - if err != nil { - t.Errorf("Expected request to succeed on retry. Got %v and err %v", resp, err) - } - // Remove retryer so it has to succeed on first call - c.options.Retryer = nil - // same client, new request - resp, err = c.Do(context.Background()) - if err != nil { - t.Errorf("Expected second request to succeed since the skew should be set on the client. Got %v and err %v", resp, err) - } -} - -func resetDefaults(restoreSleepFunc func()) { - sdk.NowTime = time.Now - restoreSleepFunc() -} diff --git a/aws/retry/middleware.go b/aws/retry/middleware.go index 0a9e0291b78..c32d16f43a5 100644 --- a/aws/retry/middleware.go +++ b/aws/retry/middleware.go @@ -5,6 +5,7 @@ import ( "errors" "fmt" "github.com/aws/aws-sdk-go-v2/aws/middleware/private/metrics" + internalcontext "github.com/aws/aws-sdk-go-v2/internal/context" "strconv" "strings" "time" @@ -40,6 +41,10 @@ type Attempt struct { requestCloner RequestCloner } +// define the threshold at which we will consider certain kind of errors to be probably +// caused by clock skew +const skewThreshold = 4 * time.Minute + // NewAttemptMiddleware returns a new Attempt retry middleware. func NewAttemptMiddleware(retryer aws.Retryer, requestCloner RequestCloner, optFns ...func(*Attempt)) *Attempt { m := &Attempt{ @@ -88,7 +93,7 @@ func (r *Attempt) HandleFinalize(ctx context.Context, in smithymiddle.FinalizeIn }) // Setting clock skew to be used on other context (like signing) - ctx = awsmiddle.SetAttemptSkewContext(ctx, attemptClockSkew) + ctx = internalcontext.SetAttemptSkewContext(ctx, attemptClockSkew) var attemptResult AttemptResult out, attemptResult, releaseRetryToken, err = r.handleAttempt(attemptCtx, attemptInput, releaseRetryToken, next) @@ -253,14 +258,20 @@ func (r *Attempt) handleAttempt( return out, attemptResult, releaseRetryToken, err } -// Note that there are errors that are known to be definitely caused by clock -// skew, which are defined on the list of retryable errors +// errors that, if detected when we know there's a clock skew, +// can be retried and have a high chance of success var possibleSkewCodes = map[string]struct{}{ "InvalidSignatureException": {}, "SignatureDoesNotMatch": {}, "AuthFailure": {}, } +var definiteSkewCodes = map[string]struct{}{ + "RequestExpired": {}, + "RequestInTheFuture": {}, + "RequestTimeTooSkewed": {}, +} + // wrapAsClockSkew checks if this error could be related to a clock skew // error and if so, wrap the error. func wrapAsClockSkew(ctx context.Context, err error) error { @@ -268,15 +279,14 @@ func wrapAsClockSkew(ctx context.Context, err error) error { if !errors.As(err, &v) { return err } - _, ok := possibleSkewCodes[v.ErrorCode()] - if !ok { - return err + if _, ok := definiteSkewCodes[v.ErrorCode()]; ok { + return &retryableClockSkewError{Err: err} } - skew := awsmiddle.GetAttemptSkewContext(ctx) - if !(skew > 4*time.Minute) { - return err + _, isPossibleSkewCode := possibleSkewCodes[v.ErrorCode()] + if skew := internalcontext.GetAttemptSkewContext(ctx); skew > skewThreshold && isPossibleSkewCode { + return &retryableClockSkewError{Err: err} } - return &ProbClockSkewError{Err: err} + return err } // MetricsHeader attaches SDK request metric header for retries to the transport diff --git a/aws/retry/middleware_test.go b/aws/retry/middleware_test.go index 671f3d6805e..a8e320e97ca 100644 --- a/aws/retry/middleware_test.go +++ b/aws/retry/middleware_test.go @@ -4,8 +4,8 @@ import ( "context" "errors" "fmt" - awsmiddle "github.com/aws/aws-sdk-go-v2/aws/middleware" "github.com/aws/aws-sdk-go-v2/aws/middleware/private/metrics/testutils" + internalcontext "github.com/aws/aws-sdk-go-v2/internal/context" "net/http" "reflect" "strconv" @@ -530,7 +530,7 @@ func TestClockSkew(t *testing.T) { t.Run(name, func(t *testing.T) { am := NewAttemptMiddleware(NewStandard(func(s *StandardOptions) { }), testutils.NoopRequestCloner) - ctx := awsmiddle.SetAttemptSkewContext(context.Background(), tt.skew) + ctx := internalcontext.SetAttemptSkewContext(context.Background(), tt.skew) _, metadata, err := am.HandleFinalize(ctx, middleware.FinalizeInput{}, middleware.FinalizeHandlerFunc( func(ctx context.Context, in middleware.FinalizeInput) ( out middleware.FinalizeOutput, metadata middleware.Metadata, err error, diff --git a/aws/retry/retryable_error.go b/aws/retry/retryable_error.go index 37925ec20b7..3a7c5fb9711 100644 --- a/aws/retry/retryable_error.go +++ b/aws/retry/retryable_error.go @@ -200,22 +200,22 @@ func (r RetryableErrorCode) IsErrorRetryable(err error) aws.Ternary { return aws.TrueTernary } -// ProbClockSkewError marks errors that "could" be caused by clock skew +// retryableClockSkewError marks errors that can be caused by clock skew // (difference between server time and client time). // This is returned when there's certain confidence that adjusting the client time // could allow a retry to succeed -type ProbClockSkewError struct{ Err error } +type retryableClockSkewError struct{ Err error } -func (e *ProbClockSkewError) Error() string { +func (e *retryableClockSkewError) Error() string { return fmt.Sprintf("Probable clock skew error: %v", e.Err) } // Unwrap returns the wrapped error. -func (e *ProbClockSkewError) Unwrap() error { +func (e *retryableClockSkewError) Unwrap() error { return e.Err } // RetryableError allows the retryer to retry this request -func (e *ProbClockSkewError) RetryableError() bool { +func (e *retryableClockSkewError) RetryableError() bool { return true } diff --git a/aws/retry/standard.go b/aws/retry/standard.go index 0665cbcee0e..d5ea93222ed 100644 --- a/aws/retry/standard.go +++ b/aws/retry/standard.go @@ -51,11 +51,8 @@ var DefaultRetryableHTTPStatusCodes = map[int]struct{}{ // DefaultRetryableErrorCodes provides the set of API error codes that should // be retried. var DefaultRetryableErrorCodes = map[string]struct{}{ - "RequestExpired": {}, - "RequestInTheFuture": {}, "RequestTimeout": {}, "RequestTimeoutException": {}, - "RequestTimeTooSkewed": {}, } // DefaultThrottleErrorCodes provides the set of API error codes that are diff --git a/codegen/smithy-aws-go-codegen/src/main/java/software/amazon/smithy/aws/go/codegen/AwsGoDependency.java b/codegen/smithy-aws-go-codegen/src/main/java/software/amazon/smithy/aws/go/codegen/AwsGoDependency.java index d5c7b414a4e..c3801131bc1 100644 --- a/codegen/smithy-aws-go-codegen/src/main/java/software/amazon/smithy/aws/go/codegen/AwsGoDependency.java +++ b/codegen/smithy-aws-go-codegen/src/main/java/software/amazon/smithy/aws/go/codegen/AwsGoDependency.java @@ -45,6 +45,7 @@ public class AwsGoDependency { public static final GoDependency INTERNAL_AUTH = aws("internal/auth", "internalauth"); public static final GoDependency INTERNAL_AUTH_SMITHY = aws("internal/auth/smithy", "internalauthsmithy"); public static final GoDependency INTERNAL_CONTEXT = aws("internal/context", "internalcontext"); + public static final GoDependency INTERNAL_MIDDLEWARE = aws("internal/middleware", "internalmiddleware"); public static final GoDependency INTERNAL_ENDPOINTS_V2 = awsModuleDep("internal/endpoints/v2", null, Versions.INTERNAL_ENDPOINTS_V2, "endpoints"); diff --git a/codegen/smithy-aws-go-codegen/src/main/java/software/amazon/smithy/aws/go/codegen/ClockSkewGenerator.java b/codegen/smithy-aws-go-codegen/src/main/java/software/amazon/smithy/aws/go/codegen/ClockSkewGenerator.java index 8e06ce3da74..caac8ddee55 100644 --- a/codegen/smithy-aws-go-codegen/src/main/java/software/amazon/smithy/aws/go/codegen/ClockSkewGenerator.java +++ b/codegen/smithy-aws-go-codegen/src/main/java/software/amazon/smithy/aws/go/codegen/ClockSkewGenerator.java @@ -9,7 +9,10 @@ import software.amazon.smithy.utils.ListUtils; import java.util.List; +import java.util.Map; +import static software.amazon.smithy.aws.go.codegen.AwsGoDependency.INTERNAL_MIDDLEWARE; +import static software.amazon.smithy.go.codegen.GoWriter.goTemplate; import static software.amazon.smithy.go.codegen.SmithyGoDependency.ATOMIC; /** @@ -19,12 +22,65 @@ public class ClockSkewGenerator implements GoIntegration { private static final String TIME_OFFSET = "timeOffset"; private static final String ADD_CLOCK_SKEW_BUILD = "addTimeOffsetBuild"; - private static final String ADD_CLOCK_SKEW_BUILD_MIDDLEWARE = "AddTimeOffsetBuildMiddleware"; - private static final String ADD_CLOCK_SKEW_DESERIALIZER = "addTimeOffsetDeserializer"; - private static final String ADD_CLOCK_SKEW_DESERIALIZE_MIDDLEWARE = "AddTimeOffsetDeserializeMiddleware"; + private static final String ADD_CLOCK_SKEW_BUILD_MIDDLEWARE = "AddTimeOffsetMiddleware"; + private static final Symbol TIME_OFFSET_RESOLVER = SymbolUtils.createValueSymbolBuilder( "initializeTimeOffsetResolver").build(); + private static final GoWriter.Writable CLOCK_SKEW_INSERT_TEMPLATE = goTemplate(""" + $dep:D + func $fn:L(stack $stack:P, c *Client) error { + mw := $depalias:L.$middleware:L{Offset: c.$off:L} + if err := stack.Build.Add(&mw, middleware.After); err != nil { + return err + } + return stack.Deserialize.Insert(&mw, "$after:L", middleware.Before) + } + """, + Map.of( + "fn", ADD_CLOCK_SKEW_BUILD, + "stack", SmithyGoDependency.SMITHY_MIDDLEWARE.struct("Stack"), + "depalias", INTERNAL_MIDDLEWARE.getAlias(), + "middleware", ADD_CLOCK_SKEW_BUILD_MIDDLEWARE, + "after", "RecordResponseTiming", + "off", TIME_OFFSET, + "dep", INTERNAL_MIDDLEWARE + )); + private static final GoWriter.Writable TIME_OFFSET_RESOLVER_TEMPLATE = goTemplate( + """ + $import:D + func $fn:L(c *Client) { + c.$off:L = new(atomic.Int64) + } + """, + Map.of( + "import", ATOMIC, + "fn", TIME_OFFSET_RESOLVER, + "off", TIME_OFFSET + ) + ); + + private static final ClientMember TIME_OFFSET_MEMBER = ClientMember.builder() + .name(TIME_OFFSET) + .type(ATOMIC.struct("Int64")) + .documentation("Difference between the time reported by the server and the client") + .build(); + private static final ClientMemberResolver TIME_OFFSET_MEMBER_RESOLVER = ClientMemberResolver.builder() + .resolver(TIME_OFFSET_RESOLVER) + .build(); + private static final MiddlewareRegistrar MIDDLEWARE = MiddlewareRegistrar.builder() + .resolvedFunction(SymbolUtils.createValueSymbolBuilder(ADD_CLOCK_SKEW_BUILD).build()) + .functionArguments(ListUtils.of( + SymbolUtils.createValueSymbolBuilder("c").build() + )).build(); + private static final List CLIENT_PLUGINS = List.of( + RuntimeClientPlugin.builder() + .addClientMember(TIME_OFFSET_MEMBER) + .addClientMemberResolver(TIME_OFFSET_MEMBER_RESOLVER) + .registerMiddleware(MIDDLEWARE) + .build() + ); + @Override public void writeAdditionalFiles( GoSettings settings, @@ -37,66 +93,13 @@ public void writeAdditionalFiles( // generate code specific to service client goDelegator.useShapeWriter(service, writer -> { - generateClockSkewInsertMiddleware(writer); - generateClockSkewDeserializeMiddleware(writer); - generateTimeOffsetResolver(writer); + writer.write(CLOCK_SKEW_INSERT_TEMPLATE); + writer.write(TIME_OFFSET_RESOLVER_TEMPLATE); }); } @Override public List getClientPlugins() { - ClientMember timeOffset = ClientMember.builder() - .name(TIME_OFFSET) - .type(ATOMIC.struct("Int64")) - .documentation("Difference between the time reported by the server and the client") - .build(); - ClientMemberResolver resolver = ClientMemberResolver.builder() - .resolver(TIME_OFFSET_RESOLVER) - .build(); - MiddlewareRegistrar initializeMiddleware = MiddlewareRegistrar.builder() - .resolvedFunction(SymbolUtils.createValueSymbolBuilder(ADD_CLOCK_SKEW_BUILD).build()) - .functionArguments(ListUtils.of( - SymbolUtils.createValueSymbolBuilder("c").build() - )).build(); - MiddlewareRegistrar finalizeMiddleware = MiddlewareRegistrar.builder() - .resolvedFunction(SymbolUtils.createValueSymbolBuilder(ADD_CLOCK_SKEW_DESERIALIZER).build()) - .functionArguments(ListUtils.of( - SymbolUtils.createValueSymbolBuilder("c").build() - )).build(); - return List.of( - RuntimeClientPlugin.builder() - .addClientMember(timeOffset) - .addClientMemberResolver(resolver) - .registerMiddleware(initializeMiddleware) - .build(), - RuntimeClientPlugin.builder() - .registerMiddleware(finalizeMiddleware) - .build() - ); - } - - private void generateClockSkewInsertMiddleware(GoWriter writer) { - Symbol stackSymbol = SymbolUtils.createPointableSymbolBuilder("Stack", SmithyGoDependency.SMITHY_MIDDLEWARE) - .build(); - writer.openBlock("func $L(stack $P, c *Client) error{", "}", ADD_CLOCK_SKEW_BUILD, stackSymbol, () -> { - writer.write("return stack.Build.Add(&awsmiddleware.$L{Offset: c.$L}, middleware.After)", - ADD_CLOCK_SKEW_BUILD_MIDDLEWARE, TIME_OFFSET); - }); - } - - private void generateClockSkewDeserializeMiddleware(GoWriter writer) { - Symbol stackSymbol = SymbolUtils.createPointableSymbolBuilder("Stack", SmithyGoDependency.SMITHY_MIDDLEWARE) - .build(); - writer.openBlock("func $L(stack $P, c *Client) error{", "}", ADD_CLOCK_SKEW_DESERIALIZER, stackSymbol, () -> - writer.write("return stack.Deserialize.Insert(&awsmiddleware.$L{Offset: c.$L}, \"RecordResponseTiming\", middleware.Before)", - ADD_CLOCK_SKEW_DESERIALIZE_MIDDLEWARE, TIME_OFFSET) - ); - } - - private void generateTimeOffsetResolver(GoWriter writer) { - writer.openBlock("func $L(c *Client) {", "}", TIME_OFFSET_RESOLVER, () -> { - Symbol atomic = SymbolUtils.createValueSymbolBuilder("Int64", ATOMIC).build(); - writer.write("c.$L = new($P)", TIME_OFFSET, atomic); - }); + return CLIENT_PLUGINS; } } \ No newline at end of file diff --git a/internal/auth/smithy/v4signer_adapter.go b/internal/auth/smithy/v4signer_adapter.go index cf0fec878f2..fb08743bf48 100644 --- a/internal/auth/smithy/v4signer_adapter.go +++ b/internal/auth/smithy/v4signer_adapter.go @@ -3,8 +3,8 @@ package smithy import ( "context" "fmt" - "github.com/aws/aws-sdk-go-v2/aws/middleware" v4 "github.com/aws/aws-sdk-go-v2/aws/signer/v4" + internalcontext "github.com/aws/aws-sdk-go-v2/internal/context" "github.com/aws/aws-sdk-go-v2/internal/sdk" "github.com/aws/smithy-go" "github.com/aws/smithy-go/auth" @@ -40,7 +40,7 @@ func (v *V4SignerAdapter) SignRequest(ctx context.Context, r *smithyhttp.Request hash := v4.GetPayloadHash(ctx) signingTime := sdk.NowTime() - skew := middleware.GetAttemptSkewContext(ctx) + skew := internalcontext.GetAttemptSkewContext(ctx) signingTime = signingTime.Add(skew) err := v.Signer.SignHTTP(ctx, ca.Credentials, r.Request, hash, name, region, signingTime, func(o *v4.SignerOptions) { o.DisableURIPathEscaping, _ = smithyhttp.GetDisableDoubleEncoding(&props) diff --git a/internal/context/context.go b/internal/context/context.go index 15bf104772f..f0c283d3942 100644 --- a/internal/context/context.go +++ b/internal/context/context.go @@ -2,12 +2,14 @@ package context import ( "context" + "time" "github.com/aws/smithy-go/middleware" ) type s3BackendKey struct{} type checksumInputAlgorithmKey struct{} +type clockSkew struct{} const ( // S3BackendS3Express identifies the S3Express backend @@ -37,3 +39,14 @@ func GetChecksumInputAlgorithm(ctx context.Context) string { v, _ := middleware.GetStackValue(ctx, checksumInputAlgorithmKey{}).(string) return v } + +// SetAttemptSkewContext sets the clock skew value on the context +func SetAttemptSkewContext(ctx context.Context, v time.Duration) context.Context { + return middleware.WithStackValue(ctx, clockSkew{}, v) +} + +// GetAttemptSkewContext gets the clock skew value from the context +func GetAttemptSkewContext(ctx context.Context) time.Duration { + x, _ := middleware.GetStackValue(ctx, clockSkew{}).(time.Duration) + return x +} diff --git a/internal/middleware/middleware.go b/internal/middleware/middleware.go new file mode 100644 index 00000000000..8e24a3f0a47 --- /dev/null +++ b/internal/middleware/middleware.go @@ -0,0 +1,42 @@ +package middleware + +import ( + "context" + "sync/atomic" + "time" + + internalcontext "github.com/aws/aws-sdk-go-v2/internal/context" + "github.com/aws/smithy-go/middleware" +) + +// AddTimeOffsetMiddleware sets a value representing clock skew on the request context. +// This can be read by other operations (such as signing) to correct the date value they send +// on the request +type AddTimeOffsetMiddleware struct { + Offset *atomic.Int64 +} + +// ID the identifier for AddTimeOffsetMiddleware +func (m *AddTimeOffsetMiddleware) ID() string { return "AddTimeOffsetMiddleware" } + +// HandleBuild sets a value for attemptSkew on the request context if one is set on the client. +func (m AddTimeOffsetMiddleware) HandleBuild(ctx context.Context, in middleware.BuildInput, next middleware.BuildHandler) ( + out middleware.BuildOutput, metadata middleware.Metadata, err error, +) { + if m.Offset != nil { + offset := time.Duration(m.Offset.Load()) + ctx = internalcontext.SetAttemptSkewContext(ctx, offset) + } + return next.HandleBuild(ctx, in) +} + +// HandleDeserialize gets the clock skew context from the context, and if set, sets it on the pointer +// held by AddTimeOffsetMiddleware +func (m *AddTimeOffsetMiddleware) HandleDeserialize(ctx context.Context, in middleware.DeserializeInput, next middleware.DeserializeHandler) ( + out middleware.DeserializeOutput, metadata middleware.Metadata, err error, +) { + if v := internalcontext.GetAttemptSkewContext(ctx); v != 0 { + m.Offset.Store(v.Nanoseconds()) + } + return next.HandleDeserialize(ctx, in) +} diff --git a/internal/middleware/middleware_test.go b/internal/middleware/middleware_test.go new file mode 100644 index 00000000000..342abef7bce --- /dev/null +++ b/internal/middleware/middleware_test.go @@ -0,0 +1,254 @@ +package middleware_test + +import ( + "context" + "fmt" + "github.com/aws/aws-sdk-go-v2/aws" + "github.com/aws/aws-sdk-go-v2/aws/retry" + "net/http" + "sync/atomic" + "testing" + "time" + + "github.com/aws/aws-sdk-go-v2/aws/middleware" + internalcontext "github.com/aws/aws-sdk-go-v2/internal/context" + internalmiddleware "github.com/aws/aws-sdk-go-v2/internal/middleware" + "github.com/aws/aws-sdk-go-v2/internal/sdk" + smithymiddleware "github.com/aws/smithy-go/middleware" + smithyhttp "github.com/aws/smithy-go/transport/http" +) + +type httpClient interface { + Do(*http.Request) (*http.Response, error) +} + +type options struct { + HTTPClient httpClient + RetryMode aws.RetryMode + Retryer aws.Retryer + Offset *atomic.Int64 +} + +type MockClient struct { + options options +} + +func addRetry(stack *smithymiddleware.Stack, o options) error { + attempt := retry.NewAttemptMiddleware(o.Retryer, smithyhttp.RequestCloner, func(m *retry.Attempt) { + m.LogAttempts = false + }) + return stack.Finalize.Add(attempt, smithymiddleware.After) +} + +func addOffset(stack *smithymiddleware.Stack, o options) error { + offsetMiddleware := internalmiddleware.AddTimeOffsetMiddleware{Offset: o.Offset} + err := stack.Build.Add(&offsetMiddleware, smithymiddleware.After) + if err != nil { + return err + } + err = stack.Deserialize.Add(&offsetMiddleware, smithymiddleware.Before) + if err != nil { + return err + } + return nil +} + +// Middleware to set a `Date` object that includes sdk time and offset +type MockAddDateHeader struct { +} + +func (l *MockAddDateHeader) ID() string { + return "MockAddDateHeader" +} + +func (l *MockAddDateHeader) HandleFinalize( + ctx context.Context, in smithymiddleware.FinalizeInput, next smithymiddleware.FinalizeHandler, +) ( + out smithymiddleware.FinalizeOutput, metadata smithymiddleware.Metadata, attemptError error, +) { + req := in.Request.(*smithyhttp.Request) + date := sdk.NowTime() + skew := internalcontext.GetAttemptSkewContext(ctx) + date = date.Add(skew) + req.Header.Set("Date", date.Format(time.RFC850)) + return next.HandleFinalize(ctx, in) +} + +// Middleware to deserialize the response which just says "OK" if the response is 200 +type DeserializeFailIfNotHTTP200 struct { +} + +func (*DeserializeFailIfNotHTTP200) ID() string { + return "DeserializeFailIfNotHTTP200" +} + +func (m *DeserializeFailIfNotHTTP200) HandleDeserialize(ctx context.Context, in smithymiddleware.DeserializeInput, next smithymiddleware.DeserializeHandler) ( + out smithymiddleware.DeserializeOutput, metadata smithymiddleware.Metadata, err error, +) { + out, metadata, err = next.HandleDeserialize(ctx, in) + if err != nil { + return out, metadata, err + } + response, ok := out.RawResponse.(*smithyhttp.Response) + if !ok { + return out, metadata, fmt.Errorf("expected raw response to be set on testing") + } + if response.StatusCode != 200 { + return out, metadata, mockRetryableError{true} + } + return out, metadata, err +} + +func (c *MockClient) setupMiddleware(stack *smithymiddleware.Stack) error { + err := error(nil) + if c.options.Retryer != nil { + err = addRetry(stack, c.options) + if err != nil { + return err + } + } + if c.options.Offset != nil { + err = addOffset(stack, c.options) + if err != nil { + return err + } + } + err = stack.Finalize.Add(&MockAddDateHeader{}, smithymiddleware.After) + if err != nil { + return err + } + err = middleware.AddRecordResponseTiming(stack) + if err != nil { + return err + } + err = stack.Deserialize.Add(&DeserializeFailIfNotHTTP200{}, smithymiddleware.After) + if err != nil { + return err + } + return nil +} + +func (c *MockClient) Do(ctx context.Context) (interface{}, error) { + // setup middlewares + ctx = smithymiddleware.ClearStackValues(ctx) + stack := smithymiddleware.NewStack("stack", smithyhttp.NewStackRequest) + err := c.setupMiddleware(stack) + if err != nil { + return nil, err + } + handler := smithymiddleware.DecorateHandler(smithyhttp.NewClientHandler(c.options.HTTPClient), stack) + result, _, err := handler.Handle(ctx, 1) + if err != nil { + return nil, err + } + return result, err +} + +type mockRetryableError struct{ b bool } + +func (m mockRetryableError) RetryableError() bool { return m.b } +func (m mockRetryableError) Error() string { + return fmt.Sprintf("mock retryable %t", m.b) +} + +func failRequestIfSkewed() smithyhttp.ClientDoFunc { + return func(req *http.Request) (*http.Response, error) { + dateHeader := req.Header.Get("Date") + if dateHeader == "" { + return nil, fmt.Errorf("expected `Date` header to be set") + } + reqDate, err := time.Parse(time.RFC850, dateHeader) + if err != nil { + return nil, err + } + parsedReqTime := time.Now().Sub(reqDate) + parsedReqTime = time.Duration.Abs(parsedReqTime) + thresholdForSkewError := 4 * time.Minute + if thresholdForSkewError-parsedReqTime <= 0 { + return &http.Response{ + StatusCode: 403, + Header: http.Header{ + "Date": {time.Now().Format(time.RFC850)}, + }, + }, nil + } + // else, return OK + return &http.Response{ + StatusCode: 200, + Header: http.Header{}, + }, nil + } +} + +func TestSdkOffsetIsSet(t *testing.T) { + nowTime := sdk.NowTime + defer func() { + sdk.NowTime = nowTime + }() + fiveMinuteSkew := func() time.Time { + return time.Now().Add(5 * time.Minute) + } + sdk.NowTime = fiveMinuteSkew + c := MockClient{ + options{ + HTTPClient: failRequestIfSkewed(), + }, + } + resp, err := c.Do(context.Background()) + if err == nil { + t.Errorf("Expected first request to fail since clock skew logic has not run. Got %v and err %v", resp, err) + } +} + +func TestRetrySetsSkewInContext(t *testing.T) { + defer resetDefaults(sdk.TestingUseNopSleep()) + fiveMinuteSkew := func() time.Time { + return time.Now().Add(5 * time.Minute) + } + sdk.NowTime = fiveMinuteSkew + c := MockClient{ + options{ + HTTPClient: failRequestIfSkewed(), + Retryer: retry.NewStandard(func(s *retry.StandardOptions) { + }), + }, + } + resp, err := c.Do(context.Background()) + if err != nil { + t.Errorf("Expected request to succeed on retry. Got %v and err %v", resp, err) + } +} + +func TestSkewIsSetOnTheWholeClient(t *testing.T) { + defer resetDefaults(sdk.TestingUseNopSleep()) + fiveMinuteSkew := func() time.Time { + return time.Now().Add(5 * time.Minute) + } + sdk.NowTime = fiveMinuteSkew + var offset atomic.Int64 + offset.Store(0) + c := MockClient{ + options{ + HTTPClient: failRequestIfSkewed(), + Retryer: retry.NewStandard(func(s *retry.StandardOptions) { + }), + Offset: &offset, + }, + } + resp, err := c.Do(context.Background()) + if err != nil { + t.Errorf("Expected request to succeed on retry. Got %v and err %v", resp, err) + } + // Remove retryer so it has to succeed on first call + c.options.Retryer = nil + // same client, new request + resp, err = c.Do(context.Background()) + if err != nil { + t.Errorf("Expected second request to succeed since the skew should be set on the client. Got %v and err %v", resp, err) + } +} + +func resetDefaults(restoreSleepFunc func()) { + sdk.NowTime = time.Now + restoreSleepFunc() +} diff --git a/internal/v4a/smithy.go b/internal/v4a/smithy.go index 516d459d5dc..454f99dcf11 100644 --- a/internal/v4a/smithy.go +++ b/internal/v4a/smithy.go @@ -3,6 +3,7 @@ package v4a import ( "context" "fmt" + internalcontext "github.com/aws/aws-sdk-go-v2/internal/context" "time" v4 "github.com/aws/aws-sdk-go-v2/aws/signer/v4" @@ -72,7 +73,11 @@ func (v *SignerAdapter) SignRequest(ctx context.Context, r *smithyhttp.Request, } hash := v4.GetPayloadHash(ctx) - err := v.Signer.SignHTTP(ctx, ca.Credentials, r.Request, hash, name, regions, sdk.NowTime(), func(o *SignerOptions) { + signingTime := sdk.NowTime() + if skew := internalcontext.GetAttemptSkewContext(ctx); skew != 0 { + signingTime.Add(skew) + } + err := v.Signer.SignHTTP(ctx, ca.Credentials, r.Request, hash, name, regions, signingTime, func(o *SignerOptions) { o.DisableURIPathEscaping, _ = smithyhttp.GetDisableDoubleEncoding(&props) o.Logger = v.Logger