Skip to content

Commit

Permalink
Better handle ECONNRESET on connected datagram sockets. (apple#2979)
Browse files Browse the repository at this point in the history
Motivation:

When a connected datagram socket sends a datagram to a port or host that
is not listening, an ICMP Destination Unreachable message may be
returned. That message triggers an ECONNRESET to be produced at the
socket layer.

On Darwin we handled this well, but on Linux it turned out that this
would push us into connection teardown and, eventually, into a crash.
Not so good.

To better handle this, we need to distinguish EPOLLERR from EPOLLHUP on
datagram sockets. In these cases, we should check whether the socket
error was fatal and, if it was not, we should continue our execution
having fired the error down the pipeline.

Modifications:

Modify the selector code to distinguish reset and error. Add support for
our channels to handle errors.
Have most channels handle errors as resets.
Override the logic for datagram channels to duplicate the logic in
readable.
Add a unit test.

Result:

Better datagrams for all.

(cherry picked from commit 2a3a333)
  • Loading branch information
Lukasa committed Nov 22, 2024
1 parent 4015e33 commit 70b2d6e
Show file tree
Hide file tree
Showing 12 changed files with 175 additions and 59 deletions.
13 changes: 11 additions & 2 deletions Sources/NIOPosix/BaseSocketChannel.swift
Original file line number Diff line number Diff line change
Expand Up @@ -1091,7 +1091,8 @@ class BaseSocketChannel<SocketType: BaseSocketProtocol>: SelectableChannel, Chan
let result: Int32 = try self.socket.getOption(level: .socket, name: .so_error)
if result != 0 {
// we have a socket error, let's forward
// this path will be executed on Linux (EPOLLERR) & Darwin (ev.fflags != 0)
// this path will be executed on Linux (EPOLLERR) & Darwin (ev.fflags != 0) for
// stream sockets, and most (but not all) errors on datagram sockets
error = IOError(errnoCode: result, reason: "connection reset (error set)")
} else {
// we don't have a socket error, this must be connection reset without an error then
Expand Down Expand Up @@ -1209,6 +1210,14 @@ class BaseSocketChannel<SocketType: BaseSocketProtocol>: SelectableChannel, Chan
true
}

/// Handles an error reported by the selector.
///
/// Default behaviour is to treat this as if it were a reset.
func error() -> ErrorResult {
self.reset()
return .fatal
}

internal final func updateCachedAddressesFromSocket(updateLocal: Bool = true, updateRemote: Bool = true) {
self.eventLoop.assertInEventLoop()
assert(updateLocal || updateRemote)
Expand Down Expand Up @@ -1331,7 +1340,7 @@ class BaseSocketChannel<SocketType: BaseSocketProtocol>: SelectableChannel, Chan
// The initial set of interested events must not contain `.readEOF` because when connect doesn't return
// synchronously, kevent might send us a `readEOF` because the `writable` event that marks the connect as completed.
// See SocketChannelTest.testServerClosesTheConnectionImmediately for a regression test.
try self.safeRegister(interested: [.reset])
try self.safeRegister(interested: [.reset, .error])
self.lifecycleManager.finishRegistration()(nil, self.pipeline)
}

Expand Down
8 changes: 4 additions & 4 deletions Sources/NIOPosix/PipeChannel.swift
Original file line number Diff line number Diff line change
Expand Up @@ -68,15 +68,15 @@ final class PipeChannel: BaseStreamSocketChannel<PipePair> {
if let inputFD = self.pipePair.inputFD {
try selector.register(
selectable: inputFD,
interested: interested.intersection([.read, .reset]),
interested: interested.intersection([.read, .reset, .error]),
makeRegistration: self.registrationForInput
)
}

if let outputFD = self.pipePair.outputFD {
try selector.register(
selectable: outputFD,
interested: interested.intersection([.write, .reset]),
interested: interested.intersection([.write, .reset, .error]),
makeRegistration: self.registrationForOutput
)
}
Expand All @@ -95,13 +95,13 @@ final class PipeChannel: BaseStreamSocketChannel<PipePair> {
if let inputFD = self.pipePair.inputFD, inputFD.isOpen {
try selector.reregister(
selectable: inputFD,
interested: interested.intersection([.read, .reset])
interested: interested.intersection([.read, .reset, .error])
)
}
if let outputFD = self.pipePair.outputFD, outputFD.isOpen {
try selector.reregister(
selectable: outputFD,
interested: interested.intersection([.write, .reset])
interested: interested.intersection([.write, .reset, .error])
)
}
}
Expand Down
8 changes: 8 additions & 0 deletions Sources/NIOPosix/SelectableChannel.swift
Original file line number Diff line number Diff line change
Expand Up @@ -44,9 +44,17 @@ internal protocol SelectableChannel: Channel {
/// Called when the `SelectableChannel` was reset (ie. is now unusable)
func reset()

/// Called when the `SelectableChannel` had an error reported on the selector.
func error() -> ErrorResult

func register(selector: Selector<NIORegistration>, interested: SelectorEventSet) throws

func deregister(selector: Selector<NIORegistration>, mode: CloseMode) throws

func reregister(selector: Selector<NIORegistration>, interested: SelectorEventSet) throws
}

internal enum ErrorResult {
case fatal
case nonFatal
}
18 changes: 15 additions & 3 deletions Sources/NIOPosix/SelectableEventLoop.swift
Original file line number Diff line number Diff line change
Expand Up @@ -443,6 +443,18 @@ internal final class SelectableEventLoop: EventLoop {
if ev.contains(.reset) {
channel.reset()
} else {
if ev.contains(.error) {
switch channel.error() {
case .fatal:
return
case .nonFatal:
break
}

guard channel.isOpen else {
return
}
}
if ev.contains(.writeEOF) {
channel.writeEOF()

Expand Down Expand Up @@ -746,10 +758,10 @@ internal final class SelectableEventLoop: EventLoop {
self.handleEvent(ev.io, channel: chan)
case .pipeChannel(let chan, let direction):
var ev = ev
if ev.io.contains(.reset) {
// .reset needs special treatment here because we're dealing with two separate pipes instead
if ev.io.contains(.reset) || ev.io.contains(.error) {
// .reset and .error needs special treatment here because we're dealing with two separate pipes instead
// of one socket. So we turn .reset input .readEOF/.writeEOF.
ev.io.subtract([.reset])
ev.io.subtract([.reset, .error])
ev.io.formUnion([direction == .input ? .readEOF : .writeEOF])
}
self.handleEvent(ev.io, channel: chan)
Expand Down
5 changes: 4 additions & 1 deletion Sources/NIOPosix/SelectorEpoll.swift
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,10 @@ extension SelectorEventSet {
if epollEvent.events & Epoll.EPOLLRDHUP != 0 {
selectorEventSet.formUnion(.readEOF)
}
if epollEvent.events & Epoll.EPOLLHUP != 0 || epollEvent.events & Epoll.EPOLLERR != 0 {
if epollEvent.events & Epoll.EPOLLERR != 0 {
selectorEventSet.formUnion(.error)
}
if epollEvent.events & Epoll.EPOLLHUP != 0 {
selectorEventSet.formUnion(.reset)
}
self = selectorEventSet
Expand Down
12 changes: 9 additions & 3 deletions Sources/NIOPosix/SelectorGeneric.swift
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ struct SelectorEventSet: OptionSet, Equatable {
/// of flags or to compare against spurious wakeups.
static let _none = SelectorEventSet([])

/// Connection reset or other errors.
/// Connection reset.
static let reset = SelectorEventSet(rawValue: 1 << 0)

/// EOF at the read/input end of a `Selectable`.
Expand All @@ -79,6 +79,9 @@ struct SelectorEventSet: OptionSet, Equatable {
/// - Note: This is rarely used because in many cases, there is no signal that this happened.
static let writeEOF = SelectorEventSet(rawValue: 1 << 4)

/// Error encountered.
static let error = SelectorEventSet(rawValue: 1 << 5)

init(rawValue: SelectorEventSet.RawValue) {
self.rawValue = rawValue
}
Expand Down Expand Up @@ -237,7 +240,7 @@ internal class Selector<R: Registration> {
makeRegistration: (SelectorEventSet, SelectorRegistrationID) -> R
) throws {
assert(self.myThread == NIOThread.current)
assert(interested.contains(.reset))
assert(interested.contains([.reset, .error]))
guard self.lifecycleState == .open else {
throw IOError(errnoCode: EBADF, reason: "can't register on selector as it's \(self.lifecycleState).")
}
Expand Down Expand Up @@ -265,7 +268,10 @@ internal class Selector<R: Registration> {
guard self.lifecycleState == .open else {
throw IOError(errnoCode: EBADF, reason: "can't re-register on selector as it's \(self.lifecycleState).")
}
assert(interested.contains(.reset), "must register for at least .reset but tried registering for \(interested)")
assert(
interested.contains([.reset, .error]),
"must register for at least .reset & .error but tried registering for \(interested)"
)
try selectable.withUnsafeHandle { fd in
var reg = registrations[Int(fd)]!
try self.reregister0(
Expand Down
2 changes: 1 addition & 1 deletion Sources/NIOPosix/SelectorKqueue.swift
Original file line number Diff line number Diff line change
Expand Up @@ -240,7 +240,7 @@ extension Selector: _SelectorBackendProtocol {
) throws {
try kqueueUpdateEventNotifications(
selectable: selectable,
interested: .reset,
interested: [.reset, .error],
oldInterested: oldInterested,
registrationID: registrationID
)
Expand Down
31 changes: 27 additions & 4 deletions Sources/NIOPosix/SocketChannel.swift
Original file line number Diff line number Diff line change
Expand Up @@ -840,10 +840,8 @@ final class DatagramChannel: BaseSocketChannel<Socket> {
#endif
}

override func shouldCloseOnReadError(_ err: Error) -> Bool {
guard let err = err as? IOError else { return true }

switch err.errnoCode {
private func shouldCloseOnErrnoCode(_ errnoCode: CInt) -> Bool {
switch errnoCode {
// ECONNREFUSED can happen on linux if the previous sendto(...) failed.
// See also:
// - https://bugzilla.redhat.com/show_bug.cgi?id=1375
Expand All @@ -857,6 +855,31 @@ final class DatagramChannel: BaseSocketChannel<Socket> {
}
}

override func shouldCloseOnReadError(_ err: Error) -> Bool {
guard let err = err as? IOError else { return true }
return self.shouldCloseOnErrnoCode(err.errnoCode)
}

override func error() -> ErrorResult {
// Assume we can get the error from the socket.
do {
let errnoCode: CInt = try self.socket.getOption(level: .socket, name: .so_error)
if self.shouldCloseOnErrnoCode(errnoCode) {
self.reset()
return .fatal
} else {
self.pipeline.syncOperations.fireErrorCaught(
IOError(errnoCode: errnoCode, reason: "so_error")
)
return .nonFatal
}
} catch {
// Unknown error, fatal.
self.reset()
return .fatal
}
}

/// Buffer a write in preparation for a flush.
///
/// When the channel is unconnected, `data` _must_ be of type `AddressedEnvelope<ByteBuffer>`.
Expand Down
55 changes: 55 additions & 0 deletions Tests/NIOPosixTests/DatagramChannelTests.swift
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,17 @@ extension Channel {
}.wait()
}

func waitForErrors(count: Int) throws -> [any Error] {
try self.pipeline.context(name: "ByteReadRecorder").flatMap { context in
if let future = (context.handler as? DatagramReadRecorder<ByteBuffer>)?.notifyForErrors(count) {
return future
}

XCTFail("Could not wait for errors")
return self.eventLoop.makeSucceededFuture([])
}.wait()
}

func readCompleteCount() throws -> Int {
try self.pipeline.context(name: "ByteReadRecorder").map { context in
(context.handler as! DatagramReadRecorder<ByteBuffer>).readCompleteCount
Expand Down Expand Up @@ -66,10 +77,12 @@ final class DatagramReadRecorder<DataType>: ChannelInboundHandler {
}

var reads: [AddressedEnvelope<DataType>] = []
var errors: [any Error] = []
var loop: EventLoop? = nil
var state: State = .fresh

var readWaiters: [Int: EventLoopPromise<[AddressedEnvelope<DataType>]>] = [:]
var errorWaiters: [Int: EventLoopPromise<[any Error]>] = [:]
var readCompleteCount = 0

func channelRegistered(context: ChannelHandlerContext) {
Expand All @@ -95,6 +108,16 @@ final class DatagramReadRecorder<DataType>: ChannelInboundHandler {
context.fireChannelRead(Self.wrapInboundOut(data))
}

func errorCaught(context: ChannelHandlerContext, error: any Error) {
self.errors.append(error)

if let promise = self.errorWaiters.removeValue(forKey: self.errors.count) {
promise.succeed(self.errors)
}

context.fireErrorCaught(error)
}

func channelReadComplete(context: ChannelHandlerContext) {
self.readCompleteCount += 1
context.fireChannelReadComplete()
Expand All @@ -108,6 +131,15 @@ final class DatagramReadRecorder<DataType>: ChannelInboundHandler {
readWaiters[count] = loop!.makePromise()
return readWaiters[count]!.futureResult
}

func notifyForErrors(_ count: Int) -> EventLoopFuture<[any Error]> {
guard self.errors.count < count else {
return self.loop!.makeSucceededFuture(.init(self.errors.prefix(count)))
}

self.errorWaiters[count] = self.loop!.makePromise()
return self.errorWaiters[count]!.futureResult
}
}

class DatagramChannelTests: XCTestCase {
Expand Down Expand Up @@ -1715,6 +1747,29 @@ class DatagramChannelTests: XCTestCase {
}
}

func testShutdownReadOnConnectedUDP() throws {
var buffer = self.firstChannel.allocator.buffer(capacity: 256)
buffer.writeStaticString("hello, world!")

// Connect and write
XCTAssertNoThrow(try self.firstChannel.connect(to: self.secondChannel.localAddress!).wait())

let writeData = AddressedEnvelope(remoteAddress: self.secondChannel.localAddress!, data: buffer)
XCTAssertNoThrow(try self.firstChannel.writeAndFlush(writeData).wait())
_ = try self.secondChannel.waitForDatagrams(count: 1)

// Ok, close on the second channel.
XCTAssertNoThrow(try self.secondChannel.close(mode: .all).wait())
print("closed")

// Write again.
XCTAssertNoThrow(try self.firstChannel.writeAndFlush(writeData).wait())

// This should trigger an error.
let errors = try self.firstChannel.waitForErrors(count: 1)
XCTAssertEqual((errors[0] as? IOError)?.errnoCode, ECONNREFUSED)
}

private func hasGoodGROSupport() throws -> Bool {
// Source code for UDP_GRO was added in Linux 5.0. However, this support is somewhat limited
// and some sources indicate support was actually added in 5.10 (perhaps more widely
Expand Down
Loading

0 comments on commit 70b2d6e

Please sign in to comment.