Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(channel): Make channel feature additive #1574

Merged
merged 14 commits into from
Jun 19, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions examples/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -276,13 +276,13 @@ tracing = ["dep:tracing", "dep:tracing-subscriber"]
uds = ["tokio-stream/net", "dep:tower", "dep:hyper", "dep:hyper-util"]
streaming = ["tokio-stream", "dep:h2"]
mock = ["tokio-stream", "dep:tower", "dep:hyper-util"]
tower = ["dep:hyper", "dep:hyper-util", "dep:tower", "dep:http"]
tower = ["dep:hyper", "dep:hyper-util", "dep:tower", "tower?/timeout", "dep:http"]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does the previous commit fail to compile the examples? If so, IMO we should squash this commit into it.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does the previous commit fail to compile the examples?

Yes. Squashed these two commits.

json-codec = ["dep:serde", "dep:serde_json", "dep:bytes"]
compression = ["tonic/gzip"]
tls = ["tonic/tls"]
tls-rustls = ["dep:hyper", "dep:hyper-util", "dep:hyper-rustls", "dep:tower", "tower-http/util", "tower-http/add-extension", "dep:rustls-pemfile", "dep:tokio-rustls", "dep:pin-project", "dep:http-body-util"]
dynamic-load-balance = ["dep:tower"]
timeout = ["tokio/time", "dep:tower"]
timeout = ["tokio/time", "dep:tower", "tower?/timeout"]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same here.

tls-client-auth = ["tonic/tls"]
types = ["dep:tonic-types"]
h2c = ["dep:hyper", "dep:tower", "dep:http", "dep:hyper-util"]
Expand Down
27 changes: 17 additions & 10 deletions tonic/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -26,24 +26,29 @@ version = "0.11.0"
codegen = ["dep:async-trait"]
gzip = ["dep:flate2"]
zstd = ["dep:zstd"]
default = ["transport", "codegen", "prost"]
default = ["channel", "codegen", "prost"]
prost = ["dep:prost"]
tls = ["dep:rustls-pemfile", "transport", "dep:tokio-rustls", "dep:tokio", "tokio?/rt", "tokio?/macros"]
tls-roots = ["tls-roots-common", "dep:rustls-native-certs"]
tls-roots-common = ["tls"]
tls-roots-common = ["tls", "channel"]
tls-webpki-roots = ["tls-roots-common", "dep:webpki-roots"]
router = ["dep:axum"]
transport = [
"router",
"dep:async-stream",
"channel",
"dep:h2",
"dep:hyper", "dep:hyper-util", "dep:hyper-timeout",
"dep:hyper", "dep:hyper-util",
"dep:socket2",
"dep:tokio", "tokio?/macros", "tokio?/net", "tokio?/time",
"dep:tower",
"dep:tower", "tower?/util", "tower?/limit",
]
channel = [
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you update the crate root documentation to clarify a bit what the distinction is between channel and transport? I don't think the current docs are very clear:

transport: Enables the fully featured, batteries included client and server implementation based on hyper, tower and tokio. Enabled by default.
channel: Enables just the full featured channel/client portion of the transport feature.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Updated to reflect the state of these feature flags.

"transport",
"dep:hyper", "hyper?/client",
"dep:hyper-util", "hyper-util?/client-legacy",
"dep:tower", "tower?/balance", "tower?/buffer", "tower?/discover", "tower?/load", "tower?/make",
"dep:hyper-timeout",
]
channel = []

