Skip to content

Commit

Permalink
Macro for marking future as Send
Browse files Browse the repository at this point in the history
  • Loading branch information
kflansburg committed Mar 20, 2024
1 parent a430c23 commit fca35f4
Show file tree
Hide file tree
Showing 10 changed files with 82 additions and 14 deletions.
26 changes: 26 additions & 0 deletions worker-macros/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@ mod durable_object;
mod event;

use proc_macro::TokenStream;
use quote::quote;
use syn::{parse_macro_input, ItemFn};

#[proc_macro_attribute]
pub fn durable_object(_attr: TokenStream, item: TokenStream) -> TokenStream {
Expand All @@ -21,3 +23,27 @@ pub fn event(attr: TokenStream, item: TokenStream) -> TokenStream {
pub fn event(attr: TokenStream, item: TokenStream) -> TokenStream {
event::expand_macro(attr, item, false)
}

#[proc_macro_attribute]
pub fn send(_attr: TokenStream, stream: TokenStream) -> TokenStream {
let stream_clone = stream.clone();
let input = parse_macro_input!(stream_clone as ItemFn);

let ItemFn {
attrs,
vis,
sig,
block,
} = input;
let stmts = &block.stmts;

let tokens = quote! {
#(#attrs)* #vis #sig {
worker::SendFuture::new(async {
#(#stmts)*
}).await
}
};

TokenStream::from(tokens)
}
9 changes: 7 additions & 2 deletions worker-sandbox/src/fetch.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ use worker::{
Response, Result, RouteContext,
};

pub async fn handle_fetch(req: Request, _env: Env, _data: SomeSharedData) -> Result<Response> {
pub async fn handle_fetch(_req: Request, _env: Env, _data: SomeSharedData) -> Result<Response> {
let req = Request::new("https://example.com", Method::Post)?;
let resp = Fetch::Request(req).send().await?;
let resp2 = Fetch::Url("https://example.com".parse()?).send().await?;
Expand All @@ -17,7 +17,12 @@ pub async fn handle_fetch(req: Request, _env: Env, _data: SomeSharedData) -> Res
))
}

pub async fn handle_fetch_json(req: Request, _env: Env, _data: SomeSharedData) -> Result<Response> {
#[worker::send]
pub async fn handle_fetch_json(
_req: Request,
_env: Env,
_data: SomeSharedData,
) -> Result<Response> {
let data: ApiData = Fetch::Url(
"https://jsonplaceholder.typicode.com/todos/1"
.parse()
Expand Down
14 changes: 14 additions & 0 deletions worker-sandbox/src/router.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ use crate::{
#[cfg(feature = "http")]
use std::convert::TryInto;
use std::sync::atomic::Ordering;

use worker::{console_log, Fetch, Headers, Request, Response, Result, RouteContext};

#[cfg(not(feature = "http"))]
Expand Down Expand Up @@ -44,6 +45,19 @@ macro_rules! handler (
}
);

#[cfg(feature = "http")]
#[debug_handler]
#[worker::send]
async fn test(
Extension(env): Extension<Env>,
Extension(data): Extension<SomeSharedData>,

Check warning on line 53 in worker-sandbox/src/router.rs

View workflow job for this annotation

GitHub Actions / Test

unused variable: `data`
req: axum::extract::Request,

Check warning on line 54 in worker-sandbox/src/router.rs

View workflow job for this annotation

GitHub Actions / Test

unused variable: `req`
) -> &'static str {
let foo = env.kv("SOME_NAMESPACE").unwrap();
foo.put("test", "test").unwrap().execute().await.unwrap();
"hello world"
}

#[cfg(feature = "http")]
pub fn make_router(data: SomeSharedData, env: Env) -> axum::Router {
axum::Router::new()
Expand Down
2 changes: 1 addition & 1 deletion worker-sandbox/src/user.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
use serde::Serialize;
use worker::{Date, DateInit, Env, Request, Response, Result, RouteContext};
use worker::{Date, DateInit, Env, Request, Response, Result};

use crate::SomeSharedData;

Expand Down
3 changes: 1 addition & 2 deletions worker-sandbox/src/ws.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,7 @@
use super::SomeSharedData;
use futures_util::StreamExt;
use worker::{
wasm_bindgen_futures, Env, Request, Response, Result, RouteContext, WebSocket, WebSocketPair,
WebsocketEvent,
wasm_bindgen_futures, Env, Request, Response, Result, WebSocket, WebSocketPair, WebsocketEvent,
};

pub async fn handle_websocket(_req: Request, env: Env, _data: SomeSharedData) -> Result<Response> {
Expand Down
3 changes: 0 additions & 3 deletions worker/src/formdata.rs
Original file line number Diff line number Diff line change
Expand Up @@ -186,6 +186,3 @@ impl From<web_sys::File> for File {
Self(file)
}
}

unsafe impl Send for File {}
unsafe impl Sync for File {}
4 changes: 3 additions & 1 deletion worker/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ pub use wasm_bindgen_futures;
pub use worker_kv as kv;

pub use cf::{Cf, TlsClientAuth};
pub use worker_macros::{durable_object, event};
pub use worker_macros::{durable_object, event, send};
#[doc(hidden)]
pub use worker_sys;
pub use worker_sys::{console_debug, console_error, console_log, console_warn};
Expand Down Expand Up @@ -100,6 +100,7 @@ pub use crate::request_init::*;
pub use crate::response::{Response, ResponseBody};
pub use crate::router::{RouteContext, RouteParams, Router};
pub use crate::schedule::*;
pub use crate::send_future::SendFuture;
pub use crate::socket::*;
pub use crate::streams::*;
pub use crate::websocket::*;
Expand Down Expand Up @@ -135,6 +136,7 @@ mod request_init;
mod response;
mod router;
mod schedule;
mod send_future;
mod socket;
mod streams;
mod websocket;
Expand Down
28 changes: 28 additions & 0 deletions worker/src/send_future.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
use futures_util::future::Future;
use pin_project::pin_project;
use std::pin::Pin;
use std::task::Context;
use std::task::Poll;

#[pin_project]
pub struct SendFuture<F> {
#[pin]
inner: F,
}

impl<F> SendFuture<F> {
pub fn new(inner: F) -> Self {
Self { inner }
}
}

unsafe impl<F> Send for SendFuture<F> {}

impl<F: Future> Future for SendFuture<F> {
type Output = F::Output;

fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
let this = self.project();
this.inner.poll(cx)
}
}
3 changes: 0 additions & 3 deletions worker/src/streams.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,6 @@ pub struct ByteStream {
pub(crate) inner: IntoStream<'static>,
}

unsafe impl Send for ByteStream {}
unsafe impl Sync for ByteStream {}

impl Stream for ByteStream {
type Item = Result<Vec<u8>>;

Expand Down
4 changes: 2 additions & 2 deletions worker/src/websocket.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
use crate::{Error, Method, Request, Result};
use futures_channel::mpsc::UnboundedReceiver;
use futures_util::{Future, FutureExt, Stream};
use futures_util::Stream;
use serde::Serialize;
use url::Url;
use worker_sys::ext::WebSocketExt;
Expand All @@ -16,7 +16,7 @@ use std::rc::Rc;
use std::task::{Context, Poll};
use wasm_bindgen::convert::FromWasmAbi;
use wasm_bindgen::prelude::Closure;
use wasm_bindgen::{JsCast, JsValue};
use wasm_bindgen::JsCast;
#[cfg(feature = "http")]
use wasm_bindgen_futures::JsFuture;

Expand Down

0 comments on commit fca35f4

Please sign in to comment.