Skip to content

Commit

Permalink
Make the fields constructor infallible, and add fields.from-list
Browse files Browse the repository at this point in the history
  • Loading branch information
elliottt committed Oct 30, 2023
1 parent ae3d7b8 commit 66ad714
Show file tree
Hide file tree
Showing 8 changed files with 79 additions and 21 deletions.
2 changes: 1 addition & 1 deletion crates/test-programs/src/bin/api_proxy.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ struct T;

impl bindings::exports::wasi::http::incoming_handler::Guest for T {
fn handle(_request: IncomingRequest, outparam: ResponseOutparam) {
let hdrs = bindings::wasi::http::types::Headers::new(&[]);
let hdrs = bindings::wasi::http::types::Headers::new();
let resp = bindings::wasi::http::types::OutgoingResponse::new(200, hdrs);
let body = resp.body().expect("outgoing response");

Expand Down
11 changes: 6 additions & 5 deletions crates/test-programs/src/bin/api_proxy_streaming.rs
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ async fn handle_request(request: IncomingRequest, response_out: ResponseOutparam

let response = OutgoingResponse::new(
200,
Fields::new(&[("content-type".to_string(), b"text/plain".to_vec())]),
Fields::from_list(&[("content-type".to_string(), b"text/plain".to_vec())]).unwrap(),
);

let mut body =
Expand All @@ -75,12 +75,13 @@ async fn handle_request(request: IncomingRequest, response_out: ResponseOutparam
(Method::Post, Some("/echo")) => {
let response = OutgoingResponse::new(
200,
Fields::new(
Fields::from_list(
&headers
.into_iter()
.filter_map(|(k, v)| (k == "content-type").then_some((k, v)))
.collect::<Vec<_>>(),
),
)
.unwrap(),
);

let mut body =
Expand Down Expand Up @@ -108,7 +109,7 @@ async fn handle_request(request: IncomingRequest, response_out: ResponseOutparam
}

_ => {
let response = OutgoingResponse::new(405, Fields::new(&[]));
let response = OutgoingResponse::new(405, Fields::new());

let body = response.body().expect("response should be writable");

Expand Down Expand Up @@ -137,7 +138,7 @@ async fn hash(url: &Url) -> Result<String> {
String::new()
}
)),
Fields::new(&[]),
Fields::new(),
);

let response = executor::outgoing_request_send(request).await?;
Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
use test_programs::wasi::http::types::{HeaderError, Headers};

fn main() {
let hdrs = Headers::new(&[]);
let hdrs = Headers::new();
assert!(matches!(
hdrs.append(&"malformed header name".to_owned(), &b"ok value".to_vec()),
Err(HeaderError::InvalidSyntax)
Expand Down Expand Up @@ -42,4 +42,19 @@ fn main() {
),
Err(HeaderError::Forbidden)
));

assert!(matches!(
Headers::from_list(&[("bad header".to_owned(), b"value".to_vec())]),
Err(HeaderError::InvalidSyntax)
));

assert!(matches!(
Headers::from_list(&[("custom-forbidden-header".to_owned(), b"value".to_vec())]),
Err(HeaderError::Forbidden)
));

assert!(matches!(
Headers::from_list(&[("ok-header-name".to_owned(), b"bad\nvalue".to_vec())]),
Err(HeaderError::InvalidSyntax)
));
}
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,11 @@ use test_programs::wasi::http::types as http_types;
fn main() {
println!("Called _start");
{
let headers = http_types::Headers::new(&[(
let headers = http_types::Headers::from_list(&[(
"Content-Type".to_string(),
"application/json".to_string().into_bytes(),
)]);
)])
.unwrap();
let request = http_types::OutgoingRequest::new(
&http_types::Method::Get,
None,
Expand All @@ -21,10 +22,11 @@ fn main() {
.unwrap();
}
{
let headers = http_types::Headers::new(&[(
let headers = http_types::Headers::from_list(&[(
"Content-Type".to_string(),
"application/text".to_string().into_bytes(),
)]);
)])
.unwrap();
let response = http_types::OutgoingResponse::new(200, headers);
let outgoing_body = response.body().unwrap();
let response_body = outgoing_body.write().unwrap();
Expand Down
4 changes: 2 additions & 2 deletions crates/test-programs/src/http.rs
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ pub fn request(
fn header_val(v: &str) -> Vec<u8> {
v.to_string().into_bytes()
}
let headers = http_types::Headers::new(
let headers = http_types::Headers::from_list(
&[
&[
("User-agent".to_string(), header_val("WASI-HTTP/0.0.1")),
Expand All @@ -51,7 +51,7 @@ pub fn request(
additional_headers.unwrap_or(&[]),
]
.concat(),
);
)?;

let request = http_types::OutgoingRequest::new(
&method,
Expand Down
36 changes: 30 additions & 6 deletions crates/wasi-http/src/types_impl.rs
Original file line number Diff line number Diff line change
Expand Up @@ -71,14 +71,38 @@ fn is_forbidden_header<T: WasiHttpView>(view: &mut T, name: &HeaderName) -> bool
}

impl<T: WasiHttpView> crate::bindings::http::types::HostFields for T {
fn new(&mut self, entries: Vec<(String, Vec<u8>)>) -> wasmtime::Result<Resource<HostFields>> {
fn new(&mut self) -> wasmtime::Result<Resource<HostFields>> {
let id = self
.table()
.push(HostFields::Owned {
fields: hyper::HeaderMap::new(),
})
.context("[new_fields] pushing fields")?;

Ok(id)
}

fn from_list(
&mut self,
entries: Vec<(String, Vec<u8>)>,
) -> wasmtime::Result<Result<Resource<HostFields>, HeaderError>> {
let mut map = hyper::HeaderMap::new();

for (header, value) in entries {
// This will trap for an invalid header name, but there's no other way to communicate
// the error out from a constructor.
let header = hyper::header::HeaderName::from_bytes(header.as_bytes())?;
let value = hyper::header::HeaderValue::from_bytes(&value)?;
let header = match hyper::header::HeaderName::from_bytes(header.as_bytes()) {
Ok(header) => header,
Err(_) => return Ok(Err(HeaderError::InvalidSyntax)),
};

if is_forbidden_header(self, &header) {
return Ok(Err(HeaderError::Forbidden));
}

let value = match hyper::header::HeaderValue::from_bytes(&value) {
Ok(value) => value,
Err(_) => return Ok(Err(HeaderError::InvalidSyntax)),
};

map.append(header, value);
}

Expand All @@ -87,7 +111,7 @@ impl<T: WasiHttpView> crate::bindings::http::types::HostFields for T {
.push(HostFields::Owned { fields: map })
.context("[new_fields] pushing fields")?;

Ok(id)
Ok(Ok(id))
}

fn drop(&mut self, fields: Resource<HostFields>) -> wasmtime::Result<()> {
Expand Down
10 changes: 9 additions & 1 deletion crates/wasi-http/wit/deps/http/types.wit
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,9 @@ interface types {
/// Headers and Trailers.
resource fields {

/// Construct an empty HTTP Fields.
constructor();

/// Construct an HTTP Fields.
///
/// The list represents each key-value pair in the Fields. Keys
Expand All @@ -66,7 +69,12 @@ interface types {
/// Value, represented as a list of bytes. In a valid Fields, all keys
/// and values are valid UTF-8 strings. However, values are not always
/// well-formed, so they are represented as a raw list of bytes.
constructor(entries: list<tuple<field-key,field-value>>);
///
/// An error result will be returned if any header or value was
/// syntactically invalid, or if a header was forbidden.
from-list: static func(
entries: list<tuple<field-key,field-value>>
) -> result<fields, header-error>;

/// Get all of the values corresponding to a key.
get: func(name: field-key) -> list<field-value>;
Expand Down
10 changes: 9 additions & 1 deletion crates/wasi/wit/deps/http/types.wit
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,9 @@ interface types {
/// Headers and Trailers.
resource fields {

/// Construct an empty HTTP Fields.
constructor();

/// Construct an HTTP Fields.
///
/// The list represents each key-value pair in the Fields. Keys
Expand All @@ -66,7 +69,12 @@ interface types {
/// Value, represented as a list of bytes. In a valid Fields, all keys
/// and values are valid UTF-8 strings. However, values are not always
/// well-formed, so they are represented as a raw list of bytes.
constructor(entries: list<tuple<field-key,field-value>>);
///
/// An error result will be returned if any header or value was
/// syntactically invalid, or if a header was forbidden.
from-list: static func(
entries: list<tuple<field-key,field-value>>
) -> result<fields, header-error>;

/// Get all of the values corresponding to a key.
get: func(name: field-key) -> list<field-value>;
Expand Down

0 comments on commit 66ad714

Please sign in to comment.