Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor(ext/tls): use concrete error types #26174

Merged
merged 4 commits into from
Oct 12, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion ext/net/ops_tls.rs
Original file line number Diff line number Diff line change
Expand Up @@ -250,7 +250,7 @@ pub fn op_tls_cert_resolver_resolve_error(
#[string] sni: String,
#[string] error: String,
) {
lookup.resolve(sni, Err(anyhow!(error)))
lookup.resolve(sni, Err(error))
}

#[op2]
Expand Down
1 change: 1 addition & 0 deletions ext/tls/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -21,5 +21,6 @@ rustls-pemfile.workspace = true
rustls-tokio-stream.workspace = true
rustls-webpki.workspace = true
serde.workspace = true
thiserror.workspace = true
tokio.workspace = true
webpki-roots.workspace = true
87 changes: 43 additions & 44 deletions ext/tls/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,17 +9,12 @@ pub use rustls_tokio_stream::*;
pub use webpki;
pub use webpki_roots;

use deno_core::anyhow::anyhow;
use deno_core::error::custom_error;
use deno_core::error::AnyError;

use rustls::client::danger::HandshakeSignatureValid;
use rustls::client::danger::ServerCertVerified;
use rustls::client::danger::ServerCertVerifier;
use rustls::client::WebPkiServerVerifier;
use rustls::ClientConfig;
use rustls::DigitallySignedStruct;
use rustls::Error;
use rustls::RootCertStore;
use rustls_pemfile::certs;
use rustls_pemfile::ec_private_keys;
Expand All @@ -35,12 +30,30 @@ use std::sync::Arc;
mod tls_key;
pub use tls_key::*;

#[derive(Debug, thiserror::Error)]
pub enum TlsError {
#[error(transparent)]
Rustls(#[from] rustls::Error),
#[error("Unable to add pem file to certificate store: {0}")]
UnableAddPemFileToCert(std::io::Error),
#[error("Unable to decode certificate")]
CertInvalid,
#[error("No certificates found in certificate data")]
CertsNotFound,
#[error("No keys found in key data")]
KeysNotFound,
#[error("Unable to decode key")]
KeyDecode,
}

/// Lazily resolves the root cert store.
///
/// This was done because the root cert store is not needed in all cases
/// and takes a bit of time to initialize.
pub trait RootCertStoreProvider: Send + Sync {
fn get_or_try_init(&self) -> Result<&RootCertStore, AnyError>;
fn get_or_try_init(
&self,
) -> Result<&RootCertStore, deno_core::error::AnyError>;
}

// This extension has no runtime apis, it only exports some shared native functions.
Expand Down Expand Up @@ -77,7 +90,7 @@ impl ServerCertVerifier for NoCertificateVerification {
server_name: &rustls::pki_types::ServerName<'_>,
ocsp_response: &[u8],
now: rustls::pki_types::UnixTime,
) -> Result<ServerCertVerified, Error> {
) -> Result<ServerCertVerified, rustls::Error> {
if self.ic_allowlist.is_empty() {
return Ok(ServerCertVerified::assertion());
}
Expand All @@ -89,7 +102,9 @@ impl ServerCertVerifier for NoCertificateVerification {
_ => {
// NOTE(bartlomieju): `ServerName` is a non-exhaustive enum
// so we have this catch all errors here.
return Err(Error::General("Unknown `ServerName` variant".to_string()));
return Err(rustls::Error::General(
"Unknown `ServerName` variant".to_string(),
));
}
};
if self.ic_allowlist.contains(&dns_name_or_ip_address) {
Expand All @@ -110,7 +125,7 @@ impl ServerCertVerifier for NoCertificateVerification {
message: &[u8],
cert: &rustls::pki_types::CertificateDer,
dss: &DigitallySignedStruct,
) -> Result<HandshakeSignatureValid, Error> {
) -> Result<HandshakeSignatureValid, rustls::Error> {
if self.ic_allowlist.is_empty() {
return Ok(HandshakeSignatureValid::assertion());
}
Expand All @@ -126,7 +141,7 @@ impl ServerCertVerifier for NoCertificateVerification {
message: &[u8],
cert: &rustls::pki_types::CertificateDer,
dss: &DigitallySignedStruct,
) -> Result<HandshakeSignatureValid, Error> {
) -> Result<HandshakeSignatureValid, rustls::Error> {
if self.ic_allowlist.is_empty() {
return Ok(HandshakeSignatureValid::assertion());
}
Expand Down Expand Up @@ -178,7 +193,7 @@ pub fn create_client_config(
unsafely_ignore_certificate_errors: Option<Vec<String>>,
maybe_cert_chain_and_key: TlsKeys,
socket_use: SocketUse,
) -> Result<ClientConfig, AnyError> {
) -> Result<ClientConfig, TlsError> {
if let Some(ic_allowlist) = unsafely_ignore_certificate_errors {
let client_config = ClientConfig::builder()
.dangerous()
Expand Down Expand Up @@ -214,10 +229,7 @@ pub fn create_client_config(
root_cert_store.add(cert)?;
}
Err(e) => {
return Err(anyhow!(
"Unable to add pem file to certificate store: {}",
e
));
return Err(TlsError::UnableAddPemFileToCert(e));
}
}
}
Expand Down Expand Up @@ -255,74 +267,61 @@ fn add_alpn(client: &mut ClientConfig, socket_use: SocketUse) {

pub fn load_certs(
reader: &mut dyn BufRead,
) -> Result<Vec<CertificateDer<'static>>, AnyError> {
) -> Result<Vec<CertificateDer<'static>>, TlsError> {
let certs: Result<Vec<_>, _> = certs(reader).collect();

let certs = certs
.map_err(|_| custom_error("InvalidData", "Unable to decode certificate"))?;
let certs = certs.map_err(|_| TlsError::CertInvalid)?;

if certs.is_empty() {
return Err(cert_not_found_err());
return Err(TlsError::CertsNotFound);
}

Ok(certs)
}

fn key_decode_err() -> AnyError {
custom_error("InvalidData", "Unable to decode key")
}

fn key_not_found_err() -> AnyError {
custom_error("InvalidData", "No keys found in key data")
}

fn cert_not_found_err() -> AnyError {
custom_error("InvalidData", "No certificates found in certificate data")
}

/// Starts with -----BEGIN RSA PRIVATE KEY-----
fn load_rsa_keys(
mut bytes: &[u8],
) -> Result<Vec<PrivateKeyDer<'static>>, AnyError> {
) -> Result<Vec<PrivateKeyDer<'static>>, TlsError> {
let keys: Result<Vec<_>, _> = rsa_private_keys(&mut bytes).collect();
let keys = keys.map_err(|_| key_decode_err())?;
let keys = keys.map_err(|_| TlsError::KeyDecode)?;
Ok(keys.into_iter().map(PrivateKeyDer::Pkcs1).collect())
}

/// Starts with -----BEGIN EC PRIVATE KEY-----
fn load_ec_keys(
mut bytes: &[u8],
) -> Result<Vec<PrivateKeyDer<'static>>, AnyError> {
) -> Result<Vec<PrivateKeyDer<'static>>, TlsError> {
let keys: Result<Vec<_>, std::io::Error> =
ec_private_keys(&mut bytes).collect();
let keys2 = keys.map_err(|_| key_decode_err())?;
let keys2 = keys.map_err(|_| TlsError::KeyDecode)?;
Ok(keys2.into_iter().map(PrivateKeyDer::Sec1).collect())
}

