diff --git a/codegen/smithy-go-codegen-test/model/main.smithy b/codegen/smithy-go-codegen-test/model/main.smithy index 71570cafa..7f1952135 100644 --- a/codegen/smithy-go-codegen-test/model/main.smithy +++ b/codegen/smithy-go-codegen-test/model/main.smithy @@ -200,10 +200,31 @@ apply ListCities @httpRequestTests([ } ]) +integer DefaultInteger +boolean DefaultBool + structure ListCitiesInput { @httpQuery("nextToken") nextToken: String, + @httpQuery("aString") + aString: String, + + @httpQuery("defaultBool") + defaultBool: DefaultBool, + + @httpQuery("boxedBool") + boxedBool: Boolean, + + @httpQuery("defaultNumber") + defaultNumber: DefaultInteger, + + @httpQuery("boxedNumber") + boxedNumber: Integer, + + @httpQuery("someEnum") + someEnum: SimpleYesNo, + @httpQuery("pageSize") pageSize: Integer } @@ -211,8 +232,16 @@ structure ListCitiesInput { structure ListCitiesOutput { nextToken: String, + someEnum: SimpleYesNo, + aString: String, + defaultBool: DefaultBool, + boxedBool: Boolean, + defaultNumber: DefaultInteger, + boxedNumber: Integer, + @required items: CitySummaries, + sparseItems: SparseCitySummaries, } // CitySummaries is a list of CitySummary structures. @@ -220,6 +249,12 @@ list CitySummaries { member: CitySummary } +// CitySummaries is a sparse list of CitySummary structures. +@sparse +list SparseCitySummaries { + member: CitySummary +} + // CitySummary contains a reference to a City. @references([{resource: City}]) structure CitySummary { diff --git a/codegen/smithy-go-codegen/build.gradle.kts b/codegen/smithy-go-codegen/build.gradle.kts index 721ffab3e..1d472d60a 100644 --- a/codegen/smithy-go-codegen/build.gradle.kts +++ b/codegen/smithy-go-codegen/build.gradle.kts @@ -18,8 +18,8 @@ extra["displayName"] = "Smithy :: Go :: Codegen" extra["moduleName"] = "software.amazon.smithy.go.codegen" dependencies { - api("software.amazon.smithy:smithy-codegen-core:[1.2.0,2.0.0[") + api("software.amazon.smithy:smithy-codegen-core:[1.3.0,2.0.0[") compile("com.atlassian.commonmark:commonmark:0.15.2") api("org.jsoup:jsoup:1.13.1") - implementation("software.amazon.smithy:smithy-protocol-test-traits:[1.2.0,2.0.0[") + implementation("software.amazon.smithy:smithy-protocol-test-traits:[1.3.0,2.0.0[") } diff --git a/codegen/smithy-go-codegen/src/main/java/software/amazon/smithy/go/codegen/CodegenUtils.java b/codegen/smithy-go-codegen/src/main/java/software/amazon/smithy/go/codegen/CodegenUtils.java index 4e39a61bf..6f6146b21 100644 --- a/codegen/smithy-go-codegen/src/main/java/software/amazon/smithy/go/codegen/CodegenUtils.java +++ b/codegen/smithy-go-codegen/src/main/java/software/amazon/smithy/go/codegen/CodegenUtils.java @@ -28,15 +28,14 @@ import software.amazon.smithy.codegen.core.CodegenException; import software.amazon.smithy.codegen.core.Symbol; import software.amazon.smithy.codegen.core.SymbolProvider; +import software.amazon.smithy.go.codegen.knowledge.GoPointableIndex; import software.amazon.smithy.model.Model; import software.amazon.smithy.model.shapes.CollectionShape; import software.amazon.smithy.model.shapes.MemberShape; import software.amazon.smithy.model.shapes.ServiceShape; import software.amazon.smithy.model.shapes.Shape; -import software.amazon.smithy.model.shapes.ShapeId; import software.amazon.smithy.model.shapes.ShapeType; import software.amazon.smithy.model.shapes.StructureShape; -import software.amazon.smithy.model.traits.EnumTrait; import software.amazon.smithy.model.traits.RequiredTrait; import software.amazon.smithy.model.traits.TitleTrait; import software.amazon.smithy.utils.StringUtils; @@ -146,26 +145,103 @@ public static String getSyntheticTypeNamespace() { return CodegenUtils.SYNTHETIC_NAMESPACE; } + + /** + * Returns the operand decorated with an & if the address of the shape type can be taken. + * + * @param model API model reference + * @param pointableIndex pointable index + * @param shape shape to use + * @param operand value to decorate + * @return updated operand + */ + public static String asAddressIfAddressable( + Model model, + GoPointableIndex pointableIndex, + Shape shape, + String operand + ) { + boolean isStruct = shape.getType() == ShapeType.STRUCTURE; + if (shape.isMemberShape()) { + isStruct = model.expectShape(shape.asMemberShape().get().getTarget()).getType() == ShapeType.STRUCTURE; + } + + boolean shouldAddress = pointableIndex.isPointable(shape) && isStruct; + return shouldAddress ? "&" + operand : operand; + } + + /** + * Returns the operand decorated with an "*" if the shape is dereferencable. + * + * @param pointableIndex knowledge index for if shape is pointable. + * @param shape The shape whose value needs to be read. + * @param operand The value to be read from. + * @return updated operand + */ + public static String getAsValueIfDereferencable( + GoPointableIndex pointableIndex, + Shape shape, + String operand + ) { + if (!pointableIndex.isDereferencable(shape)) { + return operand; + } + + return '*' + operand; + } + + /** + * Returns the operand decorated as a pointer type, without creating double pointer. + * + * @param pointableIndex knowledge index for if shape is pointable. + * @param shape The shape whose value of the type. + * @param operand The value to read. + * @return updated operand + */ + public static String getTypeAsTypePointer( + GoPointableIndex pointableIndex, + Shape shape, + String operand + ) { + if (pointableIndex.isPointable(shape)) { + return operand; + } + + return '*' + operand; + } + /** * Get the pointer reference to operand , if symbol is pointable. * This method can be used by deserializers to get pointer to * operand. * - * @param writer The writer dependencies will be added to, if needed. - * @param shape The shape whose value needs to be assigned. - * @param operand The Operand is the value to be assigned to the symbol shape. + * @param model model for api. + * @param writer The writer dependencies will be added to, if needed. + * @param pointableIndex knowledge index for if shape is pointable. + * @param shape The shape whose value needs to be assigned. + * @param operand The Operand is the value to be assigned to the symbol shape. * @return The Operand, along with pointer reference if applicable */ - public static String generatePointerValueIfPointable(GoWriter writer, Shape shape, String operand) { + public static String getAsPointerIfPointable( + Model model, + GoWriter writer, + GoPointableIndex pointableIndex, + Shape shape, + String operand + ) { + if (!pointableIndex.isPointable(shape)) { + return operand; + } + + if (shape.isMemberShape()) { + shape = model.expectShape(shape.asMemberShape().get().getTarget()); + } + String prefix = ""; String suffix = ")"; switch (shape.getType()) { case STRING: - if (shape.hasTrait(EnumTrait.class)) { - return operand; - } - prefix = "ptr.String("; break; @@ -198,126 +274,13 @@ public static String generatePointerValueIfPointable(GoWriter writer, Shape shap break; default: - if (isShapePassByReference(shape)) { - return '&' + operand; - } - return operand; - } - - writer.addUseImports(SmithyGoDependency.SMITHY_PTR); - return prefix + operand + suffix; - } - - /** - * Gets a value version of the operate based on the shape type. Returns a string with dereferencing the provided - * operand value if needed. Shapes like Structure, maps, and slices are not dereferenced. - * - * @param writer The writer dependencies will be added to, if needed. - * @param shape The shape whose value needs to be assigned. - * @param operand The Operand is the value to be assigned to the symbol shape. - * @return The Operand, along with pointer reference if applicable - */ - public static String operandValueIfScalar(GoWriter writer, Shape shape, String operand) { - String prefix = ""; - String suffix = ")"; - - switch (shape.getType()) { - case STRING: - if (shape.hasTrait(EnumTrait.class)) { - return operand; - } - - prefix = "ptr.ToString("; - break; - - case BOOLEAN: - prefix = "ptr.ToBool("; - break; - - case BYTE: - prefix = "ptr.ToInt8("; - break; - case SHORT: - prefix = "ptr.ToInt16("; - break; - case INTEGER: - prefix = "ptr.ToInt32("; - break; - case LONG: - prefix = "ptr.ToInt64("; - break; - - case FLOAT: - prefix = "ptr.ToFloat32("; - break; - case DOUBLE: - prefix = "ptr.ToFloat64("; - break; - - case TIMESTAMP: - prefix = "ptr.ToTime("; - break; - - default: - return operand; + return '&' + operand; } writer.addUseImports(SmithyGoDependency.SMITHY_PTR); return prefix + operand + suffix; } - /** - * Returns whether the shape should be passed by value in Go. - * - * @param shape the shape - * @return whether the shape should be passed by value - */ - public static boolean isShapePassByValue(Shape shape) { - return shape.getType() == ShapeType.LIST - || shape.getType() == ShapeType.SET - || shape.getType() == ShapeType.UNION - || shape.getType() == ShapeType.MAP - || shape.getType() == ShapeType.BLOB - || (shape.getType() == ShapeType.STRING && shape.hasTrait(EnumTrait.class)); - } - - /** - * Returns whether the shape should be passed by pointer reference in Go. - * - * @param shape the shape - * @return whether the shape should be passed by reference - */ - public static boolean isShapePassByReference(Shape shape) { - return !isShapePassByValue(shape); - } - - /** - * Returns whether the provided shape can have a Go nil value assigned to it. - * If provided a MemberShape it will use the target shape and the aggregate shape containing - * the member is a reference frame to determine if nil is allowed. - * - * @param model the model - * @param shape the shape to test - * @return if the shape can be assigned a nil Go value - */ - public static boolean isNilAssignableToShape(Model model, Shape shape) { - if (shape instanceof MemberShape) { - ShapeId memberShapeId = shape.getId(); - - Shape aggregateShape = model.expectShape(ShapeId.fromParts(memberShapeId.getNamespace(), - memberShapeId.getName())); - - // If the aggregate parent shape is not a structure the member shape is expected to not be a pointer - if (!(aggregateShape instanceof StructureShape)) { - return false; - } - - shape = model.expectShape(((MemberShape) shape).getTarget()); - } - - return !shape.hasTrait(EnumTrait.class); - } - /** * Returns the shape unpacked as a CollectionShape. Throws and exception if the passed in * shape is not a list or set. @@ -411,4 +374,24 @@ public static MemberShape expectMember(StructureShape shape, Predicate m public static String getServiceTitle(ServiceShape shape, String fallback) { return shape.getTrait(TitleTrait.class).map(TitleTrait::getValue).orElse(fallback); } + + /** + * isNumber returns if the shape is a number shape. + * + * @param shape shape to check + * @return true if is a number shape. + */ + public static boolean isNumber(Shape shape) { + switch (shape.getType()) { + case BYTE: + case SHORT: + case INTEGER: + case LONG: + case FLOAT: + case DOUBLE: + return true; + default: + return false; + } + } } diff --git a/codegen/smithy-go-codegen/src/main/java/software/amazon/smithy/go/codegen/GoValueAccessUtils.java b/codegen/smithy-go-codegen/src/main/java/software/amazon/smithy/go/codegen/GoValueAccessUtils.java new file mode 100644 index 000000000..f741b47d0 --- /dev/null +++ b/codegen/smithy-go-codegen/src/main/java/software/amazon/smithy/go/codegen/GoValueAccessUtils.java @@ -0,0 +1,305 @@ +/* + * Copyright 2020 Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). + * You may not use this file except in compliance with the License. + * A copy of the License is located at + * + * http://aws.amazon.com/apache2.0 + * + * or in the "license" file accompanying this file. This file is distributed + * on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either + * express or implied. See the License for the specific language governing + * permissions and limitations under the License. + * + * + */ + +package software.amazon.smithy.go.codegen; + +import java.util.function.Consumer; +import software.amazon.smithy.go.codegen.integration.ProtocolGenerator; +import software.amazon.smithy.go.codegen.knowledge.GoPointableIndex; +import software.amazon.smithy.model.shapes.CollectionShape; +import software.amazon.smithy.model.shapes.MemberShape; +import software.amazon.smithy.model.shapes.Shape; +import software.amazon.smithy.model.shapes.ShapeType; +import software.amazon.smithy.model.traits.EnumTrait; + +/** + * Utilities for generating accessor checks around other generated blocks. + */ +public final class GoValueAccessUtils { + private GoValueAccessUtils() { + } + + /** + * Writes non-zero conditional checks around a lambda specific to the member shape type. + * + * Note: Collections and map member values by default will not have individual checks on member values. To check + * not empty strings set the ignoreEmptyString to false. + * + * @param context generation context + * @param writer go writer + * @param member API shape member to determine wrapping check with + * @param operand string of text with access to value + * @param ignoreEmptyString if empty strings also checked + * @param lambda lambda to run + */ + public static void writeIfNonZeroValue( + ProtocolGenerator.GenerationContext context, + GoWriter writer, + MemberShape member, + String operand, + Boolean ignoreEmptyString, + Runnable lambda + ) { + Shape targetShape = context.getModel().expectShape(member.getTarget()); + Shape container = context.getModel().expectShape(member.getContainer()); + + // default to empty block for variable scoping with not value check. + String check = "{"; + + if (GoPointableIndex.of(context.getModel()).isNillable(member)) { + if (!ignoreEmptyString && targetShape.getType() == ShapeType.STRING) { + check = String.format("if %s != nil && len(*%s) > 0 {", operand, operand); + } else { + check = String.format("if %s != nil {", operand); + } + } else if (container instanceof CollectionShape || container.getType() == ShapeType.MAP) { + if (!ignoreEmptyString && targetShape.getType() == ShapeType.STRING) { + check = String.format("if len(%s) > 0 {", operand); + } + } else if (targetShape.hasTrait(EnumTrait.class)) { + check = String.format("if len(%s) > 0 {", operand); + + } else if (targetShape.getType() == ShapeType.BOOLEAN) { + check = String.format("if %s {", operand); + + } else if (CodegenUtils.isNumber(targetShape)) { + check = String.format("if %s != 0 {", operand); + + } else if (!ignoreEmptyString && targetShape.getType() == ShapeType.STRING) { + check = String.format("if len(%s) > 0 {", operand); + } + + writer.openBlock(check, "}", lambda); + } + + /** + * Writes non-zero conditional checks around a lambda specific to the member shape type. + * + * Ignores empty strings of string pointers, and nested within list and maps. + * + * @param context generation context + * @param writer go writer + * @param member API shape member to determine wrapping check with + * @param operand string of text with access to value + * @param lambda lambda to run + */ + public static void writeIfNonZeroValue( + ProtocolGenerator.GenerationContext context, + GoWriter writer, + MemberShape member, + String operand, + Runnable lambda + ) { + writeIfNonZeroValue(context, writer, member, operand, true, lambda); + } + + /** + * Writes non-zero conditional check around a lambda specific to a member of a container. + * + * Ignores empty strings of string pointers, and members nested within list and maps. + * + * @param context generation context + * @param writer go writer + * @param member API shape member to determine wrapping check with + * @param container operand of source member is a part of. + * @param lambda lambda to run + */ + public static void writeIfNonZeroValueMember( + ProtocolGenerator.GenerationContext context, + GoWriter writer, + MemberShape member, + String container, + Consumer lambda + ) { + String memberName = context.getSymbolProvider().toMemberName(member); + String operand = container + "." + memberName; + + writeIfNonZeroValue(context, writer, member, operand, true, () -> { + lambda.accept(operand); + }); + } + + /** + * Writes non-zero conditional check around a lambda specific to a member of a container. + * + * Note: Collections and map member values by default will not have individual checks on member values. To check + * not empty strings set the ignoreEmptyString to false. + * + * @param context generation context + * @param writer go writer + * @param member API shape member to determine wrapping check with + * @param container operand of source member is a part of. + * @param ignoreEmptyString if empty strings also checked + * @param lambda lambda to run + */ + public static void writeIfNonZeroValueMember( + ProtocolGenerator.GenerationContext context, + GoWriter writer, + MemberShape member, + String container, + boolean ignoreEmptyString, + Consumer lambda + ) { + String memberName = context.getSymbolProvider().toMemberName(member); + String operand = container + "." + memberName; + + writeIfNonZeroValue(context, writer, member, operand, ignoreEmptyString, () -> { + lambda.accept(operand); + }); + } + + /** + * Writes zero conditional checks around a lambda specific to the member shape type. + * + * Members with containers of Collection and map shapes, will ignore the lambda block + * and not call it. Optionally will ignore empty strings based on the ignoreEmptyString flag. + * + * Non-nillable shapes other than Enum, Boolean, and Number will ignore the lambda block. Optionally will ignore + * empty strings based on the ignoreEmptyString flag. + * + * Note: Collections and map member values by default will not have individual checks on member values. To check + * for empty strings set the ignoreEmptyString to false. + * + * @param context generation context + * @param writer go writer + * @param member API shape member to determine wrapping check with + * @param operand string of text with access to value + * @param ignoreEmptyString if empty strings also checked + * @param lambda lambda to run + */ + public static void writeIfZeroValue( + ProtocolGenerator.GenerationContext context, + GoWriter writer, + MemberShape member, + String operand, + Boolean ignoreEmptyString, + Runnable lambda + ) { + Shape targetShape = context.getModel().expectShape(member.getTarget()); + Shape container = context.getModel().expectShape(member.getContainer()); + + String check = "{"; + if (GoPointableIndex.of(context.getModel()).isNillable(member)) { + if (!ignoreEmptyString && targetShape.getType() == ShapeType.STRING) { + check = String.format("if %s == nil || len(*%s) == 0 {", operand, operand); + } else { + check = String.format("if %s == nil {", operand); + } + } else if (container instanceof CollectionShape || container.getType() == ShapeType.MAP) { + // Always serialize values in map/list/sets, no additional check, which means that the + // lambda will not be run, because there is no zero value to check against. + if (!ignoreEmptyString && targetShape.getType() == ShapeType.STRING) { + check = String.format("if len(%s) == 0 {", operand); + } else { + return; + } + + } else if (targetShape.hasTrait(EnumTrait.class)) { + check = String.format("if len(%s) == 0 {", operand); + + } else if (targetShape.getType() == ShapeType.BOOLEAN) { + check = String.format("if !%s {", operand); + + } else if (CodegenUtils.isNumber(targetShape)) { + check = String.format("if %s == 0 {", operand); + + } else if (!ignoreEmptyString && targetShape.getType() == ShapeType.STRING) { + check = String.format("if len(%s) == 0 {", operand); + + } else { + // default to empty block for variable scoping with not value check. + return; + } + + writer.openBlock(check, "}", lambda); + } + + /** + * Writes zero conditional checks around a lambda specific to the member shape type. + * + * Ignores empty strings of string pointers, and members nested within list and maps. + * + * @param context generation context + * @param writer go writer + * @param member API shape member to determine wrapping check with + * @param operand string of text with access to value + * @param lambda lambda to run + */ + public static void writeIfZeroValue( + ProtocolGenerator.GenerationContext context, + GoWriter writer, + MemberShape member, + String operand, + Runnable lambda + ) { + writeIfZeroValue(context, writer, member, operand, true, lambda); + } + + /** + * Writes zero conditional check around a lambda specific to a member of a container. + * + * Ignores empty strings of string pointers, and members nested within list and maps. + * + * @param context generation context + * @param writer go writer + * @param member API shape member to determine wrapping check with + * @param container operand of source member is a part of. + * @param lambda lambda to run + */ + public static void writeIfZeroValueMember( + ProtocolGenerator.GenerationContext context, + GoWriter writer, + MemberShape member, + String container, + Consumer lambda + ) { + String memberName = context.getSymbolProvider().toMemberName(member); + String operand = container + "." + memberName; + + writeIfZeroValue(context, writer, member, operand, () -> { + lambda.accept(operand); + }); + } + + /** + * Writes zero conditional check around a lambda specific to a member of a container. + * + * Ignores empty strings of string pointers, and members nested within list and maps. + * + * @param context generation context + * @param writer go writer + * @param member API shape member to determine wrapping check with + * @param container operand of source member is a part of. + * @param ignoreEmptyString if empty strings also checked + * @param lambda lambda to run + */ + public static void writeIfZeroValueMember( + ProtocolGenerator.GenerationContext context, + GoWriter writer, + MemberShape member, + String container, + boolean ignoreEmptyString, + Consumer lambda + ) { + String memberName = context.getSymbolProvider().toMemberName(member); + String operand = container + "." + memberName; + + writeIfZeroValue(context, writer, member, operand, ignoreEmptyString, () -> { + lambda.accept(operand); + }); + } +} diff --git a/codegen/smithy-go-codegen/src/main/java/software/amazon/smithy/go/codegen/GoWriter.java b/codegen/smithy-go-codegen/src/main/java/software/amazon/smithy/go/codegen/GoWriter.java index 8c448b3e1..e2f9c372d 100644 --- a/codegen/smithy-go-codegen/src/main/java/software/amazon/smithy/go/codegen/GoWriter.java +++ b/codegen/smithy-go-codegen/src/main/java/software/amazon/smithy/go/codegen/GoWriter.java @@ -346,7 +346,7 @@ public String apply(Object type, String indent) { if (isSlice || isMap) { resolvedSymbol = resolvedSymbol.getProperty(SymbolUtils.GO_ELEMENT_TYPE, Symbol.class) .orElseThrow(() -> new CodegenException("Expected go element type property to be defined")); - literal = apply(resolvedSymbol, indent); + literal = new PointableGoSymbolFormatter().apply(resolvedSymbol, "nested"); } else if (!SymbolUtils.isUniverseType(resolvedSymbol) && isExternalNamespace(resolvedSymbol.getNamespace())) { literal = formatWithNamespace(resolvedSymbol); diff --git a/codegen/smithy-go-codegen/src/main/java/software/amazon/smithy/go/codegen/ShapeValueGenerator.java b/codegen/smithy-go-codegen/src/main/java/software/amazon/smithy/go/codegen/ShapeValueGenerator.java index 7e9562661..bf7ebe4ec 100644 --- a/codegen/smithy-go-codegen/src/main/java/software/amazon/smithy/go/codegen/ShapeValueGenerator.java +++ b/codegen/smithy-go-codegen/src/main/java/software/amazon/smithy/go/codegen/ShapeValueGenerator.java @@ -18,11 +18,11 @@ package software.amazon.smithy.go.codegen; import java.util.Map; -import java.util.Optional; import java.util.logging.Logger; import software.amazon.smithy.codegen.core.CodegenException; import software.amazon.smithy.codegen.core.Symbol; import software.amazon.smithy.codegen.core.SymbolProvider; +import software.amazon.smithy.go.codegen.knowledge.GoPointableIndex; import software.amazon.smithy.model.Model; import software.amazon.smithy.model.node.ArrayNode; import software.amazon.smithy.model.node.BooleanNode; @@ -32,8 +32,6 @@ import software.amazon.smithy.model.node.NumberNode; import software.amazon.smithy.model.node.ObjectNode; import software.amazon.smithy.model.node.StringNode; -import software.amazon.smithy.model.shapes.CollectionShape; -import software.amazon.smithy.model.shapes.MapShape; import software.amazon.smithy.model.shapes.MemberShape; import software.amazon.smithy.model.shapes.Shape; import software.amazon.smithy.model.shapes.ShapeType; @@ -51,6 +49,7 @@ public final class ShapeValueGenerator { protected final Model model; protected final SymbolProvider symbolProvider; + protected final GoPointableIndex pointableIndex; /** * Initializes a shape value generator. @@ -61,6 +60,7 @@ public final class ShapeValueGenerator { public ShapeValueGenerator(Model model, SymbolProvider symbolProvider) { this.model = model; this.symbolProvider = symbolProvider; + this.pointableIndex = GoPointableIndex.of(model); } /** @@ -70,42 +70,70 @@ public ShapeValueGenerator(Model model, SymbolProvider symbolProvider) { * @param shape the shape that will be declared. * @param params parameters to fill the generated shape declaration. */ - public void writeShapeValueInline(GoWriter writer, Shape shape, Node params) { + public void writePointableStructureShapeValueInline(GoWriter writer, StructureShape shape, Node params) { if (params.isNullNode()) { - if (shape.isStringShape() && shape.hasTrait(EnumTrait.class)) { - Symbol enumSymbol = symbolProvider.toSymbol(shape); + writer.writeInline("nil"); + } + + // Input/output struct top level shapes are special since they are the only shape that can be used directly, + // not within the context of a member shape reference. + Symbol symbol = symbolProvider.toSymbol(shape); + writer.write("&$T{", symbol); + params.accept(new ShapeValueNodeVisitor(writer, this, shape)); + writer.writeInline("}"); + } + + /** + * Writes generation of a member shape value type declaration for the given the parameters. + * + * @param writer writer to write generated code with. + * @param member the shape that will be declared. + * @param params parameters to fill the generated shape declaration. + */ + protected void writeMemberValueInline(GoWriter writer, MemberShape member, Node params) { + Shape targetShape = model.expectShape(member.getTarget()); + + // Null params need to be represented as zero values for member, + if (params.isNullNode()) { + if (pointableIndex.isNillable(member)) { + writer.writeInline("nil"); + + } else if (targetShape.getType() == ShapeType.STRING && targetShape.hasTrait(EnumTrait.class)) { + Symbol enumSymbol = symbolProvider.toSymbol(targetShape); writer.writeInline("$T($S)", enumSymbol, ""); + } else { - writer.writeInline("nil"); + Symbol shapeSymbol = symbolProvider.toSymbol(member); + writer.writeInline("func() (v $P) { return v }()", shapeSymbol); } return; } - switch (shape.getType()) { + switch (targetShape.getType()) { case STRUCTURE: - structDeclShapeValue(writer, shape.asStructureShape().get(), params); + structDeclShapeValue(writer, member, params); break; case SET: case LIST: - listDeclShapeValue(writer, (CollectionShape) shape, params); + listDeclShapeValue(writer, member, params); break; case MAP: - mapDeclShapeValue(writer, shape.asMapShape().get(), params); + mapDeclShapeValue(writer, member, params); break; case UNION: - unionDeclShapeValue(writer, shape.asUnionShape().get(), params.expectObjectNode()); + unionDeclShapeValue(writer, member, params.expectObjectNode()); break; case DOCUMENT: - LOGGER.warning("Skipping " + shape.getType() + " shape type not supported, " + shape.getId()); + LOGGER.warning("Skipping " + member.getType() + " shape type not supported, " + member.getId()); writer.writeInline("nil"); break; default: - writeScalarPointerInline(writer, shape, params); + writeScalarPointerInline(writer, member, params); } } @@ -113,14 +141,15 @@ public void writeShapeValueInline(GoWriter writer, Shape shape, Node params) { * Writes the declaration for a Go structure. Delegates to the runner for member fields within the structure. * * @param writer writer to write generated code with. - * @param shape the structure shape + * @param member the structure shape * @param params parameters to fill the generated shape declaration. */ - protected void structDeclShapeValue(GoWriter writer, StructureShape shape, Node params) { - Symbol symbol = symbolProvider.toSymbol(shape); + protected void structDeclShapeValue(GoWriter writer, MemberShape member, Node params) { + Symbol symbol = symbolProvider.toSymbol(member); - writer.write("&$T{", symbol); - params.accept(new ShapeValueNodeVisitor(writer, this, shape)); + String addr = CodegenUtils.asAddressIfAddressable(model, pointableIndex, member, ""); + writer.write("$L$T{", addr, symbol); + params.accept(new ShapeValueNodeVisitor(writer, this, model.expectShape(member.getTarget()))); writer.writeInline("}"); } @@ -128,28 +157,33 @@ protected void structDeclShapeValue(GoWriter writer, StructureShape shape, Node * Writes the declaration for a Go union. * * @param writer writer to write generated code with. - * @param shape the union shape. + * @param member the union shape. * @param params the params. */ - protected void unionDeclShapeValue(GoWriter writer, UnionShape shape, ObjectNode params) { - Symbol symbol = symbolProvider.toSymbol(shape); + protected void unionDeclShapeValue(GoWriter writer, MemberShape member, ObjectNode params) { + UnionShape targetShape = (UnionShape) model.expectShape(member.getTarget()); + for (Map.Entry entry : params.getMembers().entrySet()) { - Optional member = shape.getMember(entry.getKey().toString()); - if (member.isPresent()) { - Shape target = model.expectShape(member.get().getTarget()); - Symbol memberSymbol = SymbolUtils.createValueSymbolBuilder( - symbolProvider.toMemberName(member.get()), - symbol.getNamespace() + targetShape.getMember(entry.getKey().toString()).ifPresent((unionMember) -> { + Shape unionTarget = model.expectShape(unionMember.getTarget()); + + // Need to manually create a symbol builder for a union member struct type because the "member" + // of a union will return the inner value type not the member not the member type it self. + Symbol memberSymbol = SymbolUtils.createPointableSymbolBuilder( + symbolProvider.toMemberName(unionMember), + symbolProvider.toSymbol(targetShape).getNamespace() ).build(); + // Union member types are always pointers writer.writeInline("&$T{Value: ", memberSymbol); - if (target instanceof SimpleShape) { - writeScalarValueInline(writer, target, entry.getValue()); + if (unionTarget instanceof SimpleShape) { + writeScalarValueInline(writer, unionMember, entry.getValue()); } else { - writeShapeValueInline(writer, target, entry.getValue()); + writeMemberValueInline(writer, unionMember, entry.getValue()); } writer.writeInline("}"); - } + }); + return; } } @@ -158,15 +192,12 @@ protected void unionDeclShapeValue(GoWriter writer, UnionShape shape, ObjectNode * Writes the declaration for a Go slice. Delegates to the runner for fields within the slice. * * @param writer writer to write generated code with. - * @param shape the collection shape + * @param member the collection shape * @param params parameters to fill the generated shape declaration. */ - protected void listDeclShapeValue(GoWriter writer, CollectionShape shape, Node params) { - Shape memberShape = model.expectShape(shape.getMember().getTarget()); - Symbol memberSymbol = symbolProvider.toSymbol(memberShape); - - writer.write("[]$P{", memberSymbol); - params.accept(new ShapeValueNodeVisitor(writer, this, shape)); + protected void listDeclShapeValue(GoWriter writer, MemberShape member, Node params) { + writer.write("$P{", symbolProvider.toSymbol(member)); + params.accept(new ShapeValueNodeVisitor(writer, this, model.expectShape(member.getTarget()))); writer.writeInline("}"); } @@ -174,103 +205,98 @@ protected void listDeclShapeValue(GoWriter writer, CollectionShape shape, Node p * Writes the declaration for a Go map. Delegates to the runner for key/value fields within the map. * * @param writer writer to write generated code with. - * @param shape the map shape. + * @param member the map shape. * @param params parameters to fill the generated shape declaration. */ - protected void mapDeclShapeValue(GoWriter writer, MapShape shape, Node params) { - Shape valueShape = model.expectShape(shape.getValue().getTarget()); - Shape keyShape = model.expectShape(shape.getKey().getTarget()); - - Symbol valueSymbol = symbolProvider.toSymbol(valueShape); - Symbol keySymbol = symbolProvider.toSymbol(keyShape); - - writer.write("map[$T]$P{", keySymbol, valueSymbol); - params.accept(new ShapeValueNodeVisitor(writer, this, shape)); + protected void mapDeclShapeValue(GoWriter writer, MemberShape member, Node params) { + writer.write("$P{", symbolProvider.toSymbol(member)); + params.accept(new ShapeValueNodeVisitor(writer, this, model.expectShape(member.getTarget()))); writer.writeInline("}"); } + private void writeScalarWrapper( + GoWriter writer, + MemberShape member, + Node params, + String funcName, + TriConsumer inner + ) { + if (pointableIndex.isPointable(member)) { + writer.addUseImports(SmithyGoDependency.SMITHY_PTR); + writer.writeInline("ptr." + funcName + "("); + inner.accept(writer, member, params); + writer.writeInline(")"); + } else { + inner.accept(writer, member, params); + } + } + /** * Writes scalar values with pointer value wrapping as needed based on the shape type. * * @param writer writer to write generated code with. - * @param shape scalar shape. + * @param member scalar shape. * @param params parameters to fill the generated shape declaration. */ - protected void writeScalarPointerInline(GoWriter writer, Shape shape, Node params) { - boolean withPtrImport = true; - String closing = ")"; + protected void writeScalarPointerInline(GoWriter writer, MemberShape member, Node params) { + Shape target = model.expectShape(member.getTarget()); - switch (shape.getType()) { + String funcName = ""; + switch (target.getType()) { case BOOLEAN: - writer.writeInline("ptr.Bool("); - break; - - case BLOB: - closing = ""; - withPtrImport = false; + funcName = "Bool"; break; case STRING: - // Enum are not pointers, but string alias values - if (shape.hasTrait(StreamingTrait.class) || shape.hasTrait(EnumTrait.class)) { - closing = ""; - withPtrImport = false; - } else { - writer.writeInline("ptr.String("); - } - + funcName = "String"; break; case TIMESTAMP: - writer.writeInline("ptr.Time("); + funcName = "Time"; break; case BYTE: - writer.writeInline("ptr.Int8("); + funcName = "Int8"; break; - case SHORT: - writer.writeInline("ptr.Int16("); + funcName = "Int16"; break; - case INTEGER: - writer.writeInline("ptr.Int32("); + funcName = "Int32"; break; - case LONG: - writer.writeInline("ptr.Int64("); + funcName = "Int64"; break; case FLOAT: - writer.writeInline("ptr.Float32("); + funcName = "Float32"; break; - case DOUBLE: - writer.writeInline("ptr.Float64("); + funcName = "Float64"; + break; + + case BLOB: break; case BIG_INTEGER: case BIG_DECIMAL: - writeScalarValueInline(writer, shape, params); return; default: - throw new CodegenException("unexpected shape type " + shape.getType()); - } - - if (withPtrImport) { - writer.addUseImports(SmithyGoDependency.SMITHY_PTR); + throw new CodegenException("unexpected shape type " + target.getType()); } - writeScalarValueInline(writer, shape, params); - writer.writeInline(closing); + writeScalarWrapper(writer, member, params, funcName, this::writeScalarValueInline); } - protected void writeScalarValueInline(GoWriter writer, Shape shape, Node params) { + protected void writeScalarValueInline(GoWriter writer, MemberShape member, Node params) { + Shape target = model.expectShape(member.getTarget()); + String closing = ""; - switch (shape.getType()) { + switch (target.getType()) { case BLOB: - if (shape.hasTrait(StreamingTrait.class)) { + // blob streams are io.Readers not byte slices. + if (target.hasTrait(StreamingTrait.class)) { writer.addUseImports(SmithyGoDependency.SMITHY_IO); writer.addUseImports(SmithyGoDependency.BYTES); writer.writeInline("smithyio.ReadSeekNopCloser{ReadSeeker: bytes.NewReader([]byte("); @@ -282,15 +308,16 @@ protected void writeScalarValueInline(GoWriter writer, Shape shape, Node params) break; case STRING: - // Enum are not pointers, but string alias values - if (shape.hasTrait(StreamingTrait.class)) { + // String streams are io.Readers not strings. + if (target.hasTrait(StreamingTrait.class)) { writer.addUseImports(SmithyGoDependency.SMITHY_IO); writer.addUseImports(SmithyGoDependency.STRINGS); writer.writeInline("smithyio.ReadSeekNopCloser{ReadSeeker: strings.NewReader("); closing = ")}"; - } else if (shape.hasTrait(EnumTrait.class)) { - Symbol enumSymbol = symbolProvider.toSymbol(shape); + } else if (target.hasTrait(EnumTrait.class)) { + // Enum are not pointers, but string alias values + Symbol enumSymbol = symbolProvider.toSymbol(target); writer.writeInline("$T(", enumSymbol); closing = ")"; } @@ -299,7 +326,8 @@ protected void writeScalarValueInline(GoWriter writer, Shape shape, Node params) default: break; } - params.accept(new ShapeValueNodeVisitor(writer, this, shape)); + + params.accept(new ShapeValueNodeVisitor(writer, this, target)); writer.writeInline(closing); } @@ -332,10 +360,10 @@ private ShapeValueNodeVisitor(GoWriter writer, ShapeValueGenerator valueGen, Sha */ @Override public Void arrayNode(ArrayNode node) { - Shape memberShape = model.expectShape(((CollectionShape) this.currentShape).getMember().getTarget()); + MemberShape memberShape = CodegenUtils.expectCollectionShape(this.currentShape).getMember(); node.getElements().forEach(element -> { - valueGen.writeShapeValueInline(writer, memberShape, element); + valueGen.writeMemberValueInline(writer, memberShape, element); writer.write(","); }); return null; @@ -350,10 +378,9 @@ public Void arrayNode(ArrayNode node) { @Override public Void objectNode(ObjectNode node) { node.getMembers().forEach((keyNode, valueNode) -> { - Shape memberShape; + MemberShape member; switch (currentShape.getType()) { case STRUCTURE: - MemberShape member; if (currentShape.asStructureShape().get().getMember(keyNode.getValue()).isPresent()) { member = currentShape.asStructureShape().get().getMember(keyNode.getValue()).get(); } else { @@ -361,19 +388,17 @@ public Void objectNode(ObjectNode node) { "unknown member " + currentShape.getId() + "." + keyNode.getValue()); } - memberShape = model.expectShape(member.getTarget()); String memberName = symbolProvider.toMemberName(member); - writer.write("$L: ", memberName); - valueGen.writeShapeValueInline(writer, memberShape, valueNode); + valueGen.writeMemberValueInline(writer, member, valueNode); writer.write(","); break; case MAP: - memberShape = model.expectShape(this.currentShape.asMapShape().get().getValue().getTarget()); + member = this.currentShape.asMapShape().get().getValue(); writer.write("$S: ", keyNode.getValue()); - valueGen.writeShapeValueInline(writer, memberShape, valueNode); + valueGen.writeMemberValueInline(writer, member, valueNode); writer.write(","); break; @@ -409,14 +434,7 @@ public Void booleanNode(BooleanNode node) { */ @Override public Void nullNode(NullNode node) { - if (currentShape.getType() == ShapeType.STRING && currentShape.hasTrait(EnumTrait.class)) { - Symbol enumSymbol = symbolProvider.toSymbol(currentShape); - writer.writeInline("$T($S)", enumSymbol, ""); - } else { - writer.writeInline("nil"); - } - - return null; + throw new CodegenException("unexpected null node walked, should not be encountered in walker"); } /** diff --git a/codegen/smithy-go-codegen/src/main/java/software/amazon/smithy/go/codegen/StructureGenerator.java b/codegen/smithy-go-codegen/src/main/java/software/amazon/smithy/go/codegen/StructureGenerator.java index 9b93f24b0..70cf72f23 100644 --- a/codegen/smithy-go-codegen/src/main/java/software/amazon/smithy/go/codegen/StructureGenerator.java +++ b/codegen/smithy-go-codegen/src/main/java/software/amazon/smithy/go/codegen/StructureGenerator.java @@ -21,8 +21,6 @@ import software.amazon.smithy.codegen.core.SymbolProvider; import software.amazon.smithy.model.Model; import software.amazon.smithy.model.shapes.MemberShape; -import software.amazon.smithy.model.shapes.Shape; -import software.amazon.smithy.model.shapes.ShapeType; import software.amazon.smithy.model.shapes.StructureShape; import software.amazon.smithy.model.traits.ErrorTrait; import software.amazon.smithy.utils.MapUtils; @@ -158,16 +156,4 @@ private void renderErrorStructure() { } writer.write("func (e *$L) ErrorFault() smithy.ErrorFault { return $L }", structureSymbol.getName(), fault); } - - private String getterReturnFormatter(Shape shape) { - if (CodegenUtils.isShapePassByValue(shape)) { - return "$P"; - } - - if (shape.getType() == ShapeType.STRUCTURE) { - return "$P"; - } - - return "$T"; - } } diff --git a/codegen/smithy-go-codegen/src/main/java/software/amazon/smithy/go/codegen/SymbolVisitor.java b/codegen/smithy-go-codegen/src/main/java/software/amazon/smithy/go/codegen/SymbolVisitor.java index 5cccde1e9..99126e968 100644 --- a/codegen/smithy-go-codegen/src/main/java/software/amazon/smithy/go/codegen/SymbolVisitor.java +++ b/codegen/smithy-go-codegen/src/main/java/software/amazon/smithy/go/codegen/SymbolVisitor.java @@ -25,6 +25,7 @@ import software.amazon.smithy.codegen.core.ReservedWordsBuilder; import software.amazon.smithy.codegen.core.Symbol; import software.amazon.smithy.codegen.core.SymbolProvider; +import software.amazon.smithy.go.codegen.knowledge.GoPointableIndex; import software.amazon.smithy.go.codegen.trait.UnexportedMemberTrait; import software.amazon.smithy.model.Model; import software.amazon.smithy.model.knowledge.NeighborProviderIndex; @@ -69,7 +70,6 @@ * suffixed with "_". See "reserved-words.txt" for the list of words. */ final class SymbolVisitor implements SymbolProvider, ShapeVisitor { - private static final Logger LOGGER = Logger.getLogger(SymbolVisitor.class.getName()); private final Model model; @@ -78,12 +78,14 @@ final class SymbolVisitor implements SymbolProvider, ShapeVisitor { private final ReservedWordSymbolProvider.Escaper escaper; private final ReservedWordSymbolProvider.Escaper errorMemberEscaper; private final Map structureSpecificMemberEscapers = new HashMap<>(); + private final GoPointableIndex pointableIndex; SymbolVisitor(Model model, String rootModuleName) { this.model = model; this.rootModuleName = rootModuleName; this.typesPackageName = rootModuleName + "/types"; + this.pointableIndex = GoPointableIndex.of(model); // Reserve the generated names for union members, including the unknown case. ReservedWordsBuilder reservedNames = new ReservedWordsBuilder() @@ -129,7 +131,7 @@ final class SymbolVisitor implements SymbolProvider, ShapeVisitor { * *

These have the format {UnionName}Member{MemberName}. * - * @param model The model whose unions should be reserved. + * @param model The model whose unions should be reserved. * @param builder A reserved words builder to add on to. */ private void reserveUnionMemberNames(Model model, ReservedWordsBuilder builder) { @@ -265,19 +267,19 @@ private boolean isErrorMember(MemberShape shape) { @Override public Symbol blobShape(BlobShape shape) { if (shape.hasTrait(StreamingTrait.ID)) { - Symbol inputVariant = SymbolUtils.createValueSymbolBuilder(shape, "Reader", SmithyGoDependency.IO).build(); - return SymbolUtils.createValueSymbolBuilder(shape, "ReadCloser", SmithyGoDependency.IO) + Symbol inputVariant = symbolBuilderFor(shape, "Reader", SmithyGoDependency.IO).build(); + return symbolBuilderFor(shape, "ReadCloser", SmithyGoDependency.IO) .putProperty(SymbolUtils.INPUT_VARIANT, inputVariant) .build(); } - return SymbolUtils.createValueSymbolBuilder(shape, "[]byte") + return symbolBuilderFor(shape, "[]byte") .putProperty(SymbolUtils.GO_UNIVERSE_TYPE, true) .build(); } @Override public Symbol booleanShape(BooleanShape shape) { - return SymbolUtils.createPointableSymbolBuilder(shape, "bool") + return symbolBuilderFor(shape, "bool") .putProperty(SymbolUtils.GO_UNIVERSE_TYPE, true) .build(); } @@ -299,7 +301,7 @@ private Symbol createCollectionSymbol(CollectionShape shape) { Symbol reference = toSymbol(shape.getMember()); // Shape name will be unused for symbols that represent a slice, but in the event it does we set the collection // shape's name to make debugging simpler. - return SymbolUtils.createValueSymbolBuilder(shape, getDefaultShapeName(shape)) + return symbolBuilderFor(shape, getDefaultShapeName(shape)) .putProperty(SymbolUtils.GO_SLICE, true) .putProperty(SymbolUtils.GO_UNIVERSE_TYPE, reference.getProperty(SymbolUtils.GO_UNIVERSE_TYPE, Boolean.class).orElse(false)) @@ -312,7 +314,7 @@ public Symbol mapShape(MapShape shape) { Symbol reference = toSymbol(shape.getValue()); // Shape name will be unused for symbols that represent a map, but in the event it does we set the map shape's // name to make debugging simpler. - return SymbolUtils.createValueSymbolBuilder(shape, getDefaultShapeName(shape)) + return symbolBuilderFor(shape, getDefaultShapeName(shape)) .putProperty(SymbolUtils.GO_MAP, true) .putProperty(SymbolUtils.GO_UNIVERSE_TYPE, reference.getProperty(SymbolUtils.GO_UNIVERSE_TYPE, Boolean.class).orElse(false)) @@ -320,47 +322,68 @@ public Symbol mapShape(MapShape shape) { .build(); } + private Symbol.Builder symbolBuilderFor(Shape shape, String typeName) { + if (pointableIndex.isPointable(shape)) { + return SymbolUtils.createPointableSymbolBuilder(shape, typeName); + } + + return SymbolUtils.createValueSymbolBuilder(shape, typeName); + } + + private Symbol.Builder symbolBuilderFor(Shape shape, String typeName, GoDependency namespace) { + if (pointableIndex.isPointable(shape)) { + return SymbolUtils.createPointableSymbolBuilder(shape, typeName, namespace); + } + + return SymbolUtils.createValueSymbolBuilder(shape, typeName, namespace); + } + + private Symbol.Builder symbolBuilderFor(Shape shape, String typeName, String namespace) { + if (pointableIndex.isPointable(shape)) { + return SymbolUtils.createPointableSymbolBuilder(shape, typeName, namespace); + } + + return SymbolUtils.createValueSymbolBuilder(shape, typeName, namespace); + } + @Override public Symbol byteShape(ByteShape shape) { - return SymbolUtils.createPointableSymbolBuilder(shape, "int8").build(); + return symbolBuilderFor(shape, "int8") + .putProperty(SymbolUtils.GO_UNIVERSE_TYPE, true) + .build(); } @Override public Symbol shortShape(ShortShape shape) { - return SymbolUtils.createPointableSymbolBuilder(shape, "int16") + return symbolBuilderFor(shape, "int16") .putProperty(SymbolUtils.GO_UNIVERSE_TYPE, true) .build(); } @Override public Symbol integerShape(IntegerShape shape) { - return SymbolUtils.createPointableSymbolBuilder(shape, "int32") + return symbolBuilderFor(shape, "int32") .putProperty(SymbolUtils.GO_UNIVERSE_TYPE, true) .build(); } @Override public Symbol longShape(LongShape shape) { - return SymbolUtils.createPointableSymbolBuilder(shape, "int64") + return symbolBuilderFor(shape, "int64") .putProperty(SymbolUtils.GO_UNIVERSE_TYPE, true) .build(); } @Override public Symbol floatShape(FloatShape shape) { - return SymbolUtils.createPointableSymbolBuilder(shape, "float32") + return symbolBuilderFor(shape, "float32") .putProperty(SymbolUtils.GO_UNIVERSE_TYPE, true) .build(); } - @Override - public Symbol documentShape(DocumentShape shape) { - return SymbolUtils.createValueSymbolBuilder(shape, "Document", SmithyGoDependency.SMITHY).build(); - } - @Override public Symbol doubleShape(DoubleShape shape) { - return SymbolUtils.createPointableSymbolBuilder(shape, "float64") + return symbolBuilderFor(shape, "float64") .putProperty(SymbolUtils.GO_UNIVERSE_TYPE, true) .build(); } @@ -372,11 +395,19 @@ public Symbol bigIntegerShape(BigIntegerShape shape) { @Override public Symbol bigDecimalShape(BigDecimalShape shape) { + return createBigSymbol(shape, "Float"); } private Symbol createBigSymbol(Shape shape, String symbolName) { - return SymbolUtils.createPointableSymbolBuilder(shape, symbolName, SmithyGoDependency.BIG).build(); + return symbolBuilderFor(shape, symbolName, SmithyGoDependency.BIG) + .build(); + } + + @Override + public Symbol documentShape(DocumentShape shape) { + return symbolBuilderFor(shape, "Document", SmithyGoDependency.SMITHY) + .build(); } @Override @@ -395,7 +426,7 @@ public Symbol resourceShape(ResourceShape shape) { @Override public Symbol serviceShape(ServiceShape shape) { - return SymbolUtils.createPointableSymbolBuilder(shape, "Client", rootModuleName) + return symbolBuilderFor(shape, "Client", rootModuleName) .definitionFile("./api_client.go") .build(); } @@ -404,12 +435,12 @@ public Symbol serviceShape(ServiceShape shape) { public Symbol stringShape(StringShape shape) { if (shape.hasTrait(EnumTrait.class)) { String name = getDefaultShapeName(shape); - return SymbolUtils.createValueSymbolBuilder(shape, name, typesPackageName) + return symbolBuilderFor(shape, name, typesPackageName) .definitionFile("./types/enums.go") .build(); } - return SymbolUtils.createPointableSymbolBuilder(shape, "string") + return symbolBuilderFor(shape, "string") .putProperty(SymbolUtils.GO_UNIVERSE_TYPE, true) .build(); } @@ -420,23 +451,24 @@ public Symbol structureShape(StructureShape shape) { if (shape.getId().getNamespace().equals(CodegenUtils.getSyntheticTypeNamespace())) { Optional boundOperationName = getNameOfBoundOperation(shape); if (boundOperationName.isPresent()) { - return SymbolUtils.createPointableSymbolBuilder(shape, name, rootModuleName) + return symbolBuilderFor(shape, name, rootModuleName) .definitionFile("./api_op_" + boundOperationName.get() + ".go") .build(); } } - Symbol.Builder builder = SymbolUtils.createPointableSymbolBuilder(shape, name, typesPackageName); + Symbol.Builder builder = symbolBuilderFor(shape, name, typesPackageName); if (shape.hasTrait(ErrorTrait.ID)) { builder.definitionFile("./types/errors.go"); } else { builder.definitionFile("./types/types.go"); } + return builder.build(); } private Optional getNameOfBoundOperation(StructureShape shape) { - NeighborProvider provider = model.getKnowledge(NeighborProviderIndex.class).getReverseProvider(); + NeighborProvider provider = NeighborProviderIndex.of(model).getReverseProvider(); for (Relationship relationship : provider.getNeighbors(shape)) { RelationshipType relationshipType = relationship.getRelationshipType(); if (relationshipType == RelationshipType.INPUT || relationshipType == RelationshipType.OUTPUT) { @@ -449,20 +481,22 @@ private Optional getNameOfBoundOperation(StructureShape shape) { @Override public Symbol unionShape(UnionShape shape) { String name = getDefaultShapeName(shape); - return SymbolUtils.createValueSymbolBuilder(shape, name, typesPackageName) + return symbolBuilderFor(shape, name, typesPackageName) .definitionFile("./types/types.go") .build(); } @Override - public Symbol memberShape(MemberShape shape) { - Shape targetShape = model.getShape(shape.getTarget()) - .orElseThrow(() -> new CodegenException("Shape not found: " + shape.getTarget())); - return toSymbol(targetShape); + public Symbol memberShape(MemberShape member) { + Shape targetShape = model.expectShape(member.getTarget()); + return toSymbol(targetShape) + .toBuilder() + .putProperty(SymbolUtils.POINTABLE, pointableIndex.isPointable(member)) + .build(); } @Override public Symbol timestampShape(TimestampShape shape) { - return SymbolUtils.createPointableSymbolBuilder(shape, "Time", SmithyGoDependency.TIME).build(); + return symbolBuilderFor(shape, "Time", SmithyGoDependency.TIME).build(); } } diff --git a/codegen/smithy-go-codegen/src/main/java/software/amazon/smithy/go/codegen/integration/HttpBindingProtocolGenerator.java b/codegen/smithy-go-codegen/src/main/java/software/amazon/smithy/go/codegen/integration/HttpBindingProtocolGenerator.java index 20d93e3fc..17a9cc49a 100644 --- a/codegen/smithy-go-codegen/src/main/java/software/amazon/smithy/go/codegen/integration/HttpBindingProtocolGenerator.java +++ b/codegen/smithy-go-codegen/src/main/java/software/amazon/smithy/go/codegen/integration/HttpBindingProtocolGenerator.java @@ -17,7 +17,6 @@ import static software.amazon.smithy.go.codegen.integration.HttpProtocolGeneratorUtils.isShapeWithResponseBindings; import static software.amazon.smithy.go.codegen.integration.ProtocolUtils.requiresDocumentSerdeFunction; -import static software.amazon.smithy.go.codegen.integration.ProtocolUtils.writeSafeMemberAccessor; import java.util.Collection; import java.util.Comparator; @@ -26,7 +25,6 @@ import java.util.Set; import java.util.TreeSet; import java.util.function.BiConsumer; -import java.util.function.Consumer; import java.util.logging.Logger; import java.util.stream.Collectors; import software.amazon.smithy.codegen.core.CodegenException; @@ -35,9 +33,11 @@ import software.amazon.smithy.go.codegen.ApplicationProtocol; import software.amazon.smithy.go.codegen.CodegenUtils; import software.amazon.smithy.go.codegen.GoStackStepMiddlewareGenerator; +import software.amazon.smithy.go.codegen.GoValueAccessUtils; import software.amazon.smithy.go.codegen.GoWriter; import software.amazon.smithy.go.codegen.SmithyGoDependency; import software.amazon.smithy.go.codegen.SymbolUtils; +import software.amazon.smithy.go.codegen.knowledge.GoPointableIndex; import software.amazon.smithy.go.codegen.trait.NoSerializeTrait; import software.amazon.smithy.model.Model; import software.amazon.smithy.model.knowledge.HttpBinding; @@ -390,7 +390,7 @@ protected void writeMiddlewarePayloadSerializerDelegator( Model model = context.getModel(); Shape payloadShape = model.expectShape(memberShape.getTarget()); - writeSafeMemberAccessor(context, memberShape, "input", s -> { + GoValueAccessUtils.writeIfNonZeroValueMember(context, writer, memberShape, "input", (s) -> { writer.openBlock("if !restEncoder.HasHeader(\"Content-Type\") {", "}", () -> { writer.write("restEncoder.SetHeader(\"Content-Type\").String($S)", getPayloadShapeMediaType(payloadShape)); @@ -589,29 +589,28 @@ private boolean isHttpDateTimestamp(Model model, HttpBinding.Location location, return false; } - TimestampFormatTrait.Format format = model.getKnowledge(HttpBindingIndex.class).determineTimestampFormat( + TimestampFormatTrait.Format format = HttpBindingIndex.of(model).determineTimestampFormat( memberShape, location, getDocumentTimestampFormat()); return format == Format.HTTP_DATE; } private void writeHttpBindingSetter( - Model model, + GenerationContext context, GoWriter writer, MemberShape memberShape, HttpBinding.Location location, String operand, BiConsumer locationEncoder ) { + Model model = context.getModel(); Shape targetShape = model.expectShape(memberShape.getTarget()); // We only need to dereference if we pass the shape around as reference in Go. // Note we make two exceptions here: big.Int and big.Float should still be passed as reference to the helper // method as they can be arbitrarily large. - operand = CodegenUtils.isShapePassByReference(targetShape) - && targetShape.getType() != ShapeType.BIG_INTEGER - && targetShape.getType() != ShapeType.BIG_DECIMAL - ? "*" + operand : operand; + operand = CodegenUtils.getAsValueIfDereferencable(GoPointableIndex.of(context.getModel()), memberShape, + operand); switch (targetShape.getType()) { case BOOLEAN: @@ -663,10 +662,10 @@ private void writeHttpBindingMember( Shape targetShape = model.expectShape(memberShape.getTarget()); HttpBinding.Location location = binding.getLocation(); - // throw an error if member shape targets location label, but is unset + // return an error if member shape targets location label, but is unset. if (location.equals(HttpBinding.Location.LABEL)) { - // labels must always be set to be serialized on URI - throwSerializationErrorIfMemberUnset(context, memberShape, "v", operand -> { + // labels must always be set to be serialized on URI, and non empty strings, + GoValueAccessUtils.writeIfZeroValueMember(context, writer, memberShape, "v", false, operand -> { writer.addUseImports(SmithyGoDependency.SMITHY); writer.write("return &smithy.SerializationError { " + "Err: fmt.Errorf(\"input member $L must not be empty\")}", @@ -674,18 +673,19 @@ private void writeHttpBindingMember( }); } + boolean allowZeroStrings = location != HttpBinding.Location.HEADER; - writeSafeMemberAccessor(context, memberShape, "v", operand -> { + GoValueAccessUtils.writeIfNonZeroValueMember(context, writer, memberShape, "v", allowZeroStrings, (operand) -> { final String locationName = binding.getLocationName().isEmpty() ? memberShape.getMemberName() : binding.getLocationName(); - switch (location) { case HEADER: writer.write("locationName := $S", getCanonicalHeader(locationName)); writeHeaderBinding(context, memberShape, operand, location, "locationName", "encoder"); break; case PREFIX_HEADERS: - MemberShape valueMemberShape = model.expectShape(targetShape.getId(), MapShape.class).getValue(); + MemberShape valueMemberShape = model.expectShape(targetShape.getId(), + MapShape.class).getValue(); Shape valueMemberTarget = model.expectShape(valueMemberShape.getTarget()); if (targetShape.getType() != ShapeType.MAP) { @@ -696,37 +696,38 @@ private void writeHttpBindingMember( writer.write("hv := encoder.Headers($S)", getCanonicalHeader(locationName)); writer.addUseImports(SmithyGoDependency.NET_HTTP); writer.openBlock("for mapKey, mapVal := range $L {", "}", operand, () -> { - writeHeaderBinding(context, valueMemberShape, "mapVal", location, - "http.CanonicalHeaderKey(mapKey)", "hv"); + GoValueAccessUtils.writeIfNonZeroValue(context, writer, valueMemberShape, "mapVal", false, + () -> { + writeHeaderBinding(context, valueMemberShape, "mapVal", location, + "http.CanonicalHeaderKey(mapKey)", "hv"); + }); }); break; case LABEL: - writeHttpBindingSetter(model, writer, memberShape, location, operand, (w, s) -> { - // throw a serializationErrror if target shape is string/enum, and is empty. - // labels must always be set to be serialized on URI - throwSerializationErrorIfStringMemberIsEmpty(w, memberShape, targetShape, operand); - w.writeInline("if err := encoder.SetURI($S).$L", locationName, s); - w.write("; err != nil {\n" - + "\treturn err\n" - + "}"); + writeHttpBindingSetter(context, writer, memberShape, location, operand, (w, s) -> { + w.openBlock("if err := encoder.SetURI($S).$L; err != nil {", "}", locationName, s, + () -> { + w.write("return err"); + }); }); break; case QUERY: if (targetShape instanceof CollectionShape) { - MemberShape collectionMember = ((CollectionShape) targetShape).getMember(); + MemberShape collectionMember = CodegenUtils.expectCollectionShape(targetShape).getMember(); writer.openBlock("for i := range $L {", "}", operand, () -> { - Shape collectionMemberTargetShape = model.expectShape(collectionMember.getTarget()); - if (!collectionMemberTargetShape.hasTrait(EnumTrait.class)) { - writer.openBlock("if $L == nil { continue }", operand + "[i]"); - } - writeHttpBindingSetter(model, writer, collectionMember, location, operand + "[i]", + GoValueAccessUtils.writeIfZeroValue(context, writer, collectionMember, + operand + "[i]", () -> { + writer.write("continue"); + }); + writeHttpBindingSetter(context, writer, collectionMember, location, operand + "[i]", (w, s) -> { w.writeInline("encoder.AddQuery($S).$L", locationName, s); }); }); } else { - writeHttpBindingSetter(model, writer, memberShape, location, operand, (w, s) -> w.writeInline( - "encoder.SetQuery($S).$L", locationName, s)); + writeHttpBindingSetter(context, writer, memberShape, location, operand, + (w, s) -> w.writeInline( + "encoder.SetQuery($S).$L", locationName, s)); } break; default: @@ -735,77 +736,6 @@ private void writeHttpBindingMember( }); } - /** - * Throws a serialization error if string member value is empty. - * - * @param writer Gowriter - * @param memberShape member shape - * @param targetShape target shape - * @param operand operand used to denote member - */ - private void throwSerializationErrorIfStringMemberIsEmpty( - GoWriter writer, - MemberShape memberShape, - Shape targetShape, - String operand - ) { - if (!targetShape.isStringShape()) { - return; - } - - operand = CodegenUtils.isShapePassByReference(targetShape) - ? "*" + operand : operand; - operand = targetShape.hasTrait(EnumTrait.class) - ? "string(" + operand + ")" : operand; - - // add validation for URI string members to not be empty - writer.openBlock("if len($L) == 0 {", "}", - operand, () -> { - writer.addUseImports(SmithyGoDependency.SMITHY); - writer.write("return &smithy.SerializationError { " - + "Err: fmt.Errorf(\"input member $L must not be empty\")}", - memberShape.getMemberName()); - }); - } - - /** - * throws a serialization error if passed in member is unset. - * - * @param context Generation Context - * @param member Member shape - * @param container The name that the structure is assigned to. - * @param consumer unset member consumer - */ - private void throwSerializationErrorIfMemberUnset( - GenerationContext context, - MemberShape member, - String container, - Consumer consumer - ) { - Model model = context.getModel(); - Shape target = model.expectShape(member.getTarget()); - String memberName = context.getSymbolProvider().toMemberName(member); - String operand = container + "." + memberName; - - boolean enumShape = target.hasTrait(EnumTrait.class); - - if (!enumShape && !CodegenUtils.isNilAssignableToShape(model, member)) { - consumer.accept(operand); - return; - } - - String conditionCheck; - if (enumShape) { - conditionCheck = "len(" + operand + ") == 0"; - } else { - conditionCheck = operand + " == nil"; - } - - context.getWriter().openBlock("if $L {", "}", conditionCheck, () -> { - consumer.accept(operand); - }); - } - private void writeHeaderBinding( GenerationContext context, MemberShape memberShape, @@ -816,28 +746,24 @@ private void writeHeaderBinding( ) { GoWriter writer = context.getWriter(); Model model = context.getModel(); - SymbolProvider symbolProvider = context.getSymbolProvider(); Shape targetShape = model.expectShape(memberShape.getTarget()); if (!(targetShape instanceof CollectionShape)) { - // Only set non-empty non-nil header values - writeHeaderOperandNotEmptyCheck(model, symbolProvider, memberShape, operand, writer, () -> { - String op = conditionallyBase64Encode(writer, targetShape, operand); - writeHttpBindingSetter(model, writer, memberShape, location, op, (w, s) -> { - w.writeInline("$L.SetHeader($L).$L", dest, locationName, s); - }); + String op = conditionallyBase64Encode(writer, targetShape, operand); + writeHttpBindingSetter(context, writer, memberShape, location, op, (w, s) -> { + w.writeInline("$L.SetHeader($L).$L", dest, locationName, s); }); return; } - MemberShape collectionMemberShape = ((CollectionShape) targetShape).getMember(); + MemberShape collectionMemberShape = CodegenUtils.expectCollectionShape(targetShape).getMember(); writer.openBlock("for i := range $L {", "}", operand, () -> { // Only set non-empty non-nil header values String indexedOperand = operand + "[i]"; - writeHeaderOperandNotNilAndNotEmptyCheck(model, symbolProvider, collectionMemberShape, indexedOperand, - writer, () -> { + GoValueAccessUtils.writeIfNonZeroValue(context, writer, collectionMemberShape, indexedOperand, false, + () -> { String op = conditionallyBase64Encode(writer, targetShape, indexedOperand); - writeHttpBindingSetter(model, writer, collectionMemberShape, location, op, (w, s) -> { + writeHttpBindingSetter(context, writer, collectionMemberShape, location, op, (w, s) -> { w.writeInline("$L.AddHeader($L).$L", dest, locationName, s); }); }); @@ -855,47 +781,6 @@ private String conditionallyBase64Encode(GoWriter writer, Shape targetShape, Str return operand; } - protected void writeHeaderOperandNotNilAndNotEmptyCheck( - Model model, - SymbolProvider symbolProvider, - MemberShape memberShape, - String operand, - GoWriter writer, - Runnable consumer - ) { - Shape targetShape = model.expectShape(memberShape.getTarget()); - - String conditionCheck; - if (targetShape.hasTrait(EnumTrait.class)) { - conditionCheck = "len(" + operand + ") > 0"; - } else if (targetShape.getType() == ShapeType.STRING) { - conditionCheck = operand + " != nil && len(*" + operand + ") > 0"; - } else { - conditionCheck = operand + " != nil"; - } - - writer.openBlock("if " + conditionCheck + " {", "}", consumer); - } - - protected void writeHeaderOperandNotEmptyCheck( - Model model, - SymbolProvider symbolProvider, - MemberShape memberShape, - String operand, - GoWriter writer, - Runnable consumer - ) { - Shape targetShape = model.expectShape(memberShape.getTarget()); - - if (targetShape.hasTrait(EnumTrait.class) || targetShape.getType() != ShapeType.STRING) { - consumer.run(); - return; - } - - String conditionCheck = "len(*" + operand + ") > 0"; - writer.openBlock("if " + conditionCheck + " {", "}", consumer); - } - /** * Generates serialization functions for shapes in the passed set. These functions * should return a value that can then be serialized by the implementation of @@ -958,7 +843,7 @@ private void generateHttpBindingDeserializer(GenerationContext context, Shape sh writer.write(""); for (HttpBinding binding : bindings) { - writeRestDeserializerMember(writer, model, symbolProvider, binding); + writeRestDeserializerMember(context, writer, binding); writer.write(""); } writer.write("return nil"); @@ -966,14 +851,13 @@ private void generateHttpBindingDeserializer(GenerationContext context, Shape sh } private String generateHttpHeaderValue( + GenerationContext context, GoWriter writer, - Model model, - SymbolProvider symbolProvider, MemberShape memberShape, HttpBinding binding, String operand ) { - Shape targetShape = model.expectShape(memberShape.getTarget()); + Shape targetShape = context.getModel().expectShape(memberShape.getTarget()); if (targetShape.getType() != ShapeType.LIST && targetShape.getType() != ShapeType.SET) { writer.addUseImports(SmithyGoDependency.STRINGS); @@ -1002,7 +886,7 @@ private String generateHttpHeaderValue( return "vv"; case TIMESTAMP: writer.addUseImports(SmithyGoDependency.SMITHY_TIME); - HttpBindingIndex bindingIndex = model.getKnowledge(HttpBindingIndex.class); + HttpBindingIndex bindingIndex = context.getModel().getKnowledge(HttpBindingIndex.class); TimestampFormatTrait.Format format = bindingIndex.determineTimestampFormat( memberShape, binding.getLocation(), @@ -1088,7 +972,7 @@ private String generateHttpHeaderValue( case LIST: // handle list/Set as target shape MemberShape targetValueListMemberShape = CodegenUtils.expectCollectionShape(targetShape).getMember(); - return getHttpHeaderCollectionDeserializer(writer, model, symbolProvider, targetValueListMemberShape, + return getHttpHeaderCollectionDeserializer(context, writer, targetValueListMemberShape, binding, operand); default: @@ -1097,50 +981,48 @@ private String generateHttpHeaderValue( } private String getHttpHeaderCollectionDeserializer( + GenerationContext context, GoWriter writer, - Model model, - SymbolProvider symbolProvider, MemberShape memberShape, HttpBinding binding, String operand ) { - Shape targetShape = model.expectShape(memberShape.getTarget()); - Symbol targetSymbol = symbolProvider.toSymbol(targetShape); - writer.write("var list []$P", targetSymbol); + writer.write("var list []$P", context.getSymbolProvider().toSymbol(memberShape)); String operandValue = operand + "Val"; writer.openBlock("for _, $L := range $L {", "}", operandValue, operand, () -> { - String value = generateHttpHeaderValue(writer, model, symbolProvider, memberShape, binding, - operandValue); + String value = generateHttpHeaderValue(context, writer, memberShape, binding, operandValue); writer.write("list = append(list, $L)", - CodegenUtils.generatePointerValueIfPointable(writer, targetShape, value)); + CodegenUtils.getAsPointerIfPointable(context.getModel(), writer, + GoPointableIndex.of(context.getModel()), memberShape, value)); }); return "list"; } private void writeRestDeserializerMember( + GenerationContext context, GoWriter writer, - Model model, - SymbolProvider symbolProvider, HttpBinding binding ) { MemberShape memberShape = binding.getMember(); - Shape targetShape = model.expectShape(memberShape.getTarget()); - String memberName = symbolProvider.toMemberName(memberShape); + Shape targetShape = context.getModel().expectShape(memberShape.getTarget()); + String memberName = context.getSymbolProvider().toMemberName(memberShape); switch (binding.getLocation()) { case HEADER: - writeHeaderDeserializerFunction(writer, model, symbolProvider, memberName, memberShape, binding); + writeHeaderDeserializerFunction(context, writer, memberName, memberShape, binding); break; case PREFIX_HEADERS: if (!targetShape.isMapShape()) { throw new CodegenException("unexpected prefix-header shape type found in Http bindings"); } - writePrefixHeaderDeserializerFunction(writer, model, symbolProvider, memberName, memberShape, binding); + writePrefixHeaderDeserializerFunction(context, writer, memberName, memberShape, binding); break; case RESPONSE_CODE: writer.addUseImports(SmithyGoDependency.SMITHY_PTR); - writer.write("v.$L = ptr.Int32(int32(response.StatusCode))", memberName); + writer.write("v.$L = $L", memberName, + CodegenUtils.getAsPointerIfPointable(context.getModel(), writer, + GoPointableIndex.of(context.getModel()), memberShape, "int32(response.StatusCode)")); break; default: throw new CodegenException("unexpected http binding found"); @@ -1148,46 +1030,44 @@ private void writeRestDeserializerMember( } private void writeHeaderDeserializerFunction( + GenerationContext context, GoWriter writer, - Model model, - SymbolProvider symbolProvider, String memberName, MemberShape memberShape, HttpBinding binding ) { writer.openBlock("if headerValues := response.Header.Values($S); len(headerValues) != 0 {", "}", binding.getLocationName(), () -> { - Shape targetShape = model.expectShape(memberShape.getTarget()); + Shape targetShape = context.getModel().expectShape(memberShape.getTarget()); String operand = "headerValues"; - operand = writeHeaderValueAccessor(writer, model, targetShape, binding, operand); + operand = writeHeaderValueAccessor(context, writer, targetShape, binding, operand); - String value = generateHttpHeaderValue(writer, model, symbolProvider, memberShape, binding, + String value = generateHttpHeaderValue(context, writer, memberShape, binding, operand); writer.write("v.$L = $L", memberName, - CodegenUtils.generatePointerValueIfPointable(writer, targetShape, value)); + CodegenUtils.getAsPointerIfPointable(context.getModel(), writer, + GoPointableIndex.of(context.getModel()), memberShape, value)); }); } private void writePrefixHeaderDeserializerFunction( + GenerationContext context, GoWriter writer, - Model model, - SymbolProvider symbolProvider, String memberName, MemberShape memberShape, HttpBinding binding ) { String prefix = binding.getLocationName(); - Shape targetShape = model.expectShape(memberShape.getTarget()); + Shape targetShape = context.getModel().expectShape(memberShape.getTarget()); MemberShape valueMemberShape = targetShape.asMapShape() .orElseThrow(() -> new CodegenException("prefix headers must target map shape")) .getValue(); - Shape valueMemberTarget = model.expectShape(valueMemberShape.getTarget()); writer.openBlock("for headerKey, headerValues := range response.Header {", "}", () -> { writer.addUseImports(SmithyGoDependency.STRINGS); - Symbol targetSymbol = symbolProvider.toSymbol(targetShape); + Symbol targetSymbol = context.getSymbolProvider().toSymbol(targetShape); writer.openBlock( "if lenPrefix := len($S); " @@ -1198,12 +1078,13 @@ private void writePrefixHeaderDeserializerFunction( }); String operand = "headerValues"; - operand = writeHeaderValueAccessor(writer, model, targetShape, binding, operand); + operand = writeHeaderValueAccessor(context, writer, targetShape, binding, operand); - String value = generateHttpHeaderValue(writer, model, symbolProvider, valueMemberShape, + String value = generateHttpHeaderValue(context, writer, valueMemberShape, binding, operand); writer.write("v.$L[headerKey[lenPrefix:]] = $L", memberName, - CodegenUtils.generatePointerValueIfPointable(writer, valueMemberTarget, value)); + CodegenUtils.getAsPointerIfPointable(context.getModel(), writer, + GoPointableIndex.of(context.getModel()), valueMemberShape, value)); }); }); } @@ -1212,16 +1093,16 @@ private void writePrefixHeaderDeserializerFunction( * Returns the header value accessor operand, and also if the target shape is a list/set will write the splitting * of the header values by comma(,) utility helper. * + * @param context generation context * @param writer writer - * @param model smithy model * @param targetShape target shape * @param binding http binding location * @param operand operand of the header values. * @return returns operand for accessing the header values */ private String writeHeaderValueAccessor( + GenerationContext context, GoWriter writer, - Model model, Shape targetShape, HttpBinding binding, String operand @@ -1229,7 +1110,7 @@ private String writeHeaderValueAccessor( switch (targetShape.getType()) { case LIST: case SET: - writerHeaderListValuesSplit(writer, model, CodegenUtils.expectCollectionShape(targetShape), binding, + writerHeaderListValuesSplit(context, writer, CodegenUtils.expectCollectionShape(targetShape), binding, operand); break; default: @@ -1246,19 +1127,23 @@ private String writeHeaderValueAccessor( * has special case handling for HttpDate timestamp format when serialized as a header list. Assigns the split * header values back to the same operand name. * + * @param context generation context * @param writer writer - * @param model smithy model * @param shape target collection shape * @param binding http binding location * @param operand operand of the header values. */ private void writerHeaderListValuesSplit( - GoWriter writer, Model model, CollectionShape shape, HttpBinding binding, String operand + GenerationContext context, + GoWriter writer, + CollectionShape shape, + HttpBinding binding, + String operand ) { writer.openBlock("{", "}", () -> { writer.write("var err error"); writer.addUseImports(SmithyGoDependency.SMITHY_HTTP_TRANSPORT); - if (isHttpDateTimestamp(model, binding.getLocation(), shape.getMember())) { + if (isHttpDateTimestamp(context.getModel(), binding.getLocation(), shape.getMember())) { writer.write("$L, err = smithyhttp.SplitHTTPDateTimestampHeaderListValues($L)", operand, operand); } else { writer.write("$L, err = smithyhttp.SplitHeaderListValues($L)", operand, operand); diff --git a/codegen/smithy-go-codegen/src/main/java/software/amazon/smithy/go/codegen/integration/HttpProtocolUnitTestGenerator.java b/codegen/smithy-go-codegen/src/main/java/software/amazon/smithy/go/codegen/integration/HttpProtocolUnitTestGenerator.java index afb281618..b8d40d7fa 100644 --- a/codegen/smithy-go-codegen/src/main/java/software/amazon/smithy/go/codegen/integration/HttpProtocolUnitTestGenerator.java +++ b/codegen/smithy-go-codegen/src/main/java/software/amazon/smithy/go/codegen/integration/HttpProtocolUnitTestGenerator.java @@ -35,8 +35,8 @@ import software.amazon.smithy.model.shapes.MemberShape; import software.amazon.smithy.model.shapes.OperationShape; import software.amazon.smithy.model.shapes.ServiceShape; -import software.amazon.smithy.model.shapes.Shape; import software.amazon.smithy.model.shapes.ShapeId; +import software.amazon.smithy.model.shapes.StructureShape; import software.amazon.smithy.model.traits.IdempotencyTokenTrait; import software.amazon.smithy.protocoltests.traits.HttpMessageTestCase; import software.amazon.smithy.utils.SmithyBuilder; @@ -56,9 +56,9 @@ public abstract class HttpProtocolUnitTestGenerator clientConfigValues = new TreeSet<>(); @@ -82,11 +82,13 @@ protected HttpProtocolUnitTestGenerator(Builder builder) { opSymbol = symbolProvider.toSymbol(operation); inputShape = model.expectShape(operation.getInput() - .orElseThrow(() -> new CodegenException("missing input shape for operation: " + operation.getId()))); + .orElseThrow(() -> new CodegenException("missing input shape for operation: " + operation.getId())), + StructureShape.class); inputSymbol = symbolProvider.toSymbol(inputShape); outputShape = model.expectShape(operation.getOutput() - .orElseThrow(() -> new CodegenException("missing output shape for operation: " + operation.getId()))); + .orElseThrow(() -> new CodegenException("missing output shape for operation: " + operation.getId())), + StructureShape.class); outputSymbol = symbolProvider.toSymbol(outputShape); } @@ -284,7 +286,7 @@ protected void writeStructField(GoWriter writer, String field, String valueForma * @param shape the shape the field member. * @param params the node of values to fill the member with. */ - protected void writeStructField(GoWriter writer, String field, Shape shape, ObjectNode params) { + protected void writeStructField(GoWriter writer, String field, StructureShape shape, ObjectNode params) { writer.writeInline("$L: ", field); writeShapeValueInline(writer, shape, params); writer.write(","); @@ -545,9 +547,9 @@ protected void writeAssertForbidHeader(GoWriter writer, String expect, String ac * @param shape shape of the value type to be created. * @param params values to initialize shape type with. */ - protected void writeShapeValueInline(GoWriter writer, Shape shape, ObjectNode params) { + protected void writeShapeValueInline(GoWriter writer, StructureShape shape, ObjectNode params) { new ShapeValueGenerator(model, symbolProvider) - .writeShapeValueInline(writer, shape, params); + .writePointableStructureShapeValueInline(writer, shape, params); } /** @@ -744,12 +746,12 @@ public ConfigValue build() { public static class SkipTest implements Comparable { private final ShapeId service; private final ShapeId operation; - private final String testName; + private final List testNames; SkipTest(Builder builder) { this.service = SmithyBuilder.requiredState("service id", builder.service); this.operation = SmithyBuilder.requiredState("operation id", builder.operation); - this.testName = builder.testName; + this.testNames = builder.testNames; } /** @@ -771,12 +773,12 @@ public ShapeId getOperation() { } /** - * Get the name of the test the skip test applies to. + * Get the names of the tests the skip test applies to. * * @return the name of the test to skip */ - public String getTestName() { - return testName; + public List getTestNames() { + return testNames; } /** @@ -798,11 +800,11 @@ public boolean matches(ShapeId service, ShapeId operation, String testName) { } // SkipTests not for specific test should not match this check. - if (this.testName == null || this.testName.length() == 0) { + if (this.testNames.isEmpty()) { return false; } - return this.testName.equals(testName); + return this.testNames.contains(testName); } /** @@ -823,7 +825,7 @@ public boolean matches(ShapeId service, ShapeId operation) { } // SkipTests for specific test should not match this check. - return (this.testName == null || this.testName.length() == 0); + return this.testNames.isEmpty(); } public static Builder builder() { @@ -848,12 +850,12 @@ public boolean equals(Object o) { SkipTest that = (SkipTest) o; return Objects.equals(getService(), that.getService()) && Objects.equals(getOperation(), that.getOperation()) - && Objects.equals(getTestName(), that.getTestName()); + && Objects.equals(getTestNames(), that.getTestNames()); } @Override public int hashCode() { - return Objects.hash(service, operation, testName); + return Objects.hash(service, operation, testNames); } /** @@ -862,7 +864,7 @@ public int hashCode() { public static final class Builder implements SmithyBuilder { private ShapeId service; private ShapeId operation; - private String testName; + private List testNames = new ArrayList<>(); private Builder() { } @@ -895,8 +897,8 @@ public Builder operation(ShapeId operation) { * @param testName is the name of the test to skip * @return the builder */ - public Builder testName(String testName) { - this.testName = testName; + public Builder addTestName(String testName) { + this.testNames.add(testName); return this; } diff --git a/codegen/smithy-go-codegen/src/main/java/software/amazon/smithy/go/codegen/integration/HttpProtocolUnitTestResponseErrorGenerator.java b/codegen/smithy-go-codegen/src/main/java/software/amazon/smithy/go/codegen/integration/HttpProtocolUnitTestResponseErrorGenerator.java index 97f9ad1d5..c00f76e1a 100644 --- a/codegen/smithy-go-codegen/src/main/java/software/amazon/smithy/go/codegen/integration/HttpProtocolUnitTestResponseErrorGenerator.java +++ b/codegen/smithy-go-codegen/src/main/java/software/amazon/smithy/go/codegen/integration/HttpProtocolUnitTestResponseErrorGenerator.java @@ -26,14 +26,14 @@ import software.amazon.smithy.model.Model; import software.amazon.smithy.model.shapes.OperationShape; import software.amazon.smithy.model.shapes.ServiceShape; -import software.amazon.smithy.model.shapes.Shape; +import software.amazon.smithy.model.shapes.StructureShape; import software.amazon.smithy.protocoltests.traits.HttpResponseTestCase; /** * Generates HTTP protocol unit tests for HTTP response API error test cases. */ public class HttpProtocolUnitTestResponseErrorGenerator extends HttpProtocolUnitTestResponseGenerator { - protected final Shape errorShape; + protected final StructureShape errorShape; protected final Symbol errorSymbol; /** @@ -152,7 +152,7 @@ protected void generateTestAssertions(GoWriter writer) { } public static class Builder extends HttpProtocolUnitTestResponseGenerator.Builder { - protected Shape error; + protected StructureShape error; // TODO should be a way not to define these override methods since they are all defined in the base Builder. // the return type breaks this though since this builder adds a new builder field. @@ -187,7 +187,7 @@ public Builder operation(OperationShape operation) { return this; } - public Builder error(Shape error) { + public Builder error(StructureShape error) { this.error = error; return this; } diff --git a/codegen/smithy-go-codegen/src/main/java/software/amazon/smithy/go/codegen/integration/ProtocolUtils.java b/codegen/smithy-go-codegen/src/main/java/software/amazon/smithy/go/codegen/integration/ProtocolUtils.java index 1c7d2bc7e..bab20653b 100644 --- a/codegen/smithy-go-codegen/src/main/java/software/amazon/smithy/go/codegen/integration/ProtocolUtils.java +++ b/codegen/smithy-go-codegen/src/main/java/software/amazon/smithy/go/codegen/integration/ProtocolUtils.java @@ -19,9 +19,10 @@ import java.util.TreeSet; import java.util.function.Consumer; import software.amazon.smithy.codegen.core.CodegenException; -import software.amazon.smithy.go.codegen.CodegenUtils; +import software.amazon.smithy.go.codegen.GoWriter; import software.amazon.smithy.go.codegen.MiddlewareIdentifier; import software.amazon.smithy.go.codegen.integration.ProtocolGenerator.GenerationContext; +import software.amazon.smithy.go.codegen.knowledge.GoPointableIndex; import software.amazon.smithy.model.Model; import software.amazon.smithy.model.knowledge.OperationIndex; import software.amazon.smithy.model.neighbor.RelationshipType; @@ -31,10 +32,7 @@ import software.amazon.smithy.model.shapes.Shape; import software.amazon.smithy.model.shapes.ShapeId; import software.amazon.smithy.model.shapes.ShapeType; -import software.amazon.smithy.model.shapes.SimpleShape; import software.amazon.smithy.model.shapes.StructureShape; -import software.amazon.smithy.model.traits.EnumTrait; -import software.amazon.smithy.model.traits.StreamingTrait; import software.amazon.smithy.utils.SetUtils; /** @@ -53,7 +51,8 @@ public final class ProtocolUtils { RelationshipType.SET_MEMBER, RelationshipType.MAP_VALUE, RelationshipType.MEMBER_TARGET ); - private ProtocolUtils() {} + private ProtocolUtils() { + } /** * Resolves the entire set of shapes that will require serde given an initial set of shapes. @@ -92,7 +91,7 @@ public static Set resolveRequiredDocumentShapeSerde(Model model, Set * The following shape types will require a serde function: maps, lists, sets, documents, structures, and unions. * * @param shape the shape @@ -105,7 +104,7 @@ public static boolean requiresDocumentSerdeFunction(Shape shape) { /** * Gets the operation input as a structure shape or throws an exception. * - * @param model The model that contains the operation. + * @param model The model that contains the operation. * @param operation The operation to get the input from. * @return The operation's input as a structure shape. */ @@ -118,7 +117,7 @@ public static StructureShape expectInput(Model model, OperationShape operation) /** * Gets the operation output as a structure shape or throws an exception. * - * @param model The model that contains the operation. + * @param model The model that contains the operation. * @param operation The operation to get the output from. * @return The operation's output as a structure shape. */ @@ -129,51 +128,86 @@ public static StructureShape expectOutput(Model model, OperationShape operation) } /** - * Safely accesses a given structure member. + * Wraps the protocol's delegation function changing the destination variable to a double pointer if the + * shape type is not pointable. * - * @param context The generation context. - * @param member The member being accessed. - * @param container The name that the structure is assigned to. - * @param consumer A string consumer that is given the snippet to access the member value. + * @param context generation context + * @param writer go writer + * @param member shape to determine if pointable + * @param origDestOperand original variable name + * @param lambda runnable */ - public static void writeSafeMemberAccessor( + public static void writeDeserDelegateFunction( GenerationContext context, + GoWriter writer, MemberShape member, - String container, - Consumer consumer + String origDestOperand, + Consumer lambda ) { - Model model = context.getModel(); - Shape target = model.expectShape(member.getTarget()); - String memberName = context.getSymbolProvider().toMemberName(member); - String operand = container + "." + memberName; + Shape targetShape = context.getModel().expectShape(member.getTarget()); + Shape container = context.getModel().expectShape(member.getContainer()); - boolean enumShape = target.hasTrait(EnumTrait.class); + boolean withAddr = !GoPointableIndex.of(context.getModel()).isPointable(member) + && GoPointableIndex.of(context.getModel()).isPointable(targetShape); + boolean isMap = container.getType() == ShapeType.MAP; - if (!enumShape && !CodegenUtils.isNilAssignableToShape(model, member)) { - consumer.accept(operand); - return; + String destOperand = origDestOperand; + if (isMap) { + writer.write("mapVar := $L", origDestOperand); + destOperand = "mapVar"; } - String conditionCheck; - if (enumShape) { - conditionCheck = "len(" + operand + ") > 0"; - } else { - conditionCheck = operand + " != nil"; + if (withAddr) { + writer.write("destAddr := &$L", destOperand); + destOperand = "destAddr"; } - context.getWriter().openBlock("if $L {", "}", conditionCheck, () -> { - consumer.accept(operand); - }); + lambda.accept(destOperand); + + if (isMap || withAddr) { + if (withAddr) { + destOperand = "*" + destOperand; + } + + writer.write("$L = $L", origDestOperand, destOperand); + } } /** - * Determines whether a given shape will use a scalar when the shape is used as a union value. + * Writes helper variables for the delegation function to ensure that map values are safely delegated down + * each level. * - * @param shape the shape to check - * @return false if the shape should use pointers + * @param context generation context + * @param writer go writer + * @param member shape to determine if pointable + * @param origDestOperand original variable name + * @param lambda runnable */ - public static boolean usesScalarWhenUnionValue(Shape shape) { - return !(shape instanceof SimpleShape) || shape.isBlobShape() || shape.hasTrait(EnumTrait.class) - || shape.hasTrait(StreamingTrait.class); + public static void writeSerDelegateFunction( + GenerationContext context, + GoWriter writer, + MemberShape member, + String origDestOperand, + Consumer lambda + ) { + Shape targetShape = context.getModel().expectShape(member.getTarget()); + Shape container = context.getModel().expectShape(member.getContainer()); + + boolean withAddr = !GoPointableIndex.of(context.getModel()).isPointable(member) + && GoPointableIndex.of(context.getModel()).isPointable(targetShape); + boolean isMap = container.getType() == ShapeType.MAP; + + String destOperand = origDestOperand; + if (isMap && withAddr) { + writer.write("mapVar := $L", origDestOperand); + destOperand = "mapVar"; + } + + String acceptVar = destOperand; + if (withAddr) { + acceptVar = "&" + destOperand; + } + + lambda.accept(acceptVar); } } diff --git a/codegen/smithy-go-codegen/src/main/java/software/amazon/smithy/go/codegen/integration/ValidationGenerator.java b/codegen/smithy-go-codegen/src/main/java/software/amazon/smithy/go/codegen/integration/ValidationGenerator.java index 4a72cf0e1..d2f5e785c 100644 --- a/codegen/smithy-go-codegen/src/main/java/software/amazon/smithy/go/codegen/integration/ValidationGenerator.java +++ b/codegen/smithy-go-codegen/src/main/java/software/amazon/smithy/go/codegen/integration/ValidationGenerator.java @@ -40,6 +40,7 @@ import software.amazon.smithy.codegen.core.CodegenException; import software.amazon.smithy.codegen.core.Symbol; import software.amazon.smithy.codegen.core.SymbolProvider; +import software.amazon.smithy.go.codegen.CodegenUtils; import software.amazon.smithy.go.codegen.GoSettings; import software.amazon.smithy.go.codegen.GoStackStepMiddlewareGenerator; import software.amazon.smithy.go.codegen.GoWriter; @@ -47,10 +48,12 @@ import software.amazon.smithy.go.codegen.SmithyGoDependency; import software.amazon.smithy.go.codegen.SymbolUtils; import software.amazon.smithy.go.codegen.TriConsumer; +import software.amazon.smithy.go.codegen.knowledge.GoPointableIndex; import software.amazon.smithy.go.codegen.knowledge.GoValidationIndex; import software.amazon.smithy.model.Model; import software.amazon.smithy.model.shapes.CollectionShape; import software.amazon.smithy.model.shapes.MapShape; +import software.amazon.smithy.model.shapes.MemberShape; import software.amazon.smithy.model.shapes.OperationShape; import software.amazon.smithy.model.shapes.ServiceShape; import software.amazon.smithy.model.shapes.Shape; @@ -126,13 +129,19 @@ private void generateShapeValidationFunctions( Set operationInputShapes, Set shapesWithHelpers ) { + GoPointableIndex pointableIndex = GoPointableIndex.of(model); + for (Shape shape : shapesWithHelpers) { boolean topLevelShape = operationInputShapes.contains(shape); String functionName = getShapeValidationFunctionName(shape, topLevelShape); Symbol shapeSymbol = symbolProvider.toSymbol(shape); writer.openBlock("func $L(v $P) error {", "}", functionName, shapeSymbol, () -> { writer.addUseImports(SmithyGoDependency.SMITHY); - writer.openBlock("if v == nil {", "}", () -> writer.write("return nil")); + + if (pointableIndex.isNillable(shape)) { + writer.openBlock("if v == nil {", "}", () -> writer.write("return nil")); + } + writer.write("invalidParams := smithy.InvalidParamsError{Context: $S}", shapeSymbol.getName()); switch (shape.getType()) { case STRUCTURE: @@ -142,19 +151,26 @@ private void generateShapeValidationFunctions( boolean required = GoValidationIndex.isRequiredParameter(model, memberShape, topLevelShape); boolean hasHelper = shapesWithHelpers.contains(targetShape); boolean isEnum = targetShape.getTrait(EnumTrait.class).isPresent(); + if (required) { + Runnable runnable = () -> { + writer.write("invalidParams.Add(smithy.NewErrParamRequired($S))", memberName); + if (hasHelper) { + writer.writeInline("} else "); + } else { + writer.write("}"); + } + }; + if (isEnum) { writer.write("if len(v.$L) == 0 {", memberName); - } else { + runnable.run(); + } else if (pointableIndex.isNillable(memberShape)) { writer.write("if v.$L == nil {", memberName); - } - writer.write("invalidParams.Add(smithy.NewErrParamRequired($S))", memberName); - if (hasHelper) { - writer.writeInline("} else "); - } else { - writer.write("}"); + runnable.run(); } } + if (hasHelper) { Runnable runnable = () -> { String helperName = getShapeValidationFunctionName(targetShape, false); @@ -166,42 +182,59 @@ private void generateShapeValidationFunctions( memberName); }); }; + if (isEnum) { writer.openBlock("if len(v.$L) > 0 {", "}", memberName, runnable); - } else { + } else if (pointableIndex.isNillable(memberShape)) { writer.openBlock("if v.$L != nil {", "}", memberName, runnable); } - } }); break; + case LIST: case SET: - String helperName = getShapeValidationFunctionName(model.expectShape(((CollectionShape) shape) - .getMember().getTarget()), false); + CollectionShape collectionShape = CodegenUtils.expectCollectionShape(shape); + MemberShape member = collectionShape.getMember(); + Shape memberTarget = model.expectShape(member.getTarget()); + String helperName = getShapeValidationFunctionName(memberTarget, false); + writer.openBlock("for i := range v {", "}", () -> { - writer.openBlock("if err := $L(v[i]); err != nil {", "}", helperName, () -> { + String addr = ""; + if (!pointableIndex.isPointable(member) && pointableIndex.isPointable(memberTarget)) { + addr = "&"; + } + writer.openBlock("if err := $L($Lv[i]); err != nil {", "}", helperName, addr, () -> { writer.addUseImports(SmithyGoDependency.SMITHY); - writer.write( - "invalidParams.AddNested(fmt.Sprintf(\"[%d]\", i), " + writer.write("invalidParams.AddNested(fmt.Sprintf(\"[%d]\", i), " + "err.(smithy.InvalidParamsError))"); }); }); break; + case MAP: - helperName = getShapeValidationFunctionName(model.expectShape(((MapShape) shape).getValue() - .getTarget()), false); + MapShape mapShape = shape.asMapShape().get(); + MemberShape mapValue = mapShape.getValue(); + Shape valueTarget = model.expectShape(mapValue.getTarget()); + helperName = getShapeValidationFunctionName(valueTarget, false); + writer.openBlock("for key := range v {", "}", () -> { - writer.openBlock("if err := $L(v[key]); err != nil {", "}", helperName, () -> { + String valueVar = "v[key]"; + if (!pointableIndex.isPointable(mapValue) && pointableIndex.isPointable(valueTarget)) { + writer.write("value := $L", valueVar); + valueVar = "&value"; + } + writer.openBlock("if err := $L($L); err != nil {", "}", helperName, valueVar, () -> { writer.addUseImports(SmithyGoDependency.SMITHY); - writer.write( - "invalidParams.AddNested(fmt.Sprintf(\"[%q]\", key), " + writer.write("invalidParams.AddNested(fmt.Sprintf(\"[%q]\", key), " + "err.(smithy.InvalidParamsError))"); }); }); break; + case UNION: // TODO: Implement Union support + default: throw new CodegenException("Unexpected validation helper shape type " + shape.getType()); } diff --git a/codegen/smithy-go-codegen/src/main/java/software/amazon/smithy/go/codegen/knowledge/GoPointableIndex.java b/codegen/smithy-go-codegen/src/main/java/software/amazon/smithy/go/codegen/knowledge/GoPointableIndex.java new file mode 100644 index 000000000..83e1fcf6f --- /dev/null +++ b/codegen/smithy-go-codegen/src/main/java/software/amazon/smithy/go/codegen/knowledge/GoPointableIndex.java @@ -0,0 +1,242 @@ +/* + * Copyright 2020 Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). + * You may not use this file except in compliance with the License. + * A copy of the License is located at + * + * http://aws.amazon.com/apache2.0 + * + * or in the "license" file accompanying this file. This file is distributed + * on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either + * express or implied. See the License for the specific language governing + * permissions and limitations under the License. + * + * + */ + +package software.amazon.smithy.go.codegen.knowledge; + +import java.util.HashSet; +import java.util.Set; +import java.util.logging.Logger; +import software.amazon.smithy.model.Model; +import software.amazon.smithy.model.knowledge.KnowledgeIndex; +import software.amazon.smithy.model.knowledge.NeighborProviderIndex; +import software.amazon.smithy.model.knowledge.NullableIndex; +import software.amazon.smithy.model.neighbor.NeighborProvider; +import software.amazon.smithy.model.neighbor.Relationship; +import software.amazon.smithy.model.neighbor.RelationshipType; +import software.amazon.smithy.model.shapes.MemberShape; +import software.amazon.smithy.model.shapes.Shape; +import software.amazon.smithy.model.shapes.ShapeId; +import software.amazon.smithy.model.shapes.ShapeType; +import software.amazon.smithy.model.shapes.ToShapeId; +import software.amazon.smithy.model.traits.EnumTrait; +import software.amazon.smithy.model.traits.StreamingTrait; +import software.amazon.smithy.utils.SetUtils; + +/** + * An index that checks if a member or shape type should be a pointer type in Go. + *

+ * Extends the rules of smithy's NullableIndex for Go's translation of the smithy shapes to Go types. + */ +public class GoPointableIndex implements KnowledgeIndex { + private static final Logger LOGGER = Logger.getLogger(GoPointableIndex.class.getName()); + + // All types that are Go value types + private static final Set INHERENTLY_VALUE = SetUtils.of( + ShapeType.BLOB, + ShapeType.LIST, + ShapeType.SET, + ShapeType.MAP, + ShapeType.UNION, + ShapeType.DOCUMENT + ); + + // All types that are Go pointer types + private static final Set INHERENTLY_POINTABLE = SetUtils.of( + ShapeType.BIG_DECIMAL, + ShapeType.BIG_INTEGER + ); + + // All types that cannot be dereferenced + private static final Set INHERENTLY_NONDEREFERENCABLE = SetUtils.of( + // built in slice/map + ShapeType.BLOB, + ShapeType.LIST, + ShapeType.SET, + ShapeType.MAP, + + // Interfaces + ShapeType.UNION, + ShapeType.DOCUMENT, + + // known pointer types. + ShapeType.BIG_DECIMAL, + ShapeType.BIG_INTEGER + ); + + // All types types that are comparable to nil + private static final Set INHERENTLY_NILLABLE = SetUtils.of( + // built in slice/map + ShapeType.BLOB, + ShapeType.LIST, + ShapeType.SET, + ShapeType.MAP, + + // Interfaces + ShapeType.UNION, + ShapeType.DOCUMENT, + + // known pointer types. + ShapeType.BIG_DECIMAL, + ShapeType.BIG_INTEGER + ); + + + + private final Model model; + private final NullableIndex nullableIndex; + private final Set pointableShapes = new HashSet<>(); + private final Set nillableShapes = new HashSet<>(); + private final Set dereferencableShapes = new HashSet<>(); + + public GoPointableIndex(Model model) { + this.model = model; + this.nullableIndex = NullableIndex.of(model); + + for (Shape shape : model.toSet()) { + if (shape.asMemberShape().isPresent()) { + MemberShape member = shape.asMemberShape().get(); + Shape targetShape = model.expectShape(member.getTarget()); + + if (isMemberPointable(member, targetShape)) { + pointableShapes.add(shape.getId()); + } + if (isMemberNillable(member, targetShape)) { + nillableShapes.add(shape.getId()); + } + if (isMemberDereferencable(member, targetShape)) { + dereferencableShapes.add(shape.getId()); + } + } else { + if (isShapePointable(shape)) { + pointableShapes.add(shape.getId()); + nillableShapes.add(shape.getId()); + } + if (isShapeNillable(shape)) { + nillableShapes.add(shape.getId()); + } + if (isShapeDereferencable(shape)) { + dereferencableShapes.add(shape.getId()); + } + } + } + } + + public static GoPointableIndex of(Model model) { + return model.getKnowledge(GoPointableIndex.class, GoPointableIndex::new); + } + + private boolean isMemberDereferencable(MemberShape member, Shape targetShape) { + return isShapeDereferencable(targetShape) && isMemberPointable(member, targetShape); + } + + private boolean isMemberNillable(MemberShape member, Shape targetShape) { + return INHERENTLY_NILLABLE.contains(targetShape.getType()) || isMemberPointable(member, targetShape); + } + + private boolean isMemberPointable(MemberShape member, Shape targetShape) { + return isShapePointable(targetShape) && nullableIndex.isNullable(member); + } + + private boolean isShapeDereferencable(Shape shape) { + return !INHERENTLY_NONDEREFERENCABLE.contains(shape.getType()) && isShapePointable(shape); + } + + private boolean isShapeNillable(Shape shape) { + return INHERENTLY_NILLABLE.contains(shape.getType()) || isShapePointable(shape); + } + + private boolean isShapePointable(Shape shape) { + // All operation input and output shapes are pointable. + if (isOperationStruct(shape)) { + return true; + } + + // Streamed blob shapes are never pointers because they are interfaces + if (isBlobStream(shape)) { + return false; + } + + if (shape.isServiceShape()) { + return true; + } + + // This is odd because its not a go type but a function with receiver + if (shape.isOperationShape()) { + return false; + } + + if (INHERENTLY_POINTABLE.contains(shape.getType())) { + return true; + } + + if (INHERENTLY_VALUE.contains(shape.getType()) || isShapeEnum(shape)) { + return false; + } + + return nullableIndex.isNullable(shape); + } + + private boolean isShapeEnum(Shape shape) { + return shape.getType() == ShapeType.STRING && shape.hasTrait(EnumTrait.class); + } + + private boolean isBlobStream(Shape shape) { + return shape.getType() == ShapeType.BLOB && shape.hasTrait(StreamingTrait.ID); + } + + private boolean isOperationStruct(Shape shape) { + NeighborProvider provider = NeighborProviderIndex.of(model).getReverseProvider(); + for (Relationship relationship : provider.getNeighbors(shape)) { + RelationshipType relationshipType = relationship.getRelationshipType(); + if (relationshipType == RelationshipType.INPUT || relationshipType == RelationshipType.OUTPUT) { + return true; + } + } + + return false; + } + + /** + * Returns if the shape should be generated as a Go pointer type or not. + * + * @param shape the shape to check if should be pointable type. + * @return if the shape is should be a Go pointer type. + */ + public final boolean isPointable(ToShapeId shape) { + return pointableShapes.contains(shape.toShapeId()); + } + + /** + * Returns if the Go type generated for the shape is comparable to nil. + * + * @param shape the shape to check + * @return if the shape's go type is comparable to nil + */ + public final boolean isNillable(ToShapeId shape) { + return nillableShapes.contains(shape.toShapeId()); + } + + /** + * Returns if the Go type generated for the shape can be dereferenced. + * + * @param shape the shape to check + * @return if the shape's go type is dereferencable + */ + public final boolean isDereferencable(ToShapeId shape) { + return dereferencableShapes.contains(shape.toShapeId()); + } +}