diff --git a/tonic/src/transport/channel/service/tls.rs b/tonic/src/transport/channel/service/tls.rs index 656e9fb31..629a4fe26 100644 --- a/tonic/src/transport/channel/service/tls.rs +++ b/tonic/src/transport/channel/service/tls.rs @@ -1,5 +1,4 @@ use std::fmt; -use std::io::Cursor; use std::sync::Arc; use hyper_util::rt::TokioIo; @@ -13,7 +12,9 @@ use tokio_rustls::{ }; use super::io::BoxedIo; -use crate::transport::service::tls::{add_certs_from_pem, load_identity, TlsError, ALPN_H2}; +use crate::transport::service::tls::{ + convert_certificate_to_pki_types, convert_identity_to_pki_types, TlsError, ALPN_H2, +}; use crate::transport::tls::{Certificate, Identity}; #[derive(Clone)] @@ -55,13 +56,13 @@ impl TlsConnector { } for cert in ca_certs { - add_certs_from_pem(&mut Cursor::new(cert), &mut roots)?; + roots.add_parsable_certificates(convert_certificate_to_pki_types(&cert)?); } let builder = builder.with_root_certificates(roots); let mut config = match identity { Some(identity) => { - let (client_cert, client_key) = load_identity(identity)?; + let (client_cert, client_key) = convert_identity_to_pki_types(&identity)?; builder.with_client_auth_cert(client_cert, client_key)? } None => builder.with_no_client_auth(), diff --git a/tonic/src/transport/server/service/tls.rs b/tonic/src/transport/server/service/tls.rs index d69a6a46b..395d5132b 100644 --- a/tonic/src/transport/server/service/tls.rs +++ b/tonic/src/transport/server/service/tls.rs @@ -1,4 +1,4 @@ -use std::{fmt, io::Cursor, sync::Arc}; +use std::{fmt, sync::Arc}; use tokio::io::{AsyncRead, AsyncWrite}; use tokio_rustls::{ @@ -8,7 +8,7 @@ use tokio_rustls::{ }; use crate::transport::{ - service::tls::{add_certs_from_pem, load_identity, ALPN_H2}, + service::tls::{convert_certificate_to_pki_types, convert_identity_to_pki_types, ALPN_H2}, Certificate, Identity, }; @@ -29,7 +29,7 @@ impl TlsAcceptor { None => builder.with_no_client_auth(), Some(cert) => { let mut roots = RootCertStore::empty(); - add_certs_from_pem(&mut Cursor::new(cert), &mut roots)?; + roots.add_parsable_certificates(convert_certificate_to_pki_types(&cert)?); let verifier = if client_auth_optional { WebPkiClientVerifier::builder(roots.into()).allow_unauthenticated() } else { @@ -40,7 +40,7 @@ impl TlsAcceptor { } }; - let (cert, key) = load_identity(identity)?; + let (cert, key) = convert_identity_to_pki_types(&identity)?; let mut config = builder.with_single_cert(cert, key)?; config.alpn_protocols.push(ALPN_H2.into()); diff --git a/tonic/src/transport/service/tls.rs b/tonic/src/transport/service/tls.rs index cdc2cf7ee..1b0c1c458 100644 --- a/tonic/src/transport/service/tls.rs +++ b/tonic/src/transport/service/tls.rs @@ -1,11 +1,8 @@ use std::{fmt, io::Cursor}; -use tokio_rustls::rustls::{ - pki_types::{CertificateDer, PrivateKeyDer}, - RootCertStore, -}; +use tokio_rustls::rustls::pki_types::{CertificateDer, PrivateKeyDer}; -use crate::transport::Identity; +use crate::transport::{Certificate, Identity}; /// h2 alpn in plain format for rustls. pub(crate) const ALPN_H2: &[u8] = b"h2"; @@ -38,29 +35,20 @@ impl fmt::Display for TlsError { impl std::error::Error for TlsError {} -pub(crate) fn load_identity( - identity: Identity, -) -> Result<(Vec>, PrivateKeyDer<'static>), TlsError> { - let cert = rustls_pemfile::certs(&mut Cursor::new(identity.cert)) +pub(crate) fn convert_certificate_to_pki_types( + certificate: &Certificate, +) -> Result>, TlsError> { + rustls_pemfile::certs(&mut Cursor::new(certificate)) .collect::, _>>() - .map_err(|_| TlsError::CertificateParseError)?; + .map_err(|_| TlsError::CertificateParseError) +} - let Ok(Some(key)) = rustls_pemfile::private_key(&mut Cursor::new(identity.key)) else { +pub(crate) fn convert_identity_to_pki_types( + identity: &Identity, +) -> Result<(Vec>, PrivateKeyDer<'static>), TlsError> { + let cert = convert_certificate_to_pki_types(&identity.cert)?; + let Ok(Some(key)) = rustls_pemfile::private_key(&mut Cursor::new(&identity.key)) else { return Err(TlsError::PrivateKeyParseError); }; - Ok((cert, key)) } - -pub(crate) fn add_certs_from_pem( - mut certs: &mut dyn std::io::BufRead, - roots: &mut RootCertStore, -) -> Result<(), crate::BoxError> { - for cert in rustls_pemfile::certs(&mut certs).collect::, _>>()? { - roots - .add(cert) - .map_err(|_| TlsError::CertificateParseError)?; - } - - Ok(()) -}