Skip to content

Commit

Permalink
Add new API to clean up OpenSSL threads.
Browse files Browse the repository at this point in the history
Signed-off-by: Norman Ashley <nashley@cisco.com>
  • Loading branch information
ashman-p committed Oct 23, 2024
1 parent 90030a4 commit 98bd132
Show file tree
Hide file tree
Showing 8 changed files with 113 additions and 60 deletions.
6 changes: 6 additions & 0 deletions src/common/common.c
Original file line number Diff line number Diff line change
Expand Up @@ -235,6 +235,12 @@ OQS_API void OQS_init(void) {
#endif
}

OQS_API void OQS_thread_stop(void) {
#if defined(OQS_USE_OPENSSL)
oqs_thread_stop();
#endif
}

OQS_API const char *OQS_version(void) {
return OQS_VERSION_TEXT;
}
Expand Down
98 changes: 54 additions & 44 deletions src/common/common.h
Original file line number Diff line number Diff line change
Expand Up @@ -5,14 +5,13 @@
* SPDX-License-Identifier: MIT
*/


#ifndef OQS_COMMON_H
#define OQS_COMMON_H

#include <limits.h>
#include <stdint.h>
#include <stdlib.h>
#include <stdio.h>
#include <stdlib.h>

#include <oqs/oqsconfig.h>

