Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

support s3control using different xml error format than s3 #875

Merged
merged 4 commits into from
Nov 9, 2020
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import software.amazon.smithy.codegen.core.SymbolProvider;
import software.amazon.smithy.go.codegen.GoWriter;
import software.amazon.smithy.go.codegen.SmithyGoDependency;
import software.amazon.smithy.go.codegen.SymbolUtils;
import software.amazon.smithy.go.codegen.SyntheticClone;
import software.amazon.smithy.go.codegen.integration.ProtocolGenerator;
import software.amazon.smithy.model.Model;
Expand Down Expand Up @@ -274,12 +275,28 @@ public static void writeXmlErrorMessageCodeDeserializer(ProtocolGenerator.Genera
ServiceShape service = context.getService();

if (requiresS3Customization(service)) {
writer.addUseImports(AwsCustomGoDependency.S3_SHARED_CUSTOMIZATION);
Symbol getErrorComponentFunction = SymbolUtils.createValueSymbolBuilder(
"GetErrorResponseComponents",
AwsCustomGoDependency.S3_SHARED_CUSTOMIZATION
).build();

Symbol errorOptions = SymbolUtils.createValueSymbolBuilder(
"ErrorResponseDeserializerOptions",
AwsCustomGoDependency.S3_SHARED_CUSTOMIZATION
).build();

if (isS3Service(service)){
writer.write("errorComponents, err := s3shared.GetS3ErrorResponseComponents(errorBody, response.StatusCode)");
// s3 service
writer.openBlock("errorComponents, err := $T(errorBody, $T{",
"})", getErrorComponentFunction, errorOptions, () -> {
writer.write("UseStatusCode : true, StatusCode : response.StatusCode,");
});
} else {
// s3 control
writer.write("errorComponents, err := s3shared.GetErrorResponseComponents(errorBody)");
writer.openBlock("errorComponents, err := $T(errorBody, $T{",
"})", getErrorComponentFunction, errorOptions, () -> {
writer.write("IsWrappedWithErrorTag: true,");
});
}

writer.write("if err != nil { return err }");
Expand Down
61 changes: 52 additions & 9 deletions service/internal/s3shared/xml_utils.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,29 +17,72 @@ type ErrorComponents struct {
HostID string `xml:"HostId"`
}

// GetErrorResponseComponents returns the error fields from an xml error response body
func GetErrorResponseComponents(r io.Reader) (ErrorComponents, error) {
// GetUnwrappedErrorResponseComponents returns the error fields from an xml error response body
func GetUnwrappedErrorResponseComponents(r io.Reader) (ErrorComponents, error) {
var errComponents ErrorComponents
if err := xml.NewDecoder(r).Decode(&errComponents); err != nil && err != io.EOF {
return ErrorComponents{}, fmt.Errorf("error while deserializing xml error response : %w", err)
}
return errComponents, nil
}

// GetS3ErrorResponseComponents returns the error fields from an S3 xml error response body.
// If an error code or message is not retrieved, it is derived from the http status code
func GetS3ErrorResponseComponents(r io.Reader, statusCode int) (ErrorComponents, error) {
errComponents, err := GetErrorResponseComponents(r)
// GetErrorResponseComponents returns the error fields from an xml error response body
// in which error code, and message are wrapped by a <Error> tag
func GetWrappedErrorResponseComponents(r io.Reader) (ErrorComponents, error) {
var errComponents struct {
Code string `xml:"Error>Code"`
Message string `xml:"Error>Message"`
RequestID string `xml:"RequestId"`
HostID string `xml:"HostId"`
}

if err := xml.NewDecoder(r).Decode(&errComponents); err != nil && err != io.EOF {
return ErrorComponents{}, fmt.Errorf("error while deserializing xml error response : %w", err)
}

return ErrorComponents{
Code: errComponents.Code,
Message: errComponents.Message,
RequestID: errComponents.RequestID,
HostID: errComponents.HostID,
}, nil
}

// GetErrorResponseComponents retrieves error components according to passed in options
func GetErrorResponseComponents(r io.Reader, options ErrorResponseDeserializerOptions) (ErrorComponents, error) {
var errComponents ErrorComponents
var err error

if options.IsWrappedWithErrorTag {
errComponents, err = GetWrappedErrorResponseComponents(r)
} else {
errComponents, err = GetUnwrappedErrorResponseComponents(r)
}

if err != nil {
return ErrorComponents{}, err
}

// for S3 service, we derive err code and message, if none is found
if len(errComponents.Code) == 0 && len(errComponents.Message) == 0 {
// If an error code or message is not retrieved, it is derived from the http status code
// eg, for S3 service, we derive err code and message, if none is found
if options.UseStatusCode && len(errComponents.Code) == 0 &&
len(errComponents.Message) == 0 {
// derive code and message from status code
statusText := http.StatusText(statusCode)
statusText := http.StatusText(options.StatusCode)
errComponents.Code = strings.Replace(statusText, " ", "", -1)
errComponents.Message = statusText
}
return errComponents, nil
}

type ErrorResponseDeserializerOptions struct {
// UseStatusCode denotes if status code should be used to retrieve error code, msg
UseStatusCode bool

// StatusCode is status code of error response
StatusCode int

//IsWrappedWithErrorTag represents if error response's code, msg is wrapped within an
// additional <Error> tag
IsWrappedWithErrorTag bool
}
36 changes: 33 additions & 3 deletions service/internal/s3shared/xml_utils_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,15 @@ func TestGetResponseErrorCode(t *testing.T) {
<RequestId>foo-id</RequestId>
</Error>`

const wrappedXmlErrorResponse = `<ErrorResponse><Error>
<Type>Sender</Type>
<Code>InvalidGreeting</Code>
<Message>Hi</Message>
</Error>
<HostId>bar-id</HostId>
<RequestId>foo-id</RequestId>
</ErrorResponse>`

cases := map[string]struct {
getErr func() (ErrorComponents, error)
expectedErrorCode string
Expand All @@ -24,7 +33,11 @@ func TestGetResponseErrorCode(t *testing.T) {
"standard xml error": {
getErr: func() (ErrorComponents, error) {
errResp := strings.NewReader(xmlErrorResponse)
return GetErrorResponseComponents(errResp)
return GetErrorResponseComponents(errResp, ErrorResponseDeserializerOptions{
UseStatusCode: false,
StatusCode: 0,
IsWrappedWithErrorTag: false,
})
},
expectedErrorCode: "InvalidGreeting",
expectedErrorMessage: "Hi",
Expand All @@ -35,17 +48,34 @@ func TestGetResponseErrorCode(t *testing.T) {
"s3 no response body": {
getErr: func() (ErrorComponents, error) {
errResp := strings.NewReader("")
return GetS3ErrorResponseComponents(errResp, 400)
return GetErrorResponseComponents(errResp, ErrorResponseDeserializerOptions{
UseStatusCode: true,
StatusCode: 400,
})
},
expectedErrorCode: "BadRequest",
expectedErrorMessage: "Bad Request",
},
"s3control no response body": {
getErr: func() (ErrorComponents, error) {
errResp := strings.NewReader("")
return GetErrorResponseComponents(errResp)
return GetErrorResponseComponents(errResp, ErrorResponseDeserializerOptions{
IsWrappedWithErrorTag: true,
})
},
},
"s3control standard response body": {
getErr: func() (ErrorComponents, error) {
errResp := strings.NewReader(wrappedXmlErrorResponse)
return GetErrorResponseComponents(errResp, ErrorResponseDeserializerOptions{
IsWrappedWithErrorTag: true,
})
},
expectedErrorCode: "InvalidGreeting",
expectedErrorMessage: "Hi",
expectedErrorRequestID: "foo-id",
expectedErrorHostID: "bar-id",
},
}

for name, c := range cases {
Expand Down
Loading