Skip to content

Commit

Permalink
Define OPENSSL_NO_TLS_PHA, typedef PSK callback signatures (#1526)
Browse files Browse the repository at this point in the history
This commit defines a new configuration macro `OPENSSL_NO_TLS_PHA`. This
macro is meant to be used by consuming applications to detect the fact
that we (or other `libssl`s) don't support TLSv1.3's post-handshake
authentication (PHA). We then use this macro in place of
`OPENSSL_IS_AWSLC` to detect PHA support in our CPython patch.

We also enable PSK in our CPython patch and create two PSK-related
callback function signatures [defined by OpenSSL][1] and used by
CPython.

Finally, we fix the now-executed PSK tests in CPython.

[1]:
https://www.openssl.org/docs/man1.1.1/man3/SSL_psk_client_cb_func.html
  • Loading branch information
WillChilds-Klein authored Apr 19, 2024
1 parent 10a389e commit 0aebf17
Show file tree
Hide file tree
Showing 5 changed files with 93 additions and 47 deletions.
16 changes: 16 additions & 0 deletions include/openssl/opensslconf.h
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,11 @@
#define OPENSSL_HEADER_OPENSSLCONF_H


#if defined(__cplusplus)
extern "C" {
#endif


#define OPENSSL_NO_ASYNC
#define OPENSSL_NO_BF
#define OPENSSL_NO_BLAKE2
Expand Down Expand Up @@ -48,6 +53,11 @@
#define OPENSSL_NO_MD2
#define OPENSSL_NO_MDC2
#define OPENSSL_NO_OCB

// OPENSSL_NO_TLS_PHA indicates lack of support for post-handshake
// authentication (PHA) in TLS >= 1.3
#define OPENSSL_NO_TLS_PHA

#define OPENSSL_NO_RC2
#define OPENSSL_NO_RC5
#define OPENSSL_NO_RFC3779
Expand All @@ -68,4 +78,10 @@
#define OPENSSL_NO_TS
#define OPENSSL_NO_WHIRLPOOL


#if defined(__cplusplus)
}
#endif


#endif // OPENSSL_HEADER_OPENSSLCONF_H
28 changes: 18 additions & 10 deletions include/openssl/ssl.h
Original file line number Diff line number Diff line change
Expand Up @@ -3584,6 +3584,14 @@ OPENSSL_EXPORT const SRTP_PROTECTION_PROFILE *SSL_get_selected_srtp_profile(
// PSK_MAX_PSK_LEN is the maximum supported length of a pre-shared key.
#define PSK_MAX_PSK_LEN 256

// SSL_psk_client_cb_func defines a function signature for the client callback.
typedef unsigned int (*SSL_psk_client_cb_func)(SSL *ssl,
const char *hint,
char *identity,
unsigned int max_identity_len,
uint8_t *psk,
unsigned max_psk_len);

// SSL_CTX_set_psk_client_callback sets the callback to be called when PSK is
// negotiated on the client. This callback must be set to enable PSK cipher
// suites on the client.
Expand All @@ -3596,17 +3604,19 @@ OPENSSL_EXPORT const SRTP_PROTECTION_PROFILE *SSL_get_selected_srtp_profile(
// The callback returns the length of the PSK or 0 if no suitable identity was
// found.
OPENSSL_EXPORT void SSL_CTX_set_psk_client_callback(
SSL_CTX *ctx, unsigned (*cb)(SSL *ssl, const char *hint, char *identity,
unsigned max_identity_len, uint8_t *psk,
unsigned max_psk_len));
SSL_CTX *ctx, SSL_psk_client_cb_func cb);

// SSL_set_psk_client_callback sets the callback to be called when PSK is
// negotiated on the client. This callback must be set to enable PSK cipher
// suites on the client. See also |SSL_CTX_set_psk_client_callback|.
OPENSSL_EXPORT void SSL_set_psk_client_callback(
SSL *ssl, unsigned (*cb)(SSL *ssl, const char *hint, char *identity,
unsigned max_identity_len, uint8_t *psk,
unsigned max_psk_len));
SSL *ssl, SSL_psk_client_cb_func cb);

// SSL_psk_server_cb_func defines a function signature for the server callback.
typedef unsigned (*SSL_psk_server_cb_func)(SSL *ssl,
const char *identity,
uint8_t *psk,
unsigned max_psk_len);

// SSL_CTX_set_psk_server_callback sets the callback to be called when PSK is
// negotiated on the server. This callback must be set to enable PSK cipher
Expand All @@ -3616,15 +3626,13 @@ OPENSSL_EXPORT void SSL_set_psk_client_callback(
// length at most |max_psk_len| to |psk| and return the number of bytes written
// or zero if the PSK identity is unknown.
OPENSSL_EXPORT void SSL_CTX_set_psk_server_callback(
SSL_CTX *ctx, unsigned (*cb)(SSL *ssl, const char *identity, uint8_t *psk,
unsigned max_psk_len));
SSL_CTX *ctx, SSL_psk_server_cb_func cb);

// SSL_set_psk_server_callback sets the callback to be called when PSK is
// negotiated on the server. This callback must be set to enable PSK cipher
// suites on the server. See also |SSL_CTX_set_psk_server_callback|.
OPENSSL_EXPORT void SSL_set_psk_server_callback(
SSL *ssl, unsigned (*cb)(SSL *ssl, const char *identity, uint8_t *psk,
unsigned max_psk_len));
SSL *ssl, SSL_psk_server_cb_func cb);

