From af1345a92142c19cf125101e219999debf52bd5c Mon Sep 17 00:00:00 2001 From: Luke Howard Date: Mon, 11 Nov 2024 12:19:54 +1100 Subject: [PATCH 1/8] Spelling fix: s/cancellAll/cancelAll/ --- FlyingSocks/Sources/SocketPool.swift | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/FlyingSocks/Sources/SocketPool.swift b/FlyingSocks/Sources/SocketPool.swift index 0f8a03e3..c42feadf 100644 --- a/FlyingSocks/Sources/SocketPool.swift +++ b/FlyingSocks/Sources/SocketPool.swift @@ -93,7 +93,7 @@ public final actor SocketPool: AsyncSocketPool { public func run() async throws { guard state == .ready else { throw Error("Not Ready") } state = .running - defer { cancellAll() } + defer { cancelAll() } repeat { if waiting.isEmpty { @@ -151,11 +151,11 @@ public final actor SocketPool: AsyncSocketPool { case complete } - private func cancellAll() { - logger.logInfo("SocketPoll cancellAll") + private func cancelAll() { + logger.logInfo("SocketPoll cancelAll") try? queue.stop() state = .complete - waiting.cancellAll() + waiting.cancelAll() waiting = Waiting() if let loop { self.loop = nil @@ -270,7 +270,7 @@ public final actor SocketPool: AsyncSocketPool { } } - mutating func cancellAll() { + mutating func cancelAll() { let continuations = storage.values.flatMap(\.values).map(\.continuation) storage = [:] for continuation in continuations { From 4a4cad865ac29b44b32d47b26bb17ce246b15b20 Mon Sep 17 00:00:00 2001 From: Luke Howard Date: Tue, 12 Nov 2024 12:50:34 +1100 Subject: [PATCH 2/8] Add SocketType enum for public API Insulate callers from libc types by adding a `SocketType` enumerated type, that can either be `stream` or `datagram`. --- FlyingFox/Tests/AsyncSocketTests.swift | 2 +- FlyingSocks/Sources/AsyncSocket.swift | 6 +++--- FlyingSocks/Sources/Socket.swift | 27 +++++++++++++++++++++--- FlyingSocks/Tests/AsyncSocketTests.swift | 4 ++-- FlyingSocks/Tests/SocketPoolTests.swift | 8 +++---- FlyingSocks/Tests/SocketTests.swift | 19 ++++++----------- 6 files changed, 40 insertions(+), 26 deletions(-) diff --git a/FlyingFox/Tests/AsyncSocketTests.swift b/FlyingFox/Tests/AsyncSocketTests.swift index 3de1df59..f9a8c641 100644 --- a/FlyingFox/Tests/AsyncSocketTests.swift +++ b/FlyingFox/Tests/AsyncSocketTests.swift @@ -40,7 +40,7 @@ extension AsyncSocket { } static func make(pool: some AsyncSocketPool) throws -> AsyncSocket { - let socket = try Socket(domain: AF_UNIX, type: Socket.stream) + let socket = try Socket(domain: AF_UNIX, type: .stream) return try AsyncSocket(socket: socket, pool: pool) } diff --git a/FlyingSocks/Sources/AsyncSocket.swift b/FlyingSocks/Sources/AsyncSocket.swift index 045f070b..9dc29a74 100644 --- a/FlyingSocks/Sources/AsyncSocket.swift +++ b/FlyingSocks/Sources/AsyncSocket.swift @@ -83,7 +83,7 @@ public struct AsyncSocket: Sendable { pool: some AsyncSocketPool, timeout: TimeInterval = 5) async throws -> Self { try await withThrowingTimeout(seconds: timeout) { - let socket = try Socket(domain: Int32(type(of: address).family), type: Socket.stream) + let socket = try Socket(domain: Int32(type(of: address).family), type: .stream) let asyncSocket = try AsyncSocket(socket: socket, pool: pool) try await asyncSocket.connect(to: address) return asyncSocket @@ -178,8 +178,8 @@ public struct AsyncSocket: Sendable { package extension AsyncSocket { - static func makePair(pool: some AsyncSocketPool) throws -> (AsyncSocket, AsyncSocket) { - let (s1, s2) = try Socket.makePair() + static func makePair(pool: some AsyncSocketPool, type: SocketType = .stream) throws -> (AsyncSocket, AsyncSocket) { + let (s1, s2) = try Socket.makePair(type: type) let a1 = try AsyncSocket(socket: s1, pool: pool) let a2 = try AsyncSocket(socket: s2, pool: pool) return (a1, a2) diff --git a/FlyingSocks/Sources/Socket.swift b/FlyingSocks/Sources/Socket.swift index 7de00526..469cf116 100644 --- a/FlyingSocks/Sources/Socket.swift +++ b/FlyingSocks/Sources/Socket.swift @@ -36,6 +36,22 @@ import WinSDK.WinSock2 #endif import Foundation +public enum SocketType: Sendable { + case stream + case datagram +} + +extension SocketType { + var rawValue: Int32 { + switch self { + case .stream: + Socket.stream + case .datagram: + Socket.datagram + } + } +} + public struct Socket: Sendable, Hashable { public let file: FileDescriptor @@ -53,9 +69,10 @@ public struct Socket: Sendable, Hashable { } public init(domain: Int32) throws { - try self.init(domain: domain, type: Socket.stream) + try self.init(domain: domain, type: .stream) } + @available(*, deprecated, message: "type is now SocketType") public init(domain: Int32, type: Int32) throws { let descriptor = FileDescriptor(rawValue: Socket.socket(domain, type, 0)) guard descriptor != .invalid else { @@ -64,6 +81,10 @@ public struct Socket: Sendable, Hashable { self.file = descriptor } + public init(domain: Int32, type: SocketType) throws { + try self.init(domain: domain, type: type.rawValue) + } + public var flags: Flags { get throws { let flags = Socket.fcntl(file.rawValue, F_GETFL) @@ -364,7 +385,7 @@ package extension Socket { return (s1, s2) } - static func makeNonBlockingPair() throws -> (Socket, Socket) { - try Socket.makePair(flags: .nonBlocking) + static func makeNonBlockingPair(type: SocketType = .stream) throws -> (Socket, Socket) { + try Socket.makePair(flags: .nonBlocking, type: type) } } diff --git a/FlyingSocks/Tests/AsyncSocketTests.swift b/FlyingSocks/Tests/AsyncSocketTests.swift index ac2d872f..62152048 100644 --- a/FlyingSocks/Tests/AsyncSocketTests.swift +++ b/FlyingSocks/Tests/AsyncSocketTests.swift @@ -198,7 +198,7 @@ extension AsyncSocket { static func makeListening(pool: some AsyncSocketPool) throws -> AsyncSocket { let address = sockaddr_un.unix(path: #function) try? Socket.unlink(address) - let socket = try Socket(domain: AF_UNIX, type: Socket.stream) + let socket = try Socket(domain: AF_UNIX, type: .stream) try socket.setValue(true, for: .localAddressReuse) try socket.bind(to: address) try socket.listen() @@ -206,7 +206,7 @@ extension AsyncSocket { } static func make(pool: some AsyncSocketPool) throws -> AsyncSocket { - let socket = try Socket(domain: AF_UNIX, type: Socket.stream) + let socket = try Socket(domain: AF_UNIX, type: .stream) return try AsyncSocket(socket: socket, pool: pool) } diff --git a/FlyingSocks/Tests/SocketPoolTests.swift b/FlyingSocks/Tests/SocketPoolTests.swift index 24e39a2b..51a588f7 100644 --- a/FlyingSocks/Tests/SocketPoolTests.swift +++ b/FlyingSocks/Tests/SocketPoolTests.swift @@ -85,7 +85,7 @@ struct SocketPoolTests { try await pool.prepare() let task = Task { - let socket = try Socket(domain: AF_UNIX, type: Socket.stream) + let socket = try Socket(domain: AF_UNIX, type: .stream) try await pool.suspendSocket(socket, untilReadyFor: .read) } @@ -125,7 +125,7 @@ struct SocketPoolTests { try? await task.value - let socket = try Socket(domain: AF_UNIX, type: Socket.stream) + let socket = try Socket(domain: AF_UNIX, type: .stream) await #expect(throws: (any Error).self) { try await pool.suspendSocket(socket, untilReadyFor: .read) } @@ -138,7 +138,7 @@ struct SocketPoolTests { let task = Task { try await pool.run() } defer { task.cancel() } - let socket = try Socket(domain: AF_UNIX, type: Socket.stream) + let socket = try Socket(domain: AF_UNIX, type: .stream) let suspension = Task { try await pool.suspendSocket(socket, untilReadyFor: .read) } @@ -160,7 +160,7 @@ struct SocketPoolTests { let task = Task { try await pool.run() } defer { task.cancel() } - let socket = try Socket(domain: AF_UNIX, type: Socket.stream) + let socket = try Socket(domain: AF_UNIX, type: .stream) let suspension = Task { try await pool.suspendSocket(socket, untilReadyFor: .read) } diff --git a/FlyingSocks/Tests/SocketTests.swift b/FlyingSocks/Tests/SocketTests.swift index 2cbe27cc..68b6dc01 100644 --- a/FlyingSocks/Tests/SocketTests.swift +++ b/FlyingSocks/Tests/SocketTests.swift @@ -98,7 +98,7 @@ struct SocketTests { @Test func socketWrite_Throws_WhenSocketIsNotConnected() async throws { - let s1 = try Socket(domain: AF_UNIX, type: Socket.stream) + let s1 = try Socket(domain: AF_UNIX, type: .stream) let data = Data(repeating: 0x01, count: 100) #expect(throws: SocketError.self) { try s1.write(data, from: data.startIndex) @@ -108,7 +108,7 @@ struct SocketTests { @Test func socket_Sets_And_Gets_ReceiveBufferSize() throws { - let socket = try Socket(domain: AF_UNIX, type: Socket.stream) + let socket = try Socket(domain: AF_UNIX, type: .stream) try socket.setValue(2048, for: .receiveBufferSize) #if canImport(Darwin) @@ -121,7 +121,7 @@ struct SocketTests { @Test func socket_Sets_And_Gets_SendBufferSizeOption() throws { - let socket = try Socket(domain: AF_UNIX, type: Socket.stream) + let socket = try Socket(domain: AF_UNIX, type: .stream) try socket.setValue(2048, for: .sendBufferSize) #if canImport(Darwin) @@ -134,7 +134,7 @@ struct SocketTests { @Test func socket_Sets_And_Gets_BoolOption() throws { - let socket = try Socket(domain: AF_UNIX, type: Socket.stream) + let socket = try Socket(domain: AF_UNIX, type: .stream) try socket.setValue(true, for: .localAddressReuse) #expect(try socket.getValue(for: .localAddressReuse)) @@ -145,20 +145,13 @@ struct SocketTests { @Test func socket_Sets_And_Gets_Flags() throws { - let socket = try Socket(domain: AF_UNIX, type: Socket.stream) + let socket = try Socket(domain: AF_UNIX, type: .stream) #expect(try socket.flags.contains(.append) == false) try socket.setFlags(.append) #expect(try socket.flags.contains(.append)) } - @Test - func socketInit_ThrowsError_WhenInvalid() { - #expect(throws: SocketError.self) { - _ = try Socket(domain: -1, type: -1) - } - } - @Test func socketAccept_ThrowsError_WhenInvalid() { let socket = Socket(file: -1) @@ -193,7 +186,7 @@ struct SocketTests { @Test func socketBind_ToINET() throws { - let socket = try Socket(domain: AF_INET, type: Socket.stream) + let socket = try Socket(domain: AF_INET, type: .stream) try socket.setValue(true, for: .localAddressReuse) let address = Socket.makeAddressINET(port:5050) #expect(throws: Never.self) { From a8eb609fd164cec28222de00c7b98a667d759ebb Mon Sep 17 00:00:00 2001 From: Luke Howard Date: Mon, 11 Nov 2024 12:20:14 +1100 Subject: [PATCH 3/8] Support for datagram sockets (UDP) Add support for datagram sockets, wrapping sendto() and recvfrom(). Note: send() and recv() are not supported, so a destination address must be supplied; however write() and read() can be used with datagram sockets. Fixes: #128 --- FlyingSocks/Sources/AsyncSocket.swift | 55 ++++++++++++++++++ FlyingSocks/Sources/Socket+Android.swift | 9 +++ FlyingSocks/Sources/Socket+Darwin.swift | 9 +++ FlyingSocks/Sources/Socket+Glibc.swift | 9 +++ FlyingSocks/Sources/Socket+Musl.swift | 9 +++ FlyingSocks/Sources/Socket+WinSock2.swift | 9 +++ FlyingSocks/Sources/Socket.swift | 56 ++++++++++++++++++- FlyingSocks/Sources/SocketAddress.swift | 25 +++++++++ FlyingSocks/Tests/AsyncSocketTests.swift | 45 +++++++++++++-- .../Tests/FileManager+TemporaryFile.swift | 42 ++++++++++++++ FlyingSocks/Tests/SocketAddressTests.swift | 56 ++++++++++++------- 11 files changed, 298 insertions(+), 26 deletions(-) create mode 100644 FlyingSocks/Tests/FileManager+TemporaryFile.swift diff --git a/FlyingSocks/Sources/AsyncSocket.swift b/FlyingSocks/Sources/AsyncSocket.swift index 9dc29a74..4afe2209 100644 --- a/FlyingSocks/Sources/AsyncSocket.swift +++ b/FlyingSocks/Sources/AsyncSocket.swift @@ -129,6 +129,20 @@ public struct AsyncSocket: Sendable { return buffer } + public func receive(atMost length: Int = 4096) async throws -> (sockaddr_storage, [UInt8]) { + try Task.checkCancellation() + + repeat { + do { + return try socket.receive(length: length) + } catch SocketError.blocked { + try await pool.suspendSocket(socket, untilReadyFor: .read) + } catch { + throw error + } + } while true + } + /// Reads bytes from the socket up to by not over/ /// - Parameter bytes: The max number of bytes to read /// - Returns: an array of the read bytes capped to the number of bytes provided. @@ -163,6 +177,19 @@ public struct AsyncSocket: Sendable { } } + public func send(_ data: [UInt8], to address: some SocketAddress) async throws { + let sent = try await pool.loopUntilReady(for: .write, on: socket) { + try socket.send(data, to: address) + } + guard sent == data.count else { + throw SocketError.disconnected + } + } + + public func send(_ data: Data, to address: some SocketAddress) async throws { + try await send(Array(data), to: address) + } + public func close() throws { try socket.close() } @@ -174,6 +201,14 @@ public struct AsyncSocket: Sendable { public var sockets: AsyncSocketSequence { AsyncSocketSequence(socket: self) } + + public var messages: AsyncSocketMessageSequence { + AsyncSocketMessageSequence(socket: self) + } + + public func messages(maxMessageLength: Int) -> AsyncSocketMessageSequence { + AsyncSocketMessageSequence(socket: self, maxMessageLength: maxMessageLength) + } } package extension AsyncSocket { @@ -237,6 +272,26 @@ public struct AsyncSocketSequence: AsyncSequence, AsyncIteratorProtocol, Sendabl } } +public struct AsyncSocketMessageSequence: AsyncSequence, AsyncIteratorProtocol, Sendable { + public static let DefaultMaxMessageLength: Int = 1500 + + public typealias Element = (sockaddr_storage, [UInt8]) + + private let socket: AsyncSocket + private let maxMessageLength: Int + + public func makeAsyncIterator() -> AsyncSocketMessageSequence { self } + + init(socket: AsyncSocket, maxMessageLength: Int = Self.DefaultMaxMessageLength) { + self.socket = socket + self.maxMessageLength = maxMessageLength + } + + public mutating func next() async throws -> Element? { + return try await socket.receive(atMost: maxMessageLength) + } +} + private actor ClientPoolLoader { static let shared = ClientPoolLoader() diff --git a/FlyingSocks/Sources/Socket+Android.swift b/FlyingSocks/Sources/Socket+Android.swift index fb154a8e..4574267d 100644 --- a/FlyingSocks/Sources/Socket+Android.swift +++ b/FlyingSocks/Sources/Socket+Android.swift @@ -45,6 +45,7 @@ extension Socket.FileDescriptor { extension Socket { static let stream = Int32(SOCK_STREAM) + static let datagram = Int32(SOCK_DGRAM) static let in_addr_any = Android.in_addr(s_addr: Android.in_addr_t(0)) static func makeAddressINET(port: UInt16) -> Android.sockaddr_in { @@ -175,6 +176,14 @@ extension Socket { static func pollfd(fd: FileDescriptorType, events: Int16, revents: Int16) -> Android.pollfd { Android.pollfd(fd: fd, events: events, revents: revents) } + + static func recvfrom(_ fd: FileDescriptorType, _ buffer: UnsafeMutableRawPointer!, _ nbyte: Int, _ flags: Int32, _ addr: UnsafeMutablePointer!, _ len: UnsafeMutablePointer!) -> Int { + Android.recvfrom(fd, buffer, nbyte, flags, addr, len) + } + + static func sendto(_ fd: FileDescriptorType, _ buffer: UnsafeRawPointer!, _ nbyte: Int, _ flags: Int32, _ destaddr: UnsafePointer!, _ destlen: socklen_t) -> Int { + Android.sendto(fd, buffer, nbyte, flags, destaddr, destlen) + } } #endif diff --git a/FlyingSocks/Sources/Socket+Darwin.swift b/FlyingSocks/Sources/Socket+Darwin.swift index efa3a643..f4829f4b 100644 --- a/FlyingSocks/Sources/Socket+Darwin.swift +++ b/FlyingSocks/Sources/Socket+Darwin.swift @@ -42,6 +42,7 @@ extension Socket.FileDescriptor { extension Socket { static let stream = Int32(SOCK_STREAM) + static let datagram = Int32(SOCK_DGRAM) static let in_addr_any = Darwin.in_addr(s_addr: Darwin.in_addr_t(0)) static func makeAddressINET(port: UInt16) -> Darwin.sockaddr_in { @@ -176,6 +177,14 @@ extension Socket { static func pollfd(fd: FileDescriptorType, events: Int16, revents: Int16) -> Darwin.pollfd { Darwin.pollfd(fd: fd, events: events, revents: revents) } + + static func recvfrom(_ fd: FileDescriptorType, _ buffer: UnsafeMutableRawPointer!, _ nbyte: Int, _ flags: Int32, _ addr: UnsafeMutablePointer!, _ len: UnsafeMutablePointer!) -> Int { + Darwin.recvfrom(fd, buffer, nbyte, flags, addr, len) + } + + static func sendto(_ fd: FileDescriptorType, _ buffer: UnsafeRawPointer!, _ nbyte: Int, _ flags: Int32, _ destaddr: UnsafePointer!, _ destlen: socklen_t) -> Int { + Darwin.sendto(fd, buffer, nbyte, flags, destaddr, destlen) + } } #endif diff --git a/FlyingSocks/Sources/Socket+Glibc.swift b/FlyingSocks/Sources/Socket+Glibc.swift index 9880ef8c..cc4aeafc 100644 --- a/FlyingSocks/Sources/Socket+Glibc.swift +++ b/FlyingSocks/Sources/Socket+Glibc.swift @@ -42,6 +42,7 @@ extension Socket.FileDescriptor { extension Socket { static let stream = Int32(SOCK_STREAM.rawValue) + static let datagram = Int32(SOCK_DGRAM.rawValue) static let in_addr_any = Glibc.in_addr(s_addr: Glibc.in_addr_t(0)) static func makeAddressINET(port: UInt16) -> Glibc.sockaddr_in { @@ -172,6 +173,14 @@ extension Socket { static func pollfd(fd: FileDescriptorType, events: Int16, revents: Int16) -> Glibc.pollfd { Glibc.pollfd(fd: fd, events: events, revents: revents) } + + static func recvfrom(_ fd: FileDescriptorType, _ buffer: UnsafeMutableRawPointer!, _ nbyte: Int, _ flags: Int32, _ addr: UnsafeMutablePointer!, _ len: UnsafeMutablePointer!) -> Int { + Glibc.recvfrom(fd, buffer, nbyte, flags, addr, len) + } + + static func sendto(_ fd: FileDescriptorType, _ buffer: UnsafeRawPointer!, _ nbyte: Int, _ flags: Int32, _ destaddr: UnsafePointer!, _ destlen: socklen_t) -> Int { + Glibc.sendto(fd, buffer, nbyte, flags, destaddr, destlen) + } } #endif diff --git a/FlyingSocks/Sources/Socket+Musl.swift b/FlyingSocks/Sources/Socket+Musl.swift index 5f285fd4..2de01f32 100644 --- a/FlyingSocks/Sources/Socket+Musl.swift +++ b/FlyingSocks/Sources/Socket+Musl.swift @@ -42,6 +42,7 @@ extension Socket.FileDescriptor { extension Socket { static let stream = Int32(SOCK_STREAM) + static let datagram = Int32(SOCK_DGRAM) static let in_addr_any = Musl.in_addr(s_addr: Musl.in_addr_t(0)) static func makeAddressINET(port: UInt16) -> Musl.sockaddr_in { @@ -172,6 +173,14 @@ extension Socket { static func pollfd(fd: FileDescriptorType, events: Int16, revents: Int16) -> Musl.pollfd { Musl.pollfd(fd: fd, events: events, revents: revents) } + + static func recvfrom(_ fd: FileDescriptorType, _ buffer: UnsafeMutableRawPointer!, _ nbyte: Int, _ flags: Int32, _ addr: UnsafeMutablePointer!, _ len: UnsafeMutablePointer!) -> Int { + Musl.recvfrom(fd, buffer, nbyte, flags, addr, len) + } + + static func sendto(_ fd: FileDescriptorType, _ buffer: UnsafeRawPointer!, _ nbyte: Int, _ flags: Int32, _ destaddr: UnsafePointer!, _ destlen: socklen_t) -> Int { + Musl.sendto(fd, buffer, nbyte, flags, destaddr, destlen) + } } #endif diff --git a/FlyingSocks/Sources/Socket+WinSock2.swift b/FlyingSocks/Sources/Socket+WinSock2.swift index 0c2c086f..177670d3 100755 --- a/FlyingSocks/Sources/Socket+WinSock2.swift +++ b/FlyingSocks/Sources/Socket+WinSock2.swift @@ -52,6 +52,7 @@ extension Socket.FileDescriptor { extension Socket { static let stream = Int32(SOCK_STREAM) + static let datagram = Int32(SOCK_DGRAM) static let in_addr_any = WinSDK.in_addr() static func makeAddressINET(port: UInt16) -> WinSDK.sockaddr_in { @@ -184,6 +185,14 @@ extension Socket { static func pollfd(fd: FileDescriptorType, events: Int16, revents: Int16) -> WinSDK.WSAPOLLFD { WinSDK.WSAPOLLFD(fd: fd, events: events, revents: revents) } + + static func recvfrom(_ fd: FileDescriptorType, _ buffer: UnsafeMutableRawPointer!, _ nbyte: Int, _ flags: Int32, _ addr: UnsafeMutablePointer!, _ len: UnsafeMutablePointer!) -> Int { + WinSDK.recvfrom(fd, buffer, nbyte, flags, addr, len) + } + + static func sendto(_ fd: FileDescriptorType, _ buffer: UnsafeRawPointer!, _ nbyte: Int, _ flags: Int32, _ destaddr: UnsafePointer!, _ destlen: socklen_t) -> Int { + WinSDK.sendto(fd, buffer, nbyte, flags, destaddr, destlen) + } } #endif diff --git a/FlyingSocks/Sources/Socket.swift b/FlyingSocks/Sources/Socket.swift index 469cf116..9c7574d8 100644 --- a/FlyingSocks/Sources/Socket.swift +++ b/FlyingSocks/Sources/Socket.swift @@ -235,6 +235,35 @@ public struct Socket: Sendable, Hashable { return count } + public func receive(length: Int) throws -> (sockaddr_storage, [UInt8]) { + var address: sockaddr_storage? + let bytes = try [UInt8](unsafeUninitializedCapacity: length) { buffer, count in + (address, count) = try receive(into: buffer.baseAddress!, length: length) + } + + return (address!, bytes) + } + + private func receive(into buffer: UnsafeMutablePointer, length: Int) throws -> (sockaddr_storage, Int) { + var addr = sockaddr_storage() + var size = socklen_t(MemoryLayout.size) + let count = withUnsafeMutablePointer(to: &addr) { + $0.withMemoryRebound(to: sockaddr.self, capacity: 1) { + Socket.recvfrom(file.rawValue, buffer, length, 0, $0, &size) + } + } + guard count > 0 else { + if errno == EWOULDBLOCK || errno == EAGAIN { + throw SocketError.blocked + } else if errno == EBADF || count == 0 { + throw SocketError.disconnected + } else { + throw SocketError.makeFailed("RecvFrom") + } + } + return (addr, count) + } + public func write(_ data: Data, from index: Data.Index = 0) throws -> Data.Index { precondition(index >= 0) guard index < data.endIndex else { return data.endIndex } @@ -258,6 +287,29 @@ public struct Socket: Sendable, Hashable { return sent } + public func send(_ bytes: [UInt8], to address: some SocketAddress) throws -> Int { + try bytes.withUnsafeBytes { buffer in + try send(buffer.baseAddress!, length: bytes.count, to: address) + } + } + + private func send(_ pointer: UnsafeRawPointer, length: Int, to address: some SocketAddress) throws -> Int { + var addr = address + let sent = withUnsafePointer(to: &addr) { + $0.withMemoryRebound(to: sockaddr.self, capacity: 1) { + Socket.sendto(file.rawValue, pointer, length, 0, $0, address.size) + } + } + guard sent >= 0 else { + if errno == EWOULDBLOCK || errno == EAGAIN { + throw SocketError.blocked + } else { + throw SocketError.makeFailed("SendTo") + } + } + return sent + } + public func close() throws { if Socket.close(file.rawValue) == -1 { throw SocketError.makeFailed("Close") @@ -370,8 +422,8 @@ public extension SocketOption where Self == Int32SocketOption { package extension Socket { - static func makePair(flags: Flags? = nil) throws -> (Socket, Socket) { - let (file1, file2) = Socket.socketpair(AF_UNIX, Socket.stream, 0) + static func makePair(flags: Flags? = nil, type: SocketType = .stream) throws -> (Socket, Socket) { + let (file1, file2) = Socket.socketpair(AF_UNIX, type.rawValue, 0) guard file1 > -1, file2 > -1 else { throw SocketError.makeFailed("SocketPair") } diff --git a/FlyingSocks/Sources/SocketAddress.swift b/FlyingSocks/Sources/SocketAddress.swift index 68839421..e43ae962 100644 --- a/FlyingSocks/Sources/SocketAddress.swift +++ b/FlyingSocks/Sources/SocketAddress.swift @@ -44,6 +44,31 @@ public protocol SocketAddress: Sendable { static var family: sa_family_t { get } } +extension SocketAddress { + public var family: sa_family_t { + var this = self + return withUnsafePointer(to: &this) { + $0.withMemoryRebound(to: sockaddr.self, capacity: 1) { + $0.pointee.sa_family + } + } + } + + var size: socklen_t { + // this needs to work with sockaddr_storage, hence the switch + switch Int32(family) { + case AF_INET: + socklen_t(MemoryLayout.size) + case AF_INET6: + socklen_t(MemoryLayout.size) + case AF_UNIX: + socklen_t(MemoryLayout.size) + default: + 0 + } + } +} + public extension SocketAddress where Self == sockaddr_in { static func inet(port: UInt16) -> Self { diff --git a/FlyingSocks/Tests/AsyncSocketTests.swift b/FlyingSocks/Tests/AsyncSocketTests.swift index 62152048..0d5e0ea7 100644 --- a/FlyingSocks/Tests/AsyncSocketTests.swift +++ b/FlyingSocks/Tests/AsyncSocketTests.swift @@ -187,12 +187,32 @@ struct AsyncSocketTests { try await sockets.next() == nil ) } + + @Test + func datagramSocketReceivesChunk_WhenAvailable() async throws { + let (s1, s2, addr) = try await AsyncSocket.makeDatagramPair() + + async let d2: (sockaddr_storage, [UInt8]) = s2.receive(atMost: 100) + // TODO: calling send() on Darwin to an unconnected datagram domain + // socket returns EISCONN +#if canImport(Darwin) + try await s1.write("Swift".data(using: .utf8)!) +#else + try await s1.send("Swift".data(using: .utf8)!, to: addr) +#endif + let v2 = try await d2 + #expect(String(data: Data(v2.1), encoding: .utf8) == "Swift") + + try s1.close() + try s2.close() + try? Socket.unlink(addr) + } } extension AsyncSocket { - static func make() async throws -> AsyncSocket { - try await make(pool: .client) + static func make(type: SocketType = .stream) async throws -> AsyncSocket { + try await make(pool: .client, type: type) } static func makeListening(pool: some AsyncSocketPool) throws -> AsyncSocket { @@ -205,13 +225,28 @@ extension AsyncSocket { return try AsyncSocket(socket: socket, pool: pool) } - static func make(pool: some AsyncSocketPool) throws -> AsyncSocket { - let socket = try Socket(domain: AF_UNIX, type: .stream) + static func make(pool: some AsyncSocketPool, type: SocketType = .stream) throws -> AsyncSocket { + let socket = try Socket(domain: AF_UNIX, type: type) return try AsyncSocket(socket: socket, pool: pool) } + static func makeDatagramPair() async throws -> (AsyncSocket, AsyncSocket, sockaddr_un) { + let socketPair = try await makePair(pool: .client, type: .datagram) + guard let endpoint = FileManager.default.makeTemporaryFile() else { + throw SocketError.makeFailed("MakeTemporaryFile") + } + let addr = sockaddr_un.unix(path: endpoint.path) + + try socketPair.1.socket.bind(to: addr) +#if canImport(Darwin) + try await socketPair.0.connect(to: addr) +#endif + + return (socketPair.0, socketPair.1, addr) + } + static func makePair() async throws -> (AsyncSocket, AsyncSocket) { - try await makePair(pool: .client) + try await makePair(pool: .client, type: .stream) } func writeString(_ string: String) async throws { diff --git a/FlyingSocks/Tests/FileManager+TemporaryFile.swift b/FlyingSocks/Tests/FileManager+TemporaryFile.swift new file mode 100644 index 00000000..3d9837b2 --- /dev/null +++ b/FlyingSocks/Tests/FileManager+TemporaryFile.swift @@ -0,0 +1,42 @@ +// Copyright 2019 The TensorFlow Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +@testable import FlyingSocks +import Foundation + +extension FileManager { + func makeTemporaryFile() -> URL? { + let dirPath = temporaryDirectory.appendingPathComponent("FlyingSocks.XXXXXX") + return dirPath.withUnsafeFileSystemRepresentation { maybePath in + guard let path = maybePath else { return nil } + var mutablePath = Array(repeating: Int8(0), count: Int(PATH_MAX)) + mutablePath.withUnsafeMutableBytes { mutablePathBufferPtr in + mutablePathBufferPtr.baseAddress!.copyMemory( + from: path, byteCount: Int(strlen(path)) + 1) + } + guard mktemp(&mutablePath) != nil else { return nil } + return URL( + fileURLWithFileSystemRepresentation: mutablePath, isDirectory: false, + relativeTo: nil) + } + } +} + +func withTemporaryFile(f: (URL) -> ()) throws { + guard let tmp = FileManager.default.makeTemporaryFile() else { + throw SocketError.makeFailed("MakeTemporaryFile") + } + defer { try? FileManager.default.removeItem(atPath: tmp.path) } + f(tmp) +} diff --git a/FlyingSocks/Tests/SocketAddressTests.swift b/FlyingSocks/Tests/SocketAddressTests.swift index 37fc5056..ba7ea7b8 100644 --- a/FlyingSocks/Tests/SocketAddressTests.swift +++ b/FlyingSocks/Tests/SocketAddressTests.swift @@ -159,6 +159,40 @@ struct SocketAddressTests { } } + @Test + func INET4_CheckSize() throws { + let sin = sockaddr_in.inet(port: 8001) + #expect( + sin.size == socklen_t(MemoryLayout.size) + ) + } + + @Test + func INET6_CheckSize() throws { + let sin6 = sockaddr_in6.inet6(port: 8001) + #expect( + sin6.size == socklen_t(MemoryLayout.size) + ) + } + + @Test + func unix_CheckSize() throws { + let sun = sockaddr_un.unix(path: "/var/foo") + #expect( + sun.size == socklen_t(MemoryLayout.size) + ) + } + + @Test + func unknown_CheckSize() throws { + var sa = sockaddr() + sa.sa_family = sa_family_t(AF_UNSPEC) + + #expect( + sa.size == 0 + ) + } + @Test func unlinkUnix_Throws_WhenPathIsInvalid() { #expect(throws: SocketError.self) { @@ -300,23 +334,7 @@ struct SocketAddressTests { } } - -private extension SocketAddress { - - func makeStorage() -> sockaddr_storage { - var storage = sockaddr_storage() - var addr = self - let addrSize = MemoryLayout.size - let storageSize = MemoryLayout.size - - withUnsafePointer(to: &addr) { addrPtr in - let addrRawPtr = UnsafeRawPointer(addrPtr) - withUnsafeMutablePointer(to: &storage) { storagePtr in - let storageRawPtr = UnsafeMutableRawPointer(storagePtr) - let copySize = min(addrSize, storageSize) - storageRawPtr.copyMemory(from: addrRawPtr, byteCount: copySize) - } - } - return storage - } +// this is a bit ugly but necessary to get unknown_CheckSize() to function +extension sockaddr: SocketAddress, @unchecked Sendable { + public static var family: sa_family_t { sa_family_t(AF_UNSPEC) } } From a7a3d2105c529adcb8f4b0c04c9acb5ddb00c591 Mon Sep 17 00:00:00 2001 From: Luke Howard Date: Tue, 12 Nov 2024 11:44:01 +1100 Subject: [PATCH 4/8] accept sockaddr_storage in bind() and connect() --- FlyingSocks/Sources/Socket.swift | 8 ++++---- FlyingSocks/Sources/SocketAddress.swift | 10 +++++++++- 2 files changed, 13 insertions(+), 5 deletions(-) diff --git a/FlyingSocks/Sources/Socket.swift b/FlyingSocks/Sources/Socket.swift index 9c7574d8..ba8e3a9a 100644 --- a/FlyingSocks/Sources/Socket.swift +++ b/FlyingSocks/Sources/Socket.swift @@ -120,11 +120,11 @@ public struct Socket: Sendable, Hashable { return option.makeValue(from: valuePtr.pointee) } - public func bind(to address: A) throws { + public func bind(to address: some SocketAddress) throws { var addr = address let result = withUnsafePointer(to: &addr) { $0.withMemoryRebound(to: sockaddr.self, capacity: 1) { - Socket.bind(file.rawValue, $0, socklen_t(MemoryLayout.size)) + Socket.bind(file.rawValue, $0, address.size) } } guard result >= 0 else { @@ -191,11 +191,11 @@ public struct Socket: Sendable, Hashable { return (newFile, addr) } - public func connect(to address: A) throws { + public func connect(to address: some SocketAddress) throws { var addr = address let result = withUnsafePointer(to: &addr) { $0.withMemoryRebound(to: sockaddr.self, capacity: 1) { - Socket.connect(file.rawValue, $0, socklen_t(MemoryLayout.size)) + Socket.connect(file.rawValue, $0, address.size) } } guard result >= 0 || errno == EISCONN else { diff --git a/FlyingSocks/Sources/SocketAddress.swift b/FlyingSocks/Sources/SocketAddress.swift index e43ae962..516ba228 100644 --- a/FlyingSocks/Sources/SocketAddress.swift +++ b/FlyingSocks/Sources/SocketAddress.swift @@ -107,6 +107,10 @@ public extension SocketAddress where Self == sockaddr_un { } #if compiler(>=6.0) +extension sockaddr_storage: SocketAddress, @retroactive @unchecked Sendable { + public static let family = sa_family_t(AF_UNSPEC) +} + extension sockaddr_in: SocketAddress, @retroactive @unchecked Sendable { public static let family = sa_family_t(AF_INET) } @@ -119,6 +123,10 @@ extension sockaddr_un: SocketAddress, @retroactive @unchecked Sendable { public static let family = sa_family_t(AF_UNIX) } #else +extension sockaddr_storage: SocketAddress, @unchecked Sendable { + public static let family = sa_family_t(AF_UNSPEC) +} + extension sockaddr_in: SocketAddress, @unchecked Sendable { public static let family = sa_family_t(AF_INET) } @@ -134,7 +142,7 @@ extension sockaddr_un: SocketAddress, @unchecked Sendable { public extension SocketAddress { static func make(from storage: sockaddr_storage) throws -> Self { - guard storage.ss_family == family else { + guard self is sockaddr_storage || storage.ss_family == family else { throw SocketError.unsupportedAddress } var storage = storage From aeb18ed254dc75f84b8ee8eb804f94d1a7c08ac9 Mon Sep 17 00:00:00 2001 From: Luke Howard Date: Wed, 13 Nov 2024 14:56:15 +1100 Subject: [PATCH 5/8] Support for multihoming with unbound sockets Multihomed servers (i.e. those with multiple network interfaces) need to take care to send UDP responses from the interface on which they were received, otherwise the client may never receive the packet. Historically this was done by binding separate sockets to each interface, however this does not adapt well to dynamic interface changes (without extra code to monitor for this, which is impossible to do in a portable manner). Modern operating systems provide `IP_PKTINFO` and `IPV6_PKTINFO` which allow the local interface index and address to be set and reported on unbound sockets. This commit adds support for this. Note: currently unavailable on Windows. --- FlyingSocks/Sources/AsyncSocket.swift | 74 +++++- FlyingSocks/Sources/Socket+Android.swift | 16 ++ FlyingSocks/Sources/Socket+Darwin.swift | 16 ++ FlyingSocks/Sources/Socket+Glibc.swift | 21 ++ FlyingSocks/Sources/Socket+Musl.swift | 16 ++ FlyingSocks/Sources/Socket+WinSock2.swift | 16 ++ FlyingSocks/Sources/Socket.swift | 291 +++++++++++++++++++++- FlyingSocks/Sources/SocketAddress.swift | 12 + FlyingSocks/Tests/AsyncSocketTests.swift | 20 ++ 9 files changed, 479 insertions(+), 3 deletions(-) diff --git a/FlyingSocks/Sources/AsyncSocket.swift b/FlyingSocks/Sources/AsyncSocket.swift index 4afe2209..339a3625 100644 --- a/FlyingSocks/Sources/AsyncSocket.swift +++ b/FlyingSocks/Sources/AsyncSocket.swift @@ -62,6 +62,25 @@ public extension AsyncSocketPool where Self == SocketPool { public struct AsyncSocket: Sendable { + public struct Message: Sendable { + public let peerAddress: sockaddr_storage + public let bytes: [UInt8] + public let interfaceIndex: UInt32? + public let localAddress: sockaddr_storage? + + public init( + peerAddress: sockaddr_storage, + bytes: [UInt8], + interfaceIndex: UInt32? = nil, + localAddress: sockaddr_storage? = nil + ) { + self.peerAddress = peerAddress + self.bytes = bytes + self.interfaceIndex = interfaceIndex + self.localAddress = localAddress + } + } + public let socket: Socket let pool: any AsyncSocketPool @@ -143,6 +162,23 @@ public struct AsyncSocket: Sendable { } while true } +#if !canImport(WinSDK) + public func receive(atMost length: Int) async throws -> Message { + try Task.checkCancellation() + + repeat { + do { + let (peerAddress, bytes, interfaceIndex, localAddress) = try socket.receive(length: length) + return Message(peerAddress: peerAddress, bytes: bytes, interfaceIndex: interfaceIndex, localAddress: localAddress) + } catch SocketError.blocked { + try await pool.suspendSocket(socket, untilReadyFor: .read) + } catch { + throw error + } + } while true + } +#endif + /// Reads bytes from the socket up to by not over/ /// - Parameter bytes: The max number of bytes to read /// - Returns: an array of the read bytes capped to the number of bytes provided. @@ -190,6 +226,31 @@ public struct AsyncSocket: Sendable { try await send(Array(data), to: address) } +#if !canImport(WinSDK) + public func send( + message: [UInt8], + to peerAddress: some SocketAddress, + interfaceIndex: UInt32? = nil, + from localAddress: (some SocketAddress)? = nil + ) async throws { + let sent = try await pool.loopUntilReady(for: .write, on: socket) { + try socket.send(message: message, to: peerAddress, interfaceIndex: interfaceIndex, from: localAddress) + } + guard sent == message.count else { + throw SocketError.disconnected + } + } + + public func send( + message: Data, + to peerAddress: some SocketAddress, + interfaceIndex: UInt32? = nil, + from localAddress: (some SocketAddress)? = nil + ) async throws { + try await send(message: Array(message), to: peerAddress, interfaceIndex: interfaceIndex, from: localAddress) + } +#endif + public func close() throws { try socket.close() } @@ -275,7 +336,8 @@ public struct AsyncSocketSequence: AsyncSequence, AsyncIteratorProtocol, Sendabl public struct AsyncSocketMessageSequence: AsyncSequence, AsyncIteratorProtocol, Sendable { public static let DefaultMaxMessageLength: Int = 1500 - public typealias Element = (sockaddr_storage, [UInt8]) + // Windows has a different recvmsg() API signature which is presently unsupported + public typealias Element = AsyncSocket.Message private let socket: AsyncSocket private let maxMessageLength: Int @@ -288,7 +350,15 @@ public struct AsyncSocketMessageSequence: AsyncSequence, AsyncIteratorProtocol, } public mutating func next() async throws -> Element? { - return try await socket.receive(atMost: maxMessageLength) +#if !canImport(WinSDK) + try await socket.receive(atMost: maxMessageLength) +#else + let peerAddress: sockaddr_storage + let bytes: [UInt8] + + (peerAddress, bytes) = try await socket.receive(atMost: maxMessageLength) + return AsyncSocket.Message(peerAddress: peerAddress, bytes: bytes) +#endif } } diff --git a/FlyingSocks/Sources/Socket+Android.swift b/FlyingSocks/Sources/Socket+Android.swift index 4574267d..9dab334b 100644 --- a/FlyingSocks/Sources/Socket+Android.swift +++ b/FlyingSocks/Sources/Socket+Android.swift @@ -37,6 +37,10 @@ let EPOLLET: UInt32 = 1 << 31; public extension Socket { typealias FileDescriptorType = Int32 + typealias IovLengthType = UInt + typealias ControlMessageHeaderLengthType = Int + typealias IPv4InterfaceIndexType = Int32 + typealias IPv6InterfaceIndexType = Int32 } extension Socket.FileDescriptor { @@ -47,6 +51,10 @@ extension Socket { static let stream = Int32(SOCK_STREAM) static let datagram = Int32(SOCK_DGRAM) static let in_addr_any = Android.in_addr(s_addr: Android.in_addr_t(0)) + static let ipproto_ip = Int32(IPPROTO_IP) + static let ipproto_ipv6 = Int32(IPPROTO_IPV6) + static let ip_pktinfo = Int32(IP_PKTINFO) + static let ipv6_pktinfo = Int32(IPV6_PKTINFO) static func makeAddressINET(port: UInt16) -> Android.sockaddr_in { Android.sockaddr_in( @@ -184,6 +192,14 @@ extension Socket { static func sendto(_ fd: FileDescriptorType, _ buffer: UnsafeRawPointer!, _ nbyte: Int, _ flags: Int32, _ destaddr: UnsafePointer!, _ destlen: socklen_t) -> Int { Android.sendto(fd, buffer, nbyte, flags, destaddr, destlen) } + + static func recvmsg(_ fd: FileDescriptorType, _ message: UnsafeMutablePointer, _ flags: Int32) -> Int { + Android.recvmsg(fd, message, flags) + } + + static func sendmsg(_ fd: FileDescriptorType, _ message: UnsafePointer, _ flags: Int32) -> Int { + Android.sendmsg(fd, message, flags) + } } #endif diff --git a/FlyingSocks/Sources/Socket+Darwin.swift b/FlyingSocks/Sources/Socket+Darwin.swift index f4829f4b..d0b26e52 100644 --- a/FlyingSocks/Sources/Socket+Darwin.swift +++ b/FlyingSocks/Sources/Socket+Darwin.swift @@ -34,6 +34,10 @@ import Darwin public extension Socket { typealias FileDescriptorType = Int32 + typealias IovLengthType = Int + typealias ControlMessageHeaderLengthType = UInt32 + typealias IPv4InterfaceIndexType = UInt32 + typealias IPv6InterfaceIndexType = UInt32 } extension Socket.FileDescriptor { @@ -44,6 +48,10 @@ extension Socket { static let stream = Int32(SOCK_STREAM) static let datagram = Int32(SOCK_DGRAM) static let in_addr_any = Darwin.in_addr(s_addr: Darwin.in_addr_t(0)) + static let ipproto_ip = Int32(IPPROTO_IP) + static let ipproto_ipv6 = Int32(IPPROTO_IPV6) + static let ip_pktinfo = Int32(IP_PKTINFO) + static let ipv6_pktinfo = Int32(50) // __APPLE_USE_RFC_2292 static func makeAddressINET(port: UInt16) -> Darwin.sockaddr_in { Darwin.sockaddr_in( @@ -185,6 +193,14 @@ extension Socket { static func sendto(_ fd: FileDescriptorType, _ buffer: UnsafeRawPointer!, _ nbyte: Int, _ flags: Int32, _ destaddr: UnsafePointer!, _ destlen: socklen_t) -> Int { Darwin.sendto(fd, buffer, nbyte, flags, destaddr, destlen) } + + static func recvmsg(_ fd: FileDescriptorType, _ message: UnsafeMutablePointer, _ flags: Int32) -> Int { + Darwin.recvmsg(fd, message, flags) + } + + static func sendmsg(_ fd: FileDescriptorType, _ message: UnsafePointer, _ flags: Int32) -> Int { + Darwin.sendmsg(fd, message, flags) + } } #endif diff --git a/FlyingSocks/Sources/Socket+Glibc.swift b/FlyingSocks/Sources/Socket+Glibc.swift index cc4aeafc..2cec8ca1 100644 --- a/FlyingSocks/Sources/Socket+Glibc.swift +++ b/FlyingSocks/Sources/Socket+Glibc.swift @@ -34,6 +34,10 @@ import Glibc public extension Socket { typealias FileDescriptorType = Int32 + typealias IovLengthType = Int + typealias ControlMessageHeaderLengthType = Int + typealias IPv4InterfaceIndexType = Int32 + typealias IPv6InterfaceIndexType = UInt32 } extension Socket.FileDescriptor { @@ -44,6 +48,10 @@ extension Socket { static let stream = Int32(SOCK_STREAM.rawValue) static let datagram = Int32(SOCK_DGRAM.rawValue) static let in_addr_any = Glibc.in_addr(s_addr: Glibc.in_addr_t(0)) + static let ipproto_ip = Int32(IPPROTO_IP) + static let ipproto_ipv6 = Int32(IPPROTO_IPV6) + static let ip_pktinfo = Int32(IP_PKTINFO) + static let ipv6_pktinfo = Int32(IPV6_PKTINFO) static func makeAddressINET(port: UInt16) -> Glibc.sockaddr_in { Glibc.sockaddr_in( @@ -181,6 +189,19 @@ extension Socket { static func sendto(_ fd: FileDescriptorType, _ buffer: UnsafeRawPointer!, _ nbyte: Int, _ flags: Int32, _ destaddr: UnsafePointer!, _ destlen: socklen_t) -> Int { Glibc.sendto(fd, buffer, nbyte, flags, destaddr, destlen) } + + static func recvmsg(_ fd: FileDescriptorType, _ message: UnsafeMutablePointer, _ flags: Int32) -> Int { + Glibc.recvmsg(fd, message, flags) + } + + static func sendmsg(_ fd: FileDescriptorType, _ message: UnsafePointer, _ flags: Int32) -> Int { + Glibc.sendmsg(fd, message, flags) + } +} + +struct in6_pktinfo { + var ipi6_addr: in6_addr + var ipi6_ifindex: CUnsignedInt } #endif diff --git a/FlyingSocks/Sources/Socket+Musl.swift b/FlyingSocks/Sources/Socket+Musl.swift index 2de01f32..4d823717 100644 --- a/FlyingSocks/Sources/Socket+Musl.swift +++ b/FlyingSocks/Sources/Socket+Musl.swift @@ -34,6 +34,10 @@ import Musl public extension Socket { typealias FileDescriptorType = Int32 + typealias IovLengthType = Int + typealias ControlMessageHeaderLengthType = UInt32 + typealias IPv4InterfaceIndexType = Int32 + typealias IPv6InterfaceIndexType = UInt32 } extension Socket.FileDescriptor { @@ -44,6 +48,10 @@ extension Socket { static let stream = Int32(SOCK_STREAM) static let datagram = Int32(SOCK_DGRAM) static let in_addr_any = Musl.in_addr(s_addr: Musl.in_addr_t(0)) + static let ipproto_ip = Int32(IPPROTO_IP) + static let ipproto_ipv6 = Int32(IPPROTO_IPV6) + static let ip_pktinfo = Int32(IP_PKTINFO) + static let ipv6_pktinfo = Int32(IPV6_PKTINFO) static func makeAddressINET(port: UInt16) -> Musl.sockaddr_in { Musl.sockaddr_in( @@ -181,6 +189,14 @@ extension Socket { static func sendto(_ fd: FileDescriptorType, _ buffer: UnsafeRawPointer!, _ nbyte: Int, _ flags: Int32, _ destaddr: UnsafePointer!, _ destlen: socklen_t) -> Int { Musl.sendto(fd, buffer, nbyte, flags, destaddr, destlen) } + + static func recvmsg(_ fd: FileDescriptorType, _ message: UnsafeMutablePointer, _ flags: Int32) -> Int { + Musl.recvmsg(fd, message, flags) + } + + static func sendmsg(_ fd: FileDescriptorType, _ message: UnsafePointer, _ flags: Int32) -> Int { + Musl.sendmsg(fd, message, flags) + } } #endif diff --git a/FlyingSocks/Sources/Socket+WinSock2.swift b/FlyingSocks/Sources/Socket+WinSock2.swift index 177670d3..d2b87397 100755 --- a/FlyingSocks/Sources/Socket+WinSock2.swift +++ b/FlyingSocks/Sources/Socket+WinSock2.swift @@ -44,6 +44,10 @@ public typealias sa_family_t = UInt8 public extension Socket { typealias FileDescriptorType = UInt64 + typealias IovLengthType = UInt + typealias ControlMessageHeaderLengthType = DWORD + typealias IPv4InterfaceIndexType = ULONG + typealias IPv6InterfaceIndexType = ULONG } extension Socket.FileDescriptor { @@ -54,6 +58,10 @@ extension Socket { static let stream = Int32(SOCK_STREAM) static let datagram = Int32(SOCK_DGRAM) static let in_addr_any = WinSDK.in_addr() + static let ipproto_ip = Int32(IPPROTO_IP) + static let ipproto_ipv6 = Int32(IPPROTO_IPV6) + static let ip_pktinfo = Int32(IP_PKTINFO) + static let ipv6_pktinfo = Int32(IPV6_PKTINFO) static func makeAddressINET(port: UInt16) -> WinSDK.sockaddr_in { WinSDK.sockaddr_in( @@ -193,6 +201,14 @@ extension Socket { static func sendto(_ fd: FileDescriptorType, _ buffer: UnsafeRawPointer!, _ nbyte: Int, _ flags: Int32, _ destaddr: UnsafePointer!, _ destlen: socklen_t) -> Int { WinSDK.sendto(fd, buffer, nbyte, flags, destaddr, destlen) } + + static func recvmsg(_ fd: FileDescriptorType, _ message: UnsafeMutablePointer, _ flags: Int32) -> Int { + WinSDK.recvmsg(fd, message, flags) + } + + static func sendmsg(_ fd: FileDescriptorType, _ message: UnsafePointer, _ flags: Int32) -> Int { + WinSDK.sendmsg(fd, message, flags) + } } #endif diff --git a/FlyingSocks/Sources/Socket.swift b/FlyingSocks/Sources/Socket.swift index ba8e3a9a..d83421b5 100644 --- a/FlyingSocks/Sources/Socket.swift +++ b/FlyingSocks/Sources/Socket.swift @@ -79,6 +79,9 @@ public struct Socket: Sendable, Hashable { throw SocketError.makeFailed("CreateSocket") } self.file = descriptor + if type == SocketType.datagram.rawValue { + try setPktInfo(domain: domain) + } } public init(domain: Int32, type: SocketType) throws { @@ -101,6 +104,29 @@ public struct Socket: Sendable, Hashable { } } + // enable return of ip_pktinfo/ipv6_pktinfo on recvmsg() + private func setPktInfo(domain: Int32) throws { + var enable = Int32(1) + let level: Int32 + let name: Int32 + + switch domain { + case AF_INET: + level = Socket.ipproto_ip + name = Self.ip_pktinfo + case AF_INET6: + level = Socket.ipproto_ipv6 + name = Self.ipv6_pktinfo + default: + return + } + + let result = Socket.setsockopt(file.rawValue, level, name, &enable, socklen_t(MemoryLayout.size)) + guard result >= 0 else { + throw SocketError.makeFailed("SetPktInfoOption") + } + } + public func setValue(_ value: O.Value, for option: O) throws { var value = option.makeSocketValue(from: value) let result = withUnsafeBytes(of: &value) { @@ -253,7 +279,7 @@ public struct Socket: Sendable, Hashable { } } guard count > 0 else { - if errno == EWOULDBLOCK || errno == EAGAIN { + if errno == EWOULDBLOCK { throw SocketError.blocked } else if errno == EBADF || count == 0 { throw SocketError.disconnected @@ -264,6 +290,73 @@ public struct Socket: Sendable, Hashable { return (addr, count) } +#if !canImport(WinSDK) + public func receive(length: Int) throws -> (sockaddr_storage, [UInt8], UInt32?, sockaddr_storage?) { + var peerAddress: sockaddr_storage? + var interfaceIndex: UInt32? + var localAddress: sockaddr_storage? + + let bytes = try [UInt8](unsafeUninitializedCapacity: length) { buffer, count in + (peerAddress, count, interfaceIndex, localAddress) = try receive(into: buffer.baseAddress!, length: length, flags: 0) + } + + return (peerAddress!, bytes, interfaceIndex, localAddress) + } + + private static let ControlMsgBufferSize = MemoryLayout.size + max(MemoryLayout.size, MemoryLayout.size) + + private func receive( + into buffer: UnsafeMutablePointer, + length: Int, + flags: Int32 + ) throws -> (sockaddr_storage, Int, UInt32?, sockaddr_storage?) { + var iov = iovec() + var msg = msghdr() + var peerAddress = sockaddr_storage() + var localAddress: sockaddr_storage? + var interfaceIndex: UInt32? + var controlMsgBuffer = [UInt8](repeating: 0, count: Socket.ControlMsgBufferSize) + + iov.iov_base = UnsafeMutableRawPointer(buffer) + iov.iov_len = IovLengthType(length) + + let count = withUnsafeMutablePointer(to: &iov) { iov in + msg.msg_iov = iov + msg.msg_iovlen = 1 + msg.msg_namelen = socklen_t(MemoryLayout.size) + + return withUnsafeMutablePointer(to: &peerAddress) { peerAddress in + msg.msg_name = UnsafeMutableRawPointer(peerAddress) + + return controlMsgBuffer.withUnsafeMutableBytes { controlMsgBuffer in + msg.msg_control = UnsafeMutableRawPointer(controlMsgBuffer.baseAddress) + msg.msg_controllen = ControlMessageHeaderLengthType(controlMsgBuffer.count) + + let count = Socket.recvmsg(file.rawValue, &msg, flags) + + if count > 0, msg.msg_controllen != 0 { + (interfaceIndex, localAddress) = Socket.getPacketInfoControl(msghdr: msg) + } + + return count + } + } + } + + guard count > 0 else { + if errno == EWOULDBLOCK || errno == EAGAIN { + throw SocketError.blocked + } else if errno == EBADF || count == 0 { + throw SocketError.disconnected + } else { + throw SocketError.makeFailed("RecvMsg") + } + } + + return (peerAddress, count, interfaceIndex, localAddress) + } +#endif + public func write(_ data: Data, from index: Data.Index = 0) throws -> Data.Index { precondition(index >= 0) guard index < data.endIndex else { return data.endIndex } @@ -310,6 +403,75 @@ public struct Socket: Sendable, Hashable { return sent } +#if !canImport(WinSDK) + public func send( + message: [UInt8], + to peerAddress: some SocketAddress, + interfaceIndex: UInt32? = nil, + from localAddress: (some SocketAddress)? = nil + ) throws -> Int { + try message.withUnsafeBytes { buffer in + try send( + buffer.baseAddress!, + length: buffer.count, + flags: 0, + to: peerAddress, + interfaceIndex: interfaceIndex, + from: localAddress + ) + } + } + + private func send( + _ pointer: UnsafeRawPointer, + length: Int, + flags: Int32, + to peerAddress: some SocketAddress, + interfaceIndex: UInt32? = nil, + from localAddress: (some SocketAddress)? = nil + ) throws -> Int { + var iov = iovec() + var msg = msghdr() + let family = peerAddress.family + + iov.iov_base = UnsafeMutableRawPointer(mutating: pointer) + iov.iov_len = IovLengthType(length) + + let sent = withUnsafeMutablePointer(to: &iov) { iov in + var peerAddress = peerAddress + + msg.msg_iov = iov + msg.msg_iovlen = 1 + msg.msg_namelen = peerAddress.size + + return withUnsafeMutablePointer(to: &peerAddress) { peerAddress in + msg.msg_name = UnsafeMutableRawPointer(peerAddress) + + return Socket.withPacketInfoControl( + family: family, + interfaceIndex: interfaceIndex, + address: localAddress) { control, controllen in + if let control { + msg.msg_control = UnsafeMutableRawPointer(mutating: control) + msg.msg_controllen = controllen + } + return Socket.sendmsg(file.rawValue, &msg, flags) + } + } + } + + guard sent >= 0 else { + if errno == EWOULDBLOCK || errno == EAGAIN { + throw SocketError.blocked + } else { + throw SocketError.makeFailed("SendMsg") + } + } + + return sent + } +#endif + public func close() throws { if Socket.close(file.rawValue) == -1 { throw SocketError.makeFailed("Close") @@ -441,3 +603,130 @@ package extension Socket { try Socket.makePair(flags: .nonBlocking, type: type) } } + +#if !canImport(WinSDK) +fileprivate extension Socket { + // https://github.com/swiftlang/swift-evolution/blob/main/proposals/0138-unsaferawbufferpointer.md + private static func withControlMessage( + control: UnsafeRawPointer, + controllen: ControlMessageHeaderLengthType, + _ body: (cmsghdr, UnsafeRawBufferPointer) -> () + ) { + let controlBuffer = UnsafeRawBufferPointer(start: control, count: Int(controllen)) + var cmsgHeaderIndex = 0 + + while true { + let cmsgDataIndex = cmsgHeaderIndex + MemoryLayout.stride + + if cmsgDataIndex > controllen { + break + } + + let header = controlBuffer.load(fromByteOffset: cmsgHeaderIndex, as: cmsghdr.self) + if Int(header.cmsg_len) < MemoryLayout.stride { + break + } + + cmsgHeaderIndex = cmsgDataIndex + cmsgHeaderIndex += Int(header.cmsg_len) - MemoryLayout.stride + if cmsgHeaderIndex > controlBuffer.count { + break + } + body(header, UnsafeRawBufferPointer(rebasing: controlBuffer[cmsgDataIndex...alignment - 1 + cmsgHeaderIndex &= ~(MemoryLayout.alignment - 1) + } + } + + static func getPacketInfoControl( + msghdr: msghdr + ) -> (UInt32?, sockaddr_storage?) { + var interfaceIndex: UInt32? + var localAddress = sockaddr_storage() + + withControlMessage(control: msghdr.msg_control, controllen: msghdr.msg_controllen) { cmsghdr, cmsgdata in + switch cmsghdr.cmsg_level { + case Socket.ipproto_ip: + guard cmsghdr.cmsg_type == Socket.ip_pktinfo else { break } + cmsgdata.baseAddress!.withMemoryRebound(to: in_pktinfo.self, capacity: 1) { pktinfo in + var sin = sockaddr_in() + sin.sin_addr = pktinfo.pointee.ipi_addr + interfaceIndex = UInt32(pktinfo.pointee.ipi_ifindex) + localAddress = sin.makeStorage() + } + case Socket.ipproto_ipv6: + guard cmsghdr.cmsg_type == Socket.ipv6_pktinfo else { break } + cmsgdata.baseAddress!.withMemoryRebound(to: in6_pktinfo.self, capacity: 1) { pktinfo in + var sin6 = sockaddr_in6() + sin6.sin6_addr = pktinfo.pointee.ipi6_addr + interfaceIndex = UInt32(pktinfo.pointee.ipi6_ifindex) + localAddress = sin6.makeStorage() + } + default: + break + } + } + + return (interfaceIndex, interfaceIndex != nil ? localAddress : nil) + } + + static func withPacketInfoControl( + family: sa_family_t, + interfaceIndex: UInt32?, + address: (some SocketAddress)?, + _ body: (UnsafePointer?, ControlMessageHeaderLengthType) -> T + ) -> T { + switch Int32(family) { + case AF_INET: + let buffer = ManagedBuffer.create(minimumCapacity: 1) { buffer in + buffer.withUnsafeMutablePointers { header, element in + header.pointee.cmsg_len = ControlMessageHeaderLengthType(MemoryLayout.size + MemoryLayout.size) + header.pointee.cmsg_level = SOL_SOCKET + header.pointee.cmsg_type = Socket.ipproto_ip + element.pointee.ipi_ifindex = IPv4InterfaceIndexType(interfaceIndex ?? 0) + if let address { + var address = address + withUnsafePointer(to: &address) { + $0.withMemoryRebound(to: sockaddr_in.self, capacity: 1) { + element.pointee.ipi_addr = $0.pointee.sin_addr + } + } + } else { + element.pointee.ipi_addr.s_addr = 0 + } + + return header.pointee + } + } + + return buffer.withUnsafeMutablePointerToHeader { body($0, ControlMessageHeaderLengthType($0.pointee.cmsg_len)) } + case AF_INET6: + let buffer = ManagedBuffer.create(minimumCapacity: 1) { buffer in + buffer.withUnsafeMutablePointers { header, element in + header.pointee.cmsg_len = ControlMessageHeaderLengthType(MemoryLayout.size + MemoryLayout.size) + header.pointee.cmsg_level = SOL_SOCKET + header.pointee.cmsg_type = Socket.ipproto_ipv6 + element.pointee.ipi6_ifindex = IPv6InterfaceIndexType(interfaceIndex ?? 0) + if let address { + var address = address + withUnsafePointer(to: &address) { + $0.withMemoryRebound(to: sockaddr_in6.self, capacity: 1) { + element.pointee.ipi6_addr = $0.pointee.sin6_addr + } + } + } else { + element.pointee.ipi6_addr = in6_addr() + } + + return header.pointee + } + } + + return buffer.withUnsafeMutablePointerToHeader { body($0, ControlMessageHeaderLengthType($0.pointee.cmsg_len)) } + default: + return body(nil, 0) + } + } +} +#endif diff --git a/FlyingSocks/Sources/SocketAddress.swift b/FlyingSocks/Sources/SocketAddress.swift index 516ba228..ca43f0ea 100644 --- a/FlyingSocks/Sources/SocketAddress.swift +++ b/FlyingSocks/Sources/SocketAddress.swift @@ -67,6 +67,18 @@ extension SocketAddress { 0 } } + + public func makeStorage() -> sockaddr_storage { + var storage = sockaddr_storage() + + withUnsafeMutablePointer(to: &storage) { + $0.withMemoryRebound(to: Self.self, capacity: 1) { + $0.pointee = self + } + } + + return storage + } } public extension SocketAddress where Self == sockaddr_in { diff --git a/FlyingSocks/Tests/AsyncSocketTests.swift b/FlyingSocks/Tests/AsyncSocketTests.swift index 0d5e0ea7..969e1b18 100644 --- a/FlyingSocks/Tests/AsyncSocketTests.swift +++ b/FlyingSocks/Tests/AsyncSocketTests.swift @@ -207,6 +207,26 @@ struct AsyncSocketTests { try s2.close() try? Socket.unlink(addr) } + +#if !canImport(WinSDK) + @Test + func datagramSocketReceivesMessage_WhenAvailable() async throws { + let (s1, s2, addr) = try await AsyncSocket.makeDatagramPair() + + async let d2: AsyncSocket.Message = s2.receive(atMost: 100) +#if canImport(Darwin) + try await s1.write("Swift".data(using: .utf8)!) +#else + try await s1.send(message: "Swift".data(using: .utf8)!, to: addr, from: addr) +#endif + let v2 = try await d2 + #expect(String(data: Data(v2.bytes), encoding: .utf8) == "Swift") + + try s1.close() + try s2.close() + try? Socket.unlink(addr) + } +#endif } extension AsyncSocket { From 409c3e859d3b9793aa6f59ca0533e01c2bf656e9 Mon Sep 17 00:00:00 2001 From: Luke Howard Date: Thu, 14 Nov 2024 16:41:53 +1100 Subject: [PATCH 6/8] Add type-erased AnySocketAddress --- FlyingSocks/Sources/Socket.swift | 45 +++++++--------------- FlyingSocks/Sources/SocketAddress.swift | 39 +++++++++++++++---- FlyingSocks/Tests/SocketAddressTests.swift | 35 +++++++++++++++++ 3 files changed, 81 insertions(+), 38 deletions(-) diff --git a/FlyingSocks/Sources/Socket.swift b/FlyingSocks/Sources/Socket.swift index d83421b5..9ca423cc 100644 --- a/FlyingSocks/Sources/Socket.swift +++ b/FlyingSocks/Sources/Socket.swift @@ -147,11 +147,8 @@ public struct Socket: Sendable, Hashable { } public func bind(to address: some SocketAddress) throws { - var addr = address - let result = withUnsafePointer(to: &addr) { - $0.withMemoryRebound(to: sockaddr.self, capacity: 1) { - Socket.bind(file.rawValue, $0, address.size) - } + let result = address.withSockAddr { + Socket.bind(file.rawValue, $0, address.size) } guard result >= 0 else { throw SocketError.makeFailed("Bind") @@ -170,10 +167,8 @@ public struct Socket: Sendable, Hashable { var addr = sockaddr_storage() var len = socklen_t(MemoryLayout.size) - let result = withUnsafeMutablePointer(to: &addr) { - $0.withMemoryRebound(to: sockaddr.self, capacity: 1) { - Socket.getpeername(file.rawValue, $0, &len) - } + let result = addr.withMutableSockAddr { + Socket.getpeername(file.rawValue, $0, &len) } if result != 0 { throw SocketError.makeFailed("GetPeerName") @@ -185,10 +180,8 @@ public struct Socket: Sendable, Hashable { var addr = sockaddr_storage() var len = socklen_t(MemoryLayout.size) - let result = withUnsafeMutablePointer(to: &addr) { - $0.withMemoryRebound(to: sockaddr.self, capacity: 1) { - Socket.getsockname(file.rawValue, $0, &len) - } + let result = addr.withMutableSockAddr { + Socket.getsockname(file.rawValue, $0, &len) } if result != 0 { throw SocketError.makeFailed("GetSockName") @@ -200,10 +193,8 @@ public struct Socket: Sendable, Hashable { var addr = sockaddr_storage() var len = socklen_t(MemoryLayout.size) - let newFile = withUnsafeMutablePointer(to: &addr) { - $0.withMemoryRebound(to: sockaddr.self, capacity: 1) { - FileDescriptor(rawValue: Socket.accept(file.rawValue, $0, &len)) - } + let newFile = addr.withMutableSockAddr { + FileDescriptor(rawValue: Socket.accept(file.rawValue, $0, &len)) } guard newFile != .invalid else { @@ -218,11 +209,8 @@ public struct Socket: Sendable, Hashable { } public func connect(to address: some SocketAddress) throws { - var addr = address - let result = withUnsafePointer(to: &addr) { - $0.withMemoryRebound(to: sockaddr.self, capacity: 1) { - Socket.connect(file.rawValue, $0, address.size) - } + let result = address.withSockAddr { + Socket.connect(file.rawValue, $0, address.size) } guard result >= 0 || errno == EISCONN else { if errno == EINPROGRESS { @@ -273,10 +261,8 @@ public struct Socket: Sendable, Hashable { private func receive(into buffer: UnsafeMutablePointer, length: Int) throws -> (sockaddr_storage, Int) { var addr = sockaddr_storage() var size = socklen_t(MemoryLayout.size) - let count = withUnsafeMutablePointer(to: &addr) { - $0.withMemoryRebound(to: sockaddr.self, capacity: 1) { - Socket.recvfrom(file.rawValue, buffer, length, 0, $0, &size) - } + let count = addr.withMutableSockAddr { + Socket.recvfrom(file.rawValue, buffer, length, 0, $0, &size) } guard count > 0 else { if errno == EWOULDBLOCK { @@ -387,11 +373,8 @@ public struct Socket: Sendable, Hashable { } private func send(_ pointer: UnsafeRawPointer, length: Int, to address: some SocketAddress) throws -> Int { - var addr = address - let sent = withUnsafePointer(to: &addr) { - $0.withMemoryRebound(to: sockaddr.self, capacity: 1) { - Socket.sendto(file.rawValue, pointer, length, 0, $0, address.size) - } + let sent = address.withSockAddr { + Socket.sendto(file.rawValue, pointer, length, 0, $0, address.size) } guard sent >= 0 else { if errno == EWOULDBLOCK || errno == EAGAIN { diff --git a/FlyingSocks/Sources/SocketAddress.swift b/FlyingSocks/Sources/SocketAddress.swift index ca43f0ea..91da616e 100644 --- a/FlyingSocks/Sources/SocketAddress.swift +++ b/FlyingSocks/Sources/SocketAddress.swift @@ -46,12 +46,7 @@ public protocol SocketAddress: Sendable { extension SocketAddress { public var family: sa_family_t { - var this = self - return withUnsafePointer(to: &this) { - $0.withMemoryRebound(to: sockaddr.self, capacity: 1) { - $0.pointee.sa_family - } - } + withSockAddr { $0.pointee.sa_family } } var size: socklen_t { @@ -154,7 +149,7 @@ extension sockaddr_un: SocketAddress, @unchecked Sendable { public extension SocketAddress { static func make(from storage: sockaddr_storage) throws -> Self { - guard self is sockaddr_storage || storage.ss_family == family else { + guard self is sockaddr_storage.Type || storage.ss_family == family else { throw SocketError.unsupportedAddress } var storage = storage @@ -231,3 +226,33 @@ extension Socket { } } } + +public struct AnySocketAddress: Sendable, SocketAddress { + public static var family: sa_family_t { + sa_family_t(AF_UNSPEC) + } + + private var storage: sockaddr_storage + + public init(_ sa: any SocketAddress) { + storage = sa.makeStorage() + } +} + +public extension SocketAddress { + func withSockAddr(_ body: (_ sa: UnsafePointer) throws -> T) rethrows -> T { + try withUnsafePointer(to: self) { + try $0.withMemoryRebound(to: sockaddr.self, capacity: 1) { sa in + try body(sa) + } + } + } + + mutating func withMutableSockAddr(_ body: (_ sa: UnsafeMutablePointer) throws -> T) rethrows -> T { + try withUnsafeMutablePointer(to: &self) { + try $0.withMemoryRebound(to: sockaddr.self, capacity: 1) { sa in + try body(sa) + } + } + } +} diff --git a/FlyingSocks/Tests/SocketAddressTests.swift b/FlyingSocks/Tests/SocketAddressTests.swift index ba7ea7b8..03f183cc 100644 --- a/FlyingSocks/Tests/SocketAddressTests.swift +++ b/FlyingSocks/Tests/SocketAddressTests.swift @@ -332,6 +332,41 @@ struct SocketAddressTests { } } } + + @Test + func testTypeErasedSockAddress() throws { + var addrIn6 = sockaddr_in6() + addrIn6.sin6_family = sa_family_t(AF_INET6) + addrIn6.sin6_port = UInt16(9090).bigEndian + addrIn6.sin6_addr = try Socket.makeInAddr(fromIP6: "fe80::1") + + let storage = AnySocketAddress(addrIn6) + + #expect(storage.family == sa_family_t(AF_INET6)) + #expect(storage.family == addrIn6.sin6_family) + + withUnsafeBytes(of: addrIn6) { addrIn6Ptr in + let storagePtr = addrIn6Ptr.bindMemory(to: sockaddr_storage.self) + #expect(storagePtr.baseAddress!.pointee.ss_family == sa_family_t(AF_INET6)) + let sockaddrIn6Ptr = addrIn6Ptr.bindMemory(to: sockaddr_in6.self) + #expect(sockaddrIn6Ptr.baseAddress!.pointee.sin6_port == addrIn6.sin6_port) + let addrArray = withUnsafeBytes(of: sockaddrIn6Ptr.baseAddress!.pointee.sin6_addr) { + Array($0.bindMemory(to: UInt8.self)) + } + let expectedArray = withUnsafeBytes(of: addrIn6.sin6_addr) { + Array($0.bindMemory(to: UInt8.self)) + } + #expect(addrArray == expectedArray) + } + + storage.withSockAddr { sa in + #expect(sa.pointee.sa_family == sa_family_t(AF_INET6)) + } + + addrIn6.withSockAddr { sa in + #expect(sa.pointee.sa_family == sa_family_t(AF_INET6)) + } + } } // this is a bit ugly but necessary to get unknown_CheckSize() to function From 05724f1d34af00d4e180fe8b6eb1eaf3bb57f479 Mon Sep 17 00:00:00 2001 From: Luke Howard Date: Thu, 14 Nov 2024 16:51:06 +1100 Subject: [PATCH 7/8] Return type-erased any SocketAddress for new UDP APIs --- FlyingSocks/Sources/AsyncSocket.swift | 12 ++++++------ FlyingSocks/Sources/Socket.swift | 14 +++++++------- FlyingSocks/Tests/AsyncSocketTests.swift | 2 +- 3 files changed, 14 insertions(+), 14 deletions(-) diff --git a/FlyingSocks/Sources/AsyncSocket.swift b/FlyingSocks/Sources/AsyncSocket.swift index 339a3625..8945c731 100644 --- a/FlyingSocks/Sources/AsyncSocket.swift +++ b/FlyingSocks/Sources/AsyncSocket.swift @@ -63,16 +63,16 @@ public extension AsyncSocketPool where Self == SocketPool { public struct AsyncSocket: Sendable { public struct Message: Sendable { - public let peerAddress: sockaddr_storage + public let peerAddress: any SocketAddress public let bytes: [UInt8] public let interfaceIndex: UInt32? - public let localAddress: sockaddr_storage? + public let localAddress: (any SocketAddress)? public init( - peerAddress: sockaddr_storage, + peerAddress: any SocketAddress, bytes: [UInt8], interfaceIndex: UInt32? = nil, - localAddress: sockaddr_storage? = nil + localAddress: (any SocketAddress)? = nil ) { self.peerAddress = peerAddress self.bytes = bytes @@ -148,7 +148,7 @@ public struct AsyncSocket: Sendable { return buffer } - public func receive(atMost length: Int = 4096) async throws -> (sockaddr_storage, [UInt8]) { + public func receive(atMost length: Int = 4096) async throws -> (any SocketAddress, [UInt8]) { try Task.checkCancellation() repeat { @@ -353,7 +353,7 @@ public struct AsyncSocketMessageSequence: AsyncSequence, AsyncIteratorProtocol, #if !canImport(WinSDK) try await socket.receive(atMost: maxMessageLength) #else - let peerAddress: sockaddr_storage + let peerAddress: any SocketAddress let bytes: [UInt8] (peerAddress, bytes) = try await socket.receive(atMost: maxMessageLength) diff --git a/FlyingSocks/Sources/Socket.swift b/FlyingSocks/Sources/Socket.swift index 9ca423cc..61fabbeb 100644 --- a/FlyingSocks/Sources/Socket.swift +++ b/FlyingSocks/Sources/Socket.swift @@ -249,8 +249,8 @@ public struct Socket: Sendable, Hashable { return count } - public func receive(length: Int) throws -> (sockaddr_storage, [UInt8]) { - var address: sockaddr_storage? + public func receive(length: Int) throws -> (any SocketAddress, [UInt8]) { + var address: (any SocketAddress)? let bytes = try [UInt8](unsafeUninitializedCapacity: length) { buffer, count in (address, count) = try receive(into: buffer.baseAddress!, length: length) } @@ -258,7 +258,7 @@ public struct Socket: Sendable, Hashable { return (address!, bytes) } - private func receive(into buffer: UnsafeMutablePointer, length: Int) throws -> (sockaddr_storage, Int) { + private func receive(into buffer: UnsafeMutablePointer, length: Int) throws -> (any SocketAddress, Int) { var addr = sockaddr_storage() var size = socklen_t(MemoryLayout.size) let count = addr.withMutableSockAddr { @@ -277,10 +277,10 @@ public struct Socket: Sendable, Hashable { } #if !canImport(WinSDK) - public func receive(length: Int) throws -> (sockaddr_storage, [UInt8], UInt32?, sockaddr_storage?) { - var peerAddress: sockaddr_storage? + public func receive(length: Int) throws -> (any SocketAddress, [UInt8], UInt32?, (any SocketAddress)?) { + var peerAddress: (any SocketAddress)? var interfaceIndex: UInt32? - var localAddress: sockaddr_storage? + var localAddress: (any SocketAddress)? let bytes = try [UInt8](unsafeUninitializedCapacity: length) { buffer, count in (peerAddress, count, interfaceIndex, localAddress) = try receive(into: buffer.baseAddress!, length: length, flags: 0) @@ -295,7 +295,7 @@ public struct Socket: Sendable, Hashable { into buffer: UnsafeMutablePointer, length: Int, flags: Int32 - ) throws -> (sockaddr_storage, Int, UInt32?, sockaddr_storage?) { + ) throws -> (any SocketAddress, Int, UInt32?, (any SocketAddress)?) { var iov = iovec() var msg = msghdr() var peerAddress = sockaddr_storage() diff --git a/FlyingSocks/Tests/AsyncSocketTests.swift b/FlyingSocks/Tests/AsyncSocketTests.swift index 969e1b18..2fb1e524 100644 --- a/FlyingSocks/Tests/AsyncSocketTests.swift +++ b/FlyingSocks/Tests/AsyncSocketTests.swift @@ -192,7 +192,7 @@ struct AsyncSocketTests { func datagramSocketReceivesChunk_WhenAvailable() async throws { let (s1, s2, addr) = try await AsyncSocket.makeDatagramPair() - async let d2: (sockaddr_storage, [UInt8]) = s2.receive(atMost: 100) + async let d2: (any SocketAddress, [UInt8]) = s2.receive(atMost: 100) // TODO: calling send() on Darwin to an unconnected datagram domain // socket returns EISCONN #if canImport(Darwin) From 77310686d7cadfd15e058b7bd602046e473e77be Mon Sep 17 00:00:00 2001 From: Luke Howard Date: Thu, 14 Nov 2024 23:05:36 +1100 Subject: [PATCH 8/8] Add new send(message:) API that takes AsyncSocket.Message --- FlyingSocks/Sources/AsyncSocket.swift | 17 +++++++++++++++ FlyingSocks/Tests/AsyncSocketTests.swift | 27 +++++++++++++++++++++++- 2 files changed, 43 insertions(+), 1 deletion(-) diff --git a/FlyingSocks/Sources/AsyncSocket.swift b/FlyingSocks/Sources/AsyncSocket.swift index 8945c731..76e5f55f 100644 --- a/FlyingSocks/Sources/AsyncSocket.swift +++ b/FlyingSocks/Sources/AsyncSocket.swift @@ -249,6 +249,23 @@ public struct AsyncSocket: Sendable { ) async throws { try await send(message: Array(message), to: peerAddress, interfaceIndex: interfaceIndex, from: localAddress) } + + public func send(message: Message) async throws { + let localAddress: AnySocketAddress? + + if let unwrappedLocalAddress = message.localAddress { + localAddress = AnySocketAddress(unwrappedLocalAddress) + } else { + localAddress = nil + } + + try await send( + message: message.bytes, + to: AnySocketAddress(message.peerAddress), + interfaceIndex: message.interfaceIndex, + from: localAddress + ) + } #endif public func close() throws { diff --git a/FlyingSocks/Tests/AsyncSocketTests.swift b/FlyingSocks/Tests/AsyncSocketTests.swift index 2fb1e524..3cd54274 100644 --- a/FlyingSocks/Tests/AsyncSocketTests.swift +++ b/FlyingSocks/Tests/AsyncSocketTests.swift @@ -210,7 +210,7 @@ struct AsyncSocketTests { #if !canImport(WinSDK) @Test - func datagramSocketReceivesMessage_WhenAvailable() async throws { + func datagramSocketReceivesMessageTupleAPI_WhenAvailable() async throws { let (s1, s2, addr) = try await AsyncSocket.makeDatagramPair() async let d2: AsyncSocket.Message = s2.receive(atMost: 100) @@ -227,6 +227,31 @@ struct AsyncSocketTests { try? Socket.unlink(addr) } #endif + +#if !canImport(WinSDK) + @Test + func datagramSocketReceivesMessageStructAPI_WhenAvailable() async throws { + let (s1, s2, addr) = try await AsyncSocket.makeDatagramPair() + let messageToSend = AsyncSocket.Message( + peerAddress: addr, + bytes: Array("Swift".data(using: .utf8)!), + localAddress: addr + ) + + async let d2: AsyncSocket.Message = s2.receive(atMost: 100) +#if canImport(Darwin) + try await s1.write("Swift".data(using: .utf8)!) +#else + try await s1.send(message: messageToSend) +#endif + let v2 = try await d2 + #expect(String(data: Data(v2.bytes), encoding: .utf8) == "Swift") + + try s1.close() + try s2.close() + try? Socket.unlink(addr) + } +#endif } extension AsyncSocket {