diff --git a/Socket/SecureSocket.cpp b/Socket/SecureSocket.cpp index 3821400..a380cbf 100644 --- a/Socket/SecureSocket.cpp +++ b/Socket/SecureSocket.cpp @@ -1,221 +1,318 @@ /** -* @file SecureSocket.cpp -* @brief implementation of the Secure Socket class -* @author Mohamed Amine Mzoughi -*/ + * @file SecureSocket.cpp + * @brief implementation of the Secure Socket class + * @author Mohamed Amine Mzoughi + */ #ifdef OPENSSL - #include "SecureSocket.h" #include -#ifndef LINUX +#ifdef _WIN32 // to avoid link problems in prod/test program // Update : with the newer versions of OpenSSL, there's no need to include it //#include #endif -ASecureSocket::SecureSocketGlobalInitializer& ASecureSocket::SecureSocketGlobalInitializer::instance() +ASecureSocket::SSLSocket::SSLSocket() + : m_SockFd(INVALID_SOCKET) + , m_pSSL(nullptr) + , m_pCTXSSL(nullptr) + , m_pMTHDSSL(nullptr) { - static SecureSocketGlobalInitializer inst{}; - return inst; } -ASecureSocket::SecureSocketGlobalInitializer::SecureSocketGlobalInitializer() +ASecureSocket::SSLSocket::SSLSocket(SSLSocket&& Sockother) + : m_SockFd(Sockother.m_SockFd) + , m_pSSL(Sockother.m_pSSL) + , m_pCTXSSL(Sockother.m_pCTXSSL) + , m_pMTHDSSL(Sockother.m_pMTHDSSL) { - InitializeSSL(); + Sockother.m_SockFd = INVALID_SOCKET; + Sockother.m_pSSL = nullptr; + Sockother.m_pCTXSSL = nullptr; + Sockother.m_pMTHDSSL = nullptr; } -ASecureSocket::SecureSocketGlobalInitializer::~SecureSocketGlobalInitializer() +ASecureSocket::SSLSocket& ASecureSocket::SSLSocket::operator=(SSLSocket&& Sockother) { - DestroySSL(); + if (this != &Sockother) + { + m_SockFd = Sockother.m_SockFd; + m_pSSL = Sockother.m_pSSL; + m_pCTXSSL = Sockother.m_pCTXSSL; + m_pMTHDSSL = Sockother.m_pMTHDSSL; + + // reset Sockother + Sockother.m_SockFd = INVALID_SOCKET; + Sockother.m_pSSL = nullptr; + Sockother.m_pCTXSSL = nullptr; + Sockother.m_pMTHDSSL = nullptr; + } + return *this; } -/** -* @brief constructor of the Secure Socket -* -* @param oLogger - a callabck to a logger function void(const std::string&) -* @param eSSLVersion - SSL/TLS protocol version -* -*/ -ASecureSocket::ASecureSocket(const LogFnCallback& oLogger, - const OpenSSLProtocol eSSLVersion, - const SettingsFlag eSettings /*= ALL_FLAGS*/) : - ASocket(oLogger, eSettings), - m_eOpenSSLProtocol(eSSLVersion), - m_globalInitializer(SecureSocketGlobalInitializer::instance()) +ASecureSocket::SSLSocket::~SSLSocket() { + Disconnect(); } -/** -* @brief destructor of the secure socket object -* It's a pure virtual destructor but an implementation is provided below. -* this to avoid creating a dummy pure virtual method to transform the class -* to an abstract one. -*/ -ASecureSocket::~ASecureSocket() +void ASecureSocket::SSLSocket::Disconnect() { + SocketClose(m_SockFd); + + if (m_pSSL != nullptr) + { + /* send the close_notify alert to the peer. */ + SSL_shutdown(m_pSSL); // must be called before SSL_free + SSL_free(m_pSSL); + m_pSSL = nullptr; + } + + if (m_pCTXSSL != nullptr) + { + SSL_CTX_free(m_pCTXSSL); + m_pCTXSSL = nullptr; + } } -void ASecureSocket::SetUpCtxClient(SSLSocket& Socket) +bool ASecureSocket::SSLSocket::HasPending() const { - switch (m_eOpenSSLProtocol) - { - default: - case OpenSSLProtocol::TLS: - // Standard Protocol as of 11/2018, OpenSSL will choose highest possible TLS standard between peers - Socket.m_pMTHDSSL = const_cast(TLS_client_method()); - break; + return SSL_has_pending(m_pSSL) == 1; +} - case OpenSSLProtocol::SSL_V23: - Socket.m_pMTHDSSL = const_cast(SSLv23_client_method()); - break; +int ASecureSocket::SSLSocket::PendingBytes() const +{ + return SSL_pending(m_pSSL); +} - #ifndef LINUX - // deprecated in newer versions of OpenSSL - //case OpenSSLProtocol::SSL_V2: - //Socket.m_pMTHDSSL = const_cast(SSLv2_client_method()); - //break; - #endif +std::atomic ASecureSocket::s_iSecureSocketCount = ATOMIC_VAR_INIT(0); - // deprecated - /*case OpenSSLProtocol::SSL_V3: - Socket.m_pMTHDSSL = const_cast(SSLv3_client_method()); - break;*/ +/** + * @brief constructor of the Secure Socket + * + * @param oLogger - a callabck to a logger function void(const std::string&) + * @param eSSLVersion - SSL/TLS protocol version + */ +ASecureSocket::ASecureSocket(const LogFnCallback& oLogger, + const OpenSSLProtocol& eSSLVersion, + const SettingsFlag& eSettings /*= ALL_FLAGS*/) + : ASocket(oLogger, eSettings) + , m_eOpenSSLProtocol(eSSLVersion) +{ + int expected = 0; + if (s_iSecureSocketCount.compare_exchange_strong(expected, 1)) + { + // Initialize OpenSSL + InitializeSSL(); + } + else + { + s_iSecureSocketCount.fetch_add(1, std::memory_order_relaxed); + } +} - case OpenSSLProtocol::TLS_V1: - Socket.m_pMTHDSSL = const_cast(TLSv1_client_method()); - break; - } - Socket.m_pCTXSSL = SSL_CTX_new(Socket.m_pMTHDSSL); +/** + * @brief destructor of the secure socket object + * It's a pure virtual destructor but an implementation is provided below. + * this to avoid creating a dummy pure virtual method to transform the class + * to an abstract one. + */ +ASecureSocket::~ASecureSocket() +{ + int value = s_iSecureSocketCount.load(std::memory_order_relaxed); + + do + { + if (value == 0) + { + return; + } + + if (s_iSecureSocketCount.compare_exchange_weak(value, value - 1)) + { + if (value == 1) + { + DestroySSL(); + } + return; + } + } while (true); } -void ASecureSocket::SetUpCtxServer(SSLSocket& Socket) +void ASecureSocket::InitializeSSL() { - switch (m_eOpenSSLProtocol) - { - default: - case OpenSSLProtocol::TLS: - // Standard Protocol as of 11/2018, OpenSSL will choose highest possible TLS standard between peers - Socket.m_pMTHDSSL = const_cast(TLS_server_method()); - break; + /* Initialize malloc, free, etc for OpenSSL's use. */ + //CRYPTO_malloc_init(); - #ifndef LINUX - //case OpenSSLProtocol::SSL_V2: - //Socket.m_pMTHDSSL = const_cast(SSLv2_server_method()); - //break; - #endif + /* Initialize OpenSSL's SSL libraries: load encryption & hash algorithms for SSL */ + (void)SSL_library_init(); //always returns 1 - // deprecated - /*case OpenSSLProtocol::SSL_V3: - Socket.m_pMTHDSSL = const_cast(SSLv3_server_method()); - break;*/ + /* Load the error strings for good error reporting */ + SSL_load_error_strings(); - case OpenSSLProtocol::TLS_V1: - Socket.m_pMTHDSSL = const_cast(TLSv1_server_method()); - break; + /* Load BIO error strings. */ + //ERR_load_BIO_strings(); - case OpenSSLProtocol::SSL_V23: - Socket.m_pMTHDSSL = const_cast(SSLv23_server_method()); - break; - } - Socket.m_pCTXSSL = SSL_CTX_new(Socket.m_pMTHDSSL); + /* Load all available encryption algorithms. */ + OpenSSL_add_all_algorithms(); } -void ASecureSocket::InitializeSSL() +void ASecureSocket::DestroySSL() { - /* Initialize malloc, free, etc for OpenSSL's use. */ - //CRYPTO_malloc_init(); - - /* Initialize OpenSSL's SSL libraries: load encryption & hash algorithms for SSL */ - SSL_library_init(); + ERR_free_strings(); + EVP_cleanup(); +} - /* Load the error strings for good error reporting */ - SSL_load_error_strings(); +bool ASecureSocket::SetUpCtxClient(SSLSocket& Socket) +{ + switch (m_eOpenSSLProtocol) + { + default: + case OpenSSLProtocol::TLS: + // Standard Protocol as of 11/2018, OpenSSL will choose highest possible TLS standard between peers + Socket.m_pMTHDSSL = const_cast(TLS_client_method()); + break; + + case OpenSSLProtocol::SSL_V23: + Socket.m_pMTHDSSL = const_cast(SSLv23_client_method()); + break; + +#if 0 +#ifdef _WIN32 + // deprecated in newer versions of OpenSSL + case OpenSSLProtocol::SSL_V2: + Socket.m_pMTHDSSL = const_cast(SSLv2_client_method()); + break; +#endif - /* Load BIO error strings. */ - //ERR_load_BIO_strings(); + // deprecated + case OpenSSLProtocol::SSL_V3: + Socket.m_pMTHDSSL = const_cast(SSLv3_client_method()); + break; +#endif - /* Load all available encryption algorithms. */ - OpenSSL_add_all_algorithms(); + case OpenSSLProtocol::TLS_V1: + Socket.m_pMTHDSSL = const_cast(TLSv1_client_method()); + break; + } + + if (Socket.m_pMTHDSSL == nullptr) + { + //SocketLog("[WARN ]ASecureSocket, XXX_client_method failed[%lu:%s]", ERR_get_error(), ERR_error_string(ERR_get_error(), nullptr)); + } + + Socket.m_pCTXSSL = SSL_CTX_new(Socket.m_pMTHDSSL); + if (Socket.m_pCTXSSL == nullptr) + { + //SocketLog("[ERROR]ASecureSocket, client SSL_CTX_new failed[%lu:%s]", ERR_get_error(), ERR_error_string(ERR_get_error(), nullptr)); + //ERR_print_errors_fp(stdout); + return false; + } + + return true; } -void ASecureSocket::DestroySSL() +bool ASecureSocket::SetUpCtxServer(SSLSocket& Socket) { - ERR_free_strings(); - EVP_cleanup(); + switch (m_eOpenSSLProtocol) + { + default: + case OpenSSLProtocol::TLS: + // Standard Protocol as of 11/2018, OpenSSL will choose highest possible TLS standard between peers + Socket.m_pMTHDSSL = const_cast(TLS_server_method()); + break; + +#if 0 +#ifdef _WIN32 + case OpenSSLProtocol::SSL_V2: + Socket.m_pMTHDSSL = const_cast(SSLv2_server_method()); + break; +#endif + + // deprecated + case OpenSSLProtocol::SSL_V3: + Socket.m_pMTHDSSL = const_cast(SSLv3_server_method()); + break; +#endif + + case OpenSSLProtocol::TLS_V1: + Socket.m_pMTHDSSL = const_cast(TLSv1_server_method()); + break; + + case OpenSSLProtocol::SSL_V23: + Socket.m_pMTHDSSL = const_cast(SSLv23_server_method()); + break; + } + + if (Socket.m_pMTHDSSL == nullptr) + { + //SocketLog("[WARN ]ASecureSocket, XXX_server_method failed[%lu:%s]", ERR_get_error(), ERR_error_string(ERR_get_error(), nullptr)); + } + + Socket.m_pCTXSSL = SSL_CTX_new(Socket.m_pMTHDSSL); + if (Socket.m_pCTXSSL == nullptr) + { + //SocketLog("[ERROR]ASecureSocket, server SSL_CTX_new failed[%lu:%s]", ERR_get_error(), ERR_error_string(ERR_get_error(), nullptr)); + return false; + } + + //SSL_CTX_set_verify(Socket.m_pCTXSSL, SSL_VERIFY_NONE, nullptr); + return true; } void ASecureSocket::ShutdownSSL(SSLSocket& SSLSock) { - if (SSLSock.m_pSSL != nullptr) - { - /* send the close_notify alert to the peer. */ - SSL_shutdown(SSLSock.m_pSSL); // must be called before SSL_free - SSL_free(SSLSock.m_pSSL); - SSL_CTX_free(SSLSock.m_pCTXSSL); - - SSLSock.m_pSSL = nullptr; - } + SSLSock.Disconnect(); } const char* ASecureSocket::GetSSLErrorString(int iErrorCode) { - switch (iErrorCode) - { - case SSL_ERROR_NONE: - return "The TLS/SSL I/O operation completed."; - break; - - case SSL_ERROR_ZERO_RETURN: - return "The TLS/SSL connection has been closed."; - break; - - case SSL_ERROR_WANT_READ: - return "The read operation did not complete; " - "the same TLS/SSL I/O function should be called again later."; - break; - - case SSL_ERROR_WANT_WRITE: - return "The write operation did not complete; " - "the same TLS/SSL I/O function should be called again later."; - break; - - case SSL_ERROR_WANT_CONNECT: - return "The connect operation did not complete; " - "the same TLS/SSL I/O function should be called again later."; - break; - - case SSL_ERROR_WANT_ACCEPT: - return "The accept operation did not complete; " - "the same TLS/SSL I/O function should be called again later."; - break; - - case SSL_ERROR_WANT_X509_LOOKUP: - return "The operation did not complete because an application callback set" - " by SSL_CTX_set_client_cert_cb() has asked to be called again. " - "The TLS/SSL I/O function should be called again later."; - break; - - case SSL_ERROR_SYSCALL: - return "Some I/O error occurred. The OpenSSL error queue may contain" - " more information on the error."; - break; - - case SSL_ERROR_SSL: - return "A failure in the SSL library occurred, usually a protocol error. " - "The OpenSSL error queue contains more information on the error."; - break; - - default: - return "Unknown error !"; - break; - } + switch (iErrorCode) + { + case SSL_ERROR_NONE: + return "The TLS/SSL I/O operation completed."; + + case SSL_ERROR_ZERO_RETURN: + return "The TLS/SSL connection has been closed."; + + case SSL_ERROR_WANT_READ: + return "The read operation did not complete; " + "the same TLS/SSL I/O function should be called again later."; + + case SSL_ERROR_WANT_WRITE: + return "The write operation did not complete; " + "the same TLS/SSL I/O function should be called again later."; + + case SSL_ERROR_WANT_CONNECT: + return "The connect operation did not complete; " + "the same TLS/SSL I/O function should be called again later."; + + case SSL_ERROR_WANT_ACCEPT: + return "The accept operation did not complete; " + "the same TLS/SSL I/O function should be called again later."; + + case SSL_ERROR_WANT_X509_LOOKUP: + return "The operation did not complete because an application callback set" + " by SSL_CTX_set_client_cert_cb() has asked to be called again. " + "The TLS/SSL I/O function should be called again later."; + + case SSL_ERROR_SYSCALL: + return "Some I/O error occurred. The OpenSSL error queue may contain" + " more information on the error."; + + case SSL_ERROR_SSL: + return "A failure in the SSL library occurred, usually a protocol error. " + "The OpenSSL error queue contains more information on the error."; + + default: + return "Unknown error !"; + } } int ASecureSocket::AlwaysTrueCallback(X509_STORE_CTX* pCTX, void* pArg) { - return 1; + return 1; } #endif diff --git a/Socket/SecureSocket.h b/Socket/SecureSocket.h index dfac9ad..2c2fcb1 100644 --- a/Socket/SecureSocket.h +++ b/Socket/SecureSocket.h @@ -1,15 +1,16 @@ -/* -* @file SecureSocket.h -* @brief Abstract class to perform OpenSSL API global operations -* -* @author Mohamed Amine Mzoughi -* @date 2017-02-16 -*/ +/** + * @file SecureSocket.h + * @brief Abstract class to perform OpenSSL API global operations + * + * @author Mohamed Amine Mzoughi + * @date 2017-02-16 + */ #ifdef OPENSSL #ifndef INCLUDE_ASECURESOCKET_H_ #define INCLUDE_ASECURESOCKET_H_ +#include #ifdef OPENSSL #include #include @@ -21,137 +22,97 @@ class ASecureSocket : public ASocket { public: - enum class OpenSSLProtocol - { - #ifndef LINUX - //SSL_V2, // deprecated - #endif - //SSL_V3, // deprecated - TLS_V1, - SSL_V23, /* There is no SSL protocol version named SSLv23. The SSLv23_method() API - and its variants choose SSLv2, SSLv3, or TLSv1 for compatibility with the peer. */ - TLS // Standard Protocol as of 11/2018, OpenSSL will choose highest possible TLS standard between peers - }; - - struct SSLSocket - { - SSLSocket() : - m_SockFd(INVALID_SOCKET), - m_pSSL(nullptr), - m_pCTXSSL(nullptr), - m_pMTHDSSL(nullptr) - { - } - - // copy constructor and assignment operator are disabled - SSLSocket(const SSLSocket&) = delete; - SSLSocket& operator=(const SSLSocket&) = delete; - - // move constructor - SSLSocket(SSLSocket&& Sockother) : - m_SockFd(Sockother.m_SockFd), - m_pSSL(Sockother.m_pSSL), - m_pCTXSSL(Sockother.m_pCTXSSL), - m_pMTHDSSL(Sockother.m_pMTHDSSL) - { - Sockother.m_SockFd = INVALID_SOCKET; - Sockother.m_pSSL = nullptr; - Sockother.m_pCTXSSL = nullptr; - Sockother.m_pMTHDSSL = nullptr; - } - - // move assignment operator - SSLSocket& operator=(SSLSocket&& Sockother) - { - if (this != &Sockother) - { - m_SockFd = Sockother.m_SockFd; - m_pSSL = Sockother.m_pSSL; - m_pCTXSSL = Sockother.m_pCTXSSL; - m_pMTHDSSL = Sockother.m_pMTHDSSL; - - // reset Sockother - Sockother.m_SockFd = INVALID_SOCKET; - Sockother.m_pSSL = nullptr; - Sockother.m_pCTXSSL = nullptr; - Sockother.m_pMTHDSSL = nullptr; - } - return *this; - } - - Socket m_SockFd; - SSL* m_pSSL; - SSL_CTX* m_pCTXSSL; // SSL Context Structure - SSL_METHOD* m_pMTHDSSL; // used to create an SSL_CTX - }; - - /* Please provide your logger thread-safe routine, otherwise, you can turn off - * error log messages printing by not using the flag ALL_FLAGS or ENABLE_LOG */ - explicit ASecureSocket(const LogFnCallback& oLogger, - const OpenSSLProtocol eSSLVersion = OpenSSLProtocol::TLS, - const SettingsFlag eSettings = ALL_FLAGS); - virtual ~ASecureSocket() = 0; - - /* - * For the SSL server: - * Server's own certificate (mandatory) - * CA certificate (optional) - * - * For the SSL client: - * CA certificate (mandatory) - * Client's own certificate (optional) - */ - inline const std::string& GetSSLCertAuth() { return m_strCAFile; } - inline void SetSSLCerthAuth(const std::string& strPath) { m_strCAFile = strPath; } - - inline void SetSSLCertFile(const std::string& strPath) { m_strSSLCertFile = strPath; } - inline const std::string& GetSSLCertFile() const { return m_strSSLCertFile; } - - inline void SetSSLKeyFile(const std::string& strPath) { m_strSSLKeyFile = strPath; } - inline const std::string& GetSSLKeyFile() const { return m_strSSLKeyFile; } - - //void SetSSLKeyPassword(const std::string& strPwd) { m_strSSLKeyPwd = strPwd; } - //const std::string& GetSSLKeyPwd() const { return m_strSSLKeyPwd; } + enum class OpenSSLProtocol + { +#ifdef _WIN32 + //SSL_V2, // deprecated +#endif + //SSL_V3, // deprecated + TLS_V1, + SSL_V23, /* There is no SSL protocol version named SSLv23. The SSLv23_method() API + and its variants choose SSLv2, SSLv3, or TLSv1 for compatibility with the peer. */ + TLS // Standard Protocol as of 11/2018, OpenSSL will choose highest possible TLS standard between peers + }; + + struct SSLSocket + { + SSLSocket(); + ~SSLSocket(); + + // copy constructor and assignment operator are disabled + SSLSocket(const SSLSocket&) = delete; + SSLSocket& operator=(const SSLSocket&) = delete; + + // move constructor + SSLSocket(SSLSocket&& Sockother); + // move assignment operator + SSLSocket& operator=(SSLSocket&& Sockother); + + void Disconnect(); + + bool HasPending() const; + int PendingBytes() const; + + Socket m_SockFd; + SSL* m_pSSL; + SSL_CTX* m_pCTXSSL; // SSL Context Structure + SSL_METHOD* m_pMTHDSSL; // used to create an SSL_CTX + }; + + /** + * Please provide your logger thread-safe routine, otherwise, you can turn off + * error log messages printing by not using the flag ALL_FLAGS or ENABLE_LOG + */ + explicit ASecureSocket(const LogFnCallback& oLogger, + const OpenSSLProtocol& eSSLVersion = OpenSSLProtocol::TLS, + const SettingsFlag& eSettings = ALL_FLAGS); + virtual ~ASecureSocket(); + + /** + * For the SSL server: + * Server's own certificate (mandatory) + * CA certificate (optional) + * + * For the SSL client: + * CA certificate (mandatory) + * Client's own certificate (optional) + */ + inline const std::string& GetSSLCertAuth() { return m_strCAFile; } + inline void SetSSLCerthAuth(const std::string& strPath) { m_strCAFile = strPath; } + + inline void SetSSLCertFile(const std::string& strPath) { m_strSSLCertFile = strPath; } + inline const std::string& GetSSLCertFile() const { return m_strSSLCertFile; } + + inline void SetSSLKeyFile(const std::string& strPath) { m_strSSLKeyFile = strPath; } + inline const std::string& GetSSLKeyFile() const { return m_strSSLKeyFile; } + + //void SetSSLKeyPassword(const std::string& strPwd) { m_strSSLKeyPwd = strPwd; } + //const std::string& GetSSLKeyPwd() const { return m_strSSLKeyPwd; } protected: - // object methods - void SetUpCtxClient(SSLSocket& Socket); - void SetUpCtxServer(SSLSocket& Socket); - //void SetUpCtxCombined(SSLSocket& Socket); - - // class methods - static void ShutdownSSL(SSLSocket& SSLSocket); - static const char* GetSSLErrorString(int iErrorCode); - static int AlwaysTrueCallback(X509_STORE_CTX* pCTX, void* pArg); - - // non-static/object members - OpenSSLProtocol m_eOpenSSLProtocol; - std::string m_strCAFile; - std::string m_strSSLCertFile; - std::string m_strSSLKeyFile; - //std::string m_strSSLKeyPwd; - -private: - friend class SecureSocketGlobalInitializer; - class SecureSocketGlobalInitializer { - public: - static SecureSocketGlobalInitializer& instance(); - - SecureSocketGlobalInitializer(SecureSocketGlobalInitializer const&) = delete; - SecureSocketGlobalInitializer(SecureSocketGlobalInitializer&&) = delete; + // object methods + bool SetUpCtxClient(SSLSocket& Socket); + bool SetUpCtxServer(SSLSocket& Socket); + //void SetUpCtxCombined(SSLSocket& Socket); - SecureSocketGlobalInitializer& operator=(SecureSocketGlobalInitializer const&) = delete; - SecureSocketGlobalInitializer& operator=(SecureSocketGlobalInitializer&&) = delete; + // class methods + static void ShutdownSSL(SSLSocket& SSLSocket); + static const char* GetSSLErrorString(int iErrorCode); + static int AlwaysTrueCallback(X509_STORE_CTX* pCTX, void* pArg); - ~SecureSocketGlobalInitializer(); - - private: - SecureSocketGlobalInitializer(); - }; - SecureSocketGlobalInitializer& m_globalInitializer; +private: + static void InitializeSSL(); + static void DestroySSL(); - static void InitializeSSL(); - static void DestroySSL(); +protected: + // non-static/object members + OpenSSLProtocol m_eOpenSSLProtocol; + std::string m_strCAFile; + std::string m_strSSLCertFile; + std::string m_strSSLKeyFile; + //std::string m_strSSLKeyPwd; + + static std::atomic s_iSecureSocketCount; // Count of the actual secure socket sessions }; #endif diff --git a/Socket/Socket.cpp b/Socket/Socket.cpp index 15ec3bc..70e0454 100644 --- a/Socket/Socket.cpp +++ b/Socket/Socket.cpp @@ -1,209 +1,394 @@ /** -* @file Socket.cpp -* @brief implementation of the Socket class -* @author Mohamed Amine Mzoughi -*/ + * @file Socket.cpp + * @brief implementation of the Socket class + * @author Mohamed Amine Mzoughi + */ #include "Socket.h" - +#include +#include // va_start, etc. +#include // snprintf #include #include -#ifdef WINDOWS -WSADATA ASocket::s_wsaData; +#ifdef _WIN32 +// Static members initialization +std::atomic ASocket::s_iSocketCount = ATOMIC_VAR_INIT(0); +WSADATA ASocket::s_wsaData{}; #endif -ASocket::SocketGlobalInitializer& ASocket::SocketGlobalInitializer::instance() +/** + * @brief constructor of the Socket + * + * @param Logger - a callabck to a logger function void(const std::string&) + * + */ +ASocket::ASocket(const LogFnCallback& oLogger, SettingsFlag eSettings /*= ALL_FLAGS*/) + : m_oLog(oLogger) + , m_eSettingsFlags(eSettings) { - static SocketGlobalInitializer inst{}; - return inst; +#ifdef _WIN32 + int expected = 0; + if (s_iSocketCount.compare_exchange_strong(expected, 1)) + { + InitializeEnvironment(); + } + else + { + s_iSocketCount.fetch_add(1, std::memory_order_relaxed); + } +#endif } -ASocket::SocketGlobalInitializer::SocketGlobalInitializer() +/** + * @brief destructor of the socket object + * It's a pure virtual destructor but an implementation is provided below. + * this to avoid creating a dummy pure virtual method to transform the class + * to an abstract one. + */ +ASocket::~ASocket() { - // In windows, this will init the winsock DLL stuff -#ifdef WINDOWS - // MAKEWORD(2,2) version 2.2 of Winsock - int iWinSockInitResult = WSAStartup(MAKEWORD(2, 2), &s_wsaData); - - if (iWinSockInitResult != 0) - { - std::cerr << ASocket::StringFormat("[TCPClient][Error] WSAStartup failed : %d", iWinSockInitResult); - } +#ifdef _WIN32 + int value = s_iSocketCount.load(std::memory_order_relaxed); + + do + { + if (value == 0) + { + return; + } + + if (s_iSocketCount.compare_exchange_weak(value, value - 1)) + { + if (value == 1) + { + UnInitializeEnvironment(); + } + return; + } + } while(true); #endif } -ASocket::SocketGlobalInitializer::~SocketGlobalInitializer() +#ifdef _WIN32 +bool ASocket::InitializeEnvironment() +{ + // In windows, this will init the winsock DLL stuff + // MAKEWORD(2,2) version 2.2 of Winsock + int iWinSockInitResult = WSAStartup(MAKEWORD(2, 2), &s_wsaData); + if (iWinSockInitResult != NO_ERROR) + { + //SocketLog("[ERROR]ASocket, WSAStartup failed[%d:%s]", iWinSockInitResult, strerror(iWinSockInitResult)); + return false; + } + + if (LOBYTE(s_wsaData.wVersion) != 2 || HIBYTE(s_wsaData.wVersion) != 2) + { + //SocketLog("[ERROR]ASocket, could not find a usable version of winsock.dll[%x]", s_wsaData.wVersion); + return false; + } + + return true; +} + +void ASocket::UnInitializeEnvironment() +{ + /* call WSACleanup when done using the Winsock dll */ + WSACleanup(); +} +#endif + +int ASocket::GetSocketError() { -#ifdef WINDOWS - /* call WSACleanup when done using the Winsock dll */ - WSACleanup(); +#ifdef _WIN32 + return WSAGetLastError(); +#else + return errno; #endif } -/** -* @brief constructor of the Socket -* -* @param Logger - a callabck to a logger function void(const std::string&) -* -*/ -ASocket::ASocket(const LogFnCallback& oLogger, - const SettingsFlag eSettings /*= ALL_FLAGS*/) : - m_oLog(oLogger), - m_eSettingsFlags(eSettings), - m_globalInitializer(SocketGlobalInitializer::instance()) +char* ASocket::GaiStrerror(int ecode) { +#ifdef _WIN32 + return gai_strerrorA(ecode); +#else + return gai_strerror(ecode); +#endif +} +void ASocket::SocketClose(Socket& sd) +{ + if (sd == INVALID_SOCKET) + { +#ifdef _WIN32 + closesocket(sd); +#else + close(sd); +#endif + sd = INVALID_SOCKET; + } } /** -* @brief destructor of the socket object -* It's a pure virtual destructor but an implementation is provided below. -* this to avoid creating a dummy pure virtual method to transform the class -* to an abstract one. -*/ -ASocket::~ASocket() + * @brief returns a formatted string + * + * @param [in] strFormat string with one or many format specifiers + * @param [in] parameters to be placed in the format specifiers of strFormat + * + * @retval string formatted string + */ +std::string ASocket::StringFormat(const char* fmt, ...) { - + if (fmt == NULL) + { + return std::string(); + } + + va_list args; + va_start(args, fmt); + size_t len = std::vsnprintf(NULL, 0, fmt, args); + va_end(args); + std::vector vec(len + 1); + va_start(args, fmt); + std::vsnprintf(&vec[0], len + 1, fmt, args); + vec[len] = '\0'; + va_end(args); + return std::string(vec.data()); } /** -* @brief returns a formatted string -* -* @param [in] strFormat string with one or many format specifiers -* @param [in] parameters to be placed in the format specifiers of strFormat -* -* @retval string formatted string + * @brief waits for a socket's read status change + * + * @param [in] sd socket descriptor to be selected + * @param [in] msec waiting period in milliseconds, a value of 0 implies no timeout + * + * @retval int 0 on timeout, -1 on error and 1 on success. */ -std::string ASocket::StringFormat(const std::string strFormat, ...) +int ASocket::SelectSocket(Socket sd, size_t msec/* = ACCEPT_WAIT_INF_DELAY*/) { - va_list args; - va_start (args, strFormat); - size_t len = std::vsnprintf(NULL, 0, strFormat.c_str(), args); - va_end (args); - std::vector vec(len + 1); - va_start (args, strFormat); - std::vsnprintf(&vec[0], len + 1, strFormat.c_str(), args); - va_end (args); - return &vec[0]; + if (sd == INVALID_SOCKET) + { + return -1; + } + + struct timeval tval{}; + struct timeval* tvalptr = nullptr; + + if (msec != ACCEPT_WAIT_INF_DELAY) + { + tval.tv_sec = (long)msec / 1000; + tval.tv_usec = (msec % 1000) * 1000; + tvalptr = &tval; + } + + fd_set fd_reads{}; + FD_ZERO(&fd_reads); + FD_SET(sd, &fd_reads); + +#ifdef _WIN32 + Socket max_fd = 0; +#else + Socket max_fd = sd + 1; +#endif + + // block until socket is readable. + int res = select((int)max_fd, &fd_reads, nullptr, nullptr, tvalptr); + if (res == SOCKET_ERROR) + { + if (SOCKET_ERR_SELECT_RETRIABLE(GetSocketError())) + { + return 0; + } + + //SocketLog("[ERROR]ASocket, select failed[%d:%s][%d]", GetSocketError(), strerror(GetSocketError()), max_fd); + return -1; + } + + if (res == 0) + { + return 0; + } + +#if defined(__unix__) && defined(BSD) + if (!FD_ISSET(sd, &fd_reads)) + { + return 0; + } +#endif + + assert(FD_ISSET(sd, &fd_reads)); + assert(res == 1); + return 1; } /** -* @brief waits for a socket's read status change -* -* @param [in] sd socket descriptor to be selected -* @param [in] msec waiting period in milliseconds, a value of 0 implies no timeout -* -* @retval int 0 on timeout, -1 on error and 1 on success. -*/ -int ASocket::SelectSocket(const ASocket::Socket sd, const size_t msec) + * @brief waits for a set of sockets read status change + * + * @param [in] pSocketsToSelect pointer to an array of socket descriptors to be selected + * @param [in] count elements count of pSocketsToSelect + * @param [in] msec waiting period in milliseconds, a value of 0 implies no timeout + * @param [out] selectedIndex index of the socket that is ready to be read + * + * @retval int 0 on timeout, -1 on error and 1 on success. + */ +int ASocket::SelectSockets(const Socket* pSocketsToSelect, size_t count, size_t& selectedIndex, size_t msec/* = ACCEPT_WAIT_INF_DELAY*/) { - if (sd < 0) - { - return -1; - } + if (pSocketsToSelect == nullptr || count == 0) + { + return -1; + } + + struct timeval tval{}; + struct timeval* tvalptr = nullptr; + if (msec != ACCEPT_WAIT_INF_DELAY) + { + tval.tv_sec = (long)msec / 1000; + tval.tv_usec = (msec % 1000) * 1000; + tvalptr = &tval; + } + + fd_set fd_reads{}; + FD_ZERO(&fd_reads); + +#ifdef _WIN32 + Socket max_fd = 0; +#else + Socket max_fd = -1; +#endif - struct timeval tval; - struct timeval* tvalptr = nullptr; - fd_set rset; - int res; +#ifndef _WIN32 + for (size_t i = 0; i < count; i++) + { + if (pSocketsToSelect[i] != INVALID_SOCKET) + { + FD_SET(pSocketsToSelect[i], &fd_reads); + + if (pSocketsToSelect[i] > max_fd) + { + max_fd = pSocketsToSelect[i]; + } + } + } +#endif - if (msec > 0) - { - tval.tv_sec = msec / 1000; - tval.tv_usec = (msec % 1000) * 1000; - tvalptr = &tval; - } +#ifndef _WIN32 + max_fd += 1; +#endif - FD_ZERO(&rset); - FD_SET(sd, &rset); + // block until one socket is ready to read. + int res = select((int)max_fd, &fd_reads, nullptr, nullptr, tvalptr); + if (res == SOCKET_ERROR) + { + if (SOCKET_ERR_SELECT_RETRIABLE(GetSocketError())) + { + return 0; + } + + //SocketLog("[ERROR]ASocket, select failed[%d:%s][%d]", GetSocketError(), strerror(GetSocketError()), max_fd); + return -1; + } + + if (res == 0) + { + return 0; + } + + // find the first socket which has some activity. +#if defined(__unix__) && defined(BSD) + Socket firstSocket = INVALID_SOCKET; +#endif + for (size_t i = 0; i < count; ++i) + { + if (FD_ISSET(pSocketsToSelect[i], &fd_reads)) + { + selectedIndex = i; +#if defined(__unix__) && defined(BSD) + firstSocket = pSocketsToSelect[i]; +#endif + break; + } + } + +#if defined(__unix__) && defined(BSD) + if (firstSocket == INVALID_SOCKET) + { + return 0; + } +#endif - // block until socket is readable. - res = select(sd + 1, &rset, nullptr, nullptr, tvalptr); + return 1; +} - if (res <= 0) - return res; +bool ASocket::SetRcvTimeout(Socket sd, unsigned int msec_timeout) +{ +#ifndef _WIN32 + struct timeval t = TimevalFromMsec(msec_timeout); + return SetRcvTimeout(t); +#else + if (setsockopt(sd, SOL_SOCKET, SO_RCVTIMEO, (char*)&msec_timeout, sizeof(msec_timeout)) == SOCKET_ERROR) + { + //SocketLog("[ERROR]ASocket, setsockopt SOL_SOCKET SO_RCVTIMEO failed[%d:%s][%d][%u]", GetSocketError(), strerror(GetSocketError()), sd, msec_timeout); + return false; + } + + return true; +#endif +} - if (!FD_ISSET(sd, &rset)) - return -1; +#ifndef _WIN32 +bool ASocket::SetRcvTimeout(Socket sd, const struct timeval& timeout) +{ + if (setsockopt(sd, SOL_SOCKET, SO_RCVTIMEO, (char*)&timeout, sizeof(timeout)) == SOCKET_ERROR) + { + //SocketLog("[ERROR]ASocket, setsockopt SOL_SOCKET SO_RCVTIMEO failed[%d:%s][%d][%u:%u]", GetSocketError(), strerror(GetSocketError()), sd, timeout.tv_sec, timeout.tv_usec); + return false; + } - return 1; + return true; } +#endif -/** -* @brief waits for a set of sockets read status change -* -* @param [in] pSocketsToSelect pointer to an array of socket descriptors to be selected -* @param [in] count elements count of pSocketsToSelect -* @param [in] msec waiting period in milliseconds, a value of 0 implies no timeout -* @param [out] selectedIndex index of the socket that is ready to be read -* -* @retval int 0 on timeout, -1 on error and 1 on success. -*/ -int ASocket::SelectSockets(const ASocket::Socket* pSocketsToSelect, const size_t count, - const size_t msec, size_t& selectedIndex) +bool ASocket::SetSndTimeout(Socket sd, unsigned int msec_timeout) { - if (!pSocketsToSelect || count == 0) - { - return -1; - } - - fd_set rset; - int res = -1; - - struct timeval tval; - struct timeval* tvalptr = nullptr; - if (msec > 0) - { - tval.tv_sec = msec / 1000; - tval.tv_usec = (msec % 1000) * 1000; - tvalptr = &tval; - } - - FD_ZERO(&rset); - - int max_fd = -1; - for (size_t i = 0; i < count; i++) - { - FD_SET(pSocketsToSelect[i], &rset); - - if (pSocketsToSelect[i] > max_fd) - { - max_fd = pSocketsToSelect[i]; - } - } - - // block until one socket is ready to read. - res = select(max_fd + 1, &rset, nullptr, nullptr, tvalptr); - - if (res <= 0) - return res; - - // find the first socket which has some activity. - for (size_t i = 0; i < count; i++) - { - if (FD_ISSET(pSocketsToSelect[i], &rset)) - { - selectedIndex = i; - return 1; - } - } - - return -1; +#ifndef _WIN32 + struct timeval t = TimevalFromMsec(msec_timeout); + return SetSndTimeout(t); +#else + if (setsockopt(sd, SOL_SOCKET, SO_SNDTIMEO, (char*)&msec_timeout, sizeof(msec_timeout)) == SOCKET_ERROR) + { + //SocketLog("[ERROR]ASocket, setsockopt SOL_SOCKET SO_SNDTIMEO failed[%d:%s][%d][%u]", GetSocketError(), strerror(GetSocketError()), sd, msec_timeout); + return false; + } + + return true; +#endif } -/** -* @brief converts a value representing milliseconds into a struct timeval -* -* @param [time_msec] a time value in milliseconds -* -* @retval time_msec converted to struct timeval -*/ -struct timeval ASocket::TimevalFromMsec(unsigned int time_msec){ - struct timeval t; +#ifndef _WIN32 +bool ASocket::SetSndTimeout(Socket sd, const struct timeval& timeout) +{ + if (setsockopt(sd, SOL_SOCKET, SO_SNDTIMEO, (char*)&timeout, sizeof(timeout)) == SOCKET_ERROR) + { + //SocketLog("[ERROR]ASocket, setsockopt SOL_SOCKET SO_SNDTIMEO failed[%d:%s][%d][%u:%u]", GetSocketError(), strerror(GetSocketError()), sd, timeout.tv_sec, timeout.tv_usec); + return false; + } + + return true; +} +#endif - t.tv_sec = time_msec / 1000; - t.tv_usec = (time_msec % 1000) * 1000; +/** + * @brief converts a value representing milliseconds into a struct timeval + * + * @param [time_msec] a time value in milliseconds + * + * @retval time_msec converted to struct timeval + */ +struct timeval ASocket::TimevalFromMsec(unsigned int time_msec) +{ + struct timeval t{}; + t.tv_sec = (long)time_msec / 1000; + t.tv_usec = ((long)time_msec % 1000) * 1000; - return t; + return t; } diff --git a/Socket/Socket.h b/Socket/Socket.h index 0271d16..1ef8cad 100644 --- a/Socket/Socket.h +++ b/Socket/Socket.h @@ -1,117 +1,173 @@ -/* -* @file Socket.h -* @brief Abstract class to perform API global operations -* -* @author Mohamed Amine Mzoughi -* @date 2017-02-10 -*/ +/** + * @file Socket.h + * @brief Abstract class to perform API global operations + * + * @author Mohamed Amine Mzoughi + * @date 2017-02-10 + */ #ifndef INCLUDE_ASOCKET_H_ #define INCLUDE_ASOCKET_H_ -#include // snprintf +#include +#include +#ifdef _WIN32 +#include +#endif +#include #include #include -#include -#include // va_start, etc. +#include #include - -#ifdef WINDOWS -#include -#include +#ifdef _WIN32 +#include +#include // Need to link with Ws2_32.lib #pragma comment(lib,"WS2_32.lib") #else #include -#include #include #include -#include -#include -#include -#include #include #include #include #endif -#include -#define ACCEPT_WAIT_INF_DELAY std::numeric_limits::max() +#define ACCEPT_WAIT_INF_DELAY (std::numeric_limits::max)() -class ASocket -{ -public: - // Public definitions - //typedef std::function ProgressFnCallback; - typedef std::function LogFnCallback; +#ifndef _WIN32 +#if EAGAIN == EWOULDBLOCK +#define SOCKET_ERR_IS_EAGAIN(e) ((e) == EAGAIN) +#else +#define SOCKET_ERR_IS_EAGAIN(e) ((e) == EAGAIN || (e) == EWOULDBLOCK) +#endif - // socket file descriptor id - #ifdef WINDOWS - typedef SOCKET Socket; - #else - typedef int Socket; - #define INVALID_SOCKET -1 - #endif +#define SOCKET_ERR_SELECT_RETRIABLE(e) \ + (SOCKET_ERR_IS_EAGAIN(e) || (e) == EINTR || (e) == ENOMEM) - enum SettingsFlag - { - NO_FLAGS = 0x00, - ENABLE_LOG = 0x01, - ALL_FLAGS = 0xFF - }; +#define SOCKET_ERR_RW_RETRIABLE(e) \ + (SOCKET_ERR_IS_EAGAIN(e) || (e) == EINTR || (e) == ENOMEM) - /* Please provide your logger thread-safe routine, otherwise, you can turn off - * error log messages printing by not using the flag ALL_FLAGS or ENABLE_LOG */ - explicit ASocket(const LogFnCallback& oLogger, - const SettingsFlag eSettings = ALL_FLAGS); - virtual ~ASocket() = 0; +#define SOCKET_ERR_CONNECT_RETRIABLE(e) \ + (SOCKET_ERR_IS_EAGAIN(e) || (e) == EINTR || (e) == EINPROGRESS || (e) == EALREADY) - static int SelectSockets(const Socket* pSocketsToSelect, const size_t count, - const size_t msec, size_t& selectedIndex); +#define SOCKET_ERR_ACCEPT_RETRIABLE(e) \ + (SOCKET_ERR_IS_EAGAIN(e) || (e) == EINTR || (e) == ECONNABORTED || (e) == EPROTO) - static int SelectSocket(const Socket sd, const size_t msec); +#define SOCKET_ERR_CONNECT_REFUSED(e) \ + ((e) == ECONNREFUSED) - static struct timeval TimevalFromMsec(unsigned int time_msec); +#define SOCKET_ERR_ADDR_INUSE(e) \ + ((e) == EADDRINUSE) - // String Helpers - static std::string StringFormat(const std::string strFormat, ...); +#else +#define SOCKET_ERR_IS_EAGAIN(e) ((e) == WSAEWOULDBLOCK) -protected: - // Log printer callback - /*mutable*/const LogFnCallback m_oLog; +#define SOCKET_ERR_SELECT_RETRIABLE(e) \ + (SOCKET_ERR_IS_EAGAIN(e) || (e) == WSAEINTR || (e) == WSAEFAULT || (e) == WSAEINPROGRESS) - SettingsFlag m_eSettingsFlags; +#define SOCKET_ERR_RW_RETRIABLE(e) \ + (SOCKET_ERR_IS_EAGAIN(e) || (e) == WSAEINTR || (e) == WSAENOBUFS) - #ifdef WINDOWS - static WSADATA s_wsaData; - #endif +#define SOCKET_ERR_CONNECT_RETRIABLE(e) \ + (SOCKET_ERR_IS_EAGAIN(e) || (e) == WSAEINTR || (e) == WSAEINPROGRESS || (e) == WSAEALREADY) -private: - friend class SocketGlobalInitializer; - class SocketGlobalInitializer { - public: - static SocketGlobalInitializer& instance(); +#define SOCKET_ERR_ACCEPT_RETRIABLE(e) \ + (SOCKET_ERR_IS_EAGAIN(e) || (e) == WSAEINTR || (e) == WSAECONNABORTED || (e) == WSAEPROTONOSUPPORT) + +#define SOCKET_ERR_CONNECT_REFUSED(e) \ + ((e) == WSAECONNREFUSED) + +#define SOCKET_ERR_ADDR_INUSE(e) \ + ((e) == WSAEADDRINUSE) + +#endif + +class ASocket +{ +public: + // Public definitions + //typedef std::function ProgressFnCallback; + typedef std::function LogFnCallback; - SocketGlobalInitializer(SocketGlobalInitializer const&) = delete; - SocketGlobalInitializer(SocketGlobalInitializer&&) = delete; + // socket file descriptor id +#ifdef _WIN32 + typedef SOCKET Socket; +#else + typedef int Socket; +#define INVALID_SOCKET -1 +#define SOCKET_ERROR -1 +#endif - SocketGlobalInitializer& operator=(SocketGlobalInitializer const&) = delete; - SocketGlobalInitializer& operator=(SocketGlobalInitializer&&) = delete; + static int GetSocketError(); + static char* GaiStrerror(int ecode); + static void SocketClose(Socket& sd); + + enum SettingsFlag + { + NO_FLAGS = 0x00, + ENABLE_LOG = 0x01, + ALL_FLAGS = 0xFF + }; + + /** + * Please provide your logger thread-safe routine, otherwise, you can turn off + * error log messages printing by not using the flag ALL_FLAGS or ENABLE_LOG + */ + explicit ASocket(const LogFnCallback& oLogger, SettingsFlag eSettings = ALL_FLAGS); + virtual ~ASocket(); + + static int SelectSockets(const Socket* pSocketsToSelect, size_t count, size_t& selectedIndex, size_t msec = ACCEPT_WAIT_INF_DELAY); + static int SelectSocket(Socket sd, size_t msec = ACCEPT_WAIT_INF_DELAY); + + // To disable timeout, set msec_timeout to 0. + static bool SetRcvTimeout(Socket sd, unsigned int msec_timeout); + static bool SetSndTimeout(Socket sd, unsigned int msec_timeout); + +#ifndef _WIN32 + static bool SetRcvTimeout(Socket sd, const struct timeval& timeout); + static bool SetSndTimeout(Socket sd, const struct timeval& timeout); +#endif + + static struct timeval TimevalFromMsec(unsigned int time_msec); - ~SocketGlobalInitializer(); +protected: + // String Helpers + static std::string StringFormat(const char* fmt, ...); - private: - SocketGlobalInitializer(); - }; - SocketGlobalInitializer& m_globalInitializer; +#ifdef _WIN32 +private: + static bool InitializeEnvironment(); + static void UnInitializeEnvironment(); +#endif + +protected: + // Log printer callback + /*mutable*/const LogFnCallback m_oLog; + + SettingsFlag m_eSettingsFlags; + +#ifdef _WIN32 +private: + static WSADATA s_wsaData; + static std::atomic s_iSocketCount; +#endif }; +#define SocketLog(fmt, ...) \ +do { \ + if (m_oLog && (m_eSettingsFlags & ENABLE_LOG)) \ + { \ + m_oLog(StringFormat(fmt, ##__VA_ARGS__)); \ + } \ +} while(0) + class EResolveError : public std::logic_error { public: - explicit EResolveError(const std::string &strMsg) : std::logic_error(strMsg) {} + explicit EResolveError(const std::string& strMsg) : std::logic_error(strMsg) {} }; #endif diff --git a/Socket/TCPClient.cpp b/Socket/TCPClient.cpp index e2438a6..5d266e8 100644 --- a/Socket/TCPClient.cpp +++ b/Socket/TCPClient.cpp @@ -5,419 +5,289 @@ */ #include "TCPClient.h" - -CTCPClient::CTCPClient(const LogFnCallback oLogger, - const SettingsFlag eSettings /*= ALL_FLAGS*/) : - ASocket(oLogger, eSettings), - m_eStatus(DISCONNECTED), - m_pResultAddrInfo(nullptr), - m_ConnectSocket(INVALID_SOCKET) - //m_uRetryCount(0), - //m_uRetryPeriod(0) +#include + +CTCPClient::CTCPClient(const LogFnCallback& oLogger, const SettingsFlag eSettings /*= ALL_FLAGS*/) + : ASocket(oLogger, eSettings) + , m_eStatus(DISCONNECTED) + , m_ConnectSocket(INVALID_SOCKET) + //, m_uRetryCount(0) + //, m_uRetryPeriod(0) + , m_pResultAddrInfo(nullptr) + , m_HintsAddrInfo() { +} +CTCPClient::~CTCPClient() +{ + Disconnect(); } // Method for setting receive timeout. Can be called after Connect -bool CTCPClient::SetRcvTimeout(unsigned int msec_timeout) { -#ifndef WINDOWS - struct timeval t = ASocket::TimevalFromMsec(msec_timeout); - - return this->SetRcvTimeout(t); -#else - int iErr; +bool CTCPClient::SetRcvTimeout(unsigned int msec_timeout) +{ + bool ret_val = ASocket::SetRcvTimeout(m_ConnectSocket, msec_timeout); + if (!ret_val) + { + SocketLog("[ERROR]TCPClient, setsockopt SOL_SOCKET SO_RCVTIMEO failed[%d:%s][%d][%u]", GetSocketError(), strerror(GetSocketError()), m_ConnectSocket, msec_timeout); + } - // it's expecting an int but it doesn't matter... - iErr = setsockopt(m_ConnectSocket, SOL_SOCKET, SO_RCVTIMEO, (char*)&msec_timeout, sizeof(struct timeval)); - if (iErr < 0) { - if (m_eSettingsFlags & ENABLE_LOG) - m_oLog("[TCPServer][Error] CTCPClient::SetRcvTimeout : Socket error in SO_RCVTIMEO call to setsockopt."); + return ret_val; +} - return false; +#ifndef _WIN32 +bool CTCPClient::SetRcvTimeout(const struct timeval& timeout) +{ + bool ret_val = ASocket::SetRcvTimeout(m_ConnectSocket, timeout); + if (!ret_val) + { + SocketLog("[ERROR]TCPClient, setsockopt SOL_SOCKET SO_RCVTIMEO failed[%d:%s][%d][%u:%u]", GetSocketError(), strerror(GetSocketError()), m_ConnectSocket, timeout.tv_sec, timeout.tv_usec); } - return true; -#endif + return ret_val; } +#endif -#ifndef WINDOWS -bool CTCPClient::SetRcvTimeout(struct timeval timeout) { - int iErr; +// Method for setting send timeout. Can be called after Connect +bool CTCPClient::SetSndTimeout(unsigned int msec_timeout) +{ + bool ret_val = ASocket::SetSndTimeout(m_ConnectSocket, msec_timeout); + if (!ret_val) + { + SocketLog("[ERROR]TCPClient, setsockopt SOL_SOCKET SO_SNDTIMEO failed[%d:%s][%d][%u]", GetSocketError(), strerror(GetSocketError()), m_ConnectSocket, msec_timeout); + } - iErr = setsockopt(m_ConnectSocket, SOL_SOCKET, SO_RCVTIMEO, (char*) &timeout, sizeof(struct timeval)); - if (iErr < 0) { - if (m_eSettingsFlags & ENABLE_LOG) - m_oLog("[TCPServer][Error] CTCPClient::SetRcvTimeout : Socket error in SO_RCVTIMEO call to setsockopt."); + return ret_val; +} - return false; - } +#ifndef _WIN32 +bool CTCPClient::SetSndTimeout(const struct timeval& timeout) +{ + bool ret_val = ASocket::SetSndTimeout(m_ConnectSocket, timeout); + if (!ret_val) + { + SocketLog("[ERROR]TCPClient, setsockopt SOL_SOCKET SO_SNDTIMEO failed[%d:%s][%d][%u:%u]", GetSocketError(), strerror(GetSocketError()), m_ConnectSocket, timeout.tv_sec, timeout.tv_usec); + } - return true; + return ret_val; } #endif -// Method for setting send timeout. Can be called after Connect -bool CTCPClient::SetSndTimeout(unsigned int msec_timeout) { -#ifndef WINDOWS - struct timeval t = ASocket::TimevalFromMsec(msec_timeout); +// Connexion au serveur +bool CTCPClient::Connect(const std::string& strServer, const std::string& strPort) +{ + if (m_eStatus == CONNECTED) + { + Disconnect(); + SocketLog("[WARN ]TCPClient, opening a new connexion. the last one was automatically closed[%s:%s]", strServer.c_str(), strPort.c_str()); + } + + memset(&m_HintsAddrInfo, 0, sizeof m_HintsAddrInfo); + /* AF_INET is used to specify the IPv4 address family. */ + m_HintsAddrInfo.ai_family = AF_INET; + /* SOCK_STREAM is used to specify a stream socket. */ + m_HintsAddrInfo.ai_socktype = SOCK_STREAM; + /* IPPROTO_TCP is used to specify the TCP protocol. */ + m_HintsAddrInfo.ai_protocol = IPPROTO_TCP; + + /* Resolve the server address and port */ + int iResult = getaddrinfo(strServer.c_str(), strPort.c_str(), &m_HintsAddrInfo, &m_pResultAddrInfo); + if (iResult != 0) + { + SocketLog("[ERROR]TCPClient, getaddrinfo failed[%d:%s][%s:%s]", iResult, GaiStrerror(iResult), strServer.c_str(), strPort.c_str()); + return false; + } - return this->SetSndTimeout(t); -#else - int iErr; + bool isOK = false; + + /* getaddrinfo() returns a list of address structures. + * Try each address until we successfully connect(2). + * If socket(2) (or connect(2)) fails, we (close the socket + * and) try the next address. */ + for (struct addrinfo* pResPtr = m_pResultAddrInfo; pResPtr != nullptr; pResPtr = pResPtr->ai_next) + { + // create socket + m_ConnectSocket = socket(pResPtr->ai_family, pResPtr->ai_socktype, pResPtr->ai_protocol); + if (m_ConnectSocket == INVALID_SOCKET) + { + SocketLog("[WARN ]TCPClient, create socket failed[%d:%s]", GetSocketError(), strerror(GetSocketError())); + continue; + } + + // Fixes windows 0.2 second delay sending (buffering) data. + int on = 1; + if (setsockopt(m_ConnectSocket, IPPROTO_TCP, TCP_NODELAY, (char*)&on, sizeof(on)) == SOCKET_ERROR) + { + SocketLog("[WARN ]TCPClient, setsockopt IPPROTO_TCP TCP_NODELAY failed[%d:%s][%d]", GetSocketError(), strerror(GetSocketError()), m_ConnectSocket); + } + + // connexion to the server + if (connect(m_ConnectSocket, pResPtr->ai_addr, static_cast(pResPtr->ai_addrlen)) == SOCKET_ERROR) + { + int iErrCode = GetSocketError(); + if (!SOCKET_ERR_CONNECT_RETRIABLE(iErrCode)) + { + SocketLog("[WARN ]TCPClient, connect failed[%d:%s][%d]", GetSocketError(), strerror(GetSocketError()), m_ConnectSocket); + assert(m_ConnectSocket != INVALID_SOCKET); + SocketClose(m_ConnectSocket); + continue; + } + } + + isOK = true; + m_eStatus = CONNECTED; + SocketLog("[INFO ]TCPClient, connected[%d]", m_ConnectSocket); + break; + } - // it's expecting an int but it doesn't matter... - iErr = setsockopt(m_ConnectSocket, SOL_SOCKET, SO_SNDTIMEO, (char*)&msec_timeout, sizeof(struct timeval)); - if (iErr < 0) { - if (m_eSettingsFlags & ENABLE_LOG) - m_oLog("[TCPServer][Error] CTCPClient::SetSndTimeout : Socket error in SO_SNDTIMEO call to setsockopt."); + if (m_pResultAddrInfo != nullptr) + { + freeaddrinfo(m_pResultAddrInfo); /* No longer needed */ + m_pResultAddrInfo = nullptr; + } - return false; + /* No address succeeded */ + if (!isOK) + { + SocketLog("[ERROR]TCPClient, Connect failed[%s:%s]", strServer.c_str(), strPort.c_str()); } - return true; -#endif + return isOK; } -#ifndef WINDOWS -bool CTCPClient::SetSndTimeout(struct timeval timeout) { - int iErr; +/* ret > 0 : bytes received + * ret == 0 : connection closed + * ret < 0 : recv failed + */ +int CTCPClient::Receive(char* pData, size_t uSize, bool bReadFully /*= true*/) const +{ + if (m_eStatus != CONNECTED || m_ConnectSocket == INVALID_SOCKET) + { + SocketLog("[ERROR]TCPClient, recv failed[not connected to a server.]"); + return -1; + } - iErr = setsockopt(m_ConnectSocket, SOL_SOCKET, SO_SNDTIMEO, (char*) &timeout, sizeof(struct timeval)); - if (iErr < 0) { - if (m_eSettingsFlags & ENABLE_LOG) - m_oLog("[TCPServer][Error] CTCPClient::SetSndTimeout : Socket error in SO_SNDTIMEO call to setsockopt."); + if (pData == nullptr && uSize != 0) + { + SocketLog("[ERROR]TCPClient, recv failed[%d][%p:%zu]", m_ConnectSocket, pData, uSize); + return -2; + } - return false; - } +#if 0 +#ifdef _WIN32 + int tries = 0; +#endif +#endif - return true; -} + int total = 0; + bool isOK = true; + do + { + isOK = true; + int nRecvd = recv(m_ConnectSocket, pData + total, (int)uSize - total, 0); + if (nRecvd == SOCKET_ERROR) + { + isOK = false; + int iErrCode = GetSocketError(); + if (SOCKET_ERR_RW_RETRIABLE(iErrCode)) + { + continue; + } +#if 0 +#ifdef _WIN32 + // On long messages, Windows recv sometimes fails with WSAENOBUFS, but + // will work if you try again. + if (WSAGetLastError() == WSAENOBUFS && (tries++ < 1000)) + { + Sleep(1); + continue; + } #endif +#endif + SocketLog("[ERROR]TCPClient, recv failed[%d:%s][%d]", iErrCode, strerror(iErrCode), m_ConnectSocket); + break; + } -// Connexion au serveur -bool CTCPClient::Connect(const std::string& strServer, const std::string& strPort) -{ - if (m_eStatus == CONNECTED) - { - Disconnect(); - if (m_eSettingsFlags & ENABLE_LOG) - m_oLog("[TCPClient][Warning] Opening a new connexion. The last one was automatically closed."); - } - - #ifdef WINDOWS - ZeroMemory(&m_HintsAddrInfo, sizeof(m_HintsAddrInfo)); - /* AF_INET is used to specify the IPv4 address family. */ - m_HintsAddrInfo.ai_family = AF_INET; - /* SOCK_STREAM is used to specify a stream socket. */ - m_HintsAddrInfo.ai_socktype = SOCK_STREAM; - /* IPPROTO_TCP is used to specify the TCP protocol. */ - m_HintsAddrInfo.ai_protocol = IPPROTO_TCP; - - /* Resolve the server address and port */ - int iResult = getaddrinfo(strServer.c_str(), strPort.c_str(), &m_HintsAddrInfo, &m_pResultAddrInfo); - if (iResult != 0) - { - if (m_eSettingsFlags & ENABLE_LOG) - m_oLog(StringFormat("[TCPClient][Error] getaddrinfo failed : %d", iResult)); - - if (m_pResultAddrInfo != nullptr) - { - freeaddrinfo(m_pResultAddrInfo); - m_pResultAddrInfo = nullptr; - } - - return false; - } - - // socket creation - m_ConnectSocket = socket(m_pResultAddrInfo->ai_family, // AF_INET - m_pResultAddrInfo->ai_socktype, // SOCK_STREAM - m_pResultAddrInfo->ai_protocol);// IPPROTO_TCP - - if (m_ConnectSocket == INVALID_SOCKET) - { - if (m_eSettingsFlags & ENABLE_LOG) - m_oLog(StringFormat("[TCPClient][Error] socket failed : %d", WSAGetLastError())); - - freeaddrinfo(m_pResultAddrInfo); - m_pResultAddrInfo = nullptr; - return false; - } - - // Fixes windows 0.2 second delay sending (buffering) data. - int on = 1; - int iErr; - - iErr = setsockopt(m_ConnectSocket, IPPROTO_TCP, TCP_NODELAY, (char*)&on, sizeof(on)); - if (iErr == INVALID_SOCKET) - { - if (m_eSettingsFlags & ENABLE_LOG) - m_oLog("[TCPClient][Error] Socket error in call to setsockopt"); - - closesocket(m_ConnectSocket); - freeaddrinfo(m_pResultAddrInfo); m_pResultAddrInfo = nullptr; - - return false; - } - - /* - SOCKET ConnectSocket = INVALID_SOCKET; - struct sockaddr_in clientService; - - ConnectSocket = socket(AF_INET, SOCK_STREAM, IPPROTO_TCP); - if (ConnectSocket == INVALID_SOCKET) { - printf("Error at socket(): %ld\n", WSAGetLastError()); - WSACleanup(); - return 1; - } - - // The sockaddr_in structure specifies the address family, - // IP address, and port of the server to be connected to. - clientService.sin_family = AF_INET; - clientService.sin_addr.s_addr = inet_addr("127.0.0.1"); - clientService.sin_port = htons(27015); - */ - - // connexion to the server - //unsigned uRetry = 0; - //do - //{ - iResult = connect(m_ConnectSocket, - m_pResultAddrInfo->ai_addr, - static_cast(m_pResultAddrInfo->ai_addrlen)); -//iResult = connect(m_ConnectSocket, (SOCKADDR*)&clientService, sizeof(clientService)); - - //if (iResult != SOCKET_ERROR) - //break; - - // retry mechanism - //if (uRetry < m_uRetryCount) - //if (m_eSettingsFlags & ENABLE_LOG) - /*m_oLog(StringFormat("[TCPClient][Error] connect retry %u after %u second(s)", - m_uRetryCount + 1, m_uRetryPeriod));*/ - - //if (m_uRetryPeriod > 0) - //{ - //for (unsigned uSec = 0; uSec < m_uRetryPeriod; uSec++) - //Sleep(1000); - //} - //} while (iResult == SOCKET_ERROR && ++uRetry < m_uRetryCount); - - freeaddrinfo(m_pResultAddrInfo); - m_pResultAddrInfo = nullptr; - - if (iResult != SOCKET_ERROR) - { - m_eStatus = CONNECTED; - return true; - } - if (m_eSettingsFlags & ENABLE_LOG) - m_oLog(StringFormat("[TCPClient][Error] Unable to connect to server : %d", WSAGetLastError())); - - #else - memset(&m_HintsAddrInfo, 0, sizeof m_HintsAddrInfo); - m_HintsAddrInfo.ai_family = AF_INET; // AF_INET or AF_INET6 to force version or use AF_UNSPEC - m_HintsAddrInfo.ai_socktype = SOCK_STREAM; - //m_HintsAddrInfo.ai_flags = 0; - //m_HintsAddrInfo.ai_protocol = 0; /* Any protocol */ - - int iAddrInfoRet = getaddrinfo(strServer.c_str(), strPort.c_str(), &m_HintsAddrInfo, &m_pResultAddrInfo); - if (iAddrInfoRet != 0) - { - if (m_eSettingsFlags & ENABLE_LOG) - m_oLog(StringFormat("[TCPClient][Error] getaddrinfo failed : %s", gai_strerror(iAddrInfoRet))); - - if (m_pResultAddrInfo != nullptr) - { - freeaddrinfo(m_pResultAddrInfo); - m_pResultAddrInfo = nullptr; - } - - return false; - } - - /* getaddrinfo() returns a list of address structures. - * Try each address until we successfully connect(2). - * If socket(2) (or connect(2)) fails, we (close the socket - * and) try the next address. */ - struct addrinfo* pResPtr = m_pResultAddrInfo; - for (pResPtr = m_pResultAddrInfo; pResPtr != nullptr; pResPtr = pResPtr->ai_next) - { - // create socket - m_ConnectSocket = socket(pResPtr->ai_family, pResPtr->ai_socktype, pResPtr->ai_protocol); - if (m_ConnectSocket < 0) // or == -1 - continue; - - // connexion to the server - int iConRet = connect(m_ConnectSocket, pResPtr->ai_addr, pResPtr->ai_addrlen); - if (iConRet >= 0) // or != -1 - { - /* Success */ - m_eStatus = CONNECTED; - - if (m_pResultAddrInfo != nullptr) - { - freeaddrinfo(m_pResultAddrInfo); - m_pResultAddrInfo = nullptr; - } - - return true; - } - - close(m_ConnectSocket); - } - - if (m_pResultAddrInfo != nullptr) - { - freeaddrinfo(m_pResultAddrInfo); /* No longer needed */ - m_pResultAddrInfo = nullptr; - } - - /* No address succeeded */ - if (m_eSettingsFlags & ENABLE_LOG) - m_oLog("[TCPClient][Error] no such host."); - - #endif - - return false; -} + if (nRecvd == 0) + { + SocketLog("[INFO ]TCPClient, peer shut down[%d]", m_ConnectSocket); + break; + } -bool CTCPClient::Send(const char* pData, const size_t uSize) const -{ - if (!pData || !uSize) - return false; - - if (m_eStatus != CONNECTED) - { - if (m_eSettingsFlags & ENABLE_LOG) - m_oLog("[TCPClient][Error] send failed : not connected to a server."); - - return false; - } - - int total = 0; - do - { - const int flags = 0; - int nSent; - - nSent = send(m_ConnectSocket, pData + total, uSize - total, flags); - - if (nSent < 0) - { - if (m_eSettingsFlags & ENABLE_LOG) - m_oLog("[TCPClient][Error] Socket error in call to send."); - - return false; - } - total += nSent; - } while(total < uSize); - - return true; -} + total += nRecvd; -bool CTCPClient::Send(const std::string& strData) const -{ - return Send(strData.c_str(), strData.length()); + } while (bReadFully && (total < (int)uSize)); + + if (!isOK && total == 0) + { + return -1; + } + + return (int)total; } -bool CTCPClient::Send(const std::vector& Data) const +int CTCPClient::Send(const char* pData, size_t uSize) const { - return Send(Data.data(), Data.size()); + if (m_eStatus != CONNECTED || m_ConnectSocket == INVALID_SOCKET) + { + SocketLog("[ERROR]TCPClient, send failed[not connected to a server.]"); + return -1; + } + + if (pData == nullptr && uSize != 0) + { + SocketLog("[WARN ]TCPClient, send failed[%d][%p:%zu]", m_ConnectSocket, pData, uSize); + return -1; + } + + int total = 0; + do + { + int nSent = send(m_ConnectSocket, pData + total, (int)uSize - total, 0); + if (nSent == SOCKET_ERROR) + { + int iErrCode = GetSocketError(); + if (SOCKET_ERR_RW_RETRIABLE(iErrCode)) + { + continue; + } + + SocketLog("[ERROR]TCPClient, send failed[%d:%s][%d]", iErrCode, strerror(iErrCode), m_ConnectSocket); + return -1; + } + + total += nSent; + } while (total < (int)uSize); + + return (int)total; } -/* ret > 0 : bytes received - * ret == 0 : connection closed - * ret < 0 : recv failed - */ -int CTCPClient::Receive(char* pData, const size_t uSize, bool bReadFully /*= true*/) const +int CTCPClient::Send(const std::string& strData) const { - if (!pData || !uSize) - return -2; - - if (m_eStatus != CONNECTED) - { - if (m_eSettingsFlags & ENABLE_LOG) - m_oLog("[TCPClient][Error] recv failed : not connected to a server."); - - return -1; - } - - #ifdef WINDOWS - int tries = 0; - #endif - - int total = 0; - do - { - int nRecvd = recv(m_ConnectSocket, pData + total, uSize - total, 0); - - if (nRecvd == 0) - { - // peer shut down - break; - } - - #ifdef WINDOWS - if ((nRecvd < 0) && (WSAGetLastError() == WSAENOBUFS)) - { - // On long messages, Windows recv sometimes fails with WSAENOBUFS, but - // will work if you try again. - if ((tries++ < 1000)) - { - Sleep(1); - continue; - } - - if (m_eSettingsFlags & ENABLE_LOG) - m_oLog("[TCPClient][Error] Socket error in call to recv."); - - break; - } - #endif - - total += nRecvd; - - } while (bReadFully && (total < uSize)); - - return total; + return Send(strData.c_str(), strData.length()); } -bool CTCPClient::Disconnect() +int CTCPClient::Send(const std::vector& Data) const { - if (m_eStatus != CONNECTED) - return true; - - m_eStatus = DISCONNECTED; - - #ifdef WINDOWS - // shutdown the connection since no more data will be sent - int iResult = shutdown(m_ConnectSocket, SD_SEND); - if (iResult == SOCKET_ERROR) - { - if (m_eSettingsFlags & ENABLE_LOG) - m_oLog(StringFormat("[TCPClient][Error] shutdown failed : %d", WSAGetLastError())); - - return false; - } - closesocket(m_ConnectSocket); - - if (m_pResultAddrInfo != nullptr) - { - freeaddrinfo(m_pResultAddrInfo); - m_pResultAddrInfo = nullptr; - } - #else - close(m_ConnectSocket); - #endif - - m_ConnectSocket = INVALID_SOCKET; - - return true; + return Send(Data.data(), Data.size()); } -CTCPClient::~CTCPClient() +void CTCPClient::Disconnect() { - if (m_eStatus == CONNECTED) - Disconnect(); + if (m_eStatus != CONNECTED) + { + m_eStatus = DISCONNECTED; + } + + if (m_ConnectSocket != INVALID_SOCKET) + { +#if 0//defined(_WIN32) + // shutdown the connection since no more data will be sent + if (shutdown(m_ConnectSocket, SD_SEND) == SOCKET_ERROR) + { + SocketLog("[ERROR]TCPClient, shutdown SD_SEND failed[%d:%s][%d]", GetSocketError(), strerror(GetSocketError()), m_ConnectSocket); + } +#endif + SocketClose(m_ConnectSocket); + } } diff --git a/Socket/TCPClient.h b/Socket/TCPClient.h index 1c5db98..d78961d 100644 --- a/Socket/TCPClient.h +++ b/Socket/TCPClient.h @@ -9,14 +9,6 @@ #ifndef INCLUDE_TCPCLIENT_H_ #define INCLUDE_TCPCLIENT_H_ -#include -#include // size_t -#include -#include // strerror, strlen, memcpy, strcpy -#include -#include -#include -#include #include #include @@ -26,62 +18,51 @@ class CTCPSSLClient; class CTCPClient : public ASocket { - friend class CTCPSSLClient; - + friend class CTCPSSLClient; public: - explicit CTCPClient(const LogFnCallback oLogger, const SettingsFlag eSettings = ALL_FLAGS); - ~CTCPClient() override; - - // copy constructor and assignment operator are disabled - CTCPClient(const CTCPClient&) = delete; - CTCPClient& operator=(const CTCPClient&) = delete; - - // Setters - Getters (for unit tests) - /*inline*/// void SetProgressFnCallback(void* pOwner, const ProgressFnCallback& fnCallback); - /*inline*/// void SetProxy(const std::string& strProxy); - /*inline auto GetProgressFnCallback() const - { - return m_fnProgressCallback.target(); - } - inline void* GetProgressFnCallbackOwner() const { return m_ProgressStruct.pOwner; }*/ - //inline const std::string& GetProxy() const { return m_strProxy; } - //inline const unsigned char GetSettingsFlags() const { return m_eSettingsFlags; } - - // Session - bool Connect(const std::string& strServer, const std::string& strPort); // connect to a TCP server - bool Disconnect(); // disconnect from the TCP server - bool Send(const char* pData, const size_t uSize) const; // send data to a TCP server - bool Send(const std::string& strData) const; - bool Send(const std::vector& Data) const; - int Receive(char* pData, const size_t uSize, bool bReadFully = true) const; - - // To disable timeout, set msec_timeout to 0. - bool SetRcvTimeout(unsigned int msec_timeout); - bool SetSndTimeout(unsigned int msec_timeout); - -#ifndef WINDOWS - bool SetRcvTimeout(struct timeval Timeout); - bool SetSndTimeout(struct timeval Timeout); + explicit CTCPClient(const LogFnCallback& oLogger, const SettingsFlag eSettings = ALL_FLAGS); + ~CTCPClient() override; + + // copy constructor and assignment operator are disabled + CTCPClient(const CTCPClient&) = delete; + CTCPClient& operator=(const CTCPClient&) = delete; + + // Session + bool Connect(const std::string& strServer, const std::string& strPort); // connect to a TCP server + void Disconnect(); // disconnect from the TCP server + + int Receive(char* pData, size_t uSize, bool bReadFully = true) const; + int Send(const char* pData, size_t uSize) const; // send data to a TCP server + int Send(const std::string& strData) const; + int Send(const std::vector& Data) const; + + // To disable timeout, set msec_timeout to 0. + bool SetRcvTimeout(unsigned int msec_timeout); + bool SetSndTimeout(unsigned int msec_timeout); + +#ifndef _WIN32 + bool SetRcvTimeout(const struct timeval& timeout); + bool SetSndTimeout(const struct timeval& timeout); #endif - bool IsConnected() const { return m_eStatus == CONNECTED; } + bool IsConnected() const { return m_eStatus == CONNECTED; } - Socket GetSocketDescriptor() const { return m_ConnectSocket; } + Socket GetSocketDescriptor() const { return m_ConnectSocket; } protected: - enum SocketStatus - { - CONNECTED, - DISCONNECTED - }; - - SocketStatus m_eStatus; - Socket m_ConnectSocket; // ConnectSocket - //unsigned m_uRetryCount; - //unsigned m_uRetryPeriod; - - struct addrinfo* m_pResultAddrInfo; - struct addrinfo m_HintsAddrInfo; + enum SocketStatus + { + CONNECTED, + DISCONNECTED + }; + + SocketStatus m_eStatus; + Socket m_ConnectSocket; // ConnectSocket + //unsigned m_uRetryCount; + //unsigned m_uRetryPeriod; + + struct addrinfo* m_pResultAddrInfo; + struct addrinfo m_HintsAddrInfo; }; #endif diff --git a/Socket/TCPSSLClient.cpp b/Socket/TCPSSLClient.cpp index f8e8514..be78dc7 100644 --- a/Socket/TCPSSLClient.cpp +++ b/Socket/TCPSSLClient.cpp @@ -7,257 +7,281 @@ #ifdef OPENSSL #include "TCPSSLClient.h" -CTCPSSLClient::CTCPSSLClient(const LogFnCallback oLogger, - const OpenSSLProtocol eSSLVersion, - const SettingsFlag eSettings /*= ALL_FLAGS*/) : - ASecureSocket(oLogger, eSSLVersion, eSettings), - m_TCPClient(oLogger, eSettings) +CTCPSSLClient::CTCPSSLClient(const LogFnCallback& oLogger, + const OpenSSLProtocol& eSSLVersion, + const SettingsFlag eSettings /*= ALL_FLAGS*/) + : ASecureSocket(oLogger, eSSLVersion, eSettings) + , m_TCPClient(oLogger, eSettings) { - } -bool CTCPSSLClient::SetRcvTimeout(unsigned int msec_timeout){ - return m_TCPClient.SetRcvTimeout(msec_timeout); +CTCPSSLClient::~CTCPSSLClient() +{ + Disconnect(); } -bool CTCPSSLClient::SetSndTimeout(unsigned int msec_timeout){ - return m_TCPClient.SetSndTimeout(msec_timeout); +bool CTCPSSLClient::SetRcvTimeout(unsigned int msec_timeout) +{ + return m_TCPClient.SetRcvTimeout(msec_timeout); } -#ifndef WINDOWS -bool CTCPSSLClient::SetRcvTimeout(struct timeval timeout) { +#ifndef _WIN32 +bool CTCPSSLClient::SetRcvTimeout(struct timeval timeout) +{ return m_TCPClient.SetRcvTimeout(timeout); } +#endif + +bool CTCPSSLClient::SetSndTimeout(unsigned int msec_timeout) +{ + return m_TCPClient.SetSndTimeout(msec_timeout); +} -bool CTCPSSLClient::SetSndTimeout(struct timeval timeout){ - return m_TCPClient.SetSndTimeout(timeout); +#ifndef _WIN32 +bool CTCPSSLClient::SetSndTimeout(struct timeval timeout) +{ + return m_TCPClient.SetSndTimeout(timeout); } #endif // Connexion au serveur bool CTCPSSLClient::Connect(const std::string& strServer, const std::string& strPort) { - if (m_TCPClient.Connect(strServer, strPort)) - { - m_SSLConnectSocket.m_SockFd = m_TCPClient.m_ConnectSocket; - SetUpCtxClient(m_SSLConnectSocket); - - if (m_SSLConnectSocket.m_pCTXSSL == nullptr) - { - if (m_eSettingsFlags & ENABLE_LOG) - m_oLog("[TCPSSLClient][Error] SSL_CTX_new failed."); - //ERR_print_errors_fp(stdout); - return false; - } - - /* process SSL certificates */ - /* Load a client certificate into the SSL_CTX structure. */ - if (!m_strSSLCertFile.empty()) - { - if (SSL_CTX_use_certificate_file(m_SSLConnectSocket.m_pCTXSSL, - m_strSSLCertFile.c_str(), SSL_FILETYPE_PEM) <= 0) - { - if (m_eSettingsFlags & ENABLE_LOG) - m_oLog("[TCPSSLClient][Error] Loading cert file failed."); - - return false; - } - } - /* Load trusted CA. Mandatory to verify server's certificate */ - if (!m_strCAFile.empty()) - { - if (!SSL_CTX_load_verify_locations(m_SSLConnectSocket.m_pCTXSSL, m_strCAFile.c_str(), nullptr)) - { - if (m_eSettingsFlags & ENABLE_LOG) - m_oLog("[TCPSSLClient][Error] Loading CA file failed."); - - return false; - } - SSL_CTX_set_verify_depth(m_SSLConnectSocket.m_pCTXSSL, 1); - } - /* Load a private-key into the SSL_CTX structure. - * set key file that corresponds to the server or client certificate. - * In the SSL handshake, a certificate (which contains the public key) is transmitted to allow - * the peer to use it for encryption. The encrypted message sent from the peer can be decrypted - * only using the private key. */ - if (!m_strSSLKeyFile.empty()) - { - if (SSL_CTX_use_PrivateKey_file(m_SSLConnectSocket.m_pCTXSSL, - m_strSSLKeyFile.c_str(), SSL_FILETYPE_PEM) <= 0) - { - if (m_eSettingsFlags & ENABLE_LOG) - m_oLog("[TCPSSLClient][Error] Loading key file failed."); + if (!m_TCPClient.Connect(strServer, strPort)) + { + SocketLog("[ERROR]TCPSSLClient, m_TCPClient Connect failed[:Unable to establish a TCP connection with the server.][%s:%s]", strServer.c_str(), strPort.c_str()); + return false; + } + + do + { + m_SSLConnectSocket.m_SockFd = m_TCPClient.m_ConnectSocket; + if (!SetUpCtxClient(m_SSLConnectSocket)) + { + SocketLog("[ERROR]TCPSSLClient, SSL_CTX_new failed[%s:%s][%d]", strServer.c_str(), strPort.c_str(), m_SSLConnectSocket.m_SockFd); //ERR_print_errors_fp(stdout); - return false; - } - - /* verify private key */ - /*if (!SSL_CTX_check_private_key(m_SSLConnectSocket.m_pCTXSSL)) - { - if (m_eSettingsFlags & ENABLE_LOG) - m_oLog("[TCPSSLClient][Error] Private key does not match the public certificate."); - return false; - }*/ - } - //SSL_CTX_set_cert_verify_callback(m_SSLConnectSocket.m_pCTXSSL, AlwaysTrueCallback, nullptr); - - /* create new SSL connection state */ - m_SSLConnectSocket.m_pSSL = SSL_new(m_SSLConnectSocket.m_pCTXSSL); - SSL_set_fd(m_SSLConnectSocket.m_pSSL, m_SSLConnectSocket.m_SockFd); - - /* initiate the TLS/SSL handshake with an TLS/SSL server */ - int iResult = SSL_connect(m_SSLConnectSocket.m_pSSL); - if (iResult > 0) - { - /* The data can now be transmitted securely over this connection. */ - if (m_eSettingsFlags & ENABLE_LOG) - m_oLog(StringFormat("[TCPSSLClient][Info] Connected with '%s' encryption.", - SSL_get_cipher(m_SSLConnectSocket.m_pSSL))); - - /*if (SSL_get_peer_certificate(m_SSLConnectSocket.m_pSSL) != nullptr) - { - if (SSL_get_verify_result(m_SSLConnectSocket.m_pSSL) == X509_V_OK) + break; + } + + /* process SSL certificates */ + /* Load a client certificate into the SSL_CTX structure. */ + if (!m_strSSLCertFile.empty()) + { + if (SSL_CTX_use_certificate_file(m_SSLConnectSocket.m_pCTXSSL, m_strSSLCertFile.c_str(), SSL_FILETYPE_PEM) != 1) { - if (m_eSettingsFlags & ENABLE_LOG) - m_oLog("client verification with SSL_get_verify_result() succeeded."); - } - else + SocketLog("[ERROR]TCPSSLClient, SSL_CTX_use_certificate_file failed[Loading cert file failed.][%lu:%s][%d][%s]", ERR_get_error(), ERR_error_string(ERR_get_error(), nullptr), m_SSLConnectSocket.m_SockFd, m_strSSLCertFile.c_str()); + break; + } + } + + /* Load trusted CA. Mandatory to verify server's certificate */ + if (!m_strCAFile.empty()) + { + if (SSL_CTX_load_verify_locations(m_SSLConnectSocket.m_pCTXSSL, m_strCAFile.c_str(), nullptr) != 1) { - if (m_eSettingsFlags & ENABLE_LOG) - m_oLog("client verification with SSL_get_verify_result() failed.\n"); - - return false; + SocketLog("[ERROR]TCPSSLClient, SSL_CTX_load_verify_locations failed[Loading CA file failed.][%lu:%s][%d][%s]", ERR_get_error(), ERR_error_string(ERR_get_error(), nullptr), m_SSLConnectSocket.m_SockFd, m_strCAFile.c_str()); + break; } - } - else if (m_eSettingsFlags & ENABLE_LOG) - m_oLog("the peer certificate was not presented.");*/ - return true; - } - // under Windows it creates problems - #ifdef LINUX - ERR_print_errors_fp(stdout); - #endif + SSL_CTX_set_verify_depth(m_SSLConnectSocket.m_pCTXSSL, 1); + } + + /** + * Load a private-key into the SSL_CTX structure. + * set key file that corresponds to the server or client certificate. + * In the SSL handshake, a certificate (which contains the public key) is transmitted to allow + * the peer to use it for encryption. The encrypted message sent from the peer can be decrypted + * only using the private key. + */ + if (!m_strSSLKeyFile.empty()) + { + if (SSL_CTX_use_PrivateKey_file(m_SSLConnectSocket.m_pCTXSSL, m_strSSLKeyFile.c_str(), SSL_FILETYPE_PEM) != 1) + { + SocketLog("[ERROR]TCPSSLClient, SSL_CTX_use_PrivateKey_file failed[Loading key file failed.][%lu:%s][%d][%s]", ERR_get_error(), ERR_error_string(ERR_get_error(), nullptr), m_SSLConnectSocket.m_SockFd, m_strSSLKeyFile.c_str()); + //ERR_print_errors_fp(stdout); + break; + } - if (m_eSettingsFlags & ENABLE_LOG) - m_oLog(StringFormat("[TCPSSLClient][Error] SSL_connect failed (Error=%d | %s)", - iResult, GetSSLErrorString(SSL_get_error(m_SSLConnectSocket.m_pSSL, iResult)))); +#if 0 + /* verify private key */ + if (SSL_CTX_check_private_key(m_SSLConnectSocket.m_pCTXSSL) != 1) + { + SocketLog("[ERROR]TCPSSLClient, SSL_CTX_check_private_key failed[Private key does not match the public certificate.][%lu:%s][%d]", ERR_get_error(), ERR_error_string(ERR_get_error(), nullptr), m_SSLConnectSocket.m_SockFd); + break; + } +#endif + } + //SSL_CTX_set_cert_verify_callback(m_SSLConnectSocket.m_pCTXSSL, AlwaysTrueCallback, nullptr); + + /* create new SSL connection state */ + m_SSLConnectSocket.m_pSSL = SSL_new(m_SSLConnectSocket.m_pCTXSSL); + if (m_SSLConnectSocket.m_pSSL == nullptr) + { + SocketLog("[ERROR]TCPSSLClient, SSL_new failed[%lu:%s][%d]", ERR_get_error(), ERR_error_string(ERR_get_error(), nullptr), m_SSLConnectSocket.m_SockFd); + break; + } + + if (SSL_set_fd(m_SSLConnectSocket.m_pSSL, (int)m_SSLConnectSocket.m_SockFd) != 1) + { + SocketLog("[ERROR]TCPSSLClient, SSL_set_fd failed[%lu:%s][%d]", ERR_get_error(), ERR_error_string(ERR_get_error(), nullptr), m_SSLConnectSocket.m_SockFd); + break; + } + + bool connectOK = false; + do + { + /* initiate the TLS/SSL handshake with an TLS/SSL server */ + int iResult = SSL_connect(m_SSLConnectSocket.m_pSSL); + if (iResult == 1) + { + connectOK = true; + break; + } - return false; - } + int iErrCode = SSL_get_error(m_SSLConnectSocket.m_pSSL, iResult); + if (iErrCode != SSL_ERROR_WANT_CONNECT) + { + // under Windows it creates problems + SocketLog("[ERROR]TCPSSLClient, SSL_connect failed[%d:%s][%d]", iErrCode, GetSSLErrorString(iErrCode), m_SSLConnectSocket.m_SockFd); +#ifndef _WIN32 + ERR_print_errors_fp(stdout); +#endif + break; + } + } while (1); + + if (!connectOK) + { + break; + } + + /* The data can now be transmitted securely over this connection. */ + SocketLog("[INFO ]TCPSSLClient, SSL_connect with '%s' encryption.", SSL_get_cipher(m_SSLConnectSocket.m_pSSL)); + +#if 0 + if (SSL_get_peer_certificate(m_SSLConnectSocket.m_pSSL) == nullptr) + { + SocketLog("[WARN ]TCPSSLClient, SSL_get_peer_certificate failed[the peer certificate was not presented.][%d:%s][%d]", iErrCode, GetSSLErrorString(iErrCode), m_SSLConnectSocket.m_SockFd); + } + else + { + if (SSL_get_verify_result(m_SSLConnectSocket.m_pSSL) != X509_V_OK) + { + SocketLog("[ERROR]TCPSSLClient, client verification with SSL_get_verify_result failed.[%d]", m_SSLConnectSocket.m_SockFd); + break; + } - if (m_eSettingsFlags & ENABLE_LOG) - m_oLog("[TCPSSLClient][Error] Unable to establish a TCP connection with the server."); + SocketLog("[WARN ]TCPSSLClient, client verification with SSL_get_verify_result succeeded.[%d]", m_SSLConnectSocket.m_SockFd); + } +#endif - return false; -} + return true; + } while (0); -bool CTCPSSLClient::Send(const char* pData, const size_t uSize) const -{ - if (m_TCPClient.m_eStatus != CTCPClient::CONNECTED) - { - if (m_eSettingsFlags & ENABLE_LOG) - m_oLog("[TCPSSLClient][Error] SSL send failed : not connected to an SSL server."); - - return false; - } - - int total = 0; - do - { - /* encrypt & send message */ - int nSent = SSL_write(m_SSLConnectSocket.m_pSSL, pData + total, uSize - total); - if (nSent <= 0) - { - if (m_eSettingsFlags & ENABLE_LOG) - m_oLog(StringFormat("[TCPSSLClient][Error] SSL_write failed (Error=%d | %s)", - nSent, GetSSLErrorString(SSL_get_error(m_SSLConnectSocket.m_pSSL, nSent)))); - - return false; - } - - total += nSent; - } while (total < uSize); - - return true; + Disconnect(); + return false; } -bool CTCPSSLClient::Send(const std::string& strData) const +int CTCPSSLClient::Receive(char* pData, size_t uSize, bool bReadFully /*= true*/) const { - return Send(strData.c_str(), strData.length()); -} + if (m_TCPClient.m_eStatus != CTCPClient::CONNECTED || m_TCPClient.m_ConnectSocket == INVALID_SOCKET) + { + SocketLog("[ERROR]TCPSSLClient, SSL_read failed[not connected to a SSL server.]"); + return -1; + } + + if (pData == nullptr && uSize != 0) + { + SocketLog("[ERROR]TCPSSLClient, SSL_read failed[%d][%p:%zu]", m_SSLConnectSocket.m_SockFd, pData, uSize); + return -2; + } + + int total = 0; + do + { + int nRecvd = SSL_read(m_SSLConnectSocket.m_pSSL, pData + total, (int)uSize - total); + if (nRecvd <= 0) + { + int iErrCode = SSL_get_error(m_SSLConnectSocket.m_pSSL, nRecvd); + if (iErrCode == SSL_ERROR_WANT_READ) + { + continue; + } -bool CTCPSSLClient::Send(const std::vector& Data) const -{ - return Send(Data.data(), Data.size()); -} + SocketLog("[ERROR]TCPSSLClient, SSL_read failed[%d:%s][%d]", iErrCode, GetSSLErrorString(iErrCode), m_SSLConnectSocket.m_SockFd); + break; + } -bool CTCPSSLClient::HasPending() -{ - int pend; + total += nRecvd; - pend = SSL_has_pending(m_SSLConnectSocket.m_pSSL); + } while (bReadFully && (total < (int)uSize)); - return pend == 1; + return total; } -int CTCPSSLClient::PendingBytes() +int CTCPSSLClient::Send(const char* pData, size_t uSize) const { - int nPend; + if (m_TCPClient.m_eStatus != CTCPClient::CONNECTED || m_TCPClient.m_ConnectSocket == INVALID_SOCKET) + { + SocketLog("[ERROR]TCPSSLClient, SSL_write failed[not connected to a SSL server.]"); + return -1; + } + + //OpenSSL 1.1.1 + if (pData == nullptr && uSize != 0) + { + SocketLog("[ERROR]TCPSSLClient, send failed[%d][%p:%zu]", m_SSLConnectSocket.m_SockFd, pData, uSize); + return -2; + } + + int total = 0; + do + { + /* encrypt & send message */ + int nSent = SSL_write(m_SSLConnectSocket.m_pSSL, pData + total, (int)uSize - total); + if (nSent <= 0) + { + int iErrCode = SSL_get_error(m_SSLConnectSocket.m_pSSL, nSent); + if (iErrCode == SSL_ERROR_WANT_WRITE) + { + continue; + } - nPend = SSL_pending(m_SSLConnectSocket.m_pSSL); + SocketLog("[ERROR]TCPSSLClient, SSL_write failed[%d:%s][%d]", iErrCode, GetSSLErrorString(iErrCode), m_SSLConnectSocket.m_SockFd); + return -1; + } - return nPend; + total += nSent; + } while (total < (int)uSize); + + return (int)total; } -int CTCPSSLClient::Receive(char* pData, const size_t uSize, bool bReadFully /*= true*/) const +int CTCPSSLClient::Send(const std::string& strData) const { - if (m_TCPClient.m_eStatus != CTCPClient::CONNECTED) - { - if (m_eSettingsFlags & ENABLE_LOG) - m_oLog("[TCPSSLClient][Error] SSL recv failed : not connected to a server."); - - return -1; - } - - int total = 0; - do - { - int nRecvd = SSL_read(m_SSLConnectSocket.m_pSSL, pData + total, uSize - total); - - if (nRecvd <= 0) - { - if (m_eSettingsFlags & ENABLE_LOG) - m_oLog(StringFormat("[TCPSSLClient][Error] SSL_read failed (Error=%d | %s)", - nRecvd, GetSSLErrorString(SSL_get_error(m_SSLConnectSocket.m_pSSL, nRecvd)))); - - break; - } - - total += nRecvd; - - } while (bReadFully && (total < uSize)); - - return total; + return Send(strData.c_str(), strData.length()); } -bool CTCPSSLClient::Disconnect() +int CTCPSSLClient::Send(const std::vector& Data) const { - if (m_TCPClient.m_eStatus != CTCPClient::CONNECTED) - return true; + return Send(Data.data(), Data.size()); +} - // send close_notify message to notify peer of the SSL closure. - ShutdownSSL(m_SSLConnectSocket); +bool CTCPSSLClient::HasPending() +{ + return m_SSLConnectSocket.HasPending(); +} - return m_TCPClient.Disconnect(); +int CTCPSSLClient::PendingBytes() +{ + return m_SSLConnectSocket.PendingBytes(); } -CTCPSSLClient::~CTCPSSLClient() +void CTCPSSLClient::Disconnect() { - if (m_TCPClient.m_eStatus == CTCPClient::CONNECTED) - { - Disconnect(); - m_TCPClient.Disconnect(); - } + // send close_notify message to notify peer of the SSL closure. + m_SSLConnectSocket.Disconnect(); + //ShutdownSSL(m_SSLConnectSocket); + m_TCPClient.Disconnect(); } #endif diff --git a/Socket/TCPSSLClient.h b/Socket/TCPSSLClient.h index 332ecce..633a01e 100644 --- a/Socket/TCPSSLClient.h +++ b/Socket/TCPSSLClient.h @@ -16,43 +16,40 @@ class CTCPSSLClient : public ASecureSocket { public: - explicit CTCPSSLClient(const LogFnCallback oLogger, - const OpenSSLProtocol eSSLVersion = OpenSSLProtocol::TLS, - const SettingsFlag eSettings = ALL_FLAGS); - ~CTCPSSLClient() override; - - CTCPSSLClient(const CTCPSSLClient&) = delete; - CTCPSSLClient& operator=(const CTCPSSLClient&) = delete; - - /* connect to a TCP SSL server */ - bool Connect(const std::string& strServer, const std::string& strPort); - - bool SetRcvTimeout(unsigned int timeout); - bool SetSndTimeout(unsigned int timeout); - -#ifndef WINDOWS - bool SetRcvTimeout(struct timeval timeout); - bool SetSndTimeout(struct timeval timeout); + explicit CTCPSSLClient(const LogFnCallback& oLogger, + const OpenSSLProtocol& eSSLVersion = OpenSSLProtocol::TLS, + const SettingsFlag eSettings = ALL_FLAGS); + ~CTCPSSLClient() override; + + CTCPSSLClient(const CTCPSSLClient&) = delete; + CTCPSSLClient& operator=(const CTCPSSLClient&) = delete; + + /* connect to a TCP SSL server */ + bool Connect(const std::string& strServer, const std::string& strPort); + /* disconnect from the SSL TCP server */ + void Disconnect(); + + int Receive(char* pData, size_t uSize, bool bReadFully = true) const; + /* send data to a TCP SSL server */ + int Send(const char* pData, size_t uSize) const; + int Send(const std::string& strData) const; + int Send(const std::vector& Data) const; + + bool SetRcvTimeout(unsigned int timeout); + bool SetSndTimeout(unsigned int timeout); + +#ifndef _WIN32 + bool SetRcvTimeout(struct timeval timeout); + bool SetSndTimeout(struct timeval timeout); #endif - /* disconnect from the SSL TCP server */ - bool Disconnect(); - - /* send data to a TCP SSL server */ - bool Send(const char* pData, const size_t uSize) const; - bool Send(const std::string& strData) const; - bool Send(const std::vector& Data) const; - - /* receive data from a TCP SSL server */ - bool HasPending(); - int PendingBytes(); - - int Receive(char* pData, const size_t uSize, bool bReadFully = true) const; + /* receive data from a TCP SSL server */ + bool HasPending(); + int PendingBytes(); protected: - CTCPClient m_TCPClient; - SSLSocket m_SSLConnectSocket; - + CTCPClient m_TCPClient; + SSLSocket m_SSLConnectSocket; }; #endif diff --git a/Socket/TCPSSLServer.cpp b/Socket/TCPSSLServer.cpp index 2658014..9a3324a 100644 --- a/Socket/TCPSSLServer.cpp +++ b/Socket/TCPSSLServer.cpp @@ -8,239 +8,279 @@ #include "TCPSSLServer.h" CTCPSSLServer::CTCPSSLServer(const LogFnCallback oLogger, - const std::string& strPort, - const OpenSSLProtocol eSSLVersion, - const SettingsFlag eSettings /*= ALL_FLAGS*/) - /*throw (EResolveError)*/ : - ASecureSocket(oLogger, eSSLVersion, eSettings), - m_TCPServer(oLogger, strPort, eSettings) + const std::string& strPort, + const OpenSSLProtocol eSSLVersion, + const SettingsFlag eSettings /*= ALL_FLAGS*/) /*throw (EResolveError)*/ + : ASecureSocket(oLogger, eSSLVersion, eSettings) + , m_TCPServer(oLogger, strPort, eSettings) { - } -bool CTCPSSLServer::SetRcvTimeout(SSLSocket& ClientSocket, unsigned int msec_timeout){ - return m_TCPServer.SetRcvTimeout(ClientSocket.m_SockFd, msec_timeout); +CTCPSSLServer::~CTCPSSLServer() +{ + SocketClose(m_TCPServer.m_ListenSocket); } -bool CTCPSSLServer::SetSndTimeout(SSLSocket& ClientSocket, unsigned int msec_timeout){ - return m_TCPServer.SetSndTimeout(ClientSocket.m_SockFd, msec_timeout); +bool CTCPSSLServer::SetRcvTimeout(SSLSocket& ClientSocket, unsigned int msec_timeout) +{ + return m_TCPServer.SetRcvTimeout(ClientSocket.m_SockFd, msec_timeout); } -#ifndef WINDOWS -bool CTCPSSLServer::SetRcvTimeout(SSLSocket& ClientSocket, struct timeval timeout) { +#ifndef _WIN32 +bool CTCPSSLServer::SetRcvTimeout(SSLSocket& ClientSocket, struct timeval timeout) +{ return m_TCPServer.SetRcvTimeout(ClientSocket.m_SockFd, timeout); } - -bool CTCPSSLServer::SetSndTimeout(SSLSocket& ClientSocket, struct timeval timeout){ - return m_TCPServer.SetSndTimeout(ClientSocket.m_SockFd, timeout); -} #endif -// returns the socket of the accepted client -bool CTCPSSLServer::Listen(SSLSocket& ClientSocket, size_t msec /*= ACCEPT_WAIT_INF_DELAY*/) +bool CTCPSSLServer::SetSndTimeout(SSLSocket& ClientSocket, unsigned int msec_timeout) { - if (m_TCPServer.Listen(ClientSocket.m_SockFd, msec)) - { - SetUpCtxServer(ClientSocket); - - if (ClientSocket.m_pCTXSSL == nullptr) - { - if (m_eSettingsFlags & ENABLE_LOG) - m_oLog("[TCPSSLServer][Error] SSL CTX failed."); - //ERR_print_errors_fp(stdout); - return false; - } - - //SSL_CTX_set_options(ClientSocket.m_pCTXSSL, SSL_OP_SINGLE_DH_USE); - //SSL_CTX_set_cert_verify_callback(ClientSocket.m_pCTXSSL, AlwaysTrueCallback, nullptr); - - /* Load server certificate into the SSL context. */ - if (!m_strSSLCertFile.empty()) - { - if (SSL_CTX_use_certificate_file(ClientSocket.m_pCTXSSL, - m_strSSLCertFile.c_str(), SSL_FILETYPE_PEM) <= 0) - { - if (m_eSettingsFlags & ENABLE_LOG) - m_oLog("[TCPSSLServer][Error] Loading cert file failed."); - //ERR_print_errors_fp(stdout); - return false; - } - } - /* Load trusted CA file. */ - if (!m_strCAFile.empty()) - { - if (!SSL_CTX_load_verify_locations(ClientSocket.m_pCTXSSL, m_strCAFile.c_str(), nullptr)) - { - if (m_eSettingsFlags & ENABLE_LOG) - m_oLog("[TCPSSLServer][Error] Loading CA file failed."); - - return false; - } - /* Set to require peer (client) certificate verification. */ - //SSL_CTX_set_verify(m_SSLConnectSocket.m_pCTXSSL, SSL_VERIFY_PEER, VerifyCallback); - /* Set the verification depth to 1 */ - SSL_CTX_set_verify_depth(ClientSocket.m_pCTXSSL, 1); - } - /* Load the server private-key into the SSL context. */ - if (!m_strSSLKeyFile.empty()) - { - if (SSL_CTX_use_PrivateKey_file(ClientSocket.m_pCTXSSL, - m_strSSLKeyFile.c_str(), SSL_FILETYPE_PEM) <= 0) - { - if (m_eSettingsFlags & ENABLE_LOG) - m_oLog("[TCPSSLServer][Error] Loading key file failed."); - //ERR_print_errors_fp(stdout); - return false; - } - - // verify private key - /*if (!SSL_CTX_check_private_key(ClientSocket.m_pCTXSSL)) - { - if (m_eSettingsFlags & ENABLE_LOG) - m_oLog("[TCPSSLServer][Error] Private key does not match the public certificate."); - - return false; - }*/ - } - - ClientSocket.m_pSSL = SSL_new(ClientSocket.m_pCTXSSL); - // set the socket directly into the SSL structure or we can use a BIO structure - SSL_set_fd(ClientSocket.m_pSSL, ClientSocket.m_SockFd); - - /* wait for a TLS/SSL client to initiate a TLS/SSL handshake */ - int iSSLErr = SSL_accept(ClientSocket.m_pSSL); - if (iSSLErr <= 0) - { - //Error occurred, log and close down ssl - if (m_eSettingsFlags & ENABLE_LOG) - m_oLog(StringFormat("[TCPSSLServer][Error] accept failed. (Error=%d | %s)", - iSSLErr, GetSSLErrorString(SSL_get_error(ClientSocket.m_pSSL, iSSLErr)))); - - //if (iSSLErr < 0) - // under Windows it creates problems - #ifdef LINUX - ERR_print_errors_fp(stdout); - #endif - - ShutdownSSL(ClientSocket); - - return false; - } - - /* The TLS/SSL handshake is successfully completed and a TLS/SSL connection - * has been established. Now all reads and writes must use SSL. */ - // peer_cert = SSL_get_peer_certificate(ClientSocket.m_pSSL); - return true; - } - - if (m_eSettingsFlags & ENABLE_LOG) - m_oLog("[TCPSSLServer][Error] Unable to accept an incoming TCP connection with a client."); - - return false; + return m_TCPServer.SetSndTimeout(ClientSocket.m_SockFd, msec_timeout); } -bool CTCPSSLServer::HasPending(const SSLSocket& ClientSocket) +#ifndef _WIN32 +bool CTCPSSLServer::SetSndTimeout(SSLSocket& ClientSocket, struct timeval timeout) { - int pend; - - pend = SSL_has_pending(ClientSocket.m_pSSL); - - return pend == 1; + return m_TCPServer.SetSndTimeout(ClientSocket.m_SockFd, timeout); } +#endif -int CTCPSSLServer::PendingBytes(const SSLSocket& ClientSocket) +// returns the socket of the accepted client +int CTCPSSLServer::Listen(SSLSocket& ClientSocket, size_t msec /*= ACCEPT_WAIT_INF_DELAY*/) { - int nPend; - - nPend = SSL_pending(ClientSocket.m_pSSL); - - return nPend; + int ret_val = m_TCPServer.Listen(ClientSocket.m_SockFd, msec); + if (ret_val < 0) + { + SocketLog("[ERROR]TCPSSLServer, m_TCPServer Listen failed[:Unable to accept an incoming TCP connection with a client.][:%s]", m_TCPServer.m_strPort.c_str()); + return -1; + } + + if (ret_val == 0) + { + return 0; + } + + ret_val = 0; + do + { + if (!SetUpCtxServer(ClientSocket)) + { + SocketLog("[ERROR]TCPSSLServer, SSL_CTX_new failed[:%s][%d]", m_TCPServer.m_strPort.c_str(), ClientSocket.m_SockFd); + //ERR_print_errors_fp(stdout); + break; + } + + //SSL_CTX_set_options(ClientSocket.m_pCTXSSL, SSL_OP_SINGLE_DH_USE); + //SSL_CTX_set_cert_verify_callback(ClientSocket.m_pCTXSSL, AlwaysTrueCallback, nullptr); + + /* Load server certificate into the SSL context. */ + if (!m_strSSLCertFile.empty()) + { + if (SSL_CTX_use_certificate_file(ClientSocket.m_pCTXSSL, m_strSSLCertFile.c_str(), SSL_FILETYPE_PEM) != 1) + { + SocketLog("[ERROR]TCPSSLServer, SSL_CTX_use_certificate_file failed[Loading cert file failed.][%lu:%s][%d][%s]", ERR_get_error(), ERR_error_string(ERR_get_error(), nullptr), ClientSocket.m_SockFd, m_strSSLCertFile.c_str()); + //ERR_print_errors_fp(stdout); + break; + } + } + + /* Load trusted CA file. */ + if (!m_strCAFile.empty()) + { + if (SSL_CTX_load_verify_locations(ClientSocket.m_pCTXSSL, m_strCAFile.c_str(), nullptr) != 1) + { + SocketLog("[ERROR]TCPSSLServer, SSL_CTX_load_verify_locations failed[Loading CA file failed.][%lu:%s][%d][%s]", ERR_get_error(), ERR_error_string(ERR_get_error(), nullptr), ClientSocket.m_SockFd, m_strCAFile.c_str()); + break; + } + + /* Set to require peer (client) certificate verification. */ + //SSL_CTX_set_verify(m_SSLConnectSocket.m_pCTXSSL, SSL_VERIFY_PEER, VerifyCallback); + + /* Set the verification depth to 1 */ + SSL_CTX_set_verify_depth(ClientSocket.m_pCTXSSL, 1); + } + + /* Load the server private-key into the SSL context. */ + if (!m_strSSLKeyFile.empty()) + { + if (SSL_CTX_use_PrivateKey_file(ClientSocket.m_pCTXSSL, m_strSSLKeyFile.c_str(), SSL_FILETYPE_PEM) != 1) + { + SocketLog("[ERROR]TCPSSLServer, SSL_CTX_use_PrivateKey_file failed[Loading key file failed.][%lu:%s][%d][%s]", ERR_get_error(), ERR_error_string(ERR_get_error(), nullptr), ClientSocket.m_SockFd, m_strSSLKeyFile.c_str()); + //ERR_print_errors_fp(stdout); + break; + } + +#if 0 + // verify private key + if (SSL_CTX_check_private_key(ClientSocket.m_pCTXSSL) != 1) + { + SocketLog("[ERROR]TCPSSLServer, SSL_CTX_check_private_key failed[Private key does not match the public certificate.][%lu:%s][%d]", ERR_get_error(), ERR_error_string(ERR_get_error(), nullptr), ClientSocket.m_SockFd); + break; + } +#endif + } + + ClientSocket.m_pSSL = SSL_new(ClientSocket.m_pCTXSSL); + if (ClientSocket.m_pSSL == nullptr) + { + SocketLog("[ERROR]TCPSSLServer, SSL_new failed[%lu:%s][%d]", ERR_get_error(), ERR_error_string(ERR_get_error(), nullptr), ClientSocket.m_SockFd); + break; + } + + // set the socket directly into the SSL structure or we can use a BIO structure + if (SSL_set_fd(ClientSocket.m_pSSL, (int)ClientSocket.m_SockFd) != 1) + { + SocketLog("[ERROR]TCPSSLServer, SSL_set_fd failed[%lu:%s][%d]", ERR_get_error(), ERR_error_string(ERR_get_error(), nullptr), ClientSocket.m_SockFd); + break; + } + + bool acceptOK = false; + do + { + /* wait for a TLS/SSL client to initiate a TLS/SSL handshake */ + int iSSLErr = SSL_accept(ClientSocket.m_pSSL); + if (iSSLErr == 1) + { + acceptOK = true; + break; + } + + //Error occurred, log and close down ssl + int iErrCode = SSL_get_error(ClientSocket.m_pSSL, iSSLErr); + if (iErrCode != SSL_ERROR_WANT_ACCEPT) + { + SocketLog("[ERROR]TCPSSLServer, SSL_accept failed[%d:%s][%d]", iErrCode, GetSSLErrorString(iErrCode), ClientSocket.m_SockFd); + // under Windows it creates problems +#ifndef _WIN32 + ERR_print_errors_fp(stdout); +#endif + break; + } + } while (1); + + if (!acceptOK) + { + break; + } + + /* The TLS/SSL handshake is successfully completed and a TLS/SSL connection + * has been established. Now all reads and writes must use SSL. */ + // peer_cert = SSL_get_peer_certificate(ClientSocket.m_pSSL); + SocketLog("[ERROR]TCPSSLServer, SSL_accept accepted[%d]", ClientSocket.m_SockFd); + return 1; + } while (0); + + Disconnect(ClientSocket); + return ret_val; } -/* When an SSL_read() operation has to be repeated because of SSL_ERROR_WANT_READ or SSL_ERROR_WANT_WRITE, - * it must be repeated with the same arguments.*/ -int CTCPSSLServer::Receive(const SSLSocket& ClientSocket, - char* pData, - const size_t uSize, - bool bReadFully /*= true*/) const +/** + * When an SSL_read() operation has to be repeated because of SSL_ERROR_WANT_READ or SSL_ERROR_WANT_WRITE, + * it must be repeated with the same arguments. + */ +int CTCPSSLServer::Receive(const SSLSocket& ClientSocket, char* pData, size_t uSize, bool bReadFully /*= true*/) const { - int total = 0; - do - { - int nRecvd = SSL_read(ClientSocket.m_pSSL, pData + total, uSize - total); - - if (nRecvd <= 0) - { - if (m_eSettingsFlags & ENABLE_LOG) - m_oLog(StringFormat("[TCPSSLServer][Error] SSL_read failed (Error=%d | %s)", - nRecvd, GetSSLErrorString(SSL_get_error(ClientSocket.m_pSSL, nRecvd)))); - - //ERR_print_errors_fp(stdout); - - break; - } + if (ClientSocket.m_SockFd == INVALID_SOCKET) + { + SocketLog("[ERROR]TCPSSLServer, SSL_read failed[not a connection to SSL server.]"); + return -1; + } + + if (pData == nullptr && uSize != 0) + { + SocketLog("[ERROR]TCPSSLServer, SSL_read failed[%d][%p:%zu]", ClientSocket.m_SockFd, pData, uSize); + return -2; + } + + int total = 0; + do + { + int nRecvd = SSL_read(ClientSocket.m_pSSL, pData + total, (int)uSize - total); + if (nRecvd <= 0) + { + int iErrCode = SSL_get_error(ClientSocket.m_pSSL, nRecvd); + if (iErrCode == SSL_ERROR_WANT_READ) + { + continue; + } + + SocketLog("[ERROR]TCPSSLServer, SSL_read failed[%d:%s][%d]", iErrCode, GetSSLErrorString(iErrCode), ClientSocket.m_SockFd); + //ERR_print_errors_fp(stdout); + break; + } - total += nRecvd; + total += nRecvd; - } while(bReadFully && (total < uSize)); + } while (bReadFully && (total < (int)uSize)); - return total; + return total; } /* When an SSL_write() operation has to be repeated because of SSL_ERROR_WANT_READ or SSL_ERROR_WANT_WRITE, * it must be repeated with the same arguments. * When calling SSL_write() with uSize=0 bytes to be sent the behaviour is undefined. */ -bool CTCPSSLServer::Send(const SSLSocket& ClientSocket, const char* pData, const size_t uSize) const +int CTCPSSLServer::Send(const SSLSocket& ClientSocket, const char* pData, size_t uSize) const { - int total = 0; - do - { - int nSent; - - nSent = SSL_write(ClientSocket.m_pSSL, pData + total, uSize - total); - - if (nSent <= 0) - { - if (m_eSettingsFlags & ENABLE_LOG) - m_oLog(StringFormat("[TCPSSLServer][Error] SSL_write failed (Error=%d | %s).", - nSent, GetSSLErrorString(SSL_get_error(ClientSocket.m_pSSL, nSent)))); - - return false; - } - total += nSent; - } while (total < uSize); - - return true; + if (ClientSocket.m_pSSL == nullptr || ClientSocket.m_SockFd == INVALID_SOCKET) + { + SocketLog("[ERROR]TCPSSLServer, SSL_write failed[not a connection to SSL server.]"); + return -1; + } + + //OpenSSL 1.1.1 + if (pData == nullptr && uSize != 0) + { + SocketLog("[ERROR]TCPSSLServer, SSL_write failed[%d][%p:%zu]", ClientSocket.m_SockFd, pData, uSize); + return -2; + } + + int total = 0; + do + { + int nSent = SSL_write(ClientSocket.m_pSSL, pData + total, (int)uSize - total); + if (nSent <= 0) + { + int iErrCode = SSL_get_error(ClientSocket.m_pSSL, nSent); + if (iErrCode == SSL_ERROR_WANT_WRITE) + { + continue; + } + + SocketLog("[ERROR]TCPSSLServer, SSL_write failed[%d:%s][%d]", iErrCode, GetSSLErrorString(iErrCode), ClientSocket.m_SockFd); + return -1; + } + + total += nSent; + } while (total < (int)uSize); + + return (int)total; } -bool CTCPSSLServer::Send(const SSLSocket& ClientSocket, const std::string& strData) const +int CTCPSSLServer::Send(const SSLSocket& ClientSocket, const std::string& strData) const { - bool ret; - - ret = Send(ClientSocket, strData.c_str(), strData.length()); - - return ret; + return Send(ClientSocket, strData.c_str(), strData.length()); } -bool CTCPSSLServer::Send(const SSLSocket& ClientSocket, const std::vector& Data) const +int CTCPSSLServer::Send(const SSLSocket& ClientSocket, const std::vector& Data) const { - bool ret; - - ret = Send(ClientSocket, Data.data(), Data.size()); - - return ret; + return Send(ClientSocket, Data.data(), Data.size()); } -bool CTCPSSLServer::Disconnect(SSLSocket& ClientSocket) const +bool CTCPSSLServer::HasPending(const SSLSocket& ClientSocket) { - // send close_notify message to notify peer of the SSL closure. - ShutdownSSL(ClientSocket); - - return m_TCPServer.Disconnect(ClientSocket.m_SockFd); + return ClientSocket.HasPending(); } -CTCPSSLServer::~CTCPSSLServer() +int CTCPSSLServer::PendingBytes(const SSLSocket& ClientSocket) { + return ClientSocket.PendingBytes(); +} +void CTCPSSLServer::Disconnect(SSLSocket& ClientSocket) const +{ + // send close_notify message to notify peer of the SSL closure. + ClientSocket.Disconnect(); + //ShutdownSSL(ClientSocket); + m_TCPServer.Disconnect(ClientSocket.m_SockFd); } #endif diff --git a/Socket/TCPSSLServer.h b/Socket/TCPSSLServer.h index 69c8b3e..ab0ba71 100644 --- a/Socket/TCPSSLServer.h +++ b/Socket/TCPSSLServer.h @@ -13,46 +13,42 @@ #include "SecureSocket.h" #include "TCPServer.h" -/* private inheritance from CTCPServer is replaced with composition to avoid +/* private inheritance from CTCPServer is replaced with composition to avoid * ambiguity on the log callable object */ class CTCPSSLServer : public ASecureSocket { public: - explicit CTCPSSLServer(const LogFnCallback oLogger, - const std::string& strPort, - const OpenSSLProtocol eSSLVersion = OpenSSLProtocol::TLS, - const SettingsFlag eSettings = ALL_FLAGS) - /*throw (EResolveError)*/; + explicit CTCPSSLServer(const LogFnCallback oLogger, + const std::string& strPort, + const OpenSSLProtocol eSSLVersion = OpenSSLProtocol::TLS, + const SettingsFlag eSettings = ALL_FLAGS) /*throw (EResolveError)*/; + ~CTCPSSLServer() override; - ~CTCPSSLServer() override; + CTCPSSLServer(const CTCPSSLServer&) = delete; + CTCPSSLServer& operator=(const CTCPSSLServer&) = delete; - CTCPSSLServer(const CTCPSSLServer&) = delete; - CTCPSSLServer& operator=(const CTCPSSLServer&) = delete; + int Listen(SSLSocket& ClientSocket, size_t msec = ACCEPT_WAIT_INF_DELAY); + void Disconnect(SSLSocket& ClientSocket) const; - bool Listen(SSLSocket& ClientSocket, size_t msec = ACCEPT_WAIT_INF_DELAY); + int Receive(const SSLSocket& ClientSocket, char* pData, size_t uSize, bool bReadFully = true) const; + int Send(const SSLSocket& ClientSocket, const char* pData, size_t uSize) const; + int Send(const SSLSocket& ClientSocket, const std::string& strData) const; + int Send(const SSLSocket& ClientSocket, const std::vector& Data) const; - bool SetRcvTimeout(SSLSocket& ClientSocket, unsigned int msec_timeout); - bool SetSndTimeout(SSLSocket& ClientSocket, unsigned int timeout); - -#ifndef WINDOWS - bool SetRcvTimeout(SSLSocket& ClientSocket, struct timeval timeout); - bool SetSndTimeout(SSLSocket& ClientSocket, struct timeval timeout); -#endif - - bool HasPending(const SSLSocket& ClientSocket); - int PendingBytes(const SSLSocket& ClientSocket); - int Receive(const SSLSocket& ClientSocket, char* pData, - const size_t uSize, bool bReadFully = true) const; + bool SetRcvTimeout(SSLSocket& ClientSocket, unsigned int msec_timeout); + bool SetSndTimeout(SSLSocket& ClientSocket, unsigned int msec_timeout); - bool Send(const SSLSocket& ClientSocket, const char* pData, const size_t uSize) const; - bool Send(const SSLSocket& ClientSocket, const std::string& strData) const; - bool Send(const SSLSocket& ClientSocket, const std::vector& Data) const; +#ifndef _WIN32 + bool SetRcvTimeout(SSLSocket& ClientSocket, struct timeval timeout); + bool SetSndTimeout(SSLSocket& ClientSocket, struct timeval timeout); +#endif - bool Disconnect(SSLSocket& ClientSocket) const; + bool HasPending(const SSLSocket& ClientSocket); + int PendingBytes(const SSLSocket& ClientSocket); protected: - CTCPServer m_TCPServer; + CTCPServer m_TCPServer; }; diff --git a/Socket/TCPServer.cpp b/Socket/TCPServer.cpp index dff5eef..0a0a8a7 100644 --- a/Socket/TCPServer.cpp +++ b/Socket/TCPServer.cpp @@ -7,471 +7,370 @@ #include "TCPServer.h" CTCPServer::CTCPServer(const LogFnCallback oLogger, - /*const std::string& strAddr,*/ - const std::string& strPort, - const SettingsFlag eSettings /*= ALL_FLAGS*/) - /*throw (EResolveError)*/ : - ASocket(oLogger, eSettings), - m_ListenSocket(INVALID_SOCKET), -#ifdef WINDOWS - m_pResultAddrInfo(nullptr), -#endif - //m_strHost(strAddr), - m_strPort(strPort) { -#ifdef WINDOWS - // Resolve the server address and port - ZeroMemory(&m_HintsAddrInfo, sizeof(m_HintsAddrInfo)); - /* AF_INET is used to specify the IPv4 address family. */ - m_HintsAddrInfo.ai_family = AF_INET; - /* SOCK_STREAM is used to specify a stream socket. */ - m_HintsAddrInfo.ai_socktype = SOCK_STREAM; - /* IPPROTO_TCP is used to specify the TCP protocol. */ - m_HintsAddrInfo.ai_protocol = IPPROTO_TCP; - /* AI_PASSIVE flag indicates the caller intends to use the returned socket - * address structure in a call to the bind function.*/ - m_HintsAddrInfo.ai_flags = AI_PASSIVE; - - int iResult = getaddrinfo(nullptr, strPort.c_str(), &m_HintsAddrInfo, &m_pResultAddrInfo); - if (iResult != 0) - { - if (m_pResultAddrInfo != nullptr) - { - freeaddrinfo(m_pResultAddrInfo); - m_pResultAddrInfo = nullptr; - } - - throw EResolveError(StringFormat("[TCPServer][Error] getaddrinfo failed : %d", iResult)); - } -#else - // clear address structure - bzero((char*) &m_ServAddr, sizeof(m_ServAddr)); - - int iPort = atoi(strPort.c_str()); - - /* setup the host_addr structure for use in bind call */ - // server byte order - m_ServAddr.sin_family = AF_INET; - - // automatically be filled with current host's IP address - m_ServAddr.sin_addr.s_addr = INADDR_ANY; - //m_ServAddr.sin_addr.s_addr = inet_addr(strAddr.c_str()); // doesn't work ! - - // convert short integer value for port must be converted into network byte order - m_ServAddr.sin_port = htons(iPort); -#endif + /*const std::string& strAddr,*/ const std::string& strPort, + const SettingsFlag eSettings /*= ALL_FLAGS*/) /*throw (EResolveError)*/ + : ASocket(oLogger, eSettings) + , m_ListenSocket(INVALID_SOCKET) + //, m_strHost(strAddr) + , m_strPort(strPort) + , m_pResultAddrInfo(nullptr) + , m_HintsAddrInfo() +{ } -// Method for setting receive timeout. Can be called after Listen, using the previously created ClientSocket -bool CTCPServer::SetRcvTimeout(ASocket::Socket& ClientSocket, unsigned int msec_timeout) { -#ifndef WINDOWS - struct timeval t = ASocket::TimevalFromMsec(msec_timeout); - - return this->SetRcvTimeout(ClientSocket, t); -#else - int iErr; - - iErr = setsockopt(ClientSocket, SOL_SOCKET, SO_RCVTIMEO, (char*)&msec_timeout, sizeof(struct timeval)); - if (iErr < 0) { - if (m_eSettingsFlags & ENABLE_LOG) - m_oLog("[TCPServer][Error] CTCPServer::SetRcvTimeout : Socket error in SO_RCVTIMEO call to setsockopt."); - - return false; - } - - return true; -#endif +CTCPServer::~CTCPServer() +{ + SocketClose(m_ListenSocket); } -// Method for setting send timeout. Can be called after Listen, using the previously created ClientSocket -bool CTCPServer::SetSndTimeout(ASocket::Socket& ClientSocket, unsigned int msec_timeout) { -#ifndef WINDOWS - struct timeval t = ASocket::TimevalFromMsec(msec_timeout); - - return this->SetRcvTimeout(ClientSocket, t); -#else - int iErr; - - iErr = setsockopt(ClientSocket, SOL_SOCKET, SO_SNDTIMEO, (char*)&msec_timeout, sizeof(struct timeval)); - if (iErr < 0) { - if (m_eSettingsFlags & ENABLE_LOG) - m_oLog("[TCPServer][Error] CTCPServer::SetSndTimeout : Socket error in SO_SNDTIMEO call to setsockopt."); - - return false; - } - - return true; -#endif +bool CTCPServer::InitAddrInfo() +{ + // Resolve the server address and port + memset(&m_HintsAddrInfo, 0, sizeof(m_HintsAddrInfo)); + /* AF_INET is used to specify the IPv4 address family. */ + m_HintsAddrInfo.ai_family = AF_INET; + /* SOCK_STREAM is used to specify a stream socket. */ + m_HintsAddrInfo.ai_socktype = SOCK_STREAM; + /* IPPROTO_TCP is used to specify the TCP protocol. */ + m_HintsAddrInfo.ai_protocol = IPPROTO_TCP; + /* AI_PASSIVE flag indicates the caller intends to use the returned socket + * address structure in a call to the bind function.*/ + m_HintsAddrInfo.ai_flags = AI_PASSIVE; + + int iResult = getaddrinfo(nullptr, m_strPort.c_str(), &m_HintsAddrInfo, &m_pResultAddrInfo); + if (iResult != 0) + { + SocketLog("[ERROR]TCPServer, getaddrinfo failed[%s:%s][:%s]", iResult, GaiStrerror(iResult), m_strPort.c_str()); + return false; + } + + return true; } -#ifndef WINDOWS -bool CTCPServer::SetRcvTimeout(ASocket::Socket& ClientSocket, struct timeval Timeout) { - int iErr; - - iErr = setsockopt(ClientSocket, SOL_SOCKET, SO_RCVTIMEO, (char*)&Timeout, sizeof(struct timeval)); - if (iErr < 0) { - if (m_eSettingsFlags & ENABLE_LOG) - m_oLog("[TCPServer][Error] CTCPServer::SetRcvTimeout : Socket error in SO_RCVTIMEO call to setsockopt."); +// Method for setting receive timeout. Can be called after Listen, using the previously created ClientSocket +bool CTCPServer::SetRcvTimeout(Socket ClientSocket, unsigned int msec_timeout) +{ + bool ret_val = ASocket::SetRcvTimeout(ClientSocket, msec_timeout); + if (!ret_val) + { + SocketLog("[ERROR]TCPServer, setsockopt SOL_SOCKET SO_RCVTIMEO failed[%d:%s][%d][%u]", GetSocketError(), strerror(GetSocketError()), ClientSocket, msec_timeout); + } + + return ret_val; +} - return false; - } +#ifndef _WIN32 +bool CTCPServer::SetRcvTimeout(Socket ClientSocket, struct timeval timeout) +{ + bool ret_val = ASocket::SetRcvTimeout(ClientSocket, timeout); + if (!ret_val) + { + SocketLog("[ERROR]TCPServer, setsockopt SOL_SOCKET SO_RCVTIMEO failed[%d:%s][%d][%u:%u]", GetSocketError(), strerror(GetSocketError()), ClientSocket, timeout.tv_sec, timeout.tv_usec); + } - return true; + return ret_val; } +#endif -bool CTCPServer::SetSndTimeout(ASocket::Socket& ClientSocket, struct timeval Timeout) { - int iErr; - - iErr = setsockopt(ClientSocket, SOL_SOCKET, SO_SNDTIMEO, (char*) &Timeout, sizeof(struct timeval)); - if (iErr < 0) { - if (m_eSettingsFlags & ENABLE_LOG) - m_oLog("[TCPServer][Error] CTCPServer::SetSndTimeout : Socket error in SO_SNDTIMEO call to setsockopt."); +// Method for setting send timeout. Can be called after Listen, using the previously created ClientSocket +bool CTCPServer::SetSndTimeout(Socket ClientSocket, unsigned int msec_timeout) +{ + bool ret_val = ASocket::SetSndTimeout(ClientSocket, msec_timeout); + if (!ret_val) + { + SocketLog("[ERROR]TCPServer, setsockopt SOL_SOCKET SO_SNDTIMEO failed[%d:%s][%d][%u]", GetSocketError(), strerror(GetSocketError()), ClientSocket, msec_timeout); + } + + return ret_val; +} - return false; - } +#ifndef _WIN32 +bool CTCPServer::SetSndTimeout(Socket ClientSocket, struct timeval timeout) +{ + bool ret_val = ASocket::SetSndTimeout(ClientSocket, timeout); + if (!ret_val) + { + SocketLog("[ERROR]TCPServer, setsockopt SOL_SOCKET SO_SNDTIMEO failed[%d:%s][%d][%u:%u]", GetSocketError(), strerror(GetSocketError()), ClientSocket, timeout.tv_sec, timeout.tv_usec); + } - return true; + return ret_val; } #endif // returns the socket of the accepted client // maxRcvTime and maxSendTime define timeouts in µs for receiving and sending over the socket. Using a negative value // will deactivate the timeout. 0 will set a zero timeout. -bool CTCPServer::Listen(ASocket::Socket& ClientSocket, size_t msec /*= ACCEPT_WAIT_INF_DELAY*/) { - ClientSocket = INVALID_SOCKET; - - // creates a socket to listen for incoming client connections if it doesn't already exist - if (m_ListenSocket == INVALID_SOCKET) { -#ifdef WINDOWS - m_ListenSocket = socket(m_pResultAddrInfo->ai_family, - m_pResultAddrInfo->ai_socktype, - m_pResultAddrInfo->ai_protocol); - - if (m_ListenSocket == INVALID_SOCKET) - { - if (m_eSettingsFlags & ENABLE_LOG) - m_oLog(StringFormat("[TCPServer][Error] socket failed : %d", WSAGetLastError())); - freeaddrinfo(m_pResultAddrInfo); - m_pResultAddrInfo = nullptr; - return false; - } - - // Allow the socket to be bound to an address that is already in use - int opt = 1; - int iErr = 0; - - iErr = setsockopt(m_ListenSocket, SOL_SOCKET, SO_REUSEADDR, reinterpret_cast(&opt), sizeof(int)); - if (iErr < 0) - { - if (m_eSettingsFlags & ENABLE_LOG) - m_oLog("[TCPServer][Error] CTCPServer::Listen : Socket error in call to setsockopt."); - - closesocket(m_ListenSocket); - freeaddrinfo(m_pResultAddrInfo); m_pResultAddrInfo = nullptr; - - m_ListenSocket = INVALID_SOCKET; - - return false; - } - - // bind the listen socket to the host address:port - int iResult = bind(m_ListenSocket, - m_pResultAddrInfo->ai_addr, - static_cast(m_pResultAddrInfo->ai_addrlen)); - - freeaddrinfo(m_pResultAddrInfo); // free memory allocated by getaddrinfo - m_pResultAddrInfo = nullptr; - - if (iResult == SOCKET_ERROR) - { - if (m_eSettingsFlags & ENABLE_LOG) - m_oLog(StringFormat("[TCPServer][Error] bind failed : %d", WSAGetLastError())); - closesocket(m_ListenSocket); - m_ListenSocket = INVALID_SOCKET; - return false; - } -#else - - // create a socket - // socket(int domain, int type, int protocol) - m_ListenSocket = socket(AF_INET, SOCK_STREAM, 0/*IPPROTO_TCP*/); - if (m_ListenSocket < 0) { - if (m_eSettingsFlags & ENABLE_LOG) - m_oLog(StringFormat("[TCPServer][Error] opening socket : %s", strerror(errno))); - - m_ListenSocket = INVALID_SOCKET; - return false; - } - - // Allow the socket to be bound to an address that is already in use - int opt = 1; - int iErr = 0; - - iErr = setsockopt(m_ListenSocket, SOL_SOCKET, SO_REUSEADDR, reinterpret_cast(&opt), sizeof(int)); - if (iErr < 0) { - if (m_eSettingsFlags & ENABLE_LOG) - m_oLog("[TCPServer][Error] CTCPServer::Listen : Socket error in SO_REUSEADDR call to setsockopt."); - - close(m_ListenSocket); - m_ListenSocket = INVALID_SOCKET; - - return false; - } - - /* - iErr = setsockopt(m_ListenSocket, SOL_SOCKET, SO_KEEPALIVE, reinterpret_cast(&opt), sizeof(int)); - if (iErr < 0) - { - if (m_eSettingsFlags & ENABLE_LOG) - m_oLog("[TCPServer][Error] CTCPServer::Listen : Socket error in SO_KEEPALIVE call to setsockopt."); - - close(m_ListenSocket); - m_ListenSocket = INVALID_SOCKET; - - return false; - } - */ - - // bind(int fd, struct sockaddr *local_addr, socklen_t addr_length) - // bind() passes file descriptor, the address structure, - // and the length of the address structure - // This bind() call will bind the socket to the current IP address on port, portno - int iResult = bind(m_ListenSocket, - reinterpret_cast(&m_ServAddr), - sizeof(m_ServAddr)); - if (iResult < 0) { - if (m_eSettingsFlags & ENABLE_LOG) - m_oLog(StringFormat("[TCPServer][Error] bind failed : %s", strerror(errno))); - return false; - } +int CTCPServer::Listen(Socket& ClientSocket, size_t msec /*= ACCEPT_WAIT_INF_DELAY*/) +{ + ClientSocket = INVALID_SOCKET; + + bool isOK = false; + // creates a socket to listen for incoming client connections if it doesn't already exist + if (m_ListenSocket == INVALID_SOCKET) + { + if (m_pResultAddrInfo == nullptr && !InitAddrInfo()) + { + return -1; + } + + do + { + m_ListenSocket = socket(m_pResultAddrInfo->ai_family, m_pResultAddrInfo->ai_socktype, m_pResultAddrInfo->ai_protocol); + if (m_ListenSocket == INVALID_SOCKET) + { + SocketLog("[ERROR]TCPServer, create socket failed[%d:%s]", GetSocketError(), strerror(GetSocketError())); + break; + } + + // Allow the socket to be bound to an address that is already in use + int opt = 1; + if (setsockopt(m_ListenSocket, SOL_SOCKET, SO_REUSEADDR, reinterpret_cast(&opt), sizeof(opt)) == SOCKET_ERROR) + { + SocketLog("[ERROR]TCPServer, setsockopt SOL_SOCKET SO_REUSEADDR failed[%d:%s][%d]", GetSocketError(), strerror(GetSocketError()), m_ListenSocket); + break; + } + +#if 0 +#ifdef SO_KEEPALIVE + if (setsockopt(m_ListenSocket, SOL_SOCKET, SO_KEEPALIVE, reinterpret_cast(&opt), sizeof(opt)) == SOCKET_ERROR) + { + SocketLog("[ERROR]TCPServer, setsockopt SOL_SOCKET SO_KEEPALIVE failed[%d:%s][%d]", GetSocketError(), strerror(GetSocketError()), m_ListenSocket); + break; + } +#endif #endif - } - -#ifdef WINDOWS - sockaddr addrClient; - int iResult; - /* SOMAXCONN = allow max number of connexions in waiting */ - iResult = listen(m_ListenSocket, SOMAXCONN); - if (iResult == SOCKET_ERROR) - { - if (m_eSettingsFlags & ENABLE_LOG) - m_oLog(StringFormat("[TCPServer][Error] listen failed : %d", WSAGetLastError())); - closesocket(m_ListenSocket); - m_ListenSocket = INVALID_SOCKET; - return false; - } - - if (msec != ACCEPT_WAIT_INF_DELAY) - { - int ret = SelectSocket(m_ListenSocket, msec); - if (ret == 0) - { - if (m_eSettingsFlags & ENABLE_LOG) - m_oLog("[TCPServer][Error] CTCPServer::Listen : Timed out."); - - return false; - } - - if (ret == -1) - { - if (m_eSettingsFlags & ENABLE_LOG) - m_oLog("[TCPServer][Error] CTCPServer::Listen : Error selecting socket."); - - return false; - } - } - - // accept client connection, the returned socket will be used for I/O operations - int iAddrLen = sizeof(addrClient); - ClientSocket = accept(m_ListenSocket, &addrClient, &iAddrLen); - if (ClientSocket == INVALID_SOCKET) - { - if (m_eSettingsFlags & ENABLE_LOG) - m_oLog(StringFormat("[TCPServer][Error] accept failed : %d", WSAGetLastError())); - - return false; - } - - { - if (m_eSettingsFlags & ENABLE_LOG) - // TODO : a version that handles IPv6 - m_oLog( StringFormat("[TCPServer][Info] Incoming connection from '%s' port '%d'", - (addrClient.sa_family == AF_INET) ? inet_ntoa(((struct sockaddr_in*)&addrClient)->sin_addr) : "", - (addrClient.sa_family == AF_INET) ? ntohs(((struct sockaddr_in*)&addrClient)->sin_port) : 0)); - } - - //char buf1[256]; - //unsigned long len2 = 256UL; - //if (!WSAAddressToStringA(&addrClient, lenAddr, NULL, buf1, &len2)) - //if (m_eSettingsFlags & ENABLE_LOG) - //m_oLog(StringFormat("[TCPServer][Info] Connection from %s", buf1)); - -#else - // This listen() call tells the socket to listen to the incoming connections. - // The listen() function places all incoming connection into a backlog queue - // until accept() call accepts the connection. - // Here, we set the maximum size for the backlog queue to SOMAXCONN. - int iResult = listen(m_ListenSocket, SOMAXCONN); - if (iResult < 0) { - if (m_eSettingsFlags & ENABLE_LOG) - m_oLog(StringFormat("[TCPServer][Error] listen failed : %s", strerror(errno))); - - return false; - } - - if (msec != ACCEPT_WAIT_INF_DELAY) { - int ret = SelectSocket(m_ListenSocket, msec); - if (ret == 0) { - if (m_eSettingsFlags & ENABLE_LOG) - m_oLog("[TCPServer][Error] CTCPServer::Listen : Timed out."); - - return false; - } - - if (ret == -1) { - if (m_eSettingsFlags & ENABLE_LOG) - m_oLog("[TCPServer][Error] CTCPServer::Listen : Error selecting socket."); - - return false; - } - } - - struct sockaddr_in ClientAddr; - // The accept() call actually accepts an incoming connection - socklen_t uClientLen = sizeof(ClientAddr); - - // This accept() function will write the connecting client's address info - // into the the address structure and the size of that structure is uClientLen. - // The accept() returns a new socket file descriptor for the accepted connection. - // So, the original socket file descriptor can continue to be used - // for accepting new connections while the new socker file descriptor is used for - // communicating with the connected client. - ClientSocket = accept(m_ListenSocket, - reinterpret_cast(&ClientAddr), - &uClientLen); - - if (ClientSocket < 0) { - if (m_eSettingsFlags & ENABLE_LOG) - m_oLog(StringFormat("[TCPServer][Error] accept failed : %s", strerror(errno))); - - return false; - } - - if (m_eSettingsFlags & ENABLE_LOG) - m_oLog(StringFormat("[TCPServer][Info] Incoming connection from '%s' port '%d'", - inet_ntoa(ClientAddr.sin_addr), ntohs(ClientAddr.sin_port))); + + // bind the listen socket to the host address:port + if (bind(m_ListenSocket, m_pResultAddrInfo->ai_addr, static_cast(m_pResultAddrInfo->ai_addrlen)) == SOCKET_ERROR) + { + SocketLog("[ERROR]TCPServer, bind failed[%d:%s][%d]", GetSocketError(), strerror(GetSocketError()), m_ListenSocket); + break; + } + + // This listen() call tells the socket to listen to the incoming connections. + // The listen() function places all incoming connection into a backlog queue + // until accept() call accepts the connection. + // Here, we set the maximum size for the backlog queue to SOMAXCONN. + /* SOMAXCONN = allow max number of connexions in waiting */ + if (listen(m_ListenSocket, SOMAXCONN) == SOCKET_ERROR) + { + SocketLog("[ERROR]TCPServer, listen failed[%d:%s][%d]", GetSocketError(), strerror(GetSocketError()), m_ListenSocket); + break; + } + + SocketLog("[INFO ]TCPServer, listen succeed[%d]", m_ListenSocket); + isOK = true; + } while (0); + + // free memory allocated by getaddrinfo + //if (m_pResultAddrInfo != nullptr) + { + freeaddrinfo(m_pResultAddrInfo); + m_pResultAddrInfo = nullptr; + } + + if (!isOK/* && m_ListenSocket != INVALID_SOCKET*/) + { + SocketClose(m_ListenSocket); + return -1; + } + } + + do + { + //if (msec != ACCEPT_WAIT_INF_DELAY) + { + int ret = SelectSocket(m_ListenSocket, msec); + if (ret == SOCKET_ERROR) + { + SocketLog("[ERROR]TCPServer, select failed[%d:%s][%d]", GetSocketError(), strerror(GetSocketError()), m_ListenSocket); + break; + } + + if (ret == 0) + { + //SocketLog("[INFO ]TCPServer, select timeout[%d]", m_ListenSocket); + return 0; + } + } + + // This accept() function will write the connecting client's address info + // into the the address structure and the size of that structure is uClientLen. + // The accept() returns a new socket file descriptor for the accepted connection. + // So, the original socket file descriptor can continue to be used + // for accepting new connections while the new socker file descriptor is used for + // communicating with the connected client. + // accept client connection, the returned socket will be used for I/O operations + + sockaddr addrClient{}; + int iAddrLen = (int)sizeof(addrClient); + ClientSocket = accept(m_ListenSocket, &addrClient, &iAddrLen); + if (ClientSocket == INVALID_SOCKET) + { + int iErrCode = GetSocketError(); + if (SOCKET_ERR_ACCEPT_RETRIABLE(iErrCode)) + { + return 0; + } + + SocketLog("[ERROR]TCPServer, accept failed[%d:%s][%d]", GetSocketError(), strerror(GetSocketError()), m_ListenSocket); + break; + } + +#if 0 + { + SocketLog("[INFO ]TCPServer, Incoming connection from[%s:%u]", + (addrClient.sa_family == AF_INET) ? inet_ntoa(((struct sockaddr_in*)&addrClient)->sin_addr) : "", + (addrClient.sa_family == AF_INET) ? ntohs(((struct sockaddr_in*)&addrClient)->sin_port) : 0); + } +#endif + +#if 0 + char buf1[256] = {}; + unsigned long len2 = 256UL; + if (WSAAddressToStringA(&addrClient, lenAddr, NULL, buf1, &len2) != SOCKET_ERROR) + { + SocketLog("[INFO ]TCPServer, Connection from[%s]", buf1); + } #endif - return true; + SocketLog("[INFO ]TCPServer, client_sock be accepted[%d]", ClientSocket); + return 1; + } while (0); + + SocketClose(m_ListenSocket); + return -1; } /* ret > 0 : bytes received * ret == 0 : connection closed * ret < 0 : recv failed */ -int CTCPServer::Receive(const CTCPServer::Socket ClientSocket, - char* pData, - const size_t uSize, - bool bReadFully /*= true*/) const { - if (ClientSocket < 0 || !pData || !uSize) - return -1; - -#ifdef WINDOWS - int tries = 0; +int CTCPServer::Receive(Socket ClientSocket, char* pData, size_t uSize, bool bReadFully /*= true*/) const +{ + if (ClientSocket == INVALID_SOCKET) + { + SocketLog("[ERROR]TCPServer, recv failed[not a connection to server.]"); + return -1; + } + + if (pData == nullptr || uSize == 0) + { + SocketLog("[ERROR]TCPServer, recv failed[%d][%p:%zu]", ClientSocket, pData, uSize); + return -2; + } + +#if 0 +#ifdef _WIN32 + int tries = 0; #endif - - int total = 0; - do { - int nRecvd = recv(ClientSocket, pData + total, uSize - total, 0); - - if (nRecvd == 0) { - // peer shut down - break; - } - -#ifdef WINDOWS - if ((nRecvd < 0) && (WSAGetLastError() == WSAENOBUFS)) - { - // On long messages, Windows recv sometimes fails with WSAENOBUFS, but - // will work if you try again. - if ((tries++ < 1000)) - { - Sleep(1); - continue; - } - - if (m_eSettingsFlags & ENABLE_LOG) - m_oLog("[TCPServer][Error] Socket error in call to recv."); - - break; - } #endif - total += nRecvd; - - } while (bReadFully && (total < uSize)); - - return total; -} + int total = 0; + bool isOK = true; + do + { + isOK = true; + int nRecvd = recv(ClientSocket, pData + total, (int)uSize - total, 0); + if (nRecvd == SOCKET_ERROR) + { + isOK = false; + int iErrCode = GetSocketError(); + if (SOCKET_ERR_RW_RETRIABLE(iErrCode)) + { + continue; + } +#if 0 +#ifdef _WIN32 + // On long messages, Windows recv sometimes fails with WSAENOBUFS, but + // will work if you try again. + if (WSAGetLastError() == WSAENOBUFS && (tries++ < 1000)) + { + Sleep(1); + continue; + } +#endif +#endif -bool CTCPServer::Send(const Socket ClientSocket, const char* pData, size_t uSize) const { - if (ClientSocket < 0 || !pData || !uSize) - return false; + SocketLog("[ERROR]TCPServer, recv failed[%d:%s][%d]", iErrCode, strerror(iErrCode), ClientSocket); + break; + } - int total = 0; - do { - const int flags = 0; - int nSent; + if (nRecvd == 0) + { + SocketLog("[INFO ]TCPServer, peer shut down[%d]", ClientSocket); + break; + } - nSent = send(ClientSocket, pData + total, uSize - total, flags); + total += nRecvd; - if (nSent < 0) { - if (m_eSettingsFlags & ENABLE_LOG) - m_oLog("[TCPServer][Error] Socket error in call to send."); + } while (bReadFully && (total < (int)uSize)); - return false; - } - total += nSent; - } while (total < uSize); + if (!isOK && total == 0) + { + return -1; + } - return true; + return (int)total; } -bool CTCPServer::Send(const Socket ClientSocket, const std::string& strData) const { - return Send(ClientSocket, strData.c_str(), strData.length()); +int CTCPServer::Send(const Socket ClientSocket, const char* pData, size_t uSize) const +{ + if (ClientSocket == INVALID_SOCKET) + { + SocketLog("[ERROR]TCPServer, send failed[not a connection to server.]"); + return -1; + } + + if (pData == nullptr && uSize != 0) + { + SocketLog("[ERROR]TCPServer, send failed[%d][%p:%zu]", ClientSocket, pData, uSize); + return 0; + } + + int total = 0; + do + { + int nSent = send(ClientSocket, pData + total, (int)uSize - total, 0); + if (nSent == SOCKET_ERROR) + { + int iErrCode = GetSocketError(); + if (SOCKET_ERR_RW_RETRIABLE(iErrCode)) + { + continue; + } + + SocketLog("[ERROR]TCPServer, send failed[%d:%s][%d]", iErrCode, strerror(iErrCode), ClientSocket); + return -1; + } + + total += nSent; + } while (total < (int)uSize); + + return (int)total; } -bool CTCPServer::Send(const Socket ClientSocket, const std::vector& Data) const { - return Send(ClientSocket, Data.data(), Data.size()); +int CTCPServer::Send(const Socket ClientSocket, const std::string& strData) const +{ + return Send(ClientSocket, strData.c_str(), strData.length()); } -bool CTCPServer::Disconnect(const CTCPServer::Socket ClientSocket) const { -#ifdef WINDOWS - // The shutdown function disables sends or receives on a socket. - int iResult = shutdown(ClientSocket, SD_RECEIVE); - - if (iResult == SOCKET_ERROR) - { - if (m_eSettingsFlags & ENABLE_LOG) - m_oLog(StringFormat("[TCPServer][Error] shutdown failed : %d", WSAGetLastError())); - - return false; - } - - closesocket(ClientSocket); -#else - - close(ClientSocket); - -#endif - - return true; +int CTCPServer::Send(const Socket ClientSocket, const std::vector& Data) const +{ + return Send(ClientSocket, Data.data(), Data.size()); } -CTCPServer::~CTCPServer() { -#ifdef WINDOWS - // close listen socket - closesocket(m_ListenSocket); -#else - close(m_ListenSocket); +void CTCPServer::Disconnect(Socket& ClientSocket) const +{ + if (ClientSocket != INVALID_SOCKET) + { +#if 0//defined(_WIN32) + // The shutdown function disables sends or receives on a socket. + if (shutdown(ClientSocket, SD_RECEIVE) == SOCKET_ERROR) + { + SocketLog("[ERROR]TCPServer, shutdown SD_RECEIVE failed[%d:%s][%d]", GetSocketError(), strerror(GetSocketError()), ClientSocket); + } #endif + SocketClose(ClientSocket); + } } diff --git a/Socket/TCPServer.h b/Socket/TCPServer.h index 311242a..11e409b 100644 --- a/Socket/TCPServer.h +++ b/Socket/TCPServer.h @@ -9,74 +9,54 @@ #ifndef INCLUDE_TCPSERVER_H_ #define INCLUDE_TCPSERVER_H_ -#include -#include // size_t -#include -#include // strerror, strlen, memcpy, strcpy -#include -#include -#include -#include #include #include #include "Socket.h" -#ifdef WINDOWS -#undef min -#undef max -#endif +class CTCPSSLServer; class CTCPServer : public ASocket { + friend class CTCPSSLServer; public: - explicit CTCPServer(const LogFnCallback oLogger, - /*const std::string& strAddr,*/ - const std::string& strPort, - const SettingsFlag eSettings = ALL_FLAGS) - /*throw (EResolveError)*/; - - ~CTCPServer() override; - - // copy constructor and assignment operator are disabled - CTCPServer(const CTCPServer&) = delete; - CTCPServer& operator=(const CTCPServer&) = delete; - - /* returns the socket of the accepted client, the waiting period can be set */ - bool Listen(Socket& ClientSocket, size_t msec = ACCEPT_WAIT_INF_DELAY); - - int Receive(const Socket ClientSocket, - char* pData, - const size_t uSize, - bool bReadFully = true) const; - - bool Send(const Socket ClientSocket, const char* pData, const size_t uSize) const; - bool Send(const Socket ClientSocket, const std::string& strData) const; - bool Send(const Socket ClientSocket, const std::vector& Data) const; - - bool Disconnect(const Socket ClientSocket) const; - - bool SetRcvTimeout(ASocket::Socket& ClientSocket, unsigned int msec_timeout); - bool SetSndTimeout(ASocket::Socket& ClientSocket, unsigned int msec_timeout); - -#ifndef WINDOWS - bool SetRcvTimeout(ASocket::Socket& ClientSocket, struct timeval Timeout); - bool SetSndTimeout(ASocket::Socket& ClientSocket, struct timeval Timeout); + explicit CTCPServer(const LogFnCallback oLogger, + /*const std::string& strAddr,*/ const std::string& strPort, + const SettingsFlag eSettings = ALL_FLAGS) /*throw (EResolveError)*/; + ~CTCPServer() override; + + // copy constructor and assignment operator are disabled + CTCPServer(const CTCPServer&) = delete; + CTCPServer& operator=(const CTCPServer&) = delete; + + /* returns the socket of the accepted client, the waiting period can be set */ + int Listen(Socket& ClientSocket, size_t msec = ACCEPT_WAIT_INF_DELAY); + void Disconnect(Socket& ClientSocket) const; + + int Receive(Socket ClientSocket, char* pData, size_t uSize, bool bReadFully = true) const; + int Send(Socket ClientSocket, const char* pData, size_t uSize) const; + int Send(Socket ClientSocket, const std::string& strData) const; + int Send(Socket ClientSocket, const std::vector& Data) const; + + bool SetRcvTimeout(Socket ClientSocket, unsigned int msec_timeout); + bool SetSndTimeout(Socket ClientSocket, unsigned int msec_timeout); + +#ifndef _WIN32 + bool SetRcvTimeout(Socket ClientSocket, struct timeval timeout); + bool SetSndTimeout(Socket ClientSocket, struct timeval timeout); #endif -protected: - Socket m_ListenSocket; +private: + bool InitAddrInfo(); - //std::string m_strHost; - std::string m_strPort; +protected: + Socket m_ListenSocket; - #ifdef WINDOWS - struct addrinfo* m_pResultAddrInfo; - struct addrinfo m_HintsAddrInfo; - #else - struct sockaddr_in m_ServAddr; - #endif + //std::string m_strHost; + std::string m_strPort; + struct addrinfo* m_pResultAddrInfo; + struct addrinfo m_HintsAddrInfo; }; #endif