From 2a84e33df86c50c330901bccc60e59dbb48f3d7f Mon Sep 17 00:00:00 2001 From: John Howard Date: Thu, 4 Apr 2024 13:21:07 -0700 Subject: [PATCH] handle various usage of unwrap to proper error handling For https://github.com/istio/ztunnel/issues/9 --- src/admin.rs | 89 ++++++++++++++------------------ src/hyper_util.rs | 33 +++++++----- src/identity/manager.rs | 4 +- src/metrics/server.rs | 11 ++-- src/proxy/inbound.rs | 18 +++---- src/proxy/inbound_passthrough.rs | 2 +- src/proxy/outbound.rs | 16 +++--- src/proxy/socks5.rs | 4 +- src/tls/certificate.rs | 4 +- src/tls/control.rs | 17 +++--- src/tls/workload.rs | 8 ++- 11 files changed, 112 insertions(+), 94 deletions(-) diff --git a/src/admin.rs b/src/admin.rs index d3936374c..3a487eaf0 100644 --- a/src/admin.rs +++ b/src/admin.rs @@ -30,6 +30,7 @@ use hyper::{header::HeaderValue, header::CONTENT_TYPE, Request, Response}; use pprof::protos::Message; use std::borrow::Borrow; use std::collections::HashMap; + use std::str::FromStr; use std::sync::Arc; use std::time::SystemTime; @@ -122,25 +123,24 @@ impl Service { pub fn spawn(self) { self.s.spawn(|state, req| async move { match req.uri().path() { - "/debug/pprof/profile" => Ok(handle_pprof(req).await), - "/debug/pprof/heap" => Ok(handle_jemalloc_pprof_heapgen(req).await), + "/debug/pprof/profile" => handle_pprof(req).await, + "/debug/pprof/heap" => handle_jemalloc_pprof_heapgen(req).await, "/quitquitquit" => Ok(handle_server_shutdown( state.shutdown_trigger.clone(), req, state.config.self_termination_deadline, ) .await), - "/config_dump" => Ok(handle_config_dump( - ConfigDump { + "/config_dump" => { + handle_config_dump(ConfigDump { proxy_state: state.proxy_state.clone(), static_config: Default::default(), version: BuildInfo::new(), config: state.config.clone(), certificates: dump_certs(state.cert_manager.borrow()).await, - }, - // req, // bring this back if we start using it - ) - .await), + }) + .await + } "/logging" => Ok(handle_logging(req).await), "/" => Ok(handle_dashboard(req, &state.handlers).await), _ => match Self::find_handler(state.as_ref(), req.uri().path()) { @@ -245,30 +245,22 @@ async fn dump_certs(cert_manager: &SecretManager) -> Vec { dump } -async fn handle_pprof(_req: Request) -> Response> { +async fn handle_pprof(_req: Request) -> anyhow::Result>> { let guard = pprof::ProfilerGuardBuilder::default() .frequency(1000) // .blocklist(&["libc", "libgcc", "pthread", "vdso"]) - .build() - .unwrap(); + .build()?; tokio::time::sleep(Duration::from_secs(10)).await; - match guard.report().build() { - Ok(report) => { - let profile = report.pprof().unwrap(); + let report = guard.report().build()?; + let profile = report.pprof()?; - let body = profile.write_to_bytes().unwrap(); + let body = profile.write_to_bytes()?; - Response::builder() - .status(hyper::StatusCode::OK) - .body(body.into()) - .unwrap() - } - Err(err) => plaintext_response( - hyper::StatusCode::INTERNAL_SERVER_ERROR, - format!("failed to build profile: {err}\n"), - ), - } + Ok(Response::builder() + .status(hyper::StatusCode::OK) + .body(body.into()) + .expect("builder with known status code should not fail")) } async fn handle_server_shutdown( @@ -291,10 +283,7 @@ async fn handle_server_shutdown( } } -async fn handle_config_dump( - mut dump: ConfigDump, - // _req: Request, -) -> Response> { +async fn handle_config_dump(mut dump: ConfigDump) -> anyhow::Result>> { if let Some(cfg) = dump.config.local_xds_config.clone() { match cfg.read_to_string().await { Ok(data) => match serde_yaml::from_str(&data) { @@ -311,11 +300,11 @@ async fn handle_config_dump( } } - let body = serde_json::to_string_pretty(&dump).unwrap(); - Response::builder() + let body = serde_json::to_string_pretty(&dump)?; + Ok(Response::builder() .status(hyper::StatusCode::OK) .body(body.into()) - .unwrap() + .expect("builder with known status code should not fail")) } //mirror envoy's behavior: https://www.envoyproxy.io/docs/envoy/latest/operations/admin#post--logging @@ -393,33 +382,31 @@ fn change_log_level(reset: bool, level: &str) -> Response> { } #[cfg(feature = "jemalloc")] -async fn handle_jemalloc_pprof_heapgen(_req: Request) -> Response> { - let mut prof_ctl = jemalloc_pprof::PROF_CTL.as_ref().unwrap().lock().await; +async fn handle_jemalloc_pprof_heapgen( + _req: Request, +) -> anyhow::Result>> { + let mut prof_ctl = jemalloc_pprof::PROF_CTL.as_ref()?.lock().await; if !prof_ctl.activated() { - Response::builder() + return Ok(Response::builder() .status(hyper::StatusCode::INTERNAL_SERVER_ERROR) .body("jemalloc not enabled".into()) - .unwrap() - } else { - let pprof = prof_ctl.dump_pprof().map_err(|err| { - Response::builder() - .status(hyper::StatusCode::INTERNAL_SERVER_ERROR) - .body(err) - .unwrap() - }); - Response::builder() - .status(hyper::StatusCode::OK) - .body(Bytes::from(pprof.unwrap()).into()) - .unwrap() + .expect("builder with known status code should not fail")); } + let pprof = prof_ctl.dump_pprof()?; + Ok(Response::builder() + .status(hyper::StatusCode::OK) + .body(Bytes::from(pprof?).into()) + .expect("builder with known status code should not fail")) } #[cfg(not(feature = "jemalloc"))] -async fn handle_jemalloc_pprof_heapgen(_req: Request) -> Response> { - Response::builder() +async fn handle_jemalloc_pprof_heapgen( + _req: Request, +) -> anyhow::Result>> { + Ok(Response::builder() .status(hyper::StatusCode::INTERNAL_SERVER_ERROR) .body("jemalloc not enabled".into()) - .unwrap() + .expect("builder with known status code should not fail")) } fn base64_encode(data: String) -> String { @@ -732,7 +719,7 @@ mod tests { // // this could happen for a variety of reasons; for example some types // may need custom serialize/deserialize to be keys in a map, like NetworkAddress - let resp = handle_config_dump(dump).await; + let resp = handle_config_dump(dump).await.unwrap(); let resp_bytes = resp .body() diff --git a/src/hyper_util.rs b/src/hyper_util.rs index 5e63ad70f..cc02289d9 100644 --- a/src/hyper_util.rs +++ b/src/hyper_util.rs @@ -12,6 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. +use std::convert::Infallible; use std::net::SocketAddr; use std::sync::Arc; use std::task::{Context, Poll}; @@ -23,6 +24,7 @@ use std::{ use bytes::Bytes; use drain::Watch; +use futures_util::TryFutureExt; use http_body_util::Full; use hyper::client; use hyper::rt::Sleep; @@ -205,7 +207,7 @@ impl Server { where S: Send + Sync + 'static, F: Fn(Arc, Request) -> R + Send + Sync + 'static, - R: Future>, hyper::Error>> + Send + Sync + 'static, + R: Future>, anyhow::Error>> + Send + Sync + 'static, { use futures_util::StreamExt as OtherStreamExt; let address = self.address(); @@ -229,18 +231,25 @@ impl Server { let f = f.clone(); let state = state.clone(); tokio::spawn(async move { - let serve = http1_server() - .half_close(true) - .header_read_timeout(Duration::from_secs(2)) - .max_buf_size(8 * 1024) - .serve_connection( - hyper_util::rt::TokioIo::new(socket), - hyper::service::service_fn(move |req| { - let state = state.clone(); + let serve = + http1_server() + .half_close(true) + .header_read_timeout(Duration::from_secs(2)) + .max_buf_size(8 * 1024) + .serve_connection( + hyper_util::rt::TokioIo::new(socket), + hyper::service::service_fn(move |req| { + let state = state.clone(); - f(state, req) - }), - ); + // Failures would abort the whole connection; we just want to return an HTTP error + f(state, req).or_else(|err| async move { + Ok::>, Infallible>(Response::builder() + .status(hyper::StatusCode::INTERNAL_SERVER_ERROR) + .body(err.to_string().into()) + .expect("builder with known status code should not fail")) + }) + }), + ); // Wait for drain to signal or connection serving to complete match futures_util::future::select(Box::pin(drain.signaled()), serve).await { // We got a shutdown request. Start gracful shutdown and wait for the pending requests to complete. diff --git a/src/identity/manager.rs b/src/identity/manager.rs index af288a267..35088e29f 100644 --- a/src/identity/manager.rs +++ b/src/identity/manager.rs @@ -356,7 +356,7 @@ impl Worker { }, // Initiate the next fetch. true = maybe_sleep_until(next), if fetches.len() < self.concurrency as usize => { - let (id, _) = pending.pop().unwrap(); + let (id, _) = pending.pop().expect("pending should always have an element at this point"); processing.insert(id.to_owned(), Fetch::Processing); fetches.push(async move { let res = self.client.fetch_certificate(&id).await; @@ -426,7 +426,7 @@ impl fmt::Debug for SecretManager { impl SecretManager { pub async fn new(cfg: crate::config::Config) -> Result { let caclient = CaClient::new( - cfg.ca_address.unwrap(), + cfg.ca_address.expect("ca_address must be set to use CA"), Box::new(tls::ControlPlaneAuthentication::RootCert( cfg.ca_root_cert.clone(), )), diff --git a/src/metrics/server.rs b/src/metrics/server.rs index 7984e71b3..866193920 100644 --- a/src/metrics/server.rs +++ b/src/metrics/server.rs @@ -61,8 +61,13 @@ async fn handle_metrics( _req: Request, ) -> Response> { let mut buf = String::new(); - let reg = reg.lock().unwrap(); - encode(&mut buf, ®).unwrap(); + let reg = reg.lock().expect("mutex"); + if let Err(err) = encode(&mut buf, ®) { + return Response::builder() + .status(hyper::StatusCode::INTERNAL_SERVER_ERROR) + .body(err.to_string().into()) + .expect("builder with known status code should not fail"); + } Response::builder() .status(hyper::StatusCode::OK) @@ -71,5 +76,5 @@ async fn handle_metrics( "application/openmetrics-text;charset=utf-8;version=1.0.0", ) .body(buf.into()) - .unwrap() + .expect("builder with known status code should not fail") } diff --git a/src/proxy/inbound.rs b/src/proxy/inbound.rs index a05844dc3..5ea67f287 100644 --- a/src/proxy/inbound.rs +++ b/src/proxy/inbound.rs @@ -65,7 +65,7 @@ impl Inbound { // Override with our explicitly configured setting pi.cfg.enable_original_source = Some(transparent); info!( - address=%listener.local_addr().unwrap(), + address=%listener.local_addr().expect("local_addr available"), component="inbound", transparent, "listener established", @@ -78,7 +78,7 @@ impl Inbound { } pub(super) fn address(&self) -> SocketAddr { - self.listener.local_addr().unwrap() + self.listener.local_addr().expect("local_addr available") } pub(super) async fn run(self) { @@ -97,7 +97,7 @@ impl Inbound { let (raw_socket, ssl) = tls.get_ref(); let src_identity: Option = tls::identity_from_connection(ssl); let dst = crate::socket::orig_dst_addr_or_default(raw_socket); - let src = to_canonical(raw_socket.peer_addr().unwrap()); + let src = to_canonical(raw_socket.peer_addr().expect("peer_addr available")); let pi = self.pi.clone(); let connection_manager = self.pi.connection_manager.clone(); let drain = sub_drain.clone(); @@ -315,7 +315,7 @@ impl Inbound { return Ok(Response::builder() .status(StatusCode::BAD_REQUEST) .body(Empty::new()) - .unwrap()); + .expect("builder with known status code")); } }; @@ -328,7 +328,7 @@ impl Inbound { return Ok(Response::builder() .status(StatusCode::BAD_REQUEST) .body(Empty::new()) - .unwrap()); + .expect("builder with known status code")); } }; @@ -371,7 +371,7 @@ impl Inbound { return Ok(Response::builder() .status(StatusCode::UNAUTHORIZED) .body(Empty::new()) - .unwrap()); + .expect("builder with known status code should not fail")); } // This check should be removed in favor of an L4 policy check // We should express as policy whether or not traffic is allowed to bypass a waypoint @@ -381,7 +381,7 @@ impl Inbound { return Ok(Response::builder() .status(StatusCode::UNAUTHORIZED) .body(Empty::new()) - .unwrap()); + .expect("builder with known status code should not fail")); } let source_ip = if from_waypoint { // If the request is from our waypoint, trust the Forwarded header. @@ -457,7 +457,7 @@ impl Inbound { Ok(Response::builder() .status(status_code) .body(Empty::new()) - .unwrap()) + .expect("builder with known status code should not fail")) } // Return the 404 Not Found for other routes. method => { @@ -465,7 +465,7 @@ impl Inbound { Ok(Response::builder() .status(StatusCode::NOT_FOUND) .body(Empty::new()) - .unwrap()) + .expect("builder with known status code should not fail")) } } } diff --git a/src/proxy/inbound_passthrough.rs b/src/proxy/inbound_passthrough.rs index 8c8f0c633..2fba826c6 100644 --- a/src/proxy/inbound_passthrough.rs +++ b/src/proxy/inbound_passthrough.rs @@ -49,7 +49,7 @@ impl InboundPassthrough { pi.cfg.enable_original_source = Some(transparent); info!( - address=%listener.local_addr().unwrap(), + address=%listener.local_addr().expect("local_addr available"), component="inbound plaintext", transparent, "listener established", diff --git a/src/proxy/outbound.rs b/src/proxy/outbound.rs index 57fe47564..d70dd4e66 100644 --- a/src/proxy/outbound.rs +++ b/src/proxy/outbound.rs @@ -57,7 +57,7 @@ impl Outbound { pi.cfg.enable_original_source = Some(transparent); info!( - address=%listener.local_addr().unwrap(), + address=%listener.local_addr().expect("local_addr available"), component="outbound", transparent, "listener established", @@ -70,7 +70,7 @@ impl Outbound { } pub(super) fn address(&self) -> SocketAddr { - self.listener.local_addr().unwrap() + self.listener.local_addr().expect("local_addr available") } pub(super) async fn run(self) { @@ -255,7 +255,11 @@ impl OutboundConnection { } } - allowed_sans.push(req.expected_identity.clone().unwrap()); + allowed_sans.push( + req.expected_identity + .clone() + .expect("HBONE request must have expected identity"), + ); let dst_identity = allowed_sans; let pool_key = pool::Key { @@ -316,10 +320,10 @@ impl OutboundConnection { BAGGAGE_HEADER, baggage(&req, self.pi.cfg.cluster_id.clone()), ) - .header(FORWARDED, f.value().unwrap()) + .header(FORWARDED, f.value().expect("Forwarded value is infallible")) .header(TRACEPARENT_HEADER, self.id.header()) .body(Empty::::new()) - .unwrap(); + .expect("builder with known status code should not fail"); let response = connection.send_request(request).await?; @@ -465,7 +469,7 @@ impl OutboundConnection { }); } - let mut mutable_us = us.unwrap(); + let mut mutable_us = us.expect("option is verified above"); let workload_ip = self .pi .state diff --git a/src/proxy/socks5.rs b/src/proxy/socks5.rs index 02de87da3..fb7b88346 100644 --- a/src/proxy/socks5.rs +++ b/src/proxy/socks5.rs @@ -40,7 +40,7 @@ impl Socks5 { .map_err(|e| Error::Bind(pi.cfg.socks5_addr, e))?; info!( - address=%listener.local_addr().unwrap(), + address=%listener.local_addr().expect("local_addr available"), component="socks5", "listener established", ); @@ -53,7 +53,7 @@ impl Socks5 { } pub(super) fn address(&self) -> SocketAddr { - self.listener.local_addr().unwrap() + self.listener.local_addr().expect("local_addr available") } pub async fn run(self) { diff --git a/src/tls/certificate.rs b/src/tls/certificate.rs index 9f826825c..3139fc9a9 100644 --- a/src/tls/certificate.rs +++ b/src/tls/certificate.rs @@ -108,7 +108,9 @@ pub fn identities(cert: X509Certificate) -> Result, Error> { impl Certificate { // TOOD: I would love to parse this once, but ran into lifetime issues. fn parsed(&self) -> X509Certificate { - x509_parser::parse_x509_certificate(&self.der).unwrap().1 + x509_parser::parse_x509_certificate(&self.der) + .expect("certificate was already parsed successfully before") + .1 } pub fn as_pem(&self) -> String { diff --git a/src/tls/control.rs b/src/tls/control.rs index 51eabb10a..1e1dfaed2 100644 --- a/src/tls/control.rs +++ b/src/tls/control.rs @@ -176,12 +176,17 @@ impl tower::Service> for TlsGrpcChannel { fn call(&mut self, req: http_02::Request) -> Self::Future { let mut req = http02_request_to_http1(req.map(HttpBody04ToHttpBody1::new)); - let uri = Uri::builder() - .scheme(self.uri.scheme().unwrap().to_owned()) - .authority(self.uri.authority().unwrap().to_owned()) - .path_and_query(req.uri().path_and_query().unwrap().to_owned()) - .build() - .unwrap(); + let mut uri = Uri::builder(); + if let Some(scheme) = self.uri.scheme() { + uri = uri.scheme(scheme.to_owned()); + } + if let Some(authority) = self.uri.authority() { + uri = uri.authority(authority.to_owned()); + } + if let Some(path_and_query) = req.uri().path_and_query() { + uri = uri.path_and_query(path_and_query.to_owned()); + } + let uri = uri.build().expect("uri must be valid"); *req.uri_mut() = uri; let future = self.client.request(req); Box::pin(async move { diff --git a/src/tls/workload.rs b/src/tls/workload.rs index 514fe14e2..fcfa37216 100644 --- a/src/tls/workload.rs +++ b/src/tls/workload.rs @@ -164,7 +164,13 @@ impl OutboundConnector { self, stream: TcpStream, ) -> Result, io::Error> { - let dest = ServerName::IpAddress(stream.peer_addr().unwrap().ip().into()); + let dest = ServerName::IpAddress( + stream + .peer_addr() + .expect("peer_addr must be set") + .ip() + .into(), + ); let c = tokio_rustls::TlsConnector::from(self.client_config); c.connect(dest, stream).await }