Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Non-blocking sockets support (TLS, WebSocket) #4774

Merged
merged 11 commits into from
Nov 16, 2024
43 changes: 37 additions & 6 deletions Net/include/Poco/Net/WebSocket.h
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,11 @@ class Net_API WebSocket: public StreamSocket
/// Note that special frames like PING must be handled at
/// application level. In the case of a PING, a PONG message
/// must be returned.
///
/// Once connected, a WebSocket can be put into non-blocking
/// mode, by calling setBlocking(false).
/// Please refer to the sendFrame() and receiveFrame() documentation
/// for non-blocking behavior.
{
public:
enum Mode
Expand Down Expand Up @@ -221,15 +226,21 @@ class Net_API WebSocket: public StreamSocket

#endif //POCO_NEW_STATE_ON_MOVE

void shutdown();
int shutdown();
/// Sends a Close control frame to the server end of
/// the connection to initiate an orderly shutdown
/// of the connection.
///
/// Returns the number of bytes sent or -1 if the socket
/// is non-blocking and the frame cannot be sent at this time.

void shutdown(Poco::UInt16 statusCode, const std::string& statusMessage = "");
int shutdown(Poco::UInt16 statusCode, const std::string& statusMessage = "");
/// Sends a Close control frame to the server end of
/// the connection to initiate an orderly shutdown
/// of the connection.
///
/// Returns the number of bytes sent or -1 if the socket
/// is non-blocking and the frame cannot be sent at this time.

int sendFrame(const void* buffer, int length, int flags = FRAME_TEXT);
/// Sends the contents of the given buffer through
Expand All @@ -238,11 +249,15 @@ class Net_API WebSocket: public StreamSocket
/// Values from the FrameFlags, FrameOpcodes and SendFlags enumerations
/// can be specified in flags.
///
/// Returns the number of bytes sent, which may be
/// less than the number of bytes specified.
/// Returns the number of bytes sent.
///
/// Certain socket implementations may also return a negative
/// value denoting a certain condition.
/// If the WebSocket is non-blocking and the frame could not
/// be sent in full, returns -1. In this case, the call
/// to sendFrame() must be repeated with exactly the same
/// parameters as soon as the socket becomes writable again
/// (see select() or poll()).
/// The value of length is returned after the complete
/// frame has been sent.

int receiveFrame(void* buffer, int length, int& flags);
/// Receives a frame from the socket and stores it
Expand Down Expand Up @@ -272,6 +287,14 @@ class Net_API WebSocket: public StreamSocket
///
/// The frame flags and opcode (FrameFlags and FrameOpcodes)
/// is stored in flags.
///
/// In case of a non-blocking socket, may return -1, even
/// if a partial frame has been received.
/// In this case, receiveFrame() should be called again as
/// soon as more data becomes available (see select() or poll()).
/// Eventually, receiveFrame() will return the complete frame.
/// The given buffer will not be modified until the full frame has
/// been received.

int receiveFrame(Poco::Buffer<char>& buffer, int& flags);
/// Receives a frame from the socket and stores it
Expand Down Expand Up @@ -314,6 +337,14 @@ class Net_API WebSocket: public StreamSocket
/// called on the buffer beforehand, if the expectation is that
/// the received data is stored starting at the beginning of the
/// buffer.
///
/// In case of a non-blocking socket, may return -1, even
/// if a partial frame has been received.
/// In this case, receiveFrame() should be called again as
/// soon as more data becomes available (see select() or poll()).
/// Eventually, receiveFrame() will return the complete frame.
/// The given buffer will not be modified until the full frame has
/// been received.

Mode mode() const;
/// Returns WS_SERVER if the WebSocket is a server-side
Expand Down
52 changes: 45 additions & 7 deletions Net/include/Poco/Net/WebSocketImpl.h
Original file line number Diff line number Diff line change
Expand Up @@ -42,12 +42,21 @@ class Net_API WebSocketImpl: public StreamSocketImpl
// StreamSocketImpl
virtual int sendBytes(const void* buffer, int length, int flags);
/// Sends a WebSocket protocol frame.
///
/// See WebSocket::sendFrame() for more information, including
/// behavior if set to non-blocking.

virtual int receiveBytes(void* buffer, int length, int flags);
/// Receives a WebSocket protocol frame.
///
/// See WebSocket::receiveFrame() for more information, including
/// behavior if set to non-blocking.

virtual int receiveBytes(Poco::Buffer<char>& buffer, int flags = 0, const Poco::Timespan& span = 0);
/// Receives a WebSocket protocol frame.
///
/// See WebSocket::receiveFrame() for more information, including
/// behavior if set to non-blocking.

virtual SocketImpl* acceptConnection(SocketAddress& clientAddr);
virtual void connect(const SocketAddress& address);
Expand All @@ -67,10 +76,16 @@ class Net_API WebSocketImpl: public StreamSocketImpl
virtual void sendUrgent(unsigned char data);
virtual int available();
virtual bool secure() const;
virtual void setSendBufferSize(int size);
virtual int getSendBufferSize();
virtual void setReceiveBufferSize(int size);
virtual int getReceiveBufferSize();
virtual void setSendTimeout(const Poco::Timespan& timeout);
virtual Poco::Timespan getSendTimeout();
virtual void setReceiveTimeout(const Poco::Timespan& timeout);
virtual Poco::Timespan getReceiveTimeout();
virtual void setBlocking(bool flag);
virtual bool getBlocking() const;

// Internal
int frameFlags() const;
Expand All @@ -93,13 +108,35 @@ class Net_API WebSocketImpl: public StreamSocketImpl
enum
{
FRAME_FLAG_MASK = 0x80,
MAX_HEADER_LENGTH = 14
MAX_HEADER_LENGTH = 14,
MASK_LENGTH = 4
};

struct ReceiveState
{
int frameFlags = 0;
bool useMask = false;
char mask[MASK_LENGTH];
int headerLength = 0;
int payloadLength = 0;
int remainingPayloadLength = 0;
Poco::Buffer<char> payload{0};
};

struct SendState
{
int length = 0;
int remainingPayloadOffset = 0;
int remainingPayloadLength = 0;
Poco::Buffer<char> payload{0};
};

int receiveHeader(char mask[4], bool& useMask);
int receivePayload(char *buffer, int payloadLength, char mask[4], bool useMask);
int receiveNBytes(void* buffer, int bytes);
int receiveSomeBytes(char* buffer, int bytes);
int peekHeader(ReceiveState& receiveState);
void skipHeader(int headerLength);
int receivePayload(char *buffer, int payloadLength, char mask[MASK_LENGTH], bool useMask);
int receiveNBytes(void* buffer, int length);
int receiveSomeBytes(char* buffer, int length);
int peekSomeBytes(char* buffer, int length);
virtual ~WebSocketImpl();

private:
Expand All @@ -109,8 +146,9 @@ class Net_API WebSocketImpl: public StreamSocketImpl
int _maxPayloadSize;
Poco::Buffer<char> _buffer;
int _bufferOffset;
int _frameFlags;
bool _mustMaskPayload;
ReceiveState _receiveState;
SendState _sendState;
Poco::Random _rnd;
};

