diff --git a/.travis.yml b/.travis.yml index fe53702c..b34256e8 100644 --- a/.travis.yml +++ b/.travis.yml @@ -1,34 +1,12 @@ language: rust -sudo: false - rust: - nightly +sudo: false +cache: cargo os: + - osx - linux -addons: - apt: - packages: - - libcurl4-openssl-dev - - libelf-dev - - libdw-dev - -before_script: -- | - pip install 'travis-cargo<0.2' --user && - export PATH=$HOME/.local/bin:$PATH - script: -- | - travis-cargo build -- --features tls && travis-cargo test -- --features tls && travis-cargo bench -- --features tls && - rustdoc --test README.md -L target/debug/deps -L target/debug && - travis-cargo build && travis-cargo test && travis-cargo bench - -after_success: -- travis-cargo coveralls --no-sudo - -env: - global: - # override the default `--features unstable` used for the nightly branch - - TRAVIS_CARGO_NIGHTLY_FEATURE="" + - cargo test --all --all-features diff --git a/Cargo.toml b/Cargo.toml index 0505ea54..3ab4bb2b 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,56 +1,10 @@ -[package] -name = "tarpc" -version = "0.12.1" -authors = ["Adam Wright ", "Tim Kuehn "] -license = "MIT" -documentation = "https://docs.rs/tarpc" -homepage = "https://github.com/google/tarpc" -repository = "https://github.com/google/tarpc" -keywords = ["rpc", "network", "server", "api", "tls"] -categories = ["asynchronous", "network-programming"] -readme = "README.md" -description = "An RPC framework for Rust with a focus on ease of use." - -[badges] -travis-ci = { repository = "google/tarpc" } - -[dependencies] -bincode = "1.0" -byteorder = "1.0" -bytes = "0.4" -cfg-if = "0.1.0" -futures = "0.1.11" -lazy_static = "1.0" -log = "0.4" -net2 = "0.2" -num_cpus = "1.0" -serde = "1.0" -serde_derive = "1.0" -tarpc-plugins = { path = "src/plugins", version = "0.4.0" } -thread-pool = "0.1.1" -tokio-codec = "0.1" -tokio-core = "0.1.6" -tokio-io = "0.1" -tokio-proto = "0.1.1" -tokio-service = "0.1" - -# Optional dependencies -native-tls = { version = "0.1", optional = true } -tokio-tls = { version = "0.1", optional = true } - -[dev-dependencies] -chrono = "0.4" -env_logger = "0.5" -futures-cpupool = "0.1" -clap = "2.0" -serde_bytes = "0.10" - -[target.'cfg(target_os = "macos")'.dev-dependencies] -security-framework = "0.2" - -[features] -default = [] -tls = ["tokio-tls", "native-tls"] -unstable = ["serde/unstable"] - [workspace] + +members = [ + "example-service", + "rpc", + "trace", + "bincode-transport", + "tarpc", + "plugins", +] diff --git a/README.md b/README.md index bea13b39..1fd2dd9e 100644 --- a/README.md +++ b/README.md @@ -41,265 +41,105 @@ tarpc = "0.12.0" tarpc-plugins = "0.4.0" ``` -## Example: Sync - -tarpc has two APIs: `sync` for blocking code and `future` for asynchronous -code. Here's how to use the sync api. - -```rust -#![feature(plugin, use_extern_macros, proc_macro_path_invoc)] -#![plugin(tarpc_plugins)] - -#[macro_use] -extern crate tarpc; - -use std::sync::mpsc; -use std::thread; -use tarpc::sync::{client, server}; -use tarpc::sync::client::ClientExt; -use tarpc::util::{FirstSocketAddr, Never}; - -service! { - rpc hello(name: String) -> String; -} - -#[derive(Clone)] -struct HelloServer; - -impl SyncService for HelloServer { - fn hello(&self, name: String) -> Result { - Ok(format!("Hello, {}!", name)) - } -} - -fn main() { - let (tx, rx) = mpsc::channel(); - thread::spawn(move || { - let mut handle = HelloServer.listen("localhost:0", server::Options::default()) - .unwrap(); - tx.send(handle.addr()).unwrap(); - handle.run(); - }); - let client = SyncClient::connect(rx.recv().unwrap(), client::Options::default()).unwrap(); - println!("{}", client.hello("Mom".to_string()).unwrap()); -} -``` The `service!` macro expands to a collection of items that form an rpc service. In the above example, the macro is called within the -`hello_service` module. This module will contain `SyncClient`, `AsyncClient`, -and `FutureClient` types, and `SyncService` and `AsyncService` traits. There is -also a `ServiceExt` trait that provides starter `fn`s for services, with an -umbrella impl for all services. These generated types make it easy and -ergonomic to write servers without dealing with sockets or serialization +`hello_service` module. This module will contain a `Client` stub and `Service` trait. There is +These generated types make it easy and ergonomic to write servers without dealing with serialization directly. Simply implement one of the generated traits, and you're off to the -races! See the `tarpc_examples` package for more examples. +races! -## Example: Futures +## Example: -Here's the same service, implemented using futures. +Here's a small service. ```rust -#![feature(plugin, use_extern_macros, proc_macro_path_invoc)] +#![feature(plugin, futures_api, pin, arbitrary_self_types, await_macro, async_await)] #![plugin(tarpc_plugins)] -extern crate futures; -#[macro_use] -extern crate tarpc; -extern crate tokio_core; - -use futures::Future; -use tarpc::future::{client, server}; -use tarpc::future::client::ClientExt; -use tarpc::util::{FirstSocketAddr, Never}; -use tokio_core::reactor; - -service! { +use futures::{ + compat::TokioDefaultSpawner, + future::{self, Ready}, + prelude::*, + spawn, +}; +use tarpc::rpc::{ + client, context, + server::{self, Handler, Server}, +}; +use std::io; + +// This is the service definition. It looks a lot like a trait definition. +// It defines one RPC, hello, which takes one arg, name, and returns a String. +tarpc::service! { rpc hello(name: String) -> String; } +// This is the type that implements the generated Service trait. It is the business logic +// and is used to start the server. #[derive(Clone)] struct HelloServer; -impl FutureService for HelloServer { - type HelloFut = Result; - - fn hello(&self, name: String) -> Self::HelloFut { - Ok(format!("Hello, {}!", name)) - } -} - -fn main() { - let mut reactor = reactor::Core::new().unwrap(); - let (handle, server) = HelloServer.listen("localhost:10000".first_socket_addr(), - &reactor.handle(), - server::Options::default()) - .unwrap(); - reactor.handle().spawn(server); - let options = client::Options::default().handle(reactor.handle()); - reactor.run(FutureClient::connect(handle.addr(), options) - .map_err(tarpc::Error::from) - .and_then(|client| client.hello("Mom".to_string())) - .map(|resp| println!("{}", resp))) - .unwrap(); -} -``` - -## Example: Futures + TLS - -By default, tarpc internally uses a [`TcpStream`] for communication between your clients and -servers. However, TCP by itself has no encryption. As a result, your communication will be sent in -the clear. If you want your RPC communications to be encrypted, you can choose to use [TLS]. TLS -operates as an encryption layer on top of TCP. When using TLS, your communication will occur over a -[`TlsStream`]. You can add the ability to make TLS clients and servers by adding `tarpc` -with the `tls` feature flag enabled. - -When using TLS, some additional information is required. You will need to make [`TlsAcceptor`] and -`client::tls::Context` structs; `client::tls::Context` requires a [`TlsConnector`]. The -[`TlsAcceptor`] and [`TlsConnector`] types are defined in the [native-tls]. tarpc re-exports -external TLS-related types in its `native_tls` module (`tarpc::native_tls`). - -[TLS]: https://en.wikipedia.org/wiki/Transport_Layer_Security -[`TcpStream`]: https://docs.rs/tokio-core/0.1/tokio_core/net/struct.TcpStream.html -[`TlsStream`]: https://docs.rs/native-tls/0.1/native_tls/struct.TlsStream.html -[`TlsAcceptor`]: https://docs.rs/native-tls/0.1/native_tls/struct.TlsAcceptor.html -[`TlsConnector`]: https://docs.rs/native-tls/0.1/native_tls/struct.TlsConnector.html -[native-tls]: https://github.com/sfackler/rust-native-tls - -Both TLS streams and TCP streams are supported in the same binary when the `tls` feature is enabled. -However, if you are working with both stream types, ensure that you use the TLS clients with TLS -servers and TCP clients with TCP servers. - -```rust,no_run -#![feature(plugin, use_extern_macros, proc_macro_path_invoc)] -#![plugin(tarpc_plugins)] - -extern crate futures; -#[macro_use] -extern crate tarpc; -extern crate tokio_core; - -use futures::Future; -use tarpc::future::{client, server}; -use tarpc::future::client::ClientExt; -use tarpc::tls; -use tarpc::util::{FirstSocketAddr, Never}; -use tokio_core::reactor; -use tarpc::native_tls::{Pkcs12, TlsAcceptor}; - -service! { - rpc hello(name: String) -> String; -} - -#[derive(Clone)] -struct HelloServer; +impl Service for HelloServer { + // Each defined rpc generates two items in the trait, a fn that serves the RPC, and + // an associated type representing the future output by the fn. -impl FutureService for HelloServer { - type HelloFut = Result; + type HelloFut = Ready; - fn hello(&self, name: String) -> Self::HelloFut { - Ok(format!("Hello, {}!", name)) + fn hello(&self, _: context::Context, name: String) -> Self::HelloFut { + future::ready(format!("Hello, {}!", name)) } } -fn get_acceptor() -> TlsAcceptor { - let buf = include_bytes!("test/identity.p12"); - let pkcs12 = Pkcs12::from_der(buf, "password").unwrap(); - TlsAcceptor::builder(pkcs12).unwrap().build().unwrap() -} - -fn main() { - let mut reactor = reactor::Core::new().unwrap(); - let acceptor = get_acceptor(); - let (handle, server) = HelloServer.listen("localhost:10000".first_socket_addr(), - &reactor.handle(), - server::Options::default().tls(acceptor)).unwrap(); - reactor.handle().spawn(server); - let options = client::Options::default() - .handle(reactor.handle()) - .tls(tls::client::Context::new("foobar.com").unwrap()); - reactor.run(FutureClient::connect(handle.addr(), options) - .map_err(tarpc::Error::from) - .and_then(|client| client.hello("Mom".to_string())) - .map(|resp| println!("{}", resp))) - .unwrap(); -} -``` - -## Tips +async fn run() -> io::Result<()> { + // bincode_transport is provided by the associated crate bincode-transport. It makes it easy + // to start up a serde-powered bincode serialization strategy over TCP. + let transport = bincode_transport::listen(&"0.0.0.0:0".parse().unwrap())?; + let addr = transport.local_addr(); -### Sync vs Futures + // The server is configured with the defaults. + let server = Server::new(server::Config::default()) + // Server can listen on any type that implements the Transport trait. + .incoming(transport) + // Close the stream after the client connects + .take(1) + // serve is generated by the service! macro. It takes as input any type implementing + // the generated Service trait. + .respond_with(serve(HelloServer)); -A single `service!` invocation generates code for both synchronous and future-based applications. -It's up to the user whether they want to implement the sync API or the futures API. The sync API has -the simplest programming model, at the cost of some overhead - each RPC is handled in its own -thread. The futures API is based on tokio and can run on any tokio-compatible executor. This mean a -service that implements the futures API for a tarpc service can run on a single thread, avoiding -context switches and the memory overhead of having a thread per RPC. + spawn!(server).unwrap(); -### Errors + let transport = await!(bincode_transport::connect(&addr))?; -All generated tarpc RPC methods return either `tarpc::Result` or something like `Future`. The error type defaults to `tarpc::util::Never` (a wrapper for `!` which implements -`std::error::Error`) if no error type is explicitly specified in the `service!` macro invocation. An -error type can be specified like so: + // new_stub is generated by the service! macro. Like Server, it takes a config and any + // Transport as input, and returns a Client, also generated by the macro. + // by the service mcro. + let mut client = await!(new_stub(client::Config::default(), transport)); -```rust,ignore -use tarpc::util::Message; + // The client has an RPC method for each RPC defined in service!. It takes the same args + // as defined, with the addition of a Context, which is always the first arg. The Context + // specifies a deadline and trace information which can be helpful in debugging requests. + let hello = await!(client.hello(context::current(), "Stim".to_string()))?; -service! { - rpc hello(name: String) -> String | Message -} -``` + println!("{}", hello); -`tarpc::util::Message` is just a wrapper around string that implements `std::error::Error` provided -for service implementations that don't require complex error handling. The pipe is used as syntax -for specifying the error type in a way that's agnostic of whether the service implementation is -synchronous or future-based. Note that in the simpler examples in the readme, no pipe is used, and -the macro automatically chooses `tarpc::util::Never` as the error type. - -The above declaration would produce the following synchronous service trait: - -```rust,ignore -trait SyncService { - fn hello(&self, name: String) -> Result; + Ok(()) } -``` -and the following future-based trait: - -```rust,ignore -trait FutureService { - type HelloFut: IntoFuture; - - fn hello(&mut self, name: String) -> Self::HelloFut; +fn main() { + tokio::run(run() + .map_err(|e| eprintln!("Oh no: {}", e)) + .boxed() + .compat(TokioDefaultSpawner), + ); } ``` -## Documentation +## Service Documentation Use `cargo doc` as you normally would to see the documentation created for all items expanded by a `service!` invocation. -## Additional Features - -- Concurrent requests from a single client. -- Compatible with tokio services. -- Run any number of clients and services on a single event loop. -- Any type that `impl`s `serde`'s `Serialize` and `Deserialize` can be used in - rpc signatures. -- Attributes can be specified on rpc methods. These will be included on both the - services' trait methods as well as on the clients' stub methods. - -## Gaps/Potential Improvements (not necessarily actively being worked on) - -- Configurable server rate limiting. -- Automatic client retries with exponential backoff when server is busy. -- Load balancing -- Service discovery -- Automatically reconnect on the client side when the connection cuts out. -- Support generic serialization protocols. - ## Contributing To contribute to tarpc, please see [CONTRIBUTING](CONTRIBUTING.md). diff --git a/benches/latency.rs b/benches/latency.rs deleted file mode 100644 index b6de5674..00000000 --- a/benches/latency.rs +++ /dev/null @@ -1,58 +0,0 @@ -// Copyright 2016 Google Inc. All Rights Reserved. -// -// Licensed under the MIT License, . -// This file may not be copied, modified, or distributed except according to those terms. - -#![feature(plugin, test, use_extern_macros, proc_macro_path_invoc)] -#![plugin(tarpc_plugins)] - -#[macro_use] -extern crate tarpc; -#[cfg(test)] -extern crate test; -extern crate env_logger; -extern crate futures; -extern crate tokio_core; - -use tarpc::future::{client, server}; -use tarpc::future::client::ClientExt; -use tarpc::util::{FirstSocketAddr, Never}; -#[cfg(test)] -use test::Bencher; -use tokio_core::reactor; - -service! { - rpc ack(); -} - -#[derive(Clone)] -struct Server; - -impl FutureService for Server { - type AckFut = futures::Finished<(), Never>; - fn ack(&self) -> Self::AckFut { - futures::finished(()) - } -} - -#[cfg(test)] -#[bench] -fn latency(bencher: &mut Bencher) { - let _ = env_logger::try_init(); - let mut reactor = reactor::Core::new().unwrap(); - let (handle, server) = Server - .listen( - "localhost:0".first_socket_addr(), - &reactor.handle(), - server::Options::default(), - ) - .unwrap(); - reactor.handle().spawn(server); - let client = FutureClient::connect( - handle.addr(), - client::Options::default().handle(reactor.handle()), - ); - let client = reactor.run(client).unwrap(); - - bencher.iter(|| reactor.run(client.ack()).unwrap()); -} diff --git a/bincode-transport/Cargo.toml b/bincode-transport/Cargo.toml new file mode 100644 index 00000000..b006d714 --- /dev/null +++ b/bincode-transport/Cargo.toml @@ -0,0 +1,35 @@ +cargo-features = ["rename-dependency"] + +[package] +name = "bincode-transport" +version = "0.1.0" +authors = ["Tim Kuehn "] +edition = '2018' + +[dependencies] +bincode = { version = "1.0", features = ["i128"] } +bytes = "0.4" +futures_legacy = { version = "0.1", package = "futures" } +pin-utils = "0.1.0-alpha.2" +rpc = { path = "../rpc", features = ["serde"] } +serde = "1.0" +tokio = "0.1" +tokio-io = "0.1" +tokio-serde-bincode = "0.1" +tokio-tcp = "0.1" +tokio-serde = "0.2" + +[target.'cfg(not(test))'.dependencies] +futures-preview = { git = "https://github.com/rust-lang-nursery/futures-rs", features = ["compat"] } + +[dev-dependencies] +futures-preview = { git = "https://github.com/rust-lang-nursery/futures-rs", features = ["compat", "tokio-compat"] } +env_logger = "0.5" +humantime = "1.0" +log = "0.4" +rand = "0.5" +tokio = "0.1" +tokio-executor = "0.1" +tokio-reactor = "0.1" +tokio-serde = "0.2" +tokio-timer = "0.2" diff --git a/bincode-transport/rustfmt.toml b/bincode-transport/rustfmt.toml new file mode 100644 index 00000000..0ef5137d --- /dev/null +++ b/bincode-transport/rustfmt.toml @@ -0,0 +1 @@ +edition = "Edition2018" diff --git a/bincode-transport/src/lib.rs b/bincode-transport/src/lib.rs new file mode 100644 index 00000000..8ada0a88 --- /dev/null +++ b/bincode-transport/src/lib.rs @@ -0,0 +1,285 @@ +// Copyright 2018 Google Inc. All Rights Reserved. +// +// Licensed under the MIT License, . +// This file may not be copied, modified, or distributed except according to those terms. + +//! A TCP [`Transport`] that serializes as bincode. + +#![feature( + futures_api, + pin, + arbitrary_self_types, + underscore_imports, + await_macro, + async_await, +)] +#![deny(missing_docs, missing_debug_implementations)] + +mod vendored; + +use bytes::{Bytes, BytesMut}; +use crate::vendored::tokio_serde_bincode::{IoErrorWrapper, ReadBincode, WriteBincode}; +use futures::{ + Poll, + compat::{Compat01As03, Future01CompatExt, Stream01CompatExt}, + prelude::*, + ready, task, +}; +use futures_legacy::{ + executor::{ + self as executor01, Notify as Notify01, NotifyHandle as NotifyHandle01, + UnsafeNotify as UnsafeNotify01, + }, + sink::SinkMapErr as SinkMapErr01, + sink::With as With01, + stream::MapErr as MapErr01, + Async as Async01, AsyncSink as AsyncSink01, Sink as Sink01, Stream as Stream01, +}; +use pin_utils::unsafe_pinned; +use serde::{Deserialize, Serialize}; +use std::{fmt, io, marker::PhantomData, net::SocketAddr, pin::Pin, task::LocalWaker}; +use tokio::codec::{Framed, LengthDelimitedCodec, length_delimited}; +use tokio_tcp::{self, TcpListener, TcpStream}; + +/// Returns a new bincode transport that reads from and writes to `io`. +pub fn new(io: TcpStream) -> Transport +where + Item: for<'de> Deserialize<'de>, + SinkItem: Serialize, +{ + let peer_addr = io.peer_addr(); + let local_addr = io.local_addr(); + let inner = length_delimited::Builder::new() + .max_frame_length(8_000_000) + .new_framed(io) + .map_err(IoErrorWrapper as _) + .sink_map_err(IoErrorWrapper as _) + .with(freeze as _); + let inner = WriteBincode::new(inner); + let inner = ReadBincode::new(inner); + + Transport { + inner, + staged_item: None, + peer_addr, + local_addr, + } +} + +fn freeze(bytes: BytesMut) -> Result { + Ok(bytes.freeze()) +} + +/// Connects to `addr`, wrapping the connection in a bincode transport. +pub async fn connect(addr: &SocketAddr) -> io::Result> +where + Item: for<'de> Deserialize<'de>, + SinkItem: Serialize, +{ + let stream = await!(TcpStream::connect(addr).compat())?; + Ok(new(stream)) +} + +/// Listens on `addr`, wrapping accepted connections in bincode transports. +pub fn listen(addr: &SocketAddr) -> io::Result> +where + Item: for<'de> Deserialize<'de>, + SinkItem: Serialize, +{ + let listener = TcpListener::bind(addr)?; + let local_addr = listener.local_addr()?; + let incoming = listener.incoming().compat(); + Ok(Incoming { + incoming, + local_addr, + ghost: PhantomData, + }) +} + +/// A [`TcpListener`] that wraps connections in bincode transports. +#[derive(Debug)] +pub struct Incoming { + incoming: Compat01As03, + local_addr: SocketAddr, + ghost: PhantomData<(Item, SinkItem)>, +} + +impl Incoming { + unsafe_pinned!(incoming: Compat01As03); + + /// Returns the address being listened on. + pub fn local_addr(&self) -> SocketAddr { + self.local_addr + } +} + +impl Stream for Incoming +where + Item: for<'a> Deserialize<'a>, + SinkItem: Serialize, +{ + type Item = io::Result>; + + fn poll_next(mut self: Pin<&mut Self>, waker: &LocalWaker) -> Poll> { + let next = ready!(self.incoming().poll_next(waker)?); + Poll::Ready(next.map(|conn| Ok(new(conn)))) + } +} + +/// A transport that serializes to, and deserializes from, a [`TcpStream`]. +pub struct Transport { + inner: ReadBincode< + WriteBincode< + With01< + SinkMapErr01< + MapErr01< + Framed, + fn(std::io::Error) -> IoErrorWrapper, + >, + fn(std::io::Error) -> IoErrorWrapper, + >, + BytesMut, + fn(BytesMut) -> Result, + Result + >, + SinkItem, + >, + Item, + >, + staged_item: Option, + peer_addr: io::Result, + local_addr: io::Result, +} + +impl fmt::Debug for Transport { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + write!(f, "Transport") + } +} + +impl Stream for Transport +where + Item: for<'a> Deserialize<'a>, +{ + type Item = io::Result; + + fn poll_next(self: Pin<&mut Self>, waker: &LocalWaker) -> Poll>> { + unsafe { + let inner = &mut Pin::get_mut_unchecked(self).inner; + let mut compat = inner.compat(); + let compat = Pin::new_unchecked(&mut compat); + match ready!(compat.poll_next(waker)) { + None => Poll::Ready(None), + Some(Ok(next)) => Poll::Ready(Some(Ok(next))), + Some(Err(e)) => Poll::Ready(Some(Err(e.0))), + } + } + } +} + +impl Sink for Transport +where + SinkItem: Serialize, +{ + type SinkItem = SinkItem; + type SinkError = io::Error; + + fn start_send(self: Pin<&mut Self>, item: SinkItem) -> io::Result<()> { + let me = unsafe { Pin::get_mut_unchecked(self) }; + assert!(me.staged_item.is_none()); + me.staged_item = Some(item); + Ok(()) + } + + fn poll_ready(self: Pin<&mut Self>, waker: &LocalWaker) -> Poll> { + let notify = &WakerToHandle(waker); + + executor01::with_notify(notify, 0, move || { + let me = unsafe { Pin::get_mut_unchecked(self) }; + match me.staged_item.take() { + Some(staged_item) => match me.inner.start_send(staged_item)? { + AsyncSink01::Ready => Poll::Ready(Ok(())), + AsyncSink01::NotReady(item) => { + me.staged_item = Some(item); + Poll::Pending + } + }, + None => Poll::Ready(Ok(())), + } + }) + } + + fn poll_flush(self: Pin<&mut Self>, waker: &LocalWaker) -> Poll> { + let notify = &WakerToHandle(waker); + + executor01::with_notify(notify, 0, move || { + let me = unsafe { Pin::get_mut_unchecked(self) }; + match me.inner.poll_complete()? { + Async01::Ready(()) => Poll::Ready(Ok(())), + Async01::NotReady => Poll::Pending, + } + }) + } + + fn poll_close(self: Pin<&mut Self>, waker: &LocalWaker) -> Poll> { + let notify = &WakerToHandle(waker); + + executor01::with_notify(notify, 0, move || { + let me = unsafe { Pin::get_mut_unchecked(self) }; + match me.inner.get_mut().close()? { + Async01::Ready(()) => Poll::Ready(Ok(())), + Async01::NotReady => Poll::Pending, + } + }) + } +} + +impl rpc::Transport for Transport +where + Item: for<'de> Deserialize<'de>, + SinkItem: Serialize, +{ + type Item = Item; + type SinkItem = SinkItem; + + fn peer_addr(&self) -> io::Result { + // TODO: should just access from the inner transport. + // https://github.com/alexcrichton/tokio-serde-bincode/issues/4 + Ok(*self.peer_addr.as_ref().unwrap()) + } + + fn local_addr(&self) -> io::Result { + Ok(*self.local_addr.as_ref().unwrap()) + } +} + +#[derive(Clone, Debug)] +struct WakerToHandle<'a>(&'a LocalWaker); + +#[derive(Debug)] +struct NotifyWaker(task::Waker); + +impl Notify01 for NotifyWaker { + fn notify(&self, _: usize) { + self.0.wake(); + } +} + +unsafe impl UnsafeNotify01 for NotifyWaker { + unsafe fn clone_raw(&self) -> NotifyHandle01 { + let ptr = Box::new(NotifyWaker(self.0.clone())); + + NotifyHandle01::new(Box::into_raw(ptr)) + } + + unsafe fn drop_raw(&self) { + let ptr: *const dyn UnsafeNotify01 = self; + drop(Box::from_raw(ptr as *mut dyn UnsafeNotify01)); + } +} + +impl<'a> From> for NotifyHandle01 { + fn from(handle: WakerToHandle<'a>) -> NotifyHandle01 { + unsafe { NotifyWaker(handle.0.clone().into_waker()).clone_raw() } + } +} diff --git a/bincode-transport/src/vendored/mod.rs b/bincode-transport/src/vendored/mod.rs new file mode 100644 index 00000000..2b5432b9 --- /dev/null +++ b/bincode-transport/src/vendored/mod.rs @@ -0,0 +1 @@ +pub(crate) mod tokio_serde_bincode; diff --git a/bincode-transport/src/vendored/tokio_serde_bincode.rs b/bincode-transport/src/vendored/tokio_serde_bincode.rs new file mode 100644 index 00000000..abdc63a5 --- /dev/null +++ b/bincode-transport/src/vendored/tokio_serde_bincode.rs @@ -0,0 +1,224 @@ +//! `Stream` and `Sink` adaptors for serializing and deserializing values using +//! Bincode. +//! +//! This crate provides adaptors for going from a stream or sink of buffers +//! ([`Bytes`]) to a stream or sink of values by performing Bincode encoding or +//! decoding. It is expected that each yielded buffer contains a single +//! serialized Bincode value. The specific strategy by which this is done is left +//! up to the user. One option is to use using [`length_delimited`] from +//! [tokio-io]. +//! +//! [`Bytes`]: https://docs.rs/bytes/0.4/bytes/struct.Bytes.html +//! [`length_delimited`]: http://alexcrichton.com/tokio-io/tokio_io/codec/length_delimited/index.html +//! [tokio-io]: http://github.com/alexcrichton/tokio-io +//! [examples]: https://github.com/carllerche/tokio-serde-json/tree/master/examples + +#![allow(missing_debug_implementations)] + +use bincode::Error; +use bytes::{Bytes, BytesMut}; +use futures_legacy::{Poll, Sink, StartSend, Stream}; +use serde::{Deserialize, Serialize}; +use std::io; +use tokio_serde::{Deserializer, FramedRead, FramedWrite, Serializer}; + +use std::marker::PhantomData; + +/// Adapts a stream of Bincode encoded buffers to a stream of values by +/// deserializing them. +/// +/// `ReadBincode` implements `Stream` by polling the inner buffer stream and +/// deserializing the buffer as Bincode. It expects that each yielded buffer +/// represents a single Bincode value and does not contain any extra trailing +/// bytes. +pub(crate) struct ReadBincode { + inner: FramedRead>, +} + +/// Adapts a buffer sink to a value sink by serializing the values as Bincode. +/// +/// `WriteBincode` implements `Sink` by serializing the submitted values to a +/// buffer. The buffer is then sent to the inner stream, which is responsible +/// for handling framing on the wire. +pub(crate) struct WriteBincode { + inner: FramedWrite>, +} + +struct Bincode { + ghost: PhantomData, +} + +impl ReadBincode +where + T: Stream, + U: for<'de> Deserialize<'de>, + Bytes: From, +{ + /// Creates a new `ReadBincode` with the given buffer stream. + pub fn new(inner: T) -> ReadBincode { + let json = Bincode { ghost: PhantomData }; + ReadBincode { + inner: FramedRead::new(inner, json), + } + } +} + +impl ReadBincode { + /// Returns a mutable reference to the underlying stream wrapped by + /// `ReadBincode`. + /// + /// Note that care should be taken to not tamper with the underlying stream + /// of data coming in as it may corrupt the stream of frames otherwise + /// being worked with. + pub fn get_mut(&mut self) -> &mut T { + self.inner.get_mut() + } +} + +impl Stream for ReadBincode +where + T: Stream, + U: for<'de> Deserialize<'de>, + Bytes: From, +{ + type Item = U; + type Error = ::Error; + + fn poll(&mut self) -> Poll, Self::Error> { + self.inner.poll() + } +} + +impl Sink for ReadBincode +where + T: Sink, +{ + type SinkItem = T::SinkItem; + type SinkError = T::SinkError; + + fn start_send(&mut self, item: T::SinkItem) -> StartSend { + self.get_mut().start_send(item) + } + + fn poll_complete(&mut self) -> Poll<(), T::SinkError> { + self.get_mut().poll_complete() + } + + fn close(&mut self) -> Poll<(), T::SinkError> { + self.get_mut().close() + } +} + +pub(crate) struct IoErrorWrapper(pub io::Error); +impl From> for IoErrorWrapper { + fn from(e: Box) -> Self { + IoErrorWrapper(match *e { + bincode::ErrorKind::Io(e) => e, + bincode::ErrorKind::InvalidUtf8Encoding(e) => { + io::Error::new(io::ErrorKind::InvalidInput, e) + } + bincode::ErrorKind::InvalidBoolEncoding(e) => { + io::Error::new(io::ErrorKind::InvalidInput, e.to_string()) + } + bincode::ErrorKind::InvalidTagEncoding(e) => { + io::Error::new(io::ErrorKind::InvalidInput, e.to_string()) + } + bincode::ErrorKind::InvalidCharEncoding => { + io::Error::new(io::ErrorKind::InvalidInput, "Invalid char encoding") + } + bincode::ErrorKind::DeserializeAnyNotSupported => { + io::Error::new(io::ErrorKind::InvalidInput, "Deserialize Any not supported") + } + bincode::ErrorKind::SizeLimit => { + io::Error::new(io::ErrorKind::InvalidInput, "Size limit exceeded") + } + bincode::ErrorKind::SequenceMustHaveLength => { + io::Error::new(io::ErrorKind::InvalidInput, "Sequence must have length") + } + bincode::ErrorKind::Custom(s) => io::Error::new(io::ErrorKind::Other, s), + }) + } +} + +impl From for io::Error { + fn from(wrapper: IoErrorWrapper) -> io::Error { + wrapper.0 + } +} + +impl WriteBincode +where + T: Sink, + U: Serialize, +{ + /// Creates a new `WriteBincode` with the given buffer sink. + pub fn new(inner: T) -> WriteBincode { + let json = Bincode { ghost: PhantomData }; + WriteBincode { + inner: FramedWrite::new(inner, json), + } + } +} + +impl WriteBincode { + /// Returns a mutable reference to the underlying sink wrapped by + /// `WriteBincode`. + /// + /// Note that care should be taken to not tamper with the underlying sink as + /// it may corrupt the sequence of frames otherwise being worked with. + pub fn get_mut(&mut self) -> &mut T { + self.inner.get_mut() + } +} + +impl Sink for WriteBincode +where + T: Sink, + U: Serialize, +{ + type SinkItem = U; + type SinkError = ::SinkError; + + fn start_send(&mut self, item: U) -> StartSend { + self.inner.start_send(item) + } + + fn poll_complete(&mut self) -> Poll<(), Self::SinkError> { + self.inner.poll_complete() + } + + fn close(&mut self) -> Poll<(), Self::SinkError> { + self.inner.poll_complete() + } +} + +impl Stream for WriteBincode +where + T: Stream + Sink, +{ + type Item = T::Item; + type Error = T::Error; + + fn poll(&mut self) -> Poll, T::Error> { + self.get_mut().poll() + } +} + +impl Deserializer for Bincode +where + T: for<'de> Deserialize<'de>, +{ + type Error = Error; + + fn deserialize(&mut self, src: &Bytes) -> Result { + bincode::deserialize(src) + } +} + +impl Serializer for Bincode { + type Error = Error; + + fn serialize(&mut self, item: &T) -> Result { + bincode::serialize(item).map(Into::into) + } +} diff --git a/bincode-transport/tests/bench.rs b/bincode-transport/tests/bench.rs new file mode 100644 index 00000000..c635e06a --- /dev/null +++ b/bincode-transport/tests/bench.rs @@ -0,0 +1,115 @@ +// Copyright 2018 Google Inc. All Rights Reserved. +// +// Licensed under the MIT License, . +// This file may not be copied, modified, or distributed except according to those terms. + +//! Tests client/server control flow. + +#![feature( + test, + integer_atomics, + futures_api, + generators, + await_macro, + async_await +)] + +extern crate test; + +use self::test::stats::Stats; +use futures::{compat::TokioDefaultSpawner, prelude::*}; +use rpc::{ + client::{self, Client}, + context, + server::{self, Handler, Server}, +}; +use std::{ + io, + time::{Duration, Instant}, +}; + +async fn bench() -> io::Result<()> { + let listener = bincode_transport::listen(&"0.0.0.0:0".parse().unwrap())?; + let addr = listener.local_addr(); + + tokio_executor::spawn( + Server::::new(server::Config::default()) + .incoming(listener) + .take(1) + .respond_with(|_ctx, request| futures::future::ready(Ok(request))) + .unit_error() + .boxed() + .compat() + ); + + let conn = await!(bincode_transport::connect(&addr))?; + let client = &mut await!(Client::::new(client::Config::default(), conn))?; + + let total = 10_000usize; + let mut successful = 0u32; + let mut unsuccessful = 0u32; + let mut durations = vec![]; + for _ in 1..=total { + let now = Instant::now(); + let response = await!(client.call(context::current(), 0u32)); + let elapsed = now.elapsed(); + + match response { + Ok(_) => successful += 1, + Err(_) => unsuccessful += 1, + }; + durations.push(elapsed); + } + + let durations_nanos = durations + .iter() + .map(|duration| duration.as_secs() as f64 * 1E9 + duration.subsec_nanos() as f64) + .collect::>(); + + let (lower, median, upper) = durations_nanos.quartiles(); + + println!("Of {} runs:", durations_nanos.len()); + println!("\tSuccessful: {}", successful); + println!("\tUnsuccessful: {}", unsuccessful); + println!( + "\tMean: {:?}", + Duration::from_nanos(durations_nanos.mean() as u64) + ); + println!("\tMedian: {:?}", Duration::from_nanos(median as u64)); + println!( + "\tStd Dev: {:?}", + Duration::from_nanos(durations_nanos.std_dev() as u64) + ); + println!( + "\tMin: {:?}", + Duration::from_nanos(durations_nanos.min() as u64) + ); + println!( + "\tMax: {:?}", + Duration::from_nanos(durations_nanos.max() as u64) + ); + println!( + "\tQuartiles: ({:?}, {:?}, {:?})", + Duration::from_nanos(lower as u64), + Duration::from_nanos(median as u64), + Duration::from_nanos(upper as u64) + ); + + Ok(()) +} + +#[test] +fn bench_small_packet() -> io::Result<()> { + env_logger::init(); + rpc::init(TokioDefaultSpawner); + + tokio::run( + bench() + .map_err(|e| panic!(e.to_string())) + .boxed() + .compat(), + ); + println!("done"); + + Ok(()) +} diff --git a/bincode-transport/tests/cancel.rs b/bincode-transport/tests/cancel.rs new file mode 100644 index 00000000..0400211a --- /dev/null +++ b/bincode-transport/tests/cancel.rs @@ -0,0 +1,151 @@ +// Copyright 2018 Google Inc. All Rights Reserved. +// +// Licensed under the MIT License, . +// This file may not be copied, modified, or distributed except according to those terms. + +//! Tests client/server control flow. + +#![feature(generators, await_macro, async_await, futures_api,)] + +use futures::{ + compat::{Future01CompatExt, TokioDefaultSpawner}, + prelude::*, + stream, +}; +use log::{info, trace}; +use rand::distributions::{Distribution, Normal}; +use rpc::{ + client::{self, Client}, + context, + server::{self, Server}, +}; +use std::{ + io, + time::{Duration, Instant, SystemTime}, +}; +use tokio::timer::Delay; + +pub trait AsDuration { + /// Delay of 0 if self is in the past + fn as_duration(&self) -> Duration; +} + +impl AsDuration for SystemTime { + fn as_duration(&self) -> Duration { + self.duration_since(SystemTime::now()).unwrap_or_default() + } +} + +async fn run() -> io::Result<()> { + let listener = bincode_transport::listen(&"0.0.0.0:0".parse().unwrap())?; + let addr = listener.local_addr(); + let server = Server::::new(server::Config::default()) + .incoming(listener) + .take(1) + .for_each(async move |channel| { + let channel = if let Ok(channel) = channel { + channel + } else { + return; + }; + let client_addr = *channel.client_addr(); + let handler = channel.respond_with(move |ctx, request| { + // Sleep for a time sampled from a normal distribution with: + // - mean: 1/2 the deadline. + // - std dev: 1/2 the deadline. + let deadline: Duration = ctx.deadline.as_duration(); + let deadline_millis = deadline.as_secs() * 1000 + deadline.subsec_millis() as u64; + let distribution = + Normal::new(deadline_millis as f64 / 2., deadline_millis as f64 / 2.); + let delay_millis = distribution.sample(&mut rand::thread_rng()).max(0.); + let delay = Duration::from_millis(delay_millis as u64); + + trace!( + "[{}/{}] Responding to request in {:?}.", + ctx.trace_id(), + client_addr, + delay, + ); + + let wait = Delay::new(Instant::now() + delay).compat(); + async move { + await!(wait).unwrap(); + Ok(request) + } + }); + tokio_executor::spawn(handler.unit_error().boxed().compat()); + }); + + tokio_executor::spawn(server.unit_error().boxed().compat()); + + let conn = await!(bincode_transport::connect(&addr))?; + let client = await!(Client::::new( + client::Config::default(), + conn + ))?; + + // Proxy service + let listener = bincode_transport::listen(&"0.0.0.0:0".parse().unwrap())?; + let addr = listener.local_addr(); + let proxy_server = Server::::new(server::Config::default()) + .incoming(listener) + .take(1) + .for_each(move |channel| { + let client = client.clone(); + async move { + let channel = if let Ok(channel) = channel { + channel + } else { + return; + }; + let client_addr = *channel.client_addr(); + let handler = channel.respond_with(move |ctx, request| { + trace!("[{}/{}] Proxying request.", ctx.trace_id(), client_addr); + let mut client = client.clone(); + async move { await!(client.call(ctx, request)) } + }); + tokio_executor::spawn(handler.unit_error().boxed().compat()); + } + }); + + tokio_executor::spawn(proxy_server.unit_error().boxed().compat()); + + let mut config = client::Config::default(); + config.max_in_flight_requests = 10; + config.pending_request_buffer = 10; + + let client = await!(Client::::new( + config, + await!(bincode_transport::connect(&addr))? + ))?; + + // Make 3 speculative requests, returning only the quickest. + let mut clients: Vec<_> = (1..=3u32).map(|_| client.clone()).collect(); + let mut requests = vec![]; + for client in &mut clients { + let mut ctx = context::current(); + ctx.deadline = SystemTime::now() + Duration::from_millis(200); + let trace_id = *ctx.trace_id(); + let response = client.call(ctx, "ping".into()); + requests.push(response.map(move |r| (trace_id, r))); + } + let (fastest_response, _) = await!(stream::futures_unordered(requests).into_future()); + let (trace_id, resp) = fastest_response.unwrap(); + info!("[{}] fastest_response = {:?}", trace_id, resp); + + Ok::<_, io::Error>(()) +} + +#[test] +fn cancel_slower() -> io::Result<()> { + env_logger::init(); + rpc::init(TokioDefaultSpawner); + + tokio::run( + run() + .boxed() + .map_err(|e| panic!(e)) + .compat(), + ); + Ok(()) +} diff --git a/bincode-transport/tests/pushback.rs b/bincode-transport/tests/pushback.rs new file mode 100644 index 00000000..5c403a6e --- /dev/null +++ b/bincode-transport/tests/pushback.rs @@ -0,0 +1,119 @@ +// Copyright 2018 Google Inc. All Rights Reserved. +// +// Licensed under the MIT License, . +// This file may not be copied, modified, or distributed except according to those terms. + +//! Tests client/server control flow. + +#![feature(generators, await_macro, async_await, futures_api,)] + +use futures::{ + compat::{Future01CompatExt, TokioDefaultSpawner}, + prelude::*, +}; +use log::{error, info, trace}; +use rand::distributions::{Distribution, Normal}; +use rpc::{ + client::{self, Client}, + context, + server::{self, Server}, +}; +use std::{ + io, + time::{Duration, Instant, SystemTime}, +}; +use tokio::timer::Delay; + +pub trait AsDuration { + /// Delay of 0 if self is in the past + fn as_duration(&self) -> Duration; +} + +impl AsDuration for SystemTime { + fn as_duration(&self) -> Duration { + self.duration_since(SystemTime::now()).unwrap_or_default() + } +} + +async fn run() -> io::Result<()> { + let listener = bincode_transport::listen(&"0.0.0.0:0".parse().unwrap())?; + let addr = listener.local_addr(); + let server = Server::::new(server::Config::default()) + .incoming(listener) + .take(1) + .for_each(async move |channel| { + let channel = if let Ok(channel) = channel { + channel + } else { + return; + }; + let client_addr = *channel.client_addr(); + let handler = channel.respond_with(move |ctx, request| { + // Sleep for a time sampled from a normal distribution with: + // - mean: 1/2 the deadline. + // - std dev: 1/2 the deadline. + let deadline: Duration = ctx.deadline.as_duration(); + let deadline_millis = deadline.as_secs() * 1000 + deadline.subsec_millis() as u64; + let distribution = + Normal::new(deadline_millis as f64 / 2., deadline_millis as f64 / 2.); + let delay_millis = distribution.sample(&mut rand::thread_rng()).max(0.); + let delay = Duration::from_millis(delay_millis as u64); + + trace!( + "[{}/{}] Responding to request in {:?}.", + ctx.trace_id(), + client_addr, + delay, + ); + + let sleep = Delay::new(Instant::now() + delay).compat(); + async { + await!(sleep).unwrap(); + Ok(request) + } + }); + tokio_executor::spawn(handler.unit_error().boxed().compat()); + }); + + tokio_executor::spawn(server.unit_error().boxed().compat()); + + let mut config = client::Config::default(); + config.max_in_flight_requests = 10; + config.pending_request_buffer = 10; + + let conn = await!(bincode_transport::connect(&addr))?; + let client = await!(Client::::new(config, conn))?; + + let clients = (1..=100u32).map(|_| client.clone()).collect::>(); + for mut client in clients { + let ctx = context::current(); + tokio_executor::spawn( + async move { + let trace_id = *ctx.trace_id(); + let response = client.call(ctx, "ping".into()); + match await!(response) { + Ok(response) => info!("[{}] response: {}", trace_id, response), + Err(e) => error!("[{}] request error: {:?}: {}", trace_id, e.kind(), e), + } + }.unit_error().boxed().compat() + ); + } + + Ok(()) +} + +#[test] +fn ping_pong() -> io::Result<()> { + env_logger::init(); + rpc::init(TokioDefaultSpawner); + + tokio::run( + run() + .map_ok(|_| println!("done")) + .map_err(|e| panic!(e.to_string())) + .boxed() + .compat(), + ); + + Ok(()) +} diff --git a/example-service/Cargo.toml b/example-service/Cargo.toml new file mode 100644 index 00000000..211332c3 --- /dev/null +++ b/example-service/Cargo.toml @@ -0,0 +1,22 @@ +[package] +name = "example-service" +version = "0.1.0" +authors = ["Tim Kuehn "] +edition = "2018" + +[dependencies] +bincode-transport = { path = "../bincode-transport" } +futures-preview = { git = "https://github.com/rust-lang-nursery/futures-rs", features = ["compat", "tokio-compat"] } +serde = { version = "1.0" } +tarpc = { path = "../tarpc", features = ["serde"] } +tarpc-plugins = { path = "../plugins" } +tokio = "0.1" +tokio-executor = "0.1" + +[lib] +name = "service" +path = "src/lib.rs" + +[[bin]] +name = "server" +path = "src/main.rs" diff --git a/example-service/src/lib.rs b/example-service/src/lib.rs new file mode 100644 index 00000000..9d351e85 --- /dev/null +++ b/example-service/src/lib.rs @@ -0,0 +1,15 @@ +#![feature( + futures_api, + pin, + arbitrary_self_types, + await_macro, + async_await, + proc_macro_hygiene, +)] + +// This is the service definition. It looks a lot like a trait definition. +// It defines one RPC, hello, which takes one arg, name, and returns a String. +tarpc::service! { + /// Returns a greeting for name. + rpc hello(name: String) -> String; +} diff --git a/example-service/src/main.rs b/example-service/src/main.rs new file mode 100644 index 00000000..e38a1f1d --- /dev/null +++ b/example-service/src/main.rs @@ -0,0 +1,79 @@ +#![feature( + futures_api, + pin, + arbitrary_self_types, + await_macro, + async_await, +)] + +use futures::{ + compat::TokioDefaultSpawner, + future::{self, Ready}, + prelude::*, +}; +use tarpc::{ + client, context, + server::{self, Handler, Server}, +}; +use std::io; + +// This is the type that implements the generated Service trait. It is the business logic +// and is used to start the server. +#[derive(Clone)] +struct HelloServer; + +impl service::Service for HelloServer { + // Each defined rpc generates two items in the trait, a fn that serves the RPC, and + // an associated type representing the future output by the fn. + + type HelloFut = Ready; + + fn hello(&self, _: context::Context, name: String) -> Self::HelloFut { + future::ready(format!("Hello, {}!", name)) + } +} + +async fn run() -> io::Result<()> { + // bincode_transport is provided by the associated crate bincode-transport. It makes it easy + // to start up a serde-powered bincode serialization strategy over TCP. + let transport = bincode_transport::listen(&"0.0.0.0:0".parse().unwrap())?; + let addr = transport.local_addr(); + + // The server is configured with the defaults. + let server = Server::new(server::Config::default()) + // Server can listen on any type that implements the Transport trait. + .incoming(transport) + // Close the stream after the client connects + .take(1) + // serve is generated by the service! macro. It takes as input any type implementing + // the generated Service trait. + .respond_with(service::serve(HelloServer)); + + tokio_executor::spawn(server.unit_error().boxed().compat()); + + let transport = await!(bincode_transport::connect(&addr))?; + + // new_stub is generated by the service! macro. Like Server, it takes a config and any + // Transport as input, and returns a Client, also generated by the macro. + // by the service mcro. + let mut client = await!(service::new_stub(client::Config::default(), transport))?; + + // The client has an RPC method for each RPC defined in service!. It takes the same args + // as defined, with the addition of a Context, which is always the first arg. The Context + // specifies a deadline and trace information which can be helpful in debugging requests. + let hello = await!(client.hello(context::current(), "Stim".to_string()))?; + + println!("{}", hello); + + Ok(()) +} + +fn main() { + tarpc::init(TokioDefaultSpawner); + + tokio::run(run() + .map_err(|e| eprintln!("Oh no: {}", e)) + .boxed() + .compat() + ); +} diff --git a/examples/concurrency.rs b/examples/concurrency.rs deleted file mode 100644 index 705da393..00000000 --- a/examples/concurrency.rs +++ /dev/null @@ -1,208 +0,0 @@ -// Copyright 2016 Google Inc. All Rights Reserved. -// -// Licensed under the MIT License, . -// This file may not be copied, modified, or distributed except according to those terms. - -#![feature(plugin, never_type)] -#![plugin(tarpc_plugins)] - -extern crate chrono; -extern crate clap; -extern crate env_logger; -extern crate futures; -#[macro_use] -extern crate log; -extern crate serde_bytes; -#[macro_use] -extern crate tarpc; -extern crate tokio_core; -extern crate futures_cpupool; - -use clap::{Arg, App}; -use futures::{Future, Stream}; -use futures_cpupool::{CpuFuture, CpuPool}; -use std::{cmp, thread}; -use std::sync::{Arc, mpsc}; -use std::sync::atomic::{AtomicUsize, Ordering}; -use std::time::{Duration, Instant}; -use tarpc::future::{client, server}; -use tarpc::future::client::ClientExt; -use tarpc::util::{FirstSocketAddr, Never}; -use tokio_core::reactor; - -service! { - rpc read(size: u32) -> serde_bytes::ByteBuf; -} - -#[derive(Clone)] -struct Server { - pool: CpuPool, - request_count: Arc, -} - -impl Server { - fn new() -> Self { - Server { - pool: CpuPool::new_num_cpus(), - request_count: Arc::new(AtomicUsize::new(1)), - } - } -} - -impl FutureService for Server { - type ReadFut = CpuFuture; - - fn read(&self, size: u32) -> Self::ReadFut { - let request_number = self.request_count.fetch_add(1, Ordering::SeqCst); - debug!("Server received read({}) no. {}", size, request_number); - self.pool.spawn(futures::lazy(move || { - let mut vec = Vec::with_capacity(size as usize); - for i in 0..size { - vec.push(((i % 2) << 8) as u8); - } - debug!("Server sending response no. {}", request_number); - Ok(vec.into()) - })) - } -} - -const CHUNK_SIZE: u32 = 1 << 10; - -trait Microseconds { - fn microseconds(&self) -> i64; -} - -impl Microseconds for Duration { - fn microseconds(&self) -> i64 { - chrono::Duration::from_std(*self) - .unwrap() - .num_microseconds() - .unwrap() - } -} - -#[derive(Default)] -struct Stats { - sum: Duration, - count: u64, - min: Option, - max: Option, -} - -/// Spawns a `reactor::Core` running forever on a new thread. -fn spawn_core() -> reactor::Remote { - let (tx, rx) = mpsc::channel(); - thread::spawn(move || { - let mut core = reactor::Core::new().unwrap(); - tx.send(core.handle().remote().clone()).unwrap(); - - // Run forever - core.run(futures::empty::<(), !>()).unwrap(); - }); - rx.recv().unwrap() -} - -fn run_once( - clients: Vec, - concurrency: u32, -) -> impl Future + 'static { - let start = Instant::now(); - futures::stream::futures_unordered( - (0..concurrency as usize) - .zip(clients.iter().enumerate().cycle()) - .map(|(iteration, (client_idx, client))| { - let start = Instant::now(); - debug!("Client {} reading (iteration {})...", client_idx, iteration); - client - .read(CHUNK_SIZE) - .map(move |_| (client_idx, iteration, start)) - }), - ).map(|(client_idx, iteration, start)| { - let elapsed = start.elapsed(); - debug!( - "Client {} received reply (iteration {}).", - client_idx, - iteration - ); - elapsed - }) - .map_err(|e| panic!(e)) - .fold(Stats::default(), move |mut stats, elapsed| { - stats.sum += elapsed; - stats.count += 1; - stats.min = Some(cmp::min(stats.min.unwrap_or(elapsed), elapsed)); - stats.max = Some(cmp::max(stats.max.unwrap_or(elapsed), elapsed)); - Ok(stats) - }) - .map(move |stats| { - info!( - "{} requests => Mean={}µs, Min={}µs, Max={}µs, Total={}µs", - stats.count, - stats.sum.microseconds() as f64 / stats.count as f64, - stats.min.unwrap().microseconds(), - stats.max.unwrap().microseconds(), - start.elapsed().microseconds() - ); - }) -} - -fn main() { - env_logger::init(); - let matches = App::new("Tarpc Concurrency") - .about( - "Demonstrates making concurrent requests to a tarpc service.", - ) - .arg( - Arg::with_name("concurrency") - .short("c") - .long("concurrency") - .value_name("LEVEL") - .help("Sets a custom concurrency level") - .takes_value(true), - ) - .arg( - Arg::with_name("clients") - .short("n") - .long("num_clients") - .value_name("AMOUNT") - .help("How many clients to distribute requests between") - .takes_value(true), - ) - .get_matches(); - let concurrency = matches - .value_of("concurrency") - .map(&str::parse) - .map(Result::unwrap) - .unwrap_or(10); - let num_clients = matches - .value_of("clients") - .map(&str::parse) - .map(Result::unwrap) - .unwrap_or(4); - - let mut reactor = reactor::Core::new().unwrap(); - let (handle, server) = Server::new() - .listen( - "localhost:0".first_socket_addr(), - &reactor.handle(), - server::Options::default(), - ) - .unwrap(); - reactor.handle().spawn(server); - info!("Server listening on {}.", handle.addr()); - - let clients = (0..num_clients) - // Spin up a couple threads to drive the clients. - .map(|i| (i, spawn_core())) - .map(|(i, remote)| { - info!("Client {} connecting...", i); - FutureClient::connect(handle.addr(), client::Options::default().remote(remote)) - .map_err(|e| panic!(e)) - }); - - let run = futures::collect(clients).and_then(|clients| run_once(clients, concurrency)); - - info!("Starting..."); - - reactor.run(run).unwrap(); -} diff --git a/examples/pubsub.rs b/examples/pubsub.rs deleted file mode 100644 index 848c8097..00000000 --- a/examples/pubsub.rs +++ /dev/null @@ -1,160 +0,0 @@ -// Copyright 2016 Google Inc. All Rights Reserved. -// -// Licensed under the MIT License, . -// This file may not be copied, modified, or distributed except according to those terms. - -#![feature(plugin)] -#![plugin(tarpc_plugins)] - -extern crate env_logger; -extern crate futures; -#[macro_use] -extern crate tarpc; -extern crate tokio_core; - -use futures::{Future, future}; -use publisher::FutureServiceExt as PublisherExt; -use std::cell::RefCell; -use std::collections::HashMap; -use std::net::SocketAddr; -use std::rc::Rc; -use std::thread; -use std::time::Duration; -use subscriber::FutureServiceExt as SubscriberExt; -use tarpc::future::{client, server}; -use tarpc::future::client::ClientExt; -use tarpc::util::{FirstSocketAddr, Message, Never}; -use tokio_core::reactor; - -pub mod subscriber { - service! { - rpc receive(message: String); - } -} - -pub mod publisher { - use std::net::SocketAddr; - use tarpc::util::Message; - - service! { - rpc broadcast(message: String); - rpc subscribe(id: u32, address: SocketAddr) | Message; - rpc unsubscribe(id: u32); - } -} - -#[derive(Clone, Debug)] -struct Subscriber { - id: u32, -} - -impl subscriber::FutureService for Subscriber { - type ReceiveFut = Result<(), Never>; - - fn receive(&self, message: String) -> Self::ReceiveFut { - println!("{} received message: {}", self.id, message); - Ok(()) - } -} - -impl Subscriber { - fn listen(id: u32, handle: &reactor::Handle, options: server::Options) -> server::Handle { - let (server_handle, server) = Subscriber { id: id } - .listen("localhost:0".first_socket_addr(), handle, options) - .unwrap(); - handle.spawn(server); - server_handle - } -} - -#[derive(Clone, Debug)] -struct Publisher { - clients: Rc>>, -} - -impl Publisher { - fn new() -> Publisher { - Publisher { - clients: Rc::new(RefCell::new(HashMap::new())), - } - } -} - -impl publisher::FutureService for Publisher { - type BroadcastFut = Box>; - - fn broadcast(&self, message: String) -> Self::BroadcastFut { - let acks = self.clients - .borrow() - .values() - .map(move |client| client.receive(message.clone()) - // Ignore failing subscribers. In a real pubsub, - // you'd want to continually retry until subscribers - // ack. - .then(|_| Ok(()))) - // Collect to a vec to end the borrow on `self.clients`. - .collect::>(); - Box::new(future::join_all(acks).map(|_| ())) - } - - type SubscribeFut = Box>; - - fn subscribe(&self, id: u32, address: SocketAddr) -> Self::SubscribeFut { - let clients = Rc::clone(&self.clients); - Box::new( - subscriber::FutureClient::connect(address, client::Options::default()) - .map(move |subscriber| { - println!("Subscribing {}.", id); - clients.borrow_mut().insert(id, subscriber); - () - }) - .map_err(|e| e.to_string().into()), - ) - } - - type UnsubscribeFut = Box>; - - fn unsubscribe(&self, id: u32) -> Self::UnsubscribeFut { - println!("Unsubscribing {}", id); - self.clients.borrow_mut().remove(&id).unwrap(); - Box::new(futures::finished(())) - } -} - -fn main() { - env_logger::init(); - let mut reactor = reactor::Core::new().unwrap(); - let (publisher_handle, server) = Publisher::new() - .listen( - "localhost:0".first_socket_addr(), - &reactor.handle(), - server::Options::default(), - ) - .unwrap(); - reactor.handle().spawn(server); - - let subscriber1 = Subscriber::listen(0, &reactor.handle(), server::Options::default()); - let subscriber2 = Subscriber::listen(1, &reactor.handle(), server::Options::default()); - - let publisher = reactor - .run(publisher::FutureClient::connect( - publisher_handle.addr(), - client::Options::default(), - )) - .unwrap(); - reactor - .run( - publisher - .subscribe(0, subscriber1.addr()) - .and_then(|_| publisher.subscribe(1, subscriber2.addr())) - .map_err(|e| panic!(e)) - .and_then(|_| { - println!("Broadcasting..."); - publisher.broadcast("hello to all".to_string()) - }) - .and_then(|_| publisher.unsubscribe(1)) - .and_then(|_| publisher.broadcast("hi again".to_string())), - ) - .unwrap(); - thread::sleep(Duration::from_millis(300)); -} diff --git a/examples/readme_errors.rs b/examples/readme_errors.rs deleted file mode 100644 index cf769750..00000000 --- a/examples/readme_errors.rs +++ /dev/null @@ -1,65 +0,0 @@ -// Copyright 2016 Google Inc. All Rights Reserved. -// -// Licensed under the MIT License, . -// This file may not be copied, modified, or distributed except according to those terms. - -#![feature(plugin)] -#![plugin(tarpc_plugins)] - -#[macro_use] -extern crate tarpc; -#[macro_use] -extern crate serde_derive; - -use std::error::Error; -use std::fmt; -use std::sync::mpsc; -use std::thread; -use tarpc::sync::{client, server}; -use tarpc::sync::client::ClientExt; - -service! { - rpc hello(name: String) -> String | NoNameGiven; -} - -#[derive(Debug, Deserialize, Serialize)] -pub struct NoNameGiven; - -impl fmt::Display for NoNameGiven { - fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - write!(f, "{}", self.description()) - } -} - -impl Error for NoNameGiven { - fn description(&self) -> &str { - r#"The empty String, "", is not a valid argument to rpc `hello`."# - } -} - -#[derive(Clone)] -struct HelloServer; - -impl SyncService for HelloServer { - fn hello(&self, name: String) -> Result { - if name == "" { - Err(NoNameGiven) - } else { - Ok(format!("Hello, {}!", name)) - } - } -} - -fn main() { - let (tx, rx) = mpsc::channel(); - thread::spawn(move || { - let handle = HelloServer - .listen("localhost:10000", server::Options::default()) - .unwrap(); - tx.send(handle.addr()).unwrap(); - handle.run(); - }); - let client = SyncClient::connect(rx.recv().unwrap(), client::Options::default()).unwrap(); - println!("{}", client.hello("Mom".to_string()).unwrap()); - println!("{}", client.hello("".to_string()).unwrap_err()); -} diff --git a/examples/readme_futures.rs b/examples/readme_futures.rs deleted file mode 100644 index f206734e..00000000 --- a/examples/readme_futures.rs +++ /dev/null @@ -1,55 +0,0 @@ -// Copyright 2016 Google Inc. All Rights Reserved. -// -// Licensed under the MIT License, . -// This file may not be copied, modified, or distributed except according to those terms. - -#![feature(plugin)] -#![plugin(tarpc_plugins)] - -extern crate futures; -#[macro_use] -extern crate tarpc; -extern crate tokio_core; - -use futures::Future; -use tarpc::future::{client, server}; -use tarpc::future::client::ClientExt; -use tarpc::util::{FirstSocketAddr, Never}; -use tokio_core::reactor; - -service! { - rpc hello(name: String) -> String; -} - -#[derive(Clone)] -struct HelloServer; - -impl FutureService for HelloServer { - type HelloFut = Result; - - fn hello(&self, name: String) -> Self::HelloFut { - Ok(format!("Hello, {}!", name)) - } -} - -fn main() { - let mut reactor = reactor::Core::new().unwrap(); - let (handle, server) = HelloServer - .listen( - "localhost:10000".first_socket_addr(), - &reactor.handle(), - server::Options::default(), - ) - .unwrap(); - reactor.handle().spawn(server); - - let options = client::Options::default().handle(reactor.handle()); - reactor - .run( - FutureClient::connect(handle.addr(), options) - .map_err(tarpc::Error::from) - .and_then(|client| client.hello("Mom".to_string())) - .map(|resp| println!("{}", resp)), - ) - .unwrap(); -} diff --git a/examples/readme_sync.rs b/examples/readme_sync.rs deleted file mode 100644 index 0652a1c8..00000000 --- a/examples/readme_sync.rs +++ /dev/null @@ -1,47 +0,0 @@ -// Copyright 2016 Google Inc. All Rights Reserved. -// -// Licensed under the MIT License, . -// This file may not be copied, modified, or distributed except according to those terms. - -// required by `FutureClient` (not used directly in this example) -#![feature(plugin)] -#![plugin(tarpc_plugins)] - -#[macro_use] -extern crate tarpc; - -use std::sync::mpsc; -use std::thread; -use tarpc::sync::{client, server}; -use tarpc::sync::client::ClientExt; -use tarpc::util::Never; - -service! { - rpc hello(name: String) -> String; -} - -#[derive(Clone)] -struct HelloServer; - -impl SyncService for HelloServer { - fn hello(&self, name: String) -> Result { - Ok(format!( - "Hello from thread {}, {}!", - thread::current().name().unwrap(), - name - )) - } -} - -fn main() { - let (tx, rx) = mpsc::channel(); - thread::spawn(move || { - let handle = HelloServer - .listen("localhost:0", server::Options::default()) - .unwrap(); - tx.send(handle.addr()).unwrap(); - handle.run(); - }); - let client = SyncClient::connect(rx.recv().unwrap(), client::Options::default()).unwrap(); - println!("{}", client.hello("Mom".to_string()).unwrap()); -} diff --git a/examples/server_calling_server.rs b/examples/server_calling_server.rs deleted file mode 100644 index f4d2cb51..00000000 --- a/examples/server_calling_server.rs +++ /dev/null @@ -1,113 +0,0 @@ -// Copyright 2016 Google Inc. All Rights Reserved. -// -// Licensed under the MIT License, . -// This file may not be copied, modified, or distributed except according to those terms. - -#![feature(plugin)] -#![plugin(tarpc_plugins)] - -extern crate env_logger; -#[macro_use] -extern crate tarpc; -extern crate futures; -extern crate tokio_core; - -use add::{FutureService as AddFutureService, FutureServiceExt as AddExt}; -use double::{FutureService as DoubleFutureService, FutureServiceExt as DoubleExt}; -use futures::{Future, Stream}; -use tarpc::future::{client, server}; -use tarpc::future::client::ClientExt as Fc; -use tarpc::util::{FirstSocketAddr, Message, Never}; -use tokio_core::reactor; - -pub mod add { - service! { - /// Add two ints together. - rpc add(x: i32, y: i32) -> i32; - } -} - -pub mod double { - use tarpc::util::Message; - - service! { - /// 2 * x - rpc double(x: i32) -> i32 | Message; - } -} - -#[derive(Clone)] -struct AddServer; - -impl AddFutureService for AddServer { - type AddFut = Result; - - fn add(&self, x: i32, y: i32) -> Self::AddFut { - Ok(x + y) - } -} - -#[derive(Clone)] -struct DoubleServer { - client: add::FutureClient, -} - -impl DoubleServer { - fn new(client: add::FutureClient) -> Self { - DoubleServer { client: client } - } -} - -impl DoubleFutureService for DoubleServer { - type DoubleFut = Box>; - - fn double(&self, x: i32) -> Self::DoubleFut { - Box::new(self.client - .add(x, x) - .map_err(|e| e.to_string().into())) - } -} - -fn main() { - env_logger::init(); - let mut reactor = reactor::Core::new().unwrap(); - let (add, server) = AddServer - .listen( - "localhost:0".first_socket_addr(), - &reactor.handle(), - server::Options::default(), - ) - .unwrap(); - reactor.handle().spawn(server); - - let options = client::Options::default().handle(reactor.handle()); - let add_client = reactor - .run(add::FutureClient::connect(add.addr(), options)) - .unwrap(); - - let (double, server) = DoubleServer::new(add_client) - .listen( - "localhost:0".first_socket_addr(), - &reactor.handle(), - server::Options::default(), - ) - .unwrap(); - reactor.handle().spawn(server); - - let double_client = reactor - .run(double::FutureClient::connect( - double.addr(), - client::Options::default(), - )) - .unwrap(); - reactor - .run( - futures::stream::futures_unordered((0..5).map(|i| double_client.double(i))) - .map_err(|e| println!("{}", e)) - .for_each(|i| { - println!("{:?}", i); - Ok(()) - }), - ) - .unwrap(); -} diff --git a/examples/sync_server_calling_server.rs b/examples/sync_server_calling_server.rs deleted file mode 100644 index 73358523..00000000 --- a/examples/sync_server_calling_server.rs +++ /dev/null @@ -1,98 +0,0 @@ -// Copyright 2016 Google Inc. All Rights Reserved. -// -// Licensed under the MIT License, . -// This file may not be copied, modified, or distributed except according to those terms. - -#![feature(plugin)] -#![plugin(tarpc_plugins)] - -extern crate env_logger; -#[macro_use] -extern crate tarpc; - -use add::{SyncService as AddSyncService, SyncServiceExt as AddExt}; -use double::{SyncService as DoubleSyncService, SyncServiceExt as DoubleExt}; -use std::sync::mpsc; -use std::thread; -use tarpc::sync::{client, server}; -use tarpc::sync::client::ClientExt as Fc; -use tarpc::util::{FirstSocketAddr, Message, Never}; - -pub mod add { - service! { - /// Add two ints together. - rpc add(x: i32, y: i32) -> i32; - } -} - -pub mod double { - use tarpc::util::Message; - - service! { - /// 2 * x - rpc double(x: i32) -> i32 | Message; - } -} - -#[derive(Clone)] -struct AddServer; - -impl AddSyncService for AddServer { - fn add(&self, x: i32, y: i32) -> Result { - Ok(x + y) - } -} - -#[derive(Clone)] -struct DoubleServer { - client: add::SyncClient, -} - -impl DoubleServer { - fn new(client: add::SyncClient) -> Self { - DoubleServer { client: client } - } -} - -impl DoubleSyncService for DoubleServer { - fn double(&self, x: i32) -> Result { - self.client.add(x, x).map_err(|e| e.to_string().into()) - } -} - -fn main() { - env_logger::init(); - let (tx, rx) = mpsc::channel(); - thread::spawn(move || { - let handle = AddServer - .listen( - "localhost:0".first_socket_addr(), - server::Options::default(), - ) - .unwrap(); - tx.send(handle.addr()).unwrap(); - handle.run(); - }); - - - let add = rx.recv().unwrap(); - let (tx, rx) = mpsc::channel(); - thread::spawn(move || { - let add_client = add::SyncClient::connect(add, client::Options::default()).unwrap(); - let handle = DoubleServer::new(add_client) - .listen( - "localhost:0".first_socket_addr(), - server::Options::default(), - ) - .unwrap(); - tx.send(handle.addr()).unwrap(); - handle.run(); - }); - let double = rx.recv().unwrap(); - - let double_client = double::SyncClient::connect(double, client::Options::default()).unwrap(); - for i in 0..5 { - let doubled = double_client.double(i).unwrap(); - println!("{:?}", doubled); - } -} diff --git a/examples/throughput.rs b/examples/throughput.rs deleted file mode 100644 index 8b8d9ffc..00000000 --- a/examples/throughput.rs +++ /dev/null @@ -1,119 +0,0 @@ -// Copyright 2016 Google Inc. All Rights Reserved. -// -// Licensed under the MIT License, . -// This file may not be copied, modified, or distributed except according to those terms. - -#![feature(plugin)] -#![plugin(tarpc_plugins)] - -#[macro_use] -extern crate lazy_static; -#[macro_use] -extern crate tarpc; -extern crate env_logger; -extern crate serde_bytes; -extern crate tokio_core; - -use std::io::{Read, Write, stdout}; -use std::net; -use std::sync::mpsc; -use std::thread; -use std::time; -use tarpc::future::server; -use tarpc::sync::client::{self, ClientExt}; -use tarpc::util::{FirstSocketAddr, Never}; -use tokio_core::reactor; - -lazy_static! { - static ref BUF: serde_bytes::ByteBuf = gen_vec(CHUNK_SIZE as usize).into(); -} - -fn gen_vec(size: usize) -> Vec { - let mut vec: Vec = Vec::with_capacity(size); - for i in 0..size { - vec.push(((i % 2) << 8) as u8); - } - vec -} - -service! { - rpc read() -> serde_bytes::ByteBuf; -} - -#[derive(Clone)] -struct Server; - -impl FutureService for Server { - type ReadFut = Result; - - fn read(&self) -> Self::ReadFut { - Ok(BUF.clone()) - } -} - -const CHUNK_SIZE: u32 = 1 << 19; - -fn bench_tarpc(target: u64) { - let (tx, rx) = mpsc::channel(); - thread::spawn(move || { - let mut reactor = reactor::Core::new().unwrap(); - let (addr, server) = Server - .listen( - "localhost:0".first_socket_addr(), - &reactor.handle(), - server::Options::default(), - ) - .unwrap(); - tx.send(addr).unwrap(); - reactor.run(server).unwrap(); - }); - let client = - SyncClient::connect(rx.recv().unwrap().addr(), client::Options::default()).unwrap(); - let start = time::Instant::now(); - let mut nread = 0; - while nread < target { - nread += client.read().unwrap().len() as u64; - print!("."); - stdout().flush().unwrap(); - } - println!("done"); - let duration = time::Instant::now() - start; - println!( - "TARPC: {}MB/s", - (target as f64 / (1024f64 * 1024f64)) / - (duration.as_secs() as f64 + duration.subsec_nanos() as f64 / 10E9) - ); -} - -fn bench_tcp(target: u64) { - let l = net::TcpListener::bind("localhost:0").unwrap(); - let addr = l.local_addr().unwrap(); - thread::spawn(move || { - let (mut stream, _) = l.accept().unwrap(); - while let Ok(_) = stream.write_all(&*BUF) {} - }); - let mut stream = net::TcpStream::connect(&addr).unwrap(); - let mut buf = vec![0; CHUNK_SIZE as usize]; - let start = time::Instant::now(); - let mut nread = 0; - while nread < target { - stream.read_exact(&mut buf[..]).unwrap(); - nread += CHUNK_SIZE as u64; - print!("."); - stdout().flush().unwrap(); - } - println!("done"); - let duration = time::Instant::now() - start; - println!( - "TCP: {}MB/s", - (target as f64 / (1024f64 * 1024f64)) / - (duration.as_secs() as f64 + duration.subsec_nanos() as f64 / 10E9) - ); -} - -fn main() { - env_logger::init(); - let _ = *BUF; // To non-lazily initialize it. - bench_tcp(256 << 20); - bench_tarpc(256 << 20); -} diff --git a/examples/two_clients.rs b/examples/two_clients.rs deleted file mode 100644 index b2a8d951..00000000 --- a/examples/two_clients.rs +++ /dev/null @@ -1,105 +0,0 @@ -// Copyright 2016 Google Inc. All Rights Reserved. -// -// Licensed under the MIT License, . -// This file may not be copied, modified, or distributed except according to those terms. - -#![feature(plugin)] -#![plugin(tarpc_plugins)] - -#[macro_use] -extern crate log; -#[macro_use] -extern crate tarpc; -extern crate env_logger; -extern crate tokio_core; - -use bar::FutureServiceExt as BarExt; -use baz::FutureServiceExt as BazExt; -use std::sync::mpsc; -use std::thread; -use tarpc::future::server; -use tarpc::sync::client; -use tarpc::sync::client::ClientExt; -use tarpc::util::{FirstSocketAddr, Never}; -use tokio_core::reactor; - -mod bar { - service! { - rpc bar(i: i32) -> i32; - } -} - -#[derive(Clone)] -struct Bar; -impl bar::FutureService for Bar { - type BarFut = Result; - - fn bar(&self, i: i32) -> Self::BarFut { - Ok(i) - } -} - -mod baz { - service! { - rpc baz(s: String) -> String; - } -} - -#[derive(Clone)] -struct Baz; -impl baz::FutureService for Baz { - type BazFut = Result; - - fn baz(&self, s: String) -> Self::BazFut { - Ok(format!("Hello, {}!", s)) - } -} - -fn main() { - env_logger::init(); - let bar_client = { - let (tx, rx) = mpsc::channel(); - thread::spawn(move || { - let mut reactor = reactor::Core::new().unwrap(); - let (handle, server) = Bar.listen( - "localhost:0".first_socket_addr(), - &reactor.handle(), - server::Options::default(), - ).unwrap(); - tx.send(handle).unwrap(); - reactor.run(server).unwrap(); - }); - let handle = rx.recv().unwrap(); - bar::SyncClient::connect(handle.addr(), client::Options::default()).unwrap() - }; - - let baz_client = { - let (tx, rx) = mpsc::channel(); - thread::spawn(move || { - let mut reactor = reactor::Core::new().unwrap(); - let (handle, server) = Baz.listen( - "localhost:0".first_socket_addr(), - &reactor.handle(), - server::Options::default(), - ).unwrap(); - tx.send(handle).unwrap(); - reactor.run(server).unwrap(); - }); - let handle = rx.recv().unwrap(); - baz::SyncClient::connect(handle.addr(), client::Options::default()).unwrap() - }; - - - info!("Result: {:?}", bar_client.bar(17)); - - let total = 20; - for i in 1..(total + 1) { - if i % 2 == 0 { - info!("Result 1: {:?}", bar_client.bar(i)); - } else { - info!("Result 2: {:?}", baz_client.baz(i.to_string())); - } - } - - info!("Done."); -} diff --git a/hooks/pre-push b/hooks/pre-push index 3c2a8016..57d5e875 100755 --- a/hooks/pre-push +++ b/hooks/pre-push @@ -91,11 +91,8 @@ if [ "$?" == 0 ]; then try_run "Building ... " cargo build --color=always try_run "Testing ... " cargo test --color=always - try_run "Benching ... " cargo bench --color=always + try_run "Doc Test ... " cargo clean && cargo build --tests && rustdoc --test README.md --edition 2018 -L target/debug/deps -Z unstable-options - try_run "Building with tls ... " cargo build --color=always --features tls - try_run "Testing with tls ... " cargo test --color=always --features tls - try_run "Benching with tls ... " cargo bench --color=always --features tls fi exit $PREPUSH_RESULT diff --git a/src/plugins/Cargo.toml b/plugins/Cargo.toml similarity index 84% rename from src/plugins/Cargo.toml rename to plugins/Cargo.toml index e41de1bf..9abea390 100644 --- a/src/plugins/Cargo.toml +++ b/plugins/Cargo.toml @@ -16,6 +16,9 @@ travis-ci = { repository = "google/tarpc" } [dependencies] itertools = "0.7" +syn = { version = "0.15", features = ["full", "extra-traits"] } +quote = "0.6" +proc-macro2 = "0.4" [lib] -plugin = true +proc-macro = true diff --git a/plugins/rustfmt.toml b/plugins/rustfmt.toml new file mode 100644 index 00000000..0ef5137d --- /dev/null +++ b/plugins/rustfmt.toml @@ -0,0 +1 @@ +edition = "Edition2018" diff --git a/plugins/src/lib.rs b/plugins/src/lib.rs new file mode 100644 index 00000000..adf47f96 --- /dev/null +++ b/plugins/src/lib.rs @@ -0,0 +1,85 @@ +extern crate proc_macro; +extern crate proc_macro2; +extern crate syn; +extern crate itertools; +extern crate quote; + +use proc_macro::TokenStream; + +use itertools::Itertools; +use quote::ToTokens; +use syn::{Ident, TraitItemType, TypePath, parse}; +use proc_macro2::Span; +use std::str::FromStr; + +#[proc_macro] +pub fn snake_to_camel(input: TokenStream) -> TokenStream { + let i = input.clone(); + let mut assoc_type = parse::(input).unwrap_or_else(|_| panic!("Could not parse trait item from:\n{}", i)); + + let old_ident = convert(&mut assoc_type.ident); + + for mut attr in &mut assoc_type.attrs { + if let Some(pair) = attr.path.segments.first() { + if pair.value().ident == "doc" { + attr.tts = proc_macro2::TokenStream::from_str(&attr.tts.to_string().replace("{}", &old_ident)).unwrap(); + } + } + } + + assoc_type.into_token_stream().into() +} + +#[proc_macro] +pub fn ty_snake_to_camel(input: TokenStream) -> TokenStream { + let mut path = parse::(input).unwrap(); + + // Only capitalize the final segment + convert(&mut path.path + .segments + .last_mut() + .unwrap() + .into_value() + .ident); + + path.into_token_stream().into() +} + +/// Converts an ident in-place to CamelCase and returns the previous ident. +fn convert(ident: &mut Ident) -> String { + let ident_str = ident.to_string(); + let mut camel_ty = String::new(); + + { + // Find the first non-underscore and add it capitalized. + let mut chars = ident_str.chars(); + + // Find the first non-underscore char, uppercase it, and append it. + // Guaranteed to succeed because all idents must have at least one non-underscore char. + camel_ty.extend(chars.find(|&c| c != '_').unwrap().to_uppercase()); + + // When we find an underscore, we remove it and capitalize the next char. To do this, + // we need to ensure the next char is not another underscore. + let mut chars = chars.coalesce(|c1, c2| { + if c1 == '_' && c2 == '_' { + Ok(c1) + } else { + Err((c1, c2)) + } + }); + + while let Some(c) = chars.next() { + if c != '_' { + camel_ty.push(c); + } else if let Some(c) = chars.next() { + camel_ty.extend(c.to_uppercase()); + } + } + } + + // The Fut suffix is hardcoded right now; this macro isn't really meant to be general-purpose. + camel_ty.push_str("Fut"); + + *ident = Ident::new(&camel_ty, Span::call_site()); + ident_str +} diff --git a/rpc/Cargo.toml b/rpc/Cargo.toml new file mode 100644 index 00000000..2c9a4470 --- /dev/null +++ b/rpc/Cargo.toml @@ -0,0 +1,31 @@ +cargo-features = ["namespaced-features"] + +[package] +name = "rpc" +version = "0.1.0" +authors = ["Tim Kuehn "] +edition = '2018' +namespaced-features = true + +[features] +default = [] +serde = ["trace/serde", "crate:serde", "serde/derive"] + +[dependencies] +fnv = "1.0" +humantime = "1.0" +log = "0.4" +pin-utils = "0.1.0-alpha.2" +rand = "0.5" +tokio-timer = "0.2" +trace = { path = "../trace" } +serde = { optional = true, version = "1.0" } + +[target.'cfg(not(test))'.dependencies] +futures-preview = { git = "https://github.com/rust-lang-nursery/futures-rs", features = ["compat"] } + +[dev-dependencies] +futures-preview = { git = "https://github.com/rust-lang-nursery/futures-rs", features = ["compat", "tokio-compat"] } +futures-test-preview = { git = "https://github.com/rust-lang-nursery/futures-rs" } +env_logger = "0.5" +tokio = "0.1" diff --git a/rpc/rustfmt.toml b/rpc/rustfmt.toml new file mode 100644 index 00000000..0ef5137d --- /dev/null +++ b/rpc/rustfmt.toml @@ -0,0 +1 @@ +edition = "Edition2018" diff --git a/rpc/src/client/dispatch.rs b/rpc/src/client/dispatch.rs new file mode 100644 index 00000000..95e89370 --- /dev/null +++ b/rpc/src/client/dispatch.rs @@ -0,0 +1,708 @@ +use crate::{ + context, + util::{deadline_compat, AsDuration, Compact}, + ClientMessage, ClientMessageKind, Request, Response, Transport, +}; +use fnv::FnvHashMap; +use futures::{ + Poll, + channel::{mpsc, oneshot}, + prelude::*, + ready, + stream::Fuse, + task::LocalWaker, +}; +use humantime::format_rfc3339; +use log::{debug, error, info, trace}; +use pin_utils::unsafe_pinned; +use std::{ + io, + net::SocketAddr, + pin::Pin, + sync::{ + atomic::{AtomicU64, Ordering}, + Arc, + }, + time::Instant, +}; +use trace::SpanId; + +use super::Config; + +/// Handles communication from the client to request dispatch. +#[derive(Debug)] +pub(crate) struct Channel { + to_dispatch: mpsc::Sender>, + /// Channel to send a cancel message to the dispatcher. + cancellation: RequestCancellation, + /// The ID to use for the next request to stage. + next_request_id: Arc, + server_addr: SocketAddr, +} + +impl Clone for Channel { + fn clone(&self) -> Self { + Self { + to_dispatch: self.to_dispatch.clone(), + cancellation: self.cancellation.clone(), + next_request_id: self.next_request_id.clone(), + server_addr: self.server_addr, + } + } +} + +impl Channel { + /// Sends a request to the dispatch task to forward to the server, returning a [`Future`] that + /// resolves when the request is sent (not when the response is received). + pub(crate) async fn send( + &mut self, + mut ctx: context::Context, + request: Req, + ) -> io::Result> { + // Convert the context to the call context. + ctx.trace_context.parent_id = Some(ctx.trace_context.span_id); + ctx.trace_context.span_id = SpanId::random(&mut rand::thread_rng()); + + let timeout = ctx.deadline.as_duration(); + let deadline = Instant::now() + timeout; + trace!( + "[{}/{}] Queuing request with deadline {} (timeout {:?}).", + ctx.trace_id(), + self.server_addr, + format_rfc3339(ctx.deadline), + timeout, + ); + + let (response_completion, response) = oneshot::channel(); + let cancellation = self.cancellation.clone(); + let request_id = self.next_request_id.fetch_add(1, Ordering::Relaxed); + await!(self.to_dispatch.send(DispatchRequest { + ctx, + request_id, + request, + response_completion, + })).map_err(|_| io::Error::from(io::ErrorKind::ConnectionReset))?; + Ok(DispatchResponse { + response: deadline_compat::Deadline::new(response, deadline), + complete: false, + request_id, + cancellation, + ctx, + server_addr: self.server_addr, + }) + } + + /// Sends a request to the dispatch task to forward to the server, returning a [`Future`] that + /// resolves to the response. + pub(crate) async fn call( + &mut self, + context: context::Context, + request: Req, + ) -> io::Result { + let response_future = await!(self.send(context, request))?; + await!(response_future) + } +} + +/// A server response that is completed by request dispatch when the corresponding response +/// arrives off the wire. +#[derive(Debug)] +pub struct DispatchResponse { + response: deadline_compat::Deadline>>, + ctx: context::Context, + complete: bool, + cancellation: RequestCancellation, + request_id: u64, + server_addr: SocketAddr, +} + +impl DispatchResponse { + unsafe_pinned!(server_addr: SocketAddr); + unsafe_pinned!(ctx: context::Context); +} + +impl Future for DispatchResponse { + type Output = io::Result; + + fn poll(mut self: Pin<&mut Self>, waker: &LocalWaker) -> Poll> { + let resp = ready!(self.response.poll_unpin(waker)); + + self.complete = true; + + Poll::Ready(match resp { + Ok(resp) => Ok(resp.message?), + Err(e) => Err({ + let trace_id = *self.ctx().trace_id(); + let server_addr = *self.server_addr(); + + if e.is_elapsed() { + io::Error::new( + io::ErrorKind::TimedOut, + "Client dropped expired request.".to_string(), + ) + } else if e.is_timer() { + let e = e.into_timer().unwrap(); + if e.is_at_capacity() { + io::Error::new( + io::ErrorKind::Other, + "Cancelling request because an expiration could not be set \ + due to the timer being at capacity." + .to_string(), + ) + } else if e.is_shutdown() { + panic!("[{}/{}] Timer was shutdown", trace_id, server_addr) + } else { + panic!( + "[{}/{}] Unrecognized timer error: {}", + trace_id, server_addr, e + ) + } + } else if e.is_inner() { + // The oneshot is Canceled when the dispatch task ends. + io::Error::from(io::ErrorKind::ConnectionReset) + } else { + panic!( + "[{}/{}] Unrecognized deadline error: {}", + trace_id, server_addr, e + ) + } + }), + }) + } +} + +// Cancels the request when dropped, if not already complete. +impl Drop for DispatchResponse { + fn drop(&mut self) { + if !self.complete { + // The receiver needs to be closed to handle the edge case that the request has not + // yet been received by the dispatch task. It is possible for the cancel message to + // arrive before the request itself, in which case the request could get stuck in the + // dispatch map forever if the server never responds (e.g. if the server dies while + // responding). Even if the server does respond, it will have unnecessarily done work + // for a client no longer waiting for a response. To avoid this, the dispatch task + // checks if the receiver is closed before inserting the request in the map. By + // closing the receiver before sending the cancel message, it is guaranteed that if the + // dispatch task misses an early-arriving cancellation message, then it will see the + // receiver as closed. + self.response.get_mut().close(); + self.cancellation.cancel(self.request_id); + } + } +} + +/// Spawns a dispatch task on the default executor that manages the lifecycle of requests initiated +/// by the returned [`Channel`]. +pub async fn spawn( + config: Config, + transport: C, + server_addr: SocketAddr, +) -> io::Result> +where + Req: Send, + Resp: Send, + C: Transport, SinkItem = ClientMessage> + Send, +{ + let (to_dispatch, pending_requests) = mpsc::channel(config.pending_request_buffer); + let (cancellation, canceled_requests) = cancellations(); + + crate::spawn( + RequestDispatch { + config, + server_addr, + canceled_requests, + transport: transport.fuse(), + in_flight_requests: FnvHashMap::default(), + pending_requests: pending_requests.fuse(), + }.unwrap_or_else(move |e| error!("[{}] Connection broken: {}", server_addr, e)) + ).map_err(|e| { + io::Error::new( + io::ErrorKind::Other, + format!( + "Could not spawn client dispatch task. Is shutdown: {}", + e.is_shutdown() + ), + ) + })?; + + Ok(Channel { + to_dispatch, + cancellation, + server_addr, + next_request_id: Arc::new(AtomicU64::new(0)), + }) +} + +/// Handles the lifecycle of requests, writing requests to the wire, managing cancellations, +/// and dispatching responses to the appropriate channel. +struct RequestDispatch { + /// Writes requests to the wire and reads responses off the wire. + transport: Fuse, + /// Requests waiting to be written to the wire. + pending_requests: Fuse>>, + /// Requests that were dropped. + canceled_requests: CanceledRequests, + /// Requests already written to the wire that haven't yet received responses. + in_flight_requests: FnvHashMap>, + /// Configures limits to prevent unlimited resource usage. + config: Config, + /// The address of the server connected to. + server_addr: SocketAddr, +} + +impl RequestDispatch +where + Req: Send, + Resp: Send, + C: Transport, SinkItem = ClientMessage>, +{ + unsafe_pinned!(server_addr: SocketAddr); + unsafe_pinned!(in_flight_requests: FnvHashMap>); + unsafe_pinned!(canceled_requests: CanceledRequests); + unsafe_pinned!(pending_requests: Fuse>>); + unsafe_pinned!(transport: Fuse); + + fn pump_read(self: &mut Pin<&mut Self>, waker: &LocalWaker) -> Poll>> { + Poll::Ready(match ready!(self.transport().poll_next(waker)?) { + Some(response) => { + self.complete(response); + Some(Ok(())) + } + None => { + trace!("[{}] read half closed", self.server_addr()); + None + } + }) + } + + fn pump_write(self: &mut Pin<&mut Self>, waker: &LocalWaker) -> Poll>> { + enum ReceiverStatus { + NotReady, + Closed, + } + + let pending_requests_status = match self.poll_next_request(waker)? { + Poll::Ready(Some(dispatch_request)) => { + self.write_request(dispatch_request)?; + return Poll::Ready(Some(Ok(()))); + } + Poll::Ready(None) => ReceiverStatus::Closed, + Poll::Pending => ReceiverStatus::NotReady, + }; + + let canceled_requests_status = match self.poll_next_cancellation(waker)? { + Poll::Ready(Some((context, request_id))) => { + self.write_cancel(context, request_id)?; + return Poll::Ready(Some(Ok(()))); + } + Poll::Ready(None) => ReceiverStatus::Closed, + Poll::Pending => ReceiverStatus::NotReady, + }; + + match (pending_requests_status, canceled_requests_status) { + (ReceiverStatus::Closed, ReceiverStatus::Closed) => { + ready!(self.transport().poll_flush(waker)?); + Poll::Ready(None) + } + (ReceiverStatus::NotReady, _) | (_, ReceiverStatus::NotReady) => { + // No more messages to process, so flush any messages buffered in the transport. + ready!(self.transport().poll_flush(waker)?); + + // Even if we fully-flush, we return Pending, because we have no more requests + // or cancellations right now. + Poll::Pending + } + } + } + + /// Yields the next pending request, if one is ready to be sent. + fn poll_next_request( + self: &mut Pin<&mut Self>, + waker: &LocalWaker, + ) -> Poll>>> { + if self.in_flight_requests().len() >= self.config.max_in_flight_requests { + info!( + "At in-flight request capacity ({}/{}).", + self.in_flight_requests().len(), + self.config.max_in_flight_requests + ); + + // No need to schedule a wakeup, because timers and responses are responsible + // for clearing out in-flight requests. + return Poll::Pending; + } + + while let Poll::Pending = self.transport().poll_ready(waker)? { + // We can't yield a request-to-be-sent before the transport is capable of buffering it. + ready!(self.transport().poll_flush(waker)?); + } + + loop { + match ready!(self.pending_requests().poll_next_unpin(waker)) { + Some(request) => { + if request.response_completion.is_canceled() { + trace!( + "[{}] Request canceled before being sent.", + request.ctx.trace_id() + ); + continue; + } + + return Poll::Ready(Some(Ok(request))); + } + None => { + trace!("[{}] pending_requests closed", self.server_addr()); + return Poll::Ready(None); + } + } + } + } + + /// Yields the next pending cancellation, and, if one is ready, cancels the associated request. + fn poll_next_cancellation( + self: &mut Pin<&mut Self>, + waker: &LocalWaker, + ) -> Poll>> { + while let Poll::Pending = self.transport().poll_ready(waker)? { + ready!(self.transport().poll_flush(waker)?); + } + + loop { + match ready!(self.canceled_requests().poll_next_unpin(waker)) { + Some(request_id) => { + if let Some(in_flight_data) = self.in_flight_requests().remove(&request_id) { + self.in_flight_requests().compact(0.1); + + debug!( + "[{}/{}] Removed request.", + in_flight_data.ctx.trace_id(), + self.server_addr() + ); + + return Poll::Ready(Some(Ok((in_flight_data.ctx, request_id)))); + } + } + None => { + trace!("[{}] canceled_requests closed.", self.server_addr()); + return Poll::Ready(None); + } + } + } + } + + fn write_request( + self: &mut Pin<&mut Self>, + dispatch_request: DispatchRequest, + ) -> io::Result<()> { + let request_id = dispatch_request.request_id; + let request = ClientMessage { + trace_context: dispatch_request.ctx.trace_context, + message: ClientMessageKind::Request(Request { + id: request_id, + message: dispatch_request.request, + deadline: dispatch_request.ctx.deadline, + }), + }; + self.transport().start_send(request)?; + self.in_flight_requests().insert( + request_id, + InFlightData { + ctx: dispatch_request.ctx, + response_completion: dispatch_request.response_completion, + }, + ); + Ok(()) + } + + fn write_cancel( + self: &mut Pin<&mut Self>, + context: context::Context, + request_id: u64, + ) -> io::Result<()> { + let trace_id = *context.trace_id(); + let cancel = ClientMessage { + trace_context: context.trace_context, + message: ClientMessageKind::Cancel { request_id }, + }; + self.transport().start_send(cancel)?; + trace!("[{}/{}] Cancel message sent.", trace_id, self.server_addr()); + return Ok(()); + } + + /// Sends a server response to the client task that initiated the associated request. + fn complete(self: &mut Pin<&mut Self>, response: Response) -> bool { + if let Some(in_flight_data) = self.in_flight_requests().remove(&response.request_id) { + self.in_flight_requests().compact(0.1); + + trace!( + "[{}/{}] Received response.", + in_flight_data.ctx.trace_id(), + self.server_addr() + ); + let _ = in_flight_data.response_completion.send(response); + return true; + } + + debug!( + "[{}] No in-flight request found for request_id = {}.", + self.server_addr(), + response.request_id + ); + + // If the response completion was absent, then the request was already canceled. + false + } +} + +impl Future for RequestDispatch +where + Req: Send, + Resp: Send, + C: Transport, SinkItem = ClientMessage>, +{ + type Output = io::Result<()>; + + fn poll(mut self: Pin<&mut Self>, waker: &LocalWaker) -> Poll> { + trace!("[{}] RequestDispatch::poll", self.server_addr()); + loop { + match (self.pump_read(waker)?, self.pump_write(waker)?) { + (read, write @ Poll::Ready(None)) => { + if self.in_flight_requests().is_empty() { + info!( + "[{}] Shutdown: write half closed, and no requests in flight.", + self.server_addr() + ); + return Poll::Ready(Ok(())); + } + match read { + Poll::Ready(Some(())) => continue, + _ => { + trace!( + "[{}] read: {:?}, write: {:?}, (not ready)", + self.server_addr(), + read, + write, + ); + return Poll::Pending; + } + } + } + (read @ Poll::Ready(Some(())), write) | (read, write @ Poll::Ready(Some(()))) => { + trace!( + "[{}] read: {:?}, write: {:?}", + self.server_addr(), + read, + write, + ) + } + (read, write) => { + trace!( + "[{}] read: {:?}, write: {:?} (not ready)", + self.server_addr(), + read, + write, + ); + return Poll::Pending; + } + } + } + } +} + +/// A server-bound request sent from a [`Channel`] to request dispatch, which will then manage +/// the lifecycle of the request. +#[derive(Debug)] +struct DispatchRequest { + ctx: context::Context, + request_id: u64, + request: Req, + response_completion: oneshot::Sender>, +} + +struct InFlightData { + ctx: context::Context, + response_completion: oneshot::Sender>, +} + +/// Sends request cancellation signals. +#[derive(Debug, Clone)] +struct RequestCancellation(mpsc::UnboundedSender); + +/// A stream of IDs of requests that have been canceled. +#[derive(Debug)] +struct CanceledRequests(mpsc::UnboundedReceiver); + +/// Returns a channel to send request cancellation messages. +fn cancellations() -> (RequestCancellation, CanceledRequests) { + // Unbounded because messages are sent in the drop fn. This is fine, because it's still + // bounded by the number of in-flight requests. Additionally, each request has a clone + // of the sender, so the bounded channel would have the same behavior, + // since it guarantees a slot. + let (tx, rx) = mpsc::unbounded(); + (RequestCancellation(tx), CanceledRequests(rx)) +} + +impl RequestCancellation { + /// Cancels the request with ID `request_id`. + fn cancel(&mut self, request_id: u64) { + let _ = self.0.unbounded_send(request_id); + } +} + +impl Stream for CanceledRequests { + type Item = u64; + + fn poll_next(mut self: Pin<&mut Self>, waker: &LocalWaker) -> Poll> { + self.0.poll_next_unpin(waker) + } +} + +#[cfg(test)] +mod tests { + use super::{CanceledRequests, Channel, RequestCancellation, RequestDispatch}; + use crate::{ + client::Config, + context, + transport::{self, channel::UnboundedChannel}, + ClientMessage, Response, + }; + use fnv::FnvHashMap; + use futures::{Poll, channel::mpsc, prelude::*}; + use futures_test::task::{noop_local_waker_ref}; + use std::{ + net::{IpAddr, Ipv4Addr, SocketAddr}, + pin::Pin, + sync::atomic::AtomicU64, + sync::Arc, + }; + + #[test] + fn stage_request() { + let (mut dispatch, mut channel, _server_channel) = set_up(); + + // Test that a request future dropped before it's processed by dispatch will cause the request + // to not be added to the in-flight request map. + let _resp = tokio::runtime::current_thread::block_on_all( + channel + .send(context::current(), "hi".to_string()) + .boxed() + .compat(), + ); + + let mut dispatch = Pin::new(&mut dispatch); + let waker = &noop_local_waker_ref(); + + let req = dispatch.poll_next_request(waker).ready(); + assert!(req.is_some()); + + let req = req.unwrap(); + assert_eq!(req.request_id, 0); + assert_eq!(req.request, "hi".to_string()); + } + + #[test] + fn stage_request_response_future_dropped() { + let (mut dispatch, mut channel, _server_channel) = set_up(); + + // Test that a request future dropped before it's processed by dispatch will cause the request + // to not be added to the in-flight request map. + let resp = tokio::runtime::current_thread::block_on_all( + channel + .send(context::current(), "hi".into()) + .boxed() + .compat(), + ).unwrap(); + drop(resp); + drop(channel); + + let mut dispatch = Pin::new(&mut dispatch); + let waker = &noop_local_waker_ref(); + + dispatch.poll_next_cancellation(waker).unwrap(); + assert!(dispatch.poll_next_request(waker).ready().is_none()); + } + + #[test] + fn stage_request_response_future_closed() { + let (mut dispatch, mut channel, _server_channel) = set_up(); + + // Test that a request future that's closed its receiver but not yet canceled its request -- + // i.e. still in `drop fn` -- will cause the request to not be added to the in-flight request + // map. + let resp = tokio::runtime::current_thread::block_on_all( + channel + .send(context::current(), "hi".into()) + .boxed() + .compat(), + ).unwrap(); + drop(resp); + drop(channel); + + let mut dispatch = Pin::new(&mut dispatch); + let waker = &noop_local_waker_ref(); + assert!(dispatch.poll_next_request(waker).ready().is_none()); + } + + fn set_up() -> ( + RequestDispatch, ClientMessage>>, + Channel, + UnboundedChannel, Response>, + ) { + let _ = env_logger::try_init(); + + let (to_dispatch, pending_requests) = mpsc::channel(1); + let (cancel_tx, canceled_requests) = mpsc::unbounded(); + let (client_channel, server_channel) = transport::channel::unbounded(); + + let dispatch = RequestDispatch:: { + transport: client_channel.fuse(), + pending_requests: pending_requests.fuse(), + canceled_requests: CanceledRequests(canceled_requests), + in_flight_requests: FnvHashMap::default(), + config: Config::default(), + server_addr: SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), 0), + }; + + let cancellation = RequestCancellation(cancel_tx); + let channel = Channel { + to_dispatch, + cancellation, + next_request_id: Arc::new(AtomicU64::new(0)), + server_addr: SocketAddr::new(Ipv4Addr::UNSPECIFIED.into(), 0), + }; + + (dispatch, channel, server_channel) + } + + trait PollTest { + type T; + fn unwrap(self) -> Poll; + fn ready(self) -> Self::T; + } + + impl PollTest for Poll>> + where + E: ::std::fmt::Display + Send + 'static, + { + type T = Option; + + fn unwrap(self) -> Poll> { + match self { + Poll::Ready(Some(Ok(t))) => Poll::Ready(Some(t)), + Poll::Ready(None) => Poll::Ready(None), + Poll::Ready(Some(Err(e))) => panic!(e.to_string()), + Poll::Pending => Poll::Pending, + } + } + + fn ready(self) -> Option { + match self { + Poll::Ready(Some(Ok(t))) => Some(t), + Poll::Ready(None) => None, + Poll::Ready(Some(Err(e))) => panic!(e.to_string()), + Poll::Pending => panic!("Pending"), + } + } + } + +} diff --git a/rpc/src/client/mod.rs b/rpc/src/client/mod.rs new file mode 100644 index 00000000..88832daf --- /dev/null +++ b/rpc/src/client/mod.rs @@ -0,0 +1,85 @@ +//! Provides a client that connects to a server and sends multiplexed requests. + +use crate::{context::Context, ClientMessage, Response, Transport}; +use log::warn; +use std::{ + io, + net::{Ipv4Addr, SocketAddr}, +}; + +mod dispatch; + +/// Sends multiplexed requests to, and receives responses from, a server. +#[derive(Debug)] +pub struct Client { + /// Channel to send requests to the dispatch task. + channel: dispatch::Channel, +} + +impl Clone for Client { + fn clone(&self) -> Self { + Client { + channel: self.channel.clone(), + } + } +} + +/// Settings that control the behavior of the client. +#[non_exhaustive] +#[derive(Clone, Debug)] +pub struct Config { + /// The number of requests that can be in flight at once. + /// `max_in_flight_requests` controls the size of the map used by the client + /// for storing pending requests. + pub max_in_flight_requests: usize, + /// The number of requests that can be buffered client-side before being sent. + /// `pending_requests_buffer` controls the size of the channel clients use + /// to communicate with the request dispatch task. + pub pending_request_buffer: usize, +} + +impl Default for Config { + fn default() -> Self { + Config { + max_in_flight_requests: 1_000, + pending_request_buffer: 100, + } + } +} + +impl Client +where + Req: Send, + Resp: Send, +{ + /// Creates a new Client by wrapping a [`Transport`] and spawning a dispatch task + /// that manages the lifecycle of requests. + /// + /// Must only be called from on an executor. + pub async fn new(config: Config, transport: T) -> io::Result + where + T: Transport, SinkItem = ClientMessage> + Send, + { + let server_addr = transport.peer_addr().unwrap_or_else(|e| { + warn!( + "Setting peer to unspecified because peer could not be determined: {}", + e + ); + SocketAddr::new(Ipv4Addr::UNSPECIFIED.into(), 0) + }); + + Ok(Client { + channel: await!(dispatch::spawn(config, transport, server_addr))?, + }) + } + + /// Initiates a request, sending it to the dispatch task. + /// + /// Returns a [`Future`] that resolves to this client and the future response + /// once the request is successfully enqueued. + /// + /// [`Future`]: futures::Future + pub async fn call(&mut self, ctx: Context, request: Req) -> io::Result { + await!(self.channel.call(ctx, request)) + } +} diff --git a/rpc/src/context.rs b/rpc/src/context.rs new file mode 100644 index 00000000..82abfc2b --- /dev/null +++ b/rpc/src/context.rs @@ -0,0 +1,44 @@ +// Copyright 2018 Google Inc. All Rights Reserved. +// +// Licensed under the MIT License, . +// This file may not be copied, modified, or distributed except according to those terms. + +//! Provides a request context that carries a deadline and trace context. This context is sent from +//! client to server and is used by the server to enforce response deadlines. + +use std::time::{Duration, SystemTime}; +use trace::{self, TraceId}; + +/// A request context that carries request-scoped information like deadlines and trace information. +/// It is sent from client to server and is used by the server to enforce response deadlines. +/// +/// The context should not be stored directly in a server implementation, because the context will +/// be different for each request in scope. +#[derive(Clone, Copy, Debug)] +#[non_exhaustive] +pub struct Context { + /// When the client expects the request to be complete by. The server should cancel the request + /// if it is not complete by this time. + pub deadline: SystemTime, + /// Uniquely identifies requests originating from the same source. + /// When a service handles a request by making requests itself, those requests should + /// include the same `trace_id` as that included on the original request. This way, + /// users can trace related actions across a distributed system. + pub trace_context: trace::Context, +} + +/// Returns the context for the current request, or a default Context if no request is active. +// TODO: populate Context with request-scoped data, with default fallbacks. +pub fn current() -> Context { + Context { + deadline: SystemTime::now() + Duration::from_secs(10), + trace_context: trace::Context::new_root(), + } +} + +impl Context { + /// Returns the ID of the request-scoped trace. + pub fn trace_id(&self) -> &TraceId { + &self.trace_context.trace_id + } +} diff --git a/rpc/src/lib.rs b/rpc/src/lib.rs new file mode 100644 index 00000000..2c3b3305 --- /dev/null +++ b/rpc/src/lib.rs @@ -0,0 +1,214 @@ +// Copyright 2018 Google Inc. All Rights Reserved. +// +// Licensed under the MIT License, . +// This file may not be copied, modified, or distributed except according to those terms. + +#![feature( + const_fn, + non_exhaustive, + integer_atomics, + try_trait, + nll, + futures_api, + pin, + arbitrary_self_types, + await_macro, + async_await, + generators, + optin_builtin_traits, + generator_trait, + gen_future, + decl_macro, +)] +#![deny(missing_docs, missing_debug_implementations)] + +//! An RPC framework providing client and server. +//! +//! Features: +//! * RPC deadlines, both client- and server-side. +//! * Cascading cancellation (works with multiple hops). +//! * Configurable limits +//! * In-flight requests, both client and server-side. +//! * Server-side limit is per-connection. +//! * When the server reaches the in-flight request maximum, it returns a throttled error +//! to the client. +//! * When the client reaches the in-flight request max, messages are buffered up to a +//! configurable maximum, beyond which the requests are back-pressured. +//! * Server connections. +//! * Total and per-IP limits. +//! * When an incoming connection is accepted, if already at maximum, the connection is +//! dropped. +//! * Transport agnostic. + +pub mod client; +pub mod context; +pub mod server; +pub mod transport; +pub(crate) mod util; + +pub use crate::{client::Client, server::Server, transport::Transport}; + +use futures::{Future, task::{Spawn, SpawnExt, SpawnError}}; +use std::{cell::RefCell, io, sync::Once, time::SystemTime}; + +/// A message from a client to a server. +#[derive(Debug)] +#[cfg_attr( + feature = "serde", + derive(serde::Serialize, serde::Deserialize) +)] +#[non_exhaustive] +pub struct ClientMessage { + /// The trace context associates the message with a specific chain of causally-related actions, + /// possibly orchestrated across many distributed systems. + pub trace_context: trace::Context, + /// The message payload. + pub message: ClientMessageKind, +} + +/// Different messages that can be sent from a client to a server. +#[derive(Debug)] +#[cfg_attr( + feature = "serde", + derive(serde::Serialize, serde::Deserialize) +)] +#[non_exhaustive] +pub enum ClientMessageKind { + /// A request initiated by a user. The server responds to a request by invoking a + /// service-provided request handler. The handler completes with a [`response`](Response), which + /// the server sends back to the client. + Request(Request), + /// A command to cancel an in-flight request, automatically sent by the client when a response + /// future is dropped. + /// + /// When received, the server will immediately cancel the main task (top-level future) of the + /// request handler for the associated request. Any tasks spawned by the request handler will + /// not be canceled, because the framework layer does not + /// know about them. + Cancel { + /// The ID of the request to cancel. + request_id: u64, + }, +} + +/// A request from a client to a server. +#[derive(Debug)] +#[cfg_attr( + feature = "serde", + derive(serde::Serialize, serde::Deserialize) +)] +#[non_exhaustive] +pub struct Request { + /// Uniquely identifies the request across all requests sent over a single channel. + pub id: u64, + /// The request body. + pub message: T, + /// When the client expects the request to be complete by. The server will cancel the request + /// if it is not complete by this time. + #[cfg_attr( + feature = "serde", + serde(serialize_with = "util::serde::serialize_epoch_secs") + )] + #[cfg_attr( + feature = "serde", + serde(deserialize_with = "util::serde::deserialize_epoch_secs") + )] + pub deadline: SystemTime, +} + +/// A response from a server to a client. +#[derive(Debug, PartialEq, Eq)] +#[cfg_attr( + feature = "serde", + derive(serde::Serialize, serde::Deserialize) +)] +#[non_exhaustive] +pub struct Response { + /// The ID of the request being responded to. + pub request_id: u64, + /// The response body, or an error if the request failed. + pub message: Result, +} + +/// An error response from a server to a client. +#[derive(Debug, PartialEq, Eq)] +#[cfg_attr( + feature = "serde", + derive(serde::Serialize, serde::Deserialize) +)] +#[non_exhaustive] +pub struct ServerError { + #[cfg_attr( + feature = "serde", + serde(serialize_with = "util::serde::serialize_io_error_kind_as_u32") + )] + #[cfg_attr( + feature = "serde", + serde(deserialize_with = "util::serde::deserialize_io_error_kind_from_u32") + )] + /// The type of error that occurred to fail the request. + pub kind: io::ErrorKind, + /// A message describing more detail about the error that occurred. + pub detail: Option, +} + +impl From for io::Error { + fn from(e: ServerError) -> io::Error { + io::Error::new(e.kind, e.detail.unwrap_or_default()) + } +} + +impl Request { + /// Returns the deadline for this request. + pub fn deadline(&self) -> &SystemTime { + &self.deadline + } +} + +static INIT: Once = Once::new(); +static mut SEED_SPAWN: Option> = None; +thread_local! { + static SPAWN: RefCell> = { + unsafe { + // INIT must always be called before accessing SPAWN. + // Otherwise, accessing SPAWN can trigger undefined behavior due to race conditions. + INIT.call_once(|| {}); + RefCell::new(SEED_SPAWN.clone().expect("init() must be called.")) + } + }; +} + +/// Initializes the RPC library with a mechanism to spawn futures on the user's runtime. +/// Client stubs and servers both use the initialized spawn. +/// +/// Init only has an effect the first time it is called. If called previously, successive calls to +/// init are noops. +pub fn init(spawn: impl Spawn + Clone + 'static) { + unsafe { + INIT.call_once(|| { + SEED_SPAWN = Some(Box::new(spawn)); + }); + } +} + +pub(crate) fn spawn(future: impl Future + Send + 'static) -> Result<(), SpawnError> { + SPAWN.with(|spawn| { + spawn.borrow_mut().spawn(future) + }) +} + +trait CloneSpawn: Spawn { + fn box_clone(&self) -> Box; +} + +impl Clone for Box { + fn clone(&self) -> Self { + self.box_clone() + } +} + +impl CloneSpawn for S { + fn box_clone(&self) -> Box { + Box::new(self.clone()) + } +} diff --git a/rpc/src/server/filter.rs b/rpc/src/server/filter.rs new file mode 100644 index 00000000..d5d76174 --- /dev/null +++ b/rpc/src/server/filter.rs @@ -0,0 +1,251 @@ +use crate::{ + server::{Channel, Config}, + util::Compact, + ClientMessage, Response, Transport, +}; +use fnv::FnvHashMap; +use futures::{channel::mpsc, prelude::*, ready, stream::Fuse, task::{LocalWaker, Poll}}; +use log::{debug, error, info, trace, warn}; +use pin_utils::unsafe_pinned; +use std::{ + collections::hash_map::Entry, + io, + marker::PhantomData, + net::{IpAddr, SocketAddr}, + ops::Try, + option::NoneError, + pin::Pin, +}; + +/// Drops connections under configurable conditions: +/// +/// 1. If the max number of connections is reached. +/// 2. If the max number of connections for a single IP is reached. +#[derive(Debug)] +pub struct ConnectionFilter { + listener: Fuse, + closed_connections: mpsc::UnboundedSender, + closed_connections_rx: mpsc::UnboundedReceiver, + config: Config, + connections_per_ip: FnvHashMap, + open_connections: usize, + ghost: PhantomData<(Req, Resp)>, +} + +enum NewConnection { + Filtered, + Accepted(Channel), +} + +impl Try for NewConnection { + type Ok = Channel; + type Error = NoneError; + + fn into_result(self) -> Result, NoneError> { + match self { + NewConnection::Filtered => Err(NoneError), + NewConnection::Accepted(channel) => Ok(channel), + } + } + + fn from_error(_: NoneError) -> Self { + NewConnection::Filtered + } + + fn from_ok(channel: Channel) -> Self { + NewConnection::Accepted(channel) + } +} + +impl ConnectionFilter { + unsafe_pinned!(open_connections: usize); + unsafe_pinned!(config: Config); + unsafe_pinned!(connections_per_ip: FnvHashMap); + unsafe_pinned!(closed_connections_rx: mpsc::UnboundedReceiver); + unsafe_pinned!(listener: Fuse); + + /// Sheds new connections to stay under configured limits. + pub fn filter(listener: S, config: Config) -> Self + where + S: Stream>, + C: Transport, SinkItem = Response> + Send, + { + let (closed_connections, closed_connections_rx) = mpsc::unbounded(); + + ConnectionFilter { + listener: listener.fuse(), + closed_connections, + closed_connections_rx, + config, + connections_per_ip: FnvHashMap::default(), + open_connections: 0, + ghost: PhantomData, + } + } + + fn handle_new_connection(self: &mut Pin<&mut Self>, stream: C) -> NewConnection + where + C: Transport, SinkItem = Response> + Send, + { + let peer = match stream.peer_addr() { + Ok(peer) => peer, + Err(e) => { + warn!("Could not get peer_addr of new connection: {}", e); + return NewConnection::Filtered; + } + }; + + let open_connections = *self.open_connections(); + if open_connections >= self.config().max_connections { + warn!( + "[{}] Shedding connection because the maximum open connections \ + limit is reached ({}/{}).", + peer, + open_connections, + self.config().max_connections + ); + return NewConnection::Filtered; + } + + let config = self.config.clone(); + let open_connections_for_ip = self.increment_connections_for_ip(&peer)?; + *self.open_connections() += 1; + + debug!( + "[{}] Opening channel ({}/{} connections for IP, {} total).", + peer, + open_connections_for_ip, + config.max_connections_per_ip, + self.open_connections(), + ); + + NewConnection::Accepted(Channel { + client_addr: peer, + closed_connections: self.closed_connections.clone(), + transport: stream.fuse(), + config, + ghost: PhantomData, + }) + } + + fn handle_closed_connection(self: &mut Pin<&mut Self>, addr: &SocketAddr) { + *self.open_connections() -= 1; + debug!( + "[{}] Closing channel. {} open connections remaining.", + addr, self.open_connections + ); + self.decrement_connections_for_ip(&addr); + self.connections_per_ip().compact(0.1); + } + + fn increment_connections_for_ip(self: &mut Pin<&mut Self>, peer: &SocketAddr) -> Option { + let max_connections_per_ip = self.config().max_connections_per_ip; + let mut occupied; + let mut connections_per_ip = self.connections_per_ip(); + let occupied = match connections_per_ip.entry(peer.ip()) { + Entry::Vacant(vacant) => vacant.insert(0), + Entry::Occupied(o) => { + if *o.get() < max_connections_per_ip { + // Store the reference outside the block to extend the lifetime. + occupied = o; + occupied.get_mut() + } else { + info!( + "[{}] Opened max connections from IP ({}/{}).", + peer, + o.get(), + max_connections_per_ip + ); + return None; + } + } + }; + *occupied += 1; + Some(*occupied) + } + + fn decrement_connections_for_ip(self: &mut Pin<&mut Self>, addr: &SocketAddr) { + let should_compact = match self.connections_per_ip().entry(addr.ip()) { + Entry::Vacant(_) => { + error!("[{}] Got vacant entry when closing connection.", addr); + return; + } + Entry::Occupied(mut occupied) => { + *occupied.get_mut() -= 1; + if *occupied.get() == 0 { + occupied.remove(); + true + } else { + false + } + } + }; + if should_compact { + self.connections_per_ip().compact(0.1); + } + } + + fn poll_listener( + self: &mut Pin<&mut Self>, + cx: &LocalWaker, + ) -> Poll>>> + where + S: Stream>, + C: Transport, SinkItem = Response> + Send, + { + match ready!(self.listener().poll_next_unpin(cx)?) { + Some(codec) => Poll::Ready(Some(Ok(self.handle_new_connection(codec)))), + None => Poll::Ready(None), + } + } + + fn poll_closed_connections( + self: &mut Pin<&mut Self>, + cx: &LocalWaker, + ) -> Poll> { + match ready!(self.closed_connections_rx().poll_next_unpin(cx)) { + Some(addr) => { + self.handle_closed_connection(&addr); + Poll::Ready(Ok(())) + } + None => unreachable!("Holding a copy of closed_connections and didn't close it."), + } + } +} + +impl Stream for ConnectionFilter +where + S: Stream>, + T: Transport, SinkItem = Response> + Send, +{ + type Item = io::Result>; + + fn poll_next( + mut self: Pin<&mut Self>, + cx: &LocalWaker, + ) -> Poll>>> { + loop { + match (self.poll_listener(cx)?, self.poll_closed_connections(cx)?) { + (Poll::Ready(Some(NewConnection::Accepted(channel))), _) => { + return Poll::Ready(Some(Ok(channel))) + } + (Poll::Ready(Some(NewConnection::Filtered)), _) | (_, Poll::Ready(())) => { + trace!("Filtered a connection; {} open.", self.open_connections()); + continue; + } + (Poll::Pending, Poll::Pending) => return Poll::Pending, + (Poll::Ready(None), Poll::Pending) => { + if *self.open_connections() > 0 { + trace!( + "Listener closed; {} open connections.", + self.open_connections() + ); + return Poll::Pending; + } + trace!("Shutting down listener: all connections closed, and no more coming."); + return Poll::Ready(None); + } + } + } + } +} diff --git a/rpc/src/server/mod.rs b/rpc/src/server/mod.rs new file mode 100644 index 00000000..2c8f862e --- /dev/null +++ b/rpc/src/server/mod.rs @@ -0,0 +1,599 @@ +//! Provides a server that concurrently handles many connections sending multiplexed requests. + +use crate::{ + context::Context, util::deadline_compat, util::AsDuration, util::Compact, ClientMessage, + ClientMessageKind, Request, Response, ServerError, Transport, +}; +use fnv::FnvHashMap; +use futures::{ + channel::mpsc, + future::{abortable, AbortHandle}, + prelude::*, + ready, + stream::Fuse, + task::{LocalWaker, Poll}, + try_ready, +}; +use humantime::format_rfc3339; +use log::{debug, error, info, trace, warn}; +use pin_utils::{unsafe_pinned, unsafe_unpinned}; +use std::{ + error::Error as StdError, + io, + marker::PhantomData, + net::SocketAddr, + pin::Pin, + time::{Instant, SystemTime}, +}; +use tokio_timer::timeout; +use trace::{self, TraceId}; + +mod filter; + +/// Manages clients, serving multiplexed requests over each connection. +#[derive(Debug)] +pub struct Server { + config: Config, + ghost: PhantomData<(Req, Resp)>, +} + +/// Settings that control the behavior of the server. +#[non_exhaustive] +#[derive(Clone, Debug)] +pub struct Config { + /// The maximum number of clients that can be connected to the server at once. When at the + /// limit, existing connections are honored and new connections are rejected. + pub max_connections: usize, + /// The maximum number of clients per IP address that can be connected to the server at once. + /// When an IP is at the limit, existing connections are honored and new connections on that IP + /// address are rejected. + pub max_connections_per_ip: usize, + /// The maximum number of requests that can be in flight for each client. When a client is at + /// the in-flight request limit, existing requests are fulfilled and new requests are rejected. + /// Rejected requests are sent a response error. + pub max_in_flight_requests_per_connection: usize, + /// The number of responses per client that can be buffered server-side before being sent. + /// `pending_response_buffer` controls the buffer size of the channel that a server's + /// response tasks use to send responses to the client handler task. + pub pending_response_buffer: usize, +} + +impl Default for Config { + fn default() -> Self { + Config { + max_connections: 1_000_000, + max_connections_per_ip: 1_000, + max_in_flight_requests_per_connection: 1_000, + pending_response_buffer: 100, + } + } +} + +impl Server { + /// Returns a new server with configuration specified `config`. + pub fn new(config: Config) -> Self { + Server { + config, + ghost: PhantomData, + } + } + + /// Returns the config for this server. + pub fn config(&self) -> &Config { + &self.config + } + + /// Returns a stream of the incoming connections to the server. + pub fn incoming( + self, + listener: S, + ) -> impl Stream>> + where + Req: Send, + Resp: Send, + S: Stream>, + T: Transport, SinkItem = Response> + Send, + { + self::filter::ConnectionFilter::filter(listener, self.config.clone()) + } +} + +/// The future driving the server. +#[derive(Debug)] +pub struct Running { + incoming: S, + request_handler: F, +} + +impl Running { + unsafe_pinned!(incoming: S); + unsafe_unpinned!(request_handler: F); +} + +impl Future for Running +where + S: Sized + Stream>>, + Req: Send + 'static, + Resp: Send + 'static, + T: Transport, SinkItem = Response> + Send + 'static, + F: FnMut(Context, Req) -> Fut + Send + 'static + Clone, + Fut: Future> + Send + 'static, +{ + type Output = (); + + fn poll(mut self: Pin<&mut Self>, cx: &LocalWaker) -> Poll<()> { + while let Some(channel) = ready!(self.incoming().poll_next(cx)) { + match channel { + Ok(channel) => { + let peer = channel.client_addr; + if let Err(e) = crate::spawn(channel.respond_with(self.request_handler().clone())) + { + warn!("[{}] Failed to spawn connection handler: {:?}", peer, e); + } + } + Err(e) => { + warn!("Incoming connection error: {}", e); + } + } + } + info!("Server shutting down."); + return Poll::Ready(()); + } +} + +/// A utility trait enabling a stream to fluently chain a request handler. +pub trait Handler +where + Self: Sized + Stream>>, + Req: Send, + Resp: Send, + T: Transport, SinkItem = Response> + Send, +{ + /// Responds to all requests with `request_handler`. + fn respond_with(self, request_handler: F) -> Running + where + F: FnMut(Context, Req) -> Fut + Send + 'static + Clone, + Fut: Future> + Send + 'static, + { + Running { + incoming: self, + request_handler, + } + } +} + +impl Handler for S +where + S: Sized + Stream>>, + Req: Send, + Resp: Send, + T: Transport, SinkItem = Response> + Send, +{} + +/// Responds to all requests with `request_handler`. +/// The server end of an open connection with a client. +#[derive(Debug)] +pub struct Channel { + /// Writes responses to the wire and reads requests off the wire. + transport: Fuse, + /// Signals the connection is closed when `Channel` is dropped. + closed_connections: mpsc::UnboundedSender, + /// Channel limits to prevent unlimited resource usage. + config: Config, + /// The address of the server connected to. + client_addr: SocketAddr, + /// Types the request and response. + ghost: PhantomData<(Req, Resp)>, +} + +impl Drop for Channel { + fn drop(&mut self) { + trace!("[{}] Closing channel.", self.client_addr); + + // Even in a bounded channel, each connection would have a guaranteed slot, so using + // an unbounded sender is actually no different. And, the bound is on the maximum number + // of open connections. + if self + .closed_connections + .unbounded_send(self.client_addr) + .is_err() + { + warn!( + "[{}] Failed to send closed connection message.", + self.client_addr + ); + } + } +} + +impl Channel { + unsafe_pinned!(transport: Fuse); +} + +impl Channel +where + T: Transport, SinkItem = Response> + Send, + Req: Send, + Resp: Send, +{ + pub(crate) fn start_send(self: &mut Pin<&mut Self>, response: Response) -> io::Result<()> { + self.transport().start_send(response) + } + + pub(crate) fn poll_ready( + self: &mut Pin<&mut Self>, + cx: &LocalWaker, + ) -> Poll> { + self.transport().poll_ready(cx) + } + + pub(crate) fn poll_flush( + self: &mut Pin<&mut Self>, + cx: &LocalWaker, + ) -> Poll> { + self.transport().poll_flush(cx) + } + + pub(crate) fn poll_next( + self: &mut Pin<&mut Self>, + cx: &LocalWaker, + ) -> Poll>>> { + self.transport().poll_next(cx) + } + + /// Returns the address of the client connected to the channel. + pub fn client_addr(&self) -> &SocketAddr { + &self.client_addr + } + + /// Respond to requests coming over the channel with `f`. Returns a future that drives the + /// responses and resolves when the connection is closed. + pub fn respond_with(self, f: F) -> impl Future + where + F: FnMut(Context, Req) -> Fut + Send + 'static, + Fut: Future> + Send + 'static, + Req: 'static, + Resp: 'static, + { + let (responses_tx, responses) = mpsc::channel(self.config.pending_response_buffer); + let responses = responses.fuse(); + let peer = self.client_addr; + + ClientHandler { + channel: self, + f, + pending_responses: responses, + responses_tx, + in_flight_requests: FnvHashMap::default(), + }.unwrap_or_else(move |e| { + info!("[{}] ClientHandler errored out: {}", peer, e); + }) + } +} + +#[derive(Debug)] +struct ClientHandler { + channel: Channel, + /// Responses waiting to be written to the wire. + pending_responses: Fuse)>>, + /// Handed out to request handlers to fan in responses. + responses_tx: mpsc::Sender<(Context, Response)>, + /// Number of requests currently being responded to. + in_flight_requests: FnvHashMap, + /// Request handler. + f: F, +} + +impl ClientHandler { + unsafe_pinned!(channel: Channel); + unsafe_pinned!(in_flight_requests: FnvHashMap); + unsafe_pinned!(pending_responses: Fuse)>>); + unsafe_pinned!(responses_tx: mpsc::Sender<(Context, Response)>); + // For this to be safe, field f must be private, and code in this module must never + // construct PinMut. + unsafe_unpinned!(f: F); +} + +impl ClientHandler +where + Req: Send + 'static, + Resp: Send + 'static, + T: Transport, SinkItem = Response> + Send, + F: FnMut(Context, Req) -> Fut + Send + 'static, + Fut: Future> + Send + 'static, +{ + /// If at max in-flight requests, check that there's room to immediately write a throttled + /// response. + fn poll_ready_if_throttling( + self: &mut Pin<&mut Self>, + cx: &LocalWaker, + ) -> Poll> { + if self.in_flight_requests.len() + >= self.channel.config.max_in_flight_requests_per_connection + { + let peer = self.channel().client_addr; + + while let Poll::Pending = self.channel().poll_ready(cx)? { + info!( + "[{}] In-flight requests at max ({}), and transport is not ready.", + peer, + self.in_flight_requests().len(), + ); + try_ready!(self.channel().poll_flush(cx)); + } + } + Poll::Ready(Ok(())) + } + + fn pump_read(self: &mut Pin<&mut Self>, cx: &LocalWaker) -> Poll>> { + ready!(self.poll_ready_if_throttling(cx)?); + + Poll::Ready(match ready!(self.channel().poll_next(cx)?) { + Some(message) => { + match message.message { + ClientMessageKind::Request(request) => { + self.handle_request(message.trace_context, request)?; + } + ClientMessageKind::Cancel { request_id } => { + self.cancel_request(&message.trace_context, request_id); + } + } + Some(Ok(())) + } + None => { + trace!("[{}] Read half closed", self.channel.client_addr); + None + } + }) + } + + fn pump_write( + self: &mut Pin<&mut Self>, + cx: &LocalWaker, + read_half_closed: bool, + ) -> Poll>> { + match self.poll_next_response(cx)? { + Poll::Ready(Some((_, response))) => { + self.channel().start_send(response)?; + Poll::Ready(Some(Ok(()))) + } + Poll::Ready(None) => { + // Shutdown can't be done before we finish pumping out remaining responses. + ready!(self.channel().poll_flush(cx)?); + Poll::Ready(None) + } + Poll::Pending => { + // No more requests to process, so flush any requests buffered in the transport. + ready!(self.channel().poll_flush(cx)?); + + // Being here means there are no staged requests and all written responses are + // fully flushed. So, if the read half is closed and there are no in-flight + // requests, then we can close the write half. + if read_half_closed && self.in_flight_requests().is_empty() { + Poll::Ready(None) + } else { + Poll::Pending + } + } + } + } + + fn poll_next_response( + self: &mut Pin<&mut Self>, + cx: &LocalWaker, + ) -> Poll)>>> { + // Ensure there's room to write a response. + while let Poll::Pending = self.channel().poll_ready(cx)? { + ready!(self.channel().poll_flush(cx)?); + } + + let peer = self.channel().client_addr; + + match ready!(self.pending_responses().poll_next(cx)) { + Some((ctx, response)) => { + if let Some(_) = self.in_flight_requests().remove(&response.request_id) { + self.in_flight_requests().compact(0.1); + } + trace!( + "[{}/{}] Staging response. In-flight requests = {}.", + ctx.trace_id(), + peer, + self.in_flight_requests().len(), + ); + return Poll::Ready(Some(Ok((ctx, response)))); + } + None => { + // This branch likely won't happen, since the ClientHandler is holding a Sender. + trace!("[{}] No new responses.", peer); + Poll::Ready(None) + } + } + } + + fn handle_request( + self: &mut Pin<&mut Self>, + trace_context: trace::Context, + request: Request, + ) -> io::Result<()> { + let request_id = request.id; + let peer = self.channel().client_addr; + let ctx = Context { + deadline: request.deadline, + trace_context, + }; + let request = request.message; + + if self.in_flight_requests().len() + >= self.channel().config.max_in_flight_requests_per_connection + { + debug!( + "[{}/{}] Client has reached in-flight request limit ({}/{}).", + ctx.trace_id(), + peer, + self.in_flight_requests().len(), + self.channel().config.max_in_flight_requests_per_connection + ); + + self.channel().start_send(Response { + request_id, + message: Err(ServerError { + kind: io::ErrorKind::WouldBlock, + detail: Some("Server throttled the request.".into()), + }), + })?; + return Ok(()); + } + + let deadline = ctx.deadline; + let timeout = deadline.as_duration(); + trace!( + "[{}/{}] Received request with deadline {} (timeout {:?}).", + ctx.trace_id(), + peer, + format_rfc3339(deadline), + timeout, + ); + let mut response_tx = self.responses_tx().clone(); + + let trace_id = *ctx.trace_id(); + let response = self.f()(ctx.clone(), request); + let response = deadline_compat::Deadline::new(response, Instant::now() + timeout).then( + async move |result| { + let response = Response { + request_id, + message: match result { + Ok(message) => Ok(message), + Err(e) => Err(make_server_error(e, trace_id, peer, deadline)), + }, + }; + trace!("[{}/{}] Sending response.", trace_id, peer); + await!(response_tx.send((ctx, response)).unwrap_or_else(|_| ())); + }, + ); + let (abortable_response, abort_handle) = abortable(response); + crate::spawn(abortable_response.map(|_| ())) + .map_err(|e| { + io::Error::new( + io::ErrorKind::Other, + format!( + "Could not spawn response task. Is shutdown: {}", + e.is_shutdown() + ), + ) + })?; + self.in_flight_requests().insert(request_id, abort_handle); + Ok(()) + } + + fn cancel_request(self: &mut Pin<&mut Self>, trace_context: &trace::Context, request_id: u64) { + // It's possible the request was already completed, so it's fine + // if this is None. + if let Some(cancel_handle) = self.in_flight_requests().remove(&request_id) { + self.in_flight_requests().compact(0.1); + + cancel_handle.abort(); + let remaining = self.in_flight_requests().len(); + trace!( + "[{}/{}] Request canceled. In-flight requests = {}", + trace_context.trace_id, + self.channel.client_addr, + remaining, + ); + } else { + trace!( + "[{}/{}] Received cancellation, but response handler \ + is already complete.", + trace_context.trace_id, + self.channel.client_addr + ); + } + } +} + +impl Future for ClientHandler +where + Req: Send + 'static, + Resp: Send + 'static, + T: Transport, SinkItem = Response> + Send, + F: FnMut(Context, Req) -> Fut + Send + 'static, + Fut: Future> + Send + 'static, +{ + type Output = io::Result<()>; + + fn poll(mut self: Pin<&mut Self>, cx: &LocalWaker) -> Poll> { + trace!("[{}] ClientHandler::poll", self.channel.client_addr); + loop { + let read = self.pump_read(cx)?; + match (read, self.pump_write(cx, read == Poll::Ready(None))?) { + (Poll::Ready(None), Poll::Ready(None)) => { + info!("[{}] Client disconnected.", self.channel.client_addr); + return Poll::Ready(Ok(())); + } + (read @ Poll::Ready(Some(())), write) | (read, write @ Poll::Ready(Some(()))) => { + trace!( + "[{}] read: {:?}, write: {:?}.", + self.channel.client_addr, + read, + write + ) + } + (read, write) => { + trace!( + "[{}] read: {:?}, write: {:?} (not ready).", + self.channel.client_addr, + read, + write, + ); + return Poll::Pending; + } + } + } + } +} + +fn make_server_error( + e: timeout::Error, + trace_id: TraceId, + peer: SocketAddr, + deadline: SystemTime, +) -> ServerError { + if e.is_elapsed() { + debug!( + "[{}/{}] Response did not complete before deadline of {}s.", + trace_id, + peer, + format_rfc3339(deadline) + ); + // No point in responding, since the client will have dropped the request. + ServerError { + kind: io::ErrorKind::TimedOut, + detail: Some(format!( + "Response did not complete before deadline of {}s.", + format_rfc3339(deadline) + )), + } + } else if e.is_timer() { + error!( + "[{}/{}] Response failed because of an issue with a timer: {}", + trace_id, peer, e + ); + + ServerError { + kind: io::ErrorKind::Other, + detail: Some(format!("{}", e)), + } + } else if e.is_inner() { + let e = e.into_inner().unwrap(); + ServerError { + kind: e.kind(), + detail: Some(e.description().into()), + } + } else { + error!("[{}/{}] Unexpected response failure: {}", trace_id, peer, e); + + ServerError { + kind: io::ErrorKind::Other, + detail: Some(format!("Server unexpectedly failed to respond: {}", e)), + } + } +} diff --git a/rpc/src/transport/channel.rs b/rpc/src/transport/channel.rs new file mode 100644 index 00000000..8c0368ec --- /dev/null +++ b/rpc/src/transport/channel.rs @@ -0,0 +1,151 @@ +//! Transports backed by in-memory channels. + +use crate::Transport; +use futures::{channel::mpsc, task::{LocalWaker}, Poll, Sink, Stream}; +use pin_utils::unsafe_pinned; +use std::pin::Pin; +use std::{ + io, + net::{IpAddr, Ipv4Addr, SocketAddr}, +}; + +/// Returns two unbounded channel peers. Each [`Stream`] yields items sent through the other's +/// [`Sink`]. +pub fn unbounded() -> ( + UnboundedChannel, + UnboundedChannel, +) { + let (tx1, rx2) = mpsc::unbounded(); + let (tx2, rx1) = mpsc::unbounded(); + ( + UnboundedChannel { tx: tx1, rx: rx1 }, + UnboundedChannel { tx: tx2, rx: rx2 }, + ) +} + +/// A bi-directional channel backed by an [`UnboundedSender`](mpsc::UnboundedSender) +/// and [`UnboundedReceiver`](mpsc::UnboundedReceiver). +#[derive(Debug)] +pub struct UnboundedChannel { + rx: mpsc::UnboundedReceiver, + tx: mpsc::UnboundedSender, +} + +impl UnboundedChannel { + unsafe_pinned!(rx: mpsc::UnboundedReceiver); + unsafe_pinned!(tx: mpsc::UnboundedSender); +} + +impl Stream for UnboundedChannel { + type Item = Result; + + fn poll_next(mut self: Pin<&mut Self>, cx: &LocalWaker) -> Poll>> { + self.rx().poll_next(cx).map(|option| option.map(Ok)) + } +} + +impl Sink for UnboundedChannel { + type SinkItem = SinkItem; + type SinkError = io::Error; + + fn poll_ready(mut self: Pin<&mut Self>, cx: &LocalWaker) -> Poll> { + self.tx() + .poll_ready(cx) + .map_err(|_| io::Error::from(io::ErrorKind::NotConnected)) + } + + fn start_send(mut self: Pin<&mut Self>, item: SinkItem) -> io::Result<()> { + self.tx() + .start_send(item) + .map_err(|_| io::Error::from(io::ErrorKind::NotConnected)) + } + + fn poll_flush( + mut self: Pin<&mut Self>, + cx: &LocalWaker, + ) -> Poll> { + self.tx() + .poll_flush(cx) + .map_err(|_| io::Error::from(io::ErrorKind::NotConnected)) + } + + fn poll_close(mut self: Pin<&mut Self>, cx: &LocalWaker) -> Poll> { + self.tx() + .poll_close(cx) + .map_err(|_| io::Error::from(io::ErrorKind::NotConnected)) + } +} + +impl Transport for UnboundedChannel { + type Item = Item; + type SinkItem = SinkItem; + + fn peer_addr(&self) -> io::Result { + Ok(SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), 0)) + } + + fn local_addr(&self) -> io::Result { + Ok(SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), 0)) + } +} + +#[cfg(test)] +mod tests { + use crate::{client::{self, Client}, context, server::{self, Handler, Server}, transport}; + use futures::{prelude::*, stream, compat::TokioDefaultSpawner}; + use log::trace; + use std::io; + + #[test] + fn integration() { + let _ = env_logger::try_init(); + crate::init(TokioDefaultSpawner); + + let (client_channel, server_channel) = transport::channel::unbounded(); + let server = Server::::new(server::Config::default()) + .incoming(stream::once(future::ready(Ok(server_channel)))) + .respond_with(|_ctx, request| { + future::ready(request.parse::().map_err(|_| { + io::Error::new( + io::ErrorKind::InvalidInput, + format!("{:?} is not an int", request), + ) + })) + }); + + let responses = async { + let mut client = await!(Client::new(client::Config::default(), client_channel))?; + + let response1 = await!(client.call(context::current(), "123".into())); + let response2 = await!(client.call(context::current(), "abc".into())); + + Ok::<_, io::Error>((response1, response2)) + }; + + let (response1, response2) = + run_future(server.join(responses.unwrap_or_else(|e| panic!(e)))).1; + + trace!("response1: {:?}, response2: {:?}", response1, response2); + + assert!(response1.is_ok()); + assert_eq!(response1.ok().unwrap(), 123); + + assert!(response2.is_err()); + assert_eq!(response2.err().unwrap().kind(), io::ErrorKind::InvalidInput); + } + + fn run_future(f: F) -> F::Output + where + F: Future + Send + 'static, + F::Output: Send + 'static, + { + let (tx, rx) = futures::channel::oneshot::channel(); + tokio::run( + f.map(|result| tx.send(result).unwrap_or_else(|_| unreachable!())) + .boxed() + .unit_error() + .compat(), + ); + futures::executor::block_on(rx).unwrap() + } +} diff --git a/rpc/src/transport/mod.rs b/rpc/src/transport/mod.rs new file mode 100644 index 00000000..2821fb38 --- /dev/null +++ b/rpc/src/transport/mod.rs @@ -0,0 +1,26 @@ +//! Provides a [`Transport`] trait as well as implementations. +//! +//! The rpc crate is transport- and protocol-agnostic. Any transport that impls [`Transport`] +//! can be plugged in, using whatever protocol it wants. + +use futures::prelude::*; +use std::{io, net::SocketAddr}; + +pub mod channel; + +/// A bidirectional stream ([`Sink`] + [`Stream`]) of messages. +pub trait Transport +where + Self: Stream::Item>>, + Self: Sink::SinkItem, SinkError = io::Error>, +{ + /// The type read off the transport. + type Item; + /// The type written to the transport. + type SinkItem; + + /// The address of the remote peer this transport is in communication with. + fn peer_addr(&self) -> io::Result; + /// The address of the local half of this transport. + fn local_addr(&self) -> io::Result; +} diff --git a/rpc/src/util/deadline_compat.rs b/rpc/src/util/deadline_compat.rs new file mode 100644 index 00000000..3a73bfd1 --- /dev/null +++ b/rpc/src/util/deadline_compat.rs @@ -0,0 +1,63 @@ +use futures::{ + compat::{Compat01As03, Future01CompatExt}, + prelude::*, + ready, task::{Poll, LocalWaker}, +}; +use pin_utils::unsafe_pinned; +use std::pin::Pin; +use std::time::Instant; +use tokio_timer::{timeout, Delay}; + +#[must_use = "futures do nothing unless polled"] +#[derive(Debug)] +pub struct Deadline { + future: T, + delay: Compat01As03, +} + +impl Deadline { + unsafe_pinned!(future: T); + unsafe_pinned!(delay: Compat01As03); + + /// Create a new `Deadline` that completes when `future` completes or when + /// `deadline` is reached. + pub fn new(future: T, deadline: Instant) -> Deadline { + Deadline::new_with_delay(future, Delay::new(deadline)) + } + + pub(crate) fn new_with_delay(future: T, delay: Delay) -> Deadline { + Deadline { + future, + delay: delay.compat(), + } + } + + /// Gets a mutable reference to the underlying future in this deadline. + pub fn get_mut(&mut self) -> &mut T { + &mut self.future + } +} +impl Future for Deadline +where + T: TryFuture, +{ + type Output = Result>; + + fn poll(mut self: Pin<&mut Self>, waker: &LocalWaker) -> Poll { + + // First, try polling the future + match self.future().try_poll(waker) { + Poll::Ready(Ok(v)) => return Poll::Ready(Ok(v)), + Poll::Pending => {} + Poll::Ready(Err(e)) => return Poll::Ready(Err(timeout::Error::inner(e))), + } + + let delay = self.delay().poll_unpin(waker); + + // Now check the timer + match ready!(delay) { + Ok(_) => Poll::Ready(Err(timeout::Error::elapsed())), + Err(e) => Poll::Ready(Err(timeout::Error::timer(e))), + } + } +} diff --git a/rpc/src/util/mod.rs b/rpc/src/util/mod.rs new file mode 100644 index 00000000..472d04fe --- /dev/null +++ b/rpc/src/util/mod.rs @@ -0,0 +1,40 @@ +use std::{ + collections::HashMap, + hash::{BuildHasher, Hash}, + time::{Duration, SystemTime}, +}; + +pub mod deadline_compat; +#[cfg(feature = "serde")] +pub mod serde; + +/// Types that can be represented by a [`Duration`]. +pub trait AsDuration { + fn as_duration(&self) -> Duration; +} + +impl AsDuration for SystemTime { + /// Duration of 0 if self is earlier than [`SystemTime::now`]. + fn as_duration(&self) -> Duration { + self.duration_since(SystemTime::now()).unwrap_or_default() + } +} + +/// Collection compaction; configurable `shrink_to_fit`. +pub trait Compact { + /// Compacts space if the ratio of length : capacity is less than `usage_ratio_threshold`. + fn compact(&mut self, usage_ratio_threshold: f64); +} + +impl Compact for HashMap +where + K: Eq + Hash, + H: BuildHasher, +{ + fn compact(&mut self, usage_ratio_threshold: f64) { + let usage_ratio = self.len() as f64 / self.capacity() as f64; + if usage_ratio < usage_ratio_threshold { + self.shrink_to_fit(); + } + } +} diff --git a/rpc/src/util/serde.rs b/rpc/src/util/serde.rs new file mode 100644 index 00000000..260b47b3 --- /dev/null +++ b/rpc/src/util/serde.rs @@ -0,0 +1,88 @@ +use serde::{Deserialize, Deserializer, Serialize, Serializer}; +use std::{ + io, + time::{Duration, SystemTime}, +}; + +/// Serializes `system_time` as a `u64` equal to the number of seconds since the epoch. +pub fn serialize_epoch_secs(system_time: &SystemTime, serializer: S) -> Result +where + S: Serializer, +{ + system_time + .duration_since(SystemTime::UNIX_EPOCH) + .unwrap_or(Duration::from_secs(0)) + .as_secs() // Only care about second precision + .serialize(serializer) +} + +/// Deserializes [`SystemTime`] from a `u64` equal to the number of seconds since the epoch. +pub fn deserialize_epoch_secs<'de, D>(deserializer: D) -> Result +where + D: Deserializer<'de>, +{ + Ok(SystemTime::UNIX_EPOCH + Duration::from_secs(u64::deserialize(deserializer)?)) +} + +/// Serializes [`io::ErrorKind`] as a `u32`. +pub fn serialize_io_error_kind_as_u32( + kind: &io::ErrorKind, + serializer: S, +) -> Result +where + S: Serializer, +{ + use std::io::ErrorKind::*; + match *kind { + NotFound => 0, + PermissionDenied => 1, + ConnectionRefused => 2, + ConnectionReset => 3, + ConnectionAborted => 4, + NotConnected => 5, + AddrInUse => 6, + AddrNotAvailable => 7, + BrokenPipe => 8, + AlreadyExists => 9, + WouldBlock => 10, + InvalidInput => 11, + InvalidData => 12, + TimedOut => 13, + WriteZero => 14, + Interrupted => 15, + Other => 16, + UnexpectedEof => 17, + _ => 16, + }.serialize(serializer) +} + +/// Deserializes [`io::ErrorKind`] from a `u32`. +pub fn deserialize_io_error_kind_from_u32<'de, D>( + deserializer: D, +) -> Result +where + D: Deserializer<'de>, +{ + use std::io::ErrorKind::*; + Ok(match u32::deserialize(deserializer)? { + 0 => NotFound, + 1 => PermissionDenied, + 2 => ConnectionRefused, + 3 => ConnectionReset, + 4 => ConnectionAborted, + 5 => NotConnected, + 6 => AddrInUse, + 7 => AddrNotAvailable, + 8 => BrokenPipe, + 9 => AlreadyExists, + 10 => WouldBlock, + 11 => InvalidInput, + 12 => InvalidData, + 13 => TimedOut, + 14 => WriteZero, + 15 => Interrupted, + 16 => Other, + 17 => UnexpectedEof, + _ => Other, + }) +} diff --git a/rustfmt.toml b/rustfmt.toml deleted file mode 100644 index 44148a2d..00000000 --- a/rustfmt.toml +++ /dev/null @@ -1 +0,0 @@ -reorder_imports = true diff --git a/src/errors.rs b/src/errors.rs deleted file mode 100644 index a3a62267..00000000 --- a/src/errors.rs +++ /dev/null @@ -1,91 +0,0 @@ -// Copyright 2016 Google Inc. All Rights Reserved. -// -// Licensed under the MIT License, . -// This file may not be copied, modified, or distributed except according to those terms. - -use serde::{Deserialize, Serialize}; -use std::{fmt, io}; -use std::error::Error as StdError; - -/// All errors that can occur during the use of tarpc. -#[derive(Debug)] -pub enum Error { - /// Any IO error. - Io(io::Error), - /// Error deserializing the server response. - /// - /// Typically this indicates a faulty implementation of `serde::Serialize` or - /// `serde::Deserialize`. - ResponseDeserialize(::bincode::Error), - /// Error deserializing the client request. - /// - /// Typically this indicates a faulty implementation of `serde::Serialize` or - /// `serde::Deserialize`. - RequestDeserialize(String), - /// The server was unable to reply to the rpc for some reason. - /// - /// This is a service-specific error. Its type is individually specified in the - /// `service!` macro for each rpc. - App(E), -} - -impl<'a, E: StdError + Deserialize<'a> + Serialize + Send + 'static> fmt::Display for Error { - fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - match *self { - Error::ResponseDeserialize(ref e) => write!(f, r#"{}: "{}""#, self.description(), e), - Error::RequestDeserialize(ref e) => write!(f, r#"{}: "{}""#, self.description(), e), - Error::App(ref e) => fmt::Display::fmt(e, f), - Error::Io(ref e) => fmt::Display::fmt(e, f), - } - } -} - -impl<'a, E: StdError + Deserialize<'a> + Serialize + Send + 'static> StdError for Error { - fn description(&self) -> &str { - match *self { - Error::ResponseDeserialize(_) => "The client failed to deserialize the response.", - Error::RequestDeserialize(_) => "The server failed to deserialize the request.", - Error::App(ref e) => e.description(), - Error::Io(ref e) => e.description(), - } - } - - fn cause(&self) -> Option<&StdError> { - match *self { - Error::ResponseDeserialize(ref e) => e.cause(), - Error::RequestDeserialize(_) | Error::App(_) => None, - Error::Io(ref e) => e.cause(), - } - } -} - -impl From for Error { - fn from(err: io::Error) -> Self { - Error::Io(err) - } -} - -impl From> for Error { - fn from(err: WireError) -> Self { - match err { - WireError::RequestDeserialize(s) => Error::RequestDeserialize(s), - WireError::App(e) => Error::App(e), - } - } -} - -/// A serializable, server-supplied error. -#[doc(hidden)] -#[derive(Deserialize, Serialize, Clone, Debug)] -pub enum WireError { - /// Server-side error in deserializing the client request. - RequestDeserialize(String), - /// The server was unable to reply to the rpc for some reason. - App(E), -} - -/// Convert `native_tls::Error` to `std::io::Error` -#[cfg(feature = "tls")] -pub fn native_to_io(e: ::native_tls::Error) -> io::Error { - io::Error::new(io::ErrorKind::Other, e) -} diff --git a/src/future/client.rs b/src/future/client.rs deleted file mode 100644 index 374f69f3..00000000 --- a/src/future/client.rs +++ /dev/null @@ -1,278 +0,0 @@ -// Copyright 2016 Google Inc. All Rights Reserved. -// -// Licensed under the MIT License, . -// This file may not be copied, modified, or distributed except according to those terms. - -use {REMOTE, bincode}; -use future::server::Response; -use futures::{self, Future, future}; -use protocol::Proto; -use serde::Serialize; -use serde::de::DeserializeOwned; -use std::fmt; -use std::io; -use std::net::SocketAddr; -use stream_type::StreamType; -use tokio_core::net::TcpStream; -use tokio_core::reactor; -use tokio_proto::BindClient as ProtoBindClient; -use tokio_proto::multiplex::ClientService; -use tokio_service::Service; - -cfg_if! { - if #[cfg(feature = "tls")] { - use errors::native_to_io; - use tls::client::Context; - use tokio_tls::TlsConnectorExt; - } else {} -} - -/// Additional options to configure how the client connects and operates. -#[derive(Debug)] -pub struct Options { - /// Max packet size in bytes. - max_payload_size: u64, - reactor: Option, - #[cfg(feature = "tls")] - tls_ctx: Option, -} - -impl Default for Options { - #[cfg(feature = "tls")] - fn default() -> Self { - Options { - max_payload_size: 2 << 20, - reactor: None, - tls_ctx: None, - } - } - - #[cfg(not(feature = "tls"))] - fn default() -> Self { - Options { - max_payload_size: 2 << 20, - reactor: None, - } - } -} - -impl Options { - /// Set the max payload size in bytes. The default is 2 << 20 (2 MiB). - pub fn max_payload_size(mut self, bytes: u64) -> Self { - self.max_payload_size = bytes; - self - } - - /// Drive using the given reactor handle. - pub fn handle(mut self, handle: reactor::Handle) -> Self { - self.reactor = Some(Reactor::Handle(handle)); - self - } - - /// Drive using the given reactor remote. - pub fn remote(mut self, remote: reactor::Remote) -> Self { - self.reactor = Some(Reactor::Remote(remote)); - self - } - - /// Connect using the given `Context` - #[cfg(feature = "tls")] - pub fn tls(mut self, tls_ctx: Context) -> Self { - self.tls_ctx = Some(tls_ctx); - self - } -} - -enum Reactor { - Handle(reactor::Handle), - Remote(reactor::Remote), -} - -impl fmt::Debug for Reactor { - fn fmt(&self, f: &mut fmt::Formatter) -> Result<(), fmt::Error> { - const HANDLE: &str = "Reactor::Handle"; - const HANDLE_INNER: &str = "Handle { .. }"; - const REMOTE: &str = "Reactor::Remote"; - const REMOTE_INNER: &str = "Remote { .. }"; - - match *self { - Reactor::Handle(_) => f.debug_tuple(HANDLE).field(&HANDLE_INNER).finish(), - Reactor::Remote(_) => f.debug_tuple(REMOTE).field(&REMOTE_INNER).finish(), - } - } -} - -#[doc(hidden)] -pub struct Client -where - Req: Serialize + 'static, - Resp: DeserializeOwned + 'static, - E: DeserializeOwned + 'static, -{ - inner: ClientService>>, -} - -impl Clone for Client -where - Req: Serialize + 'static, - Resp: DeserializeOwned + 'static, - E: DeserializeOwned + 'static, -{ - fn clone(&self) -> Self { - Client { - inner: self.inner.clone(), - } - } -} - -impl Service for Client -where - Req: Serialize + Send + 'static, - Resp: DeserializeOwned + Send + 'static, - E: DeserializeOwned + Send + 'static, -{ - type Request = Req; - type Response = Resp; - type Error = ::Error; - type Future = ResponseFuture; - - fn call(&self, request: Self::Request) -> Self::Future { - fn identity(t: T) -> T { - t - } - self.inner - .call(request) - .map(Self::map_err as _) - .map_err(::Error::from as _) - .and_then(identity as _) - } -} - -impl Client -where - Req: Serialize + 'static, - Resp: DeserializeOwned + 'static, - E: DeserializeOwned + 'static, -{ - fn bind(handle: &reactor::Handle, tcp: StreamType, max_payload_size: u64) -> Self - where - Req: Serialize + Send + 'static, - Resp: DeserializeOwned + Send + 'static, - E: DeserializeOwned + Send + 'static, - { - let inner = Proto::new(max_payload_size).bind_client(handle, tcp); - Client { inner } - } - - fn map_err(resp: WireResponse) -> Result> { - resp.map(|r| r.map_err(::Error::from)) - .map_err(::Error::ResponseDeserialize) - .and_then(|r| r) - } -} - -impl fmt::Debug for Client -where - Req: Serialize + 'static, - Resp: DeserializeOwned + 'static, - E: DeserializeOwned + 'static, -{ - fn fmt(&self, f: &mut fmt::Formatter) -> Result<(), fmt::Error> { - write!(f, "Client {{ .. }}") - } -} - -/// Extension methods for clients. -pub trait ClientExt: Sized { - /// The type of the future returned when calling `connect`. - type ConnectFut: Future; - - /// Connects to a server located at the given address, using the given options. - fn connect(addr: SocketAddr, options: Options) -> Self::ConnectFut; -} - -/// A future that resolves to a `Client` or an `io::Error`. -pub type ConnectFuture = futures::Flatten< - futures::MapErr< - futures::Oneshot>>, - fn(futures::Canceled) -> io::Error, - >, ->; - -impl ClientExt for Client -where - Req: Serialize + Send + 'static, - Resp: DeserializeOwned + Send + 'static, - E: DeserializeOwned + Send + 'static, -{ - type ConnectFut = ConnectFuture; - - fn connect(addr: SocketAddr, options: Options) -> Self::ConnectFut { - // we need to do this for tls because we need to avoid moving the entire `Options` - // struct into the `setup` closure, since `Reactor` is not `Send`. - #[cfg(feature = "tls")] - let mut options = options; - #[cfg(feature = "tls")] - let tls_ctx = options.tls_ctx.take(); - - let max_payload_size = options.max_payload_size; - - let connect = move |handle: &reactor::Handle| { - let handle2 = handle.clone(); - TcpStream::connect(&addr, handle) - .and_then(move |socket| { - // TODO(https://github.com/tokio-rs/tokio-proto/issues/132): move this into the - // ServerProto impl - #[cfg(feature = "tls")] - match tls_ctx { - Some(tls_ctx) => { - future::Either::A( - tls_ctx - .tls_connector - .connect_async(&tls_ctx.domain, socket) - .map(StreamType::Tls) - .map_err(native_to_io), - ) - } - None => future::Either::B(future::ok(StreamType::Tcp(socket))), - } - #[cfg(not(feature = "tls"))] future::ok(StreamType::Tcp(socket)) - }) - .map(move |tcp| Client::bind(&handle2, tcp, max_payload_size)) - }; - let (tx, rx) = futures::oneshot(); - let setup = move |handle: &reactor::Handle| { - connect(handle).then(move |result| { - // If send fails it means the client no longer cared about connecting. - let _ = tx.send(result); - Ok(()) - }) - }; - - match options.reactor { - Some(Reactor::Handle(handle)) => { - handle.spawn(setup(&handle)); - } - Some(Reactor::Remote(remote)) => { - remote.spawn(setup); - } - None => { - REMOTE.spawn(setup); - } - } - fn panic(canceled: futures::Canceled) -> io::Error { - unreachable!(canceled) - } - rx.map_err(panic as _).flatten() - } -} - -type ResponseFuture = - futures::AndThen>> as Service>::Future, - fn(WireResponse) -> Result>>, - fn(io::Error) -> ::Error>, - Result>, - fn(Result>) -> Result>>; - -type WireResponse = Result, bincode::Error>; diff --git a/src/future/mod.rs b/src/future/mod.rs deleted file mode 100644 index 79011a35..00000000 --- a/src/future/mod.rs +++ /dev/null @@ -1,4 +0,0 @@ -/// Provides the base client stubs used by the service macro. -pub mod client; -/// Provides the base server boilerplate used by service implementations. -pub mod server; diff --git a/src/future/server/connection.rs b/src/future/server/connection.rs deleted file mode 100644 index 8fb41371..00000000 --- a/src/future/server/connection.rs +++ /dev/null @@ -1,76 +0,0 @@ -use futures::unsync; -use std::io; -use tokio_service::{NewService, Service}; - -#[derive(Debug)] -pub enum Action { - Increment, - Decrement, -} - -#[derive(Clone, Debug)] -pub struct Tracker { - pub tx: unsync::mpsc::UnboundedSender, -} - -impl Tracker { - pub fn pair() -> (Self, unsync::mpsc::UnboundedReceiver) { - let (tx, rx) = unsync::mpsc::unbounded(); - (Self { tx }, rx) - } - - pub fn increment(&self) { - let _ = self.tx.unbounded_send(Action::Increment); - } - - pub fn decrement(&self) { - debug!("Closing connection"); - let _ = self.tx.unbounded_send(Action::Decrement); - } -} - -#[derive(Debug)] -pub struct TrackingService { - pub service: S, - pub tracker: Tracker, -} - -#[derive(Debug)] -pub struct TrackingNewService { - pub new_service: S, - pub connection_tracker: Tracker, -} - -impl Service for TrackingService { - type Request = S::Request; - type Response = S::Response; - type Error = S::Error; - type Future = S::Future; - - fn call(&self, req: Self::Request) -> Self::Future { - trace!("Calling service."); - self.service.call(req) - } -} - -impl Drop for TrackingService { - fn drop(&mut self) { - debug!("Dropping ConnnectionTrackingService."); - self.tracker.decrement(); - } -} - -impl NewService for TrackingNewService { - type Request = S::Request; - type Response = S::Response; - type Error = S::Error; - type Instance = TrackingService; - - fn new_service(&self) -> io::Result { - self.connection_tracker.increment(); - Ok(TrackingService { - service: self.new_service.new_service()?, - tracker: self.connection_tracker.clone(), - }) - } -} diff --git a/src/future/server/mod.rs b/src/future/server/mod.rs deleted file mode 100644 index fcfffcef..00000000 --- a/src/future/server/mod.rs +++ /dev/null @@ -1,471 +0,0 @@ -// Copyright 2016 Google Inc. All Rights Reserved. -// -// Licensed under the MIT License, . -// This file may not be copied, modified, or distributed except according to those terms. - -use {bincode, net2}; -use errors::WireError; -use futures::{Async, Future, Poll, Stream, future as futures}; -use protocol::Proto; -use serde::Serialize; -use serde::de::DeserializeOwned; -use std::fmt; -use std::io; -use std::net::SocketAddr; -use stream_type::StreamType; -use tokio_core::net::{Incoming, TcpListener, TcpStream}; -use tokio_core::reactor; -use tokio_io::{AsyncRead, AsyncWrite}; -use tokio_proto::BindServer; -use tokio_service::NewService; - -mod connection; -mod shutdown; - -cfg_if! { - if #[cfg(feature = "tls")] { - use native_tls::{self, TlsAcceptor}; - use tokio_tls::{AcceptAsync, TlsAcceptorExt, TlsStream}; - use errors::native_to_io; - } else {} -} - -pub use self::shutdown::{Shutdown, ShutdownFuture}; - -/// A handle to a bound server. -#[derive(Clone, Debug)] -pub struct Handle { - addr: SocketAddr, - shutdown: Shutdown, -} - -impl Handle { - /// Returns a hook for shutting down the server. - pub fn shutdown(&self) -> &Shutdown { - &self.shutdown - } - - /// The socket address the server is bound to. - pub fn addr(&self) -> SocketAddr { - self.addr - } -} - -enum Acceptor { - Tcp, - #[cfg(feature = "tls")] - Tls(TlsAcceptor), -} - -struct Accept { - #[cfg(feature = "tls")] - inner: futures::Either< - futures::MapErr< - futures::Map, fn(TlsStream) -> StreamType>, - fn(native_tls::Error) -> io::Error, - >, - futures::FutureResult, - >, - #[cfg(not(feature = "tls"))] - inner: futures::FutureResult, -} - -impl Future for Accept { - type Item = StreamType; - type Error = io::Error; - - fn poll(&mut self) -> Poll { - self.inner.poll() - } -} - -impl Acceptor { - // TODO(https://github.com/tokio-rs/tokio-proto/issues/132): move this into the ServerProto impl - #[cfg(feature = "tls")] - fn accept(&self, socket: TcpStream) -> Accept { - Accept { - inner: match *self { - Acceptor::Tls(ref tls_acceptor) => { - futures::Either::A( - tls_acceptor - .accept_async(socket) - .map(StreamType::Tls as _) - .map_err(native_to_io), - ) - } - Acceptor::Tcp => futures::Either::B(futures::ok(StreamType::Tcp(socket))), - }, - } - } - - #[cfg(not(feature = "tls"))] - fn accept(&self, socket: TcpStream) -> Accept { - Accept { - inner: futures::ok(StreamType::Tcp(socket)), - } - } -} - -#[cfg(feature = "tls")] -impl From for Acceptor { - fn from(options: Options) -> Self { - match options.tls_acceptor { - Some(tls_acceptor) => Acceptor::Tls(tls_acceptor), - None => Acceptor::Tcp, - } - } -} - -#[cfg(not(feature = "tls"))] -impl From for Acceptor { - fn from(_: Options) -> Self { - Acceptor::Tcp - } -} - -impl fmt::Debug for Acceptor { - fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result { - use self::Acceptor::*; - #[cfg(feature = "tls")] - const TLS: &str = "TlsAcceptor { .. }"; - - match *self { - Tcp => fmt.debug_tuple("Acceptor::Tcp").finish(), - #[cfg(feature = "tls")] - Tls(_) => fmt.debug_tuple("Acceptor::Tls").field(&TLS).finish(), - } - } -} - -impl fmt::Debug for Accept { - fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result { - fmt.debug_struct("Accept").finish() - } -} - -#[derive(Debug)] -struct AcceptStream { - stream: S, - acceptor: Acceptor, - future: Option, -} - -impl Stream for AcceptStream -where - S: Stream, -{ - type Item = ::Item; - type Error = io::Error; - - fn poll(&mut self) -> Poll, io::Error> { - if self.future.is_none() { - let stream = match try_ready!(self.stream.poll()) { - None => return Ok(Async::Ready(None)), - Some((stream, _)) => stream, - }; - self.future = Some(self.acceptor.accept(stream)); - } - assert!(self.future.is_some()); - match self.future.as_mut().unwrap().poll() { - Ok(Async::Ready(e)) => { - self.future = None; - Ok(Async::Ready(Some(e))) - } - Err(e) => { - self.future = None; - Err(e) - } - Ok(Async::NotReady) => Ok(Async::NotReady), - } - } -} - -/// Additional options to configure how the server operates. -pub struct Options { - /// Max packet size in bytes. - max_payload_size: u64, - #[cfg(feature = "tls")] - tls_acceptor: Option, -} - -impl Default for Options { - #[cfg(not(feature = "tls"))] - fn default() -> Self { - Options { - max_payload_size: 2 << 20, - } - } - - #[cfg(feature = "tls")] - fn default() -> Self { - Options { - max_payload_size: 2 << 20, - tls_acceptor: None, - } - } -} - -impl Options { - /// Set the max payload size in bytes. The default is 2 << 20 (2 MiB). - pub fn max_payload_size(mut self, bytes: u64) -> Self { - self.max_payload_size = bytes; - self - } - - /// Sets the `TlsAcceptor` - #[cfg(feature = "tls")] - pub fn tls(mut self, tls_acceptor: TlsAcceptor) -> Self { - self.tls_acceptor = Some(tls_acceptor); - self - } -} - -impl fmt::Debug for Options { - fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result { - #[cfg(feature = "tls")] - const SOME: &str = "Some(_)"; - #[cfg(feature = "tls")] - const NONE: &str = "None"; - - let mut debug_struct = fmt.debug_struct("Options"); - #[cfg(feature = "tls")] - debug_struct.field( - "tls_acceptor", - if self.tls_acceptor.is_some() { - &SOME - } else { - &NONE - }, - ); - debug_struct.finish() - } -} - -/// A message from server to client. -#[doc(hidden)] -pub type Response = Result>; - -#[doc(hidden)] -pub fn listen(new_service: S, - addr: SocketAddr, - handle: &reactor::Handle, - options: Options) - -> io::Result<(Handle, Listen)> - where S: NewService, - Response = Response, - Error = io::Error> + 'static, - Req: DeserializeOwned + 'static, - Resp: Serialize + 'static, - E: Serialize + 'static -{ - let (addr, shutdown, server) = listen_with( - new_service, - addr, - handle, - options.max_payload_size, - Acceptor::from(options), - )?; - Ok(( - Handle { - addr: addr, - shutdown: shutdown, - }, - server, - )) -} - -/// Spawns a service that binds to the given address using the given handle. -fn listen_with(new_service: S, - addr: SocketAddr, - handle: &reactor::Handle, - max_payload_size: u64, - acceptor: Acceptor) - -> io::Result<(SocketAddr, Shutdown, Listen)> - where S: NewService, - Response = Response, - Error = io::Error> + 'static, - Req: DeserializeOwned + 'static, - Resp: Serialize + 'static, - E: Serialize + 'static -{ - let listener = listener(&addr, handle)?; - let addr = listener.local_addr()?; - debug!("Listening on {}.", addr); - - let handle = handle.clone(); - let (connection_tracker, shutdown, shutdown_future) = shutdown::Watcher::triple(); - let server = BindStream { - handle: handle, - new_service: connection::TrackingNewService { - connection_tracker: connection_tracker, - new_service: new_service, - }, - stream: AcceptStream { - stream: listener.incoming(), - acceptor: acceptor, - future: None, - }, - max_payload_size: max_payload_size, - }; - - let server = AlwaysOkUnit(server.select(shutdown_future)); - Ok((addr, shutdown, Listen { inner: server })) -} - -fn listener(addr: &SocketAddr, handle: &reactor::Handle) -> io::Result { - const PENDING_CONNECTION_BACKLOG: i32 = 1024; - - let builder = match *addr { - SocketAddr::V4(_) => net2::TcpBuilder::new_v4(), - SocketAddr::V6(_) => net2::TcpBuilder::new_v6(), - }?; - configure_tcp(&builder)?; - builder.reuse_address(true)?; - builder - .bind(addr)? - .listen(PENDING_CONNECTION_BACKLOG) - .and_then(|l| TcpListener::from_listener(l, addr, handle)) -} - -#[cfg(unix)] -fn configure_tcp(tcp: &net2::TcpBuilder) -> io::Result<()> { - use net2::unix::UnixTcpBuilderExt; - tcp.reuse_port(true)?; - Ok(()) -} - -#[cfg(windows)] -fn configure_tcp(_tcp: &net2::TcpBuilder) -> io::Result<()> { - Ok(()) -} - -struct BindStream { - handle: reactor::Handle, - new_service: connection::TrackingNewService, - stream: St, - max_payload_size: u64, -} - -impl fmt::Debug for BindStream -where - S: fmt::Debug, - St: fmt::Debug, -{ - fn fmt(&self, f: &mut fmt::Formatter) -> Result<(), fmt::Error> { - f.debug_struct("BindStream") - .field("handle", &self.handle) - .field("new_service", &self.new_service) - .field("stream", &self.stream) - .finish() - } -} - -impl BindStream - where S: NewService, - Response = Response, - Error = io::Error> + 'static, - Req: DeserializeOwned + 'static, - Resp: Serialize + 'static, - E: Serialize + 'static, - I: AsyncRead + AsyncWrite + 'static, - St: Stream -{ - fn bind_each(&mut self) -> Poll<(), io::Error> { - loop { - match try!(self.stream.poll()) { - Async::Ready(Some(socket)) => { - Proto::new(self.max_payload_size).bind_server(&self.handle, - socket, - self.new_service.new_service()?); - } - Async::Ready(None) => return Ok(Async::Ready(())), - Async::NotReady => return Ok(Async::NotReady), - } - } - } -} - -impl Future for BindStream - where S: NewService, - Response = Response, - Error = io::Error> + 'static, - Req: DeserializeOwned + 'static, - Resp: Serialize + 'static, - E: Serialize + 'static, - I: AsyncRead + AsyncWrite + 'static, - St: Stream -{ - type Item = (); - type Error = (); - - fn poll(&mut self) -> Poll { - match self.bind_each() { - Ok(Async::Ready(())) => Ok(Async::Ready(())), - Ok(Async::NotReady) => Ok(Async::NotReady), - Err(e) => { - error!("While processing incoming connections: {}", e); - Err(()) - } - } - } -} - -/// The future representing a running server. -#[doc(hidden)] -pub struct Listen - where S: NewService, - Response = Response, - Error = io::Error> + 'static, - Req: DeserializeOwned + 'static, - Resp: Serialize + 'static, - E: Serialize + 'static -{ - inner: AlwaysOkUnit>, shutdown::Watcher>>, -} - -impl Future for Listen - where S: NewService, - Response = Response, - Error = io::Error> + 'static, - Req: DeserializeOwned + 'static, - Resp: Serialize + 'static, - E: Serialize + 'static -{ - type Item = (); - type Error = (); - - fn poll(&mut self) -> Poll<(), ()> { - self.inner.poll() - } -} - -impl fmt::Debug for Listen - where S: NewService, - Response = Response, - Error = io::Error> + 'static, - Req: DeserializeOwned + 'static, - Resp: Serialize + 'static, - E: Serialize + 'static -{ - fn fmt(&self, f: &mut fmt::Formatter) -> Result<(), fmt::Error> { - f.debug_struct("Listen").finish() - } -} - -#[derive(Debug)] -struct AlwaysOkUnit(F); - -impl Future for AlwaysOkUnit -where - F: Future, -{ - type Item = (); - type Error = (); - - fn poll(&mut self) -> Poll<(), ()> { - match self.0.poll() { - Ok(Async::Ready(_)) | Err(_) => Ok(Async::Ready(())), - Ok(Async::NotReady) => Ok(Async::NotReady), - } - } -} diff --git a/src/future/server/shutdown.rs b/src/future/server/shutdown.rs deleted file mode 100644 index 720201c7..00000000 --- a/src/future/server/shutdown.rs +++ /dev/null @@ -1,182 +0,0 @@ - - -use super::{AlwaysOkUnit, connection}; -use futures::{Async, Future, Poll, Stream, future as futures, stream}; -use futures::sync::{mpsc, oneshot}; -use futures::unsync; - -/// A hook to shut down a running server. -#[derive(Clone, Debug)] -pub struct Shutdown { - tx: mpsc::UnboundedSender>, -} - -/// A future that resolves when server shutdown completes. -#[derive(Debug)] -pub struct ShutdownFuture { - inner: futures::Either, AlwaysOkUnit>>, -} - -impl Future for ShutdownFuture { - type Item = (); - type Error = (); - - fn poll(&mut self) -> Poll<(), ()> { - self.inner.poll() - } -} - -impl Shutdown { - /// Initiates an orderly server shutdown. - /// - /// First, the server enters lameduck mode, in which - /// existing connections are honored but no new connections are accepted. Then, once all - /// connections are closed, it initates total shutdown. - /// - /// The returned future resolves when the server is completely shut down. - pub fn shutdown(&self) -> ShutdownFuture { - let (tx, rx) = oneshot::channel(); - let inner = if self.tx.unbounded_send(tx).is_err() { - trace!("Server already initiated shutdown."); - futures::Either::A(futures::ok(())) - } else { - futures::Either::B(AlwaysOkUnit(rx)) - }; - ShutdownFuture { inner: inner } - } -} - -#[derive(Debug)] -pub struct Watcher { - shutdown_rx: stream::Take>>, - connections: unsync::mpsc::UnboundedReceiver, - queued_error: Option<()>, - shutdown: Option>, - done: bool, - num_connections: u64, -} - -impl Watcher { - pub fn triple() -> (connection::Tracker, Shutdown, Self) { - let (connection_tx, connections) = connection::Tracker::pair(); - let (shutdown_tx, shutdown_rx) = mpsc::unbounded(); - ( - connection_tx, - Shutdown { tx: shutdown_tx }, - Watcher { - shutdown_rx: shutdown_rx.take(1), - connections: connections, - queued_error: None, - shutdown: None, - done: false, - num_connections: 0, - }, - ) - } - - fn process_connection(&mut self, action: connection::Action) { - match action { - connection::Action::Increment => self.num_connections += 1, - connection::Action::Decrement => self.num_connections -= 1, - } - } - - fn poll_shutdown_requests(&mut self) -> Poll, ()> { - Ok(Async::Ready(match try_ready!(self.shutdown_rx.poll()) { - Some(tx) => { - debug!("Received shutdown request."); - self.shutdown = Some(tx); - Some(()) - } - None => None, - })) - } - - fn poll_connections(&mut self) -> Poll, ()> { - Ok(Async::Ready(match try_ready!(self.connections.poll()) { - Some(action) => { - self.process_connection(action); - Some(()) - } - None => None, - })) - } - - fn poll_shutdown_requests_and_connections(&mut self) -> Poll, ()> { - if let Some(e) = self.queued_error.take() { - return Err(e); - } - - match try!(self.poll_shutdown_requests()) { - Async::NotReady => { - match try_ready!(self.poll_connections()) { - Some(()) => Ok(Async::Ready(Some(()))), - None => Ok(Async::NotReady), - } - } - Async::Ready(None) => { - match try_ready!(self.poll_connections()) { - Some(()) => Ok(Async::Ready(Some(()))), - None => Ok(Async::Ready(None)), - } - } - Async::Ready(Some(())) => { - match self.poll_connections() { - Err(e) => { - self.queued_error = Some(e); - Ok(Async::Ready(Some(()))) - } - Ok(Async::NotReady) | Ok(Async::Ready(None)) | Ok(Async::Ready(Some(()))) => { - Ok(Async::Ready(Some(()))) - } - } - } - } - } - - fn should_continue(&mut self) -> bool { - match self.shutdown.take() { - Some(shutdown) => { - debug!("Lameduck mode: {} open connections", self.num_connections); - if self.num_connections == 0 { - debug!("Shutting down."); - // Not required for the shutdown future to be waited on, so this - // can fail (which is fine). - let _ = shutdown.send(()); - false - } else { - self.shutdown = Some(shutdown); - true - } - } - None => true, - } - } - - fn process_request(&mut self) -> Poll, ()> { - if self.done { - return Ok(Async::Ready(None)); - } - if self.should_continue() { - self.poll_shutdown_requests_and_connections() - } else { - self.done = true; - Ok(Async::Ready(None)) - } - } -} - -impl Future for Watcher { - type Item = (); - type Error = (); - - fn poll(&mut self) -> Poll<(), ()> { - loop { - match try!(self.process_request()) { - Async::Ready(Some(())) => continue, - Async::Ready(None) => return Ok(Async::Ready(())), - Async::NotReady => return Ok(Async::NotReady), - } - } - } -} diff --git a/src/lib.rs b/src/lib.rs deleted file mode 100644 index 717426a7..00000000 --- a/src/lib.rs +++ /dev/null @@ -1,213 +0,0 @@ -// Copyright 2016 Google Inc. All Rights Reserved. -// -// Licensed under the MIT License, . -// This file may not be copied, modified, or distributed except according to those terms. - -//! tarpc is an RPC framework for rust with a focus on ease of use. Defining a -//! service can be done in just a few lines of code, and most of the boilerplate of -//! writing a server is taken care of for you. -//! -//! ## What is an RPC framework? -//! "RPC" stands for "Remote Procedure Call," a function call where the work of -//! producing the return value is being done somewhere else. When an rpc function is -//! invoked, behind the scenes the function contacts some other process somewhere -//! and asks them to evaluate the function instead. The original function then -//! returns the value produced by the other process. -//! -//! RPC frameworks are a fundamental building block of most microservices-oriented -//! architectures. Two well-known ones are [gRPC](http://www.grpc.io) and -//! [Cap'n Proto](https://capnproto.org/). -//! -//! tarpc differentiates itself from other RPC frameworks by defining the schema in code, -//! rather than in a separate language such as .proto. This means there's no separate compilation -//! process, and no cognitive context switching between different languages. Additionally, it -//! works with the community-backed library serde: any serde-serializable type can be used as -//! arguments to tarpc fns. -//! -//! Example usage: -//! -//! ``` -//! #![feature(plugin, use_extern_macros, proc_macro_path_invoc)] -//! #![plugin(tarpc_plugins)] -//! -//! #[macro_use] -//! extern crate tarpc; -//! extern crate tokio_core; -//! -//! use tarpc::sync::{client, server}; -//! use tarpc::sync::client::ClientExt; -//! use tarpc::util::Never; -//! use tokio_core::reactor; -//! use std::sync::mpsc; -//! use std::thread; -//! -//! service! { -//! rpc hello(name: String) -> String; -//! } -//! -//! #[derive(Clone)] -//! struct HelloServer; -//! -//! impl SyncService for HelloServer { -//! fn hello(&self, name: String) -> Result { -//! Ok(format!("Hello, {}!", name)) -//! } -//! } -//! -//! fn main() { -//! let (tx, rx) = mpsc::channel(); -//! thread::spawn(move || { -//! let mut handle = HelloServer.listen("localhost:10000", -//! server::Options::default()).unwrap(); -//! tx.send(handle.addr()).unwrap(); -//! handle.run(); -//! }); -//! let addr = rx.recv().unwrap(); -//! let client = SyncClient::connect(addr, client::Options::default()).unwrap(); -//! println!("{}", client.hello("Mom".to_string()).unwrap()); -//! } -//! ``` -//! -//! Example usage with TLS: -//! -//! ```no-run -//! #![feature(plugin, use_extern_macros, proc_macro_path_invoc)] -//! #![plugin(tarpc_plugins)] -//! -//! #[macro_use] -//! extern crate tarpc; -//! -//! use tarpc::sync::{client, server}; -//! use tarpc::sync::client::ClientExt; -//! use tarpc::tls; -//! use tarpc::util::Never; -//! use tarpc::native_tls::{TlsAcceptor, Pkcs12}; -//! -//! service! { -//! rpc hello(name: String) -> String; -//! } -//! -//! #[derive(Clone)] -//! struct HelloServer; -//! -//! impl SyncService for HelloServer { -//! fn hello(&self, name: String) -> Result { -//! Ok(format!("Hello, {}!", name)) -//! } -//! } -//! -//! fn get_acceptor() -> TlsAcceptor { -//! let buf = include_bytes!("test/identity.p12"); -//! let pkcs12 = Pkcs12::from_der(buf, "password").unwrap(); -//! TlsAcceptor::builder(pkcs12).unwrap().build().unwrap() -//! } -//! -//! fn main() { -//! let addr = "localhost:10000"; -//! let acceptor = get_acceptor(); -//! let _server = HelloServer.listen(addr, server::Options::default().tls(acceptor)); -//! let client = SyncClient::connect(addr, -//! client::Options::default() -//! .tls(tls::client::Context::new("foobar.com").unwrap())) -//! .unwrap(); -//! println!("{}", client.hello("Mom".to_string()).unwrap()); -//! } -//! ``` - -#![deny(missing_docs, missing_debug_implementations)] -#![feature(never_type)] -#![cfg_attr(test, feature(plugin))] -#![cfg_attr(test, plugin(tarpc_plugins))] - -extern crate byteorder; -extern crate bytes; -#[macro_use] -extern crate cfg_if; -#[macro_use] -extern crate lazy_static; -#[macro_use] -extern crate log; -extern crate net2; -extern crate num_cpus; -extern crate thread_pool; -extern crate tokio_codec; -extern crate tokio_io; - -#[doc(hidden)] -pub extern crate bincode; -#[doc(hidden)] -#[macro_use] -pub extern crate futures; -#[doc(hidden)] -pub extern crate serde; -#[doc(hidden)] -#[macro_use] -pub extern crate serde_derive; -#[doc(hidden)] -pub extern crate tokio_core; -#[doc(hidden)] -pub extern crate tokio_proto; -#[doc(hidden)] -pub extern crate tokio_service; - -pub use errors::Error; -#[doc(hidden)] -pub use errors::WireError; - -/// Provides some utility error types, as well as a trait for spawning futures on the default event -/// loop. -pub mod util; - -/// Provides the macro used for constructing rpc services and client stubs. -#[macro_use] -mod macros; -/// Synchronous version of the tarpc API -pub mod sync; -/// Futures-based version of the tarpc API. -pub mod future; -/// TLS-specific functionality. -#[cfg(feature = "tls")] -pub mod tls; -/// Provides implementations of `ClientProto` and `ServerProto` that implement the tarpc protocol. -/// The tarpc protocol is a length-delimited, bincode-serialized payload. -mod protocol; -/// Provides a few different error types. -mod errors; -/// Provides an abstraction over TLS and TCP streams. -mod stream_type; - -use std::sync::mpsc; -use std::thread; -use tokio_core::reactor; - -lazy_static! { - /// The `Remote` for the default reactor core. - static ref REMOTE: reactor::Remote = { - spawn_core() - }; -} - -/// Spawns a `reactor::Core` running forever on a new thread. -fn spawn_core() -> reactor::Remote { - let (tx, rx) = mpsc::channel(); - thread::spawn(move || { - let mut core = reactor::Core::new().unwrap(); - tx.send(core.handle().remote().clone()).unwrap(); - - // Run forever - core.run(futures::empty::<(), !>()).unwrap(); - }); - rx.recv().unwrap() -} - -cfg_if! { - if #[cfg(feature = "tls")] { - extern crate tokio_tls; - extern crate native_tls as native_tls_inner; - - /// Re-exported TLS-related types from the `native_tls` crate. - pub mod native_tls { - pub use native_tls_inner::{Error, Pkcs12, TlsAcceptor, TlsConnector}; - } - } else {} -} diff --git a/src/macros.rs b/src/macros.rs deleted file mode 100644 index 98267837..00000000 --- a/src/macros.rs +++ /dev/null @@ -1,1191 +0,0 @@ -// Copyright 2016 Google Inc. All Rights Reserved. -// -// Licensed under the MIT License, . -// This file may not be copied, modified, or distributed except according to those terms. - -#[doc(hidden)] -#[macro_export] -macro_rules! as_item { - ($i:item) => {$i}; -} - -/// The main macro that creates RPC services. -/// -/// Rpc methods are specified, mirroring trait syntax: -/// -/// ``` -/// # #![feature(plugin, use_extern_macros, proc_macro_path_invoc)] -/// # #![plugin(tarpc_plugins)] -/// # #[macro_use] extern crate tarpc; -/// # fn main() {} -/// # service! { -/// /// Say hello -/// rpc hello(name: String) -> String; -/// # } -/// ``` -/// -/// Attributes can be attached to each rpc. These attributes -/// will then be attached to the generated service traits' -/// corresponding `fn`s, as well as to the client stubs' RPCs. -/// -/// The following items are expanded in the enclosing module: -/// -/// * `FutureService` -- the trait defining the RPC service via a `Future` API. -/// * `SyncService` -- a service trait that provides a synchronous API for when -/// spawning a thread per request is acceptable. -/// * `FutureServiceExt` -- provides the methods for starting a service. There is an umbrella impl -/// for all implers of `FutureService`. It's a separate trait to prevent -/// name collisions with RPCs. -/// * `SyncServiceExt` -- same as `FutureServiceExt` but for `SyncService`. -/// * `FutureClient` -- a client whose RPCs return `Future`s. -/// * `SyncClient` -- a client whose RPCs block until the reply is available. Easiest -/// interface to use, as it looks the same as a regular function call. -/// -#[macro_export] -macro_rules! service { -// Entry point - ( - $( - $(#[$attr:meta])* - rpc $fn_name:ident( $( $arg:ident : $in_:ty ),* ) $(-> $out:ty)* $(| $error:ty)*; - )* - ) => { - service! {{ - $( - $(#[$attr])* - rpc $fn_name( $( $arg : $in_ ),* ) $(-> $out)* $(| $error)*; - )* - }} - }; -// Pattern for when the next rpc has an implicit unit return type and no error type. - ( - { - $(#[$attr:meta])* - rpc $fn_name:ident( $( $arg:ident : $in_:ty ),* ); - - $( $unexpanded:tt )* - } - $( $expanded:tt )* - ) => { - service! { - { $( $unexpanded )* } - - $( $expanded )* - - $(#[$attr])* - rpc $fn_name( $( $arg : $in_ ),* ) -> () | $crate::util::Never; - } - }; -// Pattern for when the next rpc has an explicit return type and no error type. - ( - { - $(#[$attr:meta])* - rpc $fn_name:ident( $( $arg:ident : $in_:ty ),* ) -> $out:ty; - - $( $unexpanded:tt )* - } - $( $expanded:tt )* - ) => { - service! { - { $( $unexpanded )* } - - $( $expanded )* - - $(#[$attr])* - rpc $fn_name( $( $arg : $in_ ),* ) -> $out | $crate::util::Never; - } - }; -// Pattern for when the next rpc has an implicit unit return type and an explicit error type. - ( - { - $(#[$attr:meta])* - rpc $fn_name:ident( $( $arg:ident : $in_:ty ),* ) | $error:ty; - - $( $unexpanded:tt )* - } - $( $expanded:tt )* - ) => { - service! { - { $( $unexpanded )* } - - $( $expanded )* - - $(#[$attr])* - rpc $fn_name( $( $arg : $in_ ),* ) -> () | $error; - } - }; -// Pattern for when the next rpc has an explicit return type and an explicit error type. - ( - { - $(#[$attr:meta])* - rpc $fn_name:ident( $( $arg:ident : $in_:ty ),* ) -> $out:ty | $error:ty; - - $( $unexpanded:tt )* - } - $( $expanded:tt )* - ) => { - service! { - { $( $unexpanded )* } - - $( $expanded )* - - $(#[$attr])* - rpc $fn_name( $( $arg : $in_ ),* ) -> $out | $error; - } - }; -// Pattern for when all return types have been expanded - ( - { } // none left to expand - $( - $(#[$attr:meta])* - rpc $fn_name:ident ( $( $arg:ident : $in_:ty ),* ) -> $out:ty | $error:ty; - )* - ) => { - - #[doc(hidden)] - #[allow(non_camel_case_types, unused)] - #[derive($crate::serde_derive::Serialize, $crate::serde_derive::Deserialize)] - pub enum Request__ { - $( - $fn_name{ $($arg: $in_,)* } - ),* - } - - #[doc(hidden)] - #[allow(non_camel_case_types, unused)] - #[derive($crate::serde_derive::Serialize, $crate::serde_derive::Deserialize)] - pub enum Response__ { - $( - $fn_name($out) - ),* - } - - #[doc(hidden)] - #[allow(non_camel_case_types, unused)] - #[derive(Debug, $crate::serde_derive::Deserialize, $crate::serde_derive::Serialize)] - pub enum Error__ { - $( - $fn_name($error) - ),* - } - -/// Defines the `Future` RPC service. Implementors must be `Clone` and `'static`, -/// as required by `tokio_proto::NewService`. This is required so that the service can be used -/// to respond to multiple requests concurrently. - pub trait FutureService: - ::std::clone::Clone + - 'static - { - $( - snake_to_camel! { - /// The type of future returned by `{}`. - type $fn_name: $crate::futures::IntoFuture; - } - - $(#[$attr])* - fn $fn_name(&self, $($arg:$in_),*) -> ty_snake_to_camel!(Self::$fn_name); - )* - } - - #[allow(non_camel_case_types)] - #[derive(Clone)] - struct TarpcNewService(S); - - impl ::std::fmt::Debug for TarpcNewService { - fn fmt(&self, fmt: &mut ::std::fmt::Formatter) -> ::std::fmt::Result { - fmt.debug_struct("TarpcNewService").finish() - } - } - - #[allow(non_camel_case_types)] - type ResponseFuture__ = - $crate::futures::Finished<$crate::future::server::Response, - ::std::io::Error>; - - #[allow(non_camel_case_types)] - enum FutureReply__ { - DeserializeError(ResponseFuture__), - $($fn_name( - $crate::futures::Then< - ::Future, - ResponseFuture__, - fn(::std::result::Result<$out, $error>) -> ResponseFuture__>)),* - } - - impl $crate::futures::Future for FutureReply__ { - type Item = $crate::future::server::Response; - type Error = ::std::io::Error; - - fn poll(&mut self) -> $crate::futures::Poll { - match *self { - FutureReply__::DeserializeError(ref mut future__) => { - $crate::futures::Future::poll(future__) - } - $( - FutureReply__::$fn_name(ref mut future__) => { - $crate::futures::Future::poll(future__) - } - ),* - } - } - } - - - #[allow(non_camel_case_types)] - impl $crate::tokio_service::Service for TarpcNewService - where S__: FutureService - { - type Request = ::std::result::Result; - type Response = $crate::future::server::Response; - type Error = ::std::io::Error; - type Future = FutureReply__; - - fn call(&self, request__: Self::Request) -> Self::Future { - let request__ = match request__ { - Ok(request__) => request__, - Err(err__) => { - return FutureReply__::DeserializeError( - $crate::futures::finished( - ::std::result::Result::Err( - $crate::WireError::RequestDeserialize( - ::std::string::ToString::to_string(&err__))))); - } - }; - #[allow(unreachable_patterns)] - match request__ { - $( - Request__::$fn_name{ $($arg,)* } => { - fn wrap__(response__: ::std::result::Result<$out, $error>) - -> ResponseFuture__ - { - $crate::futures::finished( - response__ - .map(Response__::$fn_name) - .map_err(|err__| { - $crate::WireError::App(Error__::$fn_name(err__)) - }) - ) - } - return FutureReply__::$fn_name( - $crate::futures::Future::then( - $crate::futures::IntoFuture::into_future( - FutureService::$fn_name(&self.0, $($arg),*)), - wrap__)); - } - )* - _ => unreachable!(), - } - } - } - - #[allow(non_camel_case_types)] - impl $crate::tokio_service::NewService - for TarpcNewService - where S__: FutureService - { - type Request = ::Request; - type Response = ::Response; - type Error = ::Error; - type Instance = Self; - - fn new_service(&self) -> ::std::io::Result { - Ok(self.clone()) - } - } - - /// The future returned by `FutureServiceExt::listen`. - #[allow(unused)] - pub struct Listen - where S: FutureService, - { - inner: $crate::future::server::Listen, - Request__, - Response__, - Error__>, - } - - impl $crate::futures::Future for Listen - where S: FutureService - { - type Item = (); - type Error = (); - - fn poll(&mut self) -> $crate::futures::Poll<(), ()> { - self.inner.poll() - } - } - - /// Provides a function for starting the service. This is a separate trait from - /// `FutureService` to prevent collisions with the names of RPCs. - pub trait FutureServiceExt: FutureService { - /// Spawns the service, binding to the given address and running on - /// the `reactor::Core` associated with `handle`. - /// - /// Returns the address being listened on as well as the server future. The future - /// must be executed for the server to run. - fn listen(self, - addr: ::std::net::SocketAddr, - handle: &$crate::tokio_core::reactor::Handle, - options: $crate::future::server::Options) - -> ::std::io::Result<($crate::future::server::Handle, Listen)> - { - $crate::future::server::listen(TarpcNewService(self), - addr, - handle, - options) - .map(|(handle, inner)| (handle, Listen { inner })) - } - } - - /// Defines the blocking RPC service. Must be `Clone`, `Send`, and `'static`, -/// as required by `tokio_proto::NewService`. This is required so that the service can be used -/// to respond to multiple requests concurrently. - pub trait SyncService: - ::std::marker::Send + - ::std::clone::Clone + - 'static - { - $( - $(#[$attr])* - fn $fn_name(&self, $($arg:$in_),*) -> ::std::result::Result<$out, $error>; - )* - } - - /// Provides a function for starting the service. This is a separate trait from - /// `SyncService` to prevent collisions with the names of RPCs. - pub trait SyncServiceExt: SyncService { - /// Spawns the service, binding to the given address and returning the server handle. - /// - /// To actually run the server, call `run` on the returned handle. - fn listen(self, addr: A, options: $crate::sync::server::Options) - -> ::std::io::Result<$crate::sync::server::Handle> - where A: ::std::net::ToSocketAddrs - { - #[derive(Clone)] - struct BlockingFutureService(S); - impl FutureService for BlockingFutureService { - $( - impl_snake_to_camel! { - type $fn_name = - $crate::util::Lazy< - fn((S, $($in_),*)) -> ::std::result::Result<$out, $error>, - (S, $($in_),*), - ::std::result::Result<$out, $error>>; - } - - $(#[$attr])* - fn $fn_name(&self, $($arg:$in_),*) - -> $crate::util::Lazy< - fn((S, $($in_),*)) -> ::std::result::Result<$out, $error>, - (S, $($in_),*), - ::std::result::Result<$out, $error>> { - fn execute((s, $($arg),*): (S, $($in_),*)) - -> ::std::result::Result<$out, $error> { - SyncService::$fn_name(&s, $($arg),*) - } - $crate::util::lazy(execute, (self.0.clone(), $($arg),*)) - } - )* - } - - let tarpc_service__ = TarpcNewService(BlockingFutureService(self)); - let addr__ = $crate::util::FirstSocketAddr::try_first_socket_addr(&addr)?; - return $crate::sync::server::listen(tarpc_service__, addr__, options); - } - } - - impl FutureServiceExt for A where A: FutureService {} - impl SyncServiceExt for S where S: SyncService {} - - /// The client stub that makes RPC calls to the server. Exposes a blocking interface. - #[allow(unused)] - #[derive(Clone, Debug)] - pub struct SyncClient { - inner: SyncClient__, - } - - impl $crate::sync::client::ClientExt for SyncClient { - fn connect(addr_: A, options_: $crate::sync::client::Options) - -> ::std::io::Result - where A: ::std::net::ToSocketAddrs, - { - let client_ = - ::connect(addr_, options_)?; - ::std::result::Result::Ok(SyncClient { - inner: client_, - }) - } - } - - impl SyncClient { - $( - #[allow(unused)] - $(#[$attr])* - pub fn $fn_name(&self, $($arg: $in_),*) - -> ::std::result::Result<$out, $crate::Error<$error>> - { - tarpc_service_then__!($out, $error, $fn_name); - let resp__ = self.inner.call(Request__::$fn_name { $($arg,)* }); - tarpc_service_then__(resp__) - } - )* - } - - #[allow(non_camel_case_types)] - type FutureClient__ = $crate::future::client::Client; - - #[allow(non_camel_case_types)] - type SyncClient__ = $crate::sync::client::Client; - - #[allow(non_camel_case_types)] - /// A future representing a client connecting to a server. - pub struct Connect { - inner: - $crate::futures::Map< - $crate::future::client::ConnectFuture< Request__, Response__, Error__>, - fn(FutureClient__) -> T>, - } - - impl $crate::futures::Future for Connect { - type Item = T; - type Error = ::std::io::Error; - - fn poll(&mut self) -> $crate::futures::Poll { - $crate::futures::Future::poll(&mut self.inner) - } - } - - #[allow(unused)] - #[derive(Clone, Debug)] - /// The client stub that makes RPC calls to the server. Exposes a Future interface. - pub struct FutureClient(FutureClient__); - - impl<'a> $crate::future::client::ClientExt for FutureClient { - type ConnectFut = Connect; - - fn connect(addr__: ::std::net::SocketAddr, - options__: $crate::future::client::Options) - -> Self::ConnectFut - { - let client = - ::connect(addr__, - options__); - - Connect { - inner: $crate::futures::Future::map(client, FutureClient) - } - } - } - - impl FutureClient { - $( - #[allow(unused)] - $(#[$attr])* - pub fn $fn_name(&self, $($arg: $in_),*) - -> $crate::futures::future::Then< - ::Future, - ::std::result::Result<$out, $crate::Error<$error>>, - fn(::std::result::Result>) - -> ::std::result::Result<$out, $crate::Error<$error>>> { - tarpc_service_then__!($out, $error, $fn_name); - - let request__ = Request__::$fn_name { $($arg,)* }; - let future__ = $crate::tokio_service::Service::call(&self.0, request__); - return $crate::futures::Future::then(future__, tarpc_service_then__); - } - )* - } - } -} - -#[doc(hidden)] -#[macro_export] -macro_rules! tarpc_service_then__ { - ($out:ty, $error:ty, $fn_name:ident) => { - fn tarpc_service_then__(msg__: - ::std::result::Result>) - -> ::std::result::Result<$out, $crate::Error<$error>> { - match msg__ { - ::std::result::Result::Ok(msg__) => { - #[allow(unreachable_patterns)] - match msg__ { - Response__::$fn_name(msg__) => - ::std::result::Result::Ok(msg__), - _ => unreachable!(), - } - } - ::std::result::Result::Err(err__) => { - ::std::result::Result::Err(match err__ { - $crate::Error::App(err__) => { - #[allow(unreachable_patterns)] - match err__ { - Error__::$fn_name(err__) => - $crate::Error::App(err__), - _ => unreachable!(), - } - } - $crate::Error::RequestDeserialize(err__) => { - $crate::Error::RequestDeserialize(err__) - } - $crate::Error::ResponseDeserialize(err__) => { - $crate::Error::ResponseDeserialize(err__) - } - $crate::Error::Io(err__) => { - $crate::Error::Io(err__) - } - }) - } - } - } - }; -} - -// allow dead code; we're just testing that the macro expansion compiles -#[allow(dead_code)] -#[cfg(test)] -mod syntax_test { - use util::Never; - - service! { - #[deny(warnings)] - #[allow(non_snake_case)] - rpc TestCamelCaseDoesntConflict(); - rpc hello() -> String; - #[doc="attr"] - rpc attr(s: String) -> String; - rpc no_args_no_return(); - rpc no_args() -> (); - rpc one_arg(foo: String) -> i32; - rpc two_args_no_return(bar: String, baz: u64); - rpc two_args(bar: String, baz: u64) -> String; - rpc no_args_ret_error() -> i32 | Never; - rpc one_arg_ret_error(foo: String) -> String | Never; - rpc no_arg_implicit_return_error() | Never; - #[doc="attr"] - rpc one_arg_implicit_return_error(foo: String) | Never; - } -} - -#[cfg(test)] -mod functional_test { - use {sync, future}; - use futures::{Future, failed}; - use std::io; - use std::net::SocketAddr; - use tokio_core::reactor; - use util::FirstSocketAddr; - extern crate env_logger; - - macro_rules! unwrap { - ($e:expr) => (match $e { - Ok(e) => e, - Err(e) => panic!("{} failed with {:?}", stringify!($e), e), - }) - } - - service! { - rpc add(x: i32, y: i32) -> i32; - rpc hey(name: String) -> String; - } - - cfg_if! { - if #[cfg(feature = "tls")] { - const DOMAIN: &str = "foobar.com"; - - use tls::client::Context; - use native_tls::{Pkcs12, TlsAcceptor, TlsConnector}; - - fn get_tls_acceptor() -> TlsAcceptor { - let buf = include_bytes!("../test/identity.p12"); - let pkcs12 = unwrap!(Pkcs12::from_der(buf, "mypass")); - unwrap!(unwrap!(TlsAcceptor::builder(pkcs12)).build()) - } - - fn get_future_tls_server_options() -> future::server::Options { - future::server::Options::default().tls(get_tls_acceptor()) - } - - fn get_sync_tls_server_options() -> sync::server::Options { - sync::server::Options::default().tls(get_tls_acceptor()) - } - - // Making the TlsConnector for testing needs to be OS-dependent just like native-tls. - // We need to go through this trickery because the test self-signed cert is not part - // of the system's cert chain. If it was, then all that is required is - // `TlsConnector::builder().unwrap().build().unwrap()`. - cfg_if! { - if #[cfg(target_os = "macos")] { - extern crate security_framework; - - use native_tls_inner::Certificate; - - fn get_future_tls_client_options() -> future::client::Options { - future::client::Options::default().tls(get_tls_client_context()) - } - - fn get_sync_tls_client_options() -> sync::client::Options { - sync::client::Options::default().tls(get_tls_client_context()) - } - - fn get_tls_client_context() -> Context { - let buf = include_bytes!("../test/root-ca.der"); - let cert = unwrap!(Certificate::from_der(buf)); - let mut connector = unwrap!(TlsConnector::builder()); - connector.add_root_certificate(cert).unwrap(); - - Context { - domain: DOMAIN.into(), - tls_connector: unwrap!(connector.build()), - } - } - } else if #[cfg(all(not(target_os = "macos"), not(windows)))] { - use native_tls_inner::backend::openssl::TlsConnectorBuilderExt; - - fn get_sync_tls_client_options() -> sync::client::Options { - sync::client::Options::default() - .tls(get_tls_client_context()) - } - - fn get_future_tls_client_options() -> future::client::Options { - future::client::Options::default() - .tls(get_tls_client_context()) - } - - fn get_tls_client_context() -> Context { - let mut connector = unwrap!(TlsConnector::builder()); - unwrap!(connector.builder_mut() - .set_ca_file("test/root-ca.pem")); - Context { - domain: DOMAIN.into(), - tls_connector: unwrap!(connector.build()), - } - } - // not implemented for windows or other platforms - } else { - fn get_tls_client_context() -> Context { - unimplemented!() - } - } - } - - fn get_sync_client(addr: SocketAddr) -> io::Result - where C: sync::client::ClientExt - { - C::connect(addr, get_sync_tls_client_options()) - } - - fn get_future_client(addr: SocketAddr, handle: reactor::Handle) -> C::ConnectFut - where C: future::client::ClientExt - { - C::connect(addr, get_future_tls_client_options().handle(handle)) - } - - fn start_server_with_sync_client(server: S) - -> io::Result<(SocketAddr, C, future::server::Shutdown)> - where C: sync::client::ClientExt, S: SyncServiceExt - { - let options = get_sync_tls_server_options(); - let (tx, rx) = ::std::sync::mpsc::channel(); - ::std::thread::spawn(move || { - let handle = unwrap!(server.listen("localhost:0".first_socket_addr(), - options)); - tx.send((handle.addr(), handle.shutdown())).unwrap(); - handle.run(); - }); - let (addr, shutdown) = rx.recv().unwrap(); - let client = unwrap!(C::connect(addr, get_sync_tls_client_options())); - Ok((addr, client, shutdown)) - } - - fn start_server_with_async_client(server: S) - -> io::Result<(future::server::Handle, reactor::Core, C)> - where C: future::client::ClientExt, S: FutureServiceExt - { - let mut reactor = reactor::Core::new()?; - let server_options = get_future_tls_server_options(); - let (handle, server) = server.listen("localhost:0".first_socket_addr(), - &reactor.handle(), - server_options)?; - reactor.handle().spawn(server); - let client_options = get_future_tls_client_options().handle(reactor.handle()); - let client = unwrap!(reactor.run(C::connect(handle.addr(), client_options))); - Ok((handle, reactor, client)) - } - - fn return_server(server: S) - -> io::Result<(future::server::Handle, reactor::Core, Listen)> - where S: FutureServiceExt - { - let reactor = reactor::Core::new()?; - let server_options = get_future_tls_server_options(); - let (handle, server) = server.listen("localhost:0".first_socket_addr(), - &reactor.handle(), - server_options)?; - Ok((handle, reactor, server)) - } - - fn start_err_server_with_async_client(server: S) - -> io::Result<(future::server::Handle, reactor::Core, C)> - where C: future::client::ClientExt, S: error_service::FutureServiceExt - { - let mut reactor = reactor::Core::new()?; - let server_options = get_future_tls_server_options(); - let (handle, server) = server.listen("localhost:0".first_socket_addr(), - &reactor.handle(), - server_options)?; - reactor.handle().spawn(server); - let client_options = get_future_tls_client_options().handle(reactor.handle()); - let client = unwrap!(reactor.run(C::connect(handle.addr(), client_options))); - Ok((handle, reactor, client)) - } - } else { - fn get_future_server_options() -> future::server::Options { - future::server::Options::default() - } - - fn get_sync_server_options() -> sync::server::Options { - sync::server::Options::default() - } - - fn get_future_client_options() -> future::client::Options { - future::client::Options::default() - } - - fn get_sync_client_options() -> sync::client::Options { - sync::client::Options::default() - } - - fn get_sync_client(addr: SocketAddr) -> io::Result - where C: sync::client::ClientExt - { - C::connect(addr, get_sync_client_options()) - } - - fn get_future_client(addr: SocketAddr, handle: reactor::Handle) -> C::ConnectFut - where C: future::client::ClientExt - { - C::connect(addr, get_future_client_options().handle(handle)) - } - - fn start_server_with_sync_client(server: S) - -> io::Result<(SocketAddr, C, future::server::Shutdown)> - where C: sync::client::ClientExt, S: SyncServiceExt - { - let options = get_sync_server_options(); - let (tx, rx) = ::std::sync::mpsc::channel(); - ::std::thread::spawn(move || { - let handle = unwrap!(server.listen("localhost:0".first_socket_addr(), options)); - tx.send((handle.addr(), handle.shutdown())).unwrap(); - handle.run(); - }); - let (addr, shutdown) = rx.recv().unwrap(); - let client = unwrap!(get_sync_client(addr)); - Ok((addr, client, shutdown)) - } - - fn start_server_with_async_client(server: S) - -> io::Result<(future::server::Handle, reactor::Core, C)> - where C: future::client::ClientExt, S: FutureServiceExt - { - let mut reactor = reactor::Core::new()?; - let options = get_future_server_options(); - let (handle, server) = server.listen("localhost:0".first_socket_addr(), - &reactor.handle(), - options)?; - reactor.handle().spawn(server); - let client = unwrap!(reactor.run(C::connect(handle.addr(), - get_future_client_options()))); - Ok((handle, reactor, client)) - } - - fn return_server(server: S) - -> io::Result<(future::server::Handle, reactor::Core, Listen)> - where S: FutureServiceExt - { - let reactor = reactor::Core::new()?; - let options = get_future_server_options(); - let (handle, server) = server.listen("localhost:0".first_socket_addr(), - &reactor.handle(), - options)?; - Ok((handle, reactor, server)) - } - - fn start_err_server_with_async_client(server: S) - -> io::Result<(future::server::Handle, reactor::Core, C)> - where C: future::client::ClientExt, S: error_service::FutureServiceExt - { - let mut reactor = reactor::Core::new()?; - let options = get_future_server_options(); - let (handle, server) = server.listen("localhost:0".first_socket_addr(), - &reactor.handle(), - options)?; - reactor.handle().spawn(server); - let client = C::connect(handle.addr(), get_future_client_options()); - let client = unwrap!(reactor.run(client)); - Ok((handle, reactor, client)) - } - } - } - - mod sync_tests { - use super::{SyncClient, SyncService, get_sync_client, env_logger, - start_server_with_sync_client}; - use util::Never; - - #[derive(Clone, Copy)] - struct Server; - - impl SyncService for Server { - fn add(&self, x: i32, y: i32) -> Result { - Ok(x + y) - } - fn hey(&self, name: String) -> Result { - Ok(format!("Hey, {}.", name)) - } - } - - #[test] - fn simple() { - let _ = env_logger::try_init(); - let (_, client, _) = - unwrap!(start_server_with_sync_client::(Server)); - assert_eq!(3, client.add(1, 2).unwrap()); - assert_eq!("Hey, Tim.", client.hey("Tim".to_string()).unwrap()); - } - - #[test] - fn shutdown() { - use futures::{Async, Future}; - - let _ = env_logger::try_init(); - let (addr, client, shutdown) = - unwrap!(start_server_with_sync_client::(Server)); - assert_eq!(3, unwrap!(client.add(1, 2))); - assert_eq!("Hey, Tim.", unwrap!(client.hey("Tim".to_string()))); - - info!("Dropping client."); - drop(client); - let (tx, rx) = ::std::sync::mpsc::channel(); - let (tx2, rx2) = ::std::sync::mpsc::channel(); - let shutdown2 = shutdown.clone(); - ::std::thread::spawn(move || { - let client = unwrap!(get_sync_client::(addr)); - let add = unwrap!(client.add(3, 2)); - unwrap!(tx.send(())); - drop(client); - // Make sure 2 shutdowns are concurrent safe. - unwrap!(shutdown2.shutdown().wait()); - unwrap!(tx2.send(add)); - }); - unwrap!(rx.recv()); - let mut shutdown1 = shutdown.shutdown(); - unwrap!(shutdown.shutdown().wait()); - // Assert shutdown2 blocks until shutdown is complete. - if let Async::NotReady = unwrap!(shutdown1.poll()) { - panic!("Shutdown should have completed"); - } - // Existing clients are served - assert_eq!(5, unwrap!(rx2.recv())); - - let e = get_sync_client::(addr).err().unwrap(); - debug!("(Success) shutdown caused client err: {}", e); - } - - #[test] - fn no_shutdown() { - let _ = env_logger::try_init(); - let (addr, client, shutdown) = - unwrap!(start_server_with_sync_client::(Server)); - assert_eq!(3, client.add(1, 2).unwrap()); - assert_eq!("Hey, Tim.", client.hey("Tim".to_string()).unwrap()); - - drop(shutdown); - - // Existing clients are served. - assert_eq!(3, client.add(1, 2).unwrap()); - // New connections are accepted. - assert!(get_sync_client::(addr).is_ok()); - } - - #[test] - fn other_service() { - let _ = env_logger::try_init(); - let (_, client, _) = unwrap!(start_server_with_sync_client::< - super::other_service::SyncClient, - Server, - >(Server)); - match client.foo().err().expect("failed unwrap") { - ::Error::RequestDeserialize(_) => {} // good - bad => panic!("Expected Error::RequestDeserialize but got {}", bad), - } - } - } - - mod bad_serialize { - use serde::{Serialize, Serializer}; - use serde::ser::SerializeSeq; - use sync::{client, server}; - use sync::client::ClientExt; - - #[derive(Deserialize)] - pub struct Bad; - - impl Serialize for Bad { - fn serialize(&self, serializer: S) -> Result - where - S: Serializer, - { - serializer.serialize_seq(None)?.end() - } - } - - service! { - rpc bad(bad: Bad) | (); - } - - impl SyncService for () { - fn bad(&self, _: Bad) -> Result<(), ()> { - Ok(()) - } - } - - #[test] - fn bad_serialize() { - let handle = () - .listen("localhost:0", server::Options::default()) - .unwrap(); - let client = SyncClient::connect(handle.addr(), client::Options::default()).unwrap(); - client.bad(Bad).err().unwrap(); - } - } - - mod future_tests { - use super::{FutureClient, FutureService, env_logger, get_future_client, return_server, - start_server_with_async_client}; - use futures::{Finished, finished}; - use tokio_core::reactor; - use util::Never; - - #[derive(Clone)] - struct Server; - - impl FutureService for Server { - type AddFut = Finished; - - fn add(&self, x: i32, y: i32) -> Self::AddFut { - finished(x + y) - } - - type HeyFut = Finished; - - fn hey(&self, name: String) -> Self::HeyFut { - finished(format!("Hey, {}.", name)) - } - } - - #[test] - fn simple() { - let _ = env_logger::try_init(); - let (_, mut reactor, client) = unwrap!( - start_server_with_async_client::(Server) - ); - assert_eq!(3, reactor.run(client.add(1, 2)).unwrap()); - assert_eq!( - "Hey, Tim.", - reactor.run(client.hey("Tim".to_string())).unwrap() - ); - } - - #[test] - fn shutdown() { - use futures::Future; - use tokio_core::reactor; - - let _ = env_logger::try_init(); - let (handle, mut reactor, server) = unwrap!(return_server::(Server)); - - let (tx, rx) = ::std::sync::mpsc::channel(); - ::std::thread::spawn(move || { - let mut reactor = reactor::Core::new().unwrap(); - let client = get_future_client::(handle.addr(), reactor.handle()); - let client = reactor.run(client).unwrap(); - let add = reactor.run(client.add(3, 2)).unwrap(); - assert_eq!(add, 5); - trace!("Dropping client."); - drop(reactor); - debug!("Shutting down..."); - handle.shutdown().shutdown().wait().unwrap(); - tx.send(add).unwrap(); - }); - reactor.run(server).unwrap(); - assert_eq!(rx.recv().unwrap(), 5); - } - - #[test] - fn concurrent() { - let _ = env_logger::try_init(); - let (_, mut reactor, client) = unwrap!( - start_server_with_async_client::(Server) - ); - let req1 = client.add(1, 2); - let req2 = client.add(3, 4); - let req3 = client.hey("Tim".to_string()); - assert_eq!(3, reactor.run(req1).unwrap()); - assert_eq!(7, reactor.run(req2).unwrap()); - assert_eq!("Hey, Tim.", reactor.run(req3).unwrap()); - } - - #[test] - fn other_service() { - let _ = env_logger::try_init(); - let (_, mut reactor, client) = unwrap!(start_server_with_async_client::< - super::other_service::FutureClient, - Server, - >(Server)); - match reactor.run(client.foo()).err().unwrap() { - ::Error::RequestDeserialize(_) => {} // good - bad => panic!(r#"Expected Error::RequestDeserialize but got "{}""#, bad), - } - } - - #[test] - fn reuse_addr() { - use util::FirstSocketAddr; - use future::server; - use super::FutureServiceExt; - - let _ = env_logger::try_init(); - let reactor = reactor::Core::new().unwrap(); - let handle = Server - .listen( - "localhost:0".first_socket_addr(), - &reactor.handle(), - server::Options::default(), - ) - .unwrap() - .0; - Server - .listen(handle.addr(), &reactor.handle(), server::Options::default()) - .unwrap(); - } - - #[test] - fn drop_client() { - use future::{client, server}; - use future::client::ClientExt; - use util::FirstSocketAddr; - use super::{FutureClient, FutureServiceExt}; - - let _ = env_logger::try_init(); - let mut reactor = reactor::Core::new().unwrap(); - let (handle, server) = Server - .listen( - "localhost:0".first_socket_addr(), - &reactor.handle(), - server::Options::default(), - ) - .unwrap(); - reactor.handle().spawn(server); - - let client = FutureClient::connect( - handle.addr(), - client::Options::default().handle(reactor.handle()), - ); - let client = unwrap!(reactor.run(client)); - assert_eq!(reactor.run(client.add(1, 2)).unwrap(), 3); - drop(client); - - let client = FutureClient::connect( - handle.addr(), - client::Options::default().handle(reactor.handle()), - ); - let client = unwrap!(reactor.run(client)); - assert_eq!(reactor.run(client.add(1, 2)).unwrap(), 3); - } - - #[cfg(feature = "tls")] - #[test] - fn tcp_and_tls() { - use future::{client, server}; - use util::FirstSocketAddr; - use future::client::ClientExt; - use super::FutureServiceExt; - - let _ = env_logger::try_init(); - let (_, mut reactor, client) = unwrap!( - start_server_with_async_client::(Server) - ); - assert_eq!(3, reactor.run(client.add(1, 2)).unwrap()); - assert_eq!( - "Hey, Tim.", - reactor.run(client.hey("Tim".to_string())).unwrap() - ); - - let (handle, server) = Server - .listen( - "localhost:0".first_socket_addr(), - &reactor.handle(), - server::Options::default(), - ) - .unwrap(); - reactor.handle().spawn(server); - let options = client::Options::default().handle(reactor.handle()); - let client = reactor - .run(FutureClient::connect(handle.addr(), options)) - .unwrap(); - assert_eq!(3, reactor.run(client.add(1, 2)).unwrap()); - assert_eq!( - "Hey, Tim.", - reactor.run(client.hey("Tim".to_string())).unwrap() - ); - } - } - - pub mod error_service { - service! { - rpc bar() -> u32 | ::util::Message; - } - } - - #[derive(Clone)] - struct ErrorServer; - - impl error_service::FutureService for ErrorServer { - type BarFut = ::futures::Failed; - - fn bar(&self) -> Self::BarFut { - info!("Called bar"); - failed("lol jk".into()) - } - } - - #[test] - fn error() { - use std::error::Error as E; - use self::error_service::*; - let _ = env_logger::try_init(); - - let (_, mut reactor, client) = - start_err_server_with_async_client::(ErrorServer).unwrap(); - reactor - .run(client.bar().then(move |result| { - match result.err().unwrap() { - ::Error::App(e) => { - assert_eq!(e.description(), "lol jk"); - Ok::<_, ()>(()) - } // good - bad => panic!("Expected Error::App but got {:?}", bad), - } - })) - .unwrap(); - } - - pub mod other_service { - service! { - rpc foo(); - } - } -} diff --git a/src/plugins/src/lib.rs b/src/plugins/src/lib.rs deleted file mode 100644 index cec47295..00000000 --- a/src/plugins/src/lib.rs +++ /dev/null @@ -1,197 +0,0 @@ -#![feature(plugin_registrar, rustc_private)] - -extern crate itertools; -extern crate rustc_plugin; -extern crate smallvec; -extern crate syntax; - -use itertools::Itertools; -use rustc_plugin::Registry; -use smallvec::SmallVec; -use syntax::ast::{self, Ident, TraitRef, Ty, TyKind}; -use syntax::ext::base::{ExtCtxt, MacResult, DummyResult, MacEager}; -use syntax::ext::quote::rt::Span; -use syntax::parse::{self, token, str_lit, PResult}; -use syntax::parse::parser::{Parser, PathStyle}; -use syntax::symbol::Symbol; -use syntax::ptr::P; -use syntax::tokenstream::{TokenTree, TokenStream}; - -fn snake_to_camel(cx: &mut ExtCtxt, sp: Span, tts: &[TokenTree]) -> Box { - let mut parser = parse::new_parser_from_tts(cx.parse_sess(), tts.into()); - // The `expand_expr` method is called so that any macro calls in the - // parsed expression are expanded. - - let mut item = match parser.parse_trait_item(&mut false) { - Ok(s) => s, - Err(mut diagnostic) => { - diagnostic.emit(); - return DummyResult::any(sp); - } - }; - - if let Err(mut diagnostic) = parser.expect(&token::Eof) { - diagnostic.emit(); - return DummyResult::any(sp); - } - - let old_ident = convert(&mut item.ident); - - // As far as I know, it's not possible in macro_rules! to reference an $ident in a doc string, - // so this is the hacky workaround. - // - // This code looks intimidating, but it's just iterating through the trait item's attributes - // copying non-doc attributes, and modifying doc attributes such that replacing any {} in the - // doc string instead holds the original, snake_case ident. - let attrs: Vec<_> = item.attrs - .drain(..) - .map(|mut attr| { - if !attr.is_sugared_doc { - return attr; - } - - // Getting at the underlying doc comment is surprisingly painful. - // The call-chain goes something like: - // - // - https://github.com/rust-lang/rust/blob/9c15de4fd59bee290848b5443c7e194fd5afb02c/src/libsyntax/attr.rs#L283 - // - https://github.com/rust-lang/rust/blob/9c15de4fd59bee290848b5443c7e194fd5afb02c/src/libsyntax/attr.rs#L1067 - // - https://github.com/rust-lang/rust/blob/9c15de4fd59bee290848b5443c7e194fd5afb02c/src/libsyntax/attr.rs#L1196 - // - https://github.com/rust-lang/rust/blob/9c15de4fd59bee290848b5443c7e194fd5afb02c/src/libsyntax/parse/mod.rs#L399 - // - https://github.com/rust-lang/rust/blob/9c15de4fd59bee290848b5443c7e194fd5afb02c/src/libsyntax/parse/mod.rs#L268 - // - // Note that a docstring (i.e., something with is_sugared_doc) *always* has exactly two - // tokens: an Eq followed by a Literal, where the Literal contains a Str_. We therefore - // match against that, modifying the inner Str with our modified Symbol. - let mut tokens = attr.tokens.clone().into_trees(); - if let Some(tt @ TokenTree::Token(_, token::Eq)) = tokens.next() { - let mut docstr = tokens.next().expect("Docstrings must have literal docstring"); - if let TokenTree::Token(_, token::Literal(token::Str_(ref mut doc), _)) = docstr { - *doc = Symbol::intern(&str_lit(&doc.as_str(), None).replace("{}", &old_ident)); - } else { - unreachable!(); - } - attr.tokens = TokenStream::concat(vec![tt.into(), docstr.into()]); - } else { - unreachable!(); - } - - attr - }) - .collect(); - item.attrs.extend(attrs.into_iter()); - - MacEager::trait_items(SmallVec::from_buf([item])) -} - -fn impl_snake_to_camel(cx: &mut ExtCtxt, sp: Span, tts: &[TokenTree]) -> Box { - let mut parser = parse::new_parser_from_tts(cx.parse_sess(), tts.into()); - // The `expand_expr` method is called so that any macro calls in the - // parsed expression are expanded. - - let mut item = match parser.parse_impl_item(&mut false) { - Ok(s) => s, - Err(mut diagnostic) => { - diagnostic.emit(); - return DummyResult::any(sp); - } - }; - - if let Err(mut diagnostic) = parser.expect(&token::Eof) { - diagnostic.emit(); - return DummyResult::any(sp); - } - - convert(&mut item.ident); - MacEager::impl_items(SmallVec::from_buf([item])) -} - -fn ty_snake_to_camel(cx: &mut ExtCtxt, sp: Span, tts: &[TokenTree]) -> Box { - let mut parser = parse::new_parser_from_tts(cx.parse_sess(), tts.into()); - // The `expand_expr` method is called so that any macro calls in the - // parsed expression are expanded. - - let mut path = match parser.parse_path(PathStyle::Type) { - Ok(s) => s, - Err(mut diagnostic) => { - diagnostic.emit(); - return DummyResult::any(sp); - } - }; - - if let Err(mut diagnostic) = parser.expect(&token::Eof) { - diagnostic.emit(); - return DummyResult::any(sp); - } - - // Only capitalize the final segment - convert(&mut path.segments - .last_mut() - .unwrap() - .ident); - MacEager::ty(P(Ty { - id: ast::DUMMY_NODE_ID, - node: TyKind::Path(None, path), - span: sp, - })) -} - -/// Converts an ident in-place to CamelCase and returns the previous ident. -fn convert(ident: &mut Ident) -> String { - let ident_str = ident.to_string(); - let mut camel_ty = String::new(); - - { - // Find the first non-underscore and add it capitalized. - let mut chars = ident_str.chars(); - - // Find the first non-underscore char, uppercase it, and append it. - // Guaranteed to succeed because all idents must have at least one non-underscore char. - camel_ty.extend(chars.find(|&c| c != '_').unwrap().to_uppercase()); - - // When we find an underscore, we remove it and capitalize the next char. To do this, - // we need to ensure the next char is not another underscore. - let mut chars = chars.coalesce(|c1, c2| { - if c1 == '_' && c2 == '_' { - Ok(c1) - } else { - Err((c1, c2)) - } - }); - - while let Some(c) = chars.next() { - if c != '_' { - camel_ty.push(c); - } else if let Some(c) = chars.next() { - camel_ty.extend(c.to_uppercase()); - } - } - } - - // The Fut suffix is hardcoded right now; this macro isn't really meant to be general-purpose. - camel_ty.push_str("Fut"); - - *ident = Ident::with_empty_ctxt(Symbol::intern(&camel_ty)); - ident_str -} - -trait ParseTraitRef { - fn parse_trait_ref(&mut self) -> PResult; -} - -impl<'a> ParseTraitRef for Parser<'a> { - /// Parse a::B - fn parse_trait_ref(&mut self) -> PResult { - Ok(TraitRef { - path: self.parse_path(PathStyle::Type)?, - ref_id: ast::DUMMY_NODE_ID, - }) - } -} - -#[plugin_registrar] -#[doc(hidden)] -pub fn plugin_registrar(reg: &mut Registry) { - reg.register_macro("snake_to_camel", snake_to_camel); - reg.register_macro("impl_snake_to_camel", impl_snake_to_camel); - reg.register_macro("ty_snake_to_camel", ty_snake_to_camel); -} diff --git a/src/protocol.rs b/src/protocol.rs deleted file mode 100644 index a80bee17..00000000 --- a/src/protocol.rs +++ /dev/null @@ -1,248 +0,0 @@ -// Copyright 2016 Google Inc. All Rights Reserved. -// -// Licensed under the MIT License, . -// This file may not be copied, modified, or distributed except according to those terms. - -use bincode; -use byteorder::{BigEndian, ByteOrder}; -use bytes::BytesMut; -use bytes::buf::BufMut; -use serde; -use std::io; -use std::marker::PhantomData; -use std::mem; -use tokio_io::{AsyncRead, AsyncWrite}; -use tokio_codec::{Encoder, Decoder, Framed}; -use tokio_proto::multiplex::{ClientProto, ServerProto}; -use tokio_proto::streaming::multiplex::RequestId; - -// `Encode` is the type that `Codec` encodes. `Decode` is the type it decodes. -#[derive(Debug)] -pub struct Codec { - max_payload_size: u64, - state: CodecState, - _phantom_data: PhantomData<(Encode, Decode)>, -} - -#[derive(Debug)] -enum CodecState { - Id, - Len { id: u64 }, - Payload { id: u64, len: u64 }, -} - -impl Codec { - fn new(max_payload_size: u64) -> Self { - Codec { - max_payload_size, - state: CodecState::Id, - _phantom_data: PhantomData, - } - } -} - -fn too_big(payload_size: u64, max_payload_size: u64) -> io::Error { - warn!( - "Not sending too-big packet of size {} (max is {})", - payload_size, - max_payload_size - ); - io::Error::new( - io::ErrorKind::InvalidData, - format!( - "Maximum payload size is {} bytes but got a payload of {}", - max_payload_size, - payload_size - ), - ) -} - -impl Encoder for Codec -where - Encode: serde::Serialize, - Decode: serde::de::DeserializeOwned, -{ - type Item = (RequestId, Encode); - type Error = io::Error; - - fn encode(&mut self, (id, message): Self::Item, buf: &mut BytesMut) -> io::Result<()> { - let payload_size = bincode::serialized_size(&message).map_err(|serialize_err| { - io::Error::new(io::ErrorKind::Other, serialize_err) - })?; - if payload_size > self.max_payload_size { - return Err(too_big(payload_size, self.max_payload_size)); - } - let message_size = 2 * mem::size_of::() + payload_size as usize; - buf.reserve(message_size); - buf.put_u64_be(id); - trace!("Encoded request id = {} as {:?}", id, buf); - buf.put_u64_be(payload_size); - bincode::serialize_into(&mut buf.writer(), &message) - .map_err(|serialize_err| { - io::Error::new(io::ErrorKind::Other, serialize_err) - })?; - trace!("Encoded buffer: {:?}", buf); - Ok(()) - } -} - -impl Decoder for Codec -where - Decode: serde::de::DeserializeOwned, -{ - type Item = (RequestId, Result); - type Error = io::Error; - - fn decode(&mut self, buf: &mut BytesMut) -> io::Result> { - use self::CodecState::*; - trace!("Codec::decode: {:?}", buf); - - loop { - match self.state { - Id if buf.len() < mem::size_of::() => { - trace!("--> Buf len is {}; waiting for 8 to parse id.", buf.len()); - return Ok(None); - } - Id => { - let mut id_buf = buf.split_to(mem::size_of::()); - let id = BigEndian::read_u64(&*id_buf); - trace!("--> Parsed id = {} from {:?}", id, id_buf); - self.state = Len { id }; - } - Len { .. } if buf.len() < mem::size_of::() => { - trace!( - "--> Buf len is {}; waiting for 8 to parse packet length.", - buf.len() - ); - return Ok(None); - } - Len { id } => { - let len_buf = buf.split_to(mem::size_of::()); - let len = BigEndian::read_u64(&*len_buf); - trace!( - "--> Parsed payload length = {}, remaining buffer length = {}", - len, - buf.len() - ); - if len > self.max_payload_size { - return Err(too_big(len, self.max_payload_size)); - } - self.state = Payload { id, len }; - } - Payload { len, .. } if buf.len() < len as usize => { - trace!( - "--> Buf len is {}; waiting for {} to parse payload.", - buf.len(), - len - ); - return Ok(None); - } - Payload { id, len } => { - let payload = buf.split_to(len as usize); - let result = bincode::deserialize(&payload); - // Reset the state machine because, either way, we're done processing this - // message. - self.state = Id; - - return Ok(Some((id, result))); - } - } - } - } -} - -/// Implements the `multiplex::ServerProto` trait. -#[derive(Debug)] -pub struct Proto { - max_payload_size: u64, - _phantom_data: PhantomData<(Encode, Decode)>, -} - -impl Proto { - /// Returns a new `Proto`. - pub fn new(max_payload_size: u64) -> Self { - Proto { - max_payload_size: max_payload_size, - _phantom_data: PhantomData, - } - } -} - -impl ServerProto for Proto -where - T: AsyncRead + AsyncWrite + 'static, - Encode: serde::Serialize + 'static, - Decode: serde::de::DeserializeOwned + 'static, -{ - type Response = Encode; - type Request = Result; - type Transport = Framed>; - type BindTransport = Result; - - fn bind_transport(&self, io: T) -> Self::BindTransport { - Ok(Framed::new(io, Codec::new(self.max_payload_size))) - } -} - -impl ClientProto for Proto -where - T: AsyncRead + AsyncWrite + 'static, - Encode: serde::Serialize + 'static, - Decode: serde::de::DeserializeOwned + 'static, -{ - type Response = Result; - type Request = Encode; - type Transport = Framed>; - type BindTransport = Result; - - fn bind_transport(&self, io: T) -> Self::BindTransport { - Ok(Framed::new(io, Codec::new(self.max_payload_size))) - } -} - -#[test] -fn serialize() { - const MSG: (u64, (char, char, char)) = (4, ('a', 'b', 'c')); - let mut buf = BytesMut::with_capacity(10); - - // Serialize twice to check for idempotence. - for _ in 0..2 { - let mut codec: Codec<(char, char, char), (char, char, char)> = Codec::new(2_000_000); - codec.encode(MSG, &mut buf).unwrap(); - let actual: Result< - Option<(u64, Result<(char, char, char), bincode::Error>)>, - io::Error, - > = codec.decode(&mut buf); - - match actual { - Ok(Some((id, ref v))) if id == MSG.0 && *v.as_ref().unwrap() == MSG.1 => {} - bad => panic!("Expected {:?}, but got {:?}", Some(MSG), bad), - } - - assert!(buf.is_empty(), "Expected empty buf but got {:?}", buf); - } -} - -#[test] -fn deserialize_big() { - let mut codec: Codec, Vec> = Codec::new(24); - - let mut buf = BytesMut::with_capacity(40); - assert_eq!( - codec - .encode((0, vec![0; 24]), &mut buf) - .err() - .unwrap() - .kind(), - io::ErrorKind::InvalidData - ); - - // Header - buf.put_slice(&mut [0u8; 8]); - // Len - buf.put_slice(&mut [0u8, 0, 0, 0, 0, 0, 0, 25]); - assert_eq!( - codec.decode(&mut buf).err().unwrap().kind(), - io::ErrorKind::InvalidData - ); -} diff --git a/src/stream_type.rs b/src/stream_type.rs deleted file mode 100644 index b60ddd6c..00000000 --- a/src/stream_type.rs +++ /dev/null @@ -1,94 +0,0 @@ -use bytes::{Buf, BufMut}; -use futures::Poll; -use std::io; -use tokio_core::net::TcpStream; -use tokio_io::{AsyncRead, AsyncWrite}; -#[cfg(feature = "tls")] -use tokio_tls::TlsStream; - -#[derive(Debug)] -pub enum StreamType { - Tcp(TcpStream), - #[cfg(feature = "tls")] - Tls(TlsStream), -} - -impl From for StreamType { - fn from(stream: TcpStream) -> Self { - StreamType::Tcp(stream) - } -} - -#[cfg(feature = "tls")] -impl From> for StreamType { - fn from(stream: TlsStream) -> Self { - StreamType::Tls(stream) - } -} - -impl io::Read for StreamType { - fn read(&mut self, buf: &mut [u8]) -> io::Result { - match *self { - StreamType::Tcp(ref mut stream) => stream.read(buf), - #[cfg(feature = "tls")] - StreamType::Tls(ref mut stream) => stream.read(buf), - } - } -} - -impl io::Write for StreamType { - fn write(&mut self, buf: &[u8]) -> io::Result { - match *self { - StreamType::Tcp(ref mut stream) => stream.write(buf), - #[cfg(feature = "tls")] - StreamType::Tls(ref mut stream) => stream.write(buf), - } - } - - fn flush(&mut self) -> io::Result<()> { - match *self { - StreamType::Tcp(ref mut stream) => stream.flush(), - #[cfg(feature = "tls")] - StreamType::Tls(ref mut stream) => stream.flush(), - } - } -} - -impl AsyncRead for StreamType { - // By overriding this fn, `StreamType` is obliged to never read the uninitialized buffer. - // Most sane implementations would never have a reason to, and `StreamType` does not, so - // this is safe. - unsafe fn prepare_uninitialized_buffer(&self, buf: &mut [u8]) -> bool { - match *self { - StreamType::Tcp(ref stream) => stream.prepare_uninitialized_buffer(buf), - #[cfg(feature = "tls")] - StreamType::Tls(ref stream) => stream.prepare_uninitialized_buffer(buf), - } - } - - fn read_buf(&mut self, buf: &mut B) -> Poll { - match *self { - StreamType::Tcp(ref mut stream) => stream.read_buf(buf), - #[cfg(feature = "tls")] - StreamType::Tls(ref mut stream) => stream.read_buf(buf), - } - } -} - -impl AsyncWrite for StreamType { - fn shutdown(&mut self) -> Poll<(), io::Error> { - match *self { - StreamType::Tcp(ref mut stream) => stream.shutdown(), - #[cfg(feature = "tls")] - StreamType::Tls(ref mut stream) => stream.shutdown(), - } - } - - fn write_buf(&mut self, buf: &mut B) -> Poll { - match *self { - StreamType::Tcp(ref mut stream) => stream.write_buf(buf), - #[cfg(feature = "tls")] - StreamType::Tls(ref mut stream) => stream.write_buf(buf), - } - } -} diff --git a/src/sync/client.rs b/src/sync/client.rs deleted file mode 100644 index 53db42f4..00000000 --- a/src/sync/client.rs +++ /dev/null @@ -1,253 +0,0 @@ -use future::client::{Client as FutureClient, ClientExt as FutureClientExt, - Options as FutureOptions}; -use futures::{Future, Stream}; -use serde::Serialize; -use serde::de::DeserializeOwned; -use std::fmt; -use std::io; -use std::net::{SocketAddr, ToSocketAddrs}; -use std::sync::mpsc; -use std::thread; -#[cfg(feature = "tls")] -use tls::client::Context; -use tokio_core::reactor; -use tokio_proto::util::client_proxy::{ClientProxy, Receiver, pair}; -use tokio_service::Service; -use util::FirstSocketAddr; - -#[doc(hidden)] -pub struct Client { - proxy: ClientProxy>, -} - -impl Clone for Client { - fn clone(&self) -> Self { - Client { - proxy: self.proxy.clone(), - } - } -} - -impl fmt::Debug for Client { - fn fmt(&self, f: &mut fmt::Formatter) -> Result<(), fmt::Error> { - const PROXY: &str = "ClientProxy { .. }"; - f.debug_struct("Client").field("proxy", &PROXY).finish() - } -} - -impl Client -where - Req: Serialize + Send + 'static, - Resp: DeserializeOwned + Send + 'static, - E: DeserializeOwned + Send + 'static, -{ - /// Drives an RPC call for the given request. - pub fn call(&self, request: Req) -> Result> { - // Must call wait here to block on the response. - // The request handler relies on this fact to safely unwrap the - // oneshot send. - self.proxy.call(request).wait() - } -} - -/// Additional options to configure how the client connects and operates. -pub struct Options { - /// Max packet size in bytes. - max_payload_size: u64, - #[cfg(feature = "tls")] - tls_ctx: Option, -} - -impl Default for Options { - #[cfg(not(feature = "tls"))] - fn default() -> Self { - Options { - max_payload_size: 2_000_000, - } - } - - #[cfg(feature = "tls")] - fn default() -> Self { - Options { - max_payload_size: 2_000_000, - tls_ctx: None, - } - } -} - -impl Options { - /// Set the max payload size in bytes. The default is 2,000,000 (2 MB). - pub fn max_payload_size(mut self, bytes: u64) -> Self { - self.max_payload_size = bytes; - self - } - - /// Connect using the given `Context` - #[cfg(feature = "tls")] - pub fn tls(mut self, ctx: Context) -> Self { - self.tls_ctx = Some(ctx); - self - } -} - -impl fmt::Debug for Options { - fn fmt(&self, f: &mut fmt::Formatter) -> Result<(), fmt::Error> { - #[cfg(feature = "tls")] - const SOME: &str = "Some(_)"; - #[cfg(feature = "tls")] - const NONE: &str = "None"; - let mut f = f.debug_struct("Options"); - #[cfg(feature = "tls")] f.field("tls_ctx", if self.tls_ctx.is_some() { &SOME } else { &NONE }); - f.finish() - } -} - -impl Into for (reactor::Handle, Options) { - #[cfg(feature = "tls")] - fn into(self) -> FutureOptions { - let (handle, options) = self; - let mut opts = FutureOptions::default().max_payload_size(options.max_payload_size).handle(handle); - if let Some(tls_ctx) = options.tls_ctx { - opts = opts.tls(tls_ctx); - } - opts - } - - #[cfg(not(feature = "tls"))] - fn into(self) -> FutureOptions { - let (handle, options) = self; - FutureOptions::default().max_payload_size(options.max_payload_size).handle(handle) - } -} - -/// Extension methods for Clients. -pub trait ClientExt: Sized { - /// Connects to a server located at the given address. - fn connect(addr: A, options: Options) -> io::Result - where - A: ToSocketAddrs; -} - -impl ClientExt for Client -where - Req: Serialize + Send + 'static, - Resp: DeserializeOwned + Send + 'static, - E: DeserializeOwned + Send + 'static, -{ - fn connect(addr: A, options: Options) -> io::Result - where - A: ToSocketAddrs, - { - let addr = addr.try_first_socket_addr()?; - let (connect_tx, connect_rx) = mpsc::channel(); - thread::spawn(move || match RequestHandler::connect(addr, options) { - Ok((proxy, mut handler)) => { - connect_tx.send(Ok(proxy)).unwrap(); - handler.handle_requests(); - } - Err(e) => connect_tx.send(Err(e)).unwrap(), - }); - Ok(connect_rx.recv().unwrap()?) - } -} - -/// Forwards incoming requests of type `Req` -/// with expected response `Result>` -/// to service `S`. -struct RequestHandler { - reactor: reactor::Core, - client: S, - requests: Receiver>, -} - -impl RequestHandler> -where - Req: Serialize + Send + 'static, - Resp: DeserializeOwned + Send + 'static, - E: DeserializeOwned + Send + 'static, -{ - /// Creates a new `RequestHandler` by connecting a `FutureClient` to the given address - /// using the given options. - fn connect(addr: SocketAddr, options: Options) -> io::Result<(Client, Self)> { - let mut reactor = reactor::Core::new()?; - let options = (reactor.handle(), options).into(); - let client = reactor.run(FutureClient::connect(addr, options))?; - let (proxy, requests) = pair(); - Ok(( - Client { proxy }, - RequestHandler { - reactor, - client, - requests, - }, - )) - } -} - -impl RequestHandler -where - Req: Serialize + 'static, - Resp: DeserializeOwned + 'static, - E: DeserializeOwned + 'static, - S: Service>, - S::Future: 'static, -{ - fn handle_requests(&mut self) { - let RequestHandler { - ref mut reactor, - ref mut requests, - ref mut client, - } = *self; - let handle = reactor.handle(); - let requests = requests - .map(|result| { - match result { - Ok(req) => req, - // The ClientProxy never sends Err currently - Err(e) => panic!("Unimplemented error handling in RequestHandler: {}", e), - } - }) - .for_each(|(request, response_tx)| { - let request = client.call(request).then(move |response| { - // Safe to unwrap because clients always block on the response future. - response_tx - .send(response) - .map_err(|_| ()) - .expect("Client should block on response"); - Ok(()) - }); - handle.spawn(request); - Ok(()) - }); - reactor.run(requests).unwrap(); - } -} - -#[test] -fn handle_requests() { - use futures::future; - - struct Client; - impl Service for Client { - type Request = i32; - type Response = i32; - type Error = ::Error<()>; - type Future = future::FutureResult>; - - fn call(&self, req: i32) -> Self::Future { - future::ok(req) - } - } - - let (request, requests) = ::futures::sync::mpsc::unbounded(); - let reactor = reactor::Core::new().unwrap(); - let client = Client; - let mut request_handler = RequestHandler { - reactor, - client, - requests, - }; - // Test that `handle_requests` returns when all request senders are dropped. - drop(request); - request_handler.handle_requests(); -} diff --git a/src/sync/mod.rs b/src/sync/mod.rs deleted file mode 100644 index 79011a35..00000000 --- a/src/sync/mod.rs +++ /dev/null @@ -1,4 +0,0 @@ -/// Provides the base client stubs used by the service macro. -pub mod client; -/// Provides the base server boilerplate used by service implementations. -pub mod server; diff --git a/src/sync/server.rs b/src/sync/server.rs deleted file mode 100644 index 65f69eeb..00000000 --- a/src/sync/server.rs +++ /dev/null @@ -1,249 +0,0 @@ -use {bincode, future, num_cpus}; -use future::server::{Response, Shutdown}; -use futures::{Future, future as futures}; -use futures::sync::oneshot; -#[cfg(feature = "tls")] -use native_tls_inner::TlsAcceptor; -use serde::Serialize; -use serde::de::DeserializeOwned; -use std::fmt; -use std::io; -use std::net::SocketAddr; -use std::time::Duration; -use std::usize; -use thread_pool::{self, Sender, Task, ThreadPool}; -use tokio_core::reactor; -use tokio_service::{NewService, Service}; - -/// Additional options to configure how the server operates. -#[derive(Debug)] -pub struct Options { - thread_pool: thread_pool::Builder, - opts: future::server::Options, -} - -impl Default for Options { - fn default() -> Self { - let num_cpus = num_cpus::get(); - Options { - thread_pool: thread_pool::Builder::new() - .keep_alive(Duration::from_secs(60)) - .max_pool_size(num_cpus * 100) - .core_pool_size(num_cpus) - .work_queue_capacity(usize::MAX) - .name_prefix("request-thread-"), - opts: future::server::Options::default(), - } - } -} - -impl Options { - /// Set the max payload size in bytes. The default is 2,000,000 (2 MB). - pub fn max_payload_size(mut self, bytes: u64) -> Self { - self.opts = self.opts.max_payload_size(bytes); - self - } - - /// Sets the thread pool builder to use when creating the server's thread pool. - pub fn thread_pool(mut self, builder: thread_pool::Builder) -> Self { - self.thread_pool = builder; - self - } - - /// Set the `TlsAcceptor` - #[cfg(feature = "tls")] - pub fn tls(mut self, tls_acceptor: TlsAcceptor) -> Self { - self.opts = self.opts.tls(tls_acceptor); - self - } -} - -/// A handle to a bound server. Must be run to start serving requests. -#[must_use = "A server does nothing until `run` is called."] -pub struct Handle { - reactor: reactor::Core, - handle: future::server::Handle, - server: Box>, -} - -impl Handle { - /// Runs the server on the current thread, blocking indefinitely. - pub fn run(mut self) { - trace!("Running..."); - match self.reactor.run(self.server) { - Ok(()) => debug!("Server successfully shutdown."), - Err(()) => debug!("Server shutdown due to error."), - } - } - - /// Returns a hook for shutting down the server. - pub fn shutdown(&self) -> Shutdown { - self.handle.shutdown().clone() - } - - /// The socket address the server is bound to. - pub fn addr(&self) -> SocketAddr { - self.handle.addr() - } -} - -impl fmt::Debug for Handle { - fn fmt(&self, f: &mut fmt::Formatter) -> Result<(), fmt::Error> { - const SERVER: &str = "Box>"; - - f.debug_struct("Handle") - .field("reactor", &self.reactor) - .field("handle", &self.handle) - .field("server", &SERVER) - .finish() - } -} - -#[doc(hidden)] -pub fn listen(new_service: S, - addr: SocketAddr, - options: Options) - -> io::Result - where S: NewService, - Response = Response, - Error = io::Error> + 'static, - ::Future: Send + 'static, - S::Response: Send, - S::Error: Send, - Req: DeserializeOwned + 'static, - Resp: Serialize + 'static, - E: Serialize + 'static -{ - let new_service = NewThreadService::new(new_service, options.thread_pool); - let reactor = reactor::Core::new()?; - let (handle, server) = - future::server::listen(new_service, addr, &reactor.handle(), options.opts)?; - let server = Box::new(server); - Ok(Handle { - reactor: reactor, - handle: handle, - server: server, - }) -} - -/// A service that uses a thread pool. -struct NewThreadService -where - S: NewService, -{ - new_service: S, - sender: Sender::Future>>, - _pool: ThreadPool::Future>>, -} - -/// A service that runs by executing request handlers in a thread pool. -struct ThreadService -where - S: Service, -{ - service: S, - sender: Sender>, -} - -/// A task that handles a single request. -struct ServiceTask -where - F: Future, -{ - future: F, - tx: oneshot::Sender>, -} - -impl NewThreadService -where - S: NewService, - ::Future: Send + 'static, - S::Response: Send, - S::Error: Send, -{ - /// Create a NewThreadService by wrapping another service. - fn new(new_service: S, pool: thread_pool::Builder) -> Self { - let (sender, _pool) = pool.build(); - NewThreadService { - new_service, - sender, - _pool, - } - } -} - -impl NewService for NewThreadService -where - S: NewService, - ::Future: Send + 'static, - S::Response: Send, - S::Error: Send, -{ - type Request = S::Request; - type Response = S::Response; - type Error = S::Error; - type Instance = ThreadService; - - fn new_service(&self) -> io::Result { - Ok(ThreadService { - service: self.new_service.new_service()?, - sender: self.sender.clone(), - }) - } -} - -impl Task for ServiceTask -where - F: Future + Send + 'static, - F::Item: Send, - F::Error: Send, -{ - fn run(self) { - // Don't care if sending fails. It just means the request is no longer - // being handled (I think). - let _ = self.tx.send(self.future.wait()); - } -} - -impl Service for ThreadService -where - S: Service, - S::Future: Send + 'static, - S::Response: Send, - S::Error: Send, -{ - type Request = S::Request; - type Response = S::Response; - type Error = S::Error; - type Future = futures::AndThen< - futures::MapErr< - oneshot::Receiver>, - fn(oneshot::Canceled) -> Self::Error, - >, - Result, - fn(Result) - -> Result, - >; - - fn call(&self, request: Self::Request) -> Self::Future { - let (tx, rx) = oneshot::channel(); - self.sender - .send(ServiceTask { - future: self.service.call(request), - tx: tx, - }) - .unwrap(); - rx.map_err(unreachable as _).and_then(ident) - } -} - -fn unreachable(t: T) -> U -where - T: fmt::Display, -{ - unreachable!(t) -} - -fn ident(t: T) -> T { - t -} diff --git a/src/tls.rs b/src/tls.rs deleted file mode 100644 index f5bd1e60..00000000 --- a/src/tls.rs +++ /dev/null @@ -1,50 +0,0 @@ -/// TLS-specific functionality for clients. -pub mod client { - use native_tls::{Error, TlsConnector}; - use std::fmt; - - /// TLS context for client - pub struct Context { - /// Domain to connect to - pub domain: String, - /// TLS connector - pub tls_connector: TlsConnector, - } - - impl Context { - /// Try to construct a new `Context`. - /// - /// The provided domain will be used for both - /// [SNI](https://en.wikipedia.org/wiki/Server_Name_Indication) and certificate hostname - /// validation. - pub fn new>(domain: S) -> Result { - Ok(Context { - domain: domain.into(), - tls_connector: TlsConnector::builder()?.build()?, - }) - } - - /// Construct a new `Context` using the provided domain and `TlsConnector` - /// - /// The domain will be used for both - /// [SNI](https://en.wikipedia.org/wiki/Server_Name_Indication) and certificate hostname - /// validation. - pub fn from_connector>(domain: S, tls_connector: TlsConnector) -> Self { - Context { - domain: domain.into(), - tls_connector: tls_connector, - } - } - } - - impl fmt::Debug for Context { - fn fmt(&self, f: &mut fmt::Formatter) -> Result<(), fmt::Error> { - const TLS_CONNECTOR: &str = "TlsConnector { .. }"; - f.debug_struct("Context") - .field("domain", &self.domain) - .field("tls_connector", &TLS_CONNECTOR) - .finish() - } - } - -} diff --git a/src/util.rs b/src/util.rs deleted file mode 100644 index 5ddb64a6..00000000 --- a/src/util.rs +++ /dev/null @@ -1,185 +0,0 @@ -// Copyright 2016 Google Inc. All Rights Reserved. -// -// Licensed under the MIT License, . -// This file may not be copied, modified, or distributed except according to those terms. - -use futures::{Future, IntoFuture, Poll}; -use futures::stream::Stream; -use serde::{Deserialize, Deserializer, Serialize, Serializer}; -use std::{fmt, io, mem}; -use std::error::Error; -use std::net::{SocketAddr, ToSocketAddrs}; - -/// A bottom type that impls `Error`, `Serialize`, and `Deserialize`. It is impossible to -/// instantiate this type. -#[allow(unreachable_code)] -pub struct Never(!); - -impl fmt::Debug for Never { - fn fmt(&self, _: &mut fmt::Formatter) -> fmt::Result { - self.0 - } -} - -impl Error for Never { - fn description(&self) -> &str { - self.0 - } -} - -impl fmt::Display for Never { - fn fmt(&self, _: &mut fmt::Formatter) -> fmt::Result { - self.0 - } -} - -impl Future for Never { - type Item = Never; - type Error = Never; - - fn poll(&mut self) -> Poll { - self.0 - } -} - -impl Stream for Never { - type Item = Never; - type Error = Never; - - fn poll(&mut self) -> Poll, Self::Error> { - self.0 - } -} - -impl Serialize for Never { - fn serialize(&self, _: S) -> Result - where - S: Serializer, - { - self.0 - } -} - -// Please don't try to deserialize this. :( -impl<'a> Deserialize<'a> for Never { - fn deserialize(_: D) -> Result - where - D: Deserializer<'a>, - { - panic!("Never cannot be instantiated!"); - } -} - -/// A `String` that impls `std::error::Error`. Useful for quick-and-dirty error propagation. -#[derive(Debug, Serialize, Deserialize)] -pub struct Message(pub String); - -impl Error for Message { - fn description(&self) -> &str { - &self.0 - } -} - -impl fmt::Display for Message { - fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - fmt::Display::fmt(&self.0, f) - } -} - -impl> From for Message { - fn from(s: S) -> Self { - Message(s.into()) - } -} - - -/// Provides a utility method for more ergonomically parsing a `SocketAddr` when only one is -/// needed. -pub trait FirstSocketAddr: ToSocketAddrs { - /// Returns the first resolved `SocketAddr`, if one exists. - fn try_first_socket_addr(&self) -> io::Result { - if let Some(a) = self.to_socket_addrs()?.next() { - Ok(a) - } else { - Err(io::Error::new( - io::ErrorKind::AddrNotAvailable, - "`ToSocketAddrs::to_socket_addrs` returned an empty iterator.", - )) - } - } - - /// Returns the first resolved `SocketAddr` or panics otherwise. - fn first_socket_addr(&self) -> SocketAddr { - self.try_first_socket_addr().unwrap() - } -} - -impl FirstSocketAddr for A {} - -/// Creates a new future which will eventually be the same as the one created -/// by calling the closure provided with the arguments provided. -/// -/// The provided closure is only run once the future has a callback scheduled -/// on it, otherwise the callback never runs. Once run, however, this future is -/// the same as the one the closure creates. -pub fn lazy(f: F, args: A) -> Lazy -where - F: FnOnce(A) -> R, - R: IntoFuture, -{ - Lazy { - inner: _Lazy::First(f, args), - } -} - -/// A future which defers creation of the actual future until a callback is -/// scheduled. -/// -/// This is created by the `lazy` function. -#[derive(Debug)] -#[must_use = "futures do nothing unless polled"] -pub struct Lazy { - inner: _Lazy, -} - -#[derive(Debug)] -enum _Lazy { - First(F, A), - Second(R), - Moved, -} - -impl Lazy -where - F: FnOnce(A) -> R, - R: IntoFuture, -{ - fn get(&mut self) -> &mut R::Future { - match self.inner { - _Lazy::First(..) => {} - _Lazy::Second(ref mut f) => return f, - _Lazy::Moved => panic!(), // can only happen if `f()` panics - } - match mem::replace(&mut self.inner, _Lazy::Moved) { - _Lazy::First(f, args) => self.inner = _Lazy::Second(f(args).into_future()), - _ => panic!(), // we already found First - } - match self.inner { - _Lazy::Second(ref mut f) => f, - _ => panic!(), // we just stored Second - } - } -} - -impl Future for Lazy -where - F: FnOnce(A) -> R, - R: IntoFuture, -{ - type Item = R::Item; - type Error = R::Error; - - fn poll(&mut self) -> Poll { - self.get().poll() - } -} diff --git a/tarpc/Cargo.toml b/tarpc/Cargo.toml new file mode 100644 index 00000000..cc9d21b6 --- /dev/null +++ b/tarpc/Cargo.toml @@ -0,0 +1,39 @@ +cargo-features = ["namespaced-features"] + +[package] +name = "tarpc" +version = "0.12.1" +authors = ["Adam Wright ", "Tim Kuehn "] +edition = "2018" +namespaced-features = true +license = "MIT" +documentation = "https://docs.rs/tarpc" +homepage = "https://github.com/google/tarpc" +repository = "https://github.com/google/tarpc" +keywords = ["rpc", "network", "server", "api", "tls"] +categories = ["asynchronous", "network-programming"] +readme = "README.md" +description = "An RPC framework for Rust with a focus on ease of use." + +[features] +serde = ["rpc/serde", "crate:serde", "serde/derive"] + +[badges] +travis-ci = { repository = "google/tarpc" } + +[dependencies] +log = "0.4" +serde = { optional = true, version = "1.0" } +tarpc-plugins = { path = "../plugins", version = "0.4.0" } +rpc = { path = "../rpc" } + +[target.'cfg(not(test))'.dependencies] +futures-preview = { git = "https://github.com/rust-lang-nursery/futures-rs" } + +[dev-dependencies] +humantime = "1.0" +futures-preview = { git = "https://github.com/rust-lang-nursery/futures-rs", features = ["compat", "tokio-compat"] } +bincode-transport = { path = "../bincode-transport" } +env_logger = "0.5" +tokio = "0.1" +tokio-executor = "0.1" diff --git a/clippy.toml b/tarpc/clippy.toml similarity index 100% rename from clippy.toml rename to tarpc/clippy.toml diff --git a/tarpc/examples/pubsub.rs b/tarpc/examples/pubsub.rs new file mode 100644 index 00000000..1a8d2a39 --- /dev/null +++ b/tarpc/examples/pubsub.rs @@ -0,0 +1,190 @@ +// Copyright 2016 Google Inc. All Rights Reserved. +// +// Licensed under the MIT License, . +// This file may not be copied, modified, or distributed except according to those terms. + +#![feature( + arbitrary_self_types, + pin, + futures_api, + await_macro, + async_await, + existential_type, + proc_macro_hygiene, +)] + +use futures::{ + future::{self, Ready}, + prelude::*, + Future, +}; +use rpc::{ + client, context, + server::{self, Handler, Server}, +}; +use std::{ + collections::HashMap, + io, + net::SocketAddr, + sync::{Arc, Mutex}, + thread, + time::Duration, +}; + +pub mod subscriber { + tarpc::service! { + rpc receive(message: String); + } +} + +pub mod publisher { + use std::net::SocketAddr; + tarpc::service! { + rpc broadcast(message: String); + rpc subscribe(id: u32, address: SocketAddr) -> Result<(), String>; + rpc unsubscribe(id: u32); + } +} + +#[derive(Clone, Debug)] +struct Subscriber { + id: u32, +} + +impl subscriber::Service for Subscriber { + type ReceiveFut = Ready<()>; + + fn receive(&self, _: context::Context, message: String) -> Self::ReceiveFut { + println!("{} received message: {}", self.id, message); + future::ready(()) + } +} + +impl Subscriber { + async fn listen(id: u32, config: server::Config) -> io::Result { + let incoming = bincode_transport::listen(&"0.0.0.0:0".parse().unwrap())?; + let addr = incoming.local_addr(); + tokio_executor::spawn( + Server::new(config) + .incoming(incoming) + .take(1) + .respond_with(subscriber::serve(Subscriber { id })) + .unit_error() + .boxed() + .compat() + ); + Ok(addr) + } +} + +#[derive(Clone, Debug)] +struct Publisher { + clients: Arc>>, +} + +impl Publisher { + fn new() -> Publisher { + Publisher { + clients: Arc::new(Mutex::new(HashMap::new())), + } + } +} + +impl publisher::Service for Publisher { + existential type BroadcastFut: Future; + + fn broadcast(&self, _: context::Context, message: String) -> Self::BroadcastFut { + async fn broadcast(clients: Arc>>, message: String) { + let mut clients = clients.lock().unwrap().clone(); + for client in clients.values_mut() { + // Ignore failing subscribers. In a real pubsub, + // you'd want to continually retry until subscribers + // ack. + let _ = await!(client.receive(context::current(), message.clone())); + } + } + + broadcast(self.clients.clone(), message) + } + + existential type SubscribeFut: Future>; + + fn subscribe(&self, _: context::Context, id: u32, addr: SocketAddr) -> Self::SubscribeFut { + async fn subscribe( + clients: Arc>>, + id: u32, + addr: SocketAddr, + ) -> io::Result<()> { + let conn = await!(bincode_transport::connect(&addr))?; + let subscriber = await!(subscriber::new_stub(client::Config::default(), conn))?; + println!("Subscribing {}.", id); + clients.lock().unwrap().insert(id, subscriber); + Ok(()) + } + + subscribe(Arc::clone(&self.clients), id, addr).map_err(|e| e.to_string()) + } + + existential type UnsubscribeFut: Future; + + fn unsubscribe(&self, _: context::Context, id: u32) -> Self::UnsubscribeFut { + println!("Unsubscribing {}", id); + let mut clients = self.clients.lock().unwrap(); + if let None = clients.remove(&id) { + eprintln!( + "Client {} not found. Existings clients: {:?}", + id, &*clients + ); + } + future::ready(()) + } +} + +async fn run() -> io::Result<()> { + env_logger::init(); + let transport = bincode_transport::listen(&"0.0.0.0:0".parse().unwrap())?; + let publisher_addr = transport.local_addr(); + tokio_executor::spawn( + Server::new(server::Config::default()) + .incoming(transport) + .take(1) + .respond_with(publisher::serve(Publisher::new())) + .unit_error() + .boxed() + .compat() + ); + + let subscriber1 = await!(Subscriber::listen(0, server::Config::default()))?; + let subscriber2 = await!(Subscriber::listen(1, server::Config::default()))?; + + let publisher_conn = bincode_transport::connect(&publisher_addr); + let publisher_conn = await!(publisher_conn)?; + let mut publisher = await!(publisher::new_stub( + client::Config::default(), + publisher_conn + ))?; + + if let Err(e) = await!(publisher.subscribe(context::current(), 0, subscriber1))? { + eprintln!("Couldn't subscribe subscriber 0: {}", e); + } + if let Err(e) = await!(publisher.subscribe(context::current(), 1, subscriber2))? { + eprintln!("Couldn't subscribe subscriber 1: {}", e); + } + + println!("Broadcasting..."); + await!(publisher.broadcast(context::current(), "hello to all".to_string()))?; + await!(publisher.unsubscribe(context::current(), 1))?; + await!(publisher.broadcast(context::current(), "hi again".to_string()))?; + Ok(()) +} + +fn main() { + tokio::run( + run() + .boxed() + .map_err(|e| panic!(e)) + .boxed() + .compat(), + ); + thread::sleep(Duration::from_millis(100)); +} diff --git a/tarpc/examples/readme.rs b/tarpc/examples/readme.rs new file mode 100644 index 00000000..65fb0d68 --- /dev/null +++ b/tarpc/examples/readme.rs @@ -0,0 +1,90 @@ +// Copyright 2016 Google Inc. All Rights Reserved. +// +// Licensed under the MIT License, . +// This file may not be copied, modified, or distributed except according to those terms. + +#![feature( + futures_api, + pin, + arbitrary_self_types, + await_macro, + async_await, + proc_macro_hygiene, +)] + +use futures::{ + future::{self, Ready}, + prelude::*, +}; +use rpc::{ + client, context, + server::{self, Handler, Server}, +}; +use std::io; + +// This is the service definition. It looks a lot like a trait definition. +// It defines one RPC, hello, which takes one arg, name, and returns a String. + +tarpc::service! { + rpc hello(name: String) -> String; +} + +// This is the type that implements the generated Service trait. It is the business logic +// and is used to start the server. +#[derive(Clone)] +struct HelloServer; + +impl Service for HelloServer { + // Each defined rpc generates two items in the trait, a fn that serves the RPC, and + // an associated type representing the future output by the fn. + + type HelloFut = Ready; + + fn hello(&self, _: context::Context, name: String) -> Self::HelloFut { + future::ready(format!("Hello, {}!", name)) + } +} + +async fn run() -> io::Result<()> { + // bincode_transport is provided by the associated crate bincode-transport. It makes it easy + // to start up a serde-powered bincode serialization strategy over TCP. + let transport = bincode_transport::listen(&"0.0.0.0:0".parse().unwrap())?; + let addr = transport.local_addr(); + + // The server is configured with the defaults. + let server = Server::new(server::Config::default()) + // Server can listen on any type that implements the Transport trait. + .incoming(transport) + // Close the stream after the client connects + .take(1) + // serve is generated by the tarpc::service! macro. It takes as input any type implementing + // the generated Service trait. + .respond_with(serve(HelloServer)); + + tokio_executor::spawn(server.unit_error().boxed().compat()); + + let transport = await!(bincode_transport::connect(&addr))?; + + // new_stub is generated by the tarpc::service! macro. Like Server, it takes a config and any + // Transport as input, and returns a Client, also generated by the macro. + // by the service mcro. + let mut client = await!(new_stub(client::Config::default(), transport))?; + + // The client has an RPC method for each RPC defined in tarpc::service!. It takes the same args + // as defined, with the addition of a Context, which is always the first arg. The Context + // specifies a deadline and trace information which can be helpful in debugging requests. + let hello = await!(client.hello(context::current(), "Stim".to_string()))?; + + println!("{}", hello); + + Ok(()) +} + +fn main() { + tokio::run( + run() + .map_err(|e| eprintln!("Oh no: {}", e)) + .boxed() + .compat(), + ); +} diff --git a/tarpc/examples/server_calling_server.rs b/tarpc/examples/server_calling_server.rs new file mode 100644 index 00000000..0854403a --- /dev/null +++ b/tarpc/examples/server_calling_server.rs @@ -0,0 +1,110 @@ +// Copyright 2016 Google Inc. All Rights Reserved. +// +// Licensed under the MIT License, . +// This file may not be copied, modified, or distributed except according to those terms. + +#![feature( + existential_type, + arbitrary_self_types, + pin, + futures_api, + await_macro, + async_await, + proc_macro_hygiene, +)] + +use crate::{add::Service as AddService, double::Service as DoubleService}; +use futures::{ + future::{self, Ready}, + prelude::*, +}; +use rpc::{ + client, context, + server::{self, Handler, Server}, +}; +use std::io; + +pub mod add { + tarpc::service! { + /// Add two ints together. + rpc add(x: i32, y: i32) -> i32; + } +} + +pub mod double { + tarpc::service! { + /// 2 * x + rpc double(x: i32) -> Result; + } +} + +#[derive(Clone)] +struct AddServer; + +impl AddService for AddServer { + type AddFut = Ready; + + fn add(&self, _: context::Context, x: i32, y: i32) -> Self::AddFut { + future::ready(x + y) + } +} + +#[derive(Clone)] +struct DoubleServer { + add_client: add::Client, +} + +impl DoubleService for DoubleServer { + existential type DoubleFut: Future> + Send; + + fn double(&self, _: context::Context, x: i32) -> Self::DoubleFut { + async fn double(mut client: add::Client, x: i32) -> Result { + let result = await!(client.add(context::current(), x, x)); + result.map_err(|e| e.to_string()) + } + + double(self.add_client.clone(), x) + } +} + +async fn run() -> io::Result<()> { + let add_listener = bincode_transport::listen(&"0.0.0.0:0".parse().unwrap())?; + let addr = add_listener.local_addr(); + let add_server = Server::new(server::Config::default()) + .incoming(add_listener) + .take(1) + .respond_with(add::serve(AddServer)); + tokio_executor::spawn(add_server.unit_error().boxed().compat()); + + let to_add_server = await!(bincode_transport::connect(&addr))?; + let add_client = await!(add::new_stub(client::Config::default(), to_add_server))?; + + let double_listener = bincode_transport::listen(&"0.0.0.0:0".parse().unwrap())?; + let addr = double_listener.local_addr(); + let double_server = rpc::Server::new(server::Config::default()) + .incoming(double_listener) + .take(1) + .respond_with(double::serve(DoubleServer { add_client })); + tokio_executor::spawn(double_server.unit_error().boxed().compat()); + + let to_double_server = await!(bincode_transport::connect(&addr))?; + let mut double_client = await!(double::new_stub( + client::Config::default(), + to_double_server + ))?; + + for i in 1..=5 { + println!("{:?}", await!(double_client.double(context::current(), i))?); + } + Ok(()) +} + +fn main() { + env_logger::init(); + tokio::run( + run() + .map_err(|e| panic!(e)) + .boxed() + .compat(), + ); +} diff --git a/tarpc/rustfmt.toml b/tarpc/rustfmt.toml new file mode 100644 index 00000000..0ef5137d --- /dev/null +++ b/tarpc/rustfmt.toml @@ -0,0 +1 @@ +edition = "Edition2018" diff --git a/tarpc/src/lib.rs b/tarpc/src/lib.rs new file mode 100644 index 00000000..6e652c1c --- /dev/null +++ b/tarpc/src/lib.rs @@ -0,0 +1,135 @@ +// Copyright 2016 Google Inc. All Rights Reserved. +// +// Licensed under the MIT License, . +// This file may not be copied, modified, or distributed except according to those terms. + +//! tarpc is an RPC framework for rust with a focus on ease of use. Defining a +//! service can be done in just a few lines of code, and most of the boilerplate of +//! writing a server is taken care of for you. +//! +//! ## What is an RPC framework? +//! "RPC" stands for "Remote Procedure Call," a function call where the work of +//! producing the return value is being done somewhere else. When an rpc function is +//! invoked, behind the scenes the function contacts some other process somewhere +//! and asks them to evaluate the function instead. The original function then +//! returns the value produced by the other process. +//! +//! RPC frameworks are a fundamental building block of most microservices-oriented +//! architectures. Two well-known ones are [gRPC](http://www.grpc.io) and +//! [Cap'n Proto](https://capnproto.org/). +//! +//! tarpc differentiates itself from other RPC frameworks by defining the schema in code, +//! rather than in a separate language such as .proto. This means there's no separate compilation +//! process, and no cognitive context switching between different languages. Additionally, it +//! works with the community-backed library serde: any serde-serializable type can be used as +//! arguments to tarpc fns. +//! +//! ## Example +//! +//! Here's a small service. +//! +//! ```rust +//! #![feature(futures_api, pin, arbitrary_self_types, await_macro, async_await, proc_macro_hygiene)] +//! +//! +//! use futures::{ +//! compat::TokioDefaultSpawner, +//! future::{self, Ready}, +//! prelude::*, +//! }; +//! use tarpc::{ +//! client, context, +//! server::{self, Handler, Server}, +//! }; +//! use std::io; +//! +//! // This is the service definition. It looks a lot like a trait definition. +//! // It defines one RPC, hello, which takes one arg, name, and returns a String. +//! tarpc::service! { +//! /// Returns a greeting for name. +//! rpc hello(name: String) -> String; +//! } +//! +//! // This is the type that implements the generated Service trait. It is the business logic +//! // and is used to start the server. +//! #[derive(Clone)] +//! struct HelloServer; +//! +//! impl Service for HelloServer { +//! // Each defined rpc generates two items in the trait, a fn that serves the RPC, and +//! // an associated type representing the future output by the fn. +//! +//! type HelloFut = Ready; +//! +//! fn hello(&self, _: context::Context, name: String) -> Self::HelloFut { +//! future::ready(format!("Hello, {}!", name)) +//! } +//! } +//! +//! async fn run() -> io::Result<()> { +//! // bincode_transport is provided by the associated crate bincode-transport. It makes it easy +//! // to start up a serde-powered bincode serialization strategy over TCP. +//! let transport = bincode_transport::listen(&"0.0.0.0:0".parse().unwrap())?; +//! let addr = transport.local_addr(); +//! +//! // The server is configured with the defaults. +//! let server = Server::new(server::Config::default()) +//! // Server can listen on any type that implements the Transport trait. +//! .incoming(transport) +//! // Close the stream after the client connects +//! .take(1) +//! // serve is generated by the service! macro. It takes as input any type implementing +//! // the generated Service trait. +//! .respond_with(serve(HelloServer)); +//! +//! tokio_executor::spawn(server.unit_error().boxed().compat()); +//! +//! let transport = await!(bincode_transport::connect(&addr))?; +//! +//! // new_stub is generated by the service! macro. Like Server, it takes a config and any +//! // Transport as input, and returns a Client, also generated by the macro. +//! // by the service mcro. +//! let mut client = await!(new_stub(client::Config::default(), transport))?; +//! +//! // The client has an RPC method for each RPC defined in service!. It takes the same args +//! // as defined, with the addition of a Context, which is always the first arg. The Context +//! // specifies a deadline and trace information which can be helpful in debugging requests. +//! let hello = await!(client.hello(context::current(), "Stim".to_string()))?; +//! +//! println!("{}", hello); +//! +//! Ok(()) +//! } +//! +//! fn main() { +//! tarpc::init(TokioDefaultSpawner); +//! tokio::run(run() +//! .map_err(|e| eprintln!("Oh no: {}", e)) +//! .boxed() +//! .compat(), +//! ); +//! } +//! ``` + +#![deny(missing_docs, missing_debug_implementations)] +#![feature( + futures_api, + pin, + await_macro, + async_await, + decl_macro, +)] +#![cfg_attr(test, feature(proc_macro_hygiene, arbitrary_self_types))] + +#[doc(hidden)] +pub use futures; +pub use rpc::*; +#[cfg(feature = "serde")] +#[doc(hidden)] +pub use serde; +#[doc(hidden)] +pub use tarpc_plugins::*; + +/// Provides the macro used for constructing rpc services and client stubs. +#[macro_use] +mod macros; diff --git a/tarpc/src/macros.rs b/tarpc/src/macros.rs new file mode 100644 index 00000000..b335db44 --- /dev/null +++ b/tarpc/src/macros.rs @@ -0,0 +1,363 @@ +// Copyright 2016 Google Inc. All Rights Reserved. +// +// Licensed under the MIT License, . +// This file may not be copied, modified, or distributed except according to those terms. + +#[cfg(feature = "serde")] +#[doc(hidden)] +#[macro_export] +macro_rules! add_serde_if_enabled { + ($(#[$attr:meta])* -- $i:item) => { + $(#[$attr])* + #[derive($crate::serde::Serialize, $crate::serde::Deserialize)] + $i + } +} + +#[cfg(not(feature = "serde"))] +#[doc(hidden)] +#[macro_export] +macro_rules! add_serde_if_enabled { + ($(#[$attr:meta])* -- $i:item) => { + $(#[$attr])* + $i + } +} + +/// The main macro that creates RPC services. +/// +/// Rpc methods are specified, mirroring trait syntax: +/// +/// ``` +/// # #![feature(await_macro, pin, arbitrary_self_types, async_await, futures_api, proc_macro_hygiene)] +/// # fn main() {} +/// # tarpc::service! { +/// /// Say hello +/// rpc hello(name: String) -> String; +/// # } +/// ``` +/// +/// Attributes can be attached to each rpc. These attributes +/// will then be attached to the generated service traits' +/// corresponding `fn`s, as well as to the client stubs' RPCs. +/// +/// The following items are expanded in the enclosing module: +/// +/// * `trait Service` -- defines the RPC service. +/// * `fn serve` -- turns a service impl into a request handler. +/// * `Client` -- a client stub with a fn for each RPC. +/// * `fn new_stub` -- creates a new Client stub. +/// +#[macro_export] +macro_rules! service { +// Entry point + ( + $( + $(#[$attr:meta])* + rpc $fn_name:ident( $( $arg:ident : $in_:ty ),* ) $(-> $out:ty)*; + )* + ) => { + $crate::service! {{ + $( + $(#[$attr])* + rpc $fn_name( $( $arg : $in_ ),* ) $(-> $out)*; + )* + }} + }; +// Pattern for when the next rpc has an implicit unit return type. + ( + { + $(#[$attr:meta])* + rpc $fn_name:ident( $( $arg:ident : $in_:ty ),* ); + + $( $unexpanded:tt )* + } + $( $expanded:tt )* + ) => { + $crate::service! { + { $( $unexpanded )* } + + $( $expanded )* + + $(#[$attr])* + rpc $fn_name( $( $arg : $in_ ),* ) -> (); + } + }; +// Pattern for when the next rpc has an explicit return type. + ( + { + $(#[$attr:meta])* + rpc $fn_name:ident( $( $arg:ident : $in_:ty ),* ) -> $out:ty; + + $( $unexpanded:tt )* + } + $( $expanded:tt )* + ) => { + $crate::service! { + { $( $unexpanded )* } + + $( $expanded )* + + $(#[$attr])* + rpc $fn_name( $( $arg : $in_ ),* ) -> $out; + } + }; +// Pattern for when all return types have been expanded + ( + { } // none left to expand + $( + $(#[$attr:meta])* + rpc $fn_name:ident ( $( $arg:ident : $in_:ty ),* ) -> $out:ty; + )* + ) => { + $crate::add_serde_if_enabled! { + #[derive(Debug)] + #[doc(hidden)] + #[allow(non_camel_case_types, unused)] + -- + pub enum Request__ { + $( + $fn_name{ $($arg: $in_,)* } + ),* + } + } + + $crate::add_serde_if_enabled! { + #[derive(Debug)] + #[doc(hidden)] + #[allow(non_camel_case_types, unused)] + -- + pub enum Response__ { + $( + $fn_name($out) + ),* + } + } + + // TODO: proc_macro can't currently parse $crate, so this needs to be imported for the + // usage of snake_to_camel! to work. + use $crate::futures::Future as Future__; + + /// Defines the RPC service. The additional trait bounds are required so that services can + /// multiplex requests across multiple tasks, potentially on multiple threads. + pub trait Service: Clone + Send + 'static { + $( + $crate::snake_to_camel! { + /// The type of future returned by `{}`. + type $fn_name: Future__ + Send; + } + + $(#[$attr])* + fn $fn_name(&self, ctx: $crate::context::Context, $($arg:$in_),*) -> $crate::ty_snake_to_camel!(Self::$fn_name); + )* + } + + // TODO: use an existential type instead of this when existential types work. + #[allow(non_camel_case_types)] + pub enum Response { + $( + $fn_name($crate::ty_snake_to_camel!(::$fn_name)), + )* + } + + impl ::std::fmt::Debug for Response { + fn fmt(&self, fmt: &mut ::std::fmt::Formatter) -> ::std::fmt::Result { + fmt.debug_struct("Response").finish() + } + } + + impl ::std::future::Future for Response { + type Output = ::std::io::Result; + + fn poll(self: ::std::pin::Pin<&mut Self>, waker: &::std::task::LocalWaker) + -> ::std::task::Poll<::std::io::Result> + { + unsafe { + match ::std::pin::Pin::get_mut_unchecked(self) { + $( + Response::$fn_name(resp) => + ::std::pin::Pin::new_unchecked(resp) + .poll(waker) + .map(Response__::$fn_name) + .map(Ok), + )* + } + } + } + } + + /// Returns a serving function to use with rpc::server::Server. + pub fn serve(service: S) + -> impl FnMut($crate::context::Context, Request__) -> Response + Send + 'static + Clone { + move |ctx, req| { + match req { + $( + Request__::$fn_name{ $($arg,)* } => { + let resp = Service::$fn_name(&mut service.clone(), ctx, $($arg),*); + Response::$fn_name(resp) + } + )* + } + } + } + + #[allow(unused)] + #[derive(Clone, Debug)] + /// The client stub that makes RPC calls to the server. Exposes a Future interface. + pub struct Client($crate::client::Client); + + /// Returns a new client stub that sends requests over the given transport. + pub async fn new_stub(config: $crate::client::Config, transport: T) + -> ::std::io::Result + where + T: $crate::Transport< + Item = $crate::Response, + SinkItem = $crate::ClientMessage> + Send, + { + Ok(Client(await!($crate::client::Client::new(config, transport))?)) + } + + impl Client { + $( + #[allow(unused)] + $(#[$attr])* + pub fn $fn_name(&mut self, ctx: $crate::context::Context, $($arg: $in_),*) + -> impl ::std::future::Future> + '_ { + let request__ = Request__::$fn_name { $($arg,)* }; + let resp = self.0.call(ctx, request__); + async move { + match await!(resp)? { + Response__::$fn_name(msg__) => ::std::result::Result::Ok(msg__), + _ => unreachable!(), + } + } + } + )* + } + } +} + +// allow dead code; we're just testing that the macro expansion compiles +#[allow(dead_code)] +#[cfg(test)] +mod syntax_test { + service! { + #[deny(warnings)] + #[allow(non_snake_case)] + rpc TestCamelCaseDoesntConflict(); + rpc hello() -> String; + #[doc="attr"] + rpc attr(s: String) -> String; + rpc no_args_no_return(); + rpc no_args() -> (); + rpc one_arg(foo: String) -> i32; + rpc two_args_no_return(bar: String, baz: u64); + rpc two_args(bar: String, baz: u64) -> String; + rpc no_args_ret_error() -> i32; + rpc one_arg_ret_error(foo: String) -> String; + rpc no_arg_implicit_return_error(); + #[doc="attr"] + rpc one_arg_implicit_return_error(foo: String); + } +} + +#[cfg(test)] +mod functional_test { + use futures::{ + compat::TokioDefaultSpawner, + future::{ready, Ready}, + prelude::*, + }; + use rpc::{ + client, context, + server::{self, Handler}, + transport::channel, + }; + use std::io; + use tokio::runtime::current_thread; + + service! { + rpc add(x: i32, y: i32) -> i32; + rpc hey(name: String) -> String; + } + + #[derive(Clone)] + struct Server; + + impl Service for Server { + type AddFut = Ready; + + fn add(&self, _: context::Context, x: i32, y: i32) -> Self::AddFut { + ready(x + y) + } + + type HeyFut = Ready; + + fn hey(&self, _: context::Context, name: String) -> Self::HeyFut { + ready(format!("Hey, {}.", name)) + } + } + + #[test] + fn sequential() { + let _ = env_logger::try_init(); + rpc::init(TokioDefaultSpawner); + + let test = async { + let (tx, rx) = channel::unbounded(); + tokio_executor::spawn( + rpc::Server::new(server::Config::default()) + .incoming(stream::once(ready(Ok(rx)))) + .respond_with(serve(Server)) + .unit_error() + .boxed() + .compat() + ); + + let mut client = await!(new_stub(client::Config::default(), tx))?; + assert_eq!(3, await!(client.add(context::current(), 1, 2))?); + assert_eq!( + "Hey, Tim.", + await!(client.hey(context::current(), "Tim".to_string()))? + ); + Ok::<_, io::Error>(()) + } + .map_err(|e| panic!(e.to_string())); + + current_thread::block_on_all(test.boxed().compat()).unwrap(); + } + + #[test] + fn concurrent() { + let _ = env_logger::try_init(); + rpc::init(TokioDefaultSpawner); + + let test = async { + let (tx, rx) = channel::unbounded(); + tokio_executor::spawn( + rpc::Server::new(server::Config::default()) + .incoming(stream::once(ready(Ok(rx)))) + .respond_with(serve(Server)) + .unit_error() + .boxed() + .compat() + ); + + let client = await!(new_stub(client::Config::default(), tx))?; + let mut c = client.clone(); + let req1 = c.add(context::current(), 1, 2); + let mut c = client.clone(); + let req2 = c.add(context::current(), 3, 4); + let mut c = client.clone(); + let req3 = c.hey(context::current(), "Tim".to_string()); + + assert_eq!(3, await!(req1)?); + assert_eq!(7, await!(req2)?); + assert_eq!("Hey, Tim.", await!(req3)?); + Ok::<_, io::Error>(()) + } + .map_err(|e| panic!("test failed: {}", e)); + + current_thread::block_on_all(test.boxed().compat()).unwrap(); + } +} diff --git a/tarpc/tests/latency.rs b/tarpc/tests/latency.rs new file mode 100644 index 00000000..7def3bd5 --- /dev/null +++ b/tarpc/tests/latency.rs @@ -0,0 +1,130 @@ +// Copyright 2018 Google Inc. All Rights Reserved. +// +// Licensed under the MIT License, . +// This file may not be copied, modified, or distributed except according to those terms. + +#![feature( + test, + arbitrary_self_types, + pin, + integer_atomics, + futures_api, + generators, + await_macro, + async_await, + proc_macro_hygiene, +)] + +extern crate test; + +use self::test::stats::Stats; +use futures::{compat::TokioDefaultSpawner, future, prelude::*}; +use rpc::{ + client, context, + server::{self, Handler, Server}, +}; +use std::{ + io, + time::{Duration, Instant}, +}; + +mod ack { + tarpc::service! { + rpc ack(); + } +} + +#[derive(Clone)] +struct Serve; + +impl ack::Service for Serve { + type AckFut = future::Ready<()>; + + fn ack(&self, _: context::Context) -> Self::AckFut { + future::ready(()) + } +} + +async fn bench() -> io::Result<()> { + let listener = bincode_transport::listen(&"0.0.0.0:0".parse().unwrap())?; + let addr = listener.local_addr(); + + tokio_executor::spawn( + Server::new(server::Config::default()) + .incoming(listener) + .take(1) + .respond_with(ack::serve(Serve)) + .unit_error() + .boxed() + .compat() + ); + + let conn = await!(bincode_transport::connect(&addr))?; + let mut client = await!(ack::new_stub(client::Config::default(), conn))?; + + let total = 10_000usize; + let mut successful = 0u32; + let mut unsuccessful = 0u32; + let mut durations = vec![]; + for _ in 1..=total { + let now = Instant::now(); + let response = await!(client.ack(context::current())); + let elapsed = now.elapsed(); + + match response { + Ok(_) => successful += 1, + Err(_) => unsuccessful += 1, + }; + durations.push(elapsed); + } + + let durations_nanos = durations + .iter() + .map(|duration| duration.as_secs() as f64 * 1E9 + duration.subsec_nanos() as f64) + .collect::>(); + + let (lower, median, upper) = durations_nanos.quartiles(); + + println!("Of {:?} runs:", durations_nanos.len()); + println!("\tSuccessful: {:?}", successful); + println!("\tUnsuccessful: {:?}", unsuccessful); + println!( + "\tMean: {:?}", + Duration::from_nanos(durations_nanos.mean() as u64) + ); + println!("\tMedian: {:?}", Duration::from_nanos(median as u64)); + println!( + "\tStd Dev: {:?}", + Duration::from_nanos(durations_nanos.std_dev() as u64) + ); + println!( + "\tMin: {:?}", + Duration::from_nanos(durations_nanos.min() as u64) + ); + println!( + "\tMax: {:?}", + Duration::from_nanos(durations_nanos.max() as u64) + ); + println!( + "\tQuartiles: ({:?}, {:?}, {:?})", + Duration::from_nanos(lower as u64), + Duration::from_nanos(median as u64), + Duration::from_nanos(upper as u64) + ); + + println!("done"); + Ok(()) +} + +#[test] +fn bench_small_packet() { + env_logger::init(); + tarpc::init(TokioDefaultSpawner); + + tokio::run( + bench() + .map_err(|e| panic!(e.to_string())) + .boxed() + .compat(), + ) +} diff --git a/test/identity.p12 b/test/identity.p12 deleted file mode 100644 index d16abb8c..00000000 Binary files a/test/identity.p12 and /dev/null differ diff --git a/test/root-ca.der b/test/root-ca.der deleted file mode 100644 index a9335c6f..00000000 Binary files a/test/root-ca.der and /dev/null differ diff --git a/test/root-ca.pem b/test/root-ca.pem deleted file mode 100644 index 4ec2f538..00000000 --- a/test/root-ca.pem +++ /dev/null @@ -1,21 +0,0 @@ ------BEGIN CERTIFICATE----- -MIIDXTCCAkWgAwIBAgIJAOIvDiVb18eVMA0GCSqGSIb3DQEBCwUAMEUxCzAJBgNV -BAYTAkFVMRMwEQYDVQQIDApTb21lLVN0YXRlMSEwHwYDVQQKDBhJbnRlcm5ldCBX -aWRnaXRzIFB0eSBMdGQwHhcNMTYwODE0MTY1NjExWhcNMjYwODEyMTY1NjExWjBF -MQswCQYDVQQGEwJBVTETMBEGA1UECAwKU29tZS1TdGF0ZTEhMB8GA1UECgwYSW50 -ZXJuZXQgV2lkZ2l0cyBQdHkgTHRkMIIBIjANBgkqhkiG9w0BAQEFAAOCAQ8AMIIB -CgKCAQEArVHWFn52Lbl1l59exduZntVSZyDYpzDND+S2LUcO6fRBWhV/1Kzox+2G -ZptbuMGmfI3iAnb0CFT4uC3kBkQQlXonGATSVyaFTFR+jq/lc0SP+9Bd7SBXieIV -eIXlY1TvlwIvj3Ntw9zX+scTA4SXxH6M0rKv9gTOub2vCMSHeF16X8DQr4XsZuQr -7Cp7j1I4aqOJyap5JTl5ijmG8cnu0n+8UcRlBzy99dLWJG0AfI3VRJdWpGTNVZ92 -aFff3RpK3F/WI2gp3qV1ynRAKuvmncGC3LDvYfcc2dgsc1N6Ffq8GIrkgRob6eBc -klDHp1d023Lwre+VaVDSo1//Y72UFwIDAQABo1AwTjAdBgNVHQ4EFgQUbNOlA6sN -XyzJjYqciKeId7g3/ZowHwYDVR0jBBgwFoAUbNOlA6sNXyzJjYqciKeId7g3/Zow -DAYDVR0TBAUwAwEB/zANBgkqhkiG9w0BAQsFAAOCAQEAVVaR5QWLZIRR4Dw6TSBn -BQiLpBSXN6oAxdDw6n4PtwW6CzydaA+creiK6LfwEsiifUfQe9f+T+TBSpdIYtMv -Z2H2tjlFX8VrjUFvPrvn5c28CuLI0foBgY8XGSkR2YMYzWw2jPEq3Th/KM5Catn3 -AFm3bGKWMtGPR4v+90chEN0jzaAmJYRrVUh9vea27bOCn31Nse6XXQPmSI6Gyncy -OAPUsvPClF3IjeL1tmBotWqSGn1cYxLo+Lwjk22A9h6vjcNQRyZF2VLVvtwYrNU3 -mwJ6GCLsLHpwW/yjyvn8iEltnJvByM/eeRnfXV6WDObyiZsE/n6DxIRJodQzFqy9 -GA== ------END CERTIFICATE----- diff --git a/trace/Cargo.toml b/trace/Cargo.toml new file mode 100644 index 00000000..ad1308d5 --- /dev/null +++ b/trace/Cargo.toml @@ -0,0 +1,13 @@ +[package] +name = "trace" +version = "0.1.0" +authors = ["tikue "] +edition = '2018' + +[dependencies] +rand = "0.5" + +[dependencies.serde] +version = "1.0" +optional = true +features = ["derive"] diff --git a/trace/rustfmt.toml b/trace/rustfmt.toml new file mode 100644 index 00000000..0ef5137d --- /dev/null +++ b/trace/rustfmt.toml @@ -0,0 +1 @@ +edition = "Edition2018" diff --git a/trace/src/lib.rs b/trace/src/lib.rs new file mode 100644 index 00000000..71d702d8 --- /dev/null +++ b/trace/src/lib.rs @@ -0,0 +1,105 @@ +// Copyright 2018 Google Inc. All Rights Reserved. +// +// Licensed under the MIT License, . +// This file may not be copied, modified, or distributed except according to those terms. + +#![deny(missing_docs, missing_debug_implementations)] + +//! Provides building blocks for tracing distributed programs. +//! +//! A trace is logically a tree of causally-related events called spans. Traces are tracked via a +//! [context](Context) that identifies the current trace, span, and parent of the current span. In +//! distributed systems, a context can be sent from client to server to connect events occurring on +//! either side. +//! +//! This crate's design is based on [opencensus +//! tracing](https://opencensus.io/core-concepts/tracing/). + +use rand::Rng; +use std::{ + fmt::{self, Formatter}, + mem, +}; + +/// A context for tracing the execution of processes, distributed or otherwise. +/// +/// Consists of a span identifying an event, an optional parent span identifying a causal event +/// that triggered the current span, and a trace with which all related spans are associated. +#[derive(Debug, PartialEq, Eq, Hash, Clone, Copy)] +#[cfg_attr( + feature = "serde", + derive(serde::Serialize, serde::Deserialize) +)] +pub struct Context { + /// An identifier of the trace associated with the current context. A trace ID is typically + /// created at a root span and passed along through all causal events. + pub trace_id: TraceId, + /// An identifier of the current span. In typical RPC usage, a span is created by a client + /// before making an RPC, and the span ID is sent to the server. The server is free to create + /// its own spans, for which it sets the client's span as the parent span. + pub span_id: SpanId, + /// An identifier of the span that originated the current span. For example, if a server sends + /// an RPC in response to a client request that included a span, the server would create a span + /// for the RPC and set its parent to the span_id in the incoming request's context. + /// + /// If `parent_id` is `None`, then this is a root context. + pub parent_id: Option, +} + +/// A 128-bit UUID identifying a trace. All spans caused by the same originating span share the +/// same trace ID. +#[derive(Debug, PartialEq, Eq, Hash, Clone, Copy)] +#[cfg_attr( + feature = "serde", + derive(serde::Serialize, serde::Deserialize) +)] +pub struct TraceId(u128); + +/// A 64-bit identifier of a span within a trace. The identifier is unique within the span's trace. +#[derive(Debug, PartialEq, Eq, Hash, Clone, Copy)] +#[cfg_attr( + feature = "serde", + derive(serde::Serialize, serde::Deserialize) +)] +pub struct SpanId(u64); + +impl Context { + /// Constructs a new root context. A root context is one with no parent span. + pub fn new_root() -> Self { + let rng = &mut rand::thread_rng(); + Context { + trace_id: TraceId::random(rng), + span_id: SpanId::random(rng), + parent_id: None, + } + } +} + +impl TraceId { + /// Returns a random trace ID that can be assumed to be globally unique if `rng` generates + /// actually-random numbers. + pub fn random(rng: &mut R) -> Self { + TraceId((rng.next_u64() as u128) << mem::size_of::() | rng.next_u64() as u128) + } +} + +impl SpanId { + /// Returns a random span ID that can be assumed to be unique within a single trace. + pub fn random(rng: &mut R) -> Self { + SpanId(rng.next_u64()) + } +} + +impl fmt::Display for TraceId { + fn fmt(&self, f: &mut Formatter) -> Result<(), fmt::Error> { + write!(f, "{:02x}", self.0)?; + Ok(()) + } +} + +impl fmt::Display for SpanId { + fn fmt(&self, f: &mut Formatter) -> Result<(), fmt::Error> { + write!(f, "{:02x}", self.0)?; + Ok(()) + } +}