Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Add gRPC interceptors #232

Merged
merged 4 commits into from
Jan 14, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 1 addition & 3 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,11 @@
members = [
"tonic",
"tonic-build",

# Non-published crates
"examples",
"interop",

# Tests
"tests/included_service",
"tests/same_name",
"tests/wellknown",
]
]
19 changes: 11 additions & 8 deletions examples/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -86,27 +86,30 @@ path = "src/uds/client.rs"
name = "uds-server"
path = "src/uds/server.rs"

[[bin]]
name = "interceptor-client"
path = "src/interceptor/client.rs"

[[bin]]
name = "interceptor-server"
path = "src/interceptor/server.rs"

[dependencies]
tonic = { path = "../tonic", features = ["tls"] }
prost = "0.6"

tokio = { version = "0.2", features = ["rt-threaded", "time", "stream", "fs", "macros", "uds"] }
futures = { version = "0.3", default-features = false, features = ["alloc"]}
futures = { version = "0.3", default-features = false, features = ["alloc"] }
async-stream = "0.2"
http = "0.2"
tower = "0.3"

tower = "0.3"
# Required for routeguide
serde = { version = "1.0", features = ["derive"] }
serde_json = "1.0"
rand = "0.7"

# Tracing
tracing = "0.1"
tracing-subscriber = { version = "0.2.0-alpha", features = ["tracing-log"] }
tracing-subscriber = { version = "0.2.0-alpha", features = ["tracing-log"] }
tracing-attributes = "0.1"
tracing-futures = "0.2"

# Required for wellknown types
prost-types = "0.6"

Expand Down
22 changes: 9 additions & 13 deletions examples/src/authentication/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,23 +2,19 @@ pub mod pb {
tonic::include_proto!("grpc.examples.echo");
}

use http::header::HeaderValue;
use pb::{echo_client::EchoClient, EchoRequest};
use tonic::transport::Channel;
use tonic::{metadata::MetadataValue, transport::Channel, Request};