Expand All @@ -120,7 +158,7 @@ class Net_API WebSocketImpl: public StreamSocketImpl
//
inline int WebSocketImpl::frameFlags() const
{
return _frameFlags;
return _receiveState.frameFlags;
}


Expand Down
88 changes: 69 additions & 19 deletions Net/src/SocketImpl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -361,24 +361,37 @@ void SocketImpl::checkBrokenTimeout(SelectMode mode)

int SocketImpl::sendBytes(const void* buffer, int length, int flags)
{
checkBrokenTimeout(SELECT_WRITE);

if (_blocking)
{
checkBrokenTimeout(SELECT_WRITE);
}
int rc;
do
{
if (_sockfd == POCO_INVALID_SOCKET) throw InvalidSocketException();
rc = ::send(_sockfd, reinterpret_cast<const char*>(buffer), length, flags);
}
while (_blocking && rc < 0 && lastError() == POCO_EINTR);
if (rc < 0) error();
if (rc < 0)
{
int err = lastError();
if (!_blocking && (err == POCO_EAGAIN || err == POCO_EWOULDBLOCK))
;
else if (err == POCO_EAGAIN || err == POCO_ETIMEDOUT)
throw TimeoutException(err);
else
error(err);
}
return rc;
}


int SocketImpl::sendBytes(const SocketBufVec& buffers, int flags)
{
checkBrokenTimeout(SELECT_WRITE);

if (_blocking)
{
checkBrokenTimeout(SELECT_WRITE);
}
int rc = 0;
do
{
Expand All @@ -395,15 +408,26 @@ int SocketImpl::sendBytes(const SocketBufVec& buffers, int flags)
#endif
}
while (_blocking && rc < 0 && lastError() == POCO_EINTR);
if (rc < 0) error();
if (rc < 0)
{
int err = lastError();
if (!_blocking && (err == POCO_EAGAIN || err == POCO_EWOULDBLOCK))
;
else if (err == POCO_EAGAIN || err == POCO_ETIMEDOUT)
throw TimeoutException(err);
else
error(err);
}
return rc;
}


