diff --git a/Cargo.toml b/Cargo.toml index d0fbab2e8..42c7f16de 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -2,13 +2,11 @@ members = [ "tonic", "tonic-build", - # Non-published crates "examples", "interop", - # Tests "tests/included_service", "tests/same_name", "tests/wellknown", -] \ No newline at end of file +] diff --git a/examples/Cargo.toml b/examples/Cargo.toml index 36c66233c..bf49098f7 100644 --- a/examples/Cargo.toml +++ b/examples/Cargo.toml @@ -86,27 +86,30 @@ path = "src/uds/client.rs" name = "uds-server" path = "src/uds/server.rs" +[[bin]] +name = "interceptor-client" +path = "src/interceptor/client.rs" + +[[bin]] +name = "interceptor-server" +path = "src/interceptor/server.rs" + [dependencies] tonic = { path = "../tonic", features = ["tls"] } prost = "0.6" - tokio = { version = "0.2", features = ["rt-threaded", "time", "stream", "fs", "macros", "uds"] } -futures = { version = "0.3", default-features = false, features = ["alloc"]} +futures = { version = "0.3", default-features = false, features = ["alloc"] } async-stream = "0.2" -http = "0.2" -tower = "0.3" - +tower = "0.3" # Required for routeguide serde = { version = "1.0", features = ["derive"] } serde_json = "1.0" rand = "0.7" - # Tracing tracing = "0.1" -tracing-subscriber = { version = "0.2.0-alpha", features = ["tracing-log"] } +tracing-subscriber = { version = "0.2.0-alpha", features = ["tracing-log"] } tracing-attributes = "0.1" tracing-futures = "0.2" - # Required for wellknown types prost-types = "0.6" diff --git a/examples/src/authentication/client.rs b/examples/src/authentication/client.rs index b97b02b37..7297c4716 100644 --- a/examples/src/authentication/client.rs +++ b/examples/src/authentication/client.rs @@ -2,23 +2,19 @@ pub mod pb { tonic::include_proto!("grpc.examples.echo"); } -use http::header::HeaderValue; use pb::{echo_client::EchoClient, EchoRequest}; -use tonic::transport::Channel; +use tonic::{metadata::MetadataValue, transport::Channel, Request}; #[tokio::main] async fn main() -> Result<(), Box> { - let channel = Channel::from_static("http://[::1]:50051") - .intercept_headers(|headers| { - headers.insert( - "authorization", - HeaderValue::from_static("Bearer some-secret-token"), - ); - }) - .connect() - .await?; - - let mut client = EchoClient::new(channel); + let channel = Channel::from_static("http://[::1]:50051").connect().await?; + + let token = MetadataValue::from_str("Bearer some-auth-token")?; + + let mut client = EchoClient::with_interceptor(channel, move |mut req: Request<()>| { + req.metadata_mut().insert("authorization", token.clone()); + Ok(req) + }); let request = tonic::Request::new(EchoRequest { message: "hello".into(), diff --git a/examples/src/authentication/server.rs b/examples/src/authentication/server.rs index 94b0dba3a..0726efcc6 100644 --- a/examples/src/authentication/server.rs +++ b/examples/src/authentication/server.rs @@ -5,8 +5,7 @@ pub mod pb { use futures::Stream; use pb::{EchoRequest, EchoResponse}; use std::pin::Pin; -use tonic::{body::BoxBody, transport::Server, Request, Response, Status, Streaming}; -use tower::Service; +use tonic::{metadata::MetadataValue, transport::Server, Request, Response, Status, Streaming}; type EchoResult = Result, Status>; type ResponseStream = Pin> + Send + Sync>>; @@ -52,36 +51,18 @@ async fn main() -> Result<(), Box> { let addr = "[::1]:50051".parse().unwrap(); let server = EchoServer::default(); - Server::builder() - .interceptor_fn(move |svc, req| { - let auth_header = req.headers().get("authorization").clone(); + let svc = pb::echo_server::EchoServer::with_interceptor(server, check_auth); - let authed = if let Some(auth_header) = auth_header { - auth_header == "Bearer some-secret-token" - } else { - false - }; + Server::builder().add_service(svc).serve(addr).await?; - let fut = svc.call(req); + Ok(()) +} - async move { - if authed { - fut.await - } else { - // Cancel the inner future since we never await it - // the IO never gets registered. - drop(fut); - let res = http::Response::builder() - .header("grpc-status", "16") - .body(BoxBody::empty()) - .unwrap(); - Ok(res) - } - } - }) - .add_service(pb::echo_server::EchoServer::new(server)) - .serve(addr) - .await?; +fn check_auth(req: Request<()>) -> Result, Status> { + let token = MetadataValue::from_str("Bearer some-secret-token").unwrap(); - Ok(()) + match req.metadata().get("authorization") { + Some(t) if token == t => Ok(req), + _ => Err(Status::unauthenticated("No valid auth token")), + } } diff --git a/examples/src/gcp/client.rs b/examples/src/gcp/client.rs index 59570fe07..15f97588d 100644 --- a/examples/src/gcp/client.rs +++ b/examples/src/gcp/client.rs @@ -3,8 +3,8 @@ pub mod api { } use api::{publisher_client::PublisherClient, ListTopicsRequest}; -use http::header::HeaderValue; use tonic::{ + metadata::MetadataValue, transport::{Certificate, Channel, ClientTlsConfig}, Request, }; @@ -23,7 +23,7 @@ async fn main() -> Result<(), Box> { .ok_or("Expected a project name as the first argument.".to_string())?; let bearer_token = format!("Bearer {}", token); - let header_value = HeaderValue::from_str(&bearer_token)?; + let header_value = MetadataValue::from_str(&bearer_token)?; let certs = tokio::fs::read("examples/data/gcp/roots.pem").await?; @@ -32,14 +32,15 @@ async fn main() -> Result<(), Box> { .domain_name("pubsub.googleapis.com"); let channel = Channel::from_static(ENDPOINT) - .intercept_headers(move |headers| { - headers.insert("authorization", header_value.clone()); - }) .tls_config(tls_config) .connect() .await?; - let mut service = PublisherClient::new(channel); + let mut service = PublisherClient::with_interceptor(channel, move |mut req: Request<()>| { + req.metadata_mut() + .insert("authorization", header_value.clone()); + Ok(req) + }); let response = service .list_topics(Request::new(ListTopicsRequest { diff --git a/examples/src/interceptor/client.rs b/examples/src/interceptor/client.rs new file mode 100644 index 000000000..f34cf6675 --- /dev/null +++ b/examples/src/interceptor/client.rs @@ -0,0 +1,34 @@ +use hello_world::greeter_client::GreeterClient; +use hello_world::HelloRequest; +use tonic::{transport::Endpoint, Request, Status}; + +pub mod hello_world { + tonic::include_proto!("helloworld"); +} + +#[tokio::main] +async fn main() -> Result<(), Box> { + let channel = Endpoint::from_static("http://[::1]:50051") + .connect() + .await?; + + let mut client = GreeterClient::with_interceptor(channel, intercept); + + let request = tonic::Request::new(HelloRequest { + name: "Tonic".into(), + }); + + let response = client.say_hello(request).await?; + + println!("RESPONSE={:?}", response); + + Ok(()) +} + +/// This function will get called on each outbound request. Returning a +/// `Status` here will cancel the request and have that status returned to +/// the caller. +fn intercept(req: Request<()>) -> Result, Status> { + println!("Intercepting request: {:?}", req); + Ok(req) +} diff --git a/examples/src/interceptor/server.rs b/examples/src/interceptor/server.rs new file mode 100644 index 000000000..b73a15d9f --- /dev/null +++ b/examples/src/interceptor/server.rs @@ -0,0 +1,46 @@ +use tonic::{transport::Server, Request, Response, Status}; + +use hello_world::greeter_server::{Greeter, GreeterServer}; +use hello_world::{HelloReply, HelloRequest}; + +pub mod hello_world { + tonic::include_proto!("helloworld"); +} + +#[derive(Default)] +pub struct MyGreeter {} + +#[tonic::async_trait] +impl Greeter for MyGreeter { + async fn say_hello( + &self, + request: Request, + ) -> Result, Status> { + let reply = hello_world::HelloReply { + message: format!("Hello {}!", request.into_inner().name), + }; + Ok(Response::new(reply)) + } +} + +#[tokio::main] +async fn main() -> Result<(), Box> { + let addr = "[::1]:50051".parse().unwrap(); + let greeter = MyGreeter::default(); + + let svc = GreeterServer::with_interceptor(greeter, intercept); + + println!("GreeterServer listening on {}", addr); + + Server::builder().add_service(svc).serve(addr).await?; + + Ok(()) +} + +/// This function will get called on each inbound request, if a `Status` +/// is returned, it will cancel the request and return that status to the +/// client. +fn intercept(req: Request<()>) -> Result, Status> { + println!("Intercepting request: {:?}", req); + Ok(req) +} diff --git a/examples/src/uds/client.rs b/examples/src/uds/client.rs index 21e61d926..f91cd6bed 100644 --- a/examples/src/uds/client.rs +++ b/examples/src/uds/client.rs @@ -5,11 +5,10 @@ pub mod hello_world { } use hello_world::{greeter_client::GreeterClient, HelloRequest}; -use http::Uri; use std::convert::TryFrom; #[cfg(unix)] use tokio::net::UnixStream; -use tonic::transport::Endpoint; +use tonic::transport::{Endpoint, Uri}; use tower::service_fn; #[cfg(unix)] diff --git a/interop/Cargo.toml b/interop/Cargo.toml index ade5c8adc..ab1fc2daa 100644 --- a/interop/Cargo.toml +++ b/interop/Cargo.toml @@ -26,10 +26,9 @@ futures-util = "0.3" async-stream = "0.2" tower = "0.3" http-body = "0.3" - +hyper = "0.13" console = "0.9" structopt = "0.3" - tracing = "0.1" tracing-subscriber = "0.2.0-alpha" tracing-log = "0.1.0" diff --git a/interop/src/bin/server.rs b/interop/src/bin/server.rs index 285a1e737..509e67349 100644 --- a/interop/src/bin/server.rs +++ b/interop/src/bin/server.rs @@ -1,10 +1,7 @@ -use http::header::HeaderName; use structopt::StructOpt; -use tonic::body::BoxBody; -use tonic::client::GrpcService; use tonic::transport::Server; use tonic::transport::{Identity, ServerTlsConfig}; -use tonic_interop::{server, MergeTrailers}; +use tonic_interop::server; #[derive(StructOpt)] struct Opts { @@ -20,33 +17,7 @@ async fn main() -> std::result::Result<(), Box> { let addr = "127.0.0.1:10000".parse().unwrap(); - let mut builder = Server::builder().interceptor_fn(|svc, req| { - let echo_header = req - .headers() - .get("x-grpc-test-echo-initial") - .map(Clone::clone); - - let echo_trailer = req - .headers() - .get("x-grpc-test-echo-trailing-bin") - .map(Clone::clone) - .map(|v| (HeaderName::from_static("x-grpc-test-echo-trailing-bin"), v)); - - let call = svc.call(req); - - async move { - let mut res = call.await?; - - if let Some(echo_header) = echo_header { - res.headers_mut() - .insert("x-grpc-test-echo-initial", echo_header); - } - - Ok(res - .map(|b| MergeTrailers::new(b, echo_trailer)) - .map(BoxBody::new)) - } - }); + let mut builder = Server::builder(); if matches.use_tls { let cert = tokio::fs::read("interop/data/server1.pem").await?; @@ -60,8 +31,11 @@ async fn main() -> std::result::Result<(), Box> { let unimplemented_service = server::UnimplementedServiceServer::new(server::UnimplementedService::default()); + // Wrap this test_service with a service that will echo headers as trailers. + let test_service_svc = server::EchoHeadersSvc::new(test_service); + builder - .add_service(test_service) + .add_service(test_service_svc) .add_service(unimplemented_service) .serve(addr) .await?; diff --git a/interop/src/lib.rs b/interop/src/lib.rs index d744136f8..6239a1b3f 100644 --- a/interop/src/lib.rs +++ b/interop/src/lib.rs @@ -9,13 +9,7 @@ pub mod pb { include!(concat!(env!("OUT_DIR"), "/grpc.testing.rs")); } -use http::header::{HeaderMap, HeaderName, HeaderValue}; -use http_body::Body; -use std::{ - default, fmt, iter, - pin::Pin, - task::{Context, Poll}, -}; +use std::{default, fmt, iter}; pub fn trace_init() { let sub = tracing_subscriber::FmtSubscriber::builder() @@ -147,41 +141,3 @@ macro_rules! test_assert { } }; } - -pub struct MergeTrailers { - inner: B, - trailer: Option<(HeaderName, HeaderValue)>, -} - -impl MergeTrailers { - pub fn new(inner: B, trailer: Option<(HeaderName, HeaderValue)>) -> Self { - Self { inner, trailer } - } -} - -impl Body for MergeTrailers { - type Data = B::Data; - type Error = B::Error; - - fn poll_data( - mut self: Pin<&mut Self>, - cx: &mut Context<'_>, - ) -> Poll>> { - Pin::new(&mut self.inner).poll_data(cx) - } - - fn poll_trailers( - mut self: Pin<&mut Self>, - cx: &mut Context<'_>, - ) -> Poll, Self::Error>> { - Pin::new(&mut self.inner).poll_trailers(cx).map_ok(|h| { - h.map(|mut headers| { - if let Some((key, value)) = &self.trailer { - headers.insert(key.clone(), value.clone()); - } - - headers - }) - }) - } -} diff --git a/interop/src/server.rs b/interop/src/server.rs index 581969cd5..047de36c7 100644 --- a/interop/src/server.rs +++ b/interop/src/server.rs @@ -1,9 +1,14 @@ use crate::pb::{self, *}; use async_stream::try_stream; use futures_util::{stream, StreamExt, TryStreamExt}; +use http::header::{HeaderMap, HeaderName, HeaderValue}; +use http_body::Body; +use std::future::Future; use std::pin::Pin; +use std::task::{Context, Poll}; use std::time::Duration; -use tonic::{Code, Request, Response, Status}; +use tonic::{body::BoxBody, transport::ServiceName, Code, Request, Response, Status}; +use tower::Service; pub use pb::test_service_server::TestServiceServer; pub use pb::unimplemented_service_server::UnimplementedServiceServer; @@ -159,3 +164,100 @@ impl pb::unimplemented_service_server::UnimplementedService for UnimplementedSer Err(Status::unimplemented("")) } } + +#[derive(Clone, Default)] +pub struct EchoHeadersSvc { + inner: S, +} + +impl ServiceName for EchoHeadersSvc { + const NAME: &'static str = S::NAME; +} + +impl EchoHeadersSvc { + pub fn new(inner: S) -> Self { + Self { inner } + } +} + +impl Service> for EchoHeadersSvc +where + S: Service, Response = http::Response> + Send, + S::Future: Send + 'static, +{ + type Response = S::Response; + type Error = S::Error; + type Future = Pin< + Box> + Send + 'static>, + >; + + fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll> { + Ok(()).into() + } + + fn call(&mut self, req: http::Request) -> Self::Future { + let echo_header = req + .headers() + .get("x-grpc-test-echo-initial") + .map(Clone::clone); + + let echo_trailer = req + .headers() + .get("x-grpc-test-echo-trailing-bin") + .map(Clone::clone) + .map(|v| (HeaderName::from_static("x-grpc-test-echo-trailing-bin"), v)); + + let call = self.inner.call(req); + + Box::pin(async move { + let mut res = call.await?; + + if let Some(echo_header) = echo_header { + res.headers_mut() + .insert("x-grpc-test-echo-initial", echo_header); + } + + Ok(res + .map(|b| MergeTrailers::new(b, echo_trailer)) + .map(BoxBody::new)) + }) + } +} + +pub struct MergeTrailers { + inner: B, + trailer: Option<(HeaderName, HeaderValue)>, +} + +impl MergeTrailers { + pub fn new(inner: B, trailer: Option<(HeaderName, HeaderValue)>) -> Self { + Self { inner, trailer } + } +} + +impl Body for MergeTrailers { + type Data = B::Data; + type Error = B::Error; + + fn poll_data( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + ) -> Poll>> { + Pin::new(&mut self.inner).poll_data(cx) + } + + fn poll_trailers( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + ) -> Poll, Self::Error>> { + Pin::new(&mut self.inner).poll_trailers(cx).map_ok(|h| { + h.map(|mut headers| { + if let Some((key, value)) = &self.trailer { + headers.insert(key.clone(), value.clone()); + } + + headers + }) + }) + } +} diff --git a/tonic-build/src/client.rs b/tonic-build/src/client.rs index d8f694108..a31697fae 100644 --- a/tonic-build/src/client.rs +++ b/tonic-build/src/client.rs @@ -34,6 +34,11 @@ pub(crate) fn generate(service: &Service, proto: &str) -> TokenStream { Self { inner } } + pub fn with_interceptor(inner: T, interceptor: impl Into) -> Self { + let inner = tonic::client::Grpc::with_interceptor(inner, interceptor); + Self { inner } + } + #methods } diff --git a/tonic-build/src/server.rs b/tonic-build/src/server.rs index 81fa05208..61114a380 100644 --- a/tonic-build/src/server.rs +++ b/tonic-build/src/server.rs @@ -29,12 +29,21 @@ pub(crate) fn generate(service: &Service, proto_path: &str) -> TokenStream { #[derive(Debug)] #[doc(hidden)] pub struct #server_service { - inner: Arc, + inner: _Inner, } + struct _Inner(Arc, Option); + impl #server_service { pub fn new(inner: T) -> Self { let inner = Arc::new(inner); + let inner = _Inner(inner, None); + Self { inner } + } + + pub fn with_interceptor(inner: T, interceptor: impl Into) -> Self { + let inner = Arc::new(inner); + let inner = _Inner(inner, Some(interceptor.into())); Self { inner } } } @@ -72,6 +81,18 @@ pub(crate) fn generate(service: &Service, proto_path: &str) -> TokenStream { } } + impl Clone for _Inner { + fn clone(&self) -> Self { + Self(self.0.clone(), self.1.clone()) + } + } + + impl std::fmt::Debug for _Inner { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "{:?}", self.0) + } + } + #transport } } @@ -246,9 +267,17 @@ fn generate_unary( let inner = self.inner.clone(); let fut = async move { + let interceptor = inner.1.clone(); + let inner = inner.0; let method = #service_ident(inner); let codec = tonic::codec::ProstCodec::default(); - let mut grpc = tonic::server::Grpc::new(codec); + + let mut grpc = if let Some(interceptor) = interceptor { + tonic::server::Grpc::with_interceptor(codec, interceptor) + } else { + tonic::server::Grpc::new(codec) + }; + let res = grpc.unary(method, req).await; Ok(res) }; @@ -289,9 +318,17 @@ fn generate_server_streaming( let inner = self.inner.clone(); let fut = async move { + let interceptor = inner.1; + let inner = inner.0; let method = #service_ident(inner); let codec = tonic::codec::ProstCodec::default(); - let mut grpc = tonic::server::Grpc::new(codec); + + let mut grpc = if let Some(interceptor) = interceptor { + tonic::server::Grpc::with_interceptor(codec, interceptor) + } else { + tonic::server::Grpc::new(codec) + }; + let res = grpc.server_streaming(method, req).await; Ok(res) }; @@ -330,9 +367,17 @@ fn generate_client_streaming( let inner = self.inner.clone(); let fut = async move { + let interceptor = inner.1; + let inner = inner.0; let method = #service_ident(inner); let codec = tonic::codec::ProstCodec::default(); - let mut grpc = tonic::server::Grpc::new(codec); + + let mut grpc = if let Some(interceptor) = interceptor { + tonic::server::Grpc::with_interceptor(codec, interceptor) + } else { + tonic::server::Grpc::new(codec) + }; + let res = grpc.client_streaming(method, req).await; Ok(res) }; @@ -373,9 +418,17 @@ fn generate_streaming( let inner = self.inner.clone(); let fut = async move { + let interceptor = inner.1; + let inner = inner.0; let method = #service_ident(inner); let codec = tonic::codec::ProstCodec::default(); - let mut grpc = tonic::server::Grpc::new(codec); + + let mut grpc = if let Some(interceptor) = interceptor { + tonic::server::Grpc::with_interceptor(codec, interceptor) + } else { + tonic::server::Grpc::new(codec) + }; + let res = grpc.streaming(method, req).await; Ok(res) }; diff --git a/tonic/src/client/grpc.rs b/tonic/src/client/grpc.rs index 07d90b81b..6a42b4022 100644 --- a/tonic/src/client/grpc.rs +++ b/tonic/src/client/grpc.rs @@ -2,6 +2,7 @@ use crate::{ body::{Body, BoxBody}, client::GrpcService, codec::{encode_client, Codec, Streaming}, + interceptor::Interceptor, Code, Request, Response, Status, }; use futures_core::Stream; @@ -28,12 +29,25 @@ use std::fmt; /// [gRPC protocol definition]: https://github.com/grpc/grpc/blob/master/doc/PROTOCOL-HTTP2.md#requests pub struct Grpc { inner: T, + interceptor: Option, } impl Grpc { /// Creates a new gRPC client with the provided [`GrpcService`]. pub fn new(inner: T) -> Self { - Self { inner } + Self { + inner, + interceptor: None, + } + } + + /// Creates a new gRPC client with the provided [`GrpcService`] and will apply + /// the provided interceptor on each request. + pub fn with_interceptor(inner: T, interceptor: impl Into) -> Self { + Self { + inner, + interceptor: Some(interceptor.into()), + } } /// Check if the inner [`GrpcService`] is able to accept a new request. @@ -134,6 +148,12 @@ impl Grpc { M1: Send + Sync + 'static, M2: Send + Sync + 'static, { + let request = if let Some(interceptor) = &self.interceptor { + interceptor.call(request)? + } else { + request + }; + let mut parts = Parts::default(); parts.path_and_query = Some(path); @@ -192,6 +212,7 @@ impl Clone for Grpc { fn clone(&self) -> Self { Self { inner: self.inner.clone(), + interceptor: self.interceptor.clone(), } } } diff --git a/tonic/src/interceptor.rs b/tonic/src/interceptor.rs new file mode 100644 index 000000000..51dcd6e98 --- /dev/null +++ b/tonic/src/interceptor.rs @@ -0,0 +1,54 @@ +use crate::{Request, Status}; +use std::{fmt, sync::Arc}; + +/// Represents a gRPC interceptor. +/// +/// gRPC interceptors are similar to middleware but have much less +/// flexibility. This interceptor allows you to do two main things, +/// one is to add/remove/check items in the `MetadataMap` of each +/// request. Two, cancel a request with any `Status`. +/// +/// An interceptor can be used on both the server and client side through +/// the `tonic-build` crate's generated structs. +/// +/// These interceptors do not allow you to modify the `Message` of the request +/// but allow you to check for metadata. If you would like to apply middleware like +/// features to the body of the request, going through the `tower` abstraction is recommended. +#[derive(Clone)] +pub struct Interceptor { + f: Arc) -> Result, Status> + Send + Sync + 'static>, +} + +impl Interceptor { + /// Create a new `Interceptor` from the provided function. + pub fn new( + f: impl Fn(Request<()>) -> Result, Status> + Send + Sync + 'static, + ) -> Self { + Interceptor { f: Arc::new(f) } + } + + pub(crate) fn call(&self, req: Request) -> Result, Status> { + let (metadata, ext, message) = req.into_parts(); + + let temp_req = Request::from_parts(metadata, ext, ()); + + let (metadata, ext, _) = (self.f)(temp_req)?.into_parts(); + + Ok(Request::from_parts(metadata, ext, message)) + } +} + +impl From for Interceptor +where + F: Fn(Request<()>) -> Result, Status> + Send + Sync + 'static, +{ + fn from(f: F) -> Self { + Interceptor::new(f) + } +} + +impl fmt::Debug for Interceptor { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("Interceptor").finish() + } +} diff --git a/tonic/src/lib.rs b/tonic/src/lib.rs index c0d4f40d3..c75784796 100644 --- a/tonic/src/lib.rs +++ b/tonic/src/lib.rs @@ -86,6 +86,7 @@ pub mod server; #[cfg_attr(docsrs, doc(cfg(feature = "transport")))] pub mod transport; +mod interceptor; mod macros; mod request; mod response; @@ -98,6 +99,7 @@ pub use async_trait::async_trait; #[doc(inline)] pub use codec::Streaming; +pub use interceptor::Interceptor; pub use request::{IntoRequest, IntoStreamingRequest, Request}; pub use response::Response; pub use status::{Code, Status}; diff --git a/tonic/src/request.rs b/tonic/src/request.rs index f1658b83b..f2f047ffc 100644 --- a/tonic/src/request.rs +++ b/tonic/src/request.rs @@ -145,6 +145,18 @@ impl Request { self.message } + pub(crate) fn into_parts(self) -> (MetadataMap, Extensions, T) { + (self.metadata, self.extensions, self.message) + } + + pub(crate) fn from_parts(metadata: MetadataMap, extensions: Extensions, message: T) -> Self { + Self { + metadata, + extensions, + message, + } + } + pub(crate) fn from_http_parts(parts: http::request::Parts, message: T) -> Self { Request { metadata: MetadataMap::from_headers(parts.headers), diff --git a/tonic/src/server/grpc.rs b/tonic/src/server/grpc.rs index b099b36a1..3ef566476 100644 --- a/tonic/src/server/grpc.rs +++ b/tonic/src/server/grpc.rs @@ -1,6 +1,7 @@ use crate::{ body::BoxBody, codec::{encode_server, Codec, Streaming}, + interceptor::Interceptor, server::{ClientStreamingService, ServerStreamingService, StreamingService, UnaryService}, Code, Request, Response, Status, }; @@ -9,6 +10,16 @@ use futures_util::{future, stream, TryStreamExt}; use http_body::Body; use std::fmt; +// A try! type macro for intercepting requests +macro_rules! t { + ($expr : expr) => { + match $expr { + Ok(request) => request, + Err(res) => return res, + } + }; +} + /// A gRPC Server handler. /// /// This will wrap some inner [`Codec`] and provide utilities to handle @@ -20,6 +31,7 @@ use std::fmt; /// implements some [`Body`]. pub struct Grpc { codec: T, + interceptor: Option, } impl Grpc @@ -27,9 +39,21 @@ where T: Codec, T::Encode: Sync, { - /// Creates a new gRPC client with the provided [`Codec`]. + /// Creates a new gRPC server with the provided [`Codec`]. pub fn new(codec: T) -> Self { - Self { codec } + Self { + codec, + interceptor: None, + } + } + + /// Creates a new gRPC server with the provided [`Codec`] and will apply the provided + /// interceptor on each inbound request. + pub fn with_interceptor(codec: T, interceptor: impl Into) -> Self { + Self { + codec, + interceptor: Some(interceptor.into()), + } } /// Handle a single unary gRPC request. @@ -53,6 +77,8 @@ where } }; + let request = t!(self.intercept_request(request)); + let response = service .call(request) .await @@ -80,6 +106,8 @@ where } }; + let request = t!(self.intercept_request(request)); + let response = service.call(request).await; self.map_response(response) @@ -97,6 +125,7 @@ where B::Error: Into + Send + 'static, { let request = self.map_request_streaming(req); + let request = t!(self.intercept_request(request)); let response = service .call(request) .await @@ -117,6 +146,7 @@ where B::Error: Into + Send, { let request = self.map_request_streaming(req); + let request = t!(self.intercept_request(request)); let response = service.call(request).await; self.map_response(response) } @@ -180,18 +210,34 @@ where http::Response::from_parts(parts, BoxBody::new(body)) } - Err(status) => { - let (mut parts, _body) = Response::new(()).into_http().into_parts(); + Err(status) => Self::map_status(status), + } + } - parts.headers.insert( - http::header::CONTENT_TYPE, - http::header::HeaderValue::from_static("application/grpc"), - ); + fn map_status(status: Status) -> http::Response { + let (mut parts, _body) = Response::new(()).into_http().into_parts(); + + parts.headers.insert( + http::header::CONTENT_TYPE, + http::header::HeaderValue::from_static("application/grpc"), + ); - status.add_header(&mut parts.headers).unwrap(); + status.add_header(&mut parts.headers).unwrap(); + + http::Response::from_parts(parts, BoxBody::empty()) + } - http::Response::from_parts(parts, BoxBody::empty()) + fn intercept_request(&self, req: Request) -> Result, http::Response> { + if let Some(interceptor) = &self.interceptor { + match interceptor.call(req) { + Ok(req) => Ok(req), + Err(status) => { + let res = Self::map_status(status); + return Err(res); + } } + } else { + Ok(req) } } } diff --git a/tonic/src/transport/channel/endpoint.rs b/tonic/src/transport/channel/endpoint.rs index 24733d423..54c206ebe 100644 --- a/tonic/src/transport/channel/endpoint.rs +++ b/tonic/src/transport/channel/endpoint.rs @@ -10,7 +10,6 @@ use http::uri::{InvalidUri, Uri}; use std::{ convert::{TryFrom, TryInto}, fmt, - sync::Arc, time::Duration, }; use tower_make::MakeConnection; @@ -27,8 +26,6 @@ pub struct Endpoint { #[cfg(feature = "tls")] pub(crate) tls: Option, pub(crate) buffer_size: Option, - pub(crate) interceptor_headers: - Option>, pub(crate) init_stream_window_size: Option, pub(crate) init_connection_window_size: Option, pub(crate) tcp_keepalive: Option, @@ -152,29 +149,6 @@ impl Endpoint { } } - /// Intercept outbound HTTP Request headers; - /// - /// # Example - /// - /// ``` - /// # use tonic::transport::Endpoint; - /// # use std::time::Duration; - /// # let mut builder = Endpoint::from_static("https://example.com"); - /// builder.intercept_headers(|headers| { - /// // Do something with headers - /// headers.insert("hello", "world".parse().unwrap()); - /// }); - /// ``` - pub fn intercept_headers(self, f: F) -> Self - where - F: Fn(&mut http::HeaderMap) + Send + Sync + 'static, - { - Endpoint { - interceptor_headers: Some(Arc::new(f)), - ..self - } - } - /// Configures TLS for the endpoint. #[cfg(feature = "tls")] #[cfg_attr(docsrs, doc(cfg(feature = "tls")))] @@ -237,7 +211,6 @@ impl From for Endpoint { #[cfg(feature = "tls")] tls: None, buffer_size: None, - interceptor_headers: None, init_stream_window_size: None, init_connection_window_size: None, tcp_keepalive: None, diff --git a/tonic/src/transport/channel/mod.rs b/tonic/src/transport/channel/mod.rs index 9141856fa..acbafb2cf 100644 --- a/tonic/src/transport/channel/mod.rs +++ b/tonic/src/transport/channel/mod.rs @@ -21,7 +21,6 @@ use std::{ fmt, future::Future, pin::Pin, - sync::Arc, task::{Context, Poll}, }; use tokio::io::{AsyncRead, AsyncWrite}; @@ -63,7 +62,6 @@ const DEFAULT_BUFFER_SIZE: usize = 1024; #[derive(Clone)] pub struct Channel { svc: Buffer>, - interceptor_headers: Option>, } /// A future that resolves to an HTTP response. @@ -114,14 +112,9 @@ impl Channel { .and_then(|e| e.buffer_size) .unwrap_or(DEFAULT_BUFFER_SIZE); - let interceptor_headers = list - .iter() - .next() - .and_then(|e| e.interceptor_headers.clone()); - let discover = ServiceList::new(list); - Self::balance(discover, buffer_size, interceptor_headers) + Self::balance(discover, buffer_size) } pub(crate) async fn connect(connector: C, endpoint: Endpoint) -> Result @@ -132,7 +125,6 @@ impl Channel { C::Response: AsyncRead + AsyncWrite + HyperConnection + Unpin + Send + 'static, { let buffer_size = endpoint.buffer_size.clone().unwrap_or(DEFAULT_BUFFER_SIZE); - let interceptor_headers = endpoint.interceptor_headers.clone(); let svc = Connection::new(connector, endpoint) .await @@ -140,17 +132,10 @@ impl Channel { let svc = Buffer::new(Either::A(svc), buffer_size); - Ok(Channel { - svc, - interceptor_headers, - }) + Ok(Channel { svc }) } - pub(crate) fn balance( - discover: D, - buffer_size: usize, - interceptor_headers: Option>, - ) -> Self + pub(crate) fn balance(discover: D, buffer_size: usize) -> Self where D: Discover + Unpin + Send + 'static, D::Error: Into, @@ -161,10 +146,7 @@ impl Channel { let svc = BoxService::new(svc); let svc = Buffer::new(Either::B(svc), buffer_size); - Channel { - svc, - interceptor_headers, - } + Channel { svc } } } @@ -177,11 +159,7 @@ impl GrpcService for Channel { GrpcService::poll_ready(&mut self.svc, cx).map_err(|e| super::Error::from_source(e)) } - fn call(&mut self, mut request: Request) -> Self::Future { - if let Some(interceptor) = self.interceptor_headers.clone() { - interceptor(request.headers_mut()); - } - + fn call(&mut self, request: Request) -> Self::Future { let inner = GrpcService::call(&mut self.svc, request); ResponseFuture { inner } } diff --git a/tonic/src/transport/mod.rs b/tonic/src/transport/mod.rs index c5418df52..192264ff2 100644 --- a/tonic/src/transport/mod.rs +++ b/tonic/src/transport/mod.rs @@ -12,7 +12,6 @@ //! - Timeouts //! - Concurrency Limits //! - Rate limiting -//! - gRPC Interceptors //! //! # Examples //! @@ -77,10 +76,6 @@ //! .tls_config(ServerTlsConfig::with_rustls() //! .identity(Identity::from_pem(&cert, &key))) //! .concurrency_limit_per_connection(256) -//! .interceptor_fn(|svc, req| { -//! println!("Request: {:?}", req); -//! svc.call(req) -//! }) //! .add_service(my_svc) //! .serve(addr) //! .await?; @@ -104,7 +99,7 @@ pub use self::error::Error; #[doc(inline)] pub use self::server::{Server, ServiceName}; pub use self::tls::{Certificate, Identity}; -pub use hyper::Body; +pub use hyper::{Body, Uri}; #[cfg(feature = "tls")] #[cfg_attr(docsrs, doc(cfg(feature = "tls")))] diff --git a/tonic/src/transport/server/mod.rs b/tonic/src/transport/server/mod.rs index e03e03774..0b07bb0af 100644 --- a/tonic/src/transport/server/mod.rs +++ b/tonic/src/transport/server/mod.rs @@ -15,7 +15,7 @@ use super::service::TlsAcceptor; use incoming::TcpIncoming; -use super::service::{layer_fn, Or, Routes, ServerIo, ServiceBuilderExt}; +use super::service::{Or, Routes, ServerIo, ServiceBuilderExt}; use crate::{body::BoxBody, request::ConnectionInfo}; use futures_core::Stream; use futures_util::{ @@ -35,15 +35,11 @@ use std::{ }; use tokio::io::{AsyncRead, AsyncWrite}; use tower::{ - layer::{Layer, Stack}, - limit::concurrency::ConcurrencyLimitLayer, - timeout::TimeoutLayer, - Service, ServiceBuilder, + limit::concurrency::ConcurrencyLimitLayer, timeout::TimeoutLayer, Service, ServiceBuilder, }; use tracing_futures::{Instrument, Instrumented}; type BoxService = tower::util::BoxService, Response, crate::Error>; -type Interceptor = Arc + Send + Sync + 'static>; type TraceInterceptor = Arc tracing::Span + Send + Sync + 'static>; /// A default batteries included `transport` server. @@ -56,7 +52,6 @@ type TraceInterceptor = Arc tracing::Span + Send + Sync + /// wanting to create a more complex and/or specific implementation. #[derive(Default, Clone)] pub struct Server { - interceptor: Option, trace_interceptor: Option, concurrency_limit: Option, timeout: Option, @@ -198,35 +193,6 @@ impl Server { } } - /// Intercept the execution of gRPC methods. - /// - /// ``` - /// # use tonic::transport::Server; - /// # use tower_service::Service; - /// # let mut builder = Server::builder(); - /// builder.interceptor_fn(|svc, req| { - /// println!("request={:?}", req); - /// svc.call(req) - /// }); - /// ``` - pub fn interceptor_fn(self, f: F) -> Self - where - F: Fn(&mut BoxService, Request) -> Out + Send + Sync + 'static, - Out: Future, crate::Error>> + Send + 'static, - { - let f = Arc::new(f); - let interceptor = layer_fn(move |mut s| { - let f = f.clone(); - tower::service_fn(move |req| f(&mut s, req)) - }); - let layer = Stack::new(interceptor, layer_fn(BoxService::new)); - - Server { - interceptor: Some(Arc::new(layer)), - ..self - } - } - /// Intercept inbound headers and add a [`tracing::Span`] to each response future. pub fn trace_fn(self, f: F) -> Self where @@ -270,7 +236,6 @@ impl Server { IE: Into, F: Future, { - let interceptor = self.interceptor.clone(); let span = self.trace_interceptor.clone(); let concurrency_limit = self.concurrency_limit; let init_connection_window_size = self.init_connection_window_size; @@ -283,7 +248,6 @@ impl Server { let svc = MakeSvc { inner: svc, - interceptor, concurrency_limit, timeout, span, @@ -480,7 +444,6 @@ impl fmt::Debug for Svc { } struct MakeSvc { - interceptor: Option, concurrency_limit: Option, timeout: Option, inner: S, @@ -508,7 +471,6 @@ where peer_certs: io.peer_certs().map(Arc::new), }; - let interceptor = self.interceptor.clone(); let svc = self.inner.clone(); let concurrency_limit = self.concurrency_limit; let timeout = self.timeout.clone(); @@ -520,20 +482,11 @@ where .optional_layer(timeout.map(TimeoutLayer::new)) .service(svc); - let svc = if let Some(interceptor) = interceptor { - let layered = interceptor.layer(BoxService::new(Svc { - inner: svc, - span, - conn_info, - })); - BoxService::new(layered) - } else { - BoxService::new(Svc { - inner: svc, - span, - conn_info, - }) - }; + let svc = BoxService::new(Svc { + inner: svc, + span, + conn_info, + }); Ok(svc) }) diff --git a/tonic/src/transport/service/layer.rs b/tonic/src/transport/service/layer.rs index c5cce18ee..765f3b917 100644 --- a/tonic/src/transport/service/layer.rs +++ b/tonic/src/transport/service/layer.rs @@ -42,6 +42,8 @@ impl ServiceBuilderExt for ServiceBuilder { } } +// TODO: figure out why this is causing a warning even though its used in optional_layer_fn +#[allow(dead_code)] pub(crate) fn layer_fn(f: F) -> LayerFn { LayerFn(f) } diff --git a/tonic/src/transport/service/mod.rs b/tonic/src/transport/service/mod.rs index 4419ecead..3bc691871 100644 --- a/tonic/src/transport/service/mod.rs +++ b/tonic/src/transport/service/mod.rs @@ -14,7 +14,7 @@ pub(crate) use self::connection::Connection; pub(crate) use self::connector::connector; pub(crate) use self::discover::ServiceList; pub(crate) use self::io::ServerIo; -pub(crate) use self::layer::{layer_fn, ServiceBuilderExt}; +pub(crate) use self::layer::ServiceBuilderExt; pub(crate) use self::router::{Or, Routes}; #[cfg(feature = "tls")] pub(crate) use self::tls::{TlsAcceptor, TlsConnector};