Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -197,6 +197,8 @@ private void generateRequestTest(OperationShape operation, HttpRequestTestCase t

// Execute the command, and catch the expected exception
writer.addImport(SmithyPythonDependency.PYTEST.packageName(), "fail", "fail");
writer.addStdlibImport("urllib.parse", "parse_qs");
writer.addStdlibImport("typing", "AbstractSet");
writer.write("""
try:
await client.$1T(input_)
Expand All @@ -207,7 +209,23 @@ private void generateRequestTest(OperationShape operation, HttpRequestTestCase t
assert actual.method == $3S
assert actual.url.path == $4S
assert actual.url.host == $5S
$6C

query = actual.url.query
actual_query_segments: list[str] = query.split("&") if query else []
expected_query_segments: list[str] = $6J
for expected_query_segment in expected_query_segments:
assert expected_query_segment in actual_query_segments
actual_query_segments.remove(expected_query_segment)

actual_query_keys: AbstractSet[str] = parse_qs(query).keys() if query else set()
expected_query_keys: set[str] = set($7J)
assert actual_query_keys >= expected_query_keys

forbidden_query_keys: set[str] = set($8J)
for forbidden_key in forbidden_query_keys:
assert forbidden_key not in actual_query_keys

$9C
except Exception as err:
fail(f"Expected '$2L' exception to be thrown, but received {type(err).__name__}: {err}")
""",
Expand All @@ -216,6 +234,9 @@ private void generateRequestTest(OperationShape operation, HttpRequestTestCase t
testCase.getMethod(),
testCase.getUri(),
host,
testCase.getQueryParams(),
testCase.getRequireQueryParams(),
testCase.getForbidQueryParams(),
(Runnable) () -> writer.maybeWrite(
!testCase.getRequireHeaders().isEmpty(),
"assert {h[0] for h in actual.headers} >= $J",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,10 +21,13 @@
import static software.amazon.smithy.model.knowledge.HttpBinding.Location.LABEL;
import static software.amazon.smithy.model.knowledge.HttpBinding.Location.PAYLOAD;
import static software.amazon.smithy.model.knowledge.HttpBinding.Location.PREFIX_HEADERS;
import static software.amazon.smithy.model.knowledge.HttpBinding.Location.QUERY;
import static software.amazon.smithy.model.knowledge.HttpBinding.Location.QUERY_PARAMS;
import static software.amazon.smithy.model.traits.TimestampFormatTrait.Format;

import java.util.List;
import java.util.Locale;
import java.util.Map;
import java.util.Set;
import java.util.TreeSet;
import java.util.stream.Collectors;
Expand All @@ -45,6 +48,7 @@
import software.amazon.smithy.model.shapes.IntegerShape;
import software.amazon.smithy.model.shapes.ListShape;
import software.amazon.smithy.model.shapes.LongShape;
import software.amazon.smithy.model.shapes.MapShape;
import software.amazon.smithy.model.shapes.MemberShape;
import software.amazon.smithy.model.shapes.OperationShape;
import software.amazon.smithy.model.shapes.Shape;
Expand All @@ -61,6 +65,7 @@
import software.amazon.smithy.python.codegen.SmithyPythonDependency;
import software.amazon.smithy.utils.CodeSection;
import software.amazon.smithy.utils.SmithyUnstableApi;
import software.amazon.smithy.utils.StringUtils;

/**
* Abstract implementation useful for all protocols that use HTTP bindings.
Expand Down Expand Up @@ -300,22 +305,97 @@ private void serializeQuery(
HttpBindingIndex bindingIndex
) {
writer.pushState(new SerializeQuerySection(operation));
writer.write("query_params: list[tuple[str, str | None]] = []");
// TODO: implement query serialization
writer.writeInline("query_params: list[tuple[str, str | None]] = [");
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just to make sure I understand how we intend this to work. If I want mylonginfovalue1234567890 appended to the end of my URI as the only part of a query string, it would be something like:

query_params = [('mylonginfovalue1234567890', None)]
uri = build_uri(query_params=query_params, **kwargs)
>>> https://example.com/?mylonginfovalue1234567890

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes

var httpTrait = operation.expectTrait(HttpTrait.class);
for (Map.Entry<String, String> entry : httpTrait.getUri().getQueryLiterals().entrySet()) {
if (StringUtils.isBlank(entry.getValue())) {
writer.write("($S, None),", entry.getKey());
} else {
writer.write("($S, $S),", entry.getKey(), entry.getValue());
}
}
writer.write("]\n");

serializeIndividualQueryParams(context, writer, operation, bindingIndex);
serializeQueryParamsMap(context, writer, operation, bindingIndex);

writer.addStdlibImport("urllib.parse", "quote", "urlquote");
writer.write("""
query: str = ""
for i, param in enumerate(query_params):
if i != 1:
if i != 0:
query += "&"
if param[1] is None:
query += param[0]
query += urlquote(param[0], safe='')
else:
query += f"{param[0]}={param[1]}"
query += f"{urlquote(param[0], safe='')}={urlquote(param[1], safe='')}"
""");
writer.popState();
}

private void serializeIndividualQueryParams(
GenerationContext context,
PythonWriter writer,
OperationShape operation,
HttpBindingIndex bindingIndex
) {
var queryBindings = bindingIndex.getRequestBindings(operation, QUERY);
for (HttpBinding binding : queryBindings) {
var memberName = context.symbolProvider().toMemberName(binding.getMember());
var locationName = binding.getLocationName();
var target = context.model().expectShape(binding.getMember().getTarget());

CodegenUtils.accessStructureMember(context, writer, "input", binding.getMember(), () -> {
if (target.isListShape()) {
var listMember = target.asListShape().get().getMember();
var listTarget = context.model().expectShape(listMember.getTarget());
var memberSerializer = listTarget.accept(new HttpMemberSerVisitor(
context, writer, QUERY, "e", listMember,
getDocumentTimestampFormat()));
writer.write("query_params.extend(($S, $L) for e in input.$L)",
locationName, memberSerializer, memberName);
} else {
var memberSerializer = target.accept(new HttpMemberSerVisitor(
context, writer, QUERY, "input." + memberName, binding.getMember(),
getDocumentTimestampFormat()));
writer.write("query_params.append(($S, $L))", locationName, memberSerializer);
}
});
}
}

private void serializeQueryParamsMap(
GenerationContext context,
PythonWriter writer,
OperationShape operation,
HttpBindingIndex bindingIndex
) {
var queryMapBindings = bindingIndex.getRequestBindings(operation, QUERY_PARAMS);
for (HttpBinding binding : queryMapBindings) {
var memberName = context.symbolProvider().toMemberName(binding.getMember());
var mapShape = context.model().expectShape(binding.getMember().getTarget(), MapShape.class);
var mapTarget = context.model().expectShape(mapShape.getValue().getTarget());

CodegenUtils.accessStructureMember(context, writer, "input", binding.getMember(), () -> {
if (mapTarget.isListShape()) {
var listMember = mapTarget.asListShape().get().getMember();
var listMemberTarget = context.model().expectShape(listMember.getTarget());
var memberSerializer = listMemberTarget.accept(new HttpMemberSerVisitor(
context, writer, QUERY, "v", listMember,
getDocumentTimestampFormat()));
writer.write("query_params.extend((k, $1L) for k in input.$2L for v in input.$2L[k])",
memberSerializer, memberName);
} else {
var memberSerializer = mapTarget.accept(new HttpMemberSerVisitor(
context, writer, QUERY, "v", mapShape.getValue(),
getDocumentTimestampFormat()));
writer.write("query_params.extend((k, $L) for k, v in input.$L.items())",
memberSerializer, memberName);
}
});
}
}

/**
* A section that controls query serialization.
*
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -52,12 +52,15 @@
@SmithyUnstableApi
public class RestJsonProtocolGenerator extends HttpBindingProtocolGenerator {

private static final Set<String> OUTPUT_TESTS_TO_SKIP = Set.of(
private static final Set<String> TESTS_TO_SKIP = Set.of(
// These two tests essentially try to assert nan == nan,
// which is never true. We should update the generator to
// make specific assertions for these.
"RestJsonSupportsNaNFloatHeaderOutputs",
"RestJsonSupportsNaNFloatInputs"
"RestJsonSupportsNaNFloatInputs",

// This requires support of idempotency autofill
"RestJsonQueryIdempotencyTokenAutoFill"
);

@Override
Expand All @@ -82,13 +85,17 @@ context, getProtocol(), writer, (shape, testCase) -> filterTests(context, shape,
}

private boolean filterTests(GenerationContext context, Shape shape, HttpMessageTestCase testCase) {
if (TESTS_TO_SKIP.contains(testCase.getId())) {
return true;
}
if (shape.hasTrait(ErrorTrait.class)) {
// Error handling isn't implemented yet
return true;
}
if (testCase instanceof HttpRequestTestCase) {
// Request serialization isn't finished, so here we only test the bindings that are supported.
Set<Location> implementedBindings = SetUtils.of(Location.LABEL, Location.DOCUMENT);
Set<Location> implementedBindings = SetUtils.of(Location.LABEL, Location.DOCUMENT, Location.QUERY,
Location.QUERY_PARAMS);
var bindingIndex = HttpBindingIndex.of(context.model());

// If any member specified in the test is bound to a location we haven't yet implemented,
Expand All @@ -109,9 +116,6 @@ private boolean filterTests(GenerationContext context, Shape shape, HttpMessageT
}
}
if (testCase instanceof HttpResponseTestCase) {
if (OUTPUT_TESTS_TO_SKIP.contains(testCase.getId())) {
return true;
}
var bindingIndex = HttpBindingIndex.of(context.model());
return bindingIndex.getResponseBindings(shape, Location.PAYLOAD).size() != 0;
}
Expand Down