Skip to content

shorten internal serde fn names #730

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Apr 6, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -448,8 +448,8 @@ private void writeSerdeDispatcher(boolean isInput) {
writer.write("throw new Error(\"No supported protocol was found\");");
} else {
String serdeFunctionName = isInput
? ProtocolGenerator.getSerFunctionName(symbol, protocolGenerator.getName())
: ProtocolGenerator.getDeserFunctionName(symbol, protocolGenerator.getName());
? ProtocolGenerator.getSerFunctionShortName(symbol)
: ProtocolGenerator.getDeserFunctionShortName(symbol);
writer.addImport(serdeFunctionName, serdeFunctionName,
Paths.get(".", CodegenUtils.SOURCE_FOLDER, ProtocolGenerator.PROTOCOLS_FOLDER,
ProtocolGenerator.getSanitizedName(protocolGenerator.getName())).toString());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -283,7 +283,7 @@ private String getDelegateDeserializer(Shape shape) {
private String getDelegateDeserializer(Shape shape, String customDataSource) {
// Use the shape for the function name.
Symbol symbol = context.getSymbolProvider().toSymbol(shape);
return ProtocolGenerator.getDeserFunctionName(symbol, context.getProtocolName())
return ProtocolGenerator.getDeserFunctionShortName(symbol)
+ "(" + customDataSource + ", context)";
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -252,7 +252,7 @@ public final String unionShape(UnionShape shape) {
private String getDelegateSerializer(Shape shape) {
// Use the shape for the function name.
Symbol symbol = context.getSymbolProvider().toSymbol(shape);
return ProtocolGenerator.getSerFunctionName(symbol, context.getProtocolName())
return ProtocolGenerator.getSerFunctionShortName(symbol)
+ "(" + dataSource + ", context)";
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -301,9 +301,12 @@ protected final void generateDeserFunction(

Symbol symbol = symbolProvider.toSymbol(shape);
// Use the shape name for the function name.
String methodName = ProtocolGenerator.getDeserFunctionName(symbol, context.getProtocolName());
String methodName = ProtocolGenerator.getDeserFunctionShortName(symbol);
String methodLongName =
ProtocolGenerator.getDeserFunctionName(symbol, context.getProtocolName());

writer.addImport(symbol, symbol.getName());
writer.writeDocs(methodLongName);
writer.openBlock("const $L = (\n"
+ " output: any,\n"
+ " context: __SerdeContext\n"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -297,9 +297,12 @@ private void generateSerFunction(

Symbol symbol = symbolProvider.toSymbol(shape);
// Use the shape name for the function name.
String methodName = ProtocolGenerator.getSerFunctionName(symbol, context.getProtocolName());
String methodName = ProtocolGenerator.getSerFunctionShortName(symbol);
String methodLongName = ProtocolGenerator.getSerFunctionName(symbol, context.getProtocolName());

writer.addImport(symbol, symbol.getName());

writer.writeDocs(methodLongName);
writer.openBlock("const $L = (\n"
+ " input: $T,\n"
+ " context: __SerdeContext\n"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -177,31 +177,34 @@ public void generateEventStreamDeserializers(

private void generateEventStreamSerializer(GenerationContext context, UnionShape eventsUnion) {
String methodName = getSerFunctionName(context, eventsUnion);
String methodLongName = ProtocolGenerator.getSerFunctionName(getSymbol(context, eventsUnion),
context.getProtocolName());
Symbol eventsUnionSymbol = getSymbol(context, eventsUnion);
TypeScriptWriter writer = context.getWriter();
Model model = context.getModel();
writer.addImport("Message", "__Message", TypeScriptDependency.AWS_SDK_TYPES.packageName);

writer.writeDocs(methodLongName);
writer.openBlock("const $L = (\n"
+ " input: any,\n"
+ " context: $L\n"
+ "): any => {", "}", methodName, getEventStreamSerdeContextType(context, eventsUnion), () -> {
writer.openBlock("const eventMarshallingVisitor = (event: any): __Message => $T.visit(event, {", "});",
eventsUnionSymbol, () -> {
eventsUnion.getAllMembers().forEach((memberName, memberShape) -> {
StructureShape target = model.expectShape(memberShape.getTarget(), StructureShape.class);
StructureShape target = model.expectShape(memberShape.getTarget(), StructureShape.class);
String eventSerMethodName = getEventSerFunctionName(context, target);
writer.write("$L: value => $L(value, context),", memberName, eventSerMethodName);
});
writer.write("_: value => value as any");
});
writer.write("return context.eventStreamMarshaller.serialize(input, eventMarshallingVisitor);");
});
writer.write("return context.eventStreamMarshaller.serialize(input, eventMarshallingVisitor);");
});
}

private String getSerFunctionName(GenerationContext context, Shape shape) {
Symbol symbol = getSymbol(context, shape);
String protocolName = context.getProtocolName();
return ProtocolGenerator.getSerFunctionName(symbol, protocolName);
return ProtocolGenerator.getSerFunctionShortName(symbol);
}

public String getEventSerFunctionName(GenerationContext context, Shape shape) {
Expand Down Expand Up @@ -347,7 +350,7 @@ private void writeEventBody(
writer.write("body = context.utf8Decoder(input.$L);", payloadMemberName);
} else if (payloadShape instanceof BlobShape || payloadShape instanceof StringShape) {
Symbol symbol = getSymbol(context, payloadShape);
String serFunctionName = ProtocolGenerator.getSerFunctionName(symbol, context.getProtocolName());
String serFunctionName = ProtocolGenerator.getSerFunctionShortName(symbol);
documentShapesToSerialize.add(payloadShape);
writer.write("body = $L(input.$L, context);", payloadMemberName, serFunctionName);
serializeInputEventDocumentPayload.run();
Expand All @@ -364,7 +367,7 @@ private void writeEventBody(
}
}
Symbol symbol = getSymbol(context, event);
String serFunctionName = ProtocolGenerator.getSerFunctionName(symbol, context.getProtocolName());
String serFunctionName = ProtocolGenerator.getSerFunctionShortName(symbol);
documentShapesToSerialize.add(event);
writer.write("body = $L(input, context);", serFunctionName);
serializeInputEventDocumentPayload.run();
Expand All @@ -373,10 +376,14 @@ private void writeEventBody(

private void generateEventStreamDeserializer(GenerationContext context, UnionShape eventsUnion) {
String methodName = getDeserFunctionName(context, eventsUnion);
String methodLongName = ProtocolGenerator.getDeserFunctionName(getSymbol(context, eventsUnion),
context.getProtocolName());
Symbol eventsUnionSymbol = getSymbol(context, eventsUnion);
TypeScriptWriter writer = context.getWriter();
Model model = context.getModel();
String contextType = getEventStreamSerdeContextType(context, eventsUnion);

writer.writeDocs(methodLongName);
writer.openBlock("const $L = (\n"
+ " output: any,\n"
+ " context: $L\n"
Expand All @@ -401,8 +408,7 @@ private void generateEventStreamDeserializer(GenerationContext context, UnionSha

private String getDeserFunctionName(GenerationContext context, Shape shape) {
Symbol symbol = getSymbol(context, shape);
String protocolName = context.getProtocolName();
return ProtocolGenerator.getDeserFunctionName(symbol, protocolName);
return ProtocolGenerator.getDeserFunctionShortName(symbol);
}

public String getEventDeserFunctionName(GenerationContext context, Shape shape) {
Expand Down Expand Up @@ -444,7 +450,7 @@ private void generateErrorEventUnmarshaller(
TypeScriptWriter writer = context.getWriter();
// If this is an error event, we need to generate the error deserializer.
errorShapesToDeserialize.add(event);
String errorDeserMethodName = getDeserFunctionName(context, event) + "Response";
String errorDeserMethodName = getDeserFunctionName(context, event) + "Res";
if (isErrorCodeInBody) {
// If error code is in body, parseBody() won't be called inside error deser. So we parse body here.
// It's ok to parse body here because body won't be streaming if 'isErrorCodeInBody' is set.
Expand Down Expand Up @@ -489,14 +495,14 @@ private void readEventBody(
} else if (payloadShape instanceof StructureShape || payloadShape instanceof UnionShape) {
writer.write("const data: any = await parseBody(output.body, context);");
Symbol symbol = getSymbol(context, payloadShape);
String deserFunctionName = ProtocolGenerator.getDeserFunctionName(symbol, context.getProtocolName());
String deserFunctionName = ProtocolGenerator.getDeserFunctionShortName(symbol);
writer.write("contents.$L = $L(data, context);", payloadMemberName, deserFunctionName);
eventShapesToDeserialize.add(payloadShape);
}
} else {
writer.write("const data: any = await parseBody(output.body, context);");
Symbol symbol = getSymbol(context, event);
String deserFunctionName = ProtocolGenerator.getDeserFunctionName(symbol, context.getProtocolName());
String deserFunctionName = ProtocolGenerator.getDeserFunctionShortName(symbol);
writer.write("Object.assign(contents, $L(data, context));", deserFunctionName);
eventShapesToDeserialize.add(event);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -397,7 +397,7 @@ public void generateServiceHandlerFactory(GenerationContext context) {
generateServiceMux(context);
writer.addImport("ServiceException", "__ServiceException", "@aws-smithy/server-common");
writer.openBlock("const serFn: (op: $1T) => __OperationSerializer<$2T<Context>, $1T, __ServiceException> = "
+ "(op) => {", "};", operationsSymbol, serviceSymbol, () -> {
+ "(op) => {", "};", operationsSymbol, serviceSymbol, () -> {
writer.openBlock("switch (op) {", "}", () -> {
operations.stream()
.filter(o -> o.getTrait(HttpTrait.class).isPresent())
Expand Down Expand Up @@ -440,12 +440,12 @@ public void generateOperationHandlerFactory(GenerationContext context, Operation

if (context.getSettings().isDisableDefaultValidation()) {
writer.write("export const get$L = <Context>(operation: __Operation<$T, $T, Context>, "
+ "customizer: __ValidationCustomizer<$S>): "
+ "__ServiceHandler<Context, __HttpRequest, __HttpResponse> => {",
+ "customizer: __ValidationCustomizer<$S>): "
+ "__ServiceHandler<Context, __HttpRequest, __HttpResponse> => {",
operationHandlerSymbol.getName(), inputType, outputType, operationSymbol.getName());
} else {
writer.write("export const get$L = <Context>(operation: __Operation<$T, $T, Context>): "
+ "__ServiceHandler<Context, __HttpRequest, __HttpResponse> => {",
+ "__ServiceHandler<Context, __HttpRequest, __HttpResponse> => {",
operationHandlerSymbol.getName(), inputType, outputType);
}
writer.indent();
Expand Down Expand Up @@ -642,19 +642,25 @@ private void generateOperationRequestSerializer(
// Ensure that the request type is imported.
writer.addUseImports(requestType);
writer.addImport("Endpoint", "__Endpoint", "@aws-sdk/types");

// e.g., se_ES
String methodName = ProtocolGenerator.getSerFunctionShortName(symbol);
// e.g., serializeAws_restJson1_1ExecuteStatement
String methodName = ProtocolGenerator.getSerFunctionName(symbol, getName());
String methodLongName = ProtocolGenerator.getSerFunctionName(symbol, getName());

// Add the normalized input type.
Symbol inputType = symbol.expectProperty("inputType", Symbol.class);
String contextType = CodegenUtils.getOperationSerializerContextType(writer, context.getModel(), operation);

writer.writeDocs(methodLongName);
writer.openBlock("export const $L = async(\n"
+ " input: $T,\n"
+ " context: $L\n"
+ "): Promise<$T> => {", "}", methodName, inputType, contextType, requestType, () -> {

// Get the hostname, path, port, and scheme from client's resolved endpoint. Then construct the request from
// them. The client's resolved endpoint can be default one or supplied by users.
// Get the hostname, path, port, and scheme from client's resolved endpoint.
// Then construct the request from them. The client's resolved endpoint can
// be default one or supplied by users.
writer.write("const {hostname, protocol = $S, port, path: basePath} = await context.endpoint();", "https");

writeRequestHeaders(context, operation, bindingIndex);
Expand Down Expand Up @@ -777,12 +783,12 @@ private void writeResolvedPath(
Shape target = model.expectShape(binding.getMember().getTarget());

String labelValueProvider = "() => " + getInputValue(
context,
binding.getLocation(),
"input." + memberName + "!",
binding.getMember(),
target
);
context,
binding.getLocation(),
"input." + memberName + "!",
binding.getMember(),
target
);

// Get the correct label to use.
Segment uriLabel = uriLabels.stream().filter(s -> s.getContent().equals(memberName)).findFirst().get();
Expand Down Expand Up @@ -1342,7 +1348,7 @@ private String getNamedMembersInputParam(
switch (bindingType) {
case PAYLOAD:
Symbol symbol = context.getSymbolProvider().toSymbol(target);
return ProtocolGenerator.getSerFunctionName(symbol, context.getProtocolName())
return ProtocolGenerator.getSerFunctionShortName(symbol)
+ "(" + dataSource + ", context)";
default:
throw new CodegenException("Unexpected named member shape binding location `" + bindingType + "`");
Expand Down Expand Up @@ -1887,17 +1893,18 @@ private void readDirectQueryBindings(GenerationContext context, List<HttpBinding
"@aws-smithy/server-common");
writer.write("let queryValue: string;");
writer.openBlock("if (Array.isArray(query[$S])) {", "}",
binding.getLocationName(),
() -> {
writer.openBlock("if (query[$S].length === 1) {", "}",
binding.getLocationName(),
() -> {
writer.write("queryValue = query[$S][0];", binding.getLocationName());
});
writer.openBlock("else {", "}", () -> {
writer.write("throw new __SerializationException();");
});
binding.getLocationName(),
() -> {
writer.openBlock("if (query[$S].length === 1) {", "}",
binding.getLocationName(),
() -> {
writer.write("queryValue = query[$S][0];", binding.getLocationName());
}
);
writer.openBlock("else {", "}", () -> {
writer.write("throw new __SerializationException();");
});
});
writer.openBlock("else {", "}", () -> {
writer.write("queryValue = query[$S] as string;", binding.getLocationName());
});
Expand Down Expand Up @@ -2052,18 +2059,21 @@ private void generateOperationResponseDeserializer(
// Ensure that the response type is imported.
writer.addUseImports(responseType);
// e.g., deserializeAws_restJson1_1ExecuteStatement
String methodName = ProtocolGenerator.getDeserFunctionName(symbol, getName());
String methodName = ProtocolGenerator.getDeserFunctionShortName(symbol);
String methodLongName = ProtocolGenerator.getDeserFunctionName(symbol, getName());
String errorMethodName = methodName + "Error";
// Add the normalized output type.
Symbol outputType = symbol.expectProperty("outputType", Symbol.class);
String contextType = CodegenUtils.getOperationDeserializerContextType(context.getSettings(), writer,
context.getModel(), operation);

// Handle the general response.
writer.writeDocs(methodLongName);
writer.openBlock("export const $L = async(\n"
+ " output: $T,\n"
+ " context: $L\n"
+ "): Promise<$T> => {", "}", methodName, responseType, contextType, outputType, () -> {
+ "): Promise<$T> => {", "}",
methodName, responseType, contextType, outputType, () -> {
// Redirect error deserialization to the dispatcher if we receive an error range
// status code that's not the modeled code (300 or higher). This allows for
// returning other 2XX codes that don't match the defined value.
Expand Down Expand Up @@ -2103,10 +2113,13 @@ private void generateErrorDeserializer(GenerationContext context, StructureShape
HttpBindingIndex bindingIndex = HttpBindingIndex.of(context.getModel());
Model model = context.getModel();
Symbol errorSymbol = symbolProvider.toSymbol(error);
String errorDeserMethodName = ProtocolGenerator.getDeserFunctionName(errorSymbol,
context.getProtocolName()) + "Response";
String errorDeserMethodName = ProtocolGenerator.getDeserFunctionShortName(errorSymbol) + "Res";
String errorDeserMethodLongName = ProtocolGenerator.getDeserFunctionName(errorSymbol, context.getProtocolName())
+ "Res";

String outputName = isErrorCodeInBody ? "parsedOutput" : "output";

writer.writeDocs(errorDeserMethodLongName);
writer.openBlock("const $L = async (\n"
+ " $L: any,\n"
+ " context: __SerdeContext\n"
Expand Down Expand Up @@ -2661,8 +2674,8 @@ private String getNamedMembersOutputParam(
case PAYLOAD:
// Redirect to a deserialization function.
Symbol symbol = context.getSymbolProvider().toSymbol(target);
return ProtocolGenerator.getDeserFunctionName(symbol, context.getProtocolName())
+ "(" + dataSource + ", context)";
return ProtocolGenerator.getDeserFunctionShortName(symbol)
+ "(" + dataSource + ", context)";
default:
throw new CodegenException("Unexpected named member shape binding location `" + bindingType + "`");
}
Expand Down
Loading