Expand All @@ -27,14 +26,15 @@ extern "C" {
* using OpenSSL functions when OQS_USE_OPENSSL is defined, and
* standard C library functions otherwise.
*/
#if (defined(OQS_USE_OPENSSL) || defined(OQS_DLOPEN_OPENSSL)) && defined(OPENSSL_VERSION_NUMBER)
#if (defined(OQS_USE_OPENSSL) || defined(OQS_DLOPEN_OPENSSL)) && \
defined(OPENSSL_VERSION_NUMBER)
#include <openssl/crypto.h>

/**
* Allocates memory of a given size.
* @param size The size of the memory to be allocated in bytes.
* @return A pointer to the allocated memory.
*/
* Allocates memory of a given size.
* @param size The size of the memory to be allocated in bytes.
* @return A pointer to the allocated memory.
*/
#define OQS_MEM_malloc(size) OPENSSL_malloc(size)

/**
Expand All @@ -43,7 +43,8 @@ extern "C" {
* @param element_size The size of each element in bytes.
* @return A pointer to the allocated memory.
*/
#define OQS_MEM_calloc(num_elements, element_size) OPENSSL_zalloc((num_elements) * (element_size))
#define OQS_MEM_calloc(num_elements, element_size) \
OPENSSL_zalloc((num_elements) * (element_size))
/**
* Duplicates a string.
* @param str The string to be duplicated.
Expand All @@ -52,10 +53,10 @@ extern "C" {
#define OQS_MEM_strdup(str) OPENSSL_strdup(str)
#else
/**
* Allocates memory of a given size.
* @param size The size of the memory to be allocated in bytes.
* @return A pointer to the allocated memory.
*/
* Allocates memory of a given size.
* @param size The size of the memory to be allocated in bytes.
* @return A pointer to the allocated memory.
*/
#define OQS_MEM_malloc(size) malloc(size) // IGNORE memory-check

/**
Expand All @@ -64,7 +65,8 @@ extern "C" {
* @param element_size The size of each element in bytes.
* @return A pointer to the allocated memory.
*/
#define OQS_MEM_calloc(num_elements, element_size) calloc(num_elements, element_size) // IGNORE memory-check
#define OQS_MEM_calloc(num_elements, element_size) \
calloc(num_elements, element_size) // IGNORE memory-check
/**
* Duplicates a string.
* @param str The string to be duplicated.
Expand All @@ -77,13 +79,14 @@ extern "C" {
* Macro for terminating the program if x is
* a null pointer.
*/
#define OQS_EXIT_IF_NULLPTR(x, loc) \
do { \
if ( (x) == (void*)0 ) { \
fprintf(stderr, "Unexpected NULL returned from %s API. Exiting.\n", loc); \
exit(EXIT_FAILURE); \
} \
} while (0)
#define OQS_EXIT_IF_NULLPTR(x, loc) \
do { \
if ((x) == (void *)0) { \
fprintf(stderr, "Unexpected NULL returned from %s API. Exiting.\n", \
loc); \
exit(EXIT_FAILURE); \
} \
} while (0)

/**
* This macro is intended to replace those assert()s
Expand All @@ -98,22 +101,24 @@ extern "C" {
*/
#ifdef OQS_USE_OPENSSL
#ifdef OPENSSL_NO_STDIO
#define OQS_OPENSSL_GUARD(x) \
do { \
if( 1 != (x) ) { \
fprintf(stderr, "Error return value from OpenSSL API: %d. Exiting.\n", x); \
exit(EXIT_FAILURE); \
} \
} while (0)
#define OQS_OPENSSL_GUARD(x) \
do { \
if (1 != (x)) { \
fprintf(stderr, "Error return value from OpenSSL API: %d. Exiting.\n", \
x); \
exit(EXIT_FAILURE); \
} \
} while (0)
#else // OPENSSL_NO_STDIO
#define OQS_OPENSSL_GUARD(x) \
do { \
if( 1 != (x) ) { \
fprintf(stderr, "Error return value from OpenSSL API: %d. Exiting.\n", x); \
OSSL_FUNC(ERR_print_errors_fp)(stderr); \
exit(EXIT_FAILURE); \
} \
} while (0)
#define OQS_OPENSSL_GUARD(x) \
do { \
if (1 != (x)) { \
fprintf(stderr, "Error return value from OpenSSL API: %d. Exiting.\n", \
x); \
OSSL_FUNC(ERR_print_errors_fp)(stderr); \
exit(EXIT_FAILURE); \
} \
} while (0)
#endif // OPENSSL_NO_STDIO
#endif // OQS_USE_OPENSSL

Expand All @@ -123,13 +128,13 @@ extern "C" {
* only handle values up to INT_MAX for those parameters.
* This macro is a temporary workaround for such functions.
*/
#define SIZE_T_TO_INT_OR_EXIT(size_t_var_name, int_var_name) \
int int_var_name = 0; \
if (size_t_var_name <= INT_MAX) { \
int_var_name = (int)size_t_var_name; \
} else { \
exit(EXIT_FAILURE); \
}
#define SIZE_T_TO_INT_OR_EXIT(size_t_var_name, int_var_name) \
int int_var_name = 0; \
if (size_t_var_name <= INT_MAX) { \
int_var_name = (int)size_t_var_name; \
} else { \
exit(EXIT_FAILURE); \
}

/**
* Defines which functions should be exposed outside the LibOQS library
Expand Down Expand Up @@ -213,6 +218,11 @@ OQS_API int OQS_CPU_has_extension(OQS_CPU_EXT ext);
*/
OQS_API void OQS_init(void);

/**
* This function stops and frees OpenSSL threads resources in the correct order
*/
OQS_API void OQS_thread_stop(void);

/**
* This function frees prefetched OpenSSL objects
*/
Expand Down Expand Up @@ -277,8 +287,8 @@ OQS_API void OQS_MEM_insecure_free(void *ptr);
* Allocates size bytes of uninitialized memory with a base pointer that is
* a multiple of alignment. Alignment must be a power of two and a multiple
* of sizeof(void *). Size must be a multiple of alignment.
* @note The allocated memory should be freed with `OQS_MEM_aligned_free` when it
* is no longer needed.
* @note The allocated memory should be freed with `OQS_MEM_aligned_free` when
* it is no longer needed.
*/
void *OQS_MEM_aligned_alloc(size_t alignment, size_t size);

Expand Down
36 changes: 23 additions & 13 deletions src/common/ossl_functions.h
Original file line number Diff line number Diff line change
Expand Up @@ -10,17 +10,26 @@ VOID_FUNC(void, ERR_print_errors_fp, (FILE *fp), (fp))
VOID_FUNC(void, EVP_CIPHER_CTX_free, (EVP_CIPHER_CTX *c), (c))
FUNC(EVP_CIPHER_CTX *, EVP_CIPHER_CTX_new, (void), ())
FUNC(int, EVP_CIPHER_CTX_set_padding, (EVP_CIPHER_CTX *c, int pad), (c, pad))
FUNC(int, EVP_DigestFinalXOF, (EVP_MD_CTX *ctx, unsigned char *md, size_t len), (ctx, md, len))
FUNC(int, EVP_DigestFinal_ex, (EVP_MD_CTX *ctx, unsigned char *md, unsigned int *s), (ctx, md, s))
FUNC(int, EVP_DigestInit_ex, (EVP_MD_CTX *ctx, const EVP_MD *type, ENGINE *impl), (ctx, type, impl))
FUNC(int, EVP_DigestUpdate, (EVP_MD_CTX *ctx, const void *d, size_t cnt), (ctx, d, cnt))
FUNC(int, EVP_EncryptFinal_ex, (EVP_CIPHER_CTX *ctx, unsigned char *out, int *outl), (ctx, out, outl))
FUNC(int, EVP_EncryptInit_ex, (EVP_CIPHER_CTX *ctx, const EVP_CIPHER *cipher, ENGINE *impl, const unsigned char *key, const unsigned char *iv),
FUNC(int, EVP_DigestFinalXOF, (EVP_MD_CTX *ctx, unsigned char *md, size_t len),
(ctx, md, len))
FUNC(int, EVP_DigestFinal_ex,
(EVP_MD_CTX *ctx, unsigned char *md, unsigned int *s), (ctx, md, s))
FUNC(int, EVP_DigestInit_ex,
(EVP_MD_CTX *ctx, const EVP_MD *type, ENGINE *impl), (ctx, type, impl))
FUNC(int, EVP_DigestUpdate, (EVP_MD_CTX *ctx, const void *d, size_t cnt),
(ctx, d, cnt))
FUNC(int, EVP_EncryptFinal_ex,
(EVP_CIPHER_CTX *ctx, unsigned char *out, int *outl), (ctx, out, outl))
FUNC(int, EVP_EncryptInit_ex,
(EVP_CIPHER_CTX *ctx, const EVP_CIPHER *cipher, ENGINE *impl,
const unsigned char *key, const unsigned char *iv),
(ctx, cipher, impl, key, iv))
FUNC(int, EVP_EncryptUpdate, (EVP_CIPHER_CTX *ctx, unsigned char *out,
int *outl, const unsigned char *in, int inl),
FUNC(int, EVP_EncryptUpdate,
(EVP_CIPHER_CTX *ctx, unsigned char *out, int *outl,
const unsigned char *in, int inl),
(ctx, out, outl, in, inl))
FUNC(int, EVP_MD_CTX_copy_ex, (EVP_MD_CTX *out, const EVP_MD_CTX *in), (out, in))
FUNC(int, EVP_MD_CTX_copy_ex, (EVP_MD_CTX *out, const EVP_MD_CTX *in),
(out, in))
VOID_FUNC(void, EVP_MD_CTX_free, (EVP_MD_CTX *ctx), (ctx))
FUNC(EVP_MD_CTX *, EVP_MD_CTX_new, (void), ())
FUNC(int, EVP_MD_CTX_reset, (EVP_MD_CTX *ctx), (ctx))
Expand All @@ -29,12 +38,12 @@ FUNC(const EVP_CIPHER *, EVP_aes_128_ctr, (void), ())
FUNC(const EVP_CIPHER *, EVP_aes_256_ecb, (void), ())
FUNC(const EVP_CIPHER *, EVP_aes_256_ctr, (void), ())
#if OPENSSL_VERSION_NUMBER >= 0x30000000L
FUNC(EVP_CIPHER *, EVP_CIPHER_fetch, (OSSL_LIB_CTX *ctx, const char *algorithm,
const char *properties),
FUNC(EVP_CIPHER *, EVP_CIPHER_fetch,
(OSSL_LIB_CTX *ctx, const char *algorithm, const char *properties),
(ctx, algorithm, properties))
VOID_FUNC(void, EVP_CIPHER_free, (EVP_CIPHER *cipher), (cipher))
FUNC(EVP_MD *, EVP_MD_fetch, (OSSL_LIB_CTX *ctx, const char *algorithm,
const char *properties),
FUNC(EVP_MD *, EVP_MD_fetch,
(OSSL_LIB_CTX *ctx, const char *algorithm, const char *properties),
(ctx, algorithm, properties))
VOID_FUNC(void, EVP_MD_free, (EVP_MD *md), (md))
#else
Expand All @@ -51,3 +60,4 @@ VOID_FUNC(void, OPENSSL_cleanse, (void *ptr, size_t len), (ptr, len))
FUNC(int, RAND_bytes, (unsigned char *buf, int num), (buf, num))
FUNC(int, RAND_poll, (void), ())
FUNC(int, RAND_status, (void), ())
VOID_FUNC(void, OPENSSL_thread_stop, (void), ())
6 changes: 6 additions & 0 deletions src/common/ossl_helpers.c
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,12 @@ void oqs_ossl_destroy(void) {
#endif
}

void oqs_thread_stop(void) {
#if defined(OQS_USE_PTHREADS) && defined(OQS_USE_OPENSSL)
OSSL_FUNC(OPENSSL_thread_stop)();
#endif
}

const EVP_MD *oqs_sha256(void) {
#if OPENSSL_VERSION_NUMBER >= 0x30000000L
#if defined(OQS_USE_PTHREADS)
Expand Down
5 changes: 3 additions & 2 deletions src/common/ossl_helpers.h
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@ extern "C" {
#if defined(OQS_USE_OPENSSL)
void oqs_ossl_destroy(void);

void oqs_thread_stop(void);

const EVP_MD *oqs_sha256(void);

const EVP_MD *oqs_sha384(void);
Expand All @@ -39,8 +41,7 @@ const EVP_CIPHER *oqs_aes_256_ctr(void);

#ifdef OQS_DLOPEN_OPENSSL

#define FUNC(ret, name, args, cargs) \
ret _oqs_ossl_##name args;
#define FUNC(ret, name, args, cargs) ret _oqs_ossl_##name args;
#define VOID_FUNC FUNC
#include "ossl_functions.h"
#undef VOID_FUNC
Expand Down
3 changes: 3 additions & 0 deletions tests/test_kem.c
Original file line number Diff line number Diff line change
Expand Up @@ -206,6 +206,9 @@ struct thread_data {
void *test_wrapper(void *arg) {
struct thread_data *td = arg;
td->rc = kem_test_correctness(td->alg_name);
#if defined(OQS_USE_OPENSSL)
OQS_thread_stop();
#endif
return NULL;
}
#endif
Expand Down
3 changes: 3 additions & 0 deletions tests/test_sig.c
Original file line number Diff line number Diff line change
Expand Up @@ -183,6 +183,9 @@ struct thread_data {
void *test_wrapper(void *arg) {
struct thread_data *td = arg;
td->rc = sig_test_correctness(td->alg_name);
#if defined(OQS_USE_OPENSSL)
OQS_thread_stop();
#endif
return NULL;
}
#endif
Expand Down
16 changes: 15 additions & 1 deletion tests/test_sig_stfl.c
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@

#if OQS_USE_PTHREADS_IN_TESTS
#include <pthread.h>

static pthread_mutex_t *test_sk_lock = NULL;
static pthread_mutex_t *sk_lock = NULL;
#endif
Expand Down Expand Up @@ -990,6 +989,9 @@ void *test_query_key(void *arg) {
struct lock_test_data *td = arg;
printf("\n%s: Start Query Stateful Key info\n", __func__);
td->rc = sig_stfl_test_query_key(td->alg_name);
#if defined(OQS_USE_OPENSSL)
OQS_thread_stop();
#endif
printf("%s: End Query Stateful Key info\n\n", __func__);
return NULL;
}
Expand All @@ -998,6 +1000,9 @@ void *test_sig_gen(void *arg) {
struct lock_test_data *td = arg;
printf("\n%s: Start Generate Stateful Signature\n", __func__);
td->rc = sig_stfl_test_sig_gen(td->alg_name);
#if defined(OQS_USE_OPENSSL)
OQS_thread_stop();
#endif
printf("%s: End Generate Stateful Signature\n\n", __func__);
return NULL;
}
Expand All @@ -1006,19 +1011,28 @@ void *test_create_keys(void *arg) {
struct lock_test_data *td = arg;
printf("\n%s: Start Generate Keys\n", __func__);
td->rc = sig_stfl_test_secret_key_lock(td->alg_name, td->katfile);
#if defined(OQS_USE_OPENSSL)
OQS_thread_stop();
#endif
printf("%s: End Generate Stateful Keys\n\n", __func__);
return NULL;
}

void *test_correctness_wrapper(void *arg) {
struct thread_data *td = arg;
td->rc = sig_stfl_test_correctness(td->alg_name, td->katfile);
#if defined(OQS_USE_OPENSSL)
OQS_thread_stop();
#endif
return NULL;
}

void *test_secret_key_wrapper(void *arg) {
struct thread_data *td = arg;
td->rc = sig_stfl_test_secret_key(td->alg_name, td->katfile);
#if defined(OQS_USE_OPENSSL)
OQS_thread_stop();
#endif
return NULL;
}
#endif
Expand Down

0 comments on commit 98bd132

Please sign in to comment.