# [[bench]]
# name = "bench_main"
Expand Down Expand Up @@ -71,13 +76,12 @@ async-trait = {version = "0.1.13", optional = true}
# transport
async-stream = {version = "0.3", optional = true}
h2 = {version = "0.4", optional = true}
hyper = {version = "1", features = ["full"], optional = true}
hyper-util = { version = ">=0.1.4, <0.2", features = ["full"], optional = true }
hyper-timeout = {version = "0.5", optional = true}
hyper = {version = "1", features = ["http1", "http2", "server"], optional = true}
hyper-util = { version = ">=0.1.4, <0.2", features = ["service", "server-auto", "tokio"], optional = true }
socket2 = { version = ">=0.4.7, <0.6.0", optional = true, features = ["all"] }
tokio = {version = "1", default-features = false, optional = true}
tokio-stream = { version = "0.1", features = ["net"] }
tower = {version = "0.4.7", default-features = false, features = ["balance", "buffer", "discover", "limit", "load", "make", "timeout", "util"], optional = true}
tower = {version = "0.4.7", default-features = false, optional = true}
axum = {version = "0.7", default-features = false, optional = true}

# rustls
Expand All @@ -90,6 +94,9 @@ webpki-roots = { version = "0.26", optional = true }
flate2 = {version = "1.0", optional = true}
zstd = { version = "0.13.0", optional = true }

# channel
hyper-timeout = {version = "0.5", optional = true}

