Skip to content

Commit

Permalink
fixup jmespath multiselect codegen (#551)
Browse files Browse the repository at this point in the history
  • Loading branch information
lucix-aws authored Nov 6, 2024
1 parent a4c9efc commit 253cd26
Show file tree
Hide file tree
Showing 5 changed files with 70 additions and 49 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,9 @@
import static software.amazon.smithy.go.codegen.util.ShapeUtil.BOOL_SHAPE;
import static software.amazon.smithy.go.codegen.util.ShapeUtil.INT_SHAPE;
import static software.amazon.smithy.go.codegen.util.ShapeUtil.STRING_SHAPE;
import static software.amazon.smithy.go.codegen.util.ShapeUtil.listOf;
import static software.amazon.smithy.utils.StringUtils.capitalize;

import java.util.HashMap;
import java.util.List;
import java.util.Map;
import software.amazon.smithy.codegen.core.CodegenException;
Expand Down Expand Up @@ -64,6 +64,12 @@ public class GoJmespathExpressionGenerator {

private int idIndex = 0;

// as we traverse an expression, we may produce intermediate "synthetic" lists - e.g. list of string, list of list
// of string, etc.
// we may need to pull the member shapes back out later, but they're not guaranteed to be in the model - so keep a
// shadow map of synthetic -> member to short-circuit the model lookup
private final Map<Shape, Shape> synthetics = new HashMap<>();

public GoJmespathExpressionGenerator(GoCodegenContext ctx, GoWriter writer) {
this.ctx = ctx;
this.writer = writer;
Expand Down Expand Up @@ -121,8 +127,17 @@ private Variable visitMultiSelectList(MultiSelectListExpression expr, Variable c
var first = items.get(0);

var ident = nextIdent();
writer.write("$L := []$P{$L}", ident, first.type,
String.join(",", items.stream().map(it -> it.ident).toList()));
writer.write("$L := []$T{}", ident, first.type);
for (var item : items) {
if (isPointable(item.type)) {
writer.write("""
if $2L != nil {
$1L = append($1L, *$2L)
}""", ident, item.ident);
} else {
writer.write("$1L = append($1L, $2L)", ident, item.ident);
}
}

return new Variable(listOf(first.shape), ident, sliceOf(first.type));
}
Expand Down Expand Up @@ -247,11 +262,13 @@ private Variable visitProjection(ProjectionExpression expr, Variable current) {
writer.indent();
// projected.shape is the _member_ of the resulting list
var projected = visit(expr.getRight(), new Variable(leftMember, "v", leftSymbol));
if (isPointable(lookahead.type)) { // projections implicitly filter out nil evaluations of RHS
if (isPointable(lookahead.type)) { // projections implicitly filter out nil evaluations of RHS...
var deref = lookahead.shape instanceof CollectionShape || lookahead.shape instanceof MapShape
? "" : "*"; // ...but slices/maps do not get dereferenced
writer.write("""
if $2L != nil {
$1L = append($1L, *$2L)
}""", ident, projected.ident);
if $1L != nil {
$2L = append($2L, $3L$1L)
}""", projected.ident, ident, deref);
} else {
writer.write("$1L = append($1L, $2L)", ident, projected.ident);
}
Expand Down Expand Up @@ -348,22 +365,22 @@ private String nextIdent() {
return "v" + idIndex;
}

private Shape listOf(Shape shape) {
var list = ShapeUtil.listOf(shape);
synthetics.putIfAbsent(list, shape);
return list;
}

private Shape expectMember(CollectionShape shape) {
return switch (shape.getMember().getTarget().toString()) {
case "smithy.go.synthetic#StringList" -> listOf(STRING_SHAPE);
case "smithy.go.synthetic#IntegerList" -> listOf(INT_SHAPE);
case "smithy.go.synthetic#BooleanList" -> listOf(BOOL_SHAPE);
default -> ShapeUtil.expectMember(ctx.model(), shape);
};
return synthetics.containsKey(shape)
? synthetics.get(shape)
: ShapeUtil.expectMember(ctx.model(), shape);
}

private Shape expectMember(MapShape shape) {
return switch (shape.getValue().getTarget().toString()) {
case "smithy.go.synthetic#StringList" -> listOf(STRING_SHAPE);
case "smithy.go.synthetic#IntegerList" -> listOf(INT_SHAPE);
case "smithy.go.synthetic#BooleanList" -> listOf(BOOL_SHAPE);
default -> ShapeUtil.expectMember(ctx.model(), shape);
};
return synthetics.containsKey(shape)
? synthetics.get(shape)
: ShapeUtil.expectMember(ctx.model(), shape);
}

// helper to generate comparisons from two results, automatically handling any dereferencing in the process
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,14 +23,11 @@
import software.amazon.smithy.go.codegen.GoJmespathExpressionGenerator;
import software.amazon.smithy.go.codegen.GoWriter;
import software.amazon.smithy.go.codegen.SmithyGoTypes;
import software.amazon.smithy.go.codegen.knowledge.GoPointableIndex;
import software.amazon.smithy.jmespath.JmespathExpression;
import software.amazon.smithy.model.node.Node;
import software.amazon.smithy.model.shapes.OperationShape;
import software.amazon.smithy.model.shapes.StructureShape;
import software.amazon.smithy.rulesengine.language.EndpointRuleSet;
import software.amazon.smithy.rulesengine.language.syntax.Identifier;
import software.amazon.smithy.rulesengine.language.syntax.parameters.ParameterType;
import software.amazon.smithy.rulesengine.traits.ContextParamTrait;
import software.amazon.smithy.rulesengine.traits.EndpointRuleSetTrait;
import software.amazon.smithy.rulesengine.traits.OperationContextParamDefinition;
Expand Down Expand Up @@ -103,33 +100,14 @@ private GoWriter.Writable generateOperationContextParamBindings() {
}

private GoWriter.Writable generateOpContextParamBinding(String paramName, OperationContextParamDefinition def) {
var param = rules.getParameters().get(Identifier.of(paramName)).get();
var expr = JmespathExpression.parse(def.getPath());

return writer -> {
var generator = new GoJmespathExpressionGenerator(ctx, writer);

writer.write("func() {"); // contain the scope for each binding
var result = generator.generate(expr, new GoJmespathExpressionGenerator.Variable(input, "in"));

if (param.getType().equals(ParameterType.STRING_ARRAY)) {
// projections can result in either []string OR []*string -- if the latter, we have to unwrap
var target = result.shape().asListShape().get().getMember().getTarget();
if (GoPointableIndex.of(ctx.model()).isPointable(target)) {
writer.write("""
deref := []string{}
for _, v := range $L {
if v != nil {
deref = append(deref, *v)
}
}
p.$L = deref""", result.ident(), capitalize(paramName));
} else {
writer.write("p.$L = $L", capitalize(paramName), result.ident());
}
} else {
writer.write("p.$L = $L", capitalize(paramName), result.ident());
}
writer.write("p.$L = $L", capitalize(paramName), result.ident());
writer.write("}()");
};
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,10 @@
import software.amazon.smithy.rulesengine.language.Endpoint;
import software.amazon.smithy.rulesengine.language.EndpointRuleSet;
import software.amazon.smithy.rulesengine.language.error.RuleError;
import software.amazon.smithy.rulesengine.language.evaluation.type.ArrayType;
import software.amazon.smithy.rulesengine.language.evaluation.type.OptionalType;
import software.amazon.smithy.rulesengine.language.evaluation.type.StringType;
import software.amazon.smithy.rulesengine.language.evaluation.type.Type;
import software.amazon.smithy.rulesengine.language.syntax.Identifier;
import software.amazon.smithy.rulesengine.language.syntax.expressions.Expression;
import software.amazon.smithy.rulesengine.language.syntax.expressions.ExpressionVisitor;
Expand Down Expand Up @@ -304,14 +307,21 @@ private GoWriter.Writable generateRule(Rule rule, List<Condition> conditions, Sc
}

if (fn.type() instanceof OptionalType || isConditionalFnResultOptional(condition, fn)) {
// []string (e.g. in endpoint params) needs to be casted as stringSlice instead of dereferenced for index
// operations
var isStringSlice = false;
if (fn.type() instanceof OptionalType opt) {
isStringSlice = isStringSlice(opt.inner());
}
return goTemplate("""
if exprVal := $target:W; exprVal != nil {
$conditionIdent:L := *exprVal
$conditionIdent:L := $exprVal:L
_ = $conditionIdent:L
$next:W
}
""",
MapUtils.of(
"exprVal", isStringSlice ? "stringSlice(exprVal)" : "*exprVal",
"conditionIdent", conditionIdentifier,
"target", generator.generate(fn),
"next", generateRule(
Expand Down Expand Up @@ -601,4 +611,11 @@ private GoWriter.Writable generateStringSliceHelper() {
return &v
}""");
}

private boolean isStringSlice(Type type) {
if (!(type instanceof ArrayType array)) {
return false;
}
return array.getMember() instanceof StringType;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -27,15 +27,15 @@

public final class ShapeUtil {
public static final StringShape STRING_SHAPE = StringShape.builder()
.id("smithy.api#String")
.id("smithy.api#PrimitiveString")
.build();

public static final IntegerShape INT_SHAPE = IntegerShape.builder()
.id("smithy.api#Integer")
.id("smithy.api#PrimitiveInteger")
.build();

public static final BooleanShape BOOL_SHAPE = BooleanShape.builder()
.id("smithy.api#Boolean")
.id("smithy.api#PrimitiveBoolean")
.build();

private ShapeUtil() {}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -489,7 +489,13 @@ public void testMultiSelect() {
assertThat(writer.toString(), Matchers.containsString("""
v1 := input.SimpleShape
v2 := input.SimpleShape2
v3 := []*string{v1,v2}
v3 := []string{}
if v1 != nil {
v3 = append(v3, *v1)
}
if v2 != nil {
v3 = append(v3, *v2)
}
"""));
}

Expand All @@ -510,9 +516,12 @@ public void testMultiSelectFlatten() {
var v2 [][]string
for _, v := range v1 {
v3 := v.Key
v4 := []*string{v3}
v4 := []string{}
if v3 != nil {
v4 = append(v4, *v3)
}
if v4 != nil {
v2 = append(v2, *v4)
v2 = append(v2, v4)
}
}
var v5 []string
Expand Down

0 comments on commit 253cd26

Please sign in to comment.