-
Notifications
You must be signed in to change notification settings - Fork 1k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
feat(tonic): add Request
and Response
extensions
#642
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,144 @@ | ||
use futures_util::FutureExt; | ||
use hyper::{Body, Request as HyperRequest, Response as HyperResponse}; | ||
use integration_tests::pb::{test_client, test_server, Input, Output}; | ||
use std::{ | ||
task::{Context, Poll}, | ||
time::Duration, | ||
}; | ||
use tokio::sync::oneshot; | ||
use tonic::{ | ||
body::BoxBody, | ||
transport::{Endpoint, NamedService, Server}, | ||
Request, Response, Status, | ||
}; | ||
use tower_service::Service; | ||
|
||
struct ExtensionValue(i32); | ||
|
||
#[tokio::test] | ||
async fn setting_extension_from_interceptor() { | ||
struct Svc; | ||
|
||
#[tonic::async_trait] | ||
impl test_server::Test for Svc { | ||
async fn unary_call(&self, req: Request<Input>) -> Result<Response<Output>, Status> { | ||
let value = req.extensions().get::<ExtensionValue>().unwrap(); | ||
assert_eq!(value.0, 42); | ||
|
||
Ok(Response::new(Output {})) | ||
} | ||
} | ||
|
||
let svc = test_server::TestServer::with_interceptor(Svc, |mut req: Request<()>| { | ||
req.extensions_mut().insert(ExtensionValue(42)); | ||
Ok(req) | ||
}); | ||
|
||
let (tx, rx) = oneshot::channel::<()>(); | ||
|
||
let jh = tokio::spawn(async move { | ||
Server::builder() | ||
.add_service(svc) | ||
.serve_with_shutdown("127.0.0.1:1323".parse().unwrap(), rx.map(drop)) | ||
.await | ||
.unwrap(); | ||
}); | ||
|
||
tokio::time::sleep(Duration::from_millis(100)).await; | ||
|
||
let channel = Endpoint::from_static("http://127.0.0.1:1323") | ||
.connect() | ||
.await | ||
.unwrap(); | ||
|
||
let mut client = test_client::TestClient::new(channel); | ||
|
||
client.unary_call(Input {}).await.unwrap(); | ||
|
||
tx.send(()).unwrap(); | ||
|
||
jh.await.unwrap(); | ||
} | ||
|
||
#[tokio::test] | ||
async fn setting_extension_from_tower() { | ||
struct Svc; | ||
|
||
#[tonic::async_trait] | ||
impl test_server::Test for Svc { | ||
async fn unary_call(&self, req: Request<Input>) -> Result<Response<Output>, Status> { | ||
let value = req.extensions().get::<ExtensionValue>().unwrap(); | ||
assert_eq!(value.0, 42); | ||
|
||
Ok(Response::new(Output {})) | ||
} | ||
} | ||
|
||
let svc = InterceptedService { | ||
inner: test_server::TestServer::new(Svc), | ||
}; | ||
|
||
let (tx, rx) = oneshot::channel::<()>(); | ||
|
||
let jh = tokio::spawn(async move { | ||
Server::builder() | ||
.add_service(svc) | ||
.serve_with_shutdown("127.0.0.1:1324".parse().unwrap(), rx.map(drop)) | ||
.await | ||
.unwrap(); | ||
}); | ||
|
||
tokio::time::sleep(Duration::from_millis(100)).await; | ||
|
||
let channel = Endpoint::from_static("http://127.0.0.1:1324") | ||
.connect() | ||
.await | ||
.unwrap(); | ||
|
||
let mut client = test_client::TestClient::new(channel); | ||
|
||
client.unary_call(Input {}).await.unwrap(); | ||
|
||
tx.send(()).unwrap(); | ||
|
||
jh.await.unwrap(); | ||
} | ||
|
||
#[derive(Debug, Clone)] | ||
struct InterceptedService<S> { | ||
inner: S, | ||
} | ||
|
||
impl<S> Service<HyperRequest<Body>> for InterceptedService<S> | ||
where | ||
S: Service<HyperRequest<Body>, Response = HyperResponse<BoxBody>> | ||
+ NamedService | ||
+ Clone | ||
+ Send | ||
+ 'static, | ||
S::Future: Send + 'static, | ||
{ | ||
type Response = S::Response; | ||
type Error = S::Error; | ||
type Future = futures::future::BoxFuture<'static, Result<Self::Response, Self::Error>>; | ||
|
||
fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> { | ||
self.inner.poll_ready(cx) | ||
} | ||
|
||
fn call(&mut self, mut req: HyperRequest<Body>) -> Self::Future { | ||
let clone = self.inner.clone(); | ||
let mut inner = std::mem::replace(&mut self.inner, clone); | ||
|
||
req.extensions_mut().insert(ExtensionValue(42)); | ||
|
||
Box::pin(async move { | ||
let response = inner.call(req).await?; | ||
Ok(response) | ||
}) | ||
} | ||
} | ||
|
||
impl<S: NamedService> NamedService for InterceptedService<S> { | ||
const NAME: &'static str = S::NAME; | ||
} |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,71 @@ | ||
use std::fmt; | ||
|
||
/// A type map of protocol extensions. | ||
/// | ||
/// `Extensions` can be used by [`Interceptor`] and [`Request`] to store extra data derived from | ||
/// the underlying protocol. | ||
/// | ||
/// [`Interceptor`]: crate::Interceptor | ||
/// [`Request`]: crate::Request | ||
pub struct Extensions { | ||
inner: http::Extensions, | ||
} | ||
|
||
impl Extensions { | ||
pub(crate) fn new() -> Self { | ||
Self { | ||
inner: http::Extensions::new(), | ||
} | ||
} | ||
|
||
/// Insert a type into this `Extensions`. | ||
/// | ||
/// If a extension of this type already existed, it will | ||
/// be returned. | ||
#[inline] | ||
pub fn insert<T: Send + Sync + 'static>(&mut self, val: T) -> Option<T> { | ||
self.inner.insert(val) | ||
} | ||
|
||
/// Get a reference to a type previously inserted on this `Extensions`. | ||
#[inline] | ||
pub fn get<T: Send + Sync + 'static>(&self) -> Option<&T> { | ||
self.inner.get() | ||
} | ||
|
||
/// Get a mutable reference to a type previously inserted on this `Extensions`. | ||
#[inline] | ||
pub fn get_mut<T: Send + Sync + 'static>(&mut self) -> Option<&mut T> { | ||
self.inner.get_mut() | ||
} | ||
|
||
/// Remove a type from this `Extensions`. | ||
/// | ||
/// If a extension of this type existed, it will be returned. | ||
#[inline] | ||
pub fn remove<T: Send + Sync + 'static>(&mut self) -> Option<T> { | ||
self.inner.remove() | ||
} | ||
|
||
/// Clear the `Extensions` of all inserted extensions. | ||
#[inline] | ||
pub fn clear(&mut self) { | ||
self.inner.clear() | ||
} | ||
|
||
#[inline] | ||
pub(crate) fn from_http(http: http::Extensions) -> Self { | ||
Self { inner: http } | ||
} | ||
|
||
#[inline] | ||
pub(crate) fn into_http(self) -> http::Extensions { | ||
self.inner | ||
} | ||
} | ||
|
||
impl fmt::Debug for Extensions { | ||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { | ||
f.debug_struct("Extensions").finish() | ||
} | ||
} |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -62,6 +62,7 @@ | |
//! [`transport`]: transport/index.html | ||
|
||
#![recursion_limit = "256"] | ||
#![allow(clippy::inconsistent_struct_constructor)] | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. the heck is this lint lmao There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The most pedantic lint ever https://rust-lang.github.io/rust-clippy/master/#inconsistent_struct_constructor. rust-analyzer was complaining about it. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. lmaooooo |
||
#![warn( | ||
missing_debug_implementations, | ||
missing_docs, | ||
|
@@ -87,6 +88,7 @@ pub mod server; | |
#[cfg_attr(docsrs, doc(cfg(feature = "transport")))] | ||
pub mod transport; | ||
|
||
mod extensions; | ||
mod interceptor; | ||
mod macros; | ||
mod request; | ||
|
@@ -100,6 +102,7 @@ pub use async_trait::async_trait; | |
|
||
#[doc(inline)] | ||
pub use codec::Streaming; | ||
pub use extensions::Extensions; | ||
pub use interceptor::Interceptor; | ||
pub use request::{IntoRequest, IntoStreamingRequest, Request}; | ||
pub use response::Response; | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Its crazy how complex this is :(
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah I really wanna try and find a way to make this easier. Having to implement
NamedService
also complicates things a bit because you cannot take a middleware from tower and use it with tonics router.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah, though I don't have a proper solution either :(