From 98bd132c64f4c9447c7a3ba1ce4891311f5b933e Mon Sep 17 00:00:00 2001 From: Norman Ashley Date: Wed, 23 Oct 2024 19:48:30 +0000 Subject: [PATCH] Add new API to clean up OpenSSL threads. Signed-off-by: Norman Ashley --- src/common/common.c | 6 +++ src/common/common.h | 98 ++++++++++++++++++++----------------- src/common/ossl_functions.h | 36 +++++++++----- src/common/ossl_helpers.c | 6 +++ src/common/ossl_helpers.h | 5 +- tests/test_kem.c | 3 ++ tests/test_sig.c | 3 ++ tests/test_sig_stfl.c | 16 +++++- 8 files changed, 113 insertions(+), 60 deletions(-) diff --git a/src/common/common.c b/src/common/common.c index 78d0dcb24..cb70d558c 100644 --- a/src/common/common.c +++ b/src/common/common.c @@ -235,6 +235,12 @@ OQS_API void OQS_init(void) { #endif } +OQS_API void OQS_thread_stop(void) { +#if defined(OQS_USE_OPENSSL) + oqs_thread_stop(); +#endif +} + OQS_API const char *OQS_version(void) { return OQS_VERSION_TEXT; } diff --git a/src/common/common.h b/src/common/common.h index b15e244a3..dcf15fd7b 100644 --- a/src/common/common.h +++ b/src/common/common.h @@ -5,14 +5,13 @@ * SPDX-License-Identifier: MIT */ - #ifndef OQS_COMMON_H #define OQS_COMMON_H #include #include -#include #include +#include #include @@ -27,14 +26,15 @@ extern "C" { * using OpenSSL functions when OQS_USE_OPENSSL is defined, and * standard C library functions otherwise. */ -#if (defined(OQS_USE_OPENSSL) || defined(OQS_DLOPEN_OPENSSL)) && defined(OPENSSL_VERSION_NUMBER) +#if (defined(OQS_USE_OPENSSL) || defined(OQS_DLOPEN_OPENSSL)) && \ + defined(OPENSSL_VERSION_NUMBER) #include /** -* Allocates memory of a given size. -* @param size The size of the memory to be allocated in bytes. -* @return A pointer to the allocated memory. -*/ + * Allocates memory of a given size. + * @param size The size of the memory to be allocated in bytes. + * @return A pointer to the allocated memory. + */ #define OQS_MEM_malloc(size) OPENSSL_malloc(size) /** @@ -43,7 +43,8 @@ extern "C" { * @param element_size The size of each element in bytes. * @return A pointer to the allocated memory. */ -#define OQS_MEM_calloc(num_elements, element_size) OPENSSL_zalloc((num_elements) * (element_size)) +#define OQS_MEM_calloc(num_elements, element_size) \ + OPENSSL_zalloc((num_elements) * (element_size)) /** * Duplicates a string. * @param str The string to be duplicated. @@ -52,10 +53,10 @@ extern "C" { #define OQS_MEM_strdup(str) OPENSSL_strdup(str) #else /** -* Allocates memory of a given size. -* @param size The size of the memory to be allocated in bytes. -* @return A pointer to the allocated memory. -*/ + * Allocates memory of a given size. + * @param size The size of the memory to be allocated in bytes. + * @return A pointer to the allocated memory. + */ #define OQS_MEM_malloc(size) malloc(size) // IGNORE memory-check /** @@ -64,7 +65,8 @@ extern "C" { * @param element_size The size of each element in bytes. * @return A pointer to the allocated memory. */ -#define OQS_MEM_calloc(num_elements, element_size) calloc(num_elements, element_size) // IGNORE memory-check +#define OQS_MEM_calloc(num_elements, element_size) \ + calloc(num_elements, element_size) // IGNORE memory-check /** * Duplicates a string. * @param str The string to be duplicated. @@ -77,13 +79,14 @@ extern "C" { * Macro for terminating the program if x is * a null pointer. */ -#define OQS_EXIT_IF_NULLPTR(x, loc) \ - do { \ - if ( (x) == (void*)0 ) { \ - fprintf(stderr, "Unexpected NULL returned from %s API. Exiting.\n", loc); \ - exit(EXIT_FAILURE); \ - } \ - } while (0) +#define OQS_EXIT_IF_NULLPTR(x, loc) \ + do { \ + if ((x) == (void *)0) { \ + fprintf(stderr, "Unexpected NULL returned from %s API. Exiting.\n", \ + loc); \ + exit(EXIT_FAILURE); \ + } \ + } while (0) /** * This macro is intended to replace those assert()s @@ -98,22 +101,24 @@ extern "C" { */ #ifdef OQS_USE_OPENSSL #ifdef OPENSSL_NO_STDIO -#define OQS_OPENSSL_GUARD(x) \ - do { \ - if( 1 != (x) ) { \ - fprintf(stderr, "Error return value from OpenSSL API: %d. Exiting.\n", x); \ - exit(EXIT_FAILURE); \ - } \ - } while (0) +#define OQS_OPENSSL_GUARD(x) \ + do { \ + if (1 != (x)) { \ + fprintf(stderr, "Error return value from OpenSSL API: %d. Exiting.\n", \ + x); \ + exit(EXIT_FAILURE); \ + } \ + } while (0) #else // OPENSSL_NO_STDIO -#define OQS_OPENSSL_GUARD(x) \ - do { \ - if( 1 != (x) ) { \ - fprintf(stderr, "Error return value from OpenSSL API: %d. Exiting.\n", x); \ - OSSL_FUNC(ERR_print_errors_fp)(stderr); \ - exit(EXIT_FAILURE); \ - } \ - } while (0) +#define OQS_OPENSSL_GUARD(x) \ + do { \ + if (1 != (x)) { \ + fprintf(stderr, "Error return value from OpenSSL API: %d. Exiting.\n", \ + x); \ + OSSL_FUNC(ERR_print_errors_fp)(stderr); \ + exit(EXIT_FAILURE); \ + } \ + } while (0) #endif // OPENSSL_NO_STDIO #endif // OQS_USE_OPENSSL @@ -123,13 +128,13 @@ extern "C" { * only handle values up to INT_MAX for those parameters. * This macro is a temporary workaround for such functions. */ -#define SIZE_T_TO_INT_OR_EXIT(size_t_var_name, int_var_name) \ - int int_var_name = 0; \ - if (size_t_var_name <= INT_MAX) { \ - int_var_name = (int)size_t_var_name; \ - } else { \ - exit(EXIT_FAILURE); \ - } +#define SIZE_T_TO_INT_OR_EXIT(size_t_var_name, int_var_name) \ + int int_var_name = 0; \ + if (size_t_var_name <= INT_MAX) { \ + int_var_name = (int)size_t_var_name; \ + } else { \ + exit(EXIT_FAILURE); \ + } /** * Defines which functions should be exposed outside the LibOQS library @@ -213,6 +218,11 @@ OQS_API int OQS_CPU_has_extension(OQS_CPU_EXT ext); */ OQS_API void OQS_init(void); +/** + * This function stops and frees OpenSSL threads resources in the correct order + */ +OQS_API void OQS_thread_stop(void); + /** * This function frees prefetched OpenSSL objects */ @@ -277,8 +287,8 @@ OQS_API void OQS_MEM_insecure_free(void *ptr); * Allocates size bytes of uninitialized memory with a base pointer that is * a multiple of alignment. Alignment must be a power of two and a multiple * of sizeof(void *). Size must be a multiple of alignment. - * @note The allocated memory should be freed with `OQS_MEM_aligned_free` when it - * is no longer needed. + * @note The allocated memory should be freed with `OQS_MEM_aligned_free` when + * it is no longer needed. */ void *OQS_MEM_aligned_alloc(size_t alignment, size_t size); diff --git a/src/common/ossl_functions.h b/src/common/ossl_functions.h index 438ec1faf..7e02898b3 100644 --- a/src/common/ossl_functions.h +++ b/src/common/ossl_functions.h @@ -10,17 +10,26 @@ VOID_FUNC(void, ERR_print_errors_fp, (FILE *fp), (fp)) VOID_FUNC(void, EVP_CIPHER_CTX_free, (EVP_CIPHER_CTX *c), (c)) FUNC(EVP_CIPHER_CTX *, EVP_CIPHER_CTX_new, (void), ()) FUNC(int, EVP_CIPHER_CTX_set_padding, (EVP_CIPHER_CTX *c, int pad), (c, pad)) -FUNC(int, EVP_DigestFinalXOF, (EVP_MD_CTX *ctx, unsigned char *md, size_t len), (ctx, md, len)) -FUNC(int, EVP_DigestFinal_ex, (EVP_MD_CTX *ctx, unsigned char *md, unsigned int *s), (ctx, md, s)) -FUNC(int, EVP_DigestInit_ex, (EVP_MD_CTX *ctx, const EVP_MD *type, ENGINE *impl), (ctx, type, impl)) -FUNC(int, EVP_DigestUpdate, (EVP_MD_CTX *ctx, const void *d, size_t cnt), (ctx, d, cnt)) -FUNC(int, EVP_EncryptFinal_ex, (EVP_CIPHER_CTX *ctx, unsigned char *out, int *outl), (ctx, out, outl)) -FUNC(int, EVP_EncryptInit_ex, (EVP_CIPHER_CTX *ctx, const EVP_CIPHER *cipher, ENGINE *impl, const unsigned char *key, const unsigned char *iv), +FUNC(int, EVP_DigestFinalXOF, (EVP_MD_CTX *ctx, unsigned char *md, size_t len), + (ctx, md, len)) +FUNC(int, EVP_DigestFinal_ex, + (EVP_MD_CTX *ctx, unsigned char *md, unsigned int *s), (ctx, md, s)) +FUNC(int, EVP_DigestInit_ex, + (EVP_MD_CTX *ctx, const EVP_MD *type, ENGINE *impl), (ctx, type, impl)) +FUNC(int, EVP_DigestUpdate, (EVP_MD_CTX *ctx, const void *d, size_t cnt), + (ctx, d, cnt)) +FUNC(int, EVP_EncryptFinal_ex, + (EVP_CIPHER_CTX *ctx, unsigned char *out, int *outl), (ctx, out, outl)) +FUNC(int, EVP_EncryptInit_ex, + (EVP_CIPHER_CTX *ctx, const EVP_CIPHER *cipher, ENGINE *impl, + const unsigned char *key, const unsigned char *iv), (ctx, cipher, impl, key, iv)) -FUNC(int, EVP_EncryptUpdate, (EVP_CIPHER_CTX *ctx, unsigned char *out, - int *outl, const unsigned char *in, int inl), +FUNC(int, EVP_EncryptUpdate, + (EVP_CIPHER_CTX *ctx, unsigned char *out, int *outl, + const unsigned char *in, int inl), (ctx, out, outl, in, inl)) -FUNC(int, EVP_MD_CTX_copy_ex, (EVP_MD_CTX *out, const EVP_MD_CTX *in), (out, in)) +FUNC(int, EVP_MD_CTX_copy_ex, (EVP_MD_CTX *out, const EVP_MD_CTX *in), + (out, in)) VOID_FUNC(void, EVP_MD_CTX_free, (EVP_MD_CTX *ctx), (ctx)) FUNC(EVP_MD_CTX *, EVP_MD_CTX_new, (void), ()) FUNC(int, EVP_MD_CTX_reset, (EVP_MD_CTX *ctx), (ctx)) @@ -29,12 +38,12 @@ FUNC(const EVP_CIPHER *, EVP_aes_128_ctr, (void), ()) FUNC(const EVP_CIPHER *, EVP_aes_256_ecb, (void), ()) FUNC(const EVP_CIPHER *, EVP_aes_256_ctr, (void), ()) #if OPENSSL_VERSION_NUMBER >= 0x30000000L -FUNC(EVP_CIPHER *, EVP_CIPHER_fetch, (OSSL_LIB_CTX *ctx, const char *algorithm, - const char *properties), +FUNC(EVP_CIPHER *, EVP_CIPHER_fetch, + (OSSL_LIB_CTX *ctx, const char *algorithm, const char *properties), (ctx, algorithm, properties)) VOID_FUNC(void, EVP_CIPHER_free, (EVP_CIPHER *cipher), (cipher)) -FUNC(EVP_MD *, EVP_MD_fetch, (OSSL_LIB_CTX *ctx, const char *algorithm, - const char *properties), +FUNC(EVP_MD *, EVP_MD_fetch, + (OSSL_LIB_CTX *ctx, const char *algorithm, const char *properties), (ctx, algorithm, properties)) VOID_FUNC(void, EVP_MD_free, (EVP_MD *md), (md)) #else @@ -51,3 +60,4 @@ VOID_FUNC(void, OPENSSL_cleanse, (void *ptr, size_t len), (ptr, len)) FUNC(int, RAND_bytes, (unsigned char *buf, int num), (buf, num)) FUNC(int, RAND_poll, (void), ()) FUNC(int, RAND_status, (void), ()) +VOID_FUNC(void, OPENSSL_thread_stop, (void), ()) \ No newline at end of file diff --git a/src/common/ossl_helpers.c b/src/common/ossl_helpers.c index f3e505dcd..4d977ac75 100644 --- a/src/common/ossl_helpers.c +++ b/src/common/ossl_helpers.c @@ -87,6 +87,12 @@ void oqs_ossl_destroy(void) { #endif } +void oqs_thread_stop(void) { +#if defined(OQS_USE_PTHREADS) && defined(OQS_USE_OPENSSL) + OSSL_FUNC(OPENSSL_thread_stop)(); +#endif +} + const EVP_MD *oqs_sha256(void) { #if OPENSSL_VERSION_NUMBER >= 0x30000000L #if defined(OQS_USE_PTHREADS) diff --git a/src/common/ossl_helpers.h b/src/common/ossl_helpers.h index 3e1bc9ff2..7587d80f3 100644 --- a/src/common/ossl_helpers.h +++ b/src/common/ossl_helpers.h @@ -13,6 +13,8 @@ extern "C" { #if defined(OQS_USE_OPENSSL) void oqs_ossl_destroy(void); +void oqs_thread_stop(void); + const EVP_MD *oqs_sha256(void); const EVP_MD *oqs_sha384(void); @@ -39,8 +41,7 @@ const EVP_CIPHER *oqs_aes_256_ctr(void); #ifdef OQS_DLOPEN_OPENSSL -#define FUNC(ret, name, args, cargs) \ - ret _oqs_ossl_##name args; +#define FUNC(ret, name, args, cargs) ret _oqs_ossl_##name args; #define VOID_FUNC FUNC #include "ossl_functions.h" #undef VOID_FUNC diff --git a/tests/test_kem.c b/tests/test_kem.c index 3c6c70b70..2abd07c23 100644 --- a/tests/test_kem.c +++ b/tests/test_kem.c @@ -206,6 +206,9 @@ struct thread_data { void *test_wrapper(void *arg) { struct thread_data *td = arg; td->rc = kem_test_correctness(td->alg_name); +#if defined(OQS_USE_OPENSSL) + OQS_thread_stop(); +#endif return NULL; } #endif diff --git a/tests/test_sig.c b/tests/test_sig.c index a5246cc9d..e856dd92e 100644 --- a/tests/test_sig.c +++ b/tests/test_sig.c @@ -183,6 +183,9 @@ struct thread_data { void *test_wrapper(void *arg) { struct thread_data *td = arg; td->rc = sig_test_correctness(td->alg_name); +#if defined(OQS_USE_OPENSSL) + OQS_thread_stop(); +#endif return NULL; } #endif diff --git a/tests/test_sig_stfl.c b/tests/test_sig_stfl.c index f95e61369..e9fc93b80 100644 --- a/tests/test_sig_stfl.c +++ b/tests/test_sig_stfl.c @@ -22,7 +22,6 @@ #if OQS_USE_PTHREADS_IN_TESTS #include - static pthread_mutex_t *test_sk_lock = NULL; static pthread_mutex_t *sk_lock = NULL; #endif @@ -990,6 +989,9 @@ void *test_query_key(void *arg) { struct lock_test_data *td = arg; printf("\n%s: Start Query Stateful Key info\n", __func__); td->rc = sig_stfl_test_query_key(td->alg_name); +#if defined(OQS_USE_OPENSSL) + OQS_thread_stop(); +#endif printf("%s: End Query Stateful Key info\n\n", __func__); return NULL; } @@ -998,6 +1000,9 @@ void *test_sig_gen(void *arg) { struct lock_test_data *td = arg; printf("\n%s: Start Generate Stateful Signature\n", __func__); td->rc = sig_stfl_test_sig_gen(td->alg_name); +#if defined(OQS_USE_OPENSSL) + OQS_thread_stop(); +#endif printf("%s: End Generate Stateful Signature\n\n", __func__); return NULL; } @@ -1006,6 +1011,9 @@ void *test_create_keys(void *arg) { struct lock_test_data *td = arg; printf("\n%s: Start Generate Keys\n", __func__); td->rc = sig_stfl_test_secret_key_lock(td->alg_name, td->katfile); +#if defined(OQS_USE_OPENSSL) + OQS_thread_stop(); +#endif printf("%s: End Generate Stateful Keys\n\n", __func__); return NULL; } @@ -1013,12 +1021,18 @@ void *test_create_keys(void *arg) { void *test_correctness_wrapper(void *arg) { struct thread_data *td = arg; td->rc = sig_stfl_test_correctness(td->alg_name, td->katfile); +#if defined(OQS_USE_OPENSSL) + OQS_thread_stop(); +#endif return NULL; } void *test_secret_key_wrapper(void *arg) { struct thread_data *td = arg; td->rc = sig_stfl_test_secret_key(td->alg_name, td->katfile); +#if defined(OQS_USE_OPENSSL) + OQS_thread_stop(); +#endif return NULL; } #endif