Skip to content

Commit bfd6ad9

Browse files
Carglglzdpgeorge
authored andcommitted
extmod/asyncio: Add ssl support with SSLContext.
This adds asyncio ssl support with SSLContext and the corresponding tests in `tests/net_inet` and `tests/multi_net`. Note that not doing the handshake on connect will delegate the handshake to the following `mbedtls_ssl_read/write` calls. However if the handshake fails when a client certificate is required and not presented by the peer, it needs to be notified of this handshake error (otherwise it will hang until timeout if any). Finally at MicroPython side raise the proper mbedtls error code and message. Signed-off-by: Carlos Gil <carlosgilglez@gmail.com>
1 parent f33dfb9 commit bfd6ad9

13 files changed

+469
-5
lines changed

.gitattributes

+1
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
*.jpg binary
1414
*.dxf binary
1515
*.mpy binary
16+
*.der binary
1617

1718
# These should also not be modified by git.
1819
tests/basics/string_cr_conversion.py -text

extmod/asyncio/stream.py

+24-5
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,8 @@ def readline(self):
6363
while True:
6464
yield core._io_queue.queue_read(self.s)
6565
l2 = self.s.readline() # may do multiple reads but won't block
66+
if l2 is None:
67+
continue
6668
l += l2
6769
if not l2 or l[-1] == 10: # \n (check l in case l2 is str)
6870
return l
@@ -100,19 +102,29 @@ def drain(self):
100102
# Create a TCP stream connection to a remote host
101103
#
102104
# async
103-
def open_connection(host, port):
105+
def open_connection(host, port, ssl=None, server_hostname=None):
104106
from errno import EINPROGRESS
105107
import socket
106108

107109
ai = socket.getaddrinfo(host, port, 0, socket.SOCK_STREAM)[0] # TODO this is blocking!
108110
s = socket.socket(ai[0], ai[1], ai[2])
109111
s.setblocking(False)
110-
ss = Stream(s)
111112
try:
112113
s.connect(ai[-1])
113114
except OSError as er:
114115
if er.errno != EINPROGRESS:
115116
raise er
117+
# wrap with SSL, if requested
118+
if ssl:
119+
if ssl is True:
120+
import ssl as _ssl
121+
122+
ssl = _ssl.SSLContext(_ssl.PROTOCOL_TLS_CLIENT)
123+
if not server_hostname:
124+
server_hostname = host
125+
s = ssl.wrap_socket(s, server_hostname=server_hostname, do_handshake_on_connect=False)
126+
s.setblocking(False)
127+
ss = Stream(s)
116128
yield core._io_queue.queue_write(s)
117129
return ss, ss
118130

@@ -135,7 +147,7 @@ def close(self):
135147
async def wait_closed(self):
136148
await self.task
137149

138-
async def _serve(self, s, cb):
150+
async def _serve(self, s, cb, ssl):
139151
self.state = False
140152
# Accept incoming connections
141153
while True:
@@ -156,14 +168,21 @@ async def _serve(self, s, cb):
156168
except:
157169
# Ignore a failed accept
158170
continue
171+
if ssl:
172+
try:
173+
s2 = ssl.wrap_socket(s2, server_side=True, do_handshake_on_connect=False)
174+
except OSError as e:
175+
core.sys.print_exception(e)
176+
s2.close()
177+
continue
159178
s2.setblocking(False)
160179
s2s = Stream(s2, {"peername": addr})
161180
core.create_task(cb(s2s, s2s))
162181

163182

164183
# Helper function to start a TCP stream server, running as a new task
165184
# TODO could use an accept-callback on socket read activity instead of creating a task
166-
async def start_server(cb, host, port, backlog=5):
185+
async def start_server(cb, host, port, backlog=5, ssl=None):
167186
import socket
168187

169188
# Create and bind server socket.
@@ -176,7 +195,7 @@ async def start_server(cb, host, port, backlog=5):
176195

177196
# Create and return server object and task.
178197
srv = Server()
179-
srv.task = core.create_task(srv._serve(s, cb))
198+
srv.task = core.create_task(srv._serve(s, cb, ssl))
180199
try:
181200
# Ensure that the _serve task has been scheduled so that it gets to
182201
# handle cancellation.

extmod/modssl_mbedtls.c

