Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Server: require State to be Clone #644

Merged
merged 2 commits into from
Jul 17, 2020
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ route-recognizer = "0.2.0"
logtest = "2.0.0"
async-trait = "0.1.36"
futures-util = "0.3.5"
pin-project-lite = "0.1.7"

[dev-dependencies]
async-std = { version = "1.6.0", features = ["unstable", "attributes"] }
Expand Down
7 changes: 5 additions & 2 deletions examples/graphql.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
use std::sync::Arc;

use async_std::task;
use juniper::{http::graphiql, http::GraphQLRequest, RootNode};
use std::sync::RwLock;
Expand Down Expand Up @@ -37,8 +39,9 @@ impl NewUser {
}
}

#[derive(Clone)]
pub struct State {
users: RwLock<Vec<User>>,
users: Arc<RwLock<Vec<User>>>,
}
impl juniper::Context for State {}

Expand Down Expand Up @@ -96,7 +99,7 @@ async fn handle_graphiql(_: Request<State>) -> tide::Result<impl Into<Response>>
fn main() -> std::io::Result<()> {
task::block_on(async {
let mut app = Server::with_state(State {
users: RwLock::new(Vec::new()),
users: Arc::new(RwLock::new(Vec::new())),
});
app.at("/").get(Redirect::permanent("/graphiql"));
app.at("/graphql").post(handle_graphql);
Expand Down
4 changes: 2 additions & 2 deletions examples/middleware.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ struct User {
name: String,
}

#[derive(Default, Debug)]
#[derive(Clone, Default, Debug)]
struct UserDatabase;
impl UserDatabase {
async fn find_user(&self) -> Option<User> {
Expand Down Expand Up @@ -62,7 +62,7 @@ impl RequestCounterMiddleware {
struct RequestCount(usize);

#[tide::utils::async_trait]
impl<State: Send + Sync + 'static> Middleware<State> for RequestCounterMiddleware {
impl<State: Clone + Send + Sync + 'static> Middleware<State> for RequestCounterMiddleware {
async fn handle(&self, mut req: Request<State>, next: Next<'_, State>) -> Result {
let count = self.requests_counted.fetch_add(1, Ordering::Relaxed);
tide::log::trace!("request counter", { count: count });
Expand Down
8 changes: 5 additions & 3 deletions examples/upload.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
use std::sync::Arc;

use async_std::{fs::OpenOptions, io};
use tempfile::TempDir;
use tide::prelude::*;
Expand All @@ -6,15 +8,15 @@ use tide::{Body, Request, Response, StatusCode};
#[async_std::main]
async fn main() -> Result<(), std::io::Error> {
tide::log::start();
let mut app = tide::with_state(tempfile::tempdir()?);
let mut app = tide::with_state(Arc::new(tempfile::tempdir()?));

// To test this example:
// $ cargo run --example upload
// $ curl -T ./README.md locahost:8080 # this writes the file to a temp directory
// $ curl localhost:8080/README.md # this reads the file from the same temp directory

app.at(":file")
.put(|req: Request<TempDir>| async move {
.put(|req: Request<Arc<TempDir>>| async move {
let path: String = req.param("file")?;
let fs_path = req.state().path().join(path);

Expand All @@ -33,7 +35,7 @@ async fn main() -> Result<(), std::io::Error> {

Ok(json!({ "bytes": bytes_written }))
})
.get(|req: Request<TempDir>| async move {
.get(|req: Request<Arc<TempDir>>| async move {
let path: String = req.param("file")?;
let fs_path = req.state().path().join(path);

Expand Down
2 changes: 1 addition & 1 deletion src/cookies/middleware.rs
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ impl CookiesMiddleware {
}

#[async_trait]
impl<State: Send + Sync + 'static> Middleware<State> for CookiesMiddleware {
impl<State: Clone + Send + Sync + 'static> Middleware<State> for CookiesMiddleware {
async fn handle(&self, mut ctx: Request<State>, next: Next<'_, State>) -> crate::Result {
let cookie_jar = if let Some(cookie_data) = ctx.ext::<CookieData>() {
cookie_data.content.clone()
Expand Down
8 changes: 4 additions & 4 deletions src/endpoint.rs
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ use crate::{Middleware, Request, Response};
///
/// Tide routes will also accept endpoints with `Fn` signatures of this form, but using the `async` keyword has better ergonomics.
#[async_trait]
pub trait Endpoint<State: Send + Sync + 'static>: Send + Sync + 'static {
pub trait Endpoint<State: Clone + Send + Sync + 'static>: Send + Sync + 'static {
/// Invoke the endpoint within the given context
async fn call(&self, req: Request<State>) -> crate::Result;
}
Expand All @@ -55,7 +55,7 @@ pub(crate) type DynEndpoint<State> = dyn Endpoint<State>;
#[async_trait]
impl<State, F, Fut, Res> Endpoint<State> for F
where
State: Send + Sync + 'static,
State: Clone + Send + Sync + 'static,
F: Send + Sync + 'static + Fn(Request<State>) -> Fut,
Fut: Future<Output = Result<Res>> + Send + 'static,
Res: Into<Response> + 'static,
Expand Down Expand Up @@ -93,7 +93,7 @@ impl<E, State> std::fmt::Debug for MiddlewareEndpoint<E, State> {

impl<E, State> MiddlewareEndpoint<E, State>
where
State: Send + Sync + 'static,
State: Clone + Send + Sync + 'static,
E: Endpoint<State>,
{
pub fn wrap_with_middleware(ep: E, middleware: &[Arc<dyn Middleware<State>>]) -> Self {
Expand All @@ -107,7 +107,7 @@ where
#[async_trait]
impl<E, State> Endpoint<State> for MiddlewareEndpoint<E, State>
where
State: Send + Sync + 'static,
State: Clone + Send + Sync + 'static,
E: Endpoint<State>,
{
async fn call(&self, req: Request<State>) -> crate::Result {
Expand Down
6 changes: 2 additions & 4 deletions src/fs/serve_dir.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ impl ServeDir {
#[async_trait::async_trait]
impl<State> Endpoint<State> for ServeDir
where
State: Send + Sync + 'static,
State: Clone + Send + Sync + 'static,
{
async fn call(&self, req: Request<State>) -> Result {
let path = req.url().path();
Expand Down Expand Up @@ -60,8 +60,6 @@ where
mod test {
use super::*;

use async_std::sync::Arc;

use std::fs::{self, File};
use std::io::Write;

Expand All @@ -83,7 +81,7 @@ mod test {
let request = crate::http::Request::get(
crate::http::Url::parse(&format!("http://localhost/{}", path)).unwrap(),
);
crate::Request::new(Arc::new(()), request, vec![])
crate::Request::new((), request, vec![])
}

#[async_std::test]
Expand Down
3 changes: 2 additions & 1 deletion src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -259,6 +259,7 @@ pub fn new() -> server::Server<()> {
/// use tide::Request;
///
/// /// The shared application state.
/// #[derive(Clone)]
/// struct State {
/// name: String,
/// }
Expand All @@ -279,7 +280,7 @@ pub fn new() -> server::Server<()> {
/// ```
pub fn with_state<State>(state: State) -> server::Server<State>
where
State: Send + Sync + 'static,
State: Clone + Send + Sync + 'static,
{
Server::with_state(state)
}
Expand Down
4 changes: 2 additions & 2 deletions src/listener/concurrent_listener.rs
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ use futures_util::stream::{futures_unordered::FuturesUnordered, StreamExt};
#[derive(Default)]
pub struct ConcurrentListener<State>(Vec<Box<dyn Listener<State>>>);

impl<State: Send + Sync + 'static> ConcurrentListener<State> {
impl<State: Clone + Send + Sync + 'static> ConcurrentListener<State> {
/// creates a new ConcurrentListener
pub fn new() -> Self {
Self(vec![])
Expand Down Expand Up @@ -78,7 +78,7 @@ impl<State: Send + Sync + 'static> ConcurrentListener<State> {
}

#[async_trait::async_trait]
impl<State: Send + Sync + 'static> Listener<State> for ConcurrentListener<State> {
impl<State: Clone + Send + Sync + 'static> Listener<State> for ConcurrentListener<State> {
async fn listen(&mut self, app: Server<State>) -> io::Result<()> {
let mut futures_unordered = FuturesUnordered::new();

Expand Down
4 changes: 2 additions & 2 deletions src/listener/failover_listener.rs
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ use async_std::io;
#[derive(Default)]
pub struct FailoverListener<State>(Vec<Box<dyn Listener<State>>>);

impl<State: Send + Sync + 'static> FailoverListener<State> {
impl<State: Clone + Send + Sync + 'static> FailoverListener<State> {
/// creates a new FailoverListener
pub fn new() -> Self {
Self(vec![])
Expand Down Expand Up @@ -80,7 +80,7 @@ impl<State: Send + Sync + 'static> FailoverListener<State> {
}

#[async_trait::async_trait]
impl<State: Send + Sync + 'static> Listener<State> for FailoverListener<State> {
impl<State: Clone + Send + Sync + 'static> Listener<State> for FailoverListener<State> {
async fn listen(&mut self, app: Server<State>) -> io::Result<()> {
for listener in self.0.iter_mut() {
let app = app.clone();
Expand Down
2 changes: 1 addition & 1 deletion src/listener/parsed_listener.rs
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ impl Display for ParsedListener {
}

#[async_trait::async_trait]
impl<State: Send + Sync + 'static> Listener<State> for ParsedListener {
impl<State: Clone + Send + Sync + 'static> Listener<State> for ParsedListener {
async fn listen(&mut self, app: Server<State>) -> io::Result<()> {
match self {
#[cfg(unix)]
Expand Down
4 changes: 2 additions & 2 deletions src/listener/tcp_listener.rs
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ impl TcpListener {
}
}

fn handle_tcp<State: Send + Sync + 'static>(app: Server<State>, stream: TcpStream) {
fn handle_tcp<State: Clone + Send + Sync + 'static>(app: Server<State>, stream: TcpStream) {
task::spawn(async move {
let local_addr = stream.local_addr().ok();
let peer_addr = stream.peer_addr().ok();
Expand All @@ -69,7 +69,7 @@ fn handle_tcp<State: Send + Sync + 'static>(app: Server<State>, stream: TcpStrea
}

#[async_trait::async_trait]
impl<State: Send + Sync + 'static> Listener<State> for TcpListener {
impl<State: Clone + Send + Sync + 'static> Listener<State> for TcpListener {
async fn listen(&mut self, app: Server<State>) -> io::Result<()> {
self.connect().await?;
let listener = self.listener()?;
Expand Down
38 changes: 20 additions & 18 deletions src/listener/to_listener.rs
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ use std::net::ToSocketAddrs;
/// # Other implementations
/// See below for additional provided implementations of ToListener.

pub trait ToListener<State: Send + Sync + 'static> {
pub trait ToListener<State: Clone + Send + Sync + 'static> {
type Listener: Listener<State>;
/// Transform self into a
/// [`Listener`](crate::listener::Listener). Unless self is
Expand All @@ -63,7 +63,7 @@ pub trait ToListener<State: Send + Sync + 'static> {
fn to_listener(self) -> io::Result<Self::Listener>;
}

impl<State: Send + Sync + 'static> ToListener<State> for Url {
impl<State: Clone + Send + Sync + 'static> ToListener<State> for Url {
type Listener = ParsedListener;

fn to_listener(self) -> io::Result<Self::Listener> {
Expand Down Expand Up @@ -106,14 +106,14 @@ impl<State: Send + Sync + 'static> ToListener<State> for Url {
}
}

impl<State: Send + Sync + 'static> ToListener<State> for String {
impl<State: Clone + Send + Sync + 'static> ToListener<State> for String {
type Listener = ParsedListener;
fn to_listener(self) -> io::Result<Self::Listener> {
ToListener::<State>::to_listener(self.as_str())
}
}

impl<State: Send + Sync + 'static> ToListener<State> for &str {
impl<State: Clone + Send + Sync + 'static> ToListener<State> for &str {
type Listener = ParsedListener;

fn to_listener(self) -> io::Result<Self::Listener> {
Expand All @@ -133,36 +133,36 @@ impl<State: Send + Sync + 'static> ToListener<State> for &str {
}

#[cfg(unix)]
impl<State: Send + Sync + 'static> ToListener<State> for async_std::path::PathBuf {
impl<State: Clone + Send + Sync + 'static> ToListener<State> for async_std::path::PathBuf {
type Listener = UnixListener;
fn to_listener(self) -> io::Result<Self::Listener> {
Ok(UnixListener::from_path(self))
}
}

#[cfg(unix)]
impl<State: Send + Sync + 'static> ToListener<State> for std::path::PathBuf {
impl<State: Clone + Send + Sync + 'static> ToListener<State> for std::path::PathBuf {
type Listener = UnixListener;
fn to_listener(self) -> io::Result<Self::Listener> {
Ok(UnixListener::from_path(self))
}
}

impl<State: Send + Sync + 'static> ToListener<State> for async_std::net::TcpListener {
impl<State: Clone + Send + Sync + 'static> ToListener<State> for async_std::net::TcpListener {
type Listener = TcpListener;
fn to_listener(self) -> io::Result<Self::Listener> {
Ok(TcpListener::from_listener(self))
}
}

impl<State: Send + Sync + 'static> ToListener<State> for std::net::TcpListener {
impl<State: Clone + Send + Sync + 'static> ToListener<State> for std::net::TcpListener {
type Listener = TcpListener;
fn to_listener(self) -> io::Result<Self::Listener> {
Ok(TcpListener::from_listener(self))
}
}

impl<State: Send + Sync + 'static> ToListener<State> for (&str, u16) {
impl<State: Clone + Send + Sync + 'static> ToListener<State> for (&str, u16) {
type Listener = TcpListener;

fn to_listener(self) -> io::Result<Self::Listener> {
Expand All @@ -171,65 +171,67 @@ impl<State: Send + Sync + 'static> ToListener<State> for (&str, u16) {
}

#[cfg(unix)]
impl<State: Send + Sync + 'static> ToListener<State> for async_std::os::unix::net::UnixListener {
impl<State: Clone + Send + Sync + 'static> ToListener<State>
for async_std::os::unix::net::UnixListener
{
type Listener = UnixListener;
fn to_listener(self) -> io::Result<Self::Listener> {
Ok(UnixListener::from_listener(self))
}
}

#[cfg(unix)]
impl<State: Send + Sync + 'static> ToListener<State> for std::os::unix::net::UnixListener {
impl<State: Clone + Send + Sync + 'static> ToListener<State> for std::os::unix::net::UnixListener {
type Listener = UnixListener;
fn to_listener(self) -> io::Result<Self::Listener> {
Ok(UnixListener::from_listener(self))
}
}

impl<State: Send + Sync + 'static> ToListener<State> for TcpListener {
impl<State: Clone + Send + Sync + 'static> ToListener<State> for TcpListener {
type Listener = Self;
fn to_listener(self) -> io::Result<Self::Listener> {
Ok(self)
}
}

#[cfg(unix)]
impl<State: Send + Sync + 'static> ToListener<State> for UnixListener {
impl<State: Clone + Send + Sync + 'static> ToListener<State> for UnixListener {
type Listener = Self;
fn to_listener(self) -> io::Result<Self::Listener> {
Ok(self)
}
}

impl<State: Send + Sync + 'static> ToListener<State> for ConcurrentListener<State> {
impl<State: Clone + Send + Sync + 'static> ToListener<State> for ConcurrentListener<State> {
type Listener = Self;
fn to_listener(self) -> io::Result<Self::Listener> {
Ok(self)
}
}

impl<State: Send + Sync + 'static> ToListener<State> for ParsedListener {
impl<State: Clone + Send + Sync + 'static> ToListener<State> for ParsedListener {
type Listener = Self;
fn to_listener(self) -> io::Result<Self::Listener> {
Ok(self)
}
}

impl<State: Send + Sync + 'static> ToListener<State> for FailoverListener<State> {
impl<State: Clone + Send + Sync + 'static> ToListener<State> for FailoverListener<State> {
type Listener = Self;
fn to_listener(self) -> io::Result<Self::Listener> {
Ok(self)
}
}

impl<State: Send + Sync + 'static> ToListener<State> for std::net::SocketAddr {
impl<State: Clone + Send + Sync + 'static> ToListener<State> for std::net::SocketAddr {
type Listener = TcpListener;
fn to_listener(self) -> io::Result<Self::Listener> {
Ok(TcpListener::from_addrs(vec![self]))
}
}

impl<TL: ToListener<State>, State: Send + Sync + 'static> ToListener<State> for Vec<TL> {
impl<TL: ToListener<State>, State: Clone + Send + Sync + 'static> ToListener<State> for Vec<TL> {
type Listener = ConcurrentListener<State>;
fn to_listener(self) -> io::Result<Self::Listener> {
let mut concurrent_listener = ConcurrentListener::new();
Expand Down
Loading