Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Correct ssl recv #129

Open
wants to merge 4 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
50 changes: 33 additions & 17 deletions src/DebugPrintMacros.h
Original file line number Diff line number Diff line change
Expand Up @@ -27,18 +27,34 @@ inline struct _DEBUG_TIME_STAMP debugTimeStamp(void) {
}
#endif

#if defined(DEBUG_ESP_PORT) && !defined(DEBUG_ESP_PORT_PRINTF)

#ifdef __cplusplus
#define DEBUG_ESP_PORT_PRINTF(format, ...) DEBUG_ESP_PORT.printf((format), ##__VA_ARGS__)
#define DEBUG_ESP_PORT_PRINTF_F(format, ...) DEBUG_ESP_PORT.printf_P(PSTR(format), ##__VA_ARGS__)
#define DEBUG_ESP_PORT_FLUSH DEBUG_ESP_PORT.flush
#else
// Handle debug printing from .c without CPP Stream, Print, ... classes
// Cannot handle flash strings in this setting
#define DEBUG_ESP_PORT_PRINTF ets_uart_printf
#define DEBUG_ESP_PORT_PRINTF_F ets_uart_printf
#define DEBUG_ESP_PORT_FLUSH (void)0
#endif

#endif

#if defined(DEBUG_ESP_PORT) && !defined(DEBUG_GENERIC)
#define DEBUG_GENERIC( module, format, ... ) \
do { \
struct _DEBUG_TIME_STAMP st = debugTimeStamp(); \
DEBUG_ESP_PORT.printf( DEBUG_TIME_STAMP_FMT module " " format, st.whole, st.dec, ##__VA_ARGS__ ); \
DEBUG_ESP_PORT_PRINTF( (DEBUG_TIME_STAMP_FMT module " " format), st.whole, st.dec, ##__VA_ARGS__ ); \
} while(false)
#endif
#if defined(DEBUG_ESP_PORT) && !defined(DEBUG_GENERIC_P)
#define DEBUG_GENERIC_P( module, format, ... ) \
#if defined(DEBUG_ESP_PORT) && !defined(DEBUG_GENERIC_F)
#define DEBUG_GENERIC_F( module, format, ... ) \
do { \
struct _DEBUG_TIME_STAMP st = debugTimeStamp(); \
DEBUG_ESP_PORT.printf_P(PSTR( DEBUG_TIME_STAMP_FMT module " " format ), st.whole, st.dec, ##__VA_ARGS__ ); \
DEBUG_ESP_PORT_PRINTF_F( (DEBUG_TIME_STAMP_FMT module " " format), st.whole, st.dec, ##__VA_ARGS__ ); \
} while(false)
#endif

Expand All @@ -47,16 +63,16 @@ inline struct _DEBUG_TIME_STAMP debugTimeStamp(void) {
do { \
if ( !(a) ) { \
DEBUG_GENERIC( module, "%s:%s:%u: ASSERT("#a") failed!\n", __FILE__, __func__, __LINE__); \
DEBUG_ESP_PORT.flush(); \
DEBUG_ESP_PORT_FLUSH(); \
} \
} while(false)
#endif
#if defined(DEBUG_GENERIC_P) && !defined(ASSERT_GENERIC_P)
#define ASSERT_GENERIC_P( a, module ) \
#if defined(DEBUG_GENERIC_F) && !defined(ASSERT_GENERIC_F)
#define ASSERT_GENERIC_F( a, module ) \
do { \
if ( !(a) ) { \
DEBUG_GENERIC_P( module, "%s:%s:%u: ASSERT("#a") failed!\n", __FILE__, __func__, __LINE__); \
DEBUG_ESP_PORT.flush(); \
DEBUG_GENERIC_F( module, "%s:%s:%u: ASSERT("#a") failed!\n", __FILE__, __func__, __LINE__); \
DEBUG_ESP_PORT_FLUSH(); \
} \
} while(false)
#endif
Expand All @@ -65,32 +81,32 @@ inline struct _DEBUG_TIME_STAMP debugTimeStamp(void) {
#define DEBUG_GENERIC(...) do { (void)0;} while(false)
#endif

#ifndef DEBUG_GENERIC_P
#define DEBUG_GENERIC_P(...) do { (void)0;} while(false)
#ifndef DEBUG_GENERIC_F
#define DEBUG_GENERIC_F(...) do { (void)0;} while(false)
#endif

#ifndef ASSERT_GENERIC
#define ASSERT_GENERIC(...) do { (void)0;} while(false)
#endif

#ifndef ASSERT_GENERIC_P
#define ASSERT_GENERIC_P(...) do { (void)0;} while(false)
#ifndef ASSERT_GENERIC_F
#define ASSERT_GENERIC_F(...) do { (void)0;} while(false)
#endif

#ifndef DEBUG_ESP_PRINTF
#define DEBUG_ESP_PRINTF( format, ...) DEBUG_GENERIC_P("[%s]", format, &_FILENAME_[1], ##__VA_ARGS__)
#define DEBUG_ESP_PRINTF( format, ...) DEBUG_GENERIC_F("[%s]", format, &_FILENAME_[1], ##__VA_ARGS__)
#endif

#if defined(DEBUG_ESP_ASYNC_TCP) && !defined(ASYNC_TCP_DEBUG)
#define ASYNC_TCP_DEBUG( format, ...) DEBUG_GENERIC_P("[ASYNC_TCP]", format, ##__VA_ARGS__)
#define ASYNC_TCP_DEBUG( format, ...) DEBUG_GENERIC_F("[ASYNC_TCP]", format, ##__VA_ARGS__)
#endif

#ifndef ASYNC_TCP_ASSERT
#define ASYNC_TCP_ASSERT( a ) ASSERT_GENERIC_P( (a), "[ASYNC_TCP]")
#define ASYNC_TCP_ASSERT( a ) ASSERT_GENERIC_F( (a), "[ASYNC_TCP]")
#endif

#if defined(DEBUG_ESP_TCP_SSL) && !defined(TCP_SSL_DEBUG)
#define TCP_SSL_DEBUG( format, ...) DEBUG_GENERIC_P("[TCP_SSL]", format, ##__VA_ARGS__)
#define TCP_SSL_DEBUG( format, ...) DEBUG_GENERIC_F("[TCP_SSL]", format, ##__VA_ARGS__)
#endif

#endif //_DEBUG_PRINT_MACROS_H
41 changes: 32 additions & 9 deletions src/ESPAsyncTCP.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@ yield(), etc.


*/

#include "Arduino.h"

#include "ESPAsyncTCP.h"
Expand Down Expand Up @@ -355,10 +356,12 @@ void AsyncClient::abort(){
void AsyncClient::close(bool now){
if(_pcb)
tcp_recved(_pcb, _rx_ack_len);
if(now)
if(now) {
ASYNC_TCP_DEBUG("close[%u]: AsyncClient 0x%" PRIXPTR "\n", getConnectionId(), uintptr_t(this));
_close();
else
} else {
_close_pcb = true;
}
}

void AsyncClient::stop() {
Expand Down Expand Up @@ -503,6 +506,7 @@ void AsyncClient::_close(){
err_t err = tcp_close(_pcb);
if(ERR_OK == err) {
setCloseError(err);
ASYNC_TCP_DEBUG("_close[%u]: AsyncClient 0x%" PRIXPTR "\n", getConnectionId(), uintptr_t(this));
} else {
ASYNC_TCP_DEBUG("_close[%u]: abort() called for AsyncClient 0x%" PRIXPTR "\n", getConnectionId(), uintptr_t(this));
abort();
Expand Down Expand Up @@ -664,6 +668,7 @@ void AsyncClient::_poll(std::shared_ptr<ACErrorTracker>& errorTracker, tcp_pcb*

// Close requested
if(_close_pcb){
ASYNC_TCP_DEBUG("_poll[%u]: Process _close_pcb.\n", errorTracker->getConnectionId() );
_close_pcb = false;
_close();
return;
Expand All @@ -679,12 +684,14 @@ void AsyncClient::_poll(std::shared_ptr<ACErrorTracker>& errorTracker, tcp_pcb*
}
// RX Timeout
if(_rx_since_timeout && (now - _rx_last_packet) >= (_rx_since_timeout * 1000)){
ASYNC_TCP_DEBUG("_poll[%u]: RX Timeout.\n", errorTracker->getConnectionId() );
_close();
return;
}
#if ASYNC_TCP_SSL_ENABLED
// SSL Handshake Timeout
if(_pcb_secure && !_handshake_done && (now - _rx_last_packet) >= 2000){
ASYNC_TCP_DEBUG("_poll[%u]: SSL Handshake Timeout.\n", errorTracker->getConnectionId() );
_close();
return;
}
Expand Down Expand Up @@ -762,19 +769,28 @@ err_t AsyncClient::_s_connected(void* arg, void* tpcb, err_t err){

#if ASYNC_TCP_SSL_ENABLED
void AsyncClient::_s_data(void *arg, struct tcp_pcb *tcp, uint8_t * data, size_t len){
(void)tcp;
AsyncClient *c = reinterpret_cast<AsyncClient*>(arg);
if(c->_recv_cb)
c->_recv_cb(c->_recv_cb_arg, c, data, len);
}

void AsyncClient::_s_handshake(void *arg, struct tcp_pcb *tcp, SSL *ssl){
(void)tcp;
(void)ssl;
AsyncClient *c = reinterpret_cast<AsyncClient*>(arg);
c->_handshake_done = true;
if(c->_connect_cb)
c->_connect_cb(c->_connect_cb_arg, c);
}

void AsyncClient::_s_ssl_error(void *arg, struct tcp_pcb *tcp, int8_t err){
(void)tcp;
#ifdef DEBUG_ESP_ASYNC_TCP
AsyncClient *c = reinterpret_cast<AsyncClient*>(arg);
auto errorTracker = c->getACErrorTracker();
ASYNC_TCP_DEBUG("_ssl_error[%u] err = %d\n", errorTracker->getConnectionId(), err);
#endif
reinterpret_cast<AsyncClient*>(arg)->_ssl_error(err);
}
#endif
Expand Down Expand Up @@ -1230,7 +1246,7 @@ err_t AsyncServer::_accept(tcp_pcb* pcb, err_t err){
}
return ERR_OK;
}
ASYNC_TCP_DEBUG("### put to wait: %d\n", _clients_waiting);
//1 ASYNC_TCP_DEBUG("### put to wait: %d\n", _clients_waiting);
new_item->pcb = pcb;
new_item->pb = NULL;
new_item->next = NULL;
Expand All @@ -1252,6 +1268,7 @@ err_t AsyncServer::_accept(tcp_pcb* pcb, err_t err){
if(c){
ASYNC_TCP_DEBUG("_accept[%u]: SSL connected\n", c->getConnectionId());
c->onConnect([this](void * arg, AsyncClient *c){
(void)arg;
_connect_cb(_connect_cb_arg, c);
}, this);
} else {
Expand Down Expand Up @@ -1303,6 +1320,7 @@ err_t AsyncServer::_s_accept(void *arg, tcp_pcb* pcb, err_t err){

#if ASYNC_TCP_SSL_ENABLED
err_t AsyncServer::_poll(tcp_pcb* pcb){
err_t err = ERR_OK;
if(!tcp_ssl_has_client() && _pending){
struct pending_pcb * p = _pending;
if(p->pcb == pcb){
Expand All @@ -1314,29 +1332,34 @@ err_t AsyncServer::_poll(tcp_pcb* pcb){
p->next = b->next;
p = b;
}
ASYNC_TCP_DEBUG("### remove from wait: %d\n", _clients_waiting);
//1 ASYNC_TCP_DEBUG("### remove from wait: %d\n", _clients_waiting);
AsyncClient *c = new (std::nothrow) AsyncClient(pcb, _ssl_ctx);
if(c){
c->onConnect([this](void * arg, AsyncClient *c){
(void)arg;
_connect_cb(_connect_cb_arg, c);
}, this);
if(p->pb)
c->_recv(pcb, p->pb, 0);
if(p->pb) {
auto errorTracker = c->getACErrorTracker();
c->_recv(errorTracker, pcb, p->pb, 0);
err = errorTracker->getCallbackCloseError();
}
}
// Should there be error handling for when "new AsynClient" fails??
free(p);
}
return ERR_OK;
return err;
}

err_t AsyncServer::_recv(struct tcp_pcb *pcb, struct pbuf *pb, err_t err){
(void)err;
if(!_pending)
return ERR_OK;

struct pending_pcb * p;

if(!pb){
ASYNC_TCP_DEBUG("### close from wait: %d\n", _clients_waiting);
//1 ASYNC_TCP_DEBUG("### close from wait: %d\n", _clients_waiting);
p = _pending;
if(p->pcb == pcb){
_pending = _pending->next;
Expand All @@ -1357,7 +1380,7 @@ err_t AsyncServer::_recv(struct tcp_pcb *pcb, struct pbuf *pb, err_t err){
return ERR_ABRT;
}
} else {
ASYNC_TCP_DEBUG("### wait _recv: %u %d\n", pb->tot_len, _clients_waiting);
//1 ASYNC_TCP_DEBUG("### wait _recv: %u %d\n", pb->tot_len, _clients_waiting);
p = _pending;
while(p && p->pcb != pcb)
p = p->next;
Expand Down
6 changes: 5 additions & 1 deletion src/async_config.h
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,11 @@
// Starting with Arduino Core 2.4.0 and up the define of DEBUG_ESP_PORT
// can be handled through the Arduino IDE Board options instead of here.
// #define DEBUG_ESP_PORT Serial

// #define DEBUG_ESP_ASYNC_TCP 1
// #define DEBUG_ESP_TCP_SSL 1

#ifndef DEBUG_SKIP__DEBUG_PRINT_MACROS

#include <DebugPrintMacros.h>

#ifndef ASYNC_TCP_ASSERT
Expand All @@ -35,4 +37,6 @@
#define TCP_SSL_DEBUG(...) do { (void)0;} while(false)
#endif

#endif

#endif /* LIBRARIES_ESPASYNCTCP_SRC_ASYNC_CONFIG_H_ */
45 changes: 35 additions & 10 deletions src/tcp_axtls.c
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,13 @@
* Compatibility for AxTLS with LWIP raw tcp mode (http://lwip.wikia.com/wiki/Raw/TCP)
* Original Code and Inspiration: Slavey Karadzhov
*/

// To handle all the definitions needed for debug printing, we need to delay
// macro definitions till later.
#define DEBUG_SKIP__DEBUG_PRINT_MACROS 1
#include <async_config.h>
#undef DEBUG_SKIP__DEBUG_PRINT_MACROS

#if ASYNC_TCP_SSL_ENABLED

#include "lwip/opt.h"
Expand All @@ -34,6 +40,13 @@
#include <stdbool.h>
#include <tcp_axtls.h>

// ets_uart_printf is defined in esp8266_undocumented.h, in newer Arduino ESP8266 Core.
extern int ets_uart_printf(const char *format, ...) __attribute__ ((format (printf, 1, 2)));
#include <DebugPrintMacros.h>
#ifndef TCP_SSL_DEBUG
#define TCP_SSL_DEBUG(...) do { (void)0;} while(false)
#endif

uint8_t * default_private_key = NULL;
uint16_t default_private_key_len = 0;

Expand Down Expand Up @@ -377,7 +390,8 @@ int tcp_ssl_read(struct tcp_pcb *tcp, struct pbuf *p) {

do {
read_bytes = ssl_read(fd_data->ssl, &read_buf);
//TCP_SSL_DEBUG("tcp_ssl_ssl_read: %d\n", read_bytes);
TCP_SSL_DEBUG("tcp_ssl_ssl_read: %d\n", read_bytes);

if(read_bytes < SSL_OK) {
if(read_bytes != SSL_CLOSE_NOTIFY) {
TCP_SSL_DEBUG("tcp_ssl_read: read error: %d\n", read_bytes);
Expand All @@ -387,20 +401,31 @@ int tcp_ssl_read(struct tcp_pcb *tcp, struct pbuf *p) {
} else if(read_bytes > 0){
if(fd_data->on_data){
fd_data->on_data(fd_data->arg, tcp, read_buf, read_bytes);
// fd_data may have been freed in callback
fd_data = tcp_ssl_get(tcp);
if(NULL == fd_data)
return SSL_CLOSE_NOTIFY;
}
total_bytes+= read_bytes;
} else {
if(fd_data->handshake != SSL_OK) {
fd_data->handshake = ssl_handshake_status(fd_data->ssl);
if(fd_data->handshake == SSL_OK){
//TCP_SSL_DEBUG("tcp_ssl_read: handshake OK\n");
// fd_data may be freed in callbacks.
int handshake = fd_data->handshake = ssl_handshake_status(fd_data->ssl);
if(handshake == SSL_OK){
TCP_SSL_DEBUG("tcp_ssl_read: handshake OK\n");
if(fd_data->on_handshake)
fd_data->on_handshake(fd_data->arg, fd_data->tcp, fd_data->ssl);
} else if(fd_data->handshake != SSL_NOT_OK){
TCP_SSL_DEBUG("tcp_ssl_read: handshake error: %d\n", fd_data->handshake);
fd_data = tcp_ssl_get(tcp);
if(NULL == fd_data)
return SSL_CLOSE_NOTIFY;
} else if(handshake != SSL_NOT_OK){
TCP_SSL_DEBUG("tcp_ssl_read: handshake error: %d\n", handshake);
if(fd_data->on_error)
fd_data->on_error(fd_data->arg, fd_data->tcp, fd_data->handshake);
return fd_data->handshake;
fd_data->on_error(fd_data->arg, fd_data->tcp, handshake);
return handshake;
// With current code APP gets called twice at onError handler.
// Once here and again after return when handshake != SSL_CLOSE_NOTIFY.
// As always APP must never free resources at onError only at onDisconnect.
}
}
}
Expand Down Expand Up @@ -525,13 +550,13 @@ int ax_port_write(int fd, uint8_t *data, uint16_t len) {
TCP_SSL_DEBUG("ax_port_write: No memory %d (%d)\n", tcp_len, len);
return err;
}
TCP_SSL_DEBUG("ax_port_write: tcp_write error: %d\n", err);
TCP_SSL_DEBUG("ax_port_write: tcp_write error: %ld\n", err);
return err;
} else if (err == ERR_OK) {
//TCP_SSL_DEBUG("ax_port_write: tcp_output: %d / %d\n", tcp_len, len);
err = tcp_output(fd_data->tcp);
if(err != ERR_OK) {
TCP_SSL_DEBUG("ax_port_write: tcp_output err: %d\n", err);
TCP_SSL_DEBUG("ax_port_write: tcp_output err: %ld\n", err);
return err;
}
}
Expand Down