#[tokio::main]
async fn main() -> Result<(), Box<dyn std::error::Error>> {
let channel = Channel::from_static("http://[::1]:50051")
.intercept_headers(|headers| {
headers.insert(
"authorization",
HeaderValue::from_static("Bearer some-secret-token"),
);
})
.connect()
.await?;

let mut client = EchoClient::new(channel);
let channel = Channel::from_static("http://[::1]:50051").connect().await?;

let token = MetadataValue::from_str("Bearer some-auth-token")?;

let mut client = EchoClient::with_interceptor(channel, move |mut req: Request<()>| {
req.metadata_mut().insert("authorization", token.clone());
Ok(req)
});

let request = tonic::Request::new(EchoRequest {
message: "hello".into(),
Expand Down
41 changes: 11 additions & 30 deletions examples/src/authentication/server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,7 @@ pub mod pb {
use futures::Stream;
use pb::{EchoRequest, EchoResponse};
use std::pin::Pin;
use tonic::{body::BoxBody, transport::Server, Request, Response, Status, Streaming};
use tower::Service;
use tonic::{metadata::MetadataValue, transport::Server, Request, Response, Status, Streaming};

type EchoResult<T> = Result<Response<T>, Status>;
type ResponseStream = Pin<Box<dyn Stream<Item = Result<EchoResponse, Status>> + Send + Sync>>;
Expand Down Expand Up @@ -52,36 +51,18 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {
let addr = "[::1]:50051".parse().unwrap();
let server = EchoServer::default();

Server::builder()
.interceptor_fn(move |svc, req| {
let auth_header = req.headers().get("authorization").clone();
let svc = pb::echo_server::EchoServer::with_interceptor(server, check_auth);

let authed = if let Some(auth_header) = auth_header {
auth_header == "Bearer some-secret-token"
} else {
false
};
Server::builder().add_service(svc).serve(addr).await?;

let fut = svc.call(req);
Ok(())
}

async move {
if authed {
fut.await
} else {
// Cancel the inner future since we never await it
// the IO never gets registered.
drop(fut);
let res = http::Response::builder()
.header("grpc-status", "16")
.body(BoxBody::empty())
.unwrap();
Ok(res)
}
}
})
.add_service(pb::echo_server::EchoServer::new(server))
.serve(addr)
.await?;
fn check_auth(req: Request<()>) -> Result<Request<()>, Status> {
let token = MetadataValue::from_str("Bearer some-secret-token").unwrap();

Ok(())
match req.metadata().get("authorization") {
Some(t) if token == t => Ok(req),
_ => Err(Status::unauthenticated("No valid auth token")),
}
}
13 changes: 7 additions & 6 deletions examples/src/gcp/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,8 @@ pub mod api {
}

use api::{publisher_client::PublisherClient, ListTopicsRequest};
use http::header::HeaderValue;
use tonic::{
metadata::MetadataValue,
transport::{Certificate, Channel, ClientTlsConfig},
Request,
};
Expand All @@ -23,7 +23,7 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {
.ok_or("Expected a project name as the first argument.".to_string())?;

let bearer_token = format!("Bearer {}", token);
let header_value = HeaderValue::from_str(&bearer_token)?;
let header_value = MetadataValue::from_str(&bearer_token)?;

let certs = tokio::fs::read("examples/data/gcp/roots.pem").await?;

Expand All @@ -32,14 +32,15 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {
.domain_name("pubsub.googleapis.com");

let channel = Channel::from_static(ENDPOINT)
.intercept_headers(move |headers| {
headers.insert("authorization", header_value.clone());
})
.tls_config(tls_config)
.connect()
.await?;

let mut service = PublisherClient::new(channel);
let mut service = PublisherClient::with_interceptor(channel, move |mut req: Request<()>| {
req.metadata_mut()
.insert("authorization", header_value.clone());
Ok(req)
});

let response = service
.list_topics(Request::new(ListTopicsRequest {
Expand Down
34 changes: 34 additions & 0 deletions examples/src/interceptor/client.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
use hello_world::greeter_client::GreeterClient;
use hello_world::HelloRequest;
use tonic::{transport::Endpoint, Request, Status};

pub mod hello_world {
tonic::include_proto!("helloworld");
}

#[tokio::main]
async fn main() -> Result<(), Box<dyn std::error::Error>> {
let channel = Endpoint::from_static("http://[::1]:50051")
.connect()
.await?;

let mut client = GreeterClient::with_interceptor(channel, intercept);

let request = tonic::Request::new(HelloRequest {
name: "Tonic".into(),
});

let response = client.say_hello(request).await?;

println!("RESPONSE={:?}", response);

Ok(())
}

/// This function will get called on each outbound request. Returning a
/// `Status` here will cancel the request and have that status returned to
/// the caller.
fn intercept(req: Request<()>) -> Result<Request<()>, Status> {
println!("Intercepting request: {:?}", req);
Ok(req)
}
46 changes: 46 additions & 0 deletions examples/src/interceptor/server.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
use tonic::{transport::Server, Request, Response, Status};

use hello_world::greeter_server::{Greeter, GreeterServer};
use hello_world::{HelloReply, HelloRequest};

pub mod hello_world {
tonic::include_proto!("helloworld");
}

#[derive(Default)]
pub struct MyGreeter {}

#[tonic::async_trait]
impl Greeter for MyGreeter {
async fn say_hello(
&self,
request: Request<HelloRequest>,
) -> Result<Response<HelloReply>, Status> {
let reply = hello_world::HelloReply {
message: format!("Hello {}!", request.into_inner().name),
};
Ok(Response::new(reply))
}
}

#[tokio::main]
async fn main() -> Result<(), Box<dyn std::error::Error>> {
let addr = "[::1]:50051".parse().unwrap();
let greeter = MyGreeter::default();

let svc = GreeterServer::with_interceptor(greeter, intercept);

println!("GreeterServer listening on {}", addr);

Server::builder().add_service(svc).serve(addr).await?;

Ok(())
}

/// This function will get called on each inbound request, if a `Status`
/// is returned, it will cancel the request and return that status to the
/// client.
fn intercept(req: Request<()>) -> Result<Request<()>, Status> {
println!("Intercepting request: {:?}", req);
Ok(req)
}
3 changes: 1 addition & 2 deletions examples/src/uds/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,10 @@ pub mod hello_world {
}

use hello_world::{greeter_client::GreeterClient, HelloRequest};
use http::Uri;
use std::convert::TryFrom;
#[cfg(unix)]
use tokio::net::UnixStream;
use tonic::transport::Endpoint;
use tonic::transport::{Endpoint, Uri};
use tower::service_fn;

#[cfg(unix)]
Expand Down
3 changes: 1 addition & 2 deletions interop/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -26,10 +26,9 @@ futures-util = "0.3"
async-stream = "0.2"
tower = "0.3"
http-body = "0.3"

hyper = "0.13"
console = "0.9"
structopt = "0.3"

tracing = "0.1"
tracing-subscriber = "0.2.0-alpha"
tracing-log = "0.1.0"
Expand Down
38 changes: 6 additions & 32 deletions interop/src/bin/server.rs
Original file line number Diff line number Diff line change
@@ -1,10 +1,7 @@
use http::header::HeaderName;
use structopt::StructOpt;
use tonic::body::BoxBody;
use tonic::client::GrpcService;
use tonic::transport::Server;
use tonic::transport::{Identity, ServerTlsConfig};
use tonic_interop::{server, MergeTrailers};
use tonic_interop::server;

#[derive(StructOpt)]
struct Opts {
Expand All @@ -20,33 +17,7 @@ async fn main() -> std::result::Result<(), Box<dyn std::error::Error>> {

let addr = "127.0.0.1:10000".parse().unwrap();

let mut builder = Server::builder().interceptor_fn(|svc, req| {
let echo_header = req
.headers()
.get("x-grpc-test-echo-initial")
.map(Clone::clone);

let echo_trailer = req
.headers()
.get("x-grpc-test-echo-trailing-bin")
.map(Clone::clone)
.map(|v| (HeaderName::from_static("x-grpc-test-echo-trailing-bin"), v));

let call = svc.call(req);

async move {
let mut res = call.await?;

if let Some(echo_header) = echo_header {
res.headers_mut()
.insert("x-grpc-test-echo-initial", echo_header);
}

Ok(res
.map(|b| MergeTrailers::new(b, echo_trailer))
.map(BoxBody::new))
}
});
let mut builder = Server::builder();

if matches.use_tls {
let cert = tokio::fs::read("interop/data/server1.pem").await?;
Expand All @@ -60,8 +31,11 @@ async fn main() -> std::result::Result<(), Box<dyn std::error::Error>> {
let unimplemented_service =
server::UnimplementedServiceServer::new(server::UnimplementedService::default());

// Wrap this test_service with a service that will echo headers as trailers.
let test_service_svc = server::EchoHeadersSvc::new(test_service);

builder
.add_service(test_service)
.add_service(test_service_svc)
.add_service(unimplemented_service)
.serve(addr)
.await?;
Expand Down
Loading