Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

codegen: Add support for list and maps generated as value members instead of pointer #887

Merged
merged 9 commits into from
Nov 15, 2020
Merged
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,24 @@ static void generateHttpProtocolTests(GenerationContext context) {
.operation(ShapeId.from("aws.protocoltests.restjson#InlineDocumentAsPayload"))
.build(),

// Null lists/maps without sparse tag
HttpProtocolUnitTestGenerator.SkipTest.builder()
.service(ShapeId.from("aws.protocoltests.restjson#RestJson"))
.operation(ShapeId.from("aws.protocoltests.restjson#JsonLists"))
.addTestName("RestJsonListsSerializeNull")
.build(),
HttpProtocolUnitTestGenerator.SkipTest.builder()
.service(ShapeId.from("aws.protocoltests.restjson#RestJson"))
.operation(ShapeId.from("aws.protocoltests.restjson#JsonMaps"))
.addTestName("RestJsonSerializesNullMapValues")
.build(),
HttpProtocolUnitTestGenerator.SkipTest.builder()
.service(ShapeId.from("aws.protocoltests.json#JsonProtocol"))
.operation(ShapeId.from("aws.protocoltests.json#NullOperation"))
.addTestName("AwsJson11MapsSerializeNullValues")
.addTestName("AwsJson11ListsSerializeNull")
.build(),

// JSON RPC Documents
HttpProtocolUnitTestGenerator.SkipTest.builder()
.service(ShapeId.from("aws.protocoltests.json#JsonProtocol"))
Expand All @@ -125,7 +143,7 @@ static void generateHttpProtocolTests(GenerationContext context) {
HttpProtocolUnitTestGenerator.SkipTest.builder()
.service(ShapeId.from("aws.protocoltests.restxml#RestXml"))
.operation(ShapeId.from("aws.protocoltests.restxml#HttpPrefixHeaders"))
.testName("HttpPrefixHeadersAreNotPresent")
.addTestName("HttpPrefixHeadersAreNotPresent")
.build()
));

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,13 +15,16 @@

package software.amazon.smithy.aws.go.codegen;

