Skip to content

Commit

Permalink
net: allow close() to ignore SSL failures due to disconnections (#15120)
Browse files Browse the repository at this point in the history
* net: allow close() to ignore SSL failures due to disconnections

Comes with this PR is also a SIGPIPE handling contraption.

* net: don't do selectSigpipe() on macOS

macOS sockets have SO_NOSIGPIPE set, so an EPIPE doesn't necessary mean
that a SIGPIPE happened.

* net: fix alreadyBlocked logic

* net: WSAESHUTDOWN is also a disconnection error
  • Loading branch information
alaviss authored Aug 1, 2020
1 parent 3ce32a7 commit c619ced
Show file tree
Hide file tree
Showing 3 changed files with 177 additions and 15 deletions.
106 changes: 98 additions & 8 deletions lib/pure/net.nim
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,9 @@ export Domain, SockType, Protocol
const useWinVersion = defined(Windows) or defined(nimdoc)
const defineSsl = defined(ssl) or defined(nimdoc)

when useWinVersion:
from winlean import WSAESHUTDOWN

when defineSsl:
import openssl

Expand Down Expand Up @@ -187,6 +190,7 @@ proc isDisconnectionError*(flags: set[SocketFlag],
lastError.int32 == WSAECONNABORTED or
lastError.int32 == WSAENETRESET or
lastError.int32 == WSAEDISCON or
lastError.int32 == WSAESHUTDOWN or
lastError.int32 == ERROR_NETNAME_DELETED)
else:
SocketFlag.SafeDisconn in flags and
Expand Down Expand Up @@ -1031,8 +1035,72 @@ proc accept*(server: Socket, client: var owned(Socket),
var addrDummy = ""
acceptAddr(server, client, addrDummy, flags)

proc close*(socket: Socket) =
when defined(posix):
from posix import Sigset, sigwait, sigismember, sigemptyset, sigaddset,
sigprocmask, pthread_sigmask, SIGPIPE, SIG_BLOCK, SIG_UNBLOCK

template blockSigpipe(body: untyped): untyped =
## Temporary block SIGPIPE within the provided code block. If SIGPIPE is
## raised for the duration of the code block, it will be queued and will be
## raised once the block ends.
##
## Within the block a `selectSigpipe()` template is provided which can be
## used to remove SIGPIPE from the queue. Note that if SIGPIPE is **not**
## raised at the time of call, it will block until SIGPIPE is raised.
##
## If SIGPIPE has already been blocked at the time of execution, the
## signal mask is left as-is and `selectSigpipe()` will become a no-op.
##
## For convenience, this template is also available for non-POSIX system,
## where `body` will be executed as-is.
when not defined(posix):
body
else:
template sigmask(how: cint, set, oset: var Sigset): untyped {.gensym.} =
## Alias for pthread_sigmask or sigprocmask depending on the status
## of --threads
when compileOption("threads"):
pthread_sigmask(how, set, oset)
else:
sigprocmask(how, set, oset)

var oldSet, watchSet: Sigset
if sigemptyset(oldSet) == -1:
raiseOSError(osLastError())
if sigemptyset(watchSet) == -1:
raiseOSError(osLastError())

if sigaddset(watchSet, SIGPIPE) == -1:
raiseOSError(osLastError(), "Couldn't add SIGPIPE to Sigset")

if sigmask(SIG_BLOCK, watchSet, oldSet) == -1:
raiseOSError(osLastError(), "Couldn't block SIGPIPE")

let alreadyBlocked = sigismember(oldSet, SIGPIPE) == 1

template selectSigpipe(): untyped {.used.} =
if not alreadyBlocked:
var signal: cint
let err = sigwait(watchSet, signal)
if err != 0:
raiseOSError(err.OSErrorCode, "Couldn't select SIGPIPE")
assert signal == SIGPIPE

try:
body
finally:
if not alreadyBlocked:
if sigmask(SIG_UNBLOCK, watchSet, oldSet) == -1:
raiseOSError(osLastError(), "Couldn't unblock SIGPIPE")

proc close*(socket: Socket, flags = {SocketFlag.SafeDisconn}) =
## Closes a socket.
##
## If `socket` is an SSL/TLS socket, this proc will also send a closure
## notification to the peer. If `SafeDisconn` is in `flags`, failure to do so
## due to disconnections will be ignored. This is generally safe in
## practice. See
## `here <https://security.stackexchange.com/a/82044>`_ for more details.
try:
when defineSsl:
if socket.isSsl and socket.sslHandle != nil:
Expand All @@ -1044,12 +1112,34 @@ proc close*(socket: Socket) =
# it is valid, under the TLS standard, to perform a unidirectional
# shutdown i.e not wait for the peers "close notify" alert with a second
# call to SSL_shutdown
ErrClearError()
let res = SSL_shutdown(socket.sslHandle)
if res == 0:
discard
elif res != 1:
socketError(socket, res)
blockSigpipe:
ErrClearError()
let res = SSL_shutdown(socket.sslHandle)
if res == 0:
discard
elif res != 1:
let
err = osLastError()
sslError = SSL_get_error(socket.sslHandle, res)

# If a close notification is received, failures outside of the
# protocol will be returned as SSL_ERROR_ZERO_RETURN instead
# of SSL_ERROR_SYSCALL. This fact is deduced by digging into
# SSL_get_error() source code.
if sslError == SSL_ERROR_ZERO_RETURN or
sslError == SSL_ERROR_SYSCALL:
when defined(posix) and not defined(macosx) and
not defined(nimdoc):
if err == EPIPE.OSErrorCode:
# Clear the SIGPIPE that's been raised due to
# the disconnection.
selectSigpipe()
else:
discard
if not flags.isDisconnectionError(err):
socketError(socket, res, lastError = err, flags = flags)
else:
socketError(socket, res, lastError = err, flags = flags)
finally:
when defineSsl:
if socket.isSsl and socket.sslHandle != nil:
Expand Down Expand Up @@ -1470,7 +1560,7 @@ proc recvFrom*(socket: Socket, data: var string, length: int,
var addrLen = sizeof(sockAddress).SockLen
result = recvfrom(socket.fd, cstring(data), length.cint, flags.cint,
cast[ptr SockAddr](addr(sockAddress)), addr(addrLen))

if result != -1:
data.setLen(result)
address = getAddrString(cast[ptr SockAddr](addr(sockAddress)))
Expand Down
1 change: 1 addition & 0 deletions lib/windows/winlean.nim
Original file line number Diff line number Diff line change
Expand Up @@ -813,6 +813,7 @@ const
WSAEINPROGRESS* = 10036
WSAEINTR* = 10004
WSAEWOULDBLOCK* = 10035
WSAESHUTDOWN* = 10058
ERROR_NETNAME_DELETED* = 64
STATUS_PENDING* = 0x103

Expand Down
85 changes: 78 additions & 7 deletions tests/stdlib/tssl.nim
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,16 @@ discard """
import net, nativesockets

when defined(posix): import os, posix
else:
import winlean
const SD_SEND = 1

when not defined(ssl):
{.error: "This test must be compiled with -d:ssl".}

const DummyData = "dummy data\n"

proc connector(port: Port) {.thread.} =
proc abruptShutdown(port: Port) {.thread.} =
let clientContext = newContext(verifyMode = CVerifyNone)
var client = newSocket(buffered = false)
clientContext.wrapSocket(client)
Expand All @@ -20,11 +23,16 @@ proc connector(port: Port) {.thread.} =
discard client.recvLine()
client.getFd.close()

proc main() =
let serverContext = newContext(verifyMode = CVerifyNone,
certFile = "tests/testdata/mycert.pem",
keyFile = "tests/testdata/mycert.pem")
proc notifiedShutdown(port: Port) {.thread.} =
let clientContext = newContext(verifyMode = CVerifyNone)
var client = newSocket(buffered = false)
clientContext.wrapSocket(client)
client.connect("localhost", port)

discard client.recvLine()
client.close()

proc main() =
when defined(posix):
var
ignoreAction = SigAction(sa_handler: SIG_IGN)
Expand All @@ -34,7 +42,11 @@ proc main() =
if sigaction(SIGPIPE, ignoreAction, oldSigPipeHandler) == -1:
raiseOSError(osLastError(), "Couldn't ignore SIGPIPE")

block peer_close_without_shutdown:
let serverContext = newContext(verifyMode = CVerifyNone,
certFile = "tests/testdata/mycert.pem",
keyFile = "tests/testdata/mycert.pem")

block peer_close_during_write_without_shutdown:
var server = newSocket(buffered = false)
defer: server.close()
serverContext.wrapSocket(server)
Expand All @@ -43,7 +55,7 @@ proc main() =
server.listen()

var clientThread: Thread[Port]
createThread(clientThread, connector, port)
createThread(clientThread, abruptShutdown, port)

var peer: Socket
try:
Expand All @@ -60,4 +72,63 @@ proc main() =
finally:
peer.close()

when defined(posix):
if sigaction(SIGPIPE, oldSigPipeHandler, nil) == -1:
raiseOSError(osLastError(), "Couldn't restore SIGPIPE handler")

block peer_close_before_received_shutdown:
var server = newSocket(buffered = false)
defer: server.close()
serverContext.wrapSocket(server)
server.bindAddr(address = "localhost")
let (_, port) = server.getLocalAddr()
server.listen()

var clientThread: Thread[Port]
createThread(clientThread, abruptShutdown, port)

var peer: Socket
try:
server.accept(peer)
peer.send(DummyData)

joinThread clientThread

# Tell the OS to close off the write side so shutdown attempts will
# be met with SIGPIPE.
when defined(posix):
discard peer.getFd.shutdown(SHUT_WR)
else:
discard peer.getFd.shutdown(SD_SEND)
finally:
peer.close()

block peer_close_after_received_shutdown:
var server = newSocket(buffered = false)
defer: server.close()
serverContext.wrapSocket(server)
server.bindAddr(address = "localhost")
let (_, port) = server.getLocalAddr()
server.listen()

var clientThread: Thread[Port]
createThread(clientThread, notifiedShutdown, port)

var peer: Socket
try:
server.accept(peer)
peer.send(DummyData)

doAssert peer.recv(1024) == "" # Get the shutdown notification
joinThread clientThread

# Tell the OS to close off the write side so shutdown attempts will
# be met with SIGPIPE.
when defined(posix):
discard peer.getFd.shutdown(SHUT_WR)
else:
discard peer.getFd.shutdown(SD_SEND)
finally:
peer.close()

when isMainModule: main()

0 comments on commit c619ced

Please sign in to comment.