Skip to content

Commit

Permalink
Implement httpLabel with nom
Browse files Browse the repository at this point in the history
Move to `nom` to implement httpLabel instead of using regexes.

Issue: #938

Signed-off-by: Daniele Ahmed <ahmeddan@amazon.com>
  • Loading branch information
82marbag committed Jan 4, 2022
1 parent 56cf68a commit 29cbd04
Show file tree
Hide file tree
Showing 4 changed files with 72 additions and 29 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ object ServerCargoDependency {
val AsyncTrait: CargoDependency = CargoDependency("async-trait", CratesIo("0.1"))
val AxumCore: CargoDependency = CargoDependency("axum-core", CratesIo("0.1"))
val FuturesUtil: CargoDependency = CargoDependency("futures-util", CratesIo("0.3"))
val Nom: CargoDependency = CargoDependency("nom", CratesIo("7"))
val PinProjectLite: CargoDependency = CargoDependency("pin-project-lite", CratesIo("0.2"))
val SerdeUrlEncoded: CargoDependency = CargoDependency("serde_urlencoded", CratesIo("0.7"))
val Tower: CargoDependency = CargoDependency("tower", CratesIo("0.4"))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,7 @@ private class ServerHttpProtocolImplGenerator(
"HttpBody" to CargoDependency.HttpBody.asType(),
"Hyper" to CargoDependency.Hyper.asType(),
"LazyStatic" to CargoDependency.LazyStatic.asType(),
"Nom" to ServerCargoDependency.Nom.asType(),
"PercentEncoding" to CargoDependency.PercentEncoding.asType(),
"Regex" to CargoDependency.Regex.asType(),
"SerdeUrlEncoded" to ServerCargoDependency.SerdeUrlEncoded.asType(),
Expand Down Expand Up @@ -572,46 +573,80 @@ private class ServerHttpProtocolImplGenerator(
if (pathBindings.isEmpty()) {
return
}
val pattern = StringBuilder()
val httpTrait = httpBindingResolver.httpTrait(operationShape)
httpTrait.uri.segments.forEach {
pattern.append("/")
if (it.isLabel) {
pattern.append("(?P<${it.content}>")
if (it.isGreedyLabel) {
pattern.append(".+")
val greedyLabelIndex = httpTrait.uri.segments.indexOfFirst { it.isGreedyLabel }
val segments =
if (greedyLabelIndex >= 0)
httpTrait.uri.segments.slice(0 until (greedyLabelIndex + 1))
else
httpTrait.uri.segments
val restAfterGreedyLabel =
if (greedyLabelIndex >= 0)
httpTrait.uri.segments.slice((greedyLabelIndex + 1) until httpTrait.uri.segments.size).joinToString(prefix = "/", separator = "/")
else
""
val labeledNames = segments
.mapIndexed { index, segment ->
if (segment.isLabel) { "m$index" } else { "_" }
}
.joinToString(prefix = (if (segments.size > 1) "(" else ""), separator = ",", postfix = (if (segments.size > 1) ")" else ""))
val nomParser = segments
.map { segment ->
if (segment.isGreedyLabel) {
"#{Nom}::combinator::rest::<_, #{Nom}::error::Error<&str>>"
} else if (segment.isLabel) {
"""#{Nom}::branch::alt::<_, _, #{Nom}::error::Error<&str>, _>((#{Nom}::bytes::complete::take_until("/"), #{Nom}::combinator::rest))"""
} else {
pattern.append("[^/]+")
"""#{Nom}::bytes::complete::tag::<_, _, #{Nom}::error::Error<&str>>("${segment.content}")"""
}
pattern.append(")")
} else {
pattern.append(it.content)
}
}
.joinToString(
// TODO: tuple() is currently limited to 21 items
prefix = if (segments.size > 1) "#{Nom}::sequence::tuple::<_, _, #{Nom}::error::Error<&str>, _>((" else "",
postfix = if (segments.size > 1) "))" else "",
transform = { parser ->
"""
#{Nom}::sequence::preceded(#{Nom}::bytes::complete::tag("/"), $parser)
""".trimIndent()
}
)
with(writer) {
rustTemplate("let input_string = request.uri().path();")
if (greedyLabelIndex >= 0 && greedyLabelIndex + 1 < httpTrait.uri.segments.size) {
rustTemplate(
"""
if !input_string.ends_with(${restAfterGreedyLabel.dq()}) {
return std::result::Result::Err(#{SmithyRejection}::Deserialize(
aws_smithy_http_server::rejection::Deserialize::from_err(format!("Postfix not found: {}", ${restAfterGreedyLabel.dq()}))));
}
let input_string = &input_string[..(input_string.len() - ${restAfterGreedyLabel.dq()}.len())];
""".trimIndent(),
*codegenScope
)
}
rustTemplate(
"""
#{LazyStatic}::lazy_static! {
static ref RE: #{Regex}::Regex = #{Regex}::Regex::new("$pattern").unwrap();
}
let (input_string, $labeledNames) = $nomParser(input_string)?;
debug_assert_eq!("", input_string);
""".trimIndent(),
*codegenScope,
*codegenScope
)
rustBlock("if let Some(captures) = RE.captures(request.uri().path())") {
pathBindings.forEach {
val deserializer = generateParsePercentEncodedStrFn(it)
rustTemplate(
"""
if let Some(m) = captures.name("${it.locationName}") {
input = input.${it.member.setterName()}(
#{deserializer}(m.as_str())?
segments
.forEachIndexed { index, segment ->
val binding = pathBindings.find { it.memberName == segment.content }
if (binding != null && segment.isLabel) {
val deserializer = generateParsePercentEncodedStrFn(binding)
rustTemplate(
"""
input = input.${binding.member.setterName()}(
#{deserializer}(m$index)?
);
}
""".trimIndent(),
"deserializer" to deserializer,
)
""".trimIndent(),
*codegenScope,
"deserializer" to deserializer,
)
}
}
}
}
}

Expand Down
1 change: 1 addition & 0 deletions rust-runtime/aws-smithy-http-server/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ http = "0.2"
http-body = "0.4"
hyper = { version = "0.14", features = ["server", "http1", "http2", "tcp"] }
mime = "0.3"
nom = "7"
pin-project-lite = "0.2"
regex = "1.0"
serde_urlencoded = "0.7"
Expand Down
6 changes: 6 additions & 0 deletions rust-runtime/aws-smithy-http-server/src/rejection.rs
Original file line number Diff line number Diff line change
Expand Up @@ -227,3 +227,9 @@ impl From<serde_urlencoded::de::Error> for SmithyRejection {
SmithyRejection::Deserialize(Deserialize::from_err(err))
}
}

impl From<nom::Err<nom::error::Error<&str>>> for SmithyRejection {
fn from(err: nom::Err<nom::error::Error<&str>>) -> Self {
SmithyRejection::Deserialize(Deserialize::from_err(err.to_owned()))
}
}

0 comments on commit 29cbd04

Please sign in to comment.