Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion FlyingFox/Tests/AsyncSocketTests.swift
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ extension AsyncSocket {
}

static func make(pool: some AsyncSocketPool) throws -> AsyncSocket {
let socket = try Socket(domain: AF_UNIX, type: Socket.stream)
let socket = try Socket(domain: AF_UNIX, type: .stream)
return try AsyncSocket(socket: socket, pool: pool)
}

Expand Down
148 changes: 145 additions & 3 deletions FlyingSocks/Sources/AsyncSocket.swift
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,25 @@

public struct AsyncSocket: Sendable {

public struct Message: Sendable {
public let peerAddress: any SocketAddress
public let bytes: [UInt8]
public let interfaceIndex: UInt32?
public let localAddress: (any SocketAddress)?

public init(
peerAddress: any SocketAddress,
bytes: [UInt8],
interfaceIndex: UInt32? = nil,
localAddress: (any SocketAddress)? = nil
) {
self.peerAddress = peerAddress
self.bytes = bytes
self.interfaceIndex = interfaceIndex
self.localAddress = localAddress
}
}

public let socket: Socket
let pool: any AsyncSocketPool

Expand All @@ -83,7 +102,7 @@
pool: some AsyncSocketPool,
timeout: TimeInterval = 5) async throws -> Self {
try await withThrowingTimeout(seconds: timeout) {
let socket = try Socket(domain: Int32(type(of: address).family), type: Socket.stream)
let socket = try Socket(domain: Int32(type(of: address).family), type: .stream)
let asyncSocket = try AsyncSocket(socket: socket, pool: pool)
try await asyncSocket.connect(to: address)
return asyncSocket
Expand Down Expand Up @@ -129,6 +148,37 @@
return buffer
}

public func receive(atMost length: Int = 4096) async throws -> (any SocketAddress, [UInt8]) {
try Task.checkCancellation()

repeat {
do {
return try socket.receive(length: length)
} catch SocketError.blocked {
try await pool.suspendSocket(socket, untilReadyFor: .read)
} catch {
throw error
}
} while true
}

Check warning on line 163 in FlyingSocks/Sources/AsyncSocket.swift

View check run for this annotation

Codecov / codecov/patch

FlyingSocks/Sources/AsyncSocket.swift#L158-L163

Added lines #L158 - L163 were not covered by tests

#if !canImport(WinSDK)
public func receive(atMost length: Int) async throws -> Message {
try Task.checkCancellation()

repeat {
do {
let (peerAddress, bytes, interfaceIndex, localAddress) = try socket.receive(length: length)
return Message(peerAddress: peerAddress, bytes: bytes, interfaceIndex: interfaceIndex, localAddress: localAddress)
} catch SocketError.blocked {
try await pool.suspendSocket(socket, untilReadyFor: .read)
} catch {
throw error
}
} while true
}

Check warning on line 179 in FlyingSocks/Sources/AsyncSocket.swift

View check run for this annotation

Codecov / codecov/patch

FlyingSocks/Sources/AsyncSocket.swift#L174-L179

Added lines #L174 - L179 were not covered by tests
#endif

/// Reads bytes from the socket up to by not over/
/// - Parameter bytes: The max number of bytes to read
/// - Returns: an array of the read bytes capped to the number of bytes provided.
Expand Down Expand Up @@ -163,6 +213,61 @@
}
}

public func send(_ data: [UInt8], to address: some SocketAddress) async throws {
let sent = try await pool.loopUntilReady(for: .write, on: socket) {
try socket.send(data, to: address)
}
guard sent == data.count else {
throw SocketError.disconnected
}
}

Check warning on line 223 in FlyingSocks/Sources/AsyncSocket.swift

View check run for this annotation

Codecov / codecov/patch

FlyingSocks/Sources/AsyncSocket.swift#L216-L223

Added lines #L216 - L223 were not covered by tests

public func send(_ data: Data, to address: some SocketAddress) async throws {
try await send(Array(data), to: address)
}

Check warning on line 227 in FlyingSocks/Sources/AsyncSocket.swift

View check run for this annotation

Codecov / codecov/patch

FlyingSocks/Sources/AsyncSocket.swift#L225-L227