[dev-dependencies]
bencher = "0.1.5"
quickcheck = "1.0"
Expand Down
7 changes: 3 additions & 4 deletions tonic/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,9 @@
//!
//! # Feature Flags
//!
//! - `transport`: Enables the fully featured, batteries included client and server
//! implementation based on [`hyper`], [`tower`] and [`tokio`]. Enabled by default.
//! - `channel`: Enables just the full featured channel/client portion of the `transport`
//! feature.
//! - `transport`: Enables just the full featured server portion of the `channel` feature.
//! - `channel`: Enables the fully featured, batteries included client and server
//! implementation based on [`hyper`], [`tower`] and [`tokio`]. Enabled by default.
//! - `codegen`: Enables all the required exports and optional dependencies required
//! for [`tonic-build`]. Enabled by default.
//! - `tls`: Enables the `rustls` based TLS options for the `transport` feature. Not
Expand Down
6 changes: 4 additions & 2 deletions tonic/src/status.rs
Original file line number Diff line number Diff line change
Expand Up @@ -618,8 +618,10 @@ fn find_status_in_source_chain(err: &(dyn Error + 'static)) -> Option<Status> {
// matches the spec of:
// > The service is currently unavailable. This is most likely a transient condition that
// > can be corrected if retried with a backoff.
#[cfg(feature = "transport")]
if let Some(connect) = err.downcast_ref::<crate::transport::ConnectError>() {
#[cfg(feature = "channel")]
if let Some(connect) =
err.downcast_ref::<crate::transport::channel::service::ConnectError>()
{
return Some(Status::unavailable(connect.to_string()));
}

Expand Down
8 changes: 4 additions & 4 deletions tonic/src/transport/channel/endpoint.rs
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
use super::super::service;
#[cfg(feature = "tls")]
use super::service::TlsConnector;
use super::service::{self, Executor, SharedExec};
use super::Channel;
#[cfg(feature = "tls")]
use super::ClientTlsConfig;
#[cfg(feature = "tls")]
use crate::transport::service::TlsConnector;
use crate::transport::{service::SharedExec, Error, Executor};
use crate::transport::Error;
use bytes::Bytes;
use http::{uri::Uri, HeaderValue};
use hyper::rt;
Expand Down
4 changes: 2 additions & 2 deletions tonic/src/transport/channel/mod.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
//! Client implementation and builder.

mod endpoint;
pub(crate) mod service;
#[cfg(feature = "tls")]
#[cfg_attr(docsrs, doc(cfg(feature = "tls")))]
mod tls;
Expand All @@ -9,9 +10,8 @@ pub use endpoint::Endpoint;
#[cfg(feature = "tls")]
pub use tls::ClientTlsConfig;

use super::service::{Connection, DynamicServiceStream, SharedExec};
use self::service::{Connection, DynamicServiceStream, Executor, SharedExec};
use crate::body::BoxBody;
use crate::transport::Executor;
use bytes::Bytes;
use http::{
uri::{InvalidUri, Uri},
Expand Down
Original file line number Diff line number Diff line change
@@ -1,8 +1,7 @@
use super::SharedExec;
use super::{grpc_timeout::GrpcTimeout, reconnect::Reconnect, AddOrigin, UserAgent};
use super::{AddOrigin, Reconnect, SharedExec, UserAgent};
use crate::{
body::{boxed, BoxBody},
transport::{BoxFuture, Endpoint},
transport::{service::GrpcTimeout, BoxFuture, Endpoint},
};
use http::Uri;
use hyper::rt;
Expand Down Expand Up @@ -36,7 +35,7 @@ impl Connection {
C::Future: Unpin + Send,
C::Response: rt::Read + rt::Write + Unpin + Send + 'static,
{
let mut settings: Builder<super::SharedExec> = Builder::new(endpoint.executor.clone())
let mut settings: Builder<SharedExec> = Builder::new(endpoint.executor.clone())
.initial_stream_window_size(endpoint.init_stream_window_size)
.initial_connection_window_size(endpoint.init_connection_window_size)
.keep_alive_interval(endpoint.http2_keep_alive_interval)
Expand Down Expand Up @@ -158,12 +157,12 @@ impl tower::Service<http::Request<BoxBody>> for SendRequest {

struct MakeSendRequestService<C> {
connector: C,
executor: super::SharedExec,
settings: Builder<super::SharedExec>,
executor: SharedExec,
settings: Builder<SharedExec>,
}

impl<C> MakeSendRequestService<C> {
fn new(connector: C, executor: SharedExec, settings: Builder<super::SharedExec>) -> Self {
fn new(connector: C, executor: SharedExec, settings: Builder<SharedExec>) -> Self {
Self {
connector,
executor,
Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
use super::super::BoxFuture;
use super::io::BoxedIo;
use super::BoxedIo;
#[cfg(feature = "tls")]
use super::tls::TlsConnector;
use super::TlsConnector;
use crate::transport::BoxFuture;
use http::Uri;
use std::fmt;
use std::task::{Context, Poll};
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
use super::connection::Connection;
use crate::transport::Endpoint;
use super::super::{Connection, Endpoint};

use hyper_util::client::legacy::connect::HttpConnector;
use std::{
Expand Down
67 changes: 67 additions & 0 deletions tonic/src/transport/channel/service/io.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
use std::io::{self, IoSlice};
use std::pin::Pin;
use std::task::{Context, Poll};

use hyper::rt;
use hyper_util::client::legacy::connect::{Connected as HyperConnected, Connection};

pub(in crate::transport) trait Io:
rt::Read + rt::Write + Send + 'static
{
}

impl<T> Io for T where T: rt::Read + rt::Write + Send + 'static {}

pub(crate) struct BoxedIo(Pin<Box<dyn Io>>);

impl BoxedIo {
pub(in crate::transport) fn new<I: Io>(io: I) -> Self {
BoxedIo(Box::pin(io))
}
}

impl Connection for BoxedIo {
fn connected(&self) -> HyperConnected {
HyperConnected::new()
}
}

impl rt::Read for BoxedIo {
fn poll_read(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: rt::ReadBufCursor<'_>,
) -> Poll<io::Result<()>> {
Pin::new(&mut self.0).poll_read(cx, buf)
}
}

impl rt::Write for BoxedIo {
fn poll_write(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &[u8],
) -> Poll<io::Result<usize>> {
Pin::new(&mut self.0).poll_write(cx, buf)
}

fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
Pin::new(&mut self.0).poll_flush(cx)
}

fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
Pin::new(&mut self.0).poll_shutdown(cx)
}

fn poll_write_vectored(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
bufs: &[IoSlice<'_>],
) -> Poll<Result<usize, io::Error>> {
Pin::new(&mut self.0).poll_write_vectored(cx, bufs)
}

fn is_write_vectored(&self) -> bool {
self.0.is_write_vectored()
}
}
28 changes: 28 additions & 0 deletions tonic/src/transport/channel/service/mod.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
mod add_origin;
use self::add_origin::AddOrigin;

mod user_agent;
use self::user_agent::UserAgent;

mod reconnect;
use self::reconnect::Reconnect;

mod connection;
pub(super) use self::connection::Connection;

mod discover;
pub(super) use self::discover::DynamicServiceStream;

mod io;
use self::io::BoxedIo;

mod connector;
pub(crate) use self::connector::{ConnectError, Connector};

mod executor;
pub(super) use self::executor::{Executor, SharedExec};

#[cfg(feature = "tls")]
mod tls;
#[cfg(feature = "tls")]
pub(super) use self::tls::TlsConnector;
83 changes: 83 additions & 0 deletions tonic/src/transport/channel/service/tls.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
use std::fmt;
use std::io::Cursor;
use std::sync::Arc;

use hyper_util::rt::TokioIo;
use tokio::io::{AsyncRead, AsyncWrite};
use tokio_rustls::{
rustls::{pki_types::ServerName, ClientConfig, RootCertStore},
TlsConnector as RustlsConnector,
};

use super::io::BoxedIo;
use crate::transport::service::tls::{add_certs_from_pem, load_identity, TlsError, ALPN_H2};
use crate::transport::tls::{Certificate, Identity};

#[derive(Clone)]
pub(crate) struct TlsConnector {
config: Arc<ClientConfig>,
domain: Arc<ServerName<'static>>,
assume_http2: bool,
}

impl TlsConnector {
pub(crate) fn new(
ca_certs: Vec<Certificate>,
identity: Option<Identity>,
domain: &str,
assume_http2: bool,
) -> Result<Self, crate::Error> {
let builder = ClientConfig::builder();
let mut roots = RootCertStore::empty();

#[cfg(feature = "tls-roots")]
roots.add_parsable_certificates(rustls_native_certs::load_native_certs()?);

#[cfg(feature = "tls-webpki-roots")]
roots.extend(webpki_roots::TLS_SERVER_ROOTS.iter().cloned());

for cert in ca_certs {
add_certs_from_pem(&mut Cursor::new(cert), &mut roots)?;
}

let builder = builder.with_root_certificates(roots);
let mut config = match identity {
Some(identity) => {
let (client_cert, client_key) = load_identity(identity)?;
builder.with_client_auth_cert(client_cert, client_key)?
}
None => builder.with_no_client_auth(),
};

config.alpn_protocols.push(ALPN_H2.into());
Ok(Self {
config: Arc::new(config),
domain: Arc::new(ServerName::try_from(domain)?.to_owned()),
assume_http2,
})
}

pub(crate) async fn connect<I>(&self, io: I) -> Result<BoxedIo, crate::Error>
where
I: AsyncRead + AsyncWrite + Send + Unpin + 'static,
{
let io = RustlsConnector::from(self.config.clone())
.connect(self.domain.as_ref().to_owned(), io)
.await?;

// Generally we require ALPN to be negotiated, but if the user has
// explicitly set `assume_http2` to true, we'll allow it to be missing.
let (_, session) = io.get_ref();
let alpn_protocol = session.alpn_protocol();
if !(alpn_protocol == Some(ALPN_H2) || self.assume_http2) {
return Err(TlsError::H2NotNegotiated.into());
}
Ok(BoxedIo::new(TokioIo::new(io)))
}
}

impl fmt::Debug for TlsConnector {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("TlsConnector").finish()
}
}
2 changes: 1 addition & 1 deletion tonic/src/transport/channel/tls.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
use super::service::TlsConnector;
use crate::transport::{
service::TlsConnector,
tls::{Certificate, Identity},
Error,
};
Expand Down
Loading