diff --git a/tests/integration_tests/tests/user_agent.rs b/tests/integration_tests/tests/user_agent.rs
new file mode 100644
index 000000000..12b7b3f66
--- /dev/null
+++ b/tests/integration_tests/tests/user_agent.rs
@@ -0,0 +1,55 @@
+use futures_util::FutureExt;
+use integration_tests::pb::{test_client, test_server, Input, Output};
+use std::time::Duration;
+use tokio::sync::oneshot;
+use tonic::{
+ transport::{Endpoint, Server},
+ Request, Response, Status,
+};
+
+#[tokio::test]
+async fn writes_user_agent_header() {
+ struct Svc;
+
+ #[tonic::async_trait]
+ impl test_server::Test for Svc {
+ async fn unary_call(&self, req: Request) -> Result, Status> {
+ match req.metadata().get("user-agent") {
+ Some(_) => Ok(Response::new(Output {})),
+ None => Err(Status::internal("user-agent header is missing")),
+ }
+ }
+ }
+
+ let svc = test_server::TestServer::new(Svc);
+
+ let (tx, rx) = oneshot::channel::<()>();
+
+ let jh = tokio::spawn(async move {
+ Server::builder()
+ .add_service(svc)
+ .serve_with_shutdown("127.0.0.1:1322".parse().unwrap(), rx.map(drop))
+ .await
+ .unwrap();
+ });
+
+ tokio::time::delay_for(Duration::from_millis(100)).await;
+
+ let channel = Endpoint::from_static("http://127.0.0.1:1322")
+ .user_agent("my-client")
+ .expect("valid user agent")
+ .connect()
+ .await
+ .unwrap();
+
+ let mut client = test_client::TestClient::new(channel);
+
+ match client.unary_call(Input {}).await {
+ Ok(_) => {}
+ Err(status) => panic!("{}", status.message()),
+ }
+
+ tx.send(()).unwrap();
+
+ jh.await.unwrap();
+}
diff --git a/tonic/src/transport/channel/endpoint.rs b/tonic/src/transport/channel/endpoint.rs
index 0e379996e..c558ff531 100644
--- a/tonic/src/transport/channel/endpoint.rs
+++ b/tonic/src/transport/channel/endpoint.rs
@@ -6,7 +6,10 @@ use super::ClientTlsConfig;
use crate::transport::service::TlsConnector;
use crate::transport::Error;
use bytes::Bytes;
-use http::uri::{InvalidUri, Uri};
+use http::{
+ uri::{InvalidUri, Uri},
+ HeaderValue,
+};
use std::{
convert::{TryFrom, TryInto},
fmt,
@@ -20,6 +23,7 @@ use tower_make::MakeConnection;
#[derive(Clone)]
pub struct Endpoint {
pub(crate) uri: Uri,
+ pub(crate) user_agent: Option,
pub(crate) timeout: Option,
pub(crate) concurrency_limit: Option,
pub(crate) rate_limit: Option<(u64, Duration)>,
@@ -74,6 +78,30 @@ impl Endpoint {
Ok(Self::from(uri))
}
+ /// Set a custom user-agent header.
+ ///
+ /// `user_agent` will be prepended to Tonic's default user-agent string (`tonic/x.x.x`).
+ /// It must be a value that can be converted into a valid `http::HeaderValue` or building
+ /// the endpoint will fail.
+ /// ```
+ /// # use tonic::transport::Endpoint;
+ /// # let mut builder = Endpoint::from_static("https://example.com");
+ /// builder.user_agent("Greeter").expect("Greeter should be a valid header value");
+ /// // user-agent: "Greeter tonic/x.x.x"
+ /// ```
+ pub fn user_agent(self, user_agent: T) -> Result
+ where
+ T: TryInto,
+ {
+ user_agent
+ .try_into()
+ .map(|ua| Endpoint {
+ user_agent: Some(ua),
+ ..self
+ })
+ .map_err(|_| Error::new_invalid_user_agent())
+ }
+
/// Apply a timeout to each request.
///
/// ```
@@ -276,6 +304,7 @@ impl From for Endpoint {
fn from(uri: Uri) -> Self {
Self {
uri,
+ user_agent: None,
concurrency_limit: None,
rate_limit: None,
timeout: None,
diff --git a/tonic/src/transport/error.rs b/tonic/src/transport/error.rs
index 584164080..042e5172d 100644
--- a/tonic/src/transport/error.rs
+++ b/tonic/src/transport/error.rs
@@ -21,6 +21,7 @@ struct ErrorImpl {
pub(crate) enum Kind {
Transport,
InvalidUri,
+ InvalidUserAgent,
}
impl Error {
@@ -43,10 +44,15 @@ impl Error {
Error::new(Kind::InvalidUri)
}
+ pub(crate) fn new_invalid_user_agent() -> Self {
+ Error::new(Kind::InvalidUserAgent)
+ }
+
fn description(&self) -> &str {
match &self.inner.kind {
Kind::Transport => "transport error",
Kind::InvalidUri => "invalid URI",
+ Kind::InvalidUserAgent => "user agent is not a valid header value",
}
}
}
diff --git a/tonic/src/transport/service/connection.rs b/tonic/src/transport/service/connection.rs
index c3e8769fe..1875a145c 100644
--- a/tonic/src/transport/service/connection.rs
+++ b/tonic/src/transport/service/connection.rs
@@ -1,4 +1,4 @@
-use super::{layer::ServiceBuilderExt, reconnect::Reconnect, AddOrigin};
+use super::{layer::ServiceBuilderExt, reconnect::Reconnect, AddOrigin, UserAgent};
use crate::{body::BoxBody, transport::Endpoint};
use http::Uri;
use hyper::client::conn::Builder;
@@ -55,6 +55,7 @@ impl Connection {
let stack = ServiceBuilder::new()
.layer_fn(|s| AddOrigin::new(s, endpoint.uri.clone()))
+ .layer_fn(|s| UserAgent::new(s, endpoint.user_agent.clone()))
.optional_layer(endpoint.timeout.map(TimeoutLayer::new))
.optional_layer(endpoint.concurrency_limit.map(ConcurrencyLimitLayer::new))
.optional_layer(endpoint.rate_limit.map(|(l, d)| RateLimitLayer::new(l, d)))
diff --git a/tonic/src/transport/service/mod.rs b/tonic/src/transport/service/mod.rs
index 92453cdbf..eab3b40ef 100644
--- a/tonic/src/transport/service/mod.rs
+++ b/tonic/src/transport/service/mod.rs
@@ -8,6 +8,7 @@ mod reconnect;
mod router;
#[cfg(feature = "tls")]
mod tls;
+mod user_agent;
pub(crate) use self::add_origin::AddOrigin;
pub(crate) use self::connection::Connection;
@@ -18,3 +19,4 @@ pub(crate) use self::layer::ServiceBuilderExt;
pub(crate) use self::router::{Or, Routes};
#[cfg(feature = "tls")]
pub(crate) use self::tls::{TlsAcceptor, TlsConnector};
+pub(crate) use self::user_agent::UserAgent;
diff --git a/tonic/src/transport/service/user_agent.rs b/tonic/src/transport/service/user_agent.rs
new file mode 100644
index 000000000..6ceaea640
--- /dev/null
+++ b/tonic/src/transport/service/user_agent.rs
@@ -0,0 +1,70 @@
+use http::{header::USER_AGENT, HeaderValue, Request};
+use std::task::{Context, Poll};
+use tower_service::Service;
+
+const TONIC_USER_AGENT: &str = concat!("tonic/", env!("CARGO_PKG_VERSION"));
+
+#[derive(Debug)]
+pub(crate) struct UserAgent {
+ inner: T,
+ user_agent: HeaderValue,
+}
+
+impl UserAgent {
+ pub(crate) fn new(inner: T, user_agent: Option) -> Self {
+ let user_agent = user_agent
+ .map(|value| {
+ let mut buf = Vec::new();
+ buf.extend(value.as_bytes());
+ buf.push(b' ');
+ buf.extend(TONIC_USER_AGENT.as_bytes());
+ HeaderValue::from_bytes(&buf).expect("user-agent should be valid")
+ })
+ .unwrap_or(HeaderValue::from_static(TONIC_USER_AGENT));
+
+ Self { inner, user_agent }
+ }
+}
+
+impl Service> for UserAgent
+where
+ T: Service>,
+{
+ type Response = T::Response;
+ type Error = T::Error;
+ type Future = T::Future;
+
+ fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll> {
+ self.inner.poll_ready(cx)
+ }
+
+ fn call(&mut self, mut req: Request) -> Self::Future {
+ req.headers_mut()
+ .insert(USER_AGENT, self.user_agent.clone());
+
+ self.inner.call(req)
+ }
+}
+
+#[cfg(test)]
+mod tests {
+ use super::*;
+
+ struct Svc;
+
+ #[test]
+ fn sets_default_if_no_custom_user_agent() {
+ assert_eq!(
+ UserAgent::new(Svc, None).user_agent,
+ HeaderValue::from_static(TONIC_USER_AGENT)
+ )
+ }
+
+ #[test]
+ fn prepends_custom_user_agent_to_default() {
+ assert_eq!(
+ UserAgent::new(Svc, Some(HeaderValue::from_static("Greeter 1.1"))).user_agent,
+ HeaderValue::from_str(&format!("Greeter 1.1 {}", TONIC_USER_AGENT)).unwrap()
+ )
+ }
+}