diff --git a/FlyingFox/Tests/AsyncSocketTests.swift b/FlyingFox/Tests/AsyncSocketTests.swift index 3de1df59..f9a8c641 100644 --- a/FlyingFox/Tests/AsyncSocketTests.swift +++ b/FlyingFox/Tests/AsyncSocketTests.swift @@ -40,7 +40,7 @@ extension AsyncSocket { } static func make(pool: some AsyncSocketPool) throws -> AsyncSocket { - let socket = try Socket(domain: AF_UNIX, type: Socket.stream) + let socket = try Socket(domain: AF_UNIX, type: .stream) return try AsyncSocket(socket: socket, pool: pool) } diff --git a/FlyingSocks/Sources/AsyncSocket.swift b/FlyingSocks/Sources/AsyncSocket.swift index 045f070b..76e5f55f 100644 --- a/FlyingSocks/Sources/AsyncSocket.swift +++ b/FlyingSocks/Sources/AsyncSocket.swift @@ -62,6 +62,25 @@ public extension AsyncSocketPool where Self == SocketPool { public struct AsyncSocket: Sendable { + public struct Message: Sendable { + public let peerAddress: any SocketAddress + public let bytes: [UInt8] + public let interfaceIndex: UInt32? + public let localAddress: (any SocketAddress)? + + public init( + peerAddress: any SocketAddress, + bytes: [UInt8], + interfaceIndex: UInt32? = nil, + localAddress: (any SocketAddress)? = nil + ) { + self.peerAddress = peerAddress + self.bytes = bytes + self.interfaceIndex = interfaceIndex + self.localAddress = localAddress + } + } + public let socket: Socket let pool: any AsyncSocketPool @@ -83,7 +102,7 @@ public struct AsyncSocket: Sendable { pool: some AsyncSocketPool, timeout: TimeInterval = 5) async throws -> Self { try await withThrowingTimeout(seconds: timeout) { - let socket = try Socket(domain: Int32(type(of: address).family), type: Socket.stream) + let socket = try Socket(domain: Int32(type(of: address).family), type: .stream) let asyncSocket = try AsyncSocket(socket: socket, pool: pool) try await asyncSocket.connect(to: address) return asyncSocket @@ -129,6 +148,37 @@ public struct AsyncSocket: Sendable { return buffer } + public func receive(atMost length: Int = 4096) async throws -> (any SocketAddress, [UInt8]) { + try Task.checkCancellation() + + repeat { + do { + return try socket.receive(length: length) + } catch SocketError.blocked { + try await pool.suspendSocket(socket, untilReadyFor: .read) + } catch { + throw error + } + } while true + } + +#if !canImport(WinSDK) + public func receive(atMost length: Int) async throws -> Message { + try Task.checkCancellation() + + repeat { + do { + let (peerAddress, bytes, interfaceIndex, localAddress) = try socket.receive(length: length) + return Message(peerAddress: peerAddress, bytes: bytes, interfaceIndex: interfaceIndex, localAddress: localAddress) + } catch SocketError.blocked { + try await pool.suspendSocket(socket, untilReadyFor: .read) + } catch { + throw error + } + } while true + } +#endif + /// Reads bytes from the socket up to by not over/ /// - Parameter bytes: The max number of bytes to read /// - Returns: an array of the read bytes capped to the number of bytes provided. @@ -163,6 +213,61 @@ public struct AsyncSocket: Sendable { } } + public func send(_ data: [UInt8], to address: some SocketAddress) async throws { + let sent = try await pool.loopUntilReady(for: .write, on: socket) { + try socket.send(data, to: address) + } + guard sent == data.count else { + throw SocketError.disconnected + } + } + + public func send(_ data: Data, to address: some SocketAddress) async throws { + try await send(Array(data), to: address) + } + +#if !canImport(WinSDK) + public func send( + message: [UInt8], + to peerAddress: some SocketAddress, + interfaceIndex: UInt32? = nil, + from localAddress: (some SocketAddress)? = nil + ) async throws { + let sent = try await pool.loopUntilReady(for: .write, on: socket) { + try socket.send(message: message, to: peerAddress, interfaceIndex: interfaceIndex, from: localAddress) + } + guard sent == message.count else { + throw SocketError.disconnected + } + } + + public func send( + message: Data, + to peerAddress: some SocketAddress, + interfaceIndex: UInt32? = nil, + from localAddress: (some SocketAddress)? = nil + ) async throws { + try await send(message: Array(message), to: peerAddress, interfaceIndex: interfaceIndex, from: localAddress) + } + + public func send(message: Message) async throws { + let localAddress: AnySocketAddress? + + if let unwrappedLocalAddress = message.localAddress { + localAddress = AnySocketAddress(unwrappedLocalAddress) + } else { + localAddress = nil + } + + try await send( + message: message.bytes, + to: AnySocketAddress(message.peerAddress), + interfaceIndex: message.interfaceIndex, + from: localAddress + ) + } +#endif + public func close() throws { try socket.close() } @@ -174,12 +279,20 @@ public struct AsyncSocket: Sendable { public var sockets: AsyncSocketSequence { AsyncSocketSequence(socket: self) } + + public var messages: AsyncSocketMessageSequence { + AsyncSocketMessageSequence(socket: self) + } + + public func messages(maxMessageLength: Int) -> AsyncSocketMessageSequence { + AsyncSocketMessageSequence(socket: self, maxMessageLength: maxMessageLength) + } } package extension AsyncSocket { - static func makePair(pool: some AsyncSocketPool) throws -> (AsyncSocket, AsyncSocket) { - let (s1, s2) = try Socket.makePair() + static func makePair(pool: some AsyncSocketPool, type: SocketType = .stream) throws -> (AsyncSocket, AsyncSocket) { + let (s1, s2) = try Socket.makePair(type: type) let a1 = try AsyncSocket(socket: s1, pool: pool) let a2 = try AsyncSocket(socket: s2, pool: pool) return (a1, a2) @@ -237,6 +350,35 @@ public struct AsyncSocketSequence: AsyncSequence, AsyncIteratorProtocol, Sendabl } } +public struct AsyncSocketMessageSequence: AsyncSequence, AsyncIteratorProtocol, Sendable { + public static let DefaultMaxMessageLength: Int = 1500 + + // Windows has a different recvmsg() API signature which is presently unsupported + public typealias Element = AsyncSocket.Message + + private let socket: AsyncSocket + private let maxMessageLength: Int + + public func makeAsyncIterator() -> AsyncSocketMessageSequence { self } + + init(socket: AsyncSocket, maxMessageLength: Int = Self.DefaultMaxMessageLength) { + self.socket = socket + self.maxMessageLength = maxMessageLength + } + + public mutating func next() async throws -> Element? { +#if !canImport(WinSDK) + try await socket.receive(atMost: maxMessageLength) +#else + let peerAddress: any SocketAddress + let bytes: [UInt8] + + (peerAddress, bytes) = try await socket.receive(atMost: maxMessageLength) + return AsyncSocket.Message(peerAddress: peerAddress, bytes: bytes) +#endif + } +} + private actor ClientPoolLoader { static let shared = ClientPoolLoader() diff --git a/FlyingSocks/Sources/Socket+Android.swift b/FlyingSocks/Sources/Socket+Android.swift index fb154a8e..9dab334b 100644 --- a/FlyingSocks/Sources/Socket+Android.swift +++ b/FlyingSocks/Sources/Socket+Android.swift @@ -37,6 +37,10 @@ let EPOLLET: UInt32 = 1 << 31; public extension Socket { typealias FileDescriptorType = Int32 + typealias IovLengthType = UInt + typealias ControlMessageHeaderLengthType = Int + typealias IPv4InterfaceIndexType = Int32 + typealias IPv6InterfaceIndexType = Int32 } extension Socket.FileDescriptor { @@ -45,7 +49,12 @@ extension Socket.FileDescriptor { extension Socket { static let stream = Int32(SOCK_STREAM) + static let datagram = Int32(SOCK_DGRAM) static let in_addr_any = Android.in_addr(s_addr: Android.in_addr_t(0)) + static let ipproto_ip = Int32(IPPROTO_IP) + static let ipproto_ipv6 = Int32(IPPROTO_IPV6) + static let ip_pktinfo = Int32(IP_PKTINFO) + static let ipv6_pktinfo = Int32(IPV6_PKTINFO) static func makeAddressINET(port: UInt16) -> Android.sockaddr_in { Android.sockaddr_in( @@ -175,6 +184,22 @@ extension Socket { static func pollfd(fd: FileDescriptorType, events: Int16, revents: Int16) -> Android.pollfd { Android.pollfd(fd: fd, events: events, revents: revents) } + + static func recvfrom(_ fd: FileDescriptorType, _ buffer: UnsafeMutableRawPointer!, _ nbyte: Int, _ flags: Int32, _ addr: UnsafeMutablePointer!, _ len: UnsafeMutablePointer!) -> Int { + Android.recvfrom(fd, buffer, nbyte, flags, addr, len) + } + + static func sendto(_ fd: FileDescriptorType, _ buffer: UnsafeRawPointer!, _ nbyte: Int, _ flags: Int32, _ destaddr: UnsafePointer!, _ destlen: socklen_t) -> Int { + Android.sendto(fd, buffer, nbyte, flags, destaddr, destlen) + } + + static func recvmsg(_ fd: FileDescriptorType, _ message: UnsafeMutablePointer, _ flags: Int32) -> Int { + Android.recvmsg(fd, message, flags) + } + + static func sendmsg(_ fd: FileDescriptorType, _ message: UnsafePointer, _ flags: Int32) -> Int { + Android.sendmsg(fd, message, flags) + } } #endif diff --git a/FlyingSocks/Sources/Socket+Darwin.swift b/FlyingSocks/Sources/Socket+Darwin.swift index efa3a643..d0b26e52 100644 --- a/FlyingSocks/Sources/Socket+Darwin.swift +++ b/FlyingSocks/Sources/Socket+Darwin.swift @@ -34,6 +34,10 @@ import Darwin public extension Socket { typealias FileDescriptorType = Int32 + typealias IovLengthType = Int + typealias ControlMessageHeaderLengthType = UInt32 + typealias IPv4InterfaceIndexType = UInt32 + typealias IPv6InterfaceIndexType = UInt32 } extension Socket.FileDescriptor { @@ -42,7 +46,12 @@ extension Socket.FileDescriptor { extension Socket { static let stream = Int32(SOCK_STREAM) + static let datagram = Int32(SOCK_DGRAM) static let in_addr_any = Darwin.in_addr(s_addr: Darwin.in_addr_t(0)) + static let ipproto_ip = Int32(IPPROTO_IP) + static let ipproto_ipv6 = Int32(IPPROTO_IPV6) + static let ip_pktinfo = Int32(IP_PKTINFO) + static let ipv6_pktinfo = Int32(50) // __APPLE_USE_RFC_2292 static func makeAddressINET(port: UInt16) -> Darwin.sockaddr_in { Darwin.sockaddr_in( @@ -176,6 +185,22 @@ extension Socket { static func pollfd(fd: FileDescriptorType, events: Int16, revents: Int16) -> Darwin.pollfd { Darwin.pollfd(fd: fd, events: events, revents: revents) } + + static func recvfrom(_ fd: FileDescriptorType, _ buffer: UnsafeMutableRawPointer!, _ nbyte: Int, _ flags: Int32, _ addr: UnsafeMutablePointer!, _ len: UnsafeMutablePointer!) -> Int { + Darwin.recvfrom(fd, buffer, nbyte, flags, addr, len) + } + + static func sendto(_ fd: FileDescriptorType, _ buffer: UnsafeRawPointer!, _ nbyte: Int, _ flags: Int32, _ destaddr: UnsafePointer!, _ destlen: socklen_t) -> Int { + Darwin.sendto(fd, buffer, nbyte, flags, destaddr, destlen) + } + + static func recvmsg(_ fd: FileDescriptorType, _ message: UnsafeMutablePointer, _ flags: Int32) -> Int { + Darwin.recvmsg(fd, message, flags) + } + + static func sendmsg(_ fd: FileDescriptorType, _ message: UnsafePointer, _ flags: Int32) -> Int { + Darwin.sendmsg(fd, message, flags) + } } #endif diff --git a/FlyingSocks/Sources/Socket+Glibc.swift b/FlyingSocks/Sources/Socket+Glibc.swift index 9880ef8c..2cec8ca1 100644 --- a/FlyingSocks/Sources/Socket+Glibc.swift +++ b/FlyingSocks/Sources/Socket+Glibc.swift @@ -34,6 +34,10 @@ import Glibc public extension Socket { typealias FileDescriptorType = Int32 + typealias IovLengthType = Int + typealias ControlMessageHeaderLengthType = Int + typealias IPv4InterfaceIndexType = Int32 + typealias IPv6InterfaceIndexType = UInt32 } extension Socket.FileDescriptor { @@ -42,7 +46,12 @@ extension Socket.FileDescriptor { extension Socket { static let stream = Int32(SOCK_STREAM.rawValue) + static let datagram = Int32(SOCK_DGRAM.rawValue) static let in_addr_any = Glibc.in_addr(s_addr: Glibc.in_addr_t(0)) + static let ipproto_ip = Int32(IPPROTO_IP) + static let ipproto_ipv6 = Int32(IPPROTO_IPV6) + static let ip_pktinfo = Int32(IP_PKTINFO) + static let ipv6_pktinfo = Int32(IPV6_PKTINFO) static func makeAddressINET(port: UInt16) -> Glibc.sockaddr_in { Glibc.sockaddr_in( @@ -172,6 +181,27 @@ extension Socket { static func pollfd(fd: FileDescriptorType, events: Int16, revents: Int16) -> Glibc.pollfd { Glibc.pollfd(fd: fd, events: events, revents: revents) } + + static func recvfrom(_ fd: FileDescriptorType, _ buffer: UnsafeMutableRawPointer!, _ nbyte: Int, _ flags: Int32, _ addr: UnsafeMutablePointer!, _ len: UnsafeMutablePointer!) -> Int { + Glibc.recvfrom(fd, buffer, nbyte, flags, addr, len) + } + + static func sendto(_ fd: FileDescriptorType, _ buffer: UnsafeRawPointer!, _ nbyte: Int, _ flags: Int32, _ destaddr: UnsafePointer!, _ destlen: socklen_t) -> Int { + Glibc.sendto(fd, buffer, nbyte, flags, destaddr, destlen) + } + + static func recvmsg(_ fd: FileDescriptorType, _ message: UnsafeMutablePointer, _ flags: Int32) -> Int { + Glibc.recvmsg(fd, message, flags) + } + + static func sendmsg(_ fd: FileDescriptorType, _ message: UnsafePointer, _ flags: Int32) -> Int { + Glibc.sendmsg(fd, message, flags) + } +} + +struct in6_pktinfo { + var ipi6_addr: in6_addr + var ipi6_ifindex: CUnsignedInt } #endif diff --git a/FlyingSocks/Sources/Socket+Musl.swift b/FlyingSocks/Sources/Socket+Musl.swift index 5f285fd4..4d823717 100644 --- a/FlyingSocks/Sources/Socket+Musl.swift +++ b/FlyingSocks/Sources/Socket+Musl.swift @@ -34,6 +34,10 @@ import Musl public extension Socket { typealias FileDescriptorType = Int32 + typealias IovLengthType = Int + typealias ControlMessageHeaderLengthType = UInt32 + typealias IPv4InterfaceIndexType = Int32 + typealias IPv6InterfaceIndexType = UInt32 } extension Socket.FileDescriptor { @@ -42,7 +46,12 @@ extension Socket.FileDescriptor { extension Socket { static let stream = Int32(SOCK_STREAM) + static let datagram = Int32(SOCK_DGRAM) static let in_addr_any = Musl.in_addr(s_addr: Musl.in_addr_t(0)) + static let ipproto_ip = Int32(IPPROTO_IP) + static let ipproto_ipv6 = Int32(IPPROTO_IPV6) + static let ip_pktinfo = Int32(IP_PKTINFO) + static let ipv6_pktinfo = Int32(IPV6_PKTINFO) static func makeAddressINET(port: UInt16) -> Musl.sockaddr_in { Musl.sockaddr_in( @@ -172,6 +181,22 @@ extension Socket { static func pollfd(fd: FileDescriptorType, events: Int16, revents: Int16) -> Musl.pollfd { Musl.pollfd(fd: fd, events: events, revents: revents) } + + static func recvfrom(_ fd: FileDescriptorType, _ buffer: UnsafeMutableRawPointer!, _ nbyte: Int, _ flags: Int32, _ addr: UnsafeMutablePointer!, _ len: UnsafeMutablePointer!) -> Int { + Musl.recvfrom(fd, buffer, nbyte, flags, addr, len) + } + + static func sendto(_ fd: FileDescriptorType, _ buffer: UnsafeRawPointer!, _ nbyte: Int, _ flags: Int32, _ destaddr: UnsafePointer!, _ destlen: socklen_t) -> Int { + Musl.sendto(fd, buffer, nbyte, flags, destaddr, destlen) + } + + static func recvmsg(_ fd: FileDescriptorType, _ message: UnsafeMutablePointer, _ flags: Int32) -> Int { + Musl.recvmsg(fd, message, flags) + } + + static func sendmsg(_ fd: FileDescriptorType, _ message: UnsafePointer, _ flags: Int32) -> Int { + Musl.sendmsg(fd, message, flags) + } } #endif diff --git a/FlyingSocks/Sources/Socket+WinSock2.swift b/FlyingSocks/Sources/Socket+WinSock2.swift index 0c2c086f..d2b87397 100755 --- a/FlyingSocks/Sources/Socket+WinSock2.swift +++ b/FlyingSocks/Sources/Socket+WinSock2.swift @@ -44,6 +44,10 @@ public typealias sa_family_t = UInt8 public extension Socket { typealias FileDescriptorType = UInt64 + typealias IovLengthType = UInt + typealias ControlMessageHeaderLengthType = DWORD + typealias IPv4InterfaceIndexType = ULONG + typealias IPv6InterfaceIndexType = ULONG } extension Socket.FileDescriptor { @@ -52,7 +56,12 @@ extension Socket.FileDescriptor { extension Socket { static let stream = Int32(SOCK_STREAM) + static let datagram = Int32(SOCK_DGRAM) static let in_addr_any = WinSDK.in_addr() + static let ipproto_ip = Int32(IPPROTO_IP) + static let ipproto_ipv6 = Int32(IPPROTO_IPV6) + static let ip_pktinfo = Int32(IP_PKTINFO) + static let ipv6_pktinfo = Int32(IPV6_PKTINFO) static func makeAddressINET(port: UInt16) -> WinSDK.sockaddr_in { WinSDK.sockaddr_in( @@ -184,6 +193,22 @@ extension Socket { static func pollfd(fd: FileDescriptorType, events: Int16, revents: Int16) -> WinSDK.WSAPOLLFD { WinSDK.WSAPOLLFD(fd: fd, events: events, revents: revents) } + + static func recvfrom(_ fd: FileDescriptorType, _ buffer: UnsafeMutableRawPointer!, _ nbyte: Int, _ flags: Int32, _ addr: UnsafeMutablePointer!, _ len: UnsafeMutablePointer!) -> Int { + WinSDK.recvfrom(fd, buffer, nbyte, flags, addr, len) + } + + static func sendto(_ fd: FileDescriptorType, _ buffer: UnsafeRawPointer!, _ nbyte: Int, _ flags: Int32, _ destaddr: UnsafePointer!, _ destlen: socklen_t) -> Int { + WinSDK.sendto(fd, buffer, nbyte, flags, destaddr, destlen) + } + + static func recvmsg(_ fd: FileDescriptorType, _ message: UnsafeMutablePointer, _ flags: Int32) -> Int { + WinSDK.recvmsg(fd, message, flags) + } + + static func sendmsg(_ fd: FileDescriptorType, _ message: UnsafePointer, _ flags: Int32) -> Int { + WinSDK.sendmsg(fd, message, flags) + } } #endif diff --git a/FlyingSocks/Sources/Socket.swift b/FlyingSocks/Sources/Socket.swift index 7de00526..61fabbeb 100644 --- a/FlyingSocks/Sources/Socket.swift +++ b/FlyingSocks/Sources/Socket.swift @@ -36,6 +36,22 @@ import WinSDK.WinSock2 #endif import Foundation +public enum SocketType: Sendable { + case stream + case datagram +} + +extension SocketType { + var rawValue: Int32 { + switch self { + case .stream: + Socket.stream + case .datagram: + Socket.datagram + } + } +} + public struct Socket: Sendable, Hashable { public let file: FileDescriptor @@ -53,15 +69,23 @@ public struct Socket: Sendable, Hashable { } public init(domain: Int32) throws { - try self.init(domain: domain, type: Socket.stream) + try self.init(domain: domain, type: .stream) } + @available(*, deprecated, message: "type is now SocketType") public init(domain: Int32, type: Int32) throws { let descriptor = FileDescriptor(rawValue: Socket.socket(domain, type, 0)) guard descriptor != .invalid else { throw SocketError.makeFailed("CreateSocket") } self.file = descriptor + if type == SocketType.datagram.rawValue { + try setPktInfo(domain: domain) + } + } + + public init(domain: Int32, type: SocketType) throws { + try self.init(domain: domain, type: type.rawValue) } public var flags: Flags { @@ -80,6 +104,29 @@ public struct Socket: Sendable, Hashable { } } + // enable return of ip_pktinfo/ipv6_pktinfo on recvmsg() + private func setPktInfo(domain: Int32) throws { + var enable = Int32(1) + let level: Int32 + let name: Int32 + + switch domain { + case AF_INET: + level = Socket.ipproto_ip + name = Self.ip_pktinfo + case AF_INET6: + level = Socket.ipproto_ipv6 + name = Self.ipv6_pktinfo + default: + return + } + + let result = Socket.setsockopt(file.rawValue, level, name, &enable, socklen_t(MemoryLayout.size)) + guard result >= 0 else { + throw SocketError.makeFailed("SetPktInfoOption") + } + } + public func setValue(_ value: O.Value, for option: O) throws { var value = option.makeSocketValue(from: value) let result = withUnsafeBytes(of: &value) { @@ -99,12 +146,9 @@ public struct Socket: Sendable, Hashable { return option.makeValue(from: valuePtr.pointee) } - public func bind(to address: A) throws { - var addr = address - let result = withUnsafePointer(to: &addr) { - $0.withMemoryRebound(to: sockaddr.self, capacity: 1) { - Socket.bind(file.rawValue, $0, socklen_t(MemoryLayout.size)) - } + public func bind(to address: some SocketAddress) throws { + let result = address.withSockAddr { + Socket.bind(file.rawValue, $0, address.size) } guard result >= 0 else { throw SocketError.makeFailed("Bind") @@ -123,10 +167,8 @@ public struct Socket: Sendable, Hashable { var addr = sockaddr_storage() var len = socklen_t(MemoryLayout.size) - let result = withUnsafeMutablePointer(to: &addr) { - $0.withMemoryRebound(to: sockaddr.self, capacity: 1) { - Socket.getpeername(file.rawValue, $0, &len) - } + let result = addr.withMutableSockAddr { + Socket.getpeername(file.rawValue, $0, &len) } if result != 0 { throw SocketError.makeFailed("GetPeerName") @@ -138,10 +180,8 @@ public struct Socket: Sendable, Hashable { var addr = sockaddr_storage() var len = socklen_t(MemoryLayout.size) - let result = withUnsafeMutablePointer(to: &addr) { - $0.withMemoryRebound(to: sockaddr.self, capacity: 1) { - Socket.getsockname(file.rawValue, $0, &len) - } + let result = addr.withMutableSockAddr { + Socket.getsockname(file.rawValue, $0, &len) } if result != 0 { throw SocketError.makeFailed("GetSockName") @@ -153,10 +193,8 @@ public struct Socket: Sendable, Hashable { var addr = sockaddr_storage() var len = socklen_t(MemoryLayout.size) - let newFile = withUnsafeMutablePointer(to: &addr) { - $0.withMemoryRebound(to: sockaddr.self, capacity: 1) { - FileDescriptor(rawValue: Socket.accept(file.rawValue, $0, &len)) - } + let newFile = addr.withMutableSockAddr { + FileDescriptor(rawValue: Socket.accept(file.rawValue, $0, &len)) } guard newFile != .invalid else { @@ -170,12 +208,9 @@ public struct Socket: Sendable, Hashable { return (newFile, addr) } - public func connect(to address: A) throws { - var addr = address - let result = withUnsafePointer(to: &addr) { - $0.withMemoryRebound(to: sockaddr.self, capacity: 1) { - Socket.connect(file.rawValue, $0, socklen_t(MemoryLayout.size)) - } + public func connect(to address: some SocketAddress) throws { + let result = address.withSockAddr { + Socket.connect(file.rawValue, $0, address.size) } guard result >= 0 || errno == EISCONN else { if errno == EINPROGRESS { @@ -214,6 +249,100 @@ public struct Socket: Sendable, Hashable { return count } + public func receive(length: Int) throws -> (any SocketAddress, [UInt8]) { + var address: (any SocketAddress)? + let bytes = try [UInt8](unsafeUninitializedCapacity: length) { buffer, count in + (address, count) = try receive(into: buffer.baseAddress!, length: length) + } + + return (address!, bytes) + } + + private func receive(into buffer: UnsafeMutablePointer, length: Int) throws -> (any SocketAddress, Int) { + var addr = sockaddr_storage() + var size = socklen_t(MemoryLayout.size) + let count = addr.withMutableSockAddr { + Socket.recvfrom(file.rawValue, buffer, length, 0, $0, &size) + } + guard count > 0 else { + if errno == EWOULDBLOCK { + throw SocketError.blocked + } else if errno == EBADF || count == 0 { + throw SocketError.disconnected + } else { + throw SocketError.makeFailed("RecvFrom") + } + } + return (addr, count) + } + +#if !canImport(WinSDK) + public func receive(length: Int) throws -> (any SocketAddress, [UInt8], UInt32?, (any SocketAddress)?) { + var peerAddress: (any SocketAddress)? + var interfaceIndex: UInt32? + var localAddress: (any SocketAddress)? + + let bytes = try [UInt8](unsafeUninitializedCapacity: length) { buffer, count in + (peerAddress, count, interfaceIndex, localAddress) = try receive(into: buffer.baseAddress!, length: length, flags: 0) + } + + return (peerAddress!, bytes, interfaceIndex, localAddress) + } + + private static let ControlMsgBufferSize = MemoryLayout.size + max(MemoryLayout.size, MemoryLayout.size) + + private func receive( + into buffer: UnsafeMutablePointer, + length: Int, + flags: Int32 + ) throws -> (any SocketAddress, Int, UInt32?, (any SocketAddress)?) { + var iov = iovec() + var msg = msghdr() + var peerAddress = sockaddr_storage() + var localAddress: sockaddr_storage? + var interfaceIndex: UInt32? + var controlMsgBuffer = [UInt8](repeating: 0, count: Socket.ControlMsgBufferSize) + + iov.iov_base = UnsafeMutableRawPointer(buffer) + iov.iov_len = IovLengthType(length) + + let count = withUnsafeMutablePointer(to: &iov) { iov in + msg.msg_iov = iov + msg.msg_iovlen = 1 + msg.msg_namelen = socklen_t(MemoryLayout.size) + + return withUnsafeMutablePointer(to: &peerAddress) { peerAddress in + msg.msg_name = UnsafeMutableRawPointer(peerAddress) + + return controlMsgBuffer.withUnsafeMutableBytes { controlMsgBuffer in + msg.msg_control = UnsafeMutableRawPointer(controlMsgBuffer.baseAddress) + msg.msg_controllen = ControlMessageHeaderLengthType(controlMsgBuffer.count) + + let count = Socket.recvmsg(file.rawValue, &msg, flags) + + if count > 0, msg.msg_controllen != 0 { + (interfaceIndex, localAddress) = Socket.getPacketInfoControl(msghdr: msg) + } + + return count + } + } + } + + guard count > 0 else { + if errno == EWOULDBLOCK || errno == EAGAIN { + throw SocketError.blocked + } else if errno == EBADF || count == 0 { + throw SocketError.disconnected + } else { + throw SocketError.makeFailed("RecvMsg") + } + } + + return (peerAddress, count, interfaceIndex, localAddress) + } +#endif + public func write(_ data: Data, from index: Data.Index = 0) throws -> Data.Index { precondition(index >= 0) guard index < data.endIndex else { return data.endIndex } @@ -237,6 +366,95 @@ public struct Socket: Sendable, Hashable { return sent } + public func send(_ bytes: [UInt8], to address: some SocketAddress) throws -> Int { + try bytes.withUnsafeBytes { buffer in + try send(buffer.baseAddress!, length: bytes.count, to: address) + } + } + + private func send(_ pointer: UnsafeRawPointer, length: Int, to address: some SocketAddress) throws -> Int { + let sent = address.withSockAddr { + Socket.sendto(file.rawValue, pointer, length, 0, $0, address.size) + } + guard sent >= 0 else { + if errno == EWOULDBLOCK || errno == EAGAIN { + throw SocketError.blocked + } else { + throw SocketError.makeFailed("SendTo") + } + } + return sent + } + +#if !canImport(WinSDK) + public func send( + message: [UInt8], + to peerAddress: some SocketAddress, + interfaceIndex: UInt32? = nil, + from localAddress: (some SocketAddress)? = nil + ) throws -> Int { + try message.withUnsafeBytes { buffer in + try send( + buffer.baseAddress!, + length: buffer.count, + flags: 0, + to: peerAddress, + interfaceIndex: interfaceIndex, + from: localAddress + ) + } + } + + private func send( + _ pointer: UnsafeRawPointer, + length: Int, + flags: Int32, + to peerAddress: some SocketAddress, + interfaceIndex: UInt32? = nil, + from localAddress: (some SocketAddress)? = nil + ) throws -> Int { + var iov = iovec() + var msg = msghdr() + let family = peerAddress.family + + iov.iov_base = UnsafeMutableRawPointer(mutating: pointer) + iov.iov_len = IovLengthType(length) + + let sent = withUnsafeMutablePointer(to: &iov) { iov in + var peerAddress = peerAddress + + msg.msg_iov = iov + msg.msg_iovlen = 1 + msg.msg_namelen = peerAddress.size + + return withUnsafeMutablePointer(to: &peerAddress) { peerAddress in + msg.msg_name = UnsafeMutableRawPointer(peerAddress) + + return Socket.withPacketInfoControl( + family: family, + interfaceIndex: interfaceIndex, + address: localAddress) { control, controllen in + if let control { + msg.msg_control = UnsafeMutableRawPointer(mutating: control) + msg.msg_controllen = controllen + } + return Socket.sendmsg(file.rawValue, &msg, flags) + } + } + } + + guard sent >= 0 else { + if errno == EWOULDBLOCK || errno == EAGAIN { + throw SocketError.blocked + } else { + throw SocketError.makeFailed("SendMsg") + } + } + + return sent + } +#endif + public func close() throws { if Socket.close(file.rawValue) == -1 { throw SocketError.makeFailed("Close") @@ -349,8 +567,8 @@ public extension SocketOption where Self == Int32SocketOption { package extension Socket { - static func makePair(flags: Flags? = nil) throws -> (Socket, Socket) { - let (file1, file2) = Socket.socketpair(AF_UNIX, Socket.stream, 0) + static func makePair(flags: Flags? = nil, type: SocketType = .stream) throws -> (Socket, Socket) { + let (file1, file2) = Socket.socketpair(AF_UNIX, type.rawValue, 0) guard file1 > -1, file2 > -1 else { throw SocketError.makeFailed("SocketPair") } @@ -364,7 +582,134 @@ package extension Socket { return (s1, s2) } - static func makeNonBlockingPair() throws -> (Socket, Socket) { - try Socket.makePair(flags: .nonBlocking) + static func makeNonBlockingPair(type: SocketType = .stream) throws -> (Socket, Socket) { + try Socket.makePair(flags: .nonBlocking, type: type) } } + +#if !canImport(WinSDK) +fileprivate extension Socket { + // https://github.com/swiftlang/swift-evolution/blob/main/proposals/0138-unsaferawbufferpointer.md + private static func withControlMessage( + control: UnsafeRawPointer, + controllen: ControlMessageHeaderLengthType, + _ body: (cmsghdr, UnsafeRawBufferPointer) -> () + ) { + let controlBuffer = UnsafeRawBufferPointer(start: control, count: Int(controllen)) + var cmsgHeaderIndex = 0 + + while true { + let cmsgDataIndex = cmsgHeaderIndex + MemoryLayout.stride + + if cmsgDataIndex > controllen { + break + } + + let header = controlBuffer.load(fromByteOffset: cmsgHeaderIndex, as: cmsghdr.self) + if Int(header.cmsg_len) < MemoryLayout.stride { + break + } + + cmsgHeaderIndex = cmsgDataIndex + cmsgHeaderIndex += Int(header.cmsg_len) - MemoryLayout.stride + if cmsgHeaderIndex > controlBuffer.count { + break + } + body(header, UnsafeRawBufferPointer(rebasing: controlBuffer[cmsgDataIndex...alignment - 1 + cmsgHeaderIndex &= ~(MemoryLayout.alignment - 1) + } + } + + static func getPacketInfoControl( + msghdr: msghdr + ) -> (UInt32?, sockaddr_storage?) { + var interfaceIndex: UInt32? + var localAddress = sockaddr_storage() + + withControlMessage(control: msghdr.msg_control, controllen: msghdr.msg_controllen) { cmsghdr, cmsgdata in + switch cmsghdr.cmsg_level { + case Socket.ipproto_ip: + guard cmsghdr.cmsg_type == Socket.ip_pktinfo else { break } + cmsgdata.baseAddress!.withMemoryRebound(to: in_pktinfo.self, capacity: 1) { pktinfo in + var sin = sockaddr_in() + sin.sin_addr = pktinfo.pointee.ipi_addr + interfaceIndex = UInt32(pktinfo.pointee.ipi_ifindex) + localAddress = sin.makeStorage() + } + case Socket.ipproto_ipv6: + guard cmsghdr.cmsg_type == Socket.ipv6_pktinfo else { break } + cmsgdata.baseAddress!.withMemoryRebound(to: in6_pktinfo.self, capacity: 1) { pktinfo in + var sin6 = sockaddr_in6() + sin6.sin6_addr = pktinfo.pointee.ipi6_addr + interfaceIndex = UInt32(pktinfo.pointee.ipi6_ifindex) + localAddress = sin6.makeStorage() + } + default: + break + } + } + + return (interfaceIndex, interfaceIndex != nil ? localAddress : nil) + } + + static func withPacketInfoControl( + family: sa_family_t, + interfaceIndex: UInt32?, + address: (some SocketAddress)?, + _ body: (UnsafePointer?, ControlMessageHeaderLengthType) -> T + ) -> T { + switch Int32(family) { + case AF_INET: + let buffer = ManagedBuffer.create(minimumCapacity: 1) { buffer in + buffer.withUnsafeMutablePointers { header, element in + header.pointee.cmsg_len = ControlMessageHeaderLengthType(MemoryLayout.size + MemoryLayout.size) + header.pointee.cmsg_level = SOL_SOCKET + header.pointee.cmsg_type = Socket.ipproto_ip + element.pointee.ipi_ifindex = IPv4InterfaceIndexType(interfaceIndex ?? 0) + if let address { + var address = address + withUnsafePointer(to: &address) { + $0.withMemoryRebound(to: sockaddr_in.self, capacity: 1) { + element.pointee.ipi_addr = $0.pointee.sin_addr + } + } + } else { + element.pointee.ipi_addr.s_addr = 0 + } + + return header.pointee + } + } + + return buffer.withUnsafeMutablePointerToHeader { body($0, ControlMessageHeaderLengthType($0.pointee.cmsg_len)) } + case AF_INET6: + let buffer = ManagedBuffer.create(minimumCapacity: 1) { buffer in + buffer.withUnsafeMutablePointers { header, element in + header.pointee.cmsg_len = ControlMessageHeaderLengthType(MemoryLayout.size + MemoryLayout.size) + header.pointee.cmsg_level = SOL_SOCKET + header.pointee.cmsg_type = Socket.ipproto_ipv6 + element.pointee.ipi6_ifindex = IPv6InterfaceIndexType(interfaceIndex ?? 0) + if let address { + var address = address + withUnsafePointer(to: &address) { + $0.withMemoryRebound(to: sockaddr_in6.self, capacity: 1) { + element.pointee.ipi6_addr = $0.pointee.sin6_addr + } + } + } else { + element.pointee.ipi6_addr = in6_addr() + } + + return header.pointee + } + } + + return buffer.withUnsafeMutablePointerToHeader { body($0, ControlMessageHeaderLengthType($0.pointee.cmsg_len)) } + default: + return body(nil, 0) + } + } +} +#endif diff --git a/FlyingSocks/Sources/SocketAddress.swift b/FlyingSocks/Sources/SocketAddress.swift index 68839421..91da616e 100644 --- a/FlyingSocks/Sources/SocketAddress.swift +++ b/FlyingSocks/Sources/SocketAddress.swift @@ -44,6 +44,38 @@ public protocol SocketAddress: Sendable { static var family: sa_family_t { get } } +extension SocketAddress { + public var family: sa_family_t { + withSockAddr { $0.pointee.sa_family } + } + + var size: socklen_t { + // this needs to work with sockaddr_storage, hence the switch + switch Int32(family) { + case AF_INET: + socklen_t(MemoryLayout.size) + case AF_INET6: + socklen_t(MemoryLayout.size) + case AF_UNIX: + socklen_t(MemoryLayout.size) + default: + 0 + } + } + + public func makeStorage() -> sockaddr_storage { + var storage = sockaddr_storage() + + withUnsafeMutablePointer(to: &storage) { + $0.withMemoryRebound(to: Self.self, capacity: 1) { + $0.pointee = self + } + } + + return storage + } +} + public extension SocketAddress where Self == sockaddr_in { static func inet(port: UInt16) -> Self { @@ -82,6 +114,10 @@ public extension SocketAddress where Self == sockaddr_un { } #if compiler(>=6.0) +extension sockaddr_storage: SocketAddress, @retroactive @unchecked Sendable { + public static let family = sa_family_t(AF_UNSPEC) +} + extension sockaddr_in: SocketAddress, @retroactive @unchecked Sendable { public static let family = sa_family_t(AF_INET) } @@ -94,6 +130,10 @@ extension sockaddr_un: SocketAddress, @retroactive @unchecked Sendable { public static let family = sa_family_t(AF_UNIX) } #else +extension sockaddr_storage: SocketAddress, @unchecked Sendable { + public static let family = sa_family_t(AF_UNSPEC) +} + extension sockaddr_in: SocketAddress, @unchecked Sendable { public static let family = sa_family_t(AF_INET) } @@ -109,7 +149,7 @@ extension sockaddr_un: SocketAddress, @unchecked Sendable { public extension SocketAddress { static func make(from storage: sockaddr_storage) throws -> Self { - guard storage.ss_family == family else { + guard self is sockaddr_storage.Type || storage.ss_family == family else { throw SocketError.unsupportedAddress } var storage = storage @@ -186,3 +226,33 @@ extension Socket { } } } + +public struct AnySocketAddress: Sendable, SocketAddress { + public static var family: sa_family_t { + sa_family_t(AF_UNSPEC) + } + + private var storage: sockaddr_storage + + public init(_ sa: any SocketAddress) { + storage = sa.makeStorage() + } +} + +public extension SocketAddress { + func withSockAddr(_ body: (_ sa: UnsafePointer) throws -> T) rethrows -> T { + try withUnsafePointer(to: self) { + try $0.withMemoryRebound(to: sockaddr.self, capacity: 1) { sa in + try body(sa) + } + } + } + + mutating func withMutableSockAddr(_ body: (_ sa: UnsafeMutablePointer) throws -> T) rethrows -> T { + try withUnsafeMutablePointer(to: &self) { + try $0.withMemoryRebound(to: sockaddr.self, capacity: 1) { sa in + try body(sa) + } + } + } +} diff --git a/FlyingSocks/Sources/SocketPool.swift b/FlyingSocks/Sources/SocketPool.swift index 0f8a03e3..c42feadf 100644 --- a/FlyingSocks/Sources/SocketPool.swift +++ b/FlyingSocks/Sources/SocketPool.swift @@ -93,7 +93,7 @@ public final actor SocketPool: AsyncSocketPool { public func run() async throws { guard state == .ready else { throw Error("Not Ready") } state = .running - defer { cancellAll() } + defer { cancelAll() } repeat { if waiting.isEmpty { @@ -151,11 +151,11 @@ public final actor SocketPool: AsyncSocketPool { case complete } - private func cancellAll() { - logger.logInfo("SocketPoll cancellAll") + private func cancelAll() { + logger.logInfo("SocketPoll cancelAll") try? queue.stop() state = .complete - waiting.cancellAll() + waiting.cancelAll() waiting = Waiting() if let loop { self.loop = nil @@ -270,7 +270,7 @@ public final actor SocketPool: AsyncSocketPool { } } - mutating func cancellAll() { + mutating func cancelAll() { let continuations = storage.values.flatMap(\.values).map(\.continuation) storage = [:] for continuation in continuations { diff --git a/FlyingSocks/Tests/AsyncSocketTests.swift b/FlyingSocks/Tests/AsyncSocketTests.swift index ac2d872f..3cd54274 100644 --- a/FlyingSocks/Tests/AsyncSocketTests.swift +++ b/FlyingSocks/Tests/AsyncSocketTests.swift @@ -187,31 +187,111 @@ struct AsyncSocketTests { try await sockets.next() == nil ) } + + @Test + func datagramSocketReceivesChunk_WhenAvailable() async throws { + let (s1, s2, addr) = try await AsyncSocket.makeDatagramPair() + + async let d2: (any SocketAddress, [UInt8]) = s2.receive(atMost: 100) + // TODO: calling send() on Darwin to an unconnected datagram domain + // socket returns EISCONN +#if canImport(Darwin) + try await s1.write("Swift".data(using: .utf8)!) +#else + try await s1.send("Swift".data(using: .utf8)!, to: addr) +#endif + let v2 = try await d2 + #expect(String(data: Data(v2.1), encoding: .utf8) == "Swift") + + try s1.close() + try s2.close() + try? Socket.unlink(addr) + } + +#if !canImport(WinSDK) + @Test + func datagramSocketReceivesMessageTupleAPI_WhenAvailable() async throws { + let (s1, s2, addr) = try await AsyncSocket.makeDatagramPair() + + async let d2: AsyncSocket.Message = s2.receive(atMost: 100) +#if canImport(Darwin) + try await s1.write("Swift".data(using: .utf8)!) +#else + try await s1.send(message: "Swift".data(using: .utf8)!, to: addr, from: addr) +#endif + let v2 = try await d2 + #expect(String(data: Data(v2.bytes), encoding: .utf8) == "Swift") + + try s1.close() + try s2.close() + try? Socket.unlink(addr) + } +#endif + +#if !canImport(WinSDK) + @Test + func datagramSocketReceivesMessageStructAPI_WhenAvailable() async throws { + let (s1, s2, addr) = try await AsyncSocket.makeDatagramPair() + let messageToSend = AsyncSocket.Message( + peerAddress: addr, + bytes: Array("Swift".data(using: .utf8)!), + localAddress: addr + ) + + async let d2: AsyncSocket.Message = s2.receive(atMost: 100) +#if canImport(Darwin) + try await s1.write("Swift".data(using: .utf8)!) +#else + try await s1.send(message: messageToSend) +#endif + let v2 = try await d2 + #expect(String(data: Data(v2.bytes), encoding: .utf8) == "Swift") + + try s1.close() + try s2.close() + try? Socket.unlink(addr) + } +#endif } extension AsyncSocket { - static func make() async throws -> AsyncSocket { - try await make(pool: .client) + static func make(type: SocketType = .stream) async throws -> AsyncSocket { + try await make(pool: .client, type: type) } static func makeListening(pool: some AsyncSocketPool) throws -> AsyncSocket { let address = sockaddr_un.unix(path: #function) try? Socket.unlink(address) - let socket = try Socket(domain: AF_UNIX, type: Socket.stream) + let socket = try Socket(domain: AF_UNIX, type: .stream) try socket.setValue(true, for: .localAddressReuse) try socket.bind(to: address) try socket.listen() return try AsyncSocket(socket: socket, pool: pool) } - static func make(pool: some AsyncSocketPool) throws -> AsyncSocket { - let socket = try Socket(domain: AF_UNIX, type: Socket.stream) + static func make(pool: some AsyncSocketPool, type: SocketType = .stream) throws -> AsyncSocket { + let socket = try Socket(domain: AF_UNIX, type: type) return try AsyncSocket(socket: socket, pool: pool) } + static func makeDatagramPair() async throws -> (AsyncSocket, AsyncSocket, sockaddr_un) { + let socketPair = try await makePair(pool: .client, type: .datagram) + guard let endpoint = FileManager.default.makeTemporaryFile() else { + throw SocketError.makeFailed("MakeTemporaryFile") + } + let addr = sockaddr_un.unix(path: endpoint.path) + + try socketPair.1.socket.bind(to: addr) +#if canImport(Darwin) + try await socketPair.0.connect(to: addr) +#endif + + return (socketPair.0, socketPair.1, addr) + } + static func makePair() async throws -> (AsyncSocket, AsyncSocket) { - try await makePair(pool: .client) + try await makePair(pool: .client, type: .stream) } func writeString(_ string: String) async throws { diff --git a/FlyingSocks/Tests/FileManager+TemporaryFile.swift b/FlyingSocks/Tests/FileManager+TemporaryFile.swift new file mode 100644 index 00000000..3d9837b2 --- /dev/null +++ b/FlyingSocks/Tests/FileManager+TemporaryFile.swift @@ -0,0 +1,42 @@ +// Copyright 2019 The TensorFlow Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +@testable import FlyingSocks +import Foundation + +extension FileManager { + func makeTemporaryFile() -> URL? { + let dirPath = temporaryDirectory.appendingPathComponent("FlyingSocks.XXXXXX") + return dirPath.withUnsafeFileSystemRepresentation { maybePath in + guard let path = maybePath else { return nil } + var mutablePath = Array(repeating: Int8(0), count: Int(PATH_MAX)) + mutablePath.withUnsafeMutableBytes { mutablePathBufferPtr in + mutablePathBufferPtr.baseAddress!.copyMemory( + from: path, byteCount: Int(strlen(path)) + 1) + } + guard mktemp(&mutablePath) != nil else { return nil } + return URL( + fileURLWithFileSystemRepresentation: mutablePath, isDirectory: false, + relativeTo: nil) + } + } +} + +func withTemporaryFile(f: (URL) -> ()) throws { + guard let tmp = FileManager.default.makeTemporaryFile() else { + throw SocketError.makeFailed("MakeTemporaryFile") + } + defer { try? FileManager.default.removeItem(atPath: tmp.path) } + f(tmp) +} diff --git a/FlyingSocks/Tests/SocketAddressTests.swift b/FlyingSocks/Tests/SocketAddressTests.swift index 37fc5056..03f183cc 100644 --- a/FlyingSocks/Tests/SocketAddressTests.swift +++ b/FlyingSocks/Tests/SocketAddressTests.swift @@ -159,6 +159,40 @@ struct SocketAddressTests { } } + @Test + func INET4_CheckSize() throws { + let sin = sockaddr_in.inet(port: 8001) + #expect( + sin.size == socklen_t(MemoryLayout.size) + ) + } + + @Test + func INET6_CheckSize() throws { + let sin6 = sockaddr_in6.inet6(port: 8001) + #expect( + sin6.size == socklen_t(MemoryLayout.size) + ) + } + + @Test + func unix_CheckSize() throws { + let sun = sockaddr_un.unix(path: "/var/foo") + #expect( + sun.size == socklen_t(MemoryLayout.size) + ) + } + + @Test + func unknown_CheckSize() throws { + var sa = sockaddr() + sa.sa_family = sa_family_t(AF_UNSPEC) + + #expect( + sa.size == 0 + ) + } + @Test func unlinkUnix_Throws_WhenPathIsInvalid() { #expect(throws: SocketError.self) { @@ -298,25 +332,44 @@ struct SocketAddressTests { } } } -} + @Test + func testTypeErasedSockAddress() throws { + var addrIn6 = sockaddr_in6() + addrIn6.sin6_family = sa_family_t(AF_INET6) + addrIn6.sin6_port = UInt16(9090).bigEndian + addrIn6.sin6_addr = try Socket.makeInAddr(fromIP6: "fe80::1") -private extension SocketAddress { + let storage = AnySocketAddress(addrIn6) - func makeStorage() -> sockaddr_storage { - var storage = sockaddr_storage() - var addr = self - let addrSize = MemoryLayout.size - let storageSize = MemoryLayout.size + #expect(storage.family == sa_family_t(AF_INET6)) + #expect(storage.family == addrIn6.sin6_family) - withUnsafePointer(to: &addr) { addrPtr in - let addrRawPtr = UnsafeRawPointer(addrPtr) - withUnsafeMutablePointer(to: &storage) { storagePtr in - let storageRawPtr = UnsafeMutableRawPointer(storagePtr) - let copySize = min(addrSize, storageSize) - storageRawPtr.copyMemory(from: addrRawPtr, byteCount: copySize) + withUnsafeBytes(of: addrIn6) { addrIn6Ptr in + let storagePtr = addrIn6Ptr.bindMemory(to: sockaddr_storage.self) + #expect(storagePtr.baseAddress!.pointee.ss_family == sa_family_t(AF_INET6)) + let sockaddrIn6Ptr = addrIn6Ptr.bindMemory(to: sockaddr_in6.self) + #expect(sockaddrIn6Ptr.baseAddress!.pointee.sin6_port == addrIn6.sin6_port) + let addrArray = withUnsafeBytes(of: sockaddrIn6Ptr.baseAddress!.pointee.sin6_addr) { + Array($0.bindMemory(to: UInt8.self)) + } + let expectedArray = withUnsafeBytes(of: addrIn6.sin6_addr) { + Array($0.bindMemory(to: UInt8.self)) } + #expect(addrArray == expectedArray) + } + + storage.withSockAddr { sa in + #expect(sa.pointee.sa_family == sa_family_t(AF_INET6)) + } + + addrIn6.withSockAddr { sa in + #expect(sa.pointee.sa_family == sa_family_t(AF_INET6)) } - return storage } } + +// this is a bit ugly but necessary to get unknown_CheckSize() to function +extension sockaddr: SocketAddress, @unchecked Sendable { + public static var family: sa_family_t { sa_family_t(AF_UNSPEC) } +} diff --git a/FlyingSocks/Tests/SocketPoolTests.swift b/FlyingSocks/Tests/SocketPoolTests.swift index 24e39a2b..51a588f7 100644 --- a/FlyingSocks/Tests/SocketPoolTests.swift +++ b/FlyingSocks/Tests/SocketPoolTests.swift @@ -85,7 +85,7 @@ struct SocketPoolTests { try await pool.prepare() let task = Task { - let socket = try Socket(domain: AF_UNIX, type: Socket.stream) + let socket = try Socket(domain: AF_UNIX, type: .stream) try await pool.suspendSocket(socket, untilReadyFor: .read) } @@ -125,7 +125,7 @@ struct SocketPoolTests { try? await task.value - let socket = try Socket(domain: AF_UNIX, type: Socket.stream) + let socket = try Socket(domain: AF_UNIX, type: .stream) await #expect(throws: (any Error).self) { try await pool.suspendSocket(socket, untilReadyFor: .read) } @@ -138,7 +138,7 @@ struct SocketPoolTests { let task = Task { try await pool.run() } defer { task.cancel() } - let socket = try Socket(domain: AF_UNIX, type: Socket.stream) + let socket = try Socket(domain: AF_UNIX, type: .stream) let suspension = Task { try await pool.suspendSocket(socket, untilReadyFor: .read) } @@ -160,7 +160,7 @@ struct SocketPoolTests { let task = Task { try await pool.run() } defer { task.cancel() } - let socket = try Socket(domain: AF_UNIX, type: Socket.stream) + let socket = try Socket(domain: AF_UNIX, type: .stream) let suspension = Task { try await pool.suspendSocket(socket, untilReadyFor: .read) } diff --git a/FlyingSocks/Tests/SocketTests.swift b/FlyingSocks/Tests/SocketTests.swift index 2cbe27cc..68b6dc01 100644 --- a/FlyingSocks/Tests/SocketTests.swift +++ b/FlyingSocks/Tests/SocketTests.swift @@ -98,7 +98,7 @@ struct SocketTests { @Test func socketWrite_Throws_WhenSocketIsNotConnected() async throws { - let s1 = try Socket(domain: AF_UNIX, type: Socket.stream) + let s1 = try Socket(domain: AF_UNIX, type: .stream) let data = Data(repeating: 0x01, count: 100) #expect(throws: SocketError.self) { try s1.write(data, from: data.startIndex) @@ -108,7 +108,7 @@ struct SocketTests { @Test func socket_Sets_And_Gets_ReceiveBufferSize() throws { - let socket = try Socket(domain: AF_UNIX, type: Socket.stream) + let socket = try Socket(domain: AF_UNIX, type: .stream) try socket.setValue(2048, for: .receiveBufferSize) #if canImport(Darwin) @@ -121,7 +121,7 @@ struct SocketTests { @Test func socket_Sets_And_Gets_SendBufferSizeOption() throws { - let socket = try Socket(domain: AF_UNIX, type: Socket.stream) + let socket = try Socket(domain: AF_UNIX, type: .stream) try socket.setValue(2048, for: .sendBufferSize) #if canImport(Darwin) @@ -134,7 +134,7 @@ struct SocketTests { @Test func socket_Sets_And_Gets_BoolOption() throws { - let socket = try Socket(domain: AF_UNIX, type: Socket.stream) + let socket = try Socket(domain: AF_UNIX, type: .stream) try socket.setValue(true, for: .localAddressReuse) #expect(try socket.getValue(for: .localAddressReuse)) @@ -145,20 +145,13 @@ struct SocketTests { @Test func socket_Sets_And_Gets_Flags() throws { - let socket = try Socket(domain: AF_UNIX, type: Socket.stream) + let socket = try Socket(domain: AF_UNIX, type: .stream) #expect(try socket.flags.contains(.append) == false) try socket.setFlags(.append) #expect(try socket.flags.contains(.append)) } - @Test - func socketInit_ThrowsError_WhenInvalid() { - #expect(throws: SocketError.self) { - _ = try Socket(domain: -1, type: -1) - } - } - @Test func socketAccept_ThrowsError_WhenInvalid() { let socket = Socket(file: -1) @@ -193,7 +186,7 @@ struct SocketTests { @Test func socketBind_ToINET() throws { - let socket = try Socket(domain: AF_INET, type: Socket.stream) + let socket = try Socket(domain: AF_INET, type: .stream) try socket.setValue(true, for: .localAddressReuse) let address = Socket.makeAddressINET(port:5050) #expect(throws: Never.self) {