Added lines #L225 - L227 were not covered by tests

#if !canImport(WinSDK)
public func send(
message: [UInt8],
to peerAddress: some SocketAddress,
interfaceIndex: UInt32? = nil,
from localAddress: (some SocketAddress)? = nil
) async throws {
let sent = try await pool.loopUntilReady(for: .write, on: socket) {
try socket.send(message: message, to: peerAddress, interfaceIndex: interfaceIndex, from: localAddress)
}
guard sent == message.count else {
throw SocketError.disconnected
}
}

Check warning on line 242 in FlyingSocks/Sources/AsyncSocket.swift

View check run for this annotation

Codecov / codecov/patch

FlyingSocks/Sources/AsyncSocket.swift#L235-L242

Added lines #L235 - L242 were not covered by tests

public func send(
message: Data,
to peerAddress: some SocketAddress,
interfaceIndex: UInt32? = nil,
from localAddress: (some SocketAddress)? = nil
) async throws {
try await send(message: Array(message), to: peerAddress, interfaceIndex: interfaceIndex, from: localAddress)
}

Check warning on line 251 in FlyingSocks/Sources/AsyncSocket.swift

View check run for this annotation

Codecov / codecov/patch

FlyingSocks/Sources/AsyncSocket.swift#L249-L251

Added lines #L249 - L251 were not covered by tests

public func send(message: Message) async throws {
let localAddress: AnySocketAddress?

if let unwrappedLocalAddress = message.localAddress {
localAddress = AnySocketAddress(unwrappedLocalAddress)
} else {
localAddress = nil
}

try await send(
message: message.bytes,
to: AnySocketAddress(message.peerAddress),
interfaceIndex: message.interfaceIndex,
from: localAddress
)
}

Check warning on line 268 in FlyingSocks/Sources/AsyncSocket.swift

View check run for this annotation

Codecov / codecov/patch

FlyingSocks/Sources/AsyncSocket.swift#L253-L268

Added lines #L253 - L268 were not covered by tests
#endif

public func close() throws {
try socket.close()
}
Expand All @@ -174,12 +279,20 @@
public var sockets: AsyncSocketSequence {
AsyncSocketSequence(socket: self)
}

public var messages: AsyncSocketMessageSequence {
AsyncSocketMessageSequence(socket: self)
}

Check warning on line 285 in FlyingSocks/Sources/AsyncSocket.swift

View check run for this annotation

Codecov / codecov/patch

FlyingSocks/Sources/AsyncSocket.swift#L283-L285

Added lines #L283 - L285 were not covered by tests

public func messages(maxMessageLength: Int) -> AsyncSocketMessageSequence {
AsyncSocketMessageSequence(socket: self, maxMessageLength: maxMessageLength)
}

Check warning on line 289 in FlyingSocks/Sources/AsyncSocket.swift

View check run for this annotation

Codecov / codecov/patch

FlyingSocks/Sources/AsyncSocket.swift#L287-L289

Added lines #L287 - L289 were not covered by tests
}

