34
34
35
35
#include "py/runtime.h"
36
36
#include "py/stream.h"
37
+ #include "py/objstr.h"
37
38
38
39
// mbedtls_time_t
39
40
#include "mbedtls/platform.h"
43
44
#include "mbedtls/entropy.h"
44
45
#include "mbedtls/ctr_drbg.h"
45
46
#include "mbedtls/debug.h"
47
+ #include "mbedtls/error.h"
46
48
47
49
typedef struct _mp_obj_ssl_socket_t {
48
50
mp_obj_base_t base ;
@@ -74,8 +76,48 @@ STATIC void mbedtls_debug(void *ctx, int level, const char *file, int line, cons
74
76
}
75
77
#endif
76
78
79
+ STATIC NORETURN void mbedtls_raise_error (int err ) {
80
+ // _mbedtls_ssl_send and _mbedtls_ssl_recv (below) turn positive error codes from the
81
+ // underlying socket into negative codes to pass them through mbedtls. Here we turn them
82
+ // positive again so they get interpreted as the OSError they really are. The
83
+ // cut-off of -256 is a bit hacky, sigh.
84
+ if (err < 0 && err > -256 ) {
85
+ mp_raise_OSError (- err );
86
+ }
87
+
88
+ #if defined(MBEDTLS_ERROR_C )
89
+ // Including mbedtls_strerror takes about 1.5KB due to the error strings.
90
+ // MBEDTLS_ERROR_C is the define used by mbedtls to conditionally include mbedtls_strerror.
91
+ // It is set/unset in the MBEDTLS_CONFIG_FILE which is defined in the Makefile.
92
+
93
+ // Try to allocate memory for the message
94
+ #define ERR_STR_MAX 80 // mbedtls_strerror truncates if it doesn't fit
95
+ mp_obj_str_t * o_str = m_new_obj_maybe (mp_obj_str_t );
96
+ byte * o_str_buf = m_new_maybe (byte , ERR_STR_MAX );
97
+ if (o_str == NULL || o_str_buf == NULL ) {
98
+ mp_raise_OSError (err );
99
+ }
100
+
101
+ // print the error message into the allocated buffer
102
+ mbedtls_strerror (err , (char * )o_str_buf , ERR_STR_MAX );
103
+ size_t len = strlen ((char * )o_str_buf );
104
+
105
+ // Put the exception object together
106
+ o_str -> base .type = & mp_type_str ;
107
+ o_str -> data = o_str_buf ;
108
+ o_str -> len = len ;
109
+ o_str -> hash = qstr_compute_hash (o_str -> data , o_str -> len );
110
+ // raise
111
+ mp_obj_t args [2 ] = { MP_OBJ_NEW_SMALL_INT (err ), MP_OBJ_FROM_PTR (o_str )};
112
+ nlr_raise (mp_obj_exception_make_new (& mp_type_OSError , 2 , 0 , args ));
113
+ #else
114
+ // mbedtls is compiled without error strings so we simply return the err number
115
+ mp_raise_OSError (err ); // err is typically a large negative number
116
+ #endif
117
+ }
118
+
77
119
STATIC int _mbedtls_ssl_send (void * ctx , const byte * buf , size_t len ) {
78
- mp_obj_t sock = * (mp_obj_t * )ctx ;
120
+ mp_obj_t sock = * (mp_obj_t * )ctx ;
79
121
80
122
const mp_stream_p_t * sock_stream = mp_get_stream (sock );
81
123
int err ;
@@ -85,14 +127,14 @@ STATIC int _mbedtls_ssl_send(void *ctx, const byte *buf, size_t len) {
85
127
if (mp_is_nonblocking_error (err )) {
86
128
return MBEDTLS_ERR_SSL_WANT_WRITE ;
87
129
}
88
- return - err ;
130
+ return - err ; // convert an MP_ERRNO to something mbedtls passes through as error
89
131
} else {
90
132
return out_sz ;
91
133
}
92
134
}
93
135
94
136
STATIC int _mbedtls_ssl_recv (void * ctx , byte * buf , size_t len ) {
95
- mp_obj_t sock = * (mp_obj_t * )ctx ;
137
+ mp_obj_t sock = * (mp_obj_t * )ctx ;
96
138
97
139
const mp_stream_p_t * sock_stream = mp_get_stream (sock );
98
140
int err ;
@@ -113,11 +155,11 @@ STATIC mp_obj_ssl_socket_t *socket_new(mp_obj_t sock, struct ssl_args *args) {
113
155
// Verify the socket object has the full stream protocol
114
156
mp_get_stream_raise (sock , MP_STREAM_OP_READ | MP_STREAM_OP_WRITE | MP_STREAM_OP_IOCTL );
115
157
116
- #if MICROPY_PY_USSL_FINALISER
158
+ #if MICROPY_PY_USSL_FINALISER
117
159
mp_obj_ssl_socket_t * o = m_new_obj_with_finaliser (mp_obj_ssl_socket_t );
118
- #else
160
+ #else
119
161
mp_obj_ssl_socket_t * o = m_new_obj (mp_obj_ssl_socket_t );
120
- #endif
162
+ #endif
121
163
o -> base .type = & ussl_socket_type ;
122
164
o -> sock = sock ;
123
165
@@ -141,9 +183,9 @@ STATIC mp_obj_ssl_socket_t *socket_new(mp_obj_t sock, struct ssl_args *args) {
141
183
}
142
184
143
185
ret = mbedtls_ssl_config_defaults (& o -> conf ,
144
- args -> server_side .u_bool ? MBEDTLS_SSL_IS_SERVER : MBEDTLS_SSL_IS_CLIENT ,
145
- MBEDTLS_SSL_TRANSPORT_STREAM ,
146
- MBEDTLS_SSL_PRESET_DEFAULT );
186
+ args -> server_side .u_bool ? MBEDTLS_SSL_IS_SERVER : MBEDTLS_SSL_IS_CLIENT ,
187
+ MBEDTLS_SSL_TRANSPORT_STREAM ,
188
+ MBEDTLS_SSL_PRESET_DEFAULT );
147
189
if (ret != 0 ) {
148
190
goto cleanup ;
149
191
}
@@ -171,7 +213,7 @@ STATIC mp_obj_ssl_socket_t *socket_new(mp_obj_t sock, struct ssl_args *args) {
171
213
172
214
if (args -> key .u_obj != mp_const_none ) {
173
215
size_t key_len ;
174
- const byte * key = (const byte * )mp_obj_str_get_data (args -> key .u_obj , & key_len );
216
+ const byte * key = (const byte * )mp_obj_str_get_data (args -> key .u_obj , & key_len );
175
217
// len should include terminating null
176
218
ret = mbedtls_pk_parse_key (& o -> pkey , key , key_len + 1 , NULL , 0 );
177
219
if (ret != 0 ) {
@@ -180,7 +222,7 @@ STATIC mp_obj_ssl_socket_t *socket_new(mp_obj_t sock, struct ssl_args *args) {
180
222
}
181
223
182
224
size_t cert_len ;
183
- const byte * cert = (const byte * )mp_obj_str_get_data (args -> cert .u_obj , & cert_len );
225
+ const byte * cert = (const byte * )mp_obj_str_get_data (args -> cert .u_obj , & cert_len );
184
226
// len should include terminating null
185
227
ret = mbedtls_x509_crt_parse (& o -> cert , cert , cert_len + 1 );
186
228
if (ret != 0 ) {
@@ -197,7 +239,6 @@ STATIC mp_obj_ssl_socket_t *socket_new(mp_obj_t sock, struct ssl_args *args) {
197
239
if (args -> do_handshake .u_bool ) {
198
240
while ((ret = mbedtls_ssl_handshake (& o -> ssl )) != 0 ) {
199
241
if (ret != MBEDTLS_ERR_SSL_WANT_READ && ret != MBEDTLS_ERR_SSL_WANT_WRITE ) {
200
- printf ("mbedtls_ssl_handshake error: -%x\n" , - ret );
201
242
goto cleanup ;
202
243
}
203
244
}
@@ -217,11 +258,11 @@ STATIC mp_obj_ssl_socket_t *socket_new(mp_obj_t sock, struct ssl_args *args) {
217
258
if (ret == MBEDTLS_ERR_SSL_ALLOC_FAILED ) {
218
259
mp_raise_OSError (MP_ENOMEM );
219
260
} else if (ret == MBEDTLS_ERR_PK_BAD_INPUT_DATA ) {
220
- mp_raise_ValueError ("invalid key" );
261
+ mp_raise_ValueError (MP_ERROR_TEXT ( "invalid key" ) );
221
262
} else if (ret == MBEDTLS_ERR_X509_BAD_INPUT_DATA ) {
222
- mp_raise_ValueError ("invalid cert" );
263
+ mp_raise_ValueError (MP_ERROR_TEXT ( "invalid cert" ) );
223
264
} else {
224
- mp_raise_OSError ( MP_EIO );
265
+ mbedtls_raise_error ( ret );
225
266
}
226
267
}
227
268
@@ -230,7 +271,7 @@ STATIC mp_obj_t mod_ssl_getpeercert(mp_obj_t o_in, mp_obj_t binary_form) {
230
271
if (!mp_obj_is_true (binary_form )) {
231
272
mp_raise_NotImplementedError (NULL );
232
273
}
233
- const mbedtls_x509_crt * peer_cert = mbedtls_ssl_get_peer_cert (& o -> ssl );
274
+ const mbedtls_x509_crt * peer_cert = mbedtls_ssl_get_peer_cert (& o -> ssl );
234
275
if (peer_cert == NULL ) {
235
276
return mp_const_none ;
236
277
}
@@ -318,9 +359,9 @@ STATIC const mp_rom_map_elem_t ussl_socket_locals_dict_table[] = {
318
359
{ MP_ROM_QSTR (MP_QSTR_write ), MP_ROM_PTR (& mp_stream_write_obj ) },
319
360
{ MP_ROM_QSTR (MP_QSTR_setblocking ), MP_ROM_PTR (& socket_setblocking_obj ) },
320
361
{ MP_ROM_QSTR (MP_QSTR_close ), MP_ROM_PTR (& mp_stream_close_obj ) },
321
- #if MICROPY_PY_USSL_FINALISER
362
+ #if MICROPY_PY_USSL_FINALISER
322
363
{ MP_ROM_QSTR (MP_QSTR___del__ ), MP_ROM_PTR (& mp_stream_close_obj ) },
323
- #endif
364
+ #endif
324
365
{ MP_ROM_QSTR (MP_QSTR_getpeercert ), MP_ROM_PTR (& mod_ssl_getpeercert_obj ) },
325
366
};
326
367
@@ -340,7 +381,7 @@ STATIC const mp_obj_type_t ussl_socket_type = {
340
381
.getiter = NULL ,
341
382
.iternext = NULL ,
342
383
.protocol = & ussl_socket_stream_p ,
343
- .locals_dict = (void * )& ussl_socket_locals_dict ,
384
+ .locals_dict = (void * )& ussl_socket_locals_dict ,
344
385
};
345
386
346
387
STATIC mp_obj_t mod_ssl_wrap_socket (size_t n_args , const mp_obj_t * pos_args , mp_map_t * kw_args ) {
@@ -358,7 +399,7 @@ STATIC mp_obj_t mod_ssl_wrap_socket(size_t n_args, const mp_obj_t *pos_args, mp_
358
399
359
400
struct ssl_args args ;
360
401
mp_arg_parse_all (n_args - 1 , pos_args + 1 , kw_args ,
361
- MP_ARRAY_SIZE (allowed_args ), allowed_args , (mp_arg_val_t * )& args );
402
+ MP_ARRAY_SIZE (allowed_args ), allowed_args , (mp_arg_val_t * )& args );
362
403
363
404
return MP_OBJ_FROM_PTR (socket_new (sock , & args ));
364
405
}
@@ -373,7 +414,7 @@ STATIC MP_DEFINE_CONST_DICT(mp_module_ssl_globals, mp_module_ssl_globals_table);
373
414
374
415
const mp_obj_module_t mp_module_ussl = {
375
416
.base = { & mp_type_module },
376
- .globals = (mp_obj_dict_t * )& mp_module_ssl_globals ,
417
+ .globals = (mp_obj_dict_t * )& mp_module_ssl_globals ,
377
418
};
378
419
379
420
#endif // MICROPY_PY_USSL
0 commit comments