Skip to content

Commit

Permalink
feat: Use http::Extensions directly
Browse files Browse the repository at this point in the history
  • Loading branch information
tottoto committed Jun 12, 2024
1 parent 9c1f2f9 commit 9bf800a
Show file tree
Hide file tree
Showing 4 changed files with 10 additions and 82 deletions.
75 changes: 0 additions & 75 deletions tonic/src/extensions.rs
Original file line number Diff line number Diff line change
@@ -1,78 +1,3 @@
use std::fmt;

/// A type map of protocol extensions.
///
/// `Extensions` can be used by [`Interceptor`] and [`Request`] to store extra data derived from
/// the underlying protocol.
///
/// [`Interceptor`]: crate::service::Interceptor
/// [`Request`]: crate::Request
#[derive(Default)]
pub struct Extensions {
inner: http::Extensions,
}

impl Extensions {
pub(crate) fn new() -> Self {
Self {
inner: http::Extensions::new(),
}
}

/// Insert a type into this `Extensions`.
///
/// If a extension of this type already existed, it will
/// be returned.
#[inline]
pub fn insert<T: Clone + Send + Sync + 'static>(&mut self, val: T) -> Option<T> {
self.inner.insert(val)
}

/// Get a reference to a type previously inserted on this `Extensions`.
#[inline]
pub fn get<T: Send + Sync + 'static>(&self) -> Option<&T> {
self.inner.get()
}

/// Get a mutable reference to a type previously inserted on this `Extensions`.
#[inline]
pub fn get_mut<T: Send + Sync + 'static>(&mut self) -> Option<&mut T> {
self.inner.get_mut()
}

/// Remove a type from this `Extensions`.
///
/// If a extension of this type existed, it will be returned.
#[inline]
pub fn remove<T: Send + Sync + 'static>(&mut self) -> Option<T> {
self.inner.remove()
}

/// Clear the `Extensions` of all inserted extensions.
#[inline]
pub fn clear(&mut self) {
self.inner.clear()
}

#[inline]
/// Convert from `http::Extensions`
pub fn from_http(http: http::Extensions) -> Self {
Self { inner: http }
}

/// Convert to `http::Extensions` and consume self.
#[inline]
pub fn into_http(self) -> http::Extensions {
self.inner
}
}

impl fmt::Debug for Extensions {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("Extensions").finish()
}
}

/// A gRPC Method info extension.
#[derive(Debug, Clone)]
pub struct GrpcMethod {
Expand Down
3 changes: 2 additions & 1 deletion tonic/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,8 @@ pub use async_trait::async_trait;

#[doc(inline)]
pub use codec::Streaming;
pub use extensions::{Extensions, GrpcMethod};
pub use extensions::GrpcMethod;
pub use http::Extensions;
pub use request::{IntoRequest, IntoStreamingRequest, Request};
pub use response::Response;
pub use status::{Code, Status};
Expand Down
6 changes: 3 additions & 3 deletions tonic/src/request.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ use crate::metadata::{MetadataMap, MetadataValue};
use crate::transport::server::TcpConnectInfo;
#[cfg(feature = "tls")]
use crate::transport::server::TlsConnectInfo;
use crate::Extensions;
use http::Extensions;
#[cfg(feature = "transport")]
use std::net::SocketAddr;
#[cfg(feature = "tls")]
Expand Down Expand Up @@ -161,7 +161,7 @@ impl<T> Request<T> {
Request {
metadata: MetadataMap::from_headers(parts.headers),
message,
extensions: Extensions::from_http(parts.extensions),
extensions: parts.extensions,
}
}

Expand All @@ -187,7 +187,7 @@ impl<T> Request<T> {
SanitizeHeaders::Yes => self.metadata.into_sanitized_headers(),
SanitizeHeaders::No => self.metadata.into_headers(),
};
*request.extensions_mut() = self.extensions.into_http();
*request.extensions_mut() = self.extensions;

request
}
Expand Down
8 changes: 5 additions & 3 deletions tonic/src/response.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
use crate::{metadata::MetadataMap, Extensions};
use http::Extensions;

use crate::metadata::MetadataMap;

/// A gRPC response and metadata from an RPC call.
#[derive(Debug)]
Expand Down Expand Up @@ -73,7 +75,7 @@ impl<T> Response<T> {
Response {
metadata: MetadataMap::from_headers(head.headers),
message,
extensions: Extensions::from_http(head.extensions),
extensions: head.extensions,
}
}

Expand All @@ -82,7 +84,7 @@ impl<T> Response<T> {

*res.version_mut() = http::Version::HTTP_2;
*res.headers_mut() = self.metadata.into_sanitized_headers();
*res.extensions_mut() = self.extensions.into_http();
*res.extensions_mut() = self.extensions;

res
}
Expand Down

0 comments on commit 9bf800a

Please sign in to comment.