diff --git a/lib/pure/net.nim b/lib/pure/net.nim index f5309585471ce..c0bfc39cc7d5a 100644 --- a/lib/pure/net.nim +++ b/lib/pure/net.nim @@ -75,6 +75,9 @@ export Domain, SockType, Protocol const useWinVersion = defined(Windows) or defined(nimdoc) const defineSsl = defined(ssl) or defined(nimdoc) +when useWinVersion: + from winlean import WSAESHUTDOWN + when defineSsl: import openssl @@ -187,6 +190,7 @@ proc isDisconnectionError*(flags: set[SocketFlag], lastError.int32 == WSAECONNABORTED or lastError.int32 == WSAENETRESET or lastError.int32 == WSAEDISCON or + lastError.int32 == WSAESHUTDOWN or lastError.int32 == ERROR_NETNAME_DELETED) else: SocketFlag.SafeDisconn in flags and @@ -1031,8 +1035,72 @@ proc accept*(server: Socket, client: var owned(Socket), var addrDummy = "" acceptAddr(server, client, addrDummy, flags) -proc close*(socket: Socket) = +when defined(posix): + from posix import Sigset, sigwait, sigismember, sigemptyset, sigaddset, + sigprocmask, pthread_sigmask, SIGPIPE, SIG_BLOCK, SIG_UNBLOCK + +template blockSigpipe(body: untyped): untyped = + ## Temporary block SIGPIPE within the provided code block. If SIGPIPE is + ## raised for the duration of the code block, it will be queued and will be + ## raised once the block ends. + ## + ## Within the block a `selectSigpipe()` template is provided which can be + ## used to remove SIGPIPE from the queue. Note that if SIGPIPE is **not** + ## raised at the time of call, it will block until SIGPIPE is raised. + ## + ## If SIGPIPE has already been blocked at the time of execution, the + ## signal mask is left as-is and `selectSigpipe()` will become a no-op. + ## + ## For convenience, this template is also available for non-POSIX system, + ## where `body` will be executed as-is. + when not defined(posix): + body + else: + template sigmask(how: cint, set, oset: var Sigset): untyped {.gensym.} = + ## Alias for pthread_sigmask or sigprocmask depending on the status + ## of --threads + when compileOption("threads"): + pthread_sigmask(how, set, oset) + else: + sigprocmask(how, set, oset) + + var oldSet, watchSet: Sigset + if sigemptyset(oldSet) == -1: + raiseOSError(osLastError()) + if sigemptyset(watchSet) == -1: + raiseOSError(osLastError()) + + if sigaddset(watchSet, SIGPIPE) == -1: + raiseOSError(osLastError(), "Couldn't add SIGPIPE to Sigset") + + if sigmask(SIG_BLOCK, watchSet, oldSet) == -1: + raiseOSError(osLastError(), "Couldn't block SIGPIPE") + + let alreadyBlocked = sigismember(oldSet, SIGPIPE) == 1 + + template selectSigpipe(): untyped {.used.} = + if not alreadyBlocked: + var signal: cint + let err = sigwait(watchSet, signal) + if err != 0: + raiseOSError(err.OSErrorCode, "Couldn't select SIGPIPE") + assert signal == SIGPIPE + + try: + body + finally: + if not alreadyBlocked: + if sigmask(SIG_UNBLOCK, watchSet, oldSet) == -1: + raiseOSError(osLastError(), "Couldn't unblock SIGPIPE") + +proc close*(socket: Socket, flags = {SocketFlag.SafeDisconn}) = ## Closes a socket. + ## + ## If `socket` is an SSL/TLS socket, this proc will also send a closure + ## notification to the peer. If `SafeDisconn` is in `flags`, failure to do so + ## due to disconnections will be ignored. This is generally safe in + ## practice. See + ## `here `_ for more details. try: when defineSsl: if socket.isSsl and socket.sslHandle != nil: @@ -1044,12 +1112,34 @@ proc close*(socket: Socket) = # it is valid, under the TLS standard, to perform a unidirectional # shutdown i.e not wait for the peers "close notify" alert with a second # call to SSL_shutdown - ErrClearError() - let res = SSL_shutdown(socket.sslHandle) - if res == 0: - discard - elif res != 1: - socketError(socket, res) + blockSigpipe: + ErrClearError() + let res = SSL_shutdown(socket.sslHandle) + if res == 0: + discard + elif res != 1: + let + err = osLastError() + sslError = SSL_get_error(socket.sslHandle, res) + + # If a close notification is received, failures outside of the + # protocol will be returned as SSL_ERROR_ZERO_RETURN instead + # of SSL_ERROR_SYSCALL. This fact is deduced by digging into + # SSL_get_error() source code. + if sslError == SSL_ERROR_ZERO_RETURN or + sslError == SSL_ERROR_SYSCALL: + when defined(posix) and not defined(macosx) and + not defined(nimdoc): + if err == EPIPE.OSErrorCode: + # Clear the SIGPIPE that's been raised due to + # the disconnection. + selectSigpipe() + else: + discard + if not flags.isDisconnectionError(err): + socketError(socket, res, lastError = err, flags = flags) + else: + socketError(socket, res, lastError = err, flags = flags) finally: when defineSsl: if socket.isSsl and socket.sslHandle != nil: @@ -1470,7 +1560,7 @@ proc recvFrom*(socket: Socket, data: var string, length: int, var addrLen = sizeof(sockAddress).SockLen result = recvfrom(socket.fd, cstring(data), length.cint, flags.cint, cast[ptr SockAddr](addr(sockAddress)), addr(addrLen)) - + if result != -1: data.setLen(result) address = getAddrString(cast[ptr SockAddr](addr(sockAddress))) diff --git a/lib/windows/winlean.nim b/lib/windows/winlean.nim index bd4cfdca7920f..c28700374c883 100644 --- a/lib/windows/winlean.nim +++ b/lib/windows/winlean.nim @@ -813,6 +813,7 @@ const WSAEINPROGRESS* = 10036 WSAEINTR* = 10004 WSAEWOULDBLOCK* = 10035 + WSAESHUTDOWN* = 10058 ERROR_NETNAME_DELETED* = 64 STATUS_PENDING* = 0x103 diff --git a/tests/stdlib/tssl.nim b/tests/stdlib/tssl.nim index cda3cb5906b07..3c2b6be2bca7c 100644 --- a/tests/stdlib/tssl.nim +++ b/tests/stdlib/tssl.nim @@ -5,13 +5,16 @@ discard """ import net, nativesockets when defined(posix): import os, posix +else: + import winlean + const SD_SEND = 1 when not defined(ssl): {.error: "This test must be compiled with -d:ssl".} const DummyData = "dummy data\n" -proc connector(port: Port) {.thread.} = +proc abruptShutdown(port: Port) {.thread.} = let clientContext = newContext(verifyMode = CVerifyNone) var client = newSocket(buffered = false) clientContext.wrapSocket(client) @@ -20,11 +23,16 @@ proc connector(port: Port) {.thread.} = discard client.recvLine() client.getFd.close() -proc main() = - let serverContext = newContext(verifyMode = CVerifyNone, - certFile = "tests/testdata/mycert.pem", - keyFile = "tests/testdata/mycert.pem") +proc notifiedShutdown(port: Port) {.thread.} = + let clientContext = newContext(verifyMode = CVerifyNone) + var client = newSocket(buffered = false) + clientContext.wrapSocket(client) + client.connect("localhost", port) + + discard client.recvLine() + client.close() +proc main() = when defined(posix): var ignoreAction = SigAction(sa_handler: SIG_IGN) @@ -34,7 +42,11 @@ proc main() = if sigaction(SIGPIPE, ignoreAction, oldSigPipeHandler) == -1: raiseOSError(osLastError(), "Couldn't ignore SIGPIPE") - block peer_close_without_shutdown: + let serverContext = newContext(verifyMode = CVerifyNone, + certFile = "tests/testdata/mycert.pem", + keyFile = "tests/testdata/mycert.pem") + + block peer_close_during_write_without_shutdown: var server = newSocket(buffered = false) defer: server.close() serverContext.wrapSocket(server) @@ -43,7 +55,7 @@ proc main() = server.listen() var clientThread: Thread[Port] - createThread(clientThread, connector, port) + createThread(clientThread, abruptShutdown, port) var peer: Socket try: @@ -60,4 +72,63 @@ proc main() = finally: peer.close() + when defined(posix): + if sigaction(SIGPIPE, oldSigPipeHandler, nil) == -1: + raiseOSError(osLastError(), "Couldn't restore SIGPIPE handler") + + block peer_close_before_received_shutdown: + var server = newSocket(buffered = false) + defer: server.close() + serverContext.wrapSocket(server) + server.bindAddr(address = "localhost") + let (_, port) = server.getLocalAddr() + server.listen() + + var clientThread: Thread[Port] + createThread(clientThread, abruptShutdown, port) + + var peer: Socket + try: + server.accept(peer) + peer.send(DummyData) + + joinThread clientThread + + # Tell the OS to close off the write side so shutdown attempts will + # be met with SIGPIPE. + when defined(posix): + discard peer.getFd.shutdown(SHUT_WR) + else: + discard peer.getFd.shutdown(SD_SEND) + finally: + peer.close() + + block peer_close_after_received_shutdown: + var server = newSocket(buffered = false) + defer: server.close() + serverContext.wrapSocket(server) + server.bindAddr(address = "localhost") + let (_, port) = server.getLocalAddr() + server.listen() + + var clientThread: Thread[Port] + createThread(clientThread, notifiedShutdown, port) + + var peer: Socket + try: + server.accept(peer) + peer.send(DummyData) + + doAssert peer.recv(1024) == "" # Get the shutdown notification + joinThread clientThread + + # Tell the OS to close off the write side so shutdown attempts will + # be met with SIGPIPE. + when defined(posix): + discard peer.getFd.shutdown(SHUT_WR) + else: + discard peer.getFd.shutdown(SD_SEND) + finally: + peer.close() + when isMainModule: main()