From 333ff006018395a39b1910d840857545254b5863 Mon Sep 17 00:00:00 2001 From: Mitchell Dorrell Date: Mon, 11 Sep 2023 11:10:23 -0400 Subject: [PATCH] Add missing return-code checking in CC20-MT Check return codes in CC20-MT initialization function and implement cleanup upon encountering a failure condition. Also properly add missing cleanup code for OpenSSL 1.x in the chachapoly_free_mt() function. --- cipher-chachapoly-libcrypto-mt.c | 115 ++++++++++++++++++++++++++----- 1 file changed, 99 insertions(+), 16 deletions(-) diff --git a/cipher-chachapoly-libcrypto-mt.c b/cipher-chachapoly-libcrypto-mt.c index e6c480f1207..9c9b3dc056e 100644 --- a/cipher-chachapoly-libcrypto-mt.c +++ b/cipher-chachapoly-libcrypto-mt.c @@ -322,6 +322,15 @@ chachapoly_free_mt(struct chachapoly_ctx_mt * ctx_mt) EVP_MAC_CTX_free(ctx_mt->poly_ctx); ctx_mt->poly_ctx = NULL; } +#elif !defined(WITH_OPENSSL3) && defined(EVP_PKEY_POLY1305) + if (ctx_mt->md_ctx != NULL) { + EVP_MD_CTX_free(ctx_mt->md_ctx); + ctx_mt->md_ctx = NULL; + } + if (ctx_mt->pkey != NULL) { + EVP_PKEY_free(ctx_mt->pkey); + ctx_mt->pkey = NULL; + } #endif /* @@ -408,7 +417,6 @@ chachapoly_new_mt(u_int startseqnr, const u_char * key, u_int keylen) ctx_mt->batchID = startseqnr / NUMSTREAMS; #ifdef OPENSSL_HAVE_POLY_EVP - /* TODO: more error checks! */ EVP_MAC *mac = NULL; if ((mac = EVP_MAC_fetch(NULL, "POLY1305", NULL)) == NULL) { freezero(ctx_mt, sizeof(*ctx_mt)); @@ -421,32 +429,73 @@ chachapoly_new_mt(u_int startseqnr, const u_char * key, u_int keylen) return NULL; } #elif !defined(WITH_OPENSSL3) && defined(EVP_PKEY_POLY1305) - ctx_mt->md_ctx = EVP_MD_CTX_new(); - ctx_mt->pkey = EVP_PKEY_new_mac_key(EVP_PKEY_POLY1305, NULL, ctx_mt->zeros, - POLY1305_KEYLEN); - EVP_DigestSignInit(ctx_mt->md_ctx, &ctx_mt->poly_ctx, NULL, NULL, ctx_mt->pkey); + if ((ctx_mt->md_ctx = EVP_MD_CTX_new()) == NULL) { + freezero(ctx_mt, sizeof(*ctx_mt)); + explicit_bzero(&startseqnr, sizeof(startseqnr)); + return NULL; + } + if ((ctx_mt->pkey = EVP_PKEY_new_mac_key(EVP_PKEY_POLY1305, NULL, + ctx_mt->zeros, POLY1305_KEYLEN)) == NULL) { + EVP_MD_CTX_free(ctx_mt->md_ctx); + freezero(ctx_mt, sizeof(*ctx_mt)); + explicit_bzero(&startseqnr, sizeof(startseqnr)); + return NULL; + } + if (EVP_DigestSignInit(ctx_mt->md_ctx, &ctx_mt->poly_ctx, NULL, NULL, + ctx_mt->pkey) == 0) { + EVP_PKEY_free(ctx_mt->pkey); + EVP_MD_CTX_free(ctx_mt->md_ctx); + freezero(ctx_mt, sizeof(*ctx_mt)); + explicit_bzero(&startseqnr, sizeof(startseqnr)); + return NULL; + } #else ctx_mt->poly_ctx = NULL; #endif - /* TODO: add error checks */ - pthread_mutex_init(&ctx_mt->batchID_lock, NULL); - pthread_mutex_init(&(ctx_mt->tid_lock), NULL); - pthread_cond_init(&(ctx_mt->cond), NULL); + if (pthread_mutex_init(&ctx_mt->batchID_lock, NULL) != 0) + goto failinitmutexbatchIDlock; + if (pthread_mutex_init(&(ctx_mt->tid_lock), NULL) != 0) + goto failinitmutexTID; + if (pthread_cond_init(&(ctx_mt->cond), NULL) != 0) + goto failinitcond; ctx_mt->batches[ctx_mt->batchID % 2].batchID = ctx_mt->batchID; ctx_mt->batches[(ctx_mt->batchID + 1) % 2].batchID = ctx_mt->batchID + 1; - for (int i=0; i<2; i++) { - struct mt_keystream_batch * batch = &(ctx_mt->batches[i]); - pthread_mutex_init(&(batch->lock), NULL); + int mutexI; + for (mutexI = 0; mutexI < 2; mutexI++) { + struct mt_keystream_batch * batch = &(ctx_mt->batches[mutexI]); + if (pthread_mutex_init(&(batch->lock), NULL) != 0) + break; + /* + * TODO: these are likely to change, don't worry about error + * checking here + */ pthread_barrier_init(&(batch->bar_start), NULL, NUMTHREADS); pthread_barrier_init(&(batch->bar_end), NULL, NUMTHREADS); } + if (mutexI < 2) { + /* Backtrack starting with 'mutexI - 1' */ + for (mutexI--; mutexI >= 0; mutexI--) { + struct mt_keystream_batch * batch = + &(ctx_mt->batches[mutexI]); + pthread_mutex_destroy(&(batch->lock)); + } + goto failinitmutex; + } - for (int i=0; itds[i]), key); - ctx_mt->tds[i].batchID = ctx_mt->batchID; + int tDataI; + for (tDataI = 0; tDataI < NUMTHREADS; tDataI++) { + if (initialize_threadData(&(ctx_mt->tds[tDataI]), key) != 0) + break; + ctx_mt->tds[tDataI].batchID = ctx_mt->batchID; + } + if (tDataI < NUMTHREADS) { + /* Backtrack starting with 'tDataI - 1' */ + for (tDataI--; tDataI >= 0; tDataI--) + free_threadData(&(ctx_mt->tds[tDataI])); + goto failinitthreadData; } struct threadData * mainData; @@ -484,7 +533,8 @@ chachapoly_new_mt(u_int startseqnr, const u_char * key, u_int keylen) ctx_mt->adv_tid = ctx_mt->self_tid; int ret=0; /* Block workers from reading their thread IDs before we set them. */ - pthread_mutex_lock(&(ctx_mt->tid_lock)); + if (pthread_mutex_lock(&(ctx_mt->tid_lock)) != 0) + goto faillockprethreads; /* was reporting the TID using gettid() but it's not portable */ debug2_f("
", getpid(), pthread_self()); for (int i=0; itds[i])); + failinitthreadData: + for (int i = 0; i < 2; i++) { + struct mt_keystream_batch * batch = &(ctx_mt->batches[i]); + pthread_mutex_destroy(&(batch->lock)); + } + failinitmutex: + pthread_cond_destroy(&(ctx_mt->cond)); + failinitcond: + pthread_mutex_destroy(&(ctx_mt->tid_lock)); + failinitmutexTID: + pthread_mutex_destroy(&ctx_mt->batchID_lock); + failinitmutexbatchIDlock: +#ifdef OPENSSL_HAVE_POLY_EVP + if (ctx_mt->poly_ctx != NULL) { + EVP_MAC_CTX_free(ctx_mt->poly_ctx); + ctx_mt->poly_ctx = NULL; + } +#elif !defined(WITH_OPENSSL3) && defined(EVP_PKEY_POLY1305) + if (ctx_mt->md_ctx != NULL) { + EVP_MD_CTX_free(ctx_mt->md_ctx); + ctx_mt->md_ctx = NULL; + } + if (ctx_mt->pkey != NULL) { + EVP_PKEY_free(ctx_mt->pkey); + ctx_mt->pkey = NULL; + } +#endif + freezero(ctx_mt, sizeof(*ctx_mt)); + explicit_bzero(&startseqnr, sizeof(startseqnr)); + return NULL; } /* a fast method to XOR the keystream against the data */