diff --git a/src/identity/auth.rs b/src/identity/auth.rs index ee53058c2..af7c70c54 100644 --- a/src/identity/auth.rs +++ b/src/identity/auth.rs @@ -15,10 +15,6 @@ use std::io; use std::path::PathBuf; -use tonic::metadata::AsciiMetadataValue; -use tonic::service::Interceptor; -use tonic::{Code, Request, Status}; - #[derive(Clone, Debug, PartialEq, Eq)] pub enum AuthSource { // JWT authentication source which contains the token file path and the cluster id. @@ -30,15 +26,18 @@ pub enum AuthSource { } impl AuthSource { - fn read_token(&self) -> io::Result, String)>> { - Ok(match self { + pub async fn insert_headers(&self, request: &mut http::HeaderMap) -> anyhow::Result<()> { + const AUTHORIZATION: &str = "authorization"; + const CLUSTER: &str = "clusterid"; + match self { AuthSource::Token(path, cluster_id) => { - let token = load_token(path).map(|mut t| { + let token = load_token(path).await.map(|mut t| { let mut bearer: Vec = b"Bearer ".to_vec(); bearer.append(&mut t); bearer })?; - Some((token, cluster_id.to_string())) + request.insert(AUTHORIZATION, token.try_into()?); + request.insert(CLUSTER, cluster_id.try_into()?); } AuthSource::StaticToken(token, cluster_id) => { let token = { @@ -46,15 +45,17 @@ impl AuthSource { bearer.extend_from_slice(token.as_bytes()); bearer }; - Some((token, cluster_id.to_string())) + request.insert(AUTHORIZATION, token.try_into()?); + request.insert(CLUSTER, cluster_id.try_into()?); } - AuthSource::None => None, - }) + AuthSource::None => {} + } + Ok(()) } } -fn load_token(path: &PathBuf) -> io::Result> { - let t = std::fs::read(path)?; +async fn load_token(path: &PathBuf) -> io::Result> { + let t = tokio::fs::read(path).await?; if t.is_empty() { return Err(io::Error::new( @@ -64,22 +65,3 @@ fn load_token(path: &PathBuf) -> io::Result> { } Ok(t) } - -impl Interceptor for AuthSource { - fn call(&mut self, mut request: Request<()>) -> Result, Status> { - if let Some((token, cluster_id)) = self - .read_token() - .map_err(|e| Status::new(Code::Unauthenticated, e.to_string()))? - { - let token = AsciiMetadataValue::try_from(token) - .map_err(|e| Status::new(Code::Unauthenticated, e.to_string()))?; - request.metadata_mut().insert("authorization", token); - if !cluster_id.is_empty() { - let id = AsciiMetadataValue::try_from(cluster_id.as_bytes().to_vec()) - .map_err(|e| Status::new(Code::Unauthenticated, e.to_string()))?; - request.metadata_mut().insert("clusterid", id); - } - } - Ok(request) - } -} diff --git a/src/identity/caclient.rs b/src/identity/caclient.rs index 8381af21b..7102e3c36 100644 --- a/src/identity/caclient.rs +++ b/src/identity/caclient.rs @@ -17,7 +17,6 @@ use std::collections::BTreeMap; use async_trait::async_trait; use prost_types::value::Kind; use prost_types::Struct; -use tonic::codegen::InterceptedService; use tracing::{error, instrument, warn}; @@ -29,7 +28,7 @@ use crate::xds::istio::ca::istio_certificate_service_client::IstioCertificateSer use crate::xds::istio::ca::IstioCertificateRequest; pub struct CaClient { - pub client: IstioCertificateServiceClient>, + pub client: IstioCertificateServiceClient, pub enable_impersonated_identity: bool, pub secret_ttl: i64, } @@ -42,9 +41,8 @@ impl CaClient { enable_impersonated_identity: bool, secret_ttl: i64, ) -> Result { - let svc = tls::grpc_connector(address, cert_provider.fetch_cert().await?)?; - - let client = IstioCertificateServiceClient::with_interceptor(svc, auth); + let svc = tls::grpc_connector(address, auth, cert_provider.fetch_cert().await?)?; + let client = IstioCertificateServiceClient::new(svc); Ok(CaClient { client, enable_impersonated_identity, diff --git a/src/tls/control.rs b/src/tls/control.rs index a00cdaa6f..1c6f60a62 100644 --- a/src/tls/control.rs +++ b/src/tls/control.rs @@ -13,10 +13,10 @@ // limitations under the License. use crate::config::RootCert; +use crate::identity::AuthSource; use crate::tls::lib::provider; use crate::tls::{ControlPlaneClientCertProvider, Error, WorkloadCertificate}; -use bytes::Bytes; -use http_body::{Body, Frame}; +use hyper::body::Incoming; use hyper::Uri; use hyper_rustls::HttpsConnector; use hyper_util::client::legacy::connect::HttpConnector; @@ -24,11 +24,9 @@ use rustls::ClientConfig; use std::future::Future; use std::io::Cursor; use std::pin::Pin; - -use hyper::body::Incoming; +use std::sync::Arc; use std::task::{Context, Poll}; use std::time::Duration; - use tonic::body::BoxBody; async fn root_to_store(root_cert: &RootCert) -> Result { @@ -89,20 +87,19 @@ async fn control_plane_client_config(root_cert: &RootCert) -> Result, BoxBody>; #[derive(Clone, Debug)] pub struct TlsGrpcChannel { uri: Uri, client: hyper_util::client::legacy::Client, BoxBody>, + auth: Arc, } /// grpc_connector provides a client TLS channel for gRPC requests. -pub async fn grpc_tls_connector(uri: String, root_cert: RootCert) -> Result { - grpc_connector(uri, control_plane_client_config(&root_cert).await?) -} - -/// grpc_connector provides a client TLS channel for gRPC requests. -pub fn grpc_connector(uri: String, cc: ClientConfig) -> Result { +pub fn grpc_connector( + uri: String, + auth: AuthSource, + cc: ClientConfig, +) -> Result { let uri = Uri::try_from(uri)?; let _is_localhost_call = uri.host() == Some("localhost"); let mut http: HttpConnector = HttpConnector::new(); @@ -131,37 +128,16 @@ pub fn grpc_connector(uri: String, cc: ClientConfig) -> Result, - cx: &mut Context<'_>, - ) -> Poll, Self::Error>>> { - match self.get_mut() { - DefaultIncoming::Some(ref mut i) => Pin::new(i).poll_frame(cx), - DefaultIncoming::Empty => Pin::new(&mut http_body_util::Empty::::new()) - .poll_frame(cx) - .map_err(|_| unreachable!()), - } - } + Ok(TlsGrpcChannel { + uri, + auth: Arc::new(auth), + client, + }) } impl tower::Service> for TlsGrpcChannel { - type Response = http::Response; - type Error = hyper_util::client::legacy::Error; + type Response = http::Response; + type Error = anyhow::Error; type Future = Pin> + Send>>; fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll> { @@ -181,10 +157,12 @@ impl tower::Service> for TlsGrpcChannel { } let uri = uri.build().expect("uri must be valid"); *req.uri_mut() = uri; - let future = self.client.request(req); + + let client = self.client.clone(); + let auth = self.auth.clone(); Box::pin(async move { - let res = future.await?; - Ok(res.map(DefaultIncoming::Some)) + auth.insert_headers(req.headers_mut()).await?; + Ok(client.request(req).await?) }) } } diff --git a/src/xds/client.rs b/src/xds/client.rs index ecb0f0285..e39406744 100644 --- a/src/xds/client.rs +++ b/src/xds/client.rs @@ -619,16 +619,14 @@ impl AdsClient { let addr = self.config.address.clone(); let tls_grpc_channel = tls::grpc_connector( self.config.address.clone(), + self.config.auth.clone(), self.config.tls_builder.fetch_cert().await?, )?; - let ads_connection = AggregatedDiscoveryServiceClient::with_interceptor( - tls_grpc_channel, - self.config.auth.clone(), - ) - .max_decoding_message_size(200 * 1024 * 1024) - .delta_aggregated_resources(tonic::Request::new(outbound)) - .await; + let ads_connection = AggregatedDiscoveryServiceClient::new(tls_grpc_channel) + .max_decoding_message_size(200 * 1024 * 1024) + .delta_aggregated_resources(tonic::Request::new(outbound)) + .await; let mut response_stream = ads_connection .map_err(|src| Error::Connection(addr, src))?