From d795e33a293ffdecb5c76330e23f27c7707ae7a5 Mon Sep 17 00:00:00 2001 From: John Howard Date: Thu, 19 Sep 2024 14:53:18 -0700 Subject: [PATCH] auth: async-ify and clean up (#1321) The Interceptor usage in tonic is not really useful for us since we make our own client for other reasons (tls, etc) so we can already easily insert some headers. It also makes us do some awkward DefaultIncoming stuff, and doesn't allow async. This moves the auth to async, which is generally a good idea and more flexible, and cleans up some of the cruft --- src/identity/auth.rs | 46 +++++++++-------------------- src/identity/caclient.rs | 8 ++--- src/tls/control.rs | 64 +++++++++++++--------------------------- src/xds/client.rs | 12 ++++---- 4 files changed, 43 insertions(+), 87 deletions(-) diff --git a/src/identity/auth.rs b/src/identity/auth.rs index ee53058c2..af7c70c54 100644 --- a/src/identity/auth.rs +++ b/src/identity/auth.rs @@ -15,10 +15,6 @@ use std::io; use std::path::PathBuf; -use tonic::metadata::AsciiMetadataValue; -use tonic::service::Interceptor; -use tonic::{Code, Request, Status}; - #[derive(Clone, Debug, PartialEq, Eq)] pub enum AuthSource { // JWT authentication source which contains the token file path and the cluster id. @@ -30,15 +26,18 @@ pub enum AuthSource { } impl AuthSource { - fn read_token(&self) -> io::Result, String)>> { - Ok(match self { + pub async fn insert_headers(&self, request: &mut http::HeaderMap) -> anyhow::Result<()> { + const AUTHORIZATION: &str = "authorization"; + const CLUSTER: &str = "clusterid"; + match self { AuthSource::Token(path, cluster_id) => { - let token = load_token(path).map(|mut t| { + let token = load_token(path).await.map(|mut t| { let mut bearer: Vec = b"Bearer ".to_vec(); bearer.append(&mut t); bearer })?; - Some((token, cluster_id.to_string())) + request.insert(AUTHORIZATION, token.try_into()?); + request.insert(CLUSTER, cluster_id.try_into()?); } AuthSource::StaticToken(token, cluster_id) => { let token = { @@ -46,15 +45,17 @@ impl AuthSource { bearer.extend_from_slice(token.as_bytes()); bearer }; - Some((token, cluster_id.to_string())) + request.insert(AUTHORIZATION, token.try_into()?); + request.insert(CLUSTER, cluster_id.try_into()?); } - AuthSource::None => None, - }) + AuthSource::None => {} + } + Ok(()) } } -fn load_token(path: &PathBuf) -> io::Result> { - let t = std::fs::read(path)?; +async fn load_token(path: &PathBuf) -> io::Result> { + let t = tokio::fs::read(path).await?; if t.is_empty() { return Err(io::Error::new( @@ -64,22 +65,3 @@ fn load_token(path: &PathBuf) -> io::Result> { } Ok(t) } - -impl Interceptor for AuthSource { - fn call(&mut self, mut request: Request<()>) -> Result, Status> { - if let Some((token, cluster_id)) = self - .read_token() - .map_err(|e| Status::new(Code::Unauthenticated, e.to_string()))? - { - let token = AsciiMetadataValue::try_from(token) - .map_err(|e| Status::new(Code::Unauthenticated, e.to_string()))?; - request.metadata_mut().insert("authorization", token); - if !cluster_id.is_empty() { - let id = AsciiMetadataValue::try_from(cluster_id.as_bytes().to_vec()) - .map_err(|e| Status::new(Code::Unauthenticated, e.to_string()))?; - request.metadata_mut().insert("clusterid", id); - } - } - Ok(request) - } -} diff --git a/src/identity/caclient.rs b/src/identity/caclient.rs index 8381af21b..7102e3c36 100644 --- a/src/identity/caclient.rs +++ b/src/identity/caclient.rs @@ -17,7 +17,6 @@ use std::collections::BTreeMap; use async_trait::async_trait; use prost_types::value::Kind; use prost_types::Struct; -use tonic::codegen::InterceptedService; use tracing::{error, instrument, warn}; @@ -29,7 +28,7 @@ use crate::xds::istio::ca::istio_certificate_service_client::IstioCertificateSer use crate::xds::istio::ca::IstioCertificateRequest; pub struct CaClient { - pub client: IstioCertificateServiceClient>, + pub client: IstioCertificateServiceClient, pub enable_impersonated_identity: bool, pub secret_ttl: i64, } @@ -42,9 +41,8 @@ impl CaClient { enable_impersonated_identity: bool, secret_ttl: i64, ) -> Result { - let svc = tls::grpc_connector(address, cert_provider.fetch_cert().await?)?; - - let client = IstioCertificateServiceClient::with_interceptor(svc, auth); + let svc = tls::grpc_connector(address, auth, cert_provider.fetch_cert().await?)?; + let client = IstioCertificateServiceClient::new(svc); Ok(CaClient { client, enable_impersonated_identity, diff --git a/src/tls/control.rs b/src/tls/control.rs index a00cdaa6f..1c6f60a62 100644 --- a/src/tls/control.rs +++ b/src/tls/control.rs @@ -13,10 +13,10 @@ // limitations under the License. use crate::config::RootCert; +use crate::identity::AuthSource; use crate::tls::lib::provider; use crate::tls::{ControlPlaneClientCertProvider, Error, WorkloadCertificate}; -use bytes::Bytes; -use http_body::{Body, Frame}; +use hyper::body::Incoming; use hyper::Uri; use hyper_rustls::HttpsConnector; use hyper_util::client::legacy::connect::HttpConnector; @@ -24,11 +24,9 @@ use rustls::ClientConfig; use std::future::Future; use std::io::Cursor; use std::pin::Pin; - -use hyper::body::Incoming; +use std::sync::Arc; use std::task::{Context, Poll}; use std::time::Duration; - use tonic::body::BoxBody; async fn root_to_store(root_cert: &RootCert) -> Result { @@ -89,20 +87,19 @@ async fn control_plane_client_config(root_cert: &RootCert) -> Result, BoxBody>; #[derive(Clone, Debug)] pub struct TlsGrpcChannel { uri: Uri, client: hyper_util::client::legacy::Client, BoxBody>, + auth: Arc, } /// grpc_connector provides a client TLS channel for gRPC requests. -pub async fn grpc_tls_connector(uri: String, root_cert: RootCert) -> Result { - grpc_connector(uri, control_plane_client_config(&root_cert).await?) -} - -/// grpc_connector provides a client TLS channel for gRPC requests. -pub fn grpc_connector(uri: String, cc: ClientConfig) -> Result { +pub fn grpc_connector( + uri: String, + auth: AuthSource, + cc: ClientConfig, +) -> Result { let uri = Uri::try_from(uri)?; let _is_localhost_call = uri.host() == Some("localhost"); let mut http: HttpConnector = HttpConnector::new(); @@ -131,37 +128,16 @@ pub fn grpc_connector(uri: String, cc: ClientConfig) -> Result, - cx: &mut Context<'_>, - ) -> Poll, Self::Error>>> { - match self.get_mut() { - DefaultIncoming::Some(ref mut i) => Pin::new(i).poll_frame(cx), - DefaultIncoming::Empty => Pin::new(&mut http_body_util::Empty::::new()) - .poll_frame(cx) - .map_err(|_| unreachable!()), - } - } + Ok(TlsGrpcChannel { + uri, + auth: Arc::new(auth), + client, + }) } impl tower::Service> for TlsGrpcChannel { - type Response = http::Response; - type Error = hyper_util::client::legacy::Error; + type Response = http::Response; + type Error = anyhow::Error; type Future = Pin> + Send>>; fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll> { @@ -181,10 +157,12 @@ impl tower::Service> for TlsGrpcChannel { } let uri = uri.build().expect("uri must be valid"); *req.uri_mut() = uri; - let future = self.client.request(req); + + let client = self.client.clone(); + let auth = self.auth.clone(); Box::pin(async move { - let res = future.await?; - Ok(res.map(DefaultIncoming::Some)) + auth.insert_headers(req.headers_mut()).await?; + Ok(client.request(req).await?) }) } } diff --git a/src/xds/client.rs b/src/xds/client.rs index ecb0f0285..e39406744 100644 --- a/src/xds/client.rs +++ b/src/xds/client.rs @@ -619,16 +619,14 @@ impl AdsClient { let addr = self.config.address.clone(); let tls_grpc_channel = tls::grpc_connector( self.config.address.clone(), + self.config.auth.clone(), self.config.tls_builder.fetch_cert().await?, )?; - let ads_connection = AggregatedDiscoveryServiceClient::with_interceptor( - tls_grpc_channel, - self.config.auth.clone(), - ) - .max_decoding_message_size(200 * 1024 * 1024) - .delta_aggregated_resources(tonic::Request::new(outbound)) - .await; + let ads_connection = AggregatedDiscoveryServiceClient::new(tls_grpc_channel) + .max_decoding_message_size(200 * 1024 * 1024) + .delta_aggregated_resources(tonic::Request::new(outbound)) + .await; let mut response_stream = ads_connection .map_err(|src| Error::Connection(addr, src))?