diff --git a/mbedtls/src/ssl/config.rs b/mbedtls/src/ssl/config.rs index 7f2c5debc..3cb8bfa9d 100644 --- a/mbedtls/src/ssl/config.rs +++ b/mbedtls/src/ssl/config.rs @@ -28,10 +28,7 @@ use crate::private::UnsafeFrom; use crate::rng::RngCallback; use crate::ssl::context::HandshakeContext; use crate::ssl::ticket::TicketCallback; -use crate::x509::Certificate; -use crate::x509::Crl; -use crate::x509::Profile; -use crate::x509::VerifyError; +use crate::x509::{self, Certificate, Crl, Profile, VerifyCallback}; #[allow(non_camel_case_types)] #[derive(Eq, PartialEq, PartialOrd, Ord, Debug, Copy, Clone)] @@ -98,7 +95,6 @@ define!( } ); -callback!(VerifyCallback: Fn(&Certificate, i32, &mut VerifyError) -> Result<()>); #[cfg(feature = "std")] callback!(DbgCallback: Fn(i32, Cow<'_, str>, i32, Cow<'_, str>) -> ()); callback!(SniCallback: Fn(&mut HandshakeContext, &[u8]) -> Result<()>); @@ -343,40 +339,8 @@ impl Config { where F: VerifyCallback + 'static, { - unsafe extern "C" fn verify_callback( - closure: *mut c_void, - crt: *mut x509_crt, - depth: c_int, - flags: *mut u32, - ) -> c_int - where - F: VerifyCallback + 'static, - { - if crt.is_null() || closure.is_null() || flags.is_null() { - return ::mbedtls_sys::ERR_X509_BAD_INPUT_DATA; - } - - let cb = &mut *(closure as *mut F); - let crt: &mut Certificate = UnsafeFrom::from(crt).expect("valid certificate"); - - let mut verify_error = match VerifyError::from_bits(*flags) { - Some(ve) => ve, - // This can only happen if mbedtls is setting flags in VerifyError that are - // missing from our definition. - None => return ::mbedtls_sys::ERR_X509_BAD_INPUT_DATA, - }; - - let res = cb(crt, depth, &mut verify_error); - *flags = verify_error.bits(); - match res { - Ok(()) => 0, - Err(e) => e.to_int(), - } - } - - self.verify_callback = Some(Arc::new(cb)); - unsafe { ssl_conf_verify(self.into(), Some(verify_callback::), &**self.verify_callback.as_mut().unwrap() as *const _ as *mut c_void) } + unsafe { ssl_conf_verify(self.into(), Some(x509::verify_callback::), &**self.verify_callback.as_ref().unwrap() as *const _ as *mut c_void) } } pub fn set_ca_callback(&mut self, cb: F) diff --git a/mbedtls/src/x509/certificate.rs b/mbedtls/src/x509/certificate.rs index f5e7db533..56f2c3cb2 100644 --- a/mbedtls/src/x509/certificate.rs +++ b/mbedtls/src/x509/certificate.rs @@ -11,7 +11,7 @@ use core::iter::FromIterator; use core::ptr::NonNull; use mbedtls_sys::*; -use mbedtls_sys::types::raw_types::c_char; +use mbedtls_sys::types::raw_types::{c_char, c_void}; use crate::alloc::{List as MbedtlsList, Box as MbedtlsBox}; #[cfg(not(feature = "std"))] @@ -21,7 +21,7 @@ use crate::hash::Type as MdType; use crate::pk::Pk; use crate::private::UnsafeFrom; use crate::rng::Random; -use crate::x509::Time; +use crate::x509::{self, Time, VerifyCallback}; extern "C" { pub(crate) fn forward_mbedtls_calloc(n: mbedtls_sys::types::size_t, size: mbedtls_sys::types::size_t) -> *mut mbedtls_sys::types::raw_types::c_void; @@ -221,11 +221,21 @@ impl Certificate { MdType::from(self.inner.sig_md) } - pub fn verify( + fn verify_ex( chain: &MbedtlsList, trust_ca: &MbedtlsList, err_info: Option<&mut String>, - ) -> Result<()> { + cb: Option, + ) -> Result<()> + where + F: VerifyCallback + 'static, + { + let (f_vrfy, p_vrfy): (Option _>, _) = if let Some(cb) = cb.as_ref() { + (Some(x509::verify_callback::), + cb as *const _ as *mut c_void) + } else { + (None, ::core::ptr::null_mut()) + }; let mut flags = 0; let result = unsafe { x509_crt_verify( @@ -234,8 +244,8 @@ impl Certificate { ::core::ptr::null_mut(), ::core::ptr::null(), &mut flags, - None, - ::core::ptr::null_mut(), + f_vrfy, + p_vrfy, ) } .into_result(); @@ -253,6 +263,26 @@ impl Certificate { } result.map(|_| ()) } + + pub fn verify( + chain: &MbedtlsList, + trust_ca: &MbedtlsList, + err_info: Option<&mut String>, + ) -> Result<()> { + Self::verify_ex(chain, trust_ca, err_info, None::<&dyn VerifyCallback>) + } + + pub fn verify_with_callback( + chain: &MbedtlsList, + trust_ca: &MbedtlsList, + err_info: Option<&mut String>, + cb: F, + ) -> Result<()> + where + F: VerifyCallback + 'static, + { + Self::verify_ex(chain, trust_ca, err_info, Some(cb)) + } } // TODO @@ -719,10 +749,10 @@ impl Extend> for MbedtlsList { } } - #[cfg(test)] mod tests { use super::*; + use crate::x509::VerifyError; struct Test { key1: Pk, @@ -995,7 +1025,22 @@ cYp0bH/RcPTC0Z+ZaqSWMtfxRrk63MJQF9EXpDCdvQRcTMD9D85DJrMKn8aumq0M // try again after fixing the chain chain.push(c_int2.clone()); + + + let mut err_str = String::new(); + + let verify_callback = |_crt: &Certificate, _depth: i32, verify_flags: &mut VerifyError| { + verify_flags.remove(VerifyError::CERT_EXPIRED); + Ok(()) + }; + Certificate::verify(&chain, &mut c_root, None).unwrap(); + let res = Certificate::verify_with_callback(&chain, &mut c_root, Some(&mut err_str), verify_callback); + + match res { + Ok(()) => (), + Err(e) => assert!(false, "Failed to verify, error: {}, err_str: {}", e, err_str), + }; } { @@ -1005,6 +1050,19 @@ cYp0bH/RcPTC0Z+ZaqSWMtfxRrk63MJQF9EXpDCdvQRcTMD9D85DJrMKn8aumq0M chain.push(c_int2.clone()); Certificate::verify(&chain, &mut c_root, None).unwrap(); + + let verify_callback = |_crt: &Certificate, _depth: i32, verify_flags: &mut VerifyError| { + verify_flags.remove(VerifyError::CERT_EXPIRED); + Ok(()) + }; + + let mut err_str = String::new(); + let res = Certificate::verify_with_callback(&chain, &mut c_root, Some(&mut err_str), verify_callback); + + match res { + Ok(()) => (), + Err(e) => assert!(false, "Failed to verify, error: {}, err_str: {}", e, err_str), + }; } } diff --git a/mbedtls/src/x509/mod.rs b/mbedtls/src/x509/mod.rs index bb08072c6..fcb12763c 100644 --- a/mbedtls/src/x509/mod.rs +++ b/mbedtls/src/x509/mod.rs @@ -17,6 +17,8 @@ pub mod profile; // write_crt // write_csr +use crate::error::Error; +use crate::private::UnsafeFrom; #[doc(inline)] pub use self::certificate::Certificate; pub use self::crl::Crl; @@ -26,7 +28,7 @@ pub use self::csr::Csr; pub use self::profile::Profile; use mbedtls_sys::*; -use mbedtls_sys::types::raw_types::c_uint; +use mbedtls_sys::types::raw_types::{c_int, c_uint, c_void}; bitflags! { pub struct KeyUsage: c_uint { const DIGITAL_SIGNATURE = X509_KU_DIGITAL_SIGNATURE as c_uint; @@ -117,6 +119,39 @@ impl VerifyError { } } +callback!(VerifyCallback: Fn(&Certificate, i32, &mut VerifyError) -> Result<(), Error>); + +pub(crate) unsafe extern "C" fn verify_callback( + closure: *mut c_void, + crt: *mut x509_crt, + depth: c_int, + flags: *mut u32, +) -> c_int +where + F: VerifyCallback + 'static, +{ + if crt.is_null() || closure.is_null() || flags.is_null() { + return ::mbedtls_sys::ERR_X509_BAD_INPUT_DATA; + } + + let cb = &*(closure as *const F); + let crt: &mut Certificate = UnsafeFrom::from(crt).expect("valid certificate"); + + let mut verify_error = match VerifyError::from_bits(*flags) { + Some(ve) => ve, + // This can only happen if mbedtls is setting flags in VerifyError that are + // missing from our definition. + None => return ::mbedtls_sys::ERR_X509_BAD_INPUT_DATA, + }; + + let res = cb(crt, depth, &mut verify_error); + *flags = verify_error.bits(); + match res { + Ok(()) => 0, + Err(e) => e.to_int(), + } +} + /// A specific moment in time in UTC #[derive(Debug, Copy, Clone, Eq, PartialEq)] pub struct Time {