package extension AsyncSocket {

static func makePair(pool: some AsyncSocketPool) throws -> (AsyncSocket, AsyncSocket) {
let (s1, s2) = try Socket.makePair()
static func makePair(pool: some AsyncSocketPool, type: SocketType = .stream) throws -> (AsyncSocket, AsyncSocket) {
let (s1, s2) = try Socket.makePair(type: type)
let a1 = try AsyncSocket(socket: s1, pool: pool)
let a2 = try AsyncSocket(socket: s2, pool: pool)
return (a1, a2)
Expand Down Expand Up @@ -237,6 +350,35 @@
}
}

public struct AsyncSocketMessageSequence: AsyncSequence, AsyncIteratorProtocol, Sendable {
public static let DefaultMaxMessageLength: Int = 1500

// Windows has a different recvmsg() API signature which is presently unsupported
public typealias Element = AsyncSocket.Message

private let socket: AsyncSocket
private let maxMessageLength: Int

public func makeAsyncIterator() -> AsyncSocketMessageSequence { self }

Check warning on line 362 in FlyingSocks/Sources/AsyncSocket.swift

View check run for this annotation

Codecov / codecov/patch

FlyingSocks/Sources/AsyncSocket.swift#L362

Added line #L362 was not covered by tests

init(socket: AsyncSocket, maxMessageLength: Int = Self.DefaultMaxMessageLength) {
self.socket = socket
self.maxMessageLength = maxMessageLength
}

Check warning on line 367 in FlyingSocks/Sources/AsyncSocket.swift

View check run for this annotation

Codecov / codecov/patch

FlyingSocks/Sources/AsyncSocket.swift#L364-L367

Added lines #L364 - L367 were not covered by tests

public mutating func next() async throws -> Element? {

Check warning on line 369 in FlyingSocks/Sources/AsyncSocket.swift

View check run for this annotation

Codecov / codecov/patch

FlyingSocks/Sources/AsyncSocket.swift#L369

Added line #L369 was not covered by tests
#if !canImport(WinSDK)
try await socket.receive(atMost: maxMessageLength)

Check warning on line 371 in FlyingSocks/Sources/AsyncSocket.swift

View check run for this annotation

Codecov / codecov/patch

FlyingSocks/Sources/AsyncSocket.swift#L371

Added line #L371 was not covered by tests
#else
let peerAddress: any SocketAddress
let bytes: [UInt8]

(peerAddress, bytes) = try await socket.receive(atMost: maxMessageLength)
return AsyncSocket.Message(peerAddress: peerAddress, bytes: bytes)
#endif
}

Check warning on line 379 in FlyingSocks/Sources/AsyncSocket.swift

View check run for this annotation

Codecov / codecov/patch

FlyingSocks/Sources/AsyncSocket.swift#L379

Added line #L379 was not covered by tests
}

private actor ClientPoolLoader {
static let shared = ClientPoolLoader()

Expand Down
25 changes: 25 additions & 0 deletions FlyingSocks/Sources/Socket+Android.swift
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,10 @@ let EPOLLET: UInt32 = 1 << 31;

public extension Socket {
typealias FileDescriptorType = Int32
typealias IovLengthType = UInt
typealias ControlMessageHeaderLengthType = Int
typealias IPv4InterfaceIndexType = Int32
typealias IPv6InterfaceIndexType = Int32
}

extension Socket.FileDescriptor {
Expand All @@ -45,7 +49,12 @@ extension Socket.FileDescriptor {

extension Socket {
static let stream = Int32(SOCK_STREAM)
static let datagram = Int32(SOCK_DGRAM)
static let in_addr_any = Android.in_addr(s_addr: Android.in_addr_t(0))
static let ipproto_ip = Int32(IPPROTO_IP)
static let ipproto_ipv6 = Int32(IPPROTO_IPV6)
static let ip_pktinfo = Int32(IP_PKTINFO)
static let ipv6_pktinfo = Int32(IPV6_PKTINFO)

static func makeAddressINET(port: UInt16) -> Android.sockaddr_in {
Android.sockaddr_in(
Expand Down Expand Up @@ -175,6 +184,22 @@ extension Socket {
static func pollfd(fd: FileDescriptorType, events: Int16, revents: Int16) -> Android.pollfd {
Android.pollfd(fd: fd, events: events, revents: revents)
}

static func recvfrom(_ fd: FileDescriptorType, _ buffer: UnsafeMutableRawPointer!, _ nbyte: Int, _ flags: Int32, _ addr: UnsafeMutablePointer<sockaddr>!, _ len: UnsafeMutablePointer<socklen_t>!) -> Int {
Android.recvfrom(fd, buffer, nbyte, flags, addr, len)
}

static func sendto(_ fd: FileDescriptorType, _ buffer: UnsafeRawPointer!, _ nbyte: Int, _ flags: Int32, _ destaddr: UnsafePointer<sockaddr>!, _ destlen: socklen_t) -> Int {
Android.sendto(fd, buffer, nbyte, flags, destaddr, destlen)
}

static func recvmsg(_ fd: FileDescriptorType, _ message: UnsafeMutablePointer<msghdr>, _ flags: Int32) -> Int {
Android.recvmsg(fd, message, flags)
}

static func sendmsg(_ fd: FileDescriptorType, _ message: UnsafePointer<msghdr>, _ flags: Int32) -> Int {
Android.sendmsg(fd, message, flags)
}
}

#endif
25 changes: 25 additions & 0 deletions FlyingSocks/Sources/Socket+Darwin.swift
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,10 @@

public extension Socket {
typealias FileDescriptorType = Int32
typealias IovLengthType = Int
typealias ControlMessageHeaderLengthType = UInt32
typealias IPv4InterfaceIndexType = UInt32
typealias IPv6InterfaceIndexType = UInt32
}

extension Socket.FileDescriptor {
Expand All @@ -42,7 +46,12 @@

extension Socket {
static let stream = Int32(SOCK_STREAM)
static let datagram = Int32(SOCK_DGRAM)
static let in_addr_any = Darwin.in_addr(s_addr: Darwin.in_addr_t(0))
static let ipproto_ip = Int32(IPPROTO_IP)
static let ipproto_ipv6 = Int32(IPPROTO_IPV6)
static let ip_pktinfo = Int32(IP_PKTINFO)
static let ipv6_pktinfo = Int32(50) // __APPLE_USE_RFC_2292

static func makeAddressINET(port: UInt16) -> Darwin.sockaddr_in {
Darwin.sockaddr_in(
Expand Down Expand Up @@ -176,6 +185,22 @@
static func pollfd(fd: FileDescriptorType, events: Int16, revents: Int16) -> Darwin.pollfd {
Darwin.pollfd(fd: fd, events: events, revents: revents)
}

static func recvfrom(_ fd: FileDescriptorType, _ buffer: UnsafeMutableRawPointer!, _ nbyte: Int, _ flags: Int32, _ addr: UnsafeMutablePointer<sockaddr>!, _ len: UnsafeMutablePointer<socklen_t>!) -> Int {
Darwin.recvfrom(fd, buffer, nbyte, flags, addr, len)
}

static func sendto(_ fd: FileDescriptorType, _ buffer: UnsafeRawPointer!, _ nbyte: Int, _ flags: Int32, _ destaddr: UnsafePointer<sockaddr>!, _ destlen: socklen_t) -> Int {
Darwin.sendto(fd, buffer, nbyte, flags, destaddr, destlen)
}

Check warning on line 195 in FlyingSocks/Sources/Socket+Darwin.swift

View check run for this annotation

Codecov / codecov/patch

FlyingSocks/Sources/Socket+Darwin.swift#L193-L195

Added lines #L193 - L195 were not covered by tests

static func recvmsg(_ fd: FileDescriptorType, _ message: UnsafeMutablePointer<msghdr>, _ flags: Int32) -> Int {
Darwin.recvmsg(fd, message, flags)
}

static func sendmsg(_ fd: FileDescriptorType, _ message: UnsafePointer<msghdr>, _ flags: Int32) -> Int {
Darwin.sendmsg(fd, message, flags)
}

Check warning on line 203 in FlyingSocks/Sources/Socket+Darwin.swift

View check run for this annotation

Codecov / codecov/patch

FlyingSocks/Sources/Socket+Darwin.swift#L201-L203

Added lines #L201 - L203 were not covered by tests
}

#endif
30 changes: 30 additions & 0 deletions FlyingSocks/Sources/Socket+Glibc.swift
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,10 @@ import Glibc

public extension Socket {
typealias FileDescriptorType = Int32
typealias IovLengthType = Int
typealias ControlMessageHeaderLengthType = Int
typealias IPv4InterfaceIndexType = Int32
typealias IPv6InterfaceIndexType = UInt32
}

extension Socket.FileDescriptor {
Expand All @@ -42,7 +46,12 @@ extension Socket.FileDescriptor {

extension Socket {
static let stream = Int32(SOCK_STREAM.rawValue)
static let datagram = Int32(SOCK_DGRAM.rawValue)
static let in_addr_any = Glibc.in_addr(s_addr: Glibc.in_addr_t(0))
static let ipproto_ip = Int32(IPPROTO_IP)
static let ipproto_ipv6 = Int32(IPPROTO_IPV6)
static let ip_pktinfo = Int32(IP_PKTINFO)
static let ipv6_pktinfo = Int32(IPV6_PKTINFO)

static func makeAddressINET(port: UInt16) -> Glibc.sockaddr_in {
Glibc.sockaddr_in(
Expand Down Expand Up @@ -172,6 +181,27 @@ extension Socket {
static func pollfd(fd: FileDescriptorType, events: Int16, revents: Int16) -> Glibc.pollfd {
Glibc.pollfd(fd: fd, events: events, revents: revents)
}

static func recvfrom(_ fd: FileDescriptorType, _ buffer: UnsafeMutableRawPointer!, _ nbyte: Int, _ flags: Int32, _ addr: UnsafeMutablePointer<sockaddr>!, _ len: UnsafeMutablePointer<socklen_t>!) -> Int {
Glibc.recvfrom(fd, buffer, nbyte, flags, addr, len)
}

static func sendto(_ fd: FileDescriptorType, _ buffer: UnsafeRawPointer!, _ nbyte: Int, _ flags: Int32, _ destaddr: UnsafePointer<sockaddr>!, _ destlen: socklen_t) -> Int {
Glibc.sendto(fd, buffer, nbyte, flags, destaddr, destlen)
}

static func recvmsg(_ fd: FileDescriptorType, _ message: UnsafeMutablePointer<msghdr>, _ flags: Int32) -> Int {
Glibc.recvmsg(fd, message, flags)
}

static func sendmsg(_ fd: FileDescriptorType, _ message: UnsafePointer<msghdr>, _ flags: Int32) -> Int {
Glibc.sendmsg(fd, message, flags)
}
}

struct in6_pktinfo {
var ipi6_addr: in6_addr
var ipi6_ifindex: CUnsignedInt
}

#endif
25 changes: 25 additions & 0 deletions FlyingSocks/Sources/Socket+Musl.swift
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,10 @@ import Musl

public extension Socket {
typealias FileDescriptorType = Int32
typealias IovLengthType = Int
typealias ControlMessageHeaderLengthType = UInt32
typealias IPv4InterfaceIndexType = Int32
typealias IPv6InterfaceIndexType = UInt32
}

extension Socket.FileDescriptor {
Expand All @@ -42,7 +46,12 @@ extension Socket.FileDescriptor {

extension Socket {
static let stream = Int32(SOCK_STREAM)
static let datagram = Int32(SOCK_DGRAM)
static let in_addr_any = Musl.in_addr(s_addr: Musl.in_addr_t(0))
static let ipproto_ip = Int32(IPPROTO_IP)
static let ipproto_ipv6 = Int32(IPPROTO_IPV6)
static let ip_pktinfo = Int32(IP_PKTINFO)
static let ipv6_pktinfo = Int32(IPV6_PKTINFO)

static func makeAddressINET(port: UInt16) -> Musl.sockaddr_in {
Musl.sockaddr_in(
Expand Down Expand Up @@ -172,6 +181,22 @@ extension Socket {
static func pollfd(fd: FileDescriptorType, events: Int16, revents: Int16) -> Musl.pollfd {
Musl.pollfd(fd: fd, events: events, revents: revents)
}

static func recvfrom(_ fd: FileDescriptorType, _ buffer: UnsafeMutableRawPointer!, _ nbyte: Int, _ flags: Int32, _ addr: UnsafeMutablePointer<sockaddr>!, _ len: UnsafeMutablePointer<socklen_t>!) -> Int {
Musl.recvfrom(fd, buffer, nbyte, flags, addr, len)
}

static func sendto(_ fd: FileDescriptorType, _ buffer: UnsafeRawPointer!, _ nbyte: Int, _ flags: Int32, _ destaddr: UnsafePointer<sockaddr>!, _ destlen: socklen_t) -> Int {
Musl.sendto(fd, buffer, nbyte, flags, destaddr, destlen)
}

static func recvmsg(_ fd: FileDescriptorType, _ message: UnsafeMutablePointer<msghdr>, _ flags: Int32) -> Int {
Musl.recvmsg(fd, message, flags)
}

static func sendmsg(_ fd: FileDescriptorType, _ message: UnsafePointer<msghdr>, _ flags: Int32) -> Int {
Musl.sendmsg(fd, message, flags)
}
}

#endif
Loading