import java.util.function.Consumer;
import software.amazon.smithy.codegen.core.CodegenException;
import software.amazon.smithy.go.codegen.CodegenUtils;
import software.amazon.smithy.go.codegen.GoWriter;
import software.amazon.smithy.go.codegen.SmithyGoDependency;
import software.amazon.smithy.go.codegen.integration.ProtocolGenerator;
import software.amazon.smithy.go.codegen.integration.ProtocolGenerator.GenerationContext;
import software.amazon.smithy.go.codegen.integration.ProtocolUtils;
import software.amazon.smithy.go.codegen.knowledge.GoPointableIndex;
import software.amazon.smithy.model.knowledge.NullableIndex;
import software.amazon.smithy.model.shapes.BigDecimalShape;
import software.amazon.smithy.model.shapes.BigIntegerShape;
import software.amazon.smithy.model.shapes.BlobShape;
Expand All @@ -40,6 +43,7 @@
import software.amazon.smithy.model.shapes.ServiceShape;
import software.amazon.smithy.model.shapes.SetShape;
import software.amazon.smithy.model.shapes.Shape;
import software.amazon.smithy.model.shapes.ShapeType;
import software.amazon.smithy.model.shapes.ShapeVisitor;
import software.amazon.smithy.model.shapes.ShortShape;
import software.amazon.smithy.model.shapes.StringShape;
Expand Down Expand Up @@ -75,63 +79,63 @@ public DocumentMemberSerVisitor(

@Override
public Void blobShape(BlobShape shape) {
String source = conditionallyDereference(shape, dataSource);
String source = CodegenUtils.getAsValueIfDereferencable(context.getPointableIndex(), member, dataSource);
context.getWriter().write("$L.Base64EncodeBytes($L)", dataDest, source);
return null;
}

@Override
public Void booleanShape(BooleanShape shape) {
String source = conditionallyDereference(shape, dataSource);
String source = CodegenUtils.getAsValueIfDereferencable(context.getPointableIndex(), member, dataSource);
context.getWriter().write("$L.Boolean($L)", dataDest, source);
return null;
}

@Override
public Void byteShape(ByteShape shape) {
String source = conditionallyDereference(shape, dataSource);
String source = CodegenUtils.getAsValueIfDereferencable(context.getPointableIndex(), member, dataSource);
context.getWriter().write("$L.Byte($L)", dataDest, source);
return null;
}

@Override
public Void shortShape(ShortShape shape) {
String source = conditionallyDereference(shape, dataSource);
String source = CodegenUtils.getAsValueIfDereferencable(context.getPointableIndex(), member, dataSource);
context.getWriter().write("$L.Short($L)", dataDest, source);
return null;
}

@Override
public Void integerShape(IntegerShape shape) {
String source = conditionallyDereference(shape, dataSource);
String source = CodegenUtils.getAsValueIfDereferencable(context.getPointableIndex(), member, dataSource);
context.getWriter().write("$L.Integer($L)", dataDest, source);
return null;
}

@Override
public Void longShape(LongShape shape) {
String source = conditionallyDereference(shape, dataSource);
String source = CodegenUtils.getAsValueIfDereferencable(context.getPointableIndex(), member, dataSource);
context.getWriter().write("$L.Long($L)", dataDest, source);
return null;
}

@Override
public Void floatShape(FloatShape shape) {
String source = conditionallyDereference(shape, dataSource);
String source = CodegenUtils.getAsValueIfDereferencable(context.getPointableIndex(), member, dataSource);
context.getWriter().write("$L.Float($L)", dataDest, source);
return null;
}

@Override
public Void doubleShape(DoubleShape shape) {
String source = conditionallyDereference(shape, dataSource);
String source = CodegenUtils.getAsValueIfDereferencable(context.getPointableIndex(), member, dataSource);
context.getWriter().write("$L.Double($L)", dataDest, source);
return null;
}

@Override
public Void timestampShape(TimestampShape shape) {
String source = conditionallyDereference(shape, dataSource);
String source = CodegenUtils.getAsValueIfDereferencable(context.getPointableIndex(), member, dataSource);
GoWriter writer = context.getWriter();
writer.addUseImports(SmithyGoDependency.SMITHY_TIME);

Expand All @@ -153,23 +157,14 @@ public Void timestampShape(TimestampShape shape) {

@Override
public Void stringShape(StringShape shape) {
String source = conditionallyDereference(shape, dataSource);
String source = CodegenUtils.getAsValueIfDereferencable(context.getPointableIndex(), member, dataSource);
if (shape.hasTrait(EnumTrait.class)) {
source = String.format("string(%s)", source);
}
context.getWriter().write("$L.String($L)", dataDest, source);
return null;
}

private String conditionallyDereference(Shape shape, String dataSource) {
boolean shouldDereference = CodegenUtils.isShapePassByReference(shape);
if (context.getModel().expectShape(member.getContainer()).isUnionShape()) {
Shape target = context.getModel().expectShape(member.getTarget());
shouldDereference &= ProtocolUtils.usesScalarWhenUnionValue(target);
}
return shouldDereference ? "*" + dataSource : dataSource;
}

@Override
public Void bigIntegerShape(BigIntegerShape shape) {
// Fail instead of losing precision through Number.
Expand All @@ -184,7 +179,7 @@ public Void bigDecimalShape(BigDecimalShape shape) {
return null;
}

private String unsupportedShape(Shape shape) {
private void unsupportedShape(Shape shape) {
throw new CodegenException(String.format("Cannot serialize shape type %s on protocol, shape: %s.",
shape.getType(), shape.getId()));
}
Expand Down Expand Up @@ -248,8 +243,10 @@ public Void mapShape(MapShape shape) {
private void writeDelegateFunction(Shape shape) {
String serFunctionName = ProtocolGenerator.getDocumentSerializerFunctionName(shape, context.getProtocolName());
GoWriter writer = context.getWriter();
writer.openBlock("if err := $L($L, $L); err != nil {", "}", serFunctionName, dataSource, dataDest, () -> {
writer.write("return err");

ProtocolUtils.writeSerDelegateFunction(context, writer, member, dataSource, (srcVar) -> {
writer.openBlock("if err := $L($L, $L); err != nil {", "}", serFunctionName, srcVar, dataDest,
() -> writer.write("return err"));
});
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
import software.amazon.smithy.go.codegen.SmithyGoDependency;
import software.amazon.smithy.go.codegen.integration.ProtocolGenerator;
import software.amazon.smithy.go.codegen.integration.ProtocolGenerator.GenerationContext;
import software.amazon.smithy.go.codegen.integration.ProtocolUtils;
import software.amazon.smithy.model.shapes.BigDecimalShape;
import software.amazon.smithy.model.shapes.BigIntegerShape;
import software.amazon.smithy.model.shapes.BlobShape;
Expand Down Expand Up @@ -55,11 +56,13 @@
*/
public class JsonMemberDeserVisitor implements ShapeVisitor<Void> {
private final GenerationContext context;
private final MemberShape member;
private final String dataDest;
private final Format timestampFormat;

public JsonMemberDeserVisitor(GenerationContext context, String dataDest, Format timestampFormat) {
public JsonMemberDeserVisitor(GenerationContext context, MemberShape member, String dataDest, Format timestampFormat) {
this.context = context;
this.member = member;
this.dataDest = dataDest;
this.timestampFormat = timestampFormat;
}
Expand All @@ -70,13 +73,20 @@ public Void blobShape(BlobShape shape) {
writer.addUseImports(SmithyGoDependency.FMT);
writer.addUseImports(SmithyGoDependency.BASE64);
final String typeError = "return fmt.Errorf(\"expected $L to be []byte, got %T instead\", value)";

writer.openBlock("if value != nil {", "}", () -> {
writer.write("jtv, ok := value.(string)");
writer.openBlock("if !ok {", "}", () -> writer.write(typeError, shape.getId().getName()));
writer.openBlock("if !ok {", "}", () -> {
writer.write(typeError, shape.getId().getName());
});

writer.write("dv, err := base64.StdEncoding.DecodeString(jtv)");
writer.openBlock("if err != nil {", "}", () -> writer.write("return fmt.Errorf(\"failed to base64 decode "
+ "$L, %w\", err)", shape.getId().getName()));
writer.write("$L = dv", dataDest);
writer.openBlock("if err != nil {", "}", () -> {
writer.write("return fmt.Errorf(\"failed to base64 decode $L, %w\", err)", shape.getId().getName());
});

writer.write("$L = $L", dataDest, CodegenUtils.getAsPointerIfPointable(context.getModel(),
context.getWriter(), context.getPointableIndex(), member, "dv"));
});
return null;
}
Expand All @@ -91,7 +101,8 @@ public Void booleanShape(BooleanShape shape) {
writer.write("return fmt.Errorf(\"expected $L to be of type *bool, got %T instead\", value)",
shape.getId().getName());
});
writer.write("$L = &jtv", dataDest);
writer.write("$L = $L", dataDest, CodegenUtils.getAsPointerIfPointable(context.getModel(),
context.getWriter(), context.getPointableIndex(), member, "jtv"));
});
return null;
}
Expand All @@ -100,27 +111,31 @@ public Void booleanShape(BooleanShape shape) {
public Void byteShape(ByteShape shape) {
GoWriter writer = context.getWriter();
// Smithy's byte shape represents a signed 8-bit int, which doesn't line up with Go's unsigned byte
handleInteger(shape, CodegenUtils.generatePointerValueIfPointable(writer, shape, "int8(i64)"));
handleInteger(shape, CodegenUtils.getAsPointerIfPointable(context.getModel(), writer,
context.getPointableIndex(), member, "int8(i64)"));
return null;
}

@Override
public Void shortShape(ShortShape shape) {
GoWriter writer = context.getWriter();
handleInteger(shape, CodegenUtils.generatePointerValueIfPointable(writer, shape, "int16(i64)"));
handleInteger(shape, CodegenUtils.getAsPointerIfPointable(context.getModel(), writer,
context.getPointableIndex(), member, "int16(i64)"));
return null;
}

@Override
public Void integerShape(IntegerShape shape) {
GoWriter writer = context.getWriter();
handleInteger(shape, CodegenUtils.generatePointerValueIfPointable(writer, shape, "int32(i64)"));
handleInteger(shape, CodegenUtils.getAsPointerIfPointable(context.getModel(), writer,
context.getPointableIndex(), member, "int32(i64)"));
return null;
}

@Override
public Void longShape(LongShape shape) {
handleInteger(shape, "&i64");
handleInteger(shape, CodegenUtils.getAsPointerIfPointable(context.getModel(), context.getWriter(),
context.getPointableIndex(), member, "i64"));
return null;
}

Expand Down Expand Up @@ -165,14 +180,15 @@ private void handleNumber(Shape shape, Runnable r) {

@Override
public Void floatShape(FloatShape shape) {
GoWriter writer = context.getWriter();
handleFloat(shape, CodegenUtils.generatePointerValueIfPointable(writer, shape, "float32(f64)"));
handleFloat(shape, CodegenUtils.getAsPointerIfPointable(context.getModel(), context.getWriter(),
context.getPointableIndex(), member, "float32(f64)"));
return null;
}

@Override
public Void doubleShape(DoubleShape shape) {
handleFloat(shape, "&f64");
handleFloat(shape, CodegenUtils.getAsPointerIfPointable(context.getModel(), context.getWriter(),
context.getPointableIndex(), member, "f64"));
return null;
}

Expand Down Expand Up @@ -201,7 +217,8 @@ public Void stringShape(StringShape shape) {
if (shape.hasTrait(EnumTrait.class)) {
handleString(shape, () -> writer.write("$L = $P(jtv)", dataDest, symbol));
} else {
handleString(shape, () -> writer.write("$L = &jtv", dataDest));
handleString(shape, () -> writer.write("$L = $L", dataDest, CodegenUtils.getAsPointerIfPointable(
context.getModel(), context.getWriter(), context.getPointableIndex(), member, "jtv")));
}

return null;
Expand Down Expand Up @@ -239,19 +256,22 @@ public Void timestampShape(TimestampShape shape) {
handleString(shape, () -> {
writer.write("t, err := smithytime.ParseDateTime(jtv)");
writer.write("if err != nil { return err }");
writer.write("$L = &t", dataDest);
writer.write("$L = $L", dataDest, CodegenUtils.getAsPointerIfPointable(context.getModel(),
context.getWriter(), context.getPointableIndex(), member, "t"));
});
break;
case HTTP_DATE:
handleString(shape, () -> {
writer.write("t, err := smithytime.ParseHTTPDate(jtv)");
writer.write("if err != nil { return err }");
writer.write("$L = &t", dataDest);
writer.write("$L = $L", dataDest, CodegenUtils.getAsPointerIfPointable(context.getModel(),
context.getWriter(), context.getPointableIndex(), member, "t"));
});
break;
case EPOCH_SECONDS:
writer.addUseImports(SmithyGoDependency.SMITHY_PTR);
handleFloat(shape, "ptr.Time(smithytime.ParseEpochSeconds(f64))");
handleFloat(shape, CodegenUtils.getAsPointerIfPointable(context.getModel(), context.getWriter(),
context.getPointableIndex(), member, "smithytime.ParseEpochSeconds(f64)"));
break;
default:
throw new CodegenException(String.format("Unknown timestamp format %s", timestampFormat));
Expand Down Expand Up @@ -340,8 +360,11 @@ public Void mapShape(MapShape shape) {
private void writeDelegateFunction(Shape shape) {
String functionName = ProtocolGenerator.getDocumentDeserializerFunctionName(shape, context.getProtocolName());
GoWriter writer = context.getWriter();
writer.openBlock("if err := $L(&$L, value); err != nil {", "}", functionName, dataDest, () -> {
writer.write("return err");

ProtocolUtils.writeDeserDelegateFunction(context, writer, member, dataDest, (destVar) -> {
writer.openBlock("if err := $L(&$L, value); err != nil {", "}", functionName, destVar, () -> {
writer.write("return err");
});
});
}
}
Loading