// SSL_CTX_use_psk_identity_hint configures server connections to advertise an
// identity hint of |identity_hint|. It returns one on success and zero on
Expand Down
19 changes: 4 additions & 15 deletions ssl/ssl_lib.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2821,36 +2821,25 @@ const char *SSL_get_psk_identity(const SSL *ssl) {
return session->psk_identity.get();
}

void SSL_set_psk_client_callback(
SSL *ssl, unsigned (*cb)(SSL *ssl, const char *hint, char *identity,
unsigned max_identity_len, uint8_t *psk,
unsigned max_psk_len)) {
void SSL_set_psk_client_callback(SSL *ssl, SSL_psk_client_cb_func cb) {
if (!ssl->config) {
return;
}
ssl->config->psk_client_callback = cb;
}

void SSL_CTX_set_psk_client_callback(
SSL_CTX *ctx, unsigned (*cb)(SSL *ssl, const char *hint, char *identity,
unsigned max_identity_len, uint8_t *psk,
unsigned max_psk_len)) {
void SSL_CTX_set_psk_client_callback(SSL_CTX *ctx, SSL_psk_client_cb_func cb) {
ctx->psk_client_callback = cb;
}

void SSL_set_psk_server_callback(SSL *ssl,
unsigned (*cb)(SSL *ssl, const char *identity,
uint8_t *psk,
unsigned max_psk_len)) {
void SSL_set_psk_server_callback(SSL *ssl, SSL_psk_server_cb_func cb) {
if (!ssl->config) {
return;
}
ssl->config->psk_server_callback = cb;
}

void SSL_CTX_set_psk_server_callback(
SSL_CTX *ctx, unsigned (*cb)(SSL *ssl, const char *identity, uint8_t *psk,
unsigned max_psk_len)) {
void SSL_CTX_set_psk_server_callback(SSL_CTX *ctx, SSL_psk_server_cb_func cb) {
ctx->psk_server_callback = cb;
}

Expand Down
76 changes: 54 additions & 22 deletions tests/ci/integration/python_patch/main/aws-lc-cpython.patch
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ index 6e63a88..7dc83d7 100644
# just check status of PHA flag
h = client.HTTPSConnection('localhost', 443)
diff --git a/Lib/test/test_ssl.py b/Lib/test/test_ssl.py
index 3fa806d..0983212 100644
index 0e50d09..f4b7b3c 100644
--- a/Lib/test/test_ssl.py
+++ b/Lib/test/test_ssl.py
@@ -41,6 +41,7 @@
Expand All @@ -32,15 +32,49 @@ index 3fa806d..0983212 100644
def seclevel_workaround(*ctxs):
""""Lower security level to '1' and allow all ciphers for TLS 1.0/1"""
for ctx in ctxs:
@@ -3997,6 +4001,7 @@ def test_no_legacy_server_connect(self):
@@ -4001,6 +4002,7 @@ def test_no_legacy_server_connect(self):
sni_name=hostname)

@unittest.skipIf(Py_DEBUG_WIN32, "Avoid mixing debug/release CRT on Windows")
+ @unittest.skipIf(Py_OPENSSL_IS_AWSLC, "AWS-LC doesn't support (FF)DHE")
def test_dh_params(self):
# Check we can get a connection with ephemeral Diffie-Hellman
client_context, server_context, hostname = testing_context()
@@ -4457,7 +4462,10 @@ def server_callback(identity):
@@ -4364,14 +4366,14 @@ def test_session_handling(self):
def test_psk(self):
psk = bytes.fromhex('deadbeef')

- client_context = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
+ client_context, server_context, _ = testing_context()
+
client_context.check_hostname = False
client_context.verify_mode = ssl.CERT_NONE
client_context.maximum_version = ssl.TLSVersion.TLSv1_2
client_context.set_ciphers('PSK')
client_context.set_psk_client_callback(lambda hint: (None, psk))

- server_context = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER)
server_context.maximum_version = ssl.TLSVersion.TLSv1_2
server_context.set_ciphers('PSK')
server_context.set_psk_server_callback(lambda identity: psk)
@@ -4443,14 +4445,14 @@ def server_callback(identity):
self.assertEqual(identity, client_identity)
return psk

- client_context = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
+ client_context, server_context, _ = testing_context()
+
client_context.check_hostname = False
client_context.verify_mode = ssl.CERT_NONE
client_context.minimum_version = ssl.TLSVersion.TLSv1_3
client_context.set_ciphers('PSK')
client_context.set_psk_client_callback(client_callback)

- server_context = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER)
server_context.minimum_version = ssl.TLSVersion.TLSv1_3
server_context.set_ciphers('PSK')
server_context.set_psk_server_callback(server_callback, identity_hint)
@@ -4461,7 +4463,10 @@ def server_callback(identity):
s.connect((HOST, server.port))


Expand All @@ -52,7 +86,7 @@ index 3fa806d..0983212 100644
class TestPostHandshakeAuth(unittest.TestCase):
def test_pha_setter(self):
protocols = [
@@ -4733,6 +4741,31 @@ def test_internal_chain_server(self):
@@ -4737,6 +4742,31 @@ def test_internal_chain_server(self):
self.assertEqual(res, b'\x02\n')


Expand Down Expand Up @@ -106,24 +140,22 @@ index cd1cf24..53bcc4c 100644
# The _tkinter module.
#
diff --git a/Modules/_ssl.c b/Modules/_ssl.c
index d00f407..7049f79 100644
index f7fdbf4..204d501 100644
--- a/Modules/_ssl.c
+++ b/Modules/_ssl.c
@@ -188,6 +188,13 @@ extern const SSL_METHOD *TLSv1_2_method(void);
@@ -187,6 +187,11 @@ extern const SSL_METHOD *TLSv1_2_method(void);
#endif


+
+#if defined(OPENSSL_IS_AWSLC) || !defined(TLS1_3_VERSION) || defined(OPENSSL_NO_TLS1_3)


+#if defined(OPENSSL_NO_TLS_PHA) || !defined(TLS1_3_VERSION) || defined(OPENSSL_NO_TLS1_3)
+ #define PY_SSL_NO_POST_HS_AUTH
+ #define OPENSSL_NO_PSK
+#endif
+
+
enum py_ssl_error {
/* these mirror ssl.h */
PY_SSL_ERROR_NONE,
@@ -232,7 +239,7 @@ enum py_proto_version {
@@ -231,7 +236,7 @@ enum py_proto_version {
PY_PROTO_TLSv1 = TLS1_VERSION,
PY_PROTO_TLSv1_1 = TLS1_1_VERSION,
PY_PROTO_TLSv1_2 = TLS1_2_VERSION,
Expand All @@ -132,7 +164,7 @@ index d00f407..7049f79 100644
PY_PROTO_TLSv1_3 = TLS1_3_VERSION,
#else
PY_PROTO_TLSv1_3 = 0x304,
@@ -294,7 +301,7 @@ typedef struct {
@@ -293,7 +298,7 @@ typedef struct {
*/
unsigned int hostflags;
int protocol;
Expand All @@ -141,7 +173,7 @@ index d00f407..7049f79 100644
int post_handshake_auth;
#endif
PyObject *msg_cb;
@@ -886,7 +893,7 @@ newPySSLSocket(PySSLContext *sslctx, PySocketSockObject *sock,
@@ -873,7 +878,7 @@ newPySSLSocket(PySSLContext *sslctx, PySocketSockObject *sock,
SSL_set_mode(self->ssl,
SSL_MODE_ACCEPT_MOVING_WRITE_BUFFER | SSL_MODE_AUTO_RETRY);

Expand All @@ -150,15 +182,15 @@ index d00f407..7049f79 100644
if (sslctx->post_handshake_auth == 1) {
if (socket_type == PY_SSL_SERVER) {
/* bpo-37428: OpenSSL does not ignore SSL_VERIFY_POST_HANDSHAKE.
@@ -1029,6 +1036,7 @@ _ssl__SSLSocket_do_handshake_impl(PySSLSocket *self)
@@ -1016,6 +1021,7 @@ _ssl__SSLSocket_do_handshake_impl(PySSLSocket *self)
} while (err.ssl == SSL_ERROR_WANT_READ ||
err.ssl == SSL_ERROR_WANT_WRITE);
Py_XDECREF(sock);
+
if (ret < 1)
return PySSL_SetError(self, ret, __FILE__, __LINE__);
return PySSL_SetError(self, __FILE__, __LINE__);
if (PySSL_ChainExceptions(self) < 0)
@@ -2788,7 +2796,7 @@ static PyObject *
@@ -2775,7 +2781,7 @@ static PyObject *
_ssl__SSLSocket_verify_client_post_handshake_impl(PySSLSocket *self)
/*[clinic end generated code: output=532147f3b1341425 input=6bfa874810a3d889]*/
{
Expand All @@ -167,7 +199,7 @@ index d00f407..7049f79 100644
int err = SSL_verify_client_post_handshake(self->ssl);
if (err == 0)
return _setSSLError(get_state_sock(self), NULL, 0, __FILE__, __LINE__);
@@ -3212,7 +3220,7 @@ _ssl__SSLContext_impl(PyTypeObject *type, int proto_version)
@@ -3198,7 +3204,7 @@ _ssl__SSLContext_impl(PyTypeObject *type, int proto_version)
X509_VERIFY_PARAM_set_flags(params, X509_V_FLAG_TRUSTED_FIRST);
X509_VERIFY_PARAM_set_hostflags(params, self->hostflags);

Expand All @@ -176,7 +208,7 @@ index d00f407..7049f79 100644
self->post_handshake_auth = 0;
SSL_CTX_set_post_handshake_auth(self->ctx, self->post_handshake_auth);
#endif
@@ -3590,7 +3598,7 @@ set_maximum_version(PySSLContext *self, PyObject *arg, void *c)
@@ -3576,7 +3582,7 @@ set_maximum_version(PySSLContext *self, PyObject *arg, void *c)
return set_min_max_proto_version(self, arg, 1);
}

Expand All @@ -185,7 +217,7 @@ index d00f407..7049f79 100644
static PyObject *
get_num_tickets(PySSLContext *self, void *c)
{
@@ -3621,7 +3629,7 @@ set_num_tickets(PySSLContext *self, PyObject *arg, void *c)
@@ -3607,7 +3613,7 @@ set_num_tickets(PySSLContext *self, PyObject *arg, void *c)

PyDoc_STRVAR(PySSLContext_num_tickets_doc,
"Control the number of TLSv1.3 session tickets");
Expand All @@ -194,7 +226,7 @@ index d00f407..7049f79 100644

static PyObject *
get_security_level(PySSLContext *self, void *c)
@@ -3724,14 +3732,14 @@ set_check_hostname(PySSLContext *self, PyObject *arg, void *c)
@@ -3710,14 +3716,14 @@ set_check_hostname(PySSLContext *self, PyObject *arg, void *c)

static PyObject *
get_post_handshake_auth(PySSLContext *self, void *c) {
Expand All @@ -211,7 +243,7 @@ index d00f407..7049f79 100644
static int
set_post_handshake_auth(PySSLContext *self, PyObject *arg, void *c) {
if (arg == NULL) {
@@ -4973,14 +4981,14 @@ static PyGetSetDef context_getsetlist[] = {
@@ -4959,14 +4965,14 @@ static PyGetSetDef context_getsetlist[] = {
(setter) _PySSLContext_set_msg_callback, NULL},
{"sni_callback", (getter) get_sni_callback,
(setter) set_sni_callback, PySSLContext_sni_callback_doc},
Expand Down
1 change: 1 addition & 0 deletions util/doc.config
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@
},{
"Name": "SSL implementation",
"Headers": [
"include/openssl/opensslconf.h",
"include/openssl/ssl.h"
]
}]
Expand Down

0 comments on commit 0aebf17

Please sign in to comment.