+42
Original file line numberDiff line numberDiff line change
@@ -166,6 +166,46 @@ STATIC NORETURN void mbedtls_raise_error(int err) {
166166
#endif
167167
}
168168

169+
STATIC void ssl_check_async_handshake_failure(mp_obj_ssl_socket_t *sslsock, int *errcode) {
170+
if (
171+
#if MBEDTLS_VERSION_NUMBER >= 0x03000000
172+
(*errcode < 0) && (mbedtls_ssl_is_handshake_over(&sslsock->ssl) == 0) && (*errcode != MBEDTLS_ERR_SSL_CONN_EOF)
173+
#else
174+
(*errcode < 0) && (*errcode != MBEDTLS_ERR_SSL_CONN_EOF)
175+
#endif
176+
) {
177+
// Asynchronous handshake is done by mbdetls_ssl_read/write. If the return code is
178+
// MBEDTLS_ERR_XX (i.e < 0) and the handshake is not done due to a handshake failure,
179+
// then notify peer with proper error code and raise local error with mbedtls_raise_error.
180+
181+
if (*errcode == MBEDTLS_ERR_SSL_NO_CLIENT_CERTIFICATE) {
182+
// Check if TLSv1.3 and use proper alert for this case (to be implemented)
183+
// uint8_t alert = MBEDTLS_SSL_ALERT_MSG_CERT_REQUIRED; tlsv1.3
184+
// uint8_t alert = MBEDTLS_SSL_ALERT_MSG_HANDSHAKE_FAILURE; tlsv1.2
185+
mbedtls_ssl_send_alert_message(&sslsock->ssl, MBEDTLS_SSL_ALERT_LEVEL_FATAL,
186+
MBEDTLS_SSL_ALERT_MSG_HANDSHAKE_FAILURE);
187+
}
188+
189+
if (*errcode == MBEDTLS_ERR_X509_CERT_VERIFY_FAILED) {
190+
// The certificate may have been rejected for several reasons.
191+
char xcbuf[256];
192+
uint32_t flags = mbedtls_ssl_get_verify_result(&sslsock->ssl);
193+
int ret = mbedtls_x509_crt_verify_info(xcbuf, sizeof(xcbuf), "\n", flags);
194+
// The length of the string written (not including the terminated nul byte),
195+
// or a negative err code.
196+
if (ret > 0) {
197+
sslsock->sock = MP_OBJ_NULL;
198+
mbedtls_ssl_free(&sslsock->ssl);
199+
mp_raise_msg_varg(&mp_type_ValueError, MP_ERROR_TEXT("%s"), xcbuf);
200+
}
201+
}
202+
203+
sslsock->sock = MP_OBJ_NULL;
204+
mbedtls_ssl_free(&sslsock->ssl);
205+
mbedtls_raise_error(*errcode);
206+
}
207+
}
208+
169209
/******************************************************************************/
170210
// SSLContext type.
171211

