From b79d8023fd4a9ef0008e0a005eabd5e6c2f383f7 Mon Sep 17 00:00:00 2001 From: skotambkar Date: Mon, 1 Jun 2020 13:23:08 -0700 Subject: [PATCH 1/4] adds support for restjson deserializer middleware, json deserializers for output, error shapes --- aws/protocol/json/decoder_util.go | 108 +++ .../go/codegen/RestJsonProtocolGenerator.java | 771 +++++++++++++++++- 2 files changed, 872 insertions(+), 7 deletions(-) create mode 100644 aws/protocol/json/decoder_util.go diff --git a/aws/protocol/json/decoder_util.go b/aws/protocol/json/decoder_util.go new file mode 100644 index 00000000000..8c1ce9812a9 --- /dev/null +++ b/aws/protocol/json/decoder_util.go @@ -0,0 +1,108 @@ +package json + +import ( + "encoding/json" + "fmt" + "io" + "strings" +) + +// GetErrorInfo util looks for code, __type, and message members in the +// json body. These members are optionally available, and the function +// returns the value of member if it is available. This function is useful to +// identify the error code, msg in a REST JSON error response. +func GetErrorInfo(decoder *json.Decoder) (string, string, error) { + + var code, typeCode, msg string + + startToken, err := decoder.Token() + if err == io.EOF { + return "", "", nil + } + if err != nil { + return "", "", err + } + + if t, ok := startToken.(json.Delim); !ok || t.String() != "{" { + return "", "", fmt.Errorf("expected start token to be {") + } + + for decoder.More() { + t, err := decoder.Token() + if err != nil { + return "", "", err + } + + switch t { + case "code": + v, err := decoder.Token() + if err != nil { + return "", "", err + } + code = v.(string) + break + case "message": + v, err := decoder.Token() + if err != nil { + return "", "", err + } + msg = v.(string) + break + case "__type": + v, err := decoder.Token() + if err != nil { + return "", "", err + } + typeCode = v.(string) + break + default: + DiscardUnknownField(decoder) + break + } + } + + endToken, err := decoder.Token() + if err != nil { + return "", "", err + } + + if t, ok := endToken.(json.Delim); !ok || t.String() != "}" { + return "", "", fmt.Errorf("expected end token to be }") + } + + if len(code) == 0 { + return typeCode, msg, nil + } + return code, msg, nil +} + +// SanitizeErrorCode sanitizes the errorCode string . +// The rule for sanitizing is if a `:` character is present, then take only the +// contents before the first : character in the value. +// If a # character is present, then take only the contents after the +// first # character in the value. +func SanitizeErrorCode(errorCode string) string { + if strings.ContainsAny(errorCode, ":") { + errorCode = strings.SplitN(errorCode, ":", 2)[0] + } + + if strings.ContainsAny(errorCode, "#") { + errorCode = strings.SplitN(errorCode, "#", 2)[1] + } + + return errorCode +} + + +// DiscardUnknownField discards unknown fields from decoder body. +// This function is useful while deserializing json body with additional +// unknown information that should be discarded. +func DiscardUnknownField(decoder *json.Decoder) { + v, _ := decoder.Token() + if _, ok := v.(json.Delim); ok { + for decoder.More() { + DiscardUnknownField(decoder) + } + decoder.Token() + } +} diff --git a/codegen/smithy-aws-go-codegen/src/main/java/software/amazon/smithy/aws/go/codegen/RestJsonProtocolGenerator.java b/codegen/smithy-aws-go-codegen/src/main/java/software/amazon/smithy/aws/go/codegen/RestJsonProtocolGenerator.java index 60e38b8c4c3..b5a0131dbd3 100644 --- a/codegen/smithy-aws-go-codegen/src/main/java/software/amazon/smithy/aws/go/codegen/RestJsonProtocolGenerator.java +++ b/codegen/smithy-aws-go-codegen/src/main/java/software/amazon/smithy/aws/go/codegen/RestJsonProtocolGenerator.java @@ -15,8 +15,10 @@ package software.amazon.smithy.aws.go.codegen; +import java.util.Collection; import java.util.Optional; import java.util.Set; +import java.util.TreeSet; import java.util.function.Predicate; import java.util.stream.Collectors; import software.amazon.smithy.codegen.core.CodegenException; @@ -37,11 +39,14 @@ import software.amazon.smithy.model.shapes.MemberShape; import software.amazon.smithy.model.shapes.OperationShape; import software.amazon.smithy.model.shapes.Shape; +import software.amazon.smithy.model.shapes.ShapeId; import software.amazon.smithy.model.shapes.ShapeType; import software.amazon.smithy.model.shapes.StructureShape; import software.amazon.smithy.model.traits.EnumTrait; +import software.amazon.smithy.model.traits.HttpErrorTrait; import software.amazon.smithy.model.traits.JsonNameTrait; import software.amazon.smithy.model.traits.MediaTypeTrait; +import software.amazon.smithy.model.traits.StreamingTrait; import software.amazon.smithy.model.traits.TimestampFormatTrait; import software.amazon.smithy.utils.FunctionalUtils; @@ -299,16 +304,15 @@ protected void writeMiddlewareDocumentSerializerDelegator( if (payloadBinding.isPresent()) { MemberShape memberShape = payloadBinding.get().getMember(); Shape payloadShape = model.expectShape(memberShape.getTarget()); - ShapeType shapeType = payloadShape.getType(); String memberName = symbolProvider.toMemberName(memberShape); Optional mediaTypeTrait = payloadShape.getTrait(MediaTypeTrait.class); mediaTypeTrait.ifPresent(typeTrait -> writer.write("restEncoder.SetHeader(\"Content-Type\").String($S)", typeTrait.getValue())); - if (shapeType == ShapeType.BLOB) { + if (payloadShape.isBlobShape()) { writer.write("documentPayload = input.$L", memberName); - } else if (shapeType == ShapeType.STRING) { + } else if (payloadShape.isStringShape()) { writer.write("documentPayload = []byte(input.$L)", memberName); } else { String functionName = ProtocolGenerator.getDocumentSerializerFunctionName(payloadShape, @@ -354,10 +358,6 @@ protected void generateDocumentBodyShapeSerializers(GenerationContext context, S }); } - @Override - protected void generateDocumentBodyShapeDeserializers(GenerationContext context, Set shapes) { - } - @Override public void generateSharedSerializerComponents(GenerationContext context) { super.generateSharedSerializerComponents(context); @@ -374,4 +374,761 @@ private String getSerializedMemberName(MemberShape memberShape) { Optional jsonNameTrait = memberShape.getTrait(JsonNameTrait.class); return jsonNameTrait.isPresent() ? jsonNameTrait.get().getValue() : memberShape.getMemberName(); } + + @Override + protected void writeMiddlewareDocumentDeserializerDelegator( + GoWriter writer, + Model model, + SymbolProvider symbolProvider, + OperationShape operation, + GoStackStepMiddlewareGenerator generator + ) { + Shape outputShape = model.expectShape(operation.getOutput().get()); + boolean isShapeWithPayloadBinding = isShapeWithResponseBindings(model, operation, HttpBinding.Location.PAYLOAD); + + if (isShapeWithPayloadBinding){ + Set memberShapesWithPayloadBinding = new TreeSet<>(); + model.getKnowledge(HttpBindingIndex.class) + .getResponseBindings(operation).values().stream() + .filter(binding -> binding.getLocation().equals(HttpBinding.Location.PAYLOAD)) + .forEach(binding -> { + memberShapesWithPayloadBinding.add(binding.getMember()); + }); + + + // since payload trait can only be applied to a single member in a output shape + MemberShape memberShape = memberShapesWithPayloadBinding.iterator().next(); + Shape targetShape = model.expectShape(memberShape.getTarget()); + + // if target shape is of type String or type Blob, then delegate deserializers for explicit payload shapes + if (targetShape.isStringShape() || targetShape.isBlobShape()) { + writeMiddlewarePayloadBindingDeserializerDelegator(writer, outputShape, false); + return; + } + } + + writer.write("buff := make([]byte, 1024)"); + writer.write("ringBuffer := sdkio.NewRingBuffer(buff)"); + writer.write(""); + + writer.addUseImports(GoDependency.IO); + writer.write("body := io.TeeReader(response.Body, ringBuffer)"); + writer.write("defer response.Body.Close()"); + writer.write(""); + + writer.addUseImports(GoDependency.JSON); + writer.write("decoder := json.NewDecoder(body)"); + writer.write("decoder.UseNumber()"); + writer.write(""); + + writeMiddlewareDocumentBindingDeserializerDelegator(writer, outputShape, false); + } + + // Writes middleware that delegates to deserializers for shapes that have explicit payload. + private void writeMiddlewarePayloadBindingDeserializerDelegator(GoWriter writer, Shape shape, + Boolean isErrorShape) { + String deserFuncName = isErrorShape? + ProtocolGenerator.getDocumentDeserializerFunctionName(shape, getProtocolName()) : + ProtocolGenerator.getDocumentOutputDeserializerFunctionName(shape, getProtocolName()); + writer.write("err = $L(output, response.Body)", deserFuncName); + writer.openBlock("if err != nil {", "}", () -> { + writer.write(String.format("return out, metadata, &aws.DeserializationError{Err:%s}", + "fmt.Errorf(\"failed to deserialize response payload, %w\", err)")); + }); + } + + + // Write middleware that delegates to deserializers for shapes that have implicit payload and deserializer + private void writeMiddlewareDocumentBindingDeserializerDelegator(GoWriter writer, Shape shape, + Boolean isErrorShape) { + String deserFuncName = isErrorShape? + ProtocolGenerator.getDocumentDeserializerFunctionName(shape, getProtocolName()) : + ProtocolGenerator.getDocumentOutputDeserializerFunctionName(shape, getProtocolName()); + writer.write("err = $L(output, decoder)", deserFuncName); + writer.openBlock("if err != nil {", "}", () -> { + writer.write("var snapshot bytes.Buffer"); + writer.write("io.Copy(&snapshot, ringBuffer)"); + writer.openBlock("return out, metadata, &aws.DeserializationError {", "}", () -> { + writer.write("Err: fmt.Errorf(\"failed to decode response body with invalid JSON, %w\", err),"); + writer.write("Snapshot: snapshot.Bytes(),"); + }); + }); + } + + @Override + protected void writeMiddlewareErrorDeserializer( + GoWriter writer, + Model model, + SymbolProvider symbolProvider, + OperationShape operationShape, + GoStackStepMiddlewareGenerator generator + ) { + Collection ErrorShapeIds = operationShape.getErrors(); + Symbol genericAPIErrorSymbol = SymbolUtils.createValueSymbolBuilder( + "GenericAPIError", GoDependency.SMITHY).build(); + + // checks if response has an error and retrieve the error code from the response + writer.openBlock("if response.StatusCode < 200 || response.StatusCode >= 300 {", "}", () -> { + + // if no modeled exceptions for the operation shape, return the response body as is + if (ErrorShapeIds.size() == 0) { + writer.addUseImports(GoDependency.JSON); + writer.write("decoder := json.NewDecoder(response.Body)"); + writer.write("decoder.UseNumber()"); + writer.write("defer response.Body.Close()"); + writer.write(""); + + writer.addUseImports(GoDependency.AWS_JSON_PROTOCOL_ALIAS); + writer.write("errorType, message, err := jsonprotocol.GetErrorInfo(decoder)"); + writer.openBlock( "if len(errortype) == 0 {", "}", () -> { + writer.write("errorType = response.Headers.Get($S)","X-Amzn-Errortype" ); + }); + writer.write("errorType = jsonprotocol.SanitizeErrorCode(errorType)"); + + writer.openBlock("genericError := $P{","}", + genericAPIErrorSymbol, () ->{ + writer.write("Code : errorType,"); + writer.write("Message : message,"); + }); + + writer.write("return out, metadata, genericError"); + return; + } + + writer.write("buff := make([]byte, 1024)"); + writer.write("ringBuffer := sdkio.NewRingBuffer(buff)"); + writer.write(""); + + writer.write("var errorBody bytes.Buffer"); + + writer.addUseImports(GoDependency.IO); + writer.write("_, err := io.Copy(errorBody, response.Body)"); + writer.openBlock("if err != nil {", "}", () -> { + writer.write(String.format("return out, metadata, &aws.DeserializationError{Err: %s}", + "fmt.Errorf(\"failed to copy error response body, %w\", err)")); + }); + writer.write("body := io.TeeReader(response.Body, ringBuffer)"); + writer.write("defer response.Body.Close()"); + writer.write(""); + + // Retrieve error shape name from response. For REST JSON protocol, the error shape name can be found either + // at Header `X-Amzn-Errortype` or a body field with the name `code`, or a body field named `__type`. + writer.addUseImports(GoDependency.JSON); + writer.write("decoder := json.NewDecoder(body)"); + writer.write("decoder.UseNumber()"); + writer.write(""); + + writer.addUseImports(GoDependency.AWS_JSON_PROTOCOL_ALIAS); + writer.write("errorType, message, err := jsonprotocol.GetErrorInfo(decoder)"); + writer.openBlock("if err != nil {", "}", () -> { + writer.write("var snapshot bytes.Buffer"); + writer.write("io.Copy(&snapshot, ringBuffer)"); + writer.openBlock("return out, metadata, &aws.DeserializationError {", "}", () -> { + writer.write("Err: fmt.Errorf(\"failed to decode response error with invalid JSON, %w\", err),"); + writer.write("Snapshot: snapshot.Bytes(),"); + }); + }); + + writer.openBlock("if len(errorType) == 0 {", "}", () -> { + writer.write("errorType = response.Headers.Get($S)", "X-Amzn-Errortype"); + }); + + writer.write("errorType = jsonprotocol.SanitizeErrorCode(errorType)"); + writer.write(""); + + // generate middleware for modeled error shapes + writeErrorShapeDeserializerDelegator(writer, model, symbolProvider, ErrorShapeIds); + writer.write(""); + + writer.openBlock("genericError := $P{","}", + genericAPIErrorSymbol, () -> { + writer.write("Code : errorType,"); + writer.write("Message : message,"); + }); + writer.write(""); + writer.write("return out, metadata, genericError"); + }); + writer.write(""); + } + + // writeErrorShapeMiddlewareDelegator takes in the list of error shapes, and generates + // middleware error shape delegators. It delegates based on whether the error shape has + // rest bindings, payload bindings, document bindings. + private void writeErrorShapeDeserializerDelegator(GoWriter writer, Model model, SymbolProvider symbolProvider, + Collection ErrorShapeIds){ + + writer.write("body = io.TeeReader(errorBody, ringBuffer)"); + writer.write("decoder = json.NewDecoder(errorBody)"); + writer.write("decoder.UseNumber()"); + writer.write(""); + + for (ShapeId errorShapeId: ErrorShapeIds) { + Shape errorShape = model.expectShape(errorShapeId); + Symbol errorSymbol = symbolProvider.toSymbol(errorShape); + + writer.openBlock("if errorType == $S {", "}", errorShapeId.getName(), () -> { + writer.write("errorResult := &$T{}", errorSymbol); + writer.write("output, ok := errorResult.($P)", errorSymbol); + writer.openBlock("if !ok {", "}", () -> { + writer.write(String.format("return out, metadata, &aws.DeserializationError{Err: %s}", + "fmt.Errorf(\"unknown error result type %T\", out.Result)")); + }); + writer.write(""); + + // If error has an HttpError trait modeled on it, assign the value to the response status code + if (errorShape.hasTrait(HttpErrorTrait.class)) { + int errorStatusCode = errorShape.getTrait(HttpErrorTrait.class).get().getCode(); + writer.addUseImports(GoDependency.NET_HTTP); + writer.write("response.StatusCode = $L", errorStatusCode); + writer.write("response.Status = http.StatusText($L)", errorStatusCode); + writer.write(""); + } + + if (isShapeWithRestResponseBindings(model, errorShape)) { + String deserFuncName = ProtocolGenerator.getOperationHttpBindingsDeserFunctionName( + errorShape, getProtocolName()); + + writer.write("err= $L(output, response)", deserFuncName); + writer.openBlock("if err != nil {", "}", () -> { + writer.write(String.format("return out, metadata, &aws.DeserializationError{Err: %s}", + "fmt.Errorf(\"failed to decode response error with invalid Http bindings, %w\", err)")); + }); + writer.write(""); + } + + if (isShapeWithResponseBindings(model, errorShape, HttpBinding.Location.DOCUMENT) + || isShapeWithResponseBindings(model, errorShape, HttpBinding.Location.PAYLOAD)) { + writeMiddlewareDocumentBindingDeserializerDelegator(writer,errorShape, true); + } + + // TODO: fix variable scoping and shadowing + writer.write("return out, metadata, errorResult"); + }); + writer.write(""); + } + } + + @Override + protected void generateOperationDocumentDeserializer( + GenerationContext context, + OperationShape operation + ) { + Model model = context.getModel(); + HttpBindingIndex bindingIndex = model.getKnowledge(HttpBindingIndex.class); + Set documentBindings = bindingIndex.getResponseBindings(operation, HttpBinding.Location.DOCUMENT) + .stream() + .map(HttpBinding::getMember) + .collect(Collectors.toSet()); + + Shape outputShape = model.expectShape(operation.getOutput() + .orElseThrow(() -> new CodegenException("Output shape missing for operation " + operation.getId()))); + GoWriter writer = context.getWriter(); + + if (documentBindings.size() != 0) { + writeDocumentBindingDeserializer(writer, model, context.getSymbolProvider(), outputShape, + documentBindings::contains, true); + writer.write(""); + } + + Set payloadBindings = bindingIndex.getResponseBindings(operation, HttpBinding.Location.PAYLOAD) + .stream() + .map(HttpBinding::getMember) + .collect(Collectors.toSet()); + + if (payloadBindings.size() == 0) { + return; + } + + writePayloadBindingDeserializer(writer, model, context.getSymbolProvider(), outputShape, + payloadBindings::contains, true); + writer.write(""); + } + + @Override + protected void generateErrorDocumentBindingDeserializer(GenerationContext context, ShapeId shapeId) { + Model model = context.getModel(); + Shape shape = model.expectShape(shapeId); + GoWriter writer = context.getWriter(); + + HttpBindingIndex bindingIndex = model.getKnowledge(HttpBindingIndex.class); + Set errorDocumentBinding = bindingIndex.getResponseBindings(shapeId, HttpBinding.Location.DOCUMENT) + .stream() + .map(HttpBinding::getMember) + .collect(Collectors.toSet()); + + if (errorDocumentBinding.size() != 0) { + writeDocumentBindingDeserializer(writer, model, context.getSymbolProvider(), shape, + errorDocumentBinding::contains, false); + writer.write(""); + } + + Set errorPayloadBinding = bindingIndex.getResponseBindings(shapeId, HttpBinding.Location.PAYLOAD) + .stream() + .map(HttpBinding::getMember) + .collect(Collectors.toSet()); + + // do not generate if no payload binding deserializer for Error Binding + if (errorPayloadBinding.size() == 0) { + return; + } + + writePayloadBindingDeserializer(writer, model, context.getSymbolProvider(), shape, + errorPayloadBinding::contains, false); + writer.write(""); + } + + + @Override + protected void generateDocumentBodyShapeDeserializers(GenerationContext context, Set shapes) { + GoWriter writer = context.getWriter(); + Model model = context.getModel(); + SymbolProvider symbolProvider = context.getSymbolProvider(); + + shapes.forEach(shape -> { + writeDocumentBindingDeserializer(writer, model, symbolProvider, shape, FunctionalUtils.alwaysTrue(), false); + writer.write(""); + }); + } + + + // Generate deserializers for shapes with payload binding + private void writePayloadBindingDeserializer( + GoWriter writer, + Model model, + SymbolProvider symbolProvider, + Shape shape, + Predicate filterMemberShapes, + Boolean isOutputShape + ) { + Symbol shapeSymbol = symbolProvider.toSymbol(shape); + String funcName = isOutputShape? ProtocolGenerator.getDocumentOutputDeserializerFunctionName(shape, getProtocolName()) + : ProtocolGenerator.getDocumentDeserializerFunctionName(shape, getProtocolName()); + + for (MemberShape memberShape : shape.members()) { + if (!filterMemberShapes.test(memberShape)) { + continue; + } + + String memberName = symbolProvider.toMemberName(memberShape); + Shape targetShape = model.expectShape(memberShape.getTarget()); + if (targetShape.isStringShape() || targetShape.isBlobShape()) { + writer.openBlock("func $L(v $P, body io.ReadCloser) error {", "}", + funcName, shapeSymbol, () -> { + writer.openBlock("if v == nil {", "}", () -> { + writer.write("return fmt.Errorf(\"unsupported deserialization of nil %T\", v)"); + }); + writer.write(""); + + if (!targetShape.hasTrait(StreamingTrait.class) && targetShape.isBlobShape()) { + writer.addUseImports(GoDependency.IOUTIL); + writer.write("bs, err := ioutil.ReadAll(body)"); + writer.write("if err != nil { return err }"); + writer.write("v.$L = bs", memberName); + } else { + writer.write("v.$L = body", memberName); + } + + writer.write("return nil"); + }); + } else { + // delegate to Json Document Binding Deserializer + writeDocumentBindingDeserializer(writer, model, symbolProvider, shape, filterMemberShapes, + isOutputShape); + } + } + } + + // Generate deserializers for shape with document binding + private void writeDocumentBindingDeserializer( + GoWriter writer, + Model model, + SymbolProvider symbolProvider, + Shape shape, + Predicate filterMemberShapes, + Boolean isOutputShape + ) { + Symbol jsonDecoder = SymbolUtils.createPointableSymbolBuilder("Decoder", GoDependency.JSON).build(); + Symbol shapeSymbol = symbolProvider.toSymbol(shape); + String functionName = isOutputShape? + ProtocolGenerator.getDocumentOutputDeserializerFunctionName(shape, getProtocolName()) : + ProtocolGenerator.getDocumentDeserializerFunctionName(shape, getProtocolName()); + + writer.addUseImports(GoDependency.FMT); + switch (shape.getType()) { + case STRUCTURE: + writer.openBlock("func $L(v $P, decoder $P) error {", "}", functionName, shapeSymbol, + jsonDecoder, () -> { + writer.openBlock("if v == nil {", "}", () -> { + writer.write("return fmt.Errorf(\"unsupported deserialization of nil %T\", v)"); + }); + writer.write(""); + generateDocumentBindingStructureShapeDeserializer(writer, model, symbolProvider, shape, + filterMemberShapes); + writer.write(""); + writer.write("return nil"); + }); + break; + case SET: + case LIST: + writer.openBlock("func $L(v $P, decoder $P) ($P, error) {", "}", functionName, shapeSymbol, + jsonDecoder, shapeSymbol, () -> { + writer.openBlock("if v == nil {", "}", () -> { + writer.write("return fmt.Errorf(\"unsupported deserialization of nil %T\", v)"); + }); + writer.write(""); + generateDocumentBindingCollectionShapeDeserializer(writer, model, symbolProvider, shape, + filterMemberShapes); + writer.write(""); + writer.write("return v, nil"); + }); + break; + case MAP: + writer.openBlock("func $L(v $P, decoder $P) ($P, error) {", "}", functionName, shapeSymbol, + jsonDecoder, shapeSymbol, () -> { + writer.openBlock("if v == nil {", "}", () -> { + writer.write("return fmt.Errorf(\"unsupported deserialization of nil %T\", v)"); + }); + writer.write(""); + generateDocumentBindingMapShapeDeserializer(writer, model, symbolProvider, shape, + filterMemberShapes); + writer.write(""); + writer.write("return v, nil"); + }); + break; + default: + break; + } + } + + // Generates deserializers for structure Shapes + private void generateDocumentBindingStructureShapeDeserializer( + GoWriter writer, + Model model, + SymbolProvider symbolProvider, + Shape shape, + Predicate filterMemberShapes + ) { + writeJsonTokenizerStartStub(writer, shape); + writer.openBlock("for decoder.More() {", "}", + () -> { + writer.write("t, err := decoder.Token()"); + writer.write("if err != nil { return err }"); + writer.openBlock("switch t {", "}", () -> { + for (MemberShape memberShape : shape.members()) { + if (!filterMemberShapes.test(memberShape)) { + continue; + } + + String memberName = symbolProvider.toMemberName(memberShape); + writer.openBlock("case $S :", "", memberShape.getMemberName(), () -> { + String operand = generateDocumentBindingMemberShapeDeserializer(writer, model, symbolProvider, memberShape); + writer.write(String.format("v.%s = %s", memberName, operand)); + writer.write("break"); + }); + } + + // default case to handle unknown fields + writer.openBlock("default : ","", () -> { + writer.addUseImports(GoDependency.AWS_JSON_PROTOCOL_ALIAS); + writer.write("jsonprotocol.DiscardUnknownField(decoder)"); + writer.write("break"); + }); + }); + }); + writeJsonTokenizerEndStub(writer, shape); + } + + + // Generates deserializers for collection shapes. + private void generateDocumentBindingCollectionShapeDeserializer( + GoWriter writer, + Model model, + SymbolProvider symbolProvider, + Shape shape, + Predicate filterMemberShapes + ) { + writeJsonTokenizerStartStub(writer, shape); + writer.openBlock("for decoder.More() {", "}", () -> { + MemberShape memberShape = shape.members().iterator().next(); + String memberName = symbolProvider.toMemberName(memberShape); + String operand = generateDocumentBindingMemberShapeDeserializer(writer, model, symbolProvider, memberShape); + + writer.write(String.format("v = append(v, %s)", operand)); + writer.write(""); + }); + writeJsonTokenizerEndStub(writer, shape); + } + + // Generates deserializers for map shapes. + private void generateDocumentBindingMapShapeDeserializer( + GoWriter writer, + Model model, + SymbolProvider symbolProvider, + Shape shape, + Predicate filterMemberShapes + ) { + writeJsonTokenizerStartStub(writer, shape); + writer.openBlock("for decoder.More() {", "}", () -> { + MemberShape memberShape = shape.members().iterator().next(); + String memberName = symbolProvider.toMemberName(memberShape); + writer.write("token, err := decoder.Token()"); + writer.write("if err != nil { return err}"); + writer.write(""); + writer.write("key, ok := token.(string)"); + writer.write("if !ok { return fmt.Errof(\"expected map-key of type string, found type %T\", t)}"); + writer.write(""); + + String operand = generateDocumentBindingMemberShapeDeserializer(writer, model, symbolProvider, memberShape); + writer.write(String.format("v[key] = %s", operand)); + writer.write(""); + }); + + writeJsonTokenizerEndStub(writer, shape); + } + + // generateDocumentBindingMemberShapeDeserializer delegates to the appropriate + // deserializer generator function for the member shapes. + private String generateDocumentBindingMemberShapeDeserializer( + GoWriter writer, + Model model, + SymbolProvider symbolProvider, + MemberShape memberShape + ) { + Shape targetShape = model.expectShape(memberShape.getTarget()); + switch (targetShape.getType()) { + case STRING: + return generateDocumentBindingStringMemberDeserializer(writer, model, symbolProvider, memberShape); + case BOOLEAN: + return generateDocumentBindingBooleanMemberDeserializer(writer, symbolProvider, memberShape); + case TIMESTAMP: + return generateDocumentBindingTimestampMemberDeserializer(writer, memberShape); + case BLOB: + return generateDocumentBindingBlobMemberDeserializer(writer, model, symbolProvider, memberShape); + case BYTE: + case SHORT: + case INTEGER: + case LONG: + case BIG_INTEGER: + return generateDocumentBindingIntegerMemberDeserializer(writer, memberShape); + case FLOAT: + case DOUBLE: + case BIG_DECIMAL: + return generateDocumentBindingFloatMemberDeserializer(writer, memberShape); + case SET: + case LIST: + case MAP: + return generateDocumentBindingCollectionMemberDeserializer(writer, model, symbolProvider, memberShape); + case STRUCTURE: + return generateDocumentBindingStructureMemberDeserializer(writer, model, symbolProvider, memberShape); + case UNION: + case DOCUMENT: + writer.write("// TODO: Support " + targetShape.getType() + " Deserialization"); + break; + default: + throw new CodegenException("Unexpected shape deserialization to JSON"); + } + return ""; + } + + + // Generates deserializer for String member shape. + private String generateDocumentBindingStringMemberDeserializer( + GoWriter writer, + Model model, + SymbolProvider symbolProvider, + MemberShape memberShape + ) { + String memberName = symbolProvider.toMemberName(memberShape); + Shape targetShape = model.expectShape(memberShape.getTarget()); + Symbol targetSymbol = symbolProvider.toSymbol(targetShape); + writer.write("val, err := decoder.Token()"); + writer.write("if err != nil { return err }"); + writer.write("st, ok := val.(string)"); + writer.openBlock("if !ok {", "}", () -> { + writer.write("return fmt.Errorf(\"expected $L to be of type $P, got %T instead\", st)" + , memberName, targetSymbol); + }); + + if (targetShape.hasTrait(EnumTrait.class)) { + return String.format("types.%s(st)", targetSymbol.getName()); + } + + return CodegenUtils.generatePointerReferenceIfPointable(memberShape, "st"); + } + + // Generates deserializer for Boolean member shape. + private String generateDocumentBindingBooleanMemberDeserializer( + GoWriter writer, + SymbolProvider symbolProvider, + MemberShape memberShape + ) { + String shapeName = symbolProvider.toMemberName(memberShape); + Symbol shapeSymbol = symbolProvider.toSymbol(memberShape); + writer.write("val, err := decoder.Token()"); + writer.write("if err != nil { return err }"); + writer.write("b, ok := val.(bool)"); + writer.openBlock("if !ok {", "}", () -> { + writer.write("return fmt.Errorf(\"expected $L to be of type $L, got %T instead\", st)" + , shapeName, shapeSymbol.getName()); + }); + + return CodegenUtils.generatePointerReferenceIfPointable(memberShape, "b"); + } + + // Generates deserializer for Byte, Short, Integer, Long, Big Integer member shape. + private String generateDocumentBindingIntegerMemberDeserializer( + GoWriter writer, + MemberShape memberShape + ) { + writer.write("val, err := decoder.Token()"); + writer.write("if err != nil { return err }"); + writer.write("nt, err := val.(json.Number).Int64()"); + writer.write("if err != nil { return err }"); + switch (memberShape.getType()) { + case BYTE: + writer.write("st := byte(nt)"); + break; + case SHORT: + writer.write("st := int16(nt)"); + case INTEGER: + writer.write("st := int32(nt)"); + break; + case LONG: + writer.write("st := nt"); + break; + case BIG_INTEGER: + writer.addUseImports(GoDependency.BIG); + writer.write("st := big.NewInt(nt)"); + break; + } + + return CodegenUtils.generatePointerReferenceIfPointable(memberShape, "st"); + } + + // Generates deserializer for Float, Double, Big Decimal member shape. + private String generateDocumentBindingFloatMemberDeserializer( + GoWriter writer, + MemberShape memberShape + ) { + writer.write("val, err := decoder.Token()"); + writer.write("if err != nil { return err }"); + writer.write("nt, err := val.(json.Number).Float64()"); + writer.write("if err != nil { return err }"); + + switch (memberShape.getType()) { + case FLOAT: + writer.write("st := float32(nt)"); + break; + case DOUBLE: + writer.write("st := nt"); + break; + case BIG_DECIMAL: + writer.addUseImports(GoDependency.BIG); + writer.write("st := big.NewFloat(nt)"); + break; + } + + return CodegenUtils.generatePointerReferenceIfPointable(memberShape, "st"); + } + + // Generates deserializer for Timestamp member shape. + private String generateDocumentBindingTimestampMemberDeserializer( + GoWriter writer, + MemberShape memberShape + ) { + writer.write("val, err := decoder.Token()"); + writer.write("if err != nil { return err }"); + writer.write("nt, err := val.(json.Number).Int64()"); + writer.write("if err != nil { return err }"); + + writer.addUseImports(GoDependency.TIME); + writer.write("ts := time.Unix(nt, 0).UTC()"); + + return CodegenUtils.generatePointerReferenceIfPointable(memberShape, "ts"); + } + + // Generates deserializer for blob member shape. + private String generateDocumentBindingBlobMemberDeserializer( + GoWriter writer, + Model model, + SymbolProvider symbolProvider, + MemberShape memberShape + ) { + Shape targetShape = model.expectShape(memberShape.getTarget()); + Symbol targetSymbol = symbolProvider.toSymbol(targetShape); + + writer.write("bs := $P{}", targetSymbol); + writer.write("err := decoder.Decode(&bs)"); + writer.write("if err != nil { return err }"); + return "bs"; + } + + // Generates deserializer for delegator for structure member shape. + private String generateDocumentBindingStructureMemberDeserializer( + GoWriter writer, + Model model, + SymbolProvider symbolProvider, + MemberShape memberShape + ) { + Shape targetShape = model.expectShape(memberShape.getTarget()); + Symbol targetSymbol = symbolProvider.toSymbol(targetShape); + String deserFunctionName = ProtocolGenerator + .getDocumentDeserializerFunctionName(targetShape, getProtocolName()); + writer.write("val := $T{}", targetSymbol); + writer.write("if err := $L(&val, decoder); err != nil { return err }", + deserFunctionName); + + return CodegenUtils.generatePointerReferenceIfPointable(memberShape, "val"); + } + + // Generates deserializer for delegator for collection member shape and map member shapes. + private String generateDocumentBindingCollectionMemberDeserializer( + GoWriter writer, + Model model, + SymbolProvider symbolProvider, + MemberShape memberShape + ) { + Shape targetShape = model.expectShape(memberShape.getTarget()); + Symbol targetSymbol = symbolProvider.toSymbol(targetShape); + + String deserializerFuncName = ProtocolGenerator + .getDocumentDeserializerFunctionName(targetShape, getProtocolName()); + writer.write("col := $P{}", targetSymbol); + writer.write("if col, err := $L(col, decoder); err != nil { return err }", deserializerFuncName); + return "col"; + } + + // generates Json decoder tokenizer start stub wrt to the shape + private void writeJsonTokenizerStartStub(GoWriter writer, Shape shape) { + String startToken = shape.isListShape() ? "[" : "{"; + writer.write("startToken, err := decoder.Token()"); + writer.write("if err == io.EOF { return nil }"); + writer.write("if err != nil { return err }"); + writer.openBlock("if t, ok := startToken.(json.Delim); !ok || t.String() != $S {", + "}", startToken, () -> { + writer.addUseImports(GoDependency.FMT); + writer.write("return fmt.Errorf($S)", + String.format("expect `%s` as start token", startToken)); + }); + writer.write(""); + } + + // generates Json decoder tokenizer end stub wrt to the shape + private void writeJsonTokenizerEndStub(GoWriter writer, Shape shape) { + String endToken = shape.isListShape() ? "]" : "}"; + writer.write(""); + writer.write("endToken, err := decoder.Token()"); + writer.write("if err != nil { return err }"); + writer.openBlock("if t, ok := endToken.(json.Delim); !ok || t.String() != $S {", + "}", endToken, () -> { + writer.write("return fmt.Errorf($S)", + String.format("expect `%s` as end token", endToken)); + }); + } + + + @Override + public void generateSharedDeserializerComponents(GenerationContext context) { + super.generateSharedDeserializerComponents(context); + } } From 415e6ddef6196b7842a61b17b5f2f45de3d8fda9 Mon Sep 17 00:00:00 2001 From: skotambkar Date: Mon, 1 Jun 2020 14:33:24 -0700 Subject: [PATCH 2/4] adds error check for discard unknown field deserializer util --- aws/protocol/json/decoder_util.go | 25 +++++++++++++++---- .../go/codegen/RestJsonProtocolGenerator.java | 3 ++- 2 files changed, 22 insertions(+), 6 deletions(-) diff --git a/aws/protocol/json/decoder_util.go b/aws/protocol/json/decoder_util.go index 8c1ce9812a9..3e1652ce721 100644 --- a/aws/protocol/json/decoder_util.go +++ b/aws/protocol/json/decoder_util.go @@ -93,16 +93,31 @@ func SanitizeErrorCode(errorCode string) string { return errorCode } - // DiscardUnknownField discards unknown fields from decoder body. // This function is useful while deserializing json body with additional // unknown information that should be discarded. -func DiscardUnknownField(decoder *json.Decoder) { - v, _ := decoder.Token() +func DiscardUnknownField(decoder *json.Decoder) error { + v, err := decoder.Token() + if err == io.EOF { + return nil + } + if err != nil{ + return err + } + if _, ok := v.(json.Delim); ok { for decoder.More() { - DiscardUnknownField(decoder) + err = DiscardUnknownField(decoder) + } + endToken, err := decoder.Token() + if err != nil{ + return err + } + if _, ok := endToken.(json.Delim); !ok { + return fmt.Errorf("invalid JSON : expected json delimiter, found %T %v", + endToken, endToken) } - decoder.Token() } + + return err } diff --git a/codegen/smithy-aws-go-codegen/src/main/java/software/amazon/smithy/aws/go/codegen/RestJsonProtocolGenerator.java b/codegen/smithy-aws-go-codegen/src/main/java/software/amazon/smithy/aws/go/codegen/RestJsonProtocolGenerator.java index b5a0131dbd3..fc30b0c0a15 100644 --- a/codegen/smithy-aws-go-codegen/src/main/java/software/amazon/smithy/aws/go/codegen/RestJsonProtocolGenerator.java +++ b/codegen/smithy-aws-go-codegen/src/main/java/software/amazon/smithy/aws/go/codegen/RestJsonProtocolGenerator.java @@ -830,7 +830,8 @@ private void generateDocumentBindingStructureShapeDeserializer( // default case to handle unknown fields writer.openBlock("default : ","", () -> { writer.addUseImports(GoDependency.AWS_JSON_PROTOCOL_ALIAS); - writer.write("jsonprotocol.DiscardUnknownField(decoder)"); + writer.write("err := jsonprotocol.DiscardUnknownField(decoder)"); + writer.write("if err != nil {return err}"); writer.write("break"); }); }); From 046d36d85283255bfd1f302b6b38d4c9d14fe246 Mon Sep 17 00:00:00 2001 From: skotambkar Date: Mon, 1 Jun 2020 21:40:33 -0700 Subject: [PATCH 3/4] suggested feedback --- aws/protocol/json/decoder_util.go | 75 ++++++------ .../go/codegen/RestJsonProtocolGenerator.java | 111 +++++++++++------- 2 files changed, 108 insertions(+), 78 deletions(-) diff --git a/aws/protocol/json/decoder_util.go b/aws/protocol/json/decoder_util.go index 3e1652ce721..c0cb58aeb99 100644 --- a/aws/protocol/json/decoder_util.go +++ b/aws/protocol/json/decoder_util.go @@ -3,6 +3,7 @@ package json import ( "encoding/json" "fmt" + "github.com/awslabs/smithy-go" "io" "strings" ) @@ -11,10 +12,7 @@ import ( // json body. These members are optionally available, and the function // returns the value of member if it is available. This function is useful to // identify the error code, msg in a REST JSON error response. -func GetErrorInfo(decoder *json.Decoder) (string, string, error) { - - var code, typeCode, msg string - +func GetErrorInfo(decoder *json.Decoder) (errorType string, message string, err error) { startToken, err := decoder.Token() if err == io.EOF { return "", "", nil @@ -28,37 +26,28 @@ func GetErrorInfo(decoder *json.Decoder) (string, string, error) { } for decoder.More() { + var target *string t, err := decoder.Token() if err != nil { return "", "", err } - switch t { - case "code": - v, err := decoder.Token() - if err != nil { - return "", "", err - } - code = v.(string) - break - case "message": - v, err := decoder.Token() - if err != nil { - return "", "", err - } - msg = v.(string) - break - case "__type": - v, err := decoder.Token() - if err != nil { - return "", "", err - } - typeCode = v.(string) - break + switch st := t.(string); { + case strings.EqualFold(st, "code"): + case strings.EqualFold(st, "__type"): + target = &errorType + case strings.EqualFold(st, "message"): + target = &message default: DiscardUnknownField(decoder) - break + continue + } + + v, err := decoder.Token() + if err != nil { + return errorType, message, err } + *target = v.(string) } endToken, err := decoder.Token() @@ -69,11 +58,7 @@ func GetErrorInfo(decoder *json.Decoder) (string, string, error) { if t, ok := endToken.(json.Delim); !ok || t.String() != "}" { return "", "", fmt.Errorf("expected end token to be }") } - - if len(code) == 0 { - return typeCode, msg, nil - } - return code, msg, nil + return errorType, message, nil } // SanitizeErrorCode sanitizes the errorCode string . @@ -101,7 +86,7 @@ func DiscardUnknownField(decoder *json.Decoder) error { if err == io.EOF { return nil } - if err != nil{ + if err != nil { return err } @@ -110,7 +95,7 @@ func DiscardUnknownField(decoder *json.Decoder) error { err = DiscardUnknownField(decoder) } endToken, err := decoder.Token() - if err != nil{ + if err != nil { return err } if _, ok := endToken.(json.Delim); !ok { @@ -119,5 +104,25 @@ func DiscardUnknownField(decoder *json.Decoder) error { } } - return err + return nil +} + +// GetSmithyGenericAPIError returns smithy generic api error and an error interface. +// Takes in json decoder, and error Code string as args. The function retrieves error message +// and error code from the decoder body. If errorCode of length greater than 0 is passed in as +// an argument, it is used instead. +func GetSmithyGenericAPIError(decoder *json.Decoder, errorCode string) (smithy.GenericAPIError, error) { + errorType, message, err := GetErrorInfo(decoder) + if err != nil { + return smithy.GenericAPIError{}, err + } + + if len(errorCode) == 0 { + errorCode = SanitizeErrorCode(errorType) + } + + return smithy.GenericAPIError{ + Code: errorCode, + Message: message, + }, nil } diff --git a/codegen/smithy-aws-go-codegen/src/main/java/software/amazon/smithy/aws/go/codegen/RestJsonProtocolGenerator.java b/codegen/smithy-aws-go-codegen/src/main/java/software/amazon/smithy/aws/go/codegen/RestJsonProtocolGenerator.java index fc30b0c0a15..bb86f4d5841 100644 --- a/codegen/smithy-aws-go-codegen/src/main/java/software/amazon/smithy/aws/go/codegen/RestJsonProtocolGenerator.java +++ b/codegen/smithy-aws-go-codegen/src/main/java/software/amazon/smithy/aws/go/codegen/RestJsonProtocolGenerator.java @@ -469,6 +469,9 @@ protected void writeMiddlewareErrorDeserializer( // checks if response has an error and retrieve the error code from the response writer.openBlock("if response.StatusCode < 200 || response.StatusCode >= 300 {", "}", () -> { + writer.write("errorType := response.Headers.Get($S)","X-Amzn-Errortype" ); + writer.write("errorType = jsonprotocol.SanitizeErrorCode(errorType)"); + writer.write(""); // if no modeled exceptions for the operation shape, return the response body as is if (ErrorShapeIds.size() == 0) { @@ -479,18 +482,8 @@ protected void writeMiddlewareErrorDeserializer( writer.write(""); writer.addUseImports(GoDependency.AWS_JSON_PROTOCOL_ALIAS); - writer.write("errorType, message, err := jsonprotocol.GetErrorInfo(decoder)"); - writer.openBlock( "if len(errortype) == 0 {", "}", () -> { - writer.write("errorType = response.Headers.Get($S)","X-Amzn-Errortype" ); - }); - writer.write("errorType = jsonprotocol.SanitizeErrorCode(errorType)"); - - writer.openBlock("genericError := $P{","}", - genericAPIErrorSymbol, () ->{ - writer.write("Code : errorType,"); - writer.write("Message : message,"); - }); - + writer.write("genericError, err := jsonprotocol.GetSmithyGenericAPIError(decoder, errorType)"); + writer.write("if err != nil { return out, metadata, &aws.DeserializationError{ Err: err}}"); writer.write("return out, metadata, genericError"); return; } @@ -502,11 +495,13 @@ protected void writeMiddlewareErrorDeserializer( writer.write("var errorBody bytes.Buffer"); writer.addUseImports(GoDependency.IO); - writer.write("_, err := io.Copy(errorBody, response.Body)"); + writer.write("_, err := io.Copy(&errorBody, response.Body)"); writer.openBlock("if err != nil {", "}", () -> { writer.write(String.format("return out, metadata, &aws.DeserializationError{Err: %s}", "fmt.Errorf(\"failed to copy error response body, %w\", err)")); }); + + writer.write(""); writer.write("body := io.TeeReader(response.Body, ringBuffer)"); writer.write("defer response.Body.Close()"); writer.write(""); @@ -518,33 +513,41 @@ protected void writeMiddlewareErrorDeserializer( writer.write("decoder.UseNumber()"); writer.write(""); + writer.write("var errorMessage string"); writer.addUseImports(GoDependency.AWS_JSON_PROTOCOL_ALIAS); - writer.write("errorType, message, err := jsonprotocol.GetErrorInfo(decoder)"); - writer.openBlock("if err != nil {", "}", () -> { - writer.write("var snapshot bytes.Buffer"); - writer.write("io.Copy(&snapshot, ringBuffer)"); - writer.openBlock("return out, metadata, &aws.DeserializationError {", "}", () -> { - writer.write("Err: fmt.Errorf(\"failed to decode response error with invalid JSON, %w\", err),"); - writer.write("Snapshot: snapshot.Bytes(),"); - }); - }); - writer.openBlock("if len(errorType) == 0 {", "}", () -> { - writer.write("errorType = response.Headers.Get($S)", "X-Amzn-Errortype"); + writer.write("errorType, errorMessage, err = jsonprotocol.GetErrorInfo(decoder)"); + writer.openBlock("if err != nil {", "}", () -> { + writer.write("var snapshot bytes.Buffer"); + writer.write("io.Copy(&snapshot, ringBuffer)"); + writer.openBlock("return out, metadata, &aws.DeserializationError {", "}", () -> { + writer.write("Err: fmt.Errorf(\"failed to decode response error with invalid JSON, %w\", err),"); + writer.write("Snapshot: snapshot.Bytes(),"); + }); + }); + writer.write("errorType = jsonprotocol.SanitizeErrorCode(errorType)"); }); - writer.write("errorType = jsonprotocol.SanitizeErrorCode(errorType)"); writer.write(""); // generate middleware for modeled error shapes writeErrorShapeDeserializerDelegator(writer, model, symbolProvider, ErrorShapeIds); writer.write(""); - writer.openBlock("genericError := $P{","}", - genericAPIErrorSymbol, () -> { - writer.write("Code : errorType,"); - writer.write("Message : message,"); + writer.openBlock("if len(errorMessage) != 0 {", "}", () -> { + writer.openBlock("genericError := $P{","}", + genericAPIErrorSymbol, () -> { + writer.write("Code : errorType,"); + writer.write("Message : errorMessage,"); + }); + writer.write(""); + writer.write("return out, metadata, genericError"); }); + + writer.write(""); + writer.addUseImports(GoDependency.AWS_JSON_PROTOCOL_ALIAS); + writer.write("genericError, err := jsonprotocol.GetSmithyGenericAPIError(decoder, errorType)"); + writer.write("if err != nil { return out, metadata, &aws.DeserializationError{ Err: err }}"); writer.write(""); writer.write("return out, metadata, genericError"); }); @@ -823,7 +826,6 @@ private void generateDocumentBindingStructureShapeDeserializer( writer.openBlock("case $S :", "", memberShape.getMemberName(), () -> { String operand = generateDocumentBindingMemberShapeDeserializer(writer, model, symbolProvider, memberShape); writer.write(String.format("v.%s = %s", memberName, operand)); - writer.write("break"); }); } @@ -832,7 +834,6 @@ private void generateDocumentBindingStructureShapeDeserializer( writer.addUseImports(GoDependency.AWS_JSON_PROTOCOL_ALIAS); writer.write("err := jsonprotocol.DiscardUnknownField(decoder)"); writer.write("if err != nil {return err}"); - writer.write("break"); }); }); }); @@ -909,11 +910,12 @@ private String generateDocumentBindingMemberShapeDeserializer( case SHORT: case INTEGER: case LONG: - case BIG_INTEGER: return generateDocumentBindingIntegerMemberDeserializer(writer, memberShape); + case BIG_INTEGER: + case BIG_DECIMAL: + return generateDocumentBindingBigMemberDeserializer(writer, memberShape); case FLOAT: case DOUBLE: - case BIG_DECIMAL: return generateDocumentBindingFloatMemberDeserializer(writer, memberShape); case SET: case LIST: @@ -976,7 +978,7 @@ private String generateDocumentBindingBooleanMemberDeserializer( return CodegenUtils.generatePointerReferenceIfPointable(memberShape, "b"); } - // Generates deserializer for Byte, Short, Integer, Long, Big Integer member shape. + // Generates deserializer for Byte, Short, Integer, Long member shape. private String generateDocumentBindingIntegerMemberDeserializer( GoWriter writer, MemberShape memberShape @@ -987,7 +989,7 @@ private String generateDocumentBindingIntegerMemberDeserializer( writer.write("if err != nil { return err }"); switch (memberShape.getType()) { case BYTE: - writer.write("st := byte(nt)"); + writer.write("st := int8(nt)"); break; case SHORT: writer.write("st := int16(nt)"); @@ -997,16 +999,41 @@ private String generateDocumentBindingIntegerMemberDeserializer( case LONG: writer.write("st := nt"); break; + default: + break; + } + + return CodegenUtils.generatePointerReferenceIfPointable(memberShape, "st"); + } + + // Generates deserializer for Big Integer, Big Decimal member shape. + private String generateDocumentBindingBigMemberDeserializer( + GoWriter writer, + MemberShape memberShape + ) { + writer.write("val, err := decoder.Token()"); + writer.write("if err != nil { return err }"); + switch (memberShape.getType()) { case BIG_INTEGER: writer.addUseImports(GoDependency.BIG); - writer.write("st := big.NewInt(nt)"); + writer.addUseImports(GoDependency.FMT); + writer.write("st, ok := new(big.Int).SetString(val.(string), 10)"); + writer.write("if !ok { return fmt.Errorf(\"error deserializing big integer type\")}"); + break; + case BIG_DECIMAL: + writer.addUseImports(GoDependency.BIG); + writer.addUseImports(GoDependency.FMT); + writer.write("st, ok := big.ParseFloat(val.(string), 10, 200, big.ToNearestAway)"); + writer.write("if !ok { return fmt.Errorf(\"error deserializing big decimal type\")}"); + break; + default: break; } return CodegenUtils.generatePointerReferenceIfPointable(memberShape, "st"); } - // Generates deserializer for Float, Double, Big Decimal member shape. + // Generates deserializer for Float, Double member shape. private String generateDocumentBindingFloatMemberDeserializer( GoWriter writer, MemberShape memberShape @@ -1023,9 +1050,7 @@ private String generateDocumentBindingFloatMemberDeserializer( case DOUBLE: writer.write("st := nt"); break; - case BIG_DECIMAL: - writer.addUseImports(GoDependency.BIG); - writer.write("st := big.NewFloat(nt)"); + default: break; } @@ -1039,11 +1064,11 @@ private String generateDocumentBindingTimestampMemberDeserializer( ) { writer.write("val, err := decoder.Token()"); writer.write("if err != nil { return err }"); - writer.write("nt, err := val.(json.Number).Int64()"); + writer.write("ft, err := val.(json.Number).Float64()"); writer.write("if err != nil { return err }"); - writer.addUseImports(GoDependency.TIME); - writer.write("ts := time.Unix(nt, 0).UTC()"); + writer.addUseImports(GoDependency.SMITHY_TIME); + writer.write("ts := smithytime.ParseEpochSeconds(ft)"); return CodegenUtils.generatePointerReferenceIfPointable(memberShape, "ts"); } From 1d5b21d1ca008d96350ae4a1e9fa9ac5bec5901f Mon Sep 17 00:00:00 2001 From: skotambkar Date: Tue, 2 Jun 2020 10:26:42 -0700 Subject: [PATCH 4/4] minor feedback updates --- aws/protocol/json/decoder_util.go | 12 ++++++++---- .../aws/go/codegen/RestJsonProtocolGenerator.java | 8 +++++--- 2 files changed, 13 insertions(+), 7 deletions(-) diff --git a/aws/protocol/json/decoder_util.go b/aws/protocol/json/decoder_util.go index c0cb58aeb99..e8a94bee3d2 100644 --- a/aws/protocol/json/decoder_util.go +++ b/aws/protocol/json/decoder_util.go @@ -34,6 +34,7 @@ func GetErrorInfo(decoder *json.Decoder) (errorType string, message string, err switch st := t.(string); { case strings.EqualFold(st, "code"): + fallthrough case strings.EqualFold(st, "__type"): target = &errorType case strings.EqualFold(st, "message"): @@ -58,6 +59,9 @@ func GetErrorInfo(decoder *json.Decoder) (errorType string, message string, err if t, ok := endToken.(json.Delim); !ok || t.String() != "}" { return "", "", fmt.Errorf("expected end token to be }") } + + // sanitize error + errorType = SanitizeErrorCode(errorType) return errorType, message, nil } @@ -111,17 +115,17 @@ func DiscardUnknownField(decoder *json.Decoder) error { // Takes in json decoder, and error Code string as args. The function retrieves error message // and error code from the decoder body. If errorCode of length greater than 0 is passed in as // an argument, it is used instead. -func GetSmithyGenericAPIError(decoder *json.Decoder, errorCode string) (smithy.GenericAPIError, error) { +func GetSmithyGenericAPIError(decoder *json.Decoder, errorCode string) (*smithy.GenericAPIError, error) { errorType, message, err := GetErrorInfo(decoder) if err != nil { - return smithy.GenericAPIError{}, err + return nil, err } if len(errorCode) == 0 { - errorCode = SanitizeErrorCode(errorType) + errorCode = errorType } - return smithy.GenericAPIError{ + return &smithy.GenericAPIError{ Code: errorCode, Message: message, }, nil diff --git a/codegen/smithy-aws-go-codegen/src/main/java/software/amazon/smithy/aws/go/codegen/RestJsonProtocolGenerator.java b/codegen/smithy-aws-go-codegen/src/main/java/software/amazon/smithy/aws/go/codegen/RestJsonProtocolGenerator.java index bb86f4d5841..a089680fbeb 100644 --- a/codegen/smithy-aws-go-codegen/src/main/java/software/amazon/smithy/aws/go/codegen/RestJsonProtocolGenerator.java +++ b/codegen/smithy-aws-go-codegen/src/main/java/software/amazon/smithy/aws/go/codegen/RestJsonProtocolGenerator.java @@ -469,6 +469,8 @@ protected void writeMiddlewareErrorDeserializer( // checks if response has an error and retrieve the error code from the response writer.openBlock("if response.StatusCode < 200 || response.StatusCode >= 300 {", "}", () -> { + // Retrieve error shape name from response. For REST JSON protocol, the error shape name can be found either + // at Header `X-Amzn-Errortype` or a body field with the name `code`, or a body field named `__type`. writer.write("errorType := response.Headers.Get($S)","X-Amzn-Errortype" ); writer.write("errorType = jsonprotocol.SanitizeErrorCode(errorType)"); writer.write(""); @@ -506,8 +508,6 @@ protected void writeMiddlewareErrorDeserializer( writer.write("defer response.Body.Close()"); writer.write(""); - // Retrieve error shape name from response. For REST JSON protocol, the error shape name can be found either - // at Header `X-Amzn-Errortype` or a body field with the name `code`, or a body field named `__type`. writer.addUseImports(GoDependency.JSON); writer.write("decoder := json.NewDecoder(body)"); writer.write("decoder.UseNumber()"); @@ -515,6 +515,9 @@ protected void writeMiddlewareErrorDeserializer( writer.write("var errorMessage string"); writer.addUseImports(GoDependency.AWS_JSON_PROTOCOL_ALIAS); + + // If errorType is empty, look for error type in a body field with the name `code`, + // or a body field named `__type`. writer.openBlock("if len(errorType) == 0 {", "}", () -> { writer.write("errorType, errorMessage, err = jsonprotocol.GetErrorInfo(decoder)"); writer.openBlock("if err != nil {", "}", () -> { @@ -525,7 +528,6 @@ protected void writeMiddlewareErrorDeserializer( writer.write("Snapshot: snapshot.Bytes(),"); }); }); - writer.write("errorType = jsonprotocol.SanitizeErrorCode(errorType)"); }); writer.write("");