int SocketImpl::receiveBytes(void* buffer, int length, int flags)
{
checkBrokenTimeout(SELECT_READ);

if (_blocking)
{
checkBrokenTimeout(SELECT_READ);
}
int rc;
do
{
Expand All @@ -414,7 +438,7 @@ int SocketImpl::receiveBytes(void* buffer, int length, int flags)
if (rc < 0)
{
int err = lastError();
if (err == POCO_EAGAIN && !_blocking)
if (!_blocking && (err == POCO_EAGAIN || err == POCO_EWOULDBLOCK))
;
else if (err == POCO_EAGAIN || err == POCO_ETIMEDOUT)
throw TimeoutException(err);
Expand All @@ -427,8 +451,10 @@ int SocketImpl::receiveBytes(void* buffer, int length, int flags)

int SocketImpl::receiveBytes(SocketBufVec& buffers, int flags)
{
checkBrokenTimeout(SELECT_READ);

if (_blocking)
{
checkBrokenTimeout(SELECT_READ);
}
int rc = 0;
do
{
Expand All @@ -448,7 +474,7 @@ int SocketImpl::receiveBytes(SocketBufVec& buffers, int flags)
if (rc < 0)
{
int err = lastError();
if (err == POCO_EAGAIN && !_blocking)
if (!_blocking && (err == POCO_EAGAIN || err == POCO_EWOULDBLOCK))
;
else if (err == POCO_EAGAIN || err == POCO_ETIMEDOUT)
throw TimeoutException(err);
Expand Down Expand Up @@ -476,7 +502,7 @@ int SocketImpl::receiveBytes(Poco::Buffer<char>& buffer, int flags, const Poco::
if (rc < 0)
{
int err = lastError();
if (err == POCO_EAGAIN && !_blocking)
if (!_blocking && (err == POCO_EAGAIN || err == POCO_EWOULDBLOCK))
;
else if (err == POCO_EAGAIN || err == POCO_ETIMEDOUT)
throw TimeoutException(err);
Expand All @@ -502,7 +528,16 @@ int SocketImpl::sendTo(const void* buffer, int length, const SocketAddress& addr
#endif
}
while (_blocking && rc < 0 && lastError() == POCO_EINTR);
if (rc < 0) error();
if (rc < 0)
{
int err = lastError();
if (!_blocking && (err == POCO_EAGAIN || err == POCO_EWOULDBLOCK))
;
else if (err == POCO_EAGAIN || err == POCO_ETIMEDOUT)
throw TimeoutException(err);
else
error(err);
}
return rc;
}

Expand Down Expand Up @@ -534,7 +569,16 @@ int SocketImpl::sendTo(const SocketBufVec& buffers, const SocketAddress& address
#endif
}
while (_blocking && rc < 0 && lastError() == POCO_EINTR);
if (rc < 0) error();
if (rc < 0)
{
int err = lastError();
if (!_blocking && (err == POCO_EAGAIN || err == POCO_EWOULDBLOCK))
;
else if (err == POCO_EAGAIN || err == POCO_ETIMEDOUT)
throw TimeoutException(err);
else
error(err);
}
return rc;
}

Expand All @@ -556,7 +600,10 @@ int SocketImpl::receiveFrom(void* buffer, int length, SocketAddress& address, in

int SocketImpl::receiveFrom(void* buffer, int length, struct sockaddr** ppSA, poco_socklen_t** ppSALen, int flags)
{
checkBrokenTimeout(SELECT_READ);
if (_blocking)
{
checkBrokenTimeout(SELECT_READ);
}
int rc;
do
{
Expand All @@ -567,7 +614,7 @@ int SocketImpl::receiveFrom(void* buffer, int length, struct sockaddr** ppSA, po
if (rc < 0)
{
int err = lastError();
if (err == POCO_EAGAIN && !_blocking)
if (!_blocking && (err == POCO_EAGAIN || err == POCO_EWOULDBLOCK))
;
else if (err == POCO_EAGAIN || err == POCO_ETIMEDOUT)
throw TimeoutException(err);
Expand Down Expand Up @@ -595,7 +642,10 @@ int SocketImpl::receiveFrom(SocketBufVec& buffers, SocketAddress& address, int f

int SocketImpl::receiveFrom(SocketBufVec& buffers, struct sockaddr** pSA, poco_socklen_t** ppSALen, int flags)
{
checkBrokenTimeout(SELECT_READ);
if (_blocking)
{
checkBrokenTimeout(SELECT_READ);
}
int rc = 0;
do
{
Expand Down Expand Up @@ -624,7 +674,7 @@ int SocketImpl::receiveFrom(SocketBufVec& buffers, struct sockaddr** pSA, poco_s
if (rc < 0)
{
int err = lastError();
if (err == POCO_EAGAIN && !_blocking)
if (!_blocking && (err == POCO_EAGAIN || err == POCO_EWOULDBLOCK))
;
else if (err == POCO_EAGAIN || err == POCO_ETIMEDOUT)
throw TimeoutException(err);
Expand Down
2 changes: 1 addition & 1 deletion Net/src/StreamSocketImpl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ int StreamSocketImpl::sendBytes(const void* buffer, int length, int flags)
while (remaining > 0)
{
int n = SocketImpl::sendBytes(p, remaining, flags);
poco_assert_dbg (n >= 0);
poco_assert_dbg (!blocking || n >= 0);
p += n;
sent += n;
remaining -= n;
Expand Down
Loading
Loading