@@ -614,6 +654,7 @@ STATIC mp_uint_t socket_read(mp_obj_t o_in, void *buf, mp_uint_t size, int *errc
614654
} else {
615655
o->last_error = ret;
616656
}
657+
ssl_check_async_handshake_failure(o, &ret);
617658
*errcode = ret;
618659
return MP_STREAM_ERROR;
619660
}
@@ -642,6 +683,7 @@ STATIC mp_uint_t socket_write(mp_obj_t o_in, const void *buf, mp_uint_t size, in
642683
} else {
643684
o->last_error = ret;
644685
}
686+
ssl_check_async_handshake_failure(o, &ret);
645687
*errcode = ret;
646688
return MP_STREAM_ERROR;
647689
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,73 @@
1+
# Test asyncio TCP server and client with TLS, transferring some data.
2+
3+
try:
4+
import os
5+
import asyncio
6+
import ssl
7+
except ImportError:
8+
print("SKIP")
9+
raise SystemExit
10+
11+
PORT = 8000
12+
13+
# These are test certificates. See tests/README.md for details.
14+
cert = cafile = "multi_net/rsa_cert.der"
15+
key = "multi_net/rsa_key.der"
16+
17+
try:
18+
os.stat(cafile)
19+
os.stat(key)
20+
except OSError:
21+
print("SKIP")
22+
raise SystemExit
23+
24+
25+
async def handle_connection(reader, writer):
26+
data = await reader.read(100)
27+
print("echo:", data)
28+
writer.write(data)
29+
await writer.drain()
30+
31+
print("close")
32+
writer.close()
33+
await writer.wait_closed()
34+
35+
print("done")
36+
ev.set()
37+
38+
39+
async def tcp_server():
40+
global ev
41+
42+
server_ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER)
43+
server_ctx.load_cert_chain(cert, key)
44+
ev = asyncio.Event()
45+
server = await asyncio.start_server(handle_connection, "0.0.0.0", PORT, ssl=server_ctx)
46+
print("server running")
47+
multitest.next()
48+
async with server:
49+
await asyncio.wait_for(ev.wait(), 10)
50+
51+
52+
async def tcp_client(message):
53+
client_ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
54+
client_ctx.verify_mode = ssl.CERT_REQUIRED
55+
client_ctx.load_verify_locations(cafile=cafile)
56+
reader, writer = await asyncio.open_connection(
57+
IP, PORT, ssl=client_ctx, server_hostname="micropython.local"
58+
)
59+
print("write:", message)
60+
writer.write(message)
61+
await writer.drain()
62+
data = await reader.read(100)
63+
print("read:", data)
64+
65+
66+
def instance0():
67+
multitest.globals(IP=multitest.get_network_ip())
68+
asyncio.run(tcp_server())
69+
70+
71+
def instance1():
72+
multitest.next()
73+
asyncio.run(tcp_client(b"client data"))
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
--- instance0 ---
2+
server running
3+
echo: b'client data'
4+
close
5+
done
6+
--- instance1 ---
7+
write: b'client data'
8+
read: b'client data'
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,76 @@
1+
# Test asyncio TCP server and client with TLS, giving a cert required error.
2+
3+
try:
4+
import os
5+
import asyncio
6+
import ssl
7+
except ImportError:
8+
print("SKIP")
9+
raise SystemExit
10+
11+
PORT = 8000
12+
13+
# These are test certificates. See tests/README.md for details.
14+
cert = cafile = "multi_net/rsa_cert.der"
15+
key = "multi_net/rsa_key.der"
16+
17+
try:
18+
os.stat(cafile)
19+
os.stat(key)
20+
except OSError:
21+
print("SKIP")
22+
raise SystemExit
23+
24+
25+
async def handle_connection(reader, writer):
26+
print("handle connection")
27+
try:
28+
data = await reader.read(100)
29+
except Exception as e:
30+
print(e)
31+
ev.set()
32+
33+
34+
async def tcp_server():
35+
global ev
36+
37+
server_ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER)
38+
server_ctx.load_cert_chain(cert, key)
39+
server_ctx.verify_mode = ssl.CERT_REQUIRED
40+
server_ctx.load_verify_locations(cafile=cert)
41+
ev = asyncio.Event()
42+
server = await asyncio.start_server(handle_connection, "0.0.0.0", PORT, ssl=server_ctx)
43+
print("server running")
44+
multitest.next()
45+
async with server:
46+
await asyncio.wait_for(ev.wait(), 10)
47+
multitest.wait("finished")
48+
print("server done")
49+
50+
51+
async def tcp_client(message):
52+
client_ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
53+
client_ctx.verify_mode = ssl.CERT_REQUIRED
54+
client_ctx.load_verify_locations(cafile=cafile)
55+
reader, writer = await asyncio.open_connection(
56+
IP, PORT, ssl=client_ctx, server_hostname="micropython.local"
57+
)
58+
try:
59+
print("write:", message)
60+
writer.write(message)
61+
print("drain")
62+
await writer.drain()
63+
except Exception as e:
64+
print(e)
65+
print("client done")
66+
multitest.broadcast("finished")
67+
68+
69+
def instance0():
70+
multitest.globals(IP=multitest.get_network_ip())
71+
asyncio.run(tcp_server())
72+
73+
74+
def instance1():
75+
multitest.next()
76+
asyncio.run(tcp_client(b"client data"))
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
--- instance0 ---
2+
server running
3+
handle connection
4+
(-29824, 'MBEDTLS_ERR_SSL_NO_CLIENT_CERTIFICATE')
5+
server done
6+
--- instance1 ---
7+
write: b'client data'
8+
drain
9+
(-30592, 'MBEDTLS_ERR_SSL_FATAL_ALERT_MESSAGE')
10+
client done

0 commit comments

Comments
 (0)