Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix response payload and incorrectly parsing error response #66

Merged
merged 8 commits into from
Jan 10, 2020
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,6 @@
import software.amazon.smithy.model.shapes.NumberShape;
import software.amazon.smithy.model.shapes.OperationShape;
import software.amazon.smithy.model.shapes.Shape;
import software.amazon.smithy.model.shapes.ShapeType;
import software.amazon.smithy.model.shapes.StringShape;
import software.amazon.smithy.model.shapes.StructureShape;
import software.amazon.smithy.model.shapes.TimestampShape;
Expand All @@ -72,6 +71,17 @@ public abstract class HttpBindingProtocolGenerator implements ProtocolGenerator
private final Set<Shape> serializingDocumentShapes = new TreeSet<>();
private final Set<Shape> deserializingDocumentShapes = new TreeSet<>();
private final Set<StructureShape> deserializingErrorShapes = new TreeSet<>();
private final boolean isErrorCodeInBody;

/**
* Creates a Http binding protocol generator.
*
* @param isErrorCodeInBody A boolean that indicates if the error code for the implementing protocol is located in
* the error response body, meaning this generator will parse the body before attempting to load an error code.
*/
public HttpBindingProtocolGenerator(boolean isErrorCodeInBody) {
this.isErrorCodeInBody = isErrorCodeInBody;
}

@Override
public ApplicationProtocol getApplicationProtocol() {
Expand Down Expand Up @@ -120,6 +130,8 @@ public void generateSharedComponents(GenerationContext context) {
generateDocumentBodyShapeSerializers(context, serializingDocumentShapes);
generateDocumentBodyShapeDeserializers(context, deserializingDocumentShapes);
HttpProtocolGeneratorUtils.generateMetadataDeserializer(context, getApplicationProtocol().getResponseType());
HttpProtocolGeneratorUtils.generateCollectBody(context);
HttpProtocolGeneratorUtils.generateCollectBodyString(context);
}

/**
Expand Down Expand Up @@ -594,7 +606,7 @@ private void generateOperationDeserializer(

// Write out the error deserialization dispatcher.
Set<StructureShape> errorShapes = HttpProtocolGeneratorUtils.generateErrorDispatcher(
context, operation, responseType, this::writeErrorCodeParser);
context, operation, responseType, this::writeErrorCodeParser, isErrorCodeInBody);
deserializingErrorShapes.addAll(errorShapes);
}

Expand All @@ -608,9 +620,10 @@ private void generateErrorDeserializer(GenerationContext context, StructureShape
context.getProtocolName()) + "Response";

writer.openBlock("const $L = async (\n"
+ " output: any,\n"
+ " $L: any,\n"
+ " context: __SerdeContext\n"
+ "): Promise<$T> => {", "};", errorDeserMethodName, errorSymbol, () -> {
+ "): Promise<$T> => {", "};",
errorDeserMethodName, isErrorCodeInBody ? "parsedOutput" : "output", errorSymbol, () -> {
writer.openBlock("const contents: $T = {", "};", errorSymbol, () -> {
writer.write("__type: $S,", error.getId().getName());
writer.write("$$fault: $S,", error.getTrait(ErrorTrait.class).get().getValue());
Expand All @@ -621,7 +634,7 @@ private void generateErrorDeserializer(GenerationContext context, StructureShape
});

readHeaders(context, error, bindingIndex);
List<HttpBinding> documentBindings = readResponseBody(context, error, bindingIndex);
List<HttpBinding> documentBindings = readErrorResponseBody(context, error, bindingIndex);
// Track all shapes bound to the document so their deserializers may be generated.
documentBindings.forEach(binding -> {
Shape target = model.expectShape(binding.getMember().getTarget());
Expand All @@ -633,6 +646,24 @@ private void generateErrorDeserializer(GenerationContext context, StructureShape
writer.write("");
}

private List<HttpBinding> readErrorResponseBody(
GenerationContext context,
Shape error,
HttpBindingIndex bindingIndex
) {
TypeScriptWriter writer = context.getWriter();
if (isErrorCodeInBody) {
// Body is already parsed in error dispatcher, simply assign body to data.
writer.write("const data: any = output.body;");
List<HttpBinding> responseBindings = bindingIndex.getResponseBindings(error, Location.DOCUMENT);
responseBindings.sort(Comparator.comparing(HttpBinding::getMemberName));
return responseBindings;
} else {
// Deserialize response body just like in normal response.
return readResponseBody(context, error, bindingIndex);
}
}

private void readHeaders(
GenerationContext context,
Shape operationOrError,
Expand Down Expand Up @@ -691,42 +722,48 @@ private List<HttpBinding> readResponseBody(
documentBindings.sort(Comparator.comparing(HttpBinding::getMemberName));
List<HttpBinding> payloadBindings = bindingIndex.getResponseBindings(operationOrError, Location.PAYLOAD);

// Detect if operation output or error shape contains a streaming member.
OperationIndex operationIndex = context.getModel().getKnowledge(OperationIndex.class);
StructureShape operationOutputOrError = operationOrError.asStructureShape()
.orElseGet(() -> operationIndex.getOutput(operationOrError).orElse(null));
boolean hasStreamingComponent = Optional.ofNullable(operationOutputOrError)
.map(structure -> structure.getAllMembers().values().stream()
.anyMatch(memberShape -> memberShape.hasTrait(StreamingTrait.class)))
.orElse(false);

if (!documentBindings.isEmpty()) {
readReponseBodyData(context, operationOrError);
// If response has document binding, the body can be parsed to JavaScript object.
writer.write("const data: any = await parseBody(output.body, context);");
deserializeOutputDocument(context, operationOrError, documentBindings);
return documentBindings;
}
if (!payloadBindings.isEmpty()) {
readReponseBodyData(context, operationOrError);
// There can only be one payload binding.
HttpBinding binding = payloadBindings.get(0);
Shape target = context.getModel().expectShape(binding.getMember().getTarget());
if (hasStreamingComponent) {
// If payload is streaming, return raw low-level stream directly.
writer.write("const data: any = output.body;");
} else if (target instanceof BlobShape) {
// If payload is non-streaming blob, only need to collect stream to binary data(Uint8Array).
writer.write("const data: any = await collectBody(output.body, context);");
} else if (target instanceof StructureShape || target instanceof UnionShape) {
// If body is Structure or Union, they we need to parse the string into JavaScript object.
writer.write("const data: any = await parseBody(output.body, context);");
} else if (target instanceof StringShape) {
// If payload is string, we need to collect body and encode binary to string.
writer.write("const data: any = await collectBodyString(output.body, context);");
} else {
throw new CodegenException(String.format("Unexpected shape type bound to payload: `%s`",
target.getType()));
}
writer.write("contents.$L = $L;", binding.getMemberName(), getOutputValue(context,
Location.PAYLOAD, "data", binding.getMember(), target));
return payloadBindings;
}
return ListUtils.of();
}

private void readReponseBodyData(GenerationContext context, Shape operationOrError) {
TypeScriptWriter writer = context.getWriter();
// Prepare response body for deserializing.
OperationIndex operationIndex = context.getModel().getKnowledge(OperationIndex.class);
StructureShape operationOutputOrError = operationOrError.asStructureShape()
.orElseGet(() -> operationIndex.getOutput(operationOrError).orElse(null));
boolean hasStreamingComponent = Optional.ofNullable(operationOutputOrError)
.map(structure -> structure.getAllMembers().values().stream()
.anyMatch(memberShape -> memberShape.hasTrait(StreamingTrait.class)))
.orElse(false);
if (hasStreamingComponent || operationOrError.getType().equals(ShapeType.STRUCTURE)) {
// For operations with streaming output or errors with streaming body we keep the body intact.
writer.write("const data: any = output.body;");
} else {
// Otherwise, we collect the response body to structured object with parseBody().
writer.write("const data: any = await parseBody(output.body, context);");
}
}

/**
* Given context and a source of data, generate an output value provider for the
* shape. This may use native types (like generating a Date for timestamps,)
Expand Down Expand Up @@ -890,10 +927,16 @@ private String getNumberOutputParam(Location bindingType, String dataSource, Sha
* Writes the code that loads an {@code errorCode} String with the content used
* to dispatch errors to specific serializers.
*
* <p>Three variables will be in scope:
* <p>Two variables will be in scope:
* <ul>
* <li>{@code output}: a value of the HttpResponse type.</li>
* <li>{@code data}: the contents of the response body.</li>
* <li>{@code output} or {@code parsedOutput}: a value of the HttpResponse type.
* <ul>
* <li>{@code output} is a raw HttpResponse, available when {@code isErrorCodeInBody} is set to
* {@code false}</li>
* <li>{@code parsedOutput} is a HttpResponse type with body parsed to JavaScript object, available
* when {@code isErrorCodeInBody} is set to {@code true}</li>
* </ul>
* </li>
* <li>{@code context}: the SerdeContext.</li>
* </ul>
*
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,44 @@ static void generateMetadataDeserializer(GenerationContext context, SymbolRefere
writer.write("");
}

/**
* Writes a response body stream collector. This function converts the low-level response body stream to
* Uint8Array binary data.
*
* @param context The generation context.
*/
static void generateCollectBody(GenerationContext context) {
TypeScriptWriter writer = context.getWriter();

writer.addImport("SerdeContext", "__SerdeContext", "@aws-sdk/types");
writer.write("// Collect low-level response body stream to Uint8Array.");
writer.openBlock("const collectBody = (streamBody: any, context: __SerdeContext): Promise<Uint8Array> => {",
"};", () -> {
writer.write("return context.streamCollector(streamBody) || new Uint8Array();");
});

writer.write("");
}

/**
* Writes a function converting the low-level response body stream to utf-8 encoded string. It depends on
* response body stream collector {@link #generateCollectBody(GenerationContext)}.
*
* @param context The generation context
*/
static void generateCollectBodyString(GenerationContext context) {
TypeScriptWriter writer = context.getWriter();

writer.addImport("SerdeContext", "__SerdeContext", "@aws-sdk/types");
writer.write("// Encode Uint8Array data into string with utf-8.");
writer.openBlock("const collectBodyString = (streamBody: any, context: __SerdeContext): Promise<string> => {",
"};", () -> {
writer.write("return collectBody(streamBody, context).then(body => context.utf8Encoder(body));");
});

writer.write("");
}

/**
* Writes a function used to dispatch to the proper error deserializer
* for each error that the operation can return. The generated function
Expand All @@ -118,13 +156,15 @@ static void generateMetadataDeserializer(GenerationContext context, SymbolRefere
* @param operation The operation to generate for.
* @param responseType The response type for the HTTP protocol.
* @param errorCodeGenerator A consumer
* @param shouldParseErrorBody Flag indicating whether need to parse response body in this dispatcher function
* @return A set of all error structure shapes for the operation that were dispatched to.
*/
static Set<StructureShape> generateErrorDispatcher(
GenerationContext context,
OperationShape operation,
SymbolReference responseType,
Consumer<GenerationContext> errorCodeGenerator
Consumer<GenerationContext> errorCodeGenerator,
boolean shouldParseErrorBody
) {
TypeScriptWriter writer = context.getWriter();
SymbolProvider symbolProvider = context.getSymbolProvider();
Expand All @@ -138,14 +178,14 @@ static Set<StructureShape> generateErrorDispatcher(
+ " output: $T,\n"
+ " context: __SerdeContext,\n"
+ "): Promise<$T> {", "}", errorMethodName, responseType, outputType, () -> {
writer.write("const data: any = await parseBody(output.body, context);");
// We only consume the parsedOutput if we're dispatching, so only generate if we will.
if (!operation.getErrors().isEmpty()) {
// Create a holding object since we have already parsed the body, but retain the rest of the output.
writer.openBlock("const parsedOutput: any = {", "};", () -> {
writer.write("...output,");
writer.write("body: data,");
});
// Prepare error response for parsing error code. If error code needs to be parsed from response body
// then we collect body and parse it to JS object, otherwise leave the response body as is.
if (shouldParseErrorBody) {
writer.openBlock("const parsedOutput: any = {", "};",
() -> {
writer.write("...output,");
writer.write("body: await parseBody(output.body, context)");
});
}

// Error responses must be at least SmithyException and MetadataBearer implementations.
Expand All @@ -167,7 +207,8 @@ static Set<StructureShape> generateErrorDispatcher(
context.getProtocolName()) + "Response";
writer.openBlock("case $S:\ncase $S:", " break;", errorId.getName(), errorId.toString(), () -> {
// Dispatch to the error deserialization function.
writer.write("response = await $L(parsedOutput, context);", errorDeserMethodName);
writer.write("response = await $L($L, context);",
errorDeserMethodName, shouldParseErrorBody ? "parsedOutput" : "output");
});
});

Expand Down
Loading