/// Starts with -----BEGIN PRIVATE KEY-----
fn load_pkcs8_keys(
mut bytes: &[u8],
) -> Result<Vec<PrivateKeyDer<'static>>, AnyError> {
) -> Result<Vec<PrivateKeyDer<'static>>, TlsError> {
let keys: Result<Vec<_>, std::io::Error> =
pkcs8_private_keys(&mut bytes).collect();
let keys2 = keys.map_err(|_| key_decode_err())?;
let keys2 = keys.map_err(|_| TlsError::KeyDecode)?;
Ok(keys2.into_iter().map(PrivateKeyDer::Pkcs8).collect())
}

fn filter_invalid_encoding_err(
to_be_filtered: Result<HandshakeSignatureValid, Error>,
) -> Result<HandshakeSignatureValid, Error> {
to_be_filtered: Result<HandshakeSignatureValid, rustls::Error>,
) -> Result<HandshakeSignatureValid, rustls::Error> {
match to_be_filtered {
Err(Error::InvalidCertificate(rustls::CertificateError::BadEncoding)) => {
Ok(HandshakeSignatureValid::assertion())
}
Err(rustls::Error::InvalidCertificate(
rustls::CertificateError::BadEncoding,
)) => Ok(HandshakeSignatureValid::assertion()),
res => res,
}
}

pub fn load_private_keys(
bytes: &[u8],
) -> Result<Vec<PrivateKeyDer<'static>>, AnyError> {
) -> Result<Vec<PrivateKeyDer<'static>>, TlsError> {
let mut keys = load_rsa_keys(bytes)?;

if keys.is_empty() {
Expand All @@ -334,7 +333,7 @@ pub fn load_private_keys(
}

if keys.is_empty() {
return Err(key_not_found_err());
return Err(TlsError::KeysNotFound);
}

Ok(keys)
Expand Down
28 changes: 19 additions & 9 deletions ext/tls/tls_key.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,6 @@
//! key lookup can handle closing one end of the pair, in which case they will just
//! attempt to clean up the associated resources.
use deno_core::anyhow::anyhow;
use deno_core::error::AnyError;
use deno_core::futures::future::poll_fn;
use deno_core::futures::future::Either;
use deno_core::futures::FutureExt;
Expand All @@ -33,7 +31,19 @@ use tokio::sync::oneshot;
use webpki::types::CertificateDer;
use webpki::types::PrivateKeyDer;

type ErrorType = Rc<AnyError>;
#[derive(Debug, thiserror::Error)]
pub enum TlsKeyError {
#[error(transparent)]
Rustls(#[from] rustls::Error),
#[error("Failed: {0}")]
Failed(ErrorType),
#[error(transparent)]
JoinError(#[from] tokio::task::JoinError),
#[error(transparent)]
RecvError(#[from] tokio::sync::broadcast::error::RecvError),
}

type ErrorType = Arc<Box<str>>;

/// A TLS certificate/private key pair.
/// see https://docs.rs/rustls-pki-types/latest/rustls_pki_types/#cloning-private-keys
Expand Down Expand Up @@ -114,7 +124,7 @@ impl TlsKeyResolver {
&self,
sni: String,
alpn: Vec<Vec<u8>>,
) -> Result<Arc<ServerConfig>, AnyError> {
) -> Result<Arc<ServerConfig>, TlsKeyError> {
let key = self.resolve(sni).await?;

let mut tls_config = ServerConfig::builder()
Expand Down Expand Up @@ -183,7 +193,7 @@ impl TlsKeyResolver {
pub fn resolve(
&self,
sni: String,
) -> impl Future<Output = Result<TlsKey, AnyError>> {
) -> impl Future<Output = Result<TlsKey, TlsKeyError>> {
let mut cache = self.inner.cache.borrow_mut();
let mut recv = match cache.get(&sni) {
None => {
Expand All @@ -194,7 +204,7 @@ impl TlsKeyResolver {
}
Some(TlsKeyState::Resolving(recv)) => recv.resubscribe(),
Some(TlsKeyState::Resolved(res)) => {
return Either::Left(ready(res.clone().map_err(|_| anyhow!("Failed"))));
return Either::Left(ready(res.clone().map_err(TlsKeyError::Failed)));
}
};
drop(cache);
Expand All @@ -212,7 +222,7 @@ impl TlsKeyResolver {
// Someone beat us to it
}
}
res.map_err(|_| anyhow!("Failed"))
res.map_err(TlsKeyError::Failed)
});
Either::Right(async move { handle.await? })
}
Expand Down Expand Up @@ -247,13 +257,13 @@ impl TlsKeyLookup {
}

/// Resolve a previously polled item.
pub fn resolve(&self, sni: String, res: Result<TlsKey, AnyError>) {
pub fn resolve(&self, sni: String, res: Result<TlsKey, String>) {
_ = self
.pending
.borrow_mut()
.remove(&sni)
.unwrap()
.send(res.map_err(Rc::new));
.send(res.map_err(|e| Arc::new(e.into_boxed_str())));
}
}

Expand Down
1 change: 1 addition & 0 deletions ext/websocket/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -349,6 +349,7 @@ pub fn create_ws_client_config(
TlsKeys::Null,
socket_use,
)
.map_err(|e| e.into())
}

/// Headers common to both http/1.1 and h2 requests.
Expand Down
13 changes: 13 additions & 0 deletions runtime/errors.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ use deno_core::error::AnyError;
use deno_core::serde_json;
use deno_core::url;
use deno_core::ModuleResolutionError;
use deno_tls::TlsError;
use std::env;
use std::error::Error;
use std::io;
Expand Down Expand Up @@ -153,12 +154,24 @@ pub fn get_nix_error_class(error: &nix::Error) -> &'static str {
}
}

fn get_tls_error_class(e: &TlsError) -> &'static str {
match e {
TlsError::Rustls(_) => "Error",
TlsError::UnableAddPemFileToCert(e) => get_io_error_class(e),
TlsError::CertInvalid => "InvalidData",
TlsError::CertsNotFound => "InvalidData",
TlsError::KeysNotFound => "InvalidData",
TlsError::KeyDecode => "InvalidData",
}
}

pub fn get_error_class_name(e: &AnyError) -> Option<&'static str> {
deno_core::error::get_custom_error_class(e)
.or_else(|| deno_webgpu::error::get_error_class_name(e))
.or_else(|| deno_web::get_error_class_name(e))
.or_else(|| deno_webstorage::get_not_supported_error_class_name(e))
.or_else(|| deno_websocket::get_network_error_class_name(e))
.or_else(|| e.downcast_ref::<TlsError>().map(get_tls_error_class))
.or_else(|| {
e.downcast_ref::<dlopen2::Error>()
.map(get_dlopen_error_class)
Expand Down