diff --git a/codegen/smithy-python-codegen/src/main/java/software/amazon/smithy/python/codegen/HttpProtocolTestGenerator.java b/codegen/smithy-python-codegen/src/main/java/software/amazon/smithy/python/codegen/HttpProtocolTestGenerator.java index 854718f12..a8e4988e6 100644 --- a/codegen/smithy-python-codegen/src/main/java/software/amazon/smithy/python/codegen/HttpProtocolTestGenerator.java +++ b/codegen/smithy-python-codegen/src/main/java/software/amazon/smithy/python/codegen/HttpProtocolTestGenerator.java @@ -197,6 +197,8 @@ private void generateRequestTest(OperationShape operation, HttpRequestTestCase t // Execute the command, and catch the expected exception writer.addImport(SmithyPythonDependency.PYTEST.packageName(), "fail", "fail"); + writer.addStdlibImport("urllib.parse", "parse_qs"); + writer.addStdlibImport("typing", "AbstractSet"); writer.write(""" try: await client.$1T(input_) @@ -207,7 +209,23 @@ private void generateRequestTest(OperationShape operation, HttpRequestTestCase t assert actual.method == $3S assert actual.url.path == $4S assert actual.url.host == $5S - $6C + + query = actual.url.query + actual_query_segments: list[str] = query.split("&") if query else [] + expected_query_segments: list[str] = $6J + for expected_query_segment in expected_query_segments: + assert expected_query_segment in actual_query_segments + actual_query_segments.remove(expected_query_segment) + + actual_query_keys: AbstractSet[str] = parse_qs(query).keys() if query else set() + expected_query_keys: set[str] = set($7J) + assert actual_query_keys >= expected_query_keys + + forbidden_query_keys: set[str] = set($8J) + for forbidden_key in forbidden_query_keys: + assert forbidden_key not in actual_query_keys + + $9C except Exception as err: fail(f"Expected '$2L' exception to be thrown, but received {type(err).__name__}: {err}") """, @@ -216,6 +234,9 @@ private void generateRequestTest(OperationShape operation, HttpRequestTestCase t testCase.getMethod(), testCase.getUri(), host, + testCase.getQueryParams(), + testCase.getRequireQueryParams(), + testCase.getForbidQueryParams(), (Runnable) () -> writer.maybeWrite( !testCase.getRequireHeaders().isEmpty(), "assert {h[0] for h in actual.headers} >= $J", diff --git a/codegen/smithy-python-codegen/src/main/java/software/amazon/smithy/python/codegen/integration/HttpBindingProtocolGenerator.java b/codegen/smithy-python-codegen/src/main/java/software/amazon/smithy/python/codegen/integration/HttpBindingProtocolGenerator.java index 405703a9a..933b7ff28 100644 --- a/codegen/smithy-python-codegen/src/main/java/software/amazon/smithy/python/codegen/integration/HttpBindingProtocolGenerator.java +++ b/codegen/smithy-python-codegen/src/main/java/software/amazon/smithy/python/codegen/integration/HttpBindingProtocolGenerator.java @@ -21,10 +21,13 @@ import static software.amazon.smithy.model.knowledge.HttpBinding.Location.LABEL; import static software.amazon.smithy.model.knowledge.HttpBinding.Location.PAYLOAD; import static software.amazon.smithy.model.knowledge.HttpBinding.Location.PREFIX_HEADERS; +import static software.amazon.smithy.model.knowledge.HttpBinding.Location.QUERY; +import static software.amazon.smithy.model.knowledge.HttpBinding.Location.QUERY_PARAMS; import static software.amazon.smithy.model.traits.TimestampFormatTrait.Format; import java.util.List; import java.util.Locale; +import java.util.Map; import java.util.Set; import java.util.TreeSet; import java.util.stream.Collectors; @@ -45,6 +48,7 @@ import software.amazon.smithy.model.shapes.IntegerShape; import software.amazon.smithy.model.shapes.ListShape; import software.amazon.smithy.model.shapes.LongShape; +import software.amazon.smithy.model.shapes.MapShape; import software.amazon.smithy.model.shapes.MemberShape; import software.amazon.smithy.model.shapes.OperationShape; import software.amazon.smithy.model.shapes.Shape; @@ -61,6 +65,7 @@ import software.amazon.smithy.python.codegen.SmithyPythonDependency; import software.amazon.smithy.utils.CodeSection; import software.amazon.smithy.utils.SmithyUnstableApi; +import software.amazon.smithy.utils.StringUtils; /** * Abstract implementation useful for all protocols that use HTTP bindings. @@ -300,22 +305,97 @@ private void serializeQuery( HttpBindingIndex bindingIndex ) { writer.pushState(new SerializeQuerySection(operation)); - writer.write("query_params: list[tuple[str, str | None]] = []"); - // TODO: implement query serialization + writer.writeInline("query_params: list[tuple[str, str | None]] = ["); + var httpTrait = operation.expectTrait(HttpTrait.class); + for (Map.Entry entry : httpTrait.getUri().getQueryLiterals().entrySet()) { + if (StringUtils.isBlank(entry.getValue())) { + writer.write("($S, None),", entry.getKey()); + } else { + writer.write("($S, $S),", entry.getKey(), entry.getValue()); + } + } + writer.write("]\n"); + serializeIndividualQueryParams(context, writer, operation, bindingIndex); + serializeQueryParamsMap(context, writer, operation, bindingIndex); + + writer.addStdlibImport("urllib.parse", "quote", "urlquote"); writer.write(""" query: str = "" for i, param in enumerate(query_params): - if i != 1: + if i != 0: query += "&" if param[1] is None: - query += param[0] + query += urlquote(param[0], safe='') else: - query += f"{param[0]}={param[1]}" + query += f"{urlquote(param[0], safe='')}={urlquote(param[1], safe='')}" """); writer.popState(); } + private void serializeIndividualQueryParams( + GenerationContext context, + PythonWriter writer, + OperationShape operation, + HttpBindingIndex bindingIndex + ) { + var queryBindings = bindingIndex.getRequestBindings(operation, QUERY); + for (HttpBinding binding : queryBindings) { + var memberName = context.symbolProvider().toMemberName(binding.getMember()); + var locationName = binding.getLocationName(); + var target = context.model().expectShape(binding.getMember().getTarget()); + + CodegenUtils.accessStructureMember(context, writer, "input", binding.getMember(), () -> { + if (target.isListShape()) { + var listMember = target.asListShape().get().getMember(); + var listTarget = context.model().expectShape(listMember.getTarget()); + var memberSerializer = listTarget.accept(new HttpMemberSerVisitor( + context, writer, QUERY, "e", listMember, + getDocumentTimestampFormat())); + writer.write("query_params.extend(($S, $L) for e in input.$L)", + locationName, memberSerializer, memberName); + } else { + var memberSerializer = target.accept(new HttpMemberSerVisitor( + context, writer, QUERY, "input." + memberName, binding.getMember(), + getDocumentTimestampFormat())); + writer.write("query_params.append(($S, $L))", locationName, memberSerializer); + } + }); + } + } + + private void serializeQueryParamsMap( + GenerationContext context, + PythonWriter writer, + OperationShape operation, + HttpBindingIndex bindingIndex + ) { + var queryMapBindings = bindingIndex.getRequestBindings(operation, QUERY_PARAMS); + for (HttpBinding binding : queryMapBindings) { + var memberName = context.symbolProvider().toMemberName(binding.getMember()); + var mapShape = context.model().expectShape(binding.getMember().getTarget(), MapShape.class); + var mapTarget = context.model().expectShape(mapShape.getValue().getTarget()); + + CodegenUtils.accessStructureMember(context, writer, "input", binding.getMember(), () -> { + if (mapTarget.isListShape()) { + var listMember = mapTarget.asListShape().get().getMember(); + var listMemberTarget = context.model().expectShape(listMember.getTarget()); + var memberSerializer = listMemberTarget.accept(new HttpMemberSerVisitor( + context, writer, QUERY, "v", listMember, + getDocumentTimestampFormat())); + writer.write("query_params.extend((k, $1L) for k in input.$2L for v in input.$2L[k])", + memberSerializer, memberName); + } else { + var memberSerializer = mapTarget.accept(new HttpMemberSerVisitor( + context, writer, QUERY, "v", mapShape.getValue(), + getDocumentTimestampFormat())); + writer.write("query_params.extend((k, $L) for k, v in input.$L.items())", + memberSerializer, memberName); + } + }); + } + } + /** * A section that controls query serialization. * diff --git a/codegen/smithy-python-codegen/src/main/java/software/amazon/smithy/python/codegen/integration/RestJsonProtocolGenerator.java b/codegen/smithy-python-codegen/src/main/java/software/amazon/smithy/python/codegen/integration/RestJsonProtocolGenerator.java index a3c84500c..b1b7639a1 100644 --- a/codegen/smithy-python-codegen/src/main/java/software/amazon/smithy/python/codegen/integration/RestJsonProtocolGenerator.java +++ b/codegen/smithy-python-codegen/src/main/java/software/amazon/smithy/python/codegen/integration/RestJsonProtocolGenerator.java @@ -52,12 +52,15 @@ @SmithyUnstableApi public class RestJsonProtocolGenerator extends HttpBindingProtocolGenerator { - private static final Set OUTPUT_TESTS_TO_SKIP = Set.of( + private static final Set TESTS_TO_SKIP = Set.of( // These two tests essentially try to assert nan == nan, // which is never true. We should update the generator to // make specific assertions for these. "RestJsonSupportsNaNFloatHeaderOutputs", - "RestJsonSupportsNaNFloatInputs" + "RestJsonSupportsNaNFloatInputs", + + // This requires support of idempotency autofill + "RestJsonQueryIdempotencyTokenAutoFill" ); @Override @@ -82,13 +85,17 @@ context, getProtocol(), writer, (shape, testCase) -> filterTests(context, shape, } private boolean filterTests(GenerationContext context, Shape shape, HttpMessageTestCase testCase) { + if (TESTS_TO_SKIP.contains(testCase.getId())) { + return true; + } if (shape.hasTrait(ErrorTrait.class)) { // Error handling isn't implemented yet return true; } if (testCase instanceof HttpRequestTestCase) { // Request serialization isn't finished, so here we only test the bindings that are supported. - Set implementedBindings = SetUtils.of(Location.LABEL, Location.DOCUMENT); + Set implementedBindings = SetUtils.of(Location.LABEL, Location.DOCUMENT, Location.QUERY, + Location.QUERY_PARAMS); var bindingIndex = HttpBindingIndex.of(context.model()); // If any member specified in the test is bound to a location we haven't yet implemented, @@ -109,9 +116,6 @@ private boolean filterTests(GenerationContext context, Shape shape, HttpMessageT } } if (testCase instanceof HttpResponseTestCase) { - if (OUTPUT_TESTS_TO_SKIP.contains(testCase.getId())) { - return true; - } var bindingIndex = HttpBindingIndex.of(context.model()); return bindingIndex.getResponseBindings(shape, Location.PAYLOAD).size() != 0; }