From 9161ad79d0731b23a80771ae0cf07878e6a1e63f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=F0=9F=90=99=F0=9F=A6=96=F0=9F=A6=84Nathan=20Lucyk?= Date: Tue, 5 Nov 2024 22:17:32 +0000 Subject: [PATCH 1/5] feat(bindings): enable multiple cert chain on config Adds ref counting and load_pem to cert chain. Makes use of newer C functions that default to app owned instead of lib owned. Config builder extended to make use of cert_chain instead of only creating certs. This will enable multiple cert_chains. Still work in progress. Need to implement C function to get all cert chain and keys on the config. Need to add unit tests. --- bindings/rust/s2n-tls/src/cert_chain.rs | 153 +++++++++++++++++++++++- bindings/rust/s2n-tls/src/config.rs | 27 +++++ tests/unit/s2n_config_test.c | 2 + tls/s2n_config.c | 90 ++++++++++++++ tls/s2n_internal.h | 9 ++ 5 files changed, 276 insertions(+), 5 deletions(-) diff --git a/bindings/rust/s2n-tls/src/cert_chain.rs b/bindings/rust/s2n-tls/src/cert_chain.rs index 007657d15f5..e2f1fc08eb5 100644 --- a/bindings/rust/s2n-tls/src/cert_chain.rs +++ b/bindings/rust/s2n-tls/src/cert_chain.rs @@ -20,6 +20,14 @@ impl CertificateChain<'_> { pub(crate) fn new() -> Result, Error> { unsafe { let ptr = s2n_cert_chain_and_key_new().into_result()?; + let context = Box::::default(); + let context = Box::into_raw(context) as *mut c_void; + + unsafe { + s2n_cert_chain_and_key_set_ctx(ptr.as_ptr(), context) + .into_result() + .unwrap(); + } Ok(CertificateChain { ptr, is_owned: true, @@ -28,6 +36,8 @@ impl CertificateChain<'_> { } } + /// This CertificateChain is not owned and will not increment the reference count. + /// When the rust instance is dropped it will not drop the pointer. pub(crate) unsafe fn from_ptr_reference<'a>( ptr: NonNull, ) -> CertificateChain<'a> { @@ -38,6 +48,27 @@ impl CertificateChain<'_> { } } + /// # Safety + /// + /// This CertificateChain _MUST_ have been initialized with the constructor. + /// Additionally, this does NOT increment the reference count, + /// so consider cloning the result if the source pointer is still + /// valid and usable afterwards. + pub(crate) unsafe fn from_owned_ptr_reference<'a>( + ptr: NonNull, + ) -> CertificateChain<'a> { + let cert_chain = CertificateChain { + ptr, + is_owned: true, + _lifetime: PhantomData, + }; + + // Check if the context can be retrieved. + // If it can't, this is not a valid CertificateChain. + cert_chain.context(); + cert_chain + } + pub fn iter(&self) -> CertificateChainIter<'_> { CertificateChainIter { idx: 0, @@ -75,6 +106,71 @@ impl CertificateChain<'_> { pub(crate) fn as_mut_ptr(&mut self) -> NonNull { self.ptr } + + /// Retrieve a reference to the [`CertificateChainContext`] stored on the CertificateChain. + pub(crate) fn context(&self) -> &CertificateChainContext { + let mut ctx = core::ptr::null_mut(); + unsafe { + ctx = s2n_cert_chain_and_key_get_ctx(self.ptr) + .into_result() + .unwrap(); + &*(ctx as *const CertificateChainContext) + } + } + + /// Retrieve a mutable reference to the [`CertificateChainContext`] stored on the CertificateChain. + pub(crate) fn context_mut(&mut self) -> &mut CertificateChainContext { + let mut ctx = core::ptr::null_mut(); + unsafe { + ctx = s2n_cert_chain_and_key_get_ctx(self.ptr.as_ptr()) + .into_result() + .unwrap(); + &mut *(ctx as *mut CertificateChainContext) + } + } + + pub fn load_pem(&mut self, certificate: &[u8], private_key: &[u8]) { + let certificate = CString::new(certificate).map_err(|_| Error::INVALID_INPUT)?; + let private_key = CString::new(private_key).map_err(|_| Error::INVALID_INPUT)?; + unsafe { + s2n_cert_chain_and_key_load_pem( + self.ptr.as_ptr(), + certificate.as_ptr(), + private_key.as_ptr(), + ) + .into_result() + }?; + Ok(self) + } + + + pub fn set_ocsp_data(&mut self, data: &[u8]) -> Result<&mut Self, Error> { + let size: u32 = data.len().try_into().map_err(|_| Error::INVALID_INPUT)?; + unsafe { + s2n_cert_chain_and_key_set_ocsp_data(self.ptr.as_ptr(), data.as_ptr(), size) + .into_result() + }?; + Ok(self) + } +} + +impl Clone for CertificateChain { + fn clone(&self) -> Self { + let context = self.context(); + + // Safety + // + // Using a relaxed ordering is alright here, as knowledge of the + // original reference prevents other threads from erroneously deleting + // the object. + // https://github.com/rust-lang/rust/blob/e012a191d768adeda1ee36a99ef8b92d51920154/library/alloc/src/sync.rs#L1329 + let _count = context.refcount.fetch_add(1, Ordering::Relaxed); + Self { + ptr: self.ptr, + is_owned: true, // clone only makes sense for owned + _lifetime: PhantomData, + } + } } // # Safety @@ -82,13 +178,46 @@ impl CertificateChain<'_> { // s2n_cert_chain_and_key objects can be sent across threads. unsafe impl Send for CertificateChain<'_> {} +/// # Safety +/// +/// Safety: All C methods that mutate the s2n_cert_chain are wrapped +/// in Rust methods that require a mutable reference. +unsafe impl Sync for CertificateChain<'_> {} + impl Drop for CertificateChain<'_> { fn drop(&mut self) { - if self.is_owned { - // ignore failures since there's not much we can do about it - unsafe { - let _ = s2n_cert_chain_and_key_free(self.ptr.as_ptr()).into_result(); - } + if !self.is_owned { + // not ours to cleanup + return; + } + let context = self.context_mut(); + let count = context.refcount.fetch_sub(1, Ordering::Release); + debug_assert!(count > 0, "refcount should not drop below 1 instance"); + + // only free the cert if this is the last instance + if count != 1 { + return; + } + + // Safety + // + // The use of Ordering and fence mirrors the `Arc` implementation in + // the standard library. + // + // This fence is needed to prevent reordering of use of the data and + // deletion of the data. Because it is marked `Release`, the decreasing + // of the reference count synchronizes with this `Acquire` fence. This + // means that use of the data happens before decreasing the reference + // count, which happens before this fence, which happens before the + // deletion of the data. + // https://github.com/rust-lang/rust/blob/e012a191d768adeda1ee36a99ef8b92d51920154/library/alloc/src/sync.rs#L1637 + std::sync::atomic::fence(Ordering::Acquire); + + unsafe { + // This is the last instance so free the context. + let context = Box::from_raw(context); + drop(context); + let _ = s2n_cert_chain_and_key_free(self.ptr.as_ptr()).into_result(); } } } @@ -152,3 +281,17 @@ impl<'a> Certificate<'a> { // // Certificates just reference data in the chain, so share the Send-ness of the chain. unsafe impl Send for Certificate<'_> {} + +pub(crate) struct CertificateChainContext { + refcount: AtomicUsize, +} + +impl Default for CertificateChainContext { + fn default() -> Self { + // The AtomicUsize is used to manually track the reference count of the CertificateChain. + // This mechanism is used to track when the CertificateChain object should be freed. + Self { + refcount: AtomicUsize::new(1), + } + } +} diff --git a/bindings/rust/s2n-tls/src/config.rs b/bindings/rust/s2n-tls/src/config.rs index c7e6353347e..a616c1a6694 100644 --- a/bindings/rust/s2n-tls/src/config.rs +++ b/bindings/rust/s2n-tls/src/config.rs @@ -155,6 +155,12 @@ impl Drop for Config { // This is the last instance so free the context. let context = Box::from_raw(context); drop(context); + // TODO: drop certs + // let mut certs = core::ptr::null_mut(); + // let _ = s2n_config_get_cert_chains(self.0.as_ptr(), &mut certs).into_result(); + // for(cert : certs) { + // drop(CertificateChain::from_owned_ptr_reference(cert)); + // } let _ = s2n_config_free(self.0.as_ptr()).into_result(); } @@ -277,6 +283,25 @@ impl Builder { Ok(self) } + /// Adds the CertificateChain to the Config. + /// Prefer to use this function over load_pem. + /// This function is not compatible with load_pem and will error if both are used + pub fn add_cert_chain(&mut self, mut cert_chain: CertificateChain) -> Result<&mut Self, Error> { + unsafe { + s2n_config_add_cert_chain_and_key_to_store( + self.as_mut_ptr(), + cert_chain.as_mut_ptr().as_ptr(), + ).into_result()?; + } + // Setting the cert chain on the config creates one additional reference + // so do not drop so prevent Rust from calling `drop()` at the end of this function. + mem::forget(cert_chain); + Ok(self) + } + + /// Creates a certificate chain and binds it to the config. + /// Prefer to use the newer add_cert_chain which enables multiple certs over this function. + /// This function is not compatible with add_cert_chain and will error if both are used pub fn load_pem(&mut self, certificate: &[u8], private_key: &[u8]) -> Result<&mut Self, Error> { let certificate = CString::new(certificate).map_err(|_| Error::INVALID_INPUT)?; let private_key = CString::new(private_key).map_err(|_| Error::INVALID_INPUT)?; @@ -395,6 +420,8 @@ impl Builder { /// Sets the OCSP data for the default certificate chain associated with the Config. /// /// Servers will send the data in response to OCSP stapling requests from clients. + /// + /// Prefer to use add_cert_chain with a CertificateChain that has OCSP data set over this function. // // NOTE: this modifies a certificate chain, NOT the Config itself. This is currently safe // because the certificate chain is set with s2n_config_add_cert_chain_and_key, which diff --git a/tests/unit/s2n_config_test.c b/tests/unit/s2n_config_test.c index e43930bb188..aae13f4a857 100644 --- a/tests/unit/s2n_config_test.c +++ b/tests/unit/s2n_config_test.c @@ -763,6 +763,8 @@ int main(int argc, char **argv) EXPECT_SUCCESS(s2n_connection_set_config(conn, config)); }; + // TODO: add Test s2n_config_get_cert_chains + /* Test loading system certs */ { /* s2n_config_load_system_certs safety */ diff --git a/tls/s2n_config.c b/tls/s2n_config.c index ccc1940c0ac..20abdb2715c 100644 --- a/tls/s2n_config.c +++ b/tls/s2n_config.c @@ -729,6 +729,96 @@ int s2n_config_set_cert_chain_and_key_defaults(struct s2n_config *config, return 0; } +/* Only used in the Rust bindings for cleanup */ +int s2n_config_get_cert_chains(struct s2n_config *config, + struct s2n_cert_chain_and_key ***cert_chains, + uint32_t *chain_count) +{ + POSIX_ENSURE_REF(config); + POSIX_ENSURE_REF(cert_chains); + POSIX_ENSURE_REF(chain_count); + *chain_count = 0; + *cert_chains = NULL; + uint32_t total_possible_chains = 0; + + /* Count all the certs to know how much max memory to allocate */ + for (int i = 0; i < S2N_CERT_TYPE_COUNT; i++) { + if (config->default_certs_by_type.certs[i] != NULL) { + total_possible_chains++; + } + } + if (config->domain_name_to_cert_map != NULL) { + uint32_t domain_count = 0; + POSIX_GUARD_RESULT(s2n_map_size(config->domain_name_to_cert_map, &domain_count)); + total_possible_chains += (domain_count * S2N_CERT_TYPE_COUNT); + } + + if (total_possible_chains == 0) { + return S2N_SUCCESS; + } + DEFER_CLEANUP(struct s2n_blob allocator = {0}, s2n_free); + POSIX_GUARD(s2n_alloc(&allocator, sizeof(struct s2n_cert_chain_and_key*) * total_possible_chains)); + // These two lines to try and fix cast from 'uint8_t *' (aka 'unsigned char *') to 'struct s2n_cert_chain_and_key **' increases required alignment from 1 to 8 [-Werror,-Wcast-align]GCC + POSIX_GUARD(s2n_blob_init(&allocator, allocator.data, sizeof(struct s2n_cert_chain_and_key*) * total_possible_chains)); + POSIX_GUARD(s2n_blob_zero(&allocator)); + *cert_chains = (struct s2n_cert_chain_and_key**)allocator.data; + + + for (int i = 0; i < S2N_CERT_TYPE_COUNT; i++) { + if (config->default_certs_by_type.certs[i] != NULL) { + (*cert_chains)[*chain_count] = config->default_certs_by_type.certs[i]; + (*chain_count)++; + } + } + if (config->domain_name_to_cert_map != NULL) { + struct s2n_map_iterator iter = {0}; + POSIX_GUARD_RESULT(s2n_map_iterator_init(&iter, config->domain_name_to_cert_map)); + + while (s2n_map_iterator_has_next(&iter)) { + struct s2n_blob value = {0}; + POSIX_GUARD_RESULT(s2n_map_iterator_next(&iter, &value)); + + struct certs_by_type *domain_certs = (void *)value.data; + for (int i = 0; i < S2N_CERT_TYPE_COUNT; i++) { + if (domain_certs->certs[i] != NULL) { + bool duplicate = false; + for (uint32_t j = 0; j < *chain_count; j++) { + if ((*cert_chains)[j] == domain_certs->certs[i]) { + duplicate = true; + break; + } + } + + if (!duplicate) { + (*cert_chains)[*chain_count] = domain_certs->certs[i]; + (*chain_count)++; + } + } + } + } + } + + /* If we found fewer chains than allocated, reallocate to the exact size */ + if (*chain_count < total_possible_chains) { + DEFER_CLEANUP(struct s2n_blob right_sized_allocator = {0}, s2n_free); + POSIX_GUARD(s2n_alloc(&right_sized_allocator, sizeof(struct s2n_cert_chain_and_key*) * (*chain_count))); + POSIX_GUARD(s2n_blob_init(&right_sized_allocator, right_sized_allocator.data, + sizeof(struct s2n_cert_chain_and_key*) * (*chain_count))); + struct s2n_cert_chain_and_key **right_sized_chains = (struct s2n_cert_chain_and_key**)right_sized_allocator.data; + + POSIX_CHECKED_MEMCPY(right_sized_chains, *cert_chains, sizeof(struct s2n_cert_chain_and_key*) * (*chain_count)); + *cert_chains = right_sized_chains; + + /* Prevent double free of the memory */ + right_sized_allocator.data = NULL; + } + + /* Prevent double free of the memory */ + allocator.data = NULL; + + return S2N_SUCCESS; +} + int s2n_config_add_dhparams(struct s2n_config *config, const char *dhparams_pem) { DEFER_CLEANUP(struct s2n_stuffer dhparams_in_stuffer = { 0 }, s2n_stuffer_free); diff --git a/tls/s2n_internal.h b/tls/s2n_internal.h index 7c51f656a24..8ba4ab01031 100644 --- a/tls/s2n_internal.h +++ b/tls/s2n_internal.h @@ -59,3 +59,12 @@ S2N_PRIVATE_API int s2n_config_add_cert_chain(struct s2n_config *config, * is still waiting for encryption. */ S2N_PRIVATE_API int s2n_flush(struct s2n_connection *conn, s2n_blocked_status *blocked); + +/* + * Gets all the s2n_cert_chain_and_key set on the config + * + * This method is only useful for the Rust bindings that need to iterate over + * the + */ +S2N_PRIVATE_API int s2n_config_get_cert_chains(struct s2n_config *config, + struct s2n_cert_chain_and_key ***cert_chains, uint32_t *chain_count); From f4a253501e99e4c4be8d8eb1bbad700e48585300 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=F0=9F=90=99=F0=9F=A6=96=F0=9F=A6=84Nathan=20Lucyk?= Date: Fri, 15 Nov 2024 17:19:05 +0000 Subject: [PATCH 2/5] Implement get cert C code for rust bindings w/ unit tests --- bindings/rust/s2n-tls/src/cert_chain.rs | 22 +++++- bindings/rust/s2n-tls/src/config.rs | 32 +++++++-- tests/unit/s2n_config_test.c | 94 ++++++++++++++++++++++++- tls/s2n_config.c | 40 +++++------ 4 files changed, 153 insertions(+), 35 deletions(-) diff --git a/bindings/rust/s2n-tls/src/cert_chain.rs b/bindings/rust/s2n-tls/src/cert_chain.rs index e2f1fc08eb5..71f1ba3b3f3 100644 --- a/bindings/rust/s2n-tls/src/cert_chain.rs +++ b/bindings/rust/s2n-tls/src/cert_chain.rs @@ -6,6 +6,7 @@ use s2n_tls_sys::*; use std::{ marker::PhantomData, ptr::{self, NonNull}, + sync::atomic::{AtomicUsize, Ordering}, }; /// A CertificateChain represents a chain of X.509 certificates. @@ -52,8 +53,7 @@ impl CertificateChain<'_> { /// /// This CertificateChain _MUST_ have been initialized with the constructor. /// Additionally, this does NOT increment the reference count, - /// so consider cloning the result if the source pointer is still - /// valid and usable afterwards. + /// so consider cloning the result if the source pointer is still valid and usable afterwards. pub(crate) unsafe fn from_owned_ptr_reference<'a>( ptr: NonNull, ) -> CertificateChain<'a> { @@ -154,7 +154,7 @@ impl CertificateChain<'_> { } } -impl Clone for CertificateChain { +impl Clone for CertificateChain<'_> { fn clone(&self) -> Self { let context = self.context(); @@ -295,3 +295,19 @@ impl Default for CertificateChainContext { } } } +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn clone_and_drop_update_ref_count() { + let original_cert = CertificateChain::new().unwrap(); + assert_eq!(original_cert.context().refcount.load(Ordering::Relaxed), 1); + + let second_cert = original_cert.clone(); + assert_eq!(original_cert.context().refcount.load(Ordering::Relaxed), 2); + + drop(second_cert); + assert_eq!(original_cert.context().refcount.load(Ordering::Relaxed), 1); + } +} \ No newline at end of file diff --git a/bindings/rust/s2n-tls/src/config.rs b/bindings/rust/s2n-tls/src/config.rs index a616c1a6694..29d7893ef90 100644 --- a/bindings/rust/s2n-tls/src/config.rs +++ b/bindings/rust/s2n-tls/src/config.rs @@ -155,12 +155,29 @@ impl Drop for Config { // This is the last instance so free the context. let context = Box::from_raw(context); drop(context); - // TODO: drop certs - // let mut certs = core::ptr::null_mut(); - // let _ = s2n_config_get_cert_chains(self.0.as_ptr(), &mut certs).into_result(); - // for(cert : certs) { - // drop(CertificateChain::from_owned_ptr_reference(cert)); - // } + + // Clean up the certificate chains + let mut cert_chains: *mut *mut s2n_cert_chain_and_key = std::ptr::null_mut(); + let mut chain_count: u32 = 0; + if s2n_config_get_cert_chains(self.0.as_ptr(), &mut cert_chains, &mut chain_count) + .into_result() + .is_ok() + { + if !cert_chains.is_null() && chain_count > 0 { + let cert_slice = std::slice::from_raw_parts(cert_chains, chain_count as usize); + for &cert_ptr in cert_slice { + if !cert_ptr.is_null() { + drop(CertificateChain::from_owned_ptr_reference(cert_ptr)); + } + } + + let _ = s2n_free_object( + &mut (cert_chains as *mut u8), + (chain_count as usize) * std::mem::size_of::<*mut s2n_cert_chain_and_key>(), + ) + .into_result(); + } + } let _ = s2n_config_free(self.0.as_ptr()).into_result(); } @@ -291,7 +308,8 @@ impl Builder { s2n_config_add_cert_chain_and_key_to_store( self.as_mut_ptr(), cert_chain.as_mut_ptr().as_ptr(), - ).into_result()?; + ) + .into_result()?; } // Setting the cert chain on the config creates one additional reference // so do not drop so prevent Rust from calling `drop()` at the end of this function. diff --git a/tests/unit/s2n_config_test.c b/tests/unit/s2n_config_test.c index aae13f4a857..18972339013 100644 --- a/tests/unit/s2n_config_test.c +++ b/tests/unit/s2n_config_test.c @@ -763,7 +763,99 @@ int main(int argc, char **argv) EXPECT_SUCCESS(s2n_connection_set_config(conn, config)); }; - // TODO: add Test s2n_config_get_cert_chains +/* Test s2n_config_get_cert_chains */ +{ + /* Test with no certificates */ + { + struct s2n_config *config = s2n_config_new(); + EXPECT_NOT_NULL(config); + + struct s2n_cert_chain_and_key **cert_chains = NULL; + uint32_t chain_count = 0; + + EXPECT_SUCCESS(s2n_config_get_cert_chains(config, &cert_chains, &chain_count)); + EXPECT_NULL(cert_chains); + EXPECT_EQUAL(chain_count, 0); + + EXPECT_SUCCESS(s2n_config_free(config)); + }; + + /* Test with a single certificate */ + { + struct s2n_config *config = s2n_config_new(); + EXPECT_NOT_NULL(config); + + struct s2n_cert_chain_and_key *chain_and_key = NULL; + EXPECT_SUCCESS(s2n_test_cert_chain_and_key_new(&chain_and_key, S2N_DEFAULT_TEST_CERT_CHAIN, S2N_DEFAULT_TEST_PRIVATE_KEY)); + EXPECT_SUCCESS(s2n_config_add_cert_chain_and_key_to_store(config, chain_and_key)); + + struct s2n_cert_chain_and_key **cert_chains = NULL; + uint32_t chain_count = 0; + + EXPECT_SUCCESS(s2n_config_get_cert_chains(config, &cert_chains, &chain_count)); + EXPECT_NOT_NULL(cert_chains); + EXPECT_EQUAL(chain_count, 1); + EXPECT_EQUAL(cert_chains[0], chain_and_key); + + EXPECT_SUCCESS(s2n_free_object((uint8_t **)&cert_chains, sizeof(struct s2n_cert_chain_and_key *) * chain_count)); + EXPECT_SUCCESS(s2n_config_free(config)); + EXPECT_SUCCESS(s2n_cert_chain_and_key_free(chain_and_key)); + }; + + /* Test with multiple certificates */ + { + struct s2n_config *config = s2n_config_new(); + EXPECT_NOT_NULL(config); + + struct s2n_cert_chain_and_key *chain1 = NULL; + struct s2n_cert_chain_and_key *chain2 = NULL; + EXPECT_SUCCESS(s2n_test_cert_chain_and_key_new(&chain1, S2N_DEFAULT_TEST_CERT_CHAIN, S2N_DEFAULT_TEST_PRIVATE_KEY)); + EXPECT_SUCCESS(s2n_test_cert_chain_and_key_new(&chain2, S2N_DEFAULT_ECDSA_TEST_CERT_CHAIN, S2N_DEFAULT_ECDSA_TEST_PRIVATE_KEY)); + EXPECT_SUCCESS(s2n_config_add_cert_chain_and_key_to_store(config, chain1)); + EXPECT_SUCCESS(s2n_config_add_cert_chain_and_key_to_store(config, chain2)); + + struct s2n_cert_chain_and_key **cert_chains = NULL; + uint32_t chain_count = 0; + + EXPECT_SUCCESS(s2n_config_get_cert_chains(config, &cert_chains, &chain_count)); + EXPECT_NOT_NULL(cert_chains); + EXPECT_EQUAL(chain_count, 2); + EXPECT_TRUE((cert_chains[0] == chain1 && cert_chains[1] == chain2) || + (cert_chains[0] == chain2 && cert_chains[1] == chain1)); + + EXPECT_SUCCESS(s2n_free_object((uint8_t **)&cert_chains, sizeof(struct s2n_cert_chain_and_key *) * chain_count)); + EXPECT_SUCCESS(s2n_config_free(config)); + EXPECT_SUCCESS(s2n_cert_chain_and_key_free(chain1)); + EXPECT_SUCCESS(s2n_cert_chain_and_key_free(chain2)); + }; + + /* Test with domain name to cert map */ + { + struct s2n_config *config = s2n_config_new(); + EXPECT_NOT_NULL(config); + + struct s2n_cert_chain_and_key *chain1 = NULL; + struct s2n_cert_chain_and_key *chain2 = NULL; + EXPECT_SUCCESS(s2n_test_cert_chain_and_key_new(&chain1, S2N_DEFAULT_TEST_CERT_CHAIN, S2N_DEFAULT_TEST_PRIVATE_KEY)); + EXPECT_SUCCESS(s2n_test_cert_chain_and_key_new(&chain2, S2N_DEFAULT_ECDSA_TEST_CERT_CHAIN, S2N_DEFAULT_ECDSA_TEST_PRIVATE_KEY)); + EXPECT_SUCCESS(s2n_config_add_cert_chain_and_key_to_store(config, chain1)); + EXPECT_SUCCESS(s2n_config_build_domain_name_to_cert_map(config, chain2)); + + struct s2n_cert_chain_and_key **cert_chains = NULL; + uint32_t chain_count = 0; + + EXPECT_SUCCESS(s2n_config_get_cert_chains(config, &cert_chains, &chain_count)); + EXPECT_NOT_NULL(cert_chains); + EXPECT_EQUAL(chain_count, 2); + EXPECT_TRUE((cert_chains[0] == chain1 && cert_chains[1] == chain2) || + (cert_chains[0] == chain2 && cert_chains[1] == chain1)); + + EXPECT_SUCCESS(s2n_free_object((uint8_t **)&cert_chains, sizeof(struct s2n_cert_chain_and_key *) * chain_count)); + EXPECT_SUCCESS(s2n_config_free(config)); + EXPECT_SUCCESS(s2n_cert_chain_and_key_free(chain1)); + EXPECT_SUCCESS(s2n_cert_chain_and_key_free(chain2)); + }; +} /* Test loading system certs */ { diff --git a/tls/s2n_config.c b/tls/s2n_config.c index 20abdb2715c..05e303e6fb5 100644 --- a/tls/s2n_config.c +++ b/tls/s2n_config.c @@ -756,18 +756,23 @@ int s2n_config_get_cert_chains(struct s2n_config *config, if (total_possible_chains == 0) { return S2N_SUCCESS; } + + /* Use a union to ensure proper alignment (casting gives error saying increases required alignment from 1 to 8 ) */ + union { + struct s2n_cert_chain_and_key **chains; + uint8_t *data; + } aligned_chains; + + /* Allocate memory for the array of pointers */ DEFER_CLEANUP(struct s2n_blob allocator = {0}, s2n_free); POSIX_GUARD(s2n_alloc(&allocator, sizeof(struct s2n_cert_chain_and_key*) * total_possible_chains)); - // These two lines to try and fix cast from 'uint8_t *' (aka 'unsigned char *') to 'struct s2n_cert_chain_and_key **' increases required alignment from 1 to 8 [-Werror,-Wcast-align]GCC - POSIX_GUARD(s2n_blob_init(&allocator, allocator.data, sizeof(struct s2n_cert_chain_and_key*) * total_possible_chains)); - POSIX_GUARD(s2n_blob_zero(&allocator)); - *cert_chains = (struct s2n_cert_chain_and_key**)allocator.data; + aligned_chains.data = allocator.data; + uint32_t cert_index = 0; for (int i = 0; i < S2N_CERT_TYPE_COUNT; i++) { if (config->default_certs_by_type.certs[i] != NULL) { - (*cert_chains)[*chain_count] = config->default_certs_by_type.certs[i]; - (*chain_count)++; + aligned_chains.chains[cert_index++] = config->default_certs_by_type.certs[i]; } } if (config->domain_name_to_cert_map != NULL) { @@ -782,36 +787,23 @@ int s2n_config_get_cert_chains(struct s2n_config *config, for (int i = 0; i < S2N_CERT_TYPE_COUNT; i++) { if (domain_certs->certs[i] != NULL) { bool duplicate = false; - for (uint32_t j = 0; j < *chain_count; j++) { - if ((*cert_chains)[j] == domain_certs->certs[i]) { + for (uint32_t j = 0; j < cert_index; j++) { + if (aligned_chains.chains[j] == domain_certs->certs[i]) { duplicate = true; break; } } if (!duplicate) { - (*cert_chains)[*chain_count] = domain_certs->certs[i]; - (*chain_count)++; + aligned_chains.chains[cert_index++] = domain_certs->certs[i]; } } } } } - /* If we found fewer chains than allocated, reallocate to the exact size */ - if (*chain_count < total_possible_chains) { - DEFER_CLEANUP(struct s2n_blob right_sized_allocator = {0}, s2n_free); - POSIX_GUARD(s2n_alloc(&right_sized_allocator, sizeof(struct s2n_cert_chain_and_key*) * (*chain_count))); - POSIX_GUARD(s2n_blob_init(&right_sized_allocator, right_sized_allocator.data, - sizeof(struct s2n_cert_chain_and_key*) * (*chain_count))); - struct s2n_cert_chain_and_key **right_sized_chains = (struct s2n_cert_chain_and_key**)right_sized_allocator.data; - - POSIX_CHECKED_MEMCPY(right_sized_chains, *cert_chains, sizeof(struct s2n_cert_chain_and_key*) * (*chain_count)); - *cert_chains = right_sized_chains; - - /* Prevent double free of the memory */ - right_sized_allocator.data = NULL; - } + *cert_chains = aligned_chains.chains; + *chain_count = cert_index; /* Prevent double free of the memory */ allocator.data = NULL; From eaa8e55f3f98554bd0afc269ae84a12020ffa03b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=F0=9F=90=99=F0=9F=A6=96=F0=9F=A6=84Nathan=20Lucyk?= Date: Sat, 16 Nov 2024 01:18:15 +0000 Subject: [PATCH 3/5] Add unit tests for cert chain and builder --- tests/unit/s2n_config_test.c | 7 ++++--- tls/s2n_config.c | 11 +++++++++++ tls/s2n_internal.h | 9 ++++++++- 3 files changed, 23 insertions(+), 4 deletions(-) diff --git a/tests/unit/s2n_config_test.c b/tests/unit/s2n_config_test.c index 18972339013..6ad3d8166bd 100644 --- a/tests/unit/s2n_config_test.c +++ b/tests/unit/s2n_config_test.c @@ -776,6 +776,7 @@ int main(int argc, char **argv) EXPECT_SUCCESS(s2n_config_get_cert_chains(config, &cert_chains, &chain_count)); EXPECT_NULL(cert_chains); EXPECT_EQUAL(chain_count, 0); + EXPECT_SUCCESS(s2n_free_config_cert_chains(&cert_chains, chain_count)); EXPECT_SUCCESS(s2n_config_free(config)); }; @@ -797,7 +798,7 @@ int main(int argc, char **argv) EXPECT_EQUAL(chain_count, 1); EXPECT_EQUAL(cert_chains[0], chain_and_key); - EXPECT_SUCCESS(s2n_free_object((uint8_t **)&cert_chains, sizeof(struct s2n_cert_chain_and_key *) * chain_count)); + EXPECT_SUCCESS(s2n_free_config_cert_chains(&cert_chains, chain_count)); EXPECT_SUCCESS(s2n_config_free(config)); EXPECT_SUCCESS(s2n_cert_chain_and_key_free(chain_and_key)); }; @@ -823,7 +824,7 @@ int main(int argc, char **argv) EXPECT_TRUE((cert_chains[0] == chain1 && cert_chains[1] == chain2) || (cert_chains[0] == chain2 && cert_chains[1] == chain1)); - EXPECT_SUCCESS(s2n_free_object((uint8_t **)&cert_chains, sizeof(struct s2n_cert_chain_and_key *) * chain_count)); + EXPECT_SUCCESS(s2n_free_config_cert_chains(&cert_chains, chain_count)); EXPECT_SUCCESS(s2n_config_free(config)); EXPECT_SUCCESS(s2n_cert_chain_and_key_free(chain1)); EXPECT_SUCCESS(s2n_cert_chain_and_key_free(chain2)); @@ -850,7 +851,7 @@ int main(int argc, char **argv) EXPECT_TRUE((cert_chains[0] == chain1 && cert_chains[1] == chain2) || (cert_chains[0] == chain2 && cert_chains[1] == chain1)); - EXPECT_SUCCESS(s2n_free_object((uint8_t **)&cert_chains, sizeof(struct s2n_cert_chain_and_key *) * chain_count)); + EXPECT_SUCCESS(s2n_free_config_cert_chains(&cert_chains, chain_count)); EXPECT_SUCCESS(s2n_config_free(config)); EXPECT_SUCCESS(s2n_cert_chain_and_key_free(chain1)); EXPECT_SUCCESS(s2n_cert_chain_and_key_free(chain2)); diff --git a/tls/s2n_config.c b/tls/s2n_config.c index 05e303e6fb5..80706727330 100644 --- a/tls/s2n_config.c +++ b/tls/s2n_config.c @@ -811,6 +811,17 @@ int s2n_config_get_cert_chains(struct s2n_config *config, return S2N_SUCCESS; } +int s2n_free_config_cert_chains(struct s2n_cert_chain_and_key ***cert_chains, uint32_t chain_count) +{ + POSIX_ENSURE_REF(cert_chains); + if (*cert_chains != NULL) { + POSIX_GUARD(s2n_free_object((uint8_t **)cert_chains, sizeof(struct s2n_cert_chain_and_key *) * chain_count)); + *cert_chains = NULL; + } + + return S2N_SUCCESS; +} + int s2n_config_add_dhparams(struct s2n_config *config, const char *dhparams_pem) { DEFER_CLEANUP(struct s2n_stuffer dhparams_in_stuffer = { 0 }, s2n_stuffer_free); diff --git a/tls/s2n_internal.h b/tls/s2n_internal.h index 8ba4ab01031..5c592acd614 100644 --- a/tls/s2n_internal.h +++ b/tls/s2n_internal.h @@ -64,7 +64,14 @@ S2N_PRIVATE_API int s2n_flush(struct s2n_connection *conn, s2n_blocked_status *b * Gets all the s2n_cert_chain_and_key set on the config * * This method is only useful for the Rust bindings that need to iterate over - * the + * the cert chains for cleanup */ S2N_PRIVATE_API int s2n_config_get_cert_chains(struct s2n_config *config, struct s2n_cert_chain_and_key ***cert_chains, uint32_t *chain_count); + +/* + * Cleanup for memory needed for s2n_config_get_cert_chains + * + * This method is only used by the Rust bindings + */ +S2N_PRIVATE_API int s2n_free_config_cert_chains(struct s2n_cert_chain_and_key ***cert_chains, uint32_t chain_count); From 8a54993b503259894c9a9923cb9e9c260841f6f0 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=F0=9F=90=99=F0=9F=A6=96=F0=9F=A6=84Nathan=20Lucyk?= Date: Sat, 16 Nov 2024 01:19:10 +0000 Subject: [PATCH 4/5] Add bindings changes for unit tests --- bindings/rust/s2n-tls/src/cert_chain.rs | 57 +++++++------ bindings/rust/s2n-tls/src/config.rs | 13 ++- bindings/rust/s2n-tls/src/testing/s2n_tls.rs | 84 ++++++++++++++++++++ 3 files changed, 121 insertions(+), 33 deletions(-) diff --git a/bindings/rust/s2n-tls/src/cert_chain.rs b/bindings/rust/s2n-tls/src/cert_chain.rs index 71f1ba3b3f3..9916866b0f6 100644 --- a/bindings/rust/s2n-tls/src/cert_chain.rs +++ b/bindings/rust/s2n-tls/src/cert_chain.rs @@ -4,6 +4,7 @@ use crate::error::{Error, Fallible}; use s2n_tls_sys::*; use std::{ + ffi::{c_void, CString}, marker::PhantomData, ptr::{self, NonNull}, sync::atomic::{AtomicUsize, Ordering}, @@ -19,22 +20,22 @@ pub struct CertificateChain<'a> { impl CertificateChain<'_> { /// This allocates a new certificate chain from s2n. pub(crate) fn new() -> Result, Error> { + crate::init::init(); + let ptr = unsafe { s2n_cert_chain_and_key_new().into_result()? }; + + let context = Box::::default(); + let context = Box::into_raw(context) as *mut c_void; + unsafe { - let ptr = s2n_cert_chain_and_key_new().into_result()?; - let context = Box::::default(); - let context = Box::into_raw(context) as *mut c_void; - - unsafe { - s2n_cert_chain_and_key_set_ctx(ptr.as_ptr(), context) - .into_result() - .unwrap(); - } - Ok(CertificateChain { - ptr, - is_owned: true, - _lifetime: PhantomData, - }) + s2n_cert_chain_and_key_set_ctx(ptr.as_ptr(), context) + .into_result() + .unwrap(); } + Ok(CertificateChain { + ptr, + is_owned: true, + _lifetime: PhantomData, + }) } /// This CertificateChain is not owned and will not increment the reference count. @@ -64,7 +65,7 @@ impl CertificateChain<'_> { }; // Check if the context can be retrieved. - // If it can't, this is not a valid CertificateChain. + // If it can't, this is not an owned CertificateChain created through constructor. cert_chain.context(); cert_chain } @@ -109,27 +110,25 @@ impl CertificateChain<'_> { /// Retrieve a reference to the [`CertificateChainContext`] stored on the CertificateChain. pub(crate) fn context(&self) -> &CertificateChainContext { - let mut ctx = core::ptr::null_mut(); unsafe { - ctx = s2n_cert_chain_and_key_get_ctx(self.ptr) + let ctx = s2n_cert_chain_and_key_get_ctx(self.ptr.as_ptr()) .into_result() .unwrap(); - &*(ctx as *const CertificateChainContext) + &*(ctx.as_ptr() as *const CertificateChainContext) } } /// Retrieve a mutable reference to the [`CertificateChainContext`] stored on the CertificateChain. pub(crate) fn context_mut(&mut self) -> &mut CertificateChainContext { - let mut ctx = core::ptr::null_mut(); unsafe { - ctx = s2n_cert_chain_and_key_get_ctx(self.ptr.as_ptr()) + let ctx = s2n_cert_chain_and_key_get_ctx(self.ptr.as_ptr()) .into_result() .unwrap(); - &mut *(ctx as *mut CertificateChainContext) + &mut *(ctx.as_ptr() as *mut CertificateChainContext) } } - pub fn load_pem(&mut self, certificate: &[u8], private_key: &[u8]) { + pub fn load_pem(&mut self, certificate: &[u8], private_key: &[u8]) -> Result<&mut Self, Error> { let certificate = CString::new(certificate).map_err(|_| Error::INVALID_INPUT)?; let private_key = CString::new(private_key).map_err(|_| Error::INVALID_INPUT)?; unsafe { @@ -143,7 +142,6 @@ impl CertificateChain<'_> { Ok(self) } - pub fn set_ocsp_data(&mut self, data: &[u8]) -> Result<&mut Self, Error> { let size: u32 = data.len().try_into().map_err(|_| Error::INVALID_INPUT)?; unsafe { @@ -303,11 +301,18 @@ mod tests { fn clone_and_drop_update_ref_count() { let original_cert = CertificateChain::new().unwrap(); assert_eq!(original_cert.context().refcount.load(Ordering::Relaxed), 1); - + let second_cert = original_cert.clone(); assert_eq!(original_cert.context().refcount.load(Ordering::Relaxed), 2); - + drop(second_cert); assert_eq!(original_cert.context().refcount.load(Ordering::Relaxed), 1); } -} \ No newline at end of file + + // ensure the config context is send and sync + #[test] + fn context_send_sync_test() { + fn assert_send_sync() {} + assert_send_sync::(); + } +} diff --git a/bindings/rust/s2n-tls/src/config.rs b/bindings/rust/s2n-tls/src/config.rs index 29d7893ef90..84f27ea356a 100644 --- a/bindings/rust/s2n-tls/src/config.rs +++ b/bindings/rust/s2n-tls/src/config.rs @@ -5,10 +5,12 @@ use crate::renegotiate::RenegotiateCallback; use crate::{ callbacks::*, + cert_chain::CertificateChain, enums::*, error::{Error, Fallible}, security, }; +use core::mem::{self}; use core::{convert::TryInto, ptr::NonNull}; use s2n_tls_sys::*; use std::{ @@ -166,16 +168,13 @@ impl Drop for Config { if !cert_chains.is_null() && chain_count > 0 { let cert_slice = std::slice::from_raw_parts(cert_chains, chain_count as usize); for &cert_ptr in cert_slice { - if !cert_ptr.is_null() { - drop(CertificateChain::from_owned_ptr_reference(cert_ptr)); + if let Some(non_null_ptr) = NonNull::new(cert_ptr) { + drop(CertificateChain::from_owned_ptr_reference(non_null_ptr)); } } - let _ = s2n_free_object( - &mut (cert_chains as *mut u8), - (chain_count as usize) * std::mem::size_of::<*mut s2n_cert_chain_and_key>(), - ) - .into_result(); + let _ = + s2n_free_config_cert_chains(&mut cert_chains, chain_count).into_result(); } } diff --git a/bindings/rust/s2n-tls/src/testing/s2n_tls.rs b/bindings/rust/s2n-tls/src/testing/s2n_tls.rs index 1d5d920c384..380572fd668 100644 --- a/bindings/rust/s2n-tls/src/testing/s2n_tls.rs +++ b/bindings/rust/s2n-tls/src/testing/s2n_tls.rs @@ -5,6 +5,7 @@ mod tests { use crate::{ callbacks::{ClientHelloCallback, ConnectionFuture, ConnectionFutureResult}, + cert_chain::CertificateChain, enums::ClientAuthType, error::ErrorType, testing::{self, client_hello::*, Error, Result, *}, @@ -114,6 +115,89 @@ mod tests { Ok(()) } + #[test] + fn config_builder_fails_when_mixing_ownership() { + let first_prefix = concat!( + env!("CARGO_MANIFEST_DIR"), + "/../../../tests/pems/rsa_4096_sha512_client_" + ); + let first_keypair = CertKeyPair::from(first_prefix, "cert", "key", "cert"); + let second_prefix = concat!( + env!("CARGO_MANIFEST_DIR"), + "/../../../tests/pems/rsa_4096_sha256_client_" + ); + let second_keypair = CertKeyPair::from(second_prefix, "cert", "key", "cert"); + + let mut cert_chain = CertificateChain::new().unwrap(); + cert_chain + .load_pem(first_keypair.cert(), first_keypair.key()) + .expect("Unable to load pem into cert chain"); + + let mut builder = Builder::new(); + builder + .add_cert_chain(cert_chain) + .expect("Unable to add cert chain to builder"); + + // load_pem and add_cert_chain cannot mix on the builder due to ownership so we should get an error + let result = builder.load_pem(second_keypair.cert(), second_keypair.key()); + assert!(result.is_err()); + } + + #[test] + fn config_builder_accepts_multiple_cert_chains() { + let first_prefix = concat!( + env!("CARGO_MANIFEST_DIR"), + "/../../../tests/pems/rsa_4096_sha512_client_" + ); + let first_keypair = CertKeyPair::from(first_prefix, "cert", "key", "cert"); + let second_prefix = concat!( + env!("CARGO_MANIFEST_DIR"), + "/../../../tests/pems/rsa_4096_sha256_client_" + ); + let second_keypair = CertKeyPair::from(second_prefix, "cert", "key", "cert"); + + let mut first_cert_chain = CertificateChain::new().unwrap(); + first_cert_chain + .load_pem(first_keypair.cert(), first_keypair.key()) + .expect("Unable to load pem into cert chain"); + let mut second_cert_chain = CertificateChain::new().unwrap(); + second_cert_chain + .load_pem(second_keypair.cert(), second_keypair.key()) + .expect("Unable to load pem into cert chain"); + + let mut builder = Builder::new(); + + assert!(builder.add_cert_chain(first_cert_chain).is_ok()); + assert!(builder.add_cert_chain(second_cert_chain).is_ok()); + let _ = builder + .build() + .expect("Failed to build config with multiple cert chains"); + } + + #[test] + fn cert_chain_sharable_by_configs() -> Result<(), Error> { + let keypair = CertKeyPair::default(); + let mut cert_chain = CertificateChain::new().unwrap(); + cert_chain + .load_pem(keypair.cert(), keypair.key()) + .expect("Unable to load pem into cert chain"); + let cloned_cert_chain = cert_chain.clone(); + + let mut first_builder = Builder::new(); + first_builder + .add_cert_chain(cert_chain) + .expect("Failed to add cert chain to first config"); + let mut second_builder = Builder::new(); + second_builder + .add_cert_chain(cloned_cert_chain) + .expect("Failed to add cert chain to second config"); + + first_builder.build().expect("Unable to build first config"); + second_builder.build().expect("Unable to build second config"); + + Ok(()) + } + #[test] fn connnection_waker() { let config = build_config(&security::DEFAULT_TLS13).unwrap(); From 7622f5205c219ffa79a7e2fcbe9194250c158e8e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=F0=9F=90=99=F0=9F=A6=96=F0=9F=A6=84Nathan=20Lucyk?= Date: Mon, 18 Nov 2024 19:07:22 +0000 Subject: [PATCH 5/5] typo fix --- bindings/rust/s2n-tls/src/cert_chain.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/bindings/rust/s2n-tls/src/cert_chain.rs b/bindings/rust/s2n-tls/src/cert_chain.rs index 9916866b0f6..5638c373288 100644 --- a/bindings/rust/s2n-tls/src/cert_chain.rs +++ b/bindings/rust/s2n-tls/src/cert_chain.rs @@ -309,7 +309,7 @@ mod tests { assert_eq!(original_cert.context().refcount.load(Ordering::Relaxed), 1); } - // ensure the config context is send and sync + // ensure the CertificateChainContext is send and sync #[test] fn context_send_sync_test() { fn assert_send_sync() {}