Skip to content

Commit 150911e

Browse files
authored
Merge pull request #129 from PADL/udp
2 parents 3de9f8a + 7731068 commit 150911e

15 files changed

+931
-76
lines changed

FlyingFox/Tests/AsyncSocketTests.swift

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@ extension AsyncSocket {
4040
}
4141

4242
static func make(pool: some AsyncSocketPool) throws -> AsyncSocket {
43-
let socket = try Socket(domain: AF_UNIX, type: Socket.stream)
43+
let socket = try Socket(domain: AF_UNIX, type: .stream)
4444
return try AsyncSocket(socket: socket, pool: pool)
4545
}
4646

FlyingSocks/Sources/AsyncSocket.swift

Lines changed: 145 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,25 @@ public extension AsyncSocketPool where Self == SocketPool<Poll> {
6262

6363
public struct AsyncSocket: Sendable {
6464

65+
public struct Message: Sendable {
66+
public let peerAddress: any SocketAddress
67+
public let bytes: [UInt8]
68+
public let interfaceIndex: UInt32?
69+
public let localAddress: (any SocketAddress)?
70+
71+
public init(
72+
peerAddress: any SocketAddress,
73+
bytes: [UInt8],
74+
interfaceIndex: UInt32? = nil,
75+
localAddress: (any SocketAddress)? = nil
76+
) {
77+
self.peerAddress = peerAddress
78+
self.bytes = bytes
79+
self.interfaceIndex = interfaceIndex
80+
self.localAddress = localAddress
81+
}
82+
}
83+
6584
public let socket: Socket
6685
let pool: any AsyncSocketPool
6786

@@ -83,7 +102,7 @@ public struct AsyncSocket: Sendable {
83102
pool: some AsyncSocketPool,
84103
timeout: TimeInterval = 5) async throws -> Self {
85104
try await withThrowingTimeout(seconds: timeout) {
86-
let socket = try Socket(domain: Int32(type(of: address).family), type: Socket.stream)
105+
let socket = try Socket(domain: Int32(type(of: address).family), type: .stream)
87106
let asyncSocket = try AsyncSocket(socket: socket, pool: pool)
88107
try await asyncSocket.connect(to: address)
89108
return asyncSocket
@@ -129,6 +148,37 @@ public struct AsyncSocket: Sendable {
129148
return buffer
130149
}
131150

151+
public func receive(atMost length: Int = 4096) async throws -> (any SocketAddress, [UInt8]) {
152+
try Task.checkCancellation()
153+
154+
repeat {
155+
do {
156+
return try socket.receive(length: length)
157+
} catch SocketError.blocked {
158+
try await pool.suspendSocket(socket, untilReadyFor: .read)
159+
} catch {
160+
throw error
161+
}
162+
} while true
163+
}
164+
165+
#if !canImport(WinSDK)
166+
public func receive(atMost length: Int) async throws -> Message {
167+
try Task.checkCancellation()
168+
169+
repeat {
170+
do {
171+
let (peerAddress, bytes, interfaceIndex, localAddress) = try socket.receive(length: length)
172+
return Message(peerAddress: peerAddress, bytes: bytes, interfaceIndex: interfaceIndex, localAddress: localAddress)
173+
} catch SocketError.blocked {
174+
try await pool.suspendSocket(socket, untilReadyFor: .read)
175+
} catch {
176+
throw error
177+
}
178+
} while true
179+
}
180+
#endif
181+
132182
/// Reads bytes from the socket up to by not over/
133183
/// - Parameter bytes: The max number of bytes to read
134184
/// - Returns: an array of the read bytes capped to the number of bytes provided.
@@ -163,6 +213,61 @@ public struct AsyncSocket: Sendable {
163213
}
164214
}
165215

216+
public func send(_ data: [UInt8], to address: some SocketAddress) async throws {
217+
let sent = try await pool.loopUntilReady(for: .write, on: socket) {
218+
try socket.send(data, to: address)
219+
}
220+
guard sent == data.count else {
221+
throw SocketError.disconnected
222+
}
223+
}
224+
225+
public func send(_ data: Data, to address: some SocketAddress) async throws {
226+
try await send(Array(data), to: address)
227+
}
228+
229+
#if !canImport(WinSDK)
230+
public func send(
231+
message: [UInt8],
232+
to peerAddress: some SocketAddress,
233+
interfaceIndex: UInt32? = nil,
234+
from localAddress: (some SocketAddress)? = nil
235+
) async throws {
236+
let sent = try await pool.loopUntilReady(for: .write, on: socket) {
237+
try socket.send(message: message, to: peerAddress, interfaceIndex: interfaceIndex, from: localAddress)
238+
}
239+
guard sent == message.count else {
240+
throw SocketError.disconnected
241+
}
242+
}
243+
244+
public func send(
245+
message: Data,
246+
to peerAddress: some SocketAddress,
247+
interfaceIndex: UInt32? = nil,
248+
from localAddress: (some SocketAddress)? = nil
249+
) async throws {
250+
try await send(message: Array(message), to: peerAddress, interfaceIndex: interfaceIndex, from: localAddress)
251+
}
252+
253+
public func send(message: Message) async throws {
254+
let localAddress: AnySocketAddress?
255+
256+
if let unwrappedLocalAddress = message.localAddress {
257+
localAddress = AnySocketAddress(unwrappedLocalAddress)
258+
} else {
259+
localAddress = nil
260+
}
261+
262+
try await send(
263+
message: message.bytes,
264+
to: AnySocketAddress(message.peerAddress),
265+
interfaceIndex: message.interfaceIndex,
266+
from: localAddress
267+
)
268+
}
269+
#endif
270+
166271
public func close() throws {
167272
try socket.close()
168273
}
@@ -174,12 +279,20 @@ public struct AsyncSocket: Sendable {
174279
public var sockets: AsyncSocketSequence {
175280
AsyncSocketSequence(socket: self)
176281
}
282+
283+
public var messages: AsyncSocketMessageSequence {
284+
AsyncSocketMessageSequence(socket: self)
285+
}
286+
287+
public func messages(maxMessageLength: Int) -> AsyncSocketMessageSequence {
288+
AsyncSocketMessageSequence(socket: self, maxMessageLength: maxMessageLength)
289+
}
177290
}
178291

179292
package extension AsyncSocket {
180293

181-
static func makePair(pool: some AsyncSocketPool) throws -> (AsyncSocket, AsyncSocket) {
182-
let (s1, s2) = try Socket.makePair()
294+
static func makePair(pool: some AsyncSocketPool, type: SocketType = .stream) throws -> (AsyncSocket, AsyncSocket) {
295+
let (s1, s2) = try Socket.makePair(type: type)
183296
let a1 = try AsyncSocket(socket: s1, pool: pool)
184297
let a2 = try AsyncSocket(socket: s2, pool: pool)
185298
return (a1, a2)
@@ -237,6 +350,35 @@ public struct AsyncSocketSequence: AsyncSequence, AsyncIteratorProtocol, Sendabl
237350
}
238351
}
239352

353+
public struct AsyncSocketMessageSequence: AsyncSequence, AsyncIteratorProtocol, Sendable {
354+
public static let DefaultMaxMessageLength: Int = 1500
355+
356+
// Windows has a different recvmsg() API signature which is presently unsupported
357+
public typealias Element = AsyncSocket.Message
358+
359+
private let socket: AsyncSocket
360+
private let maxMessageLength: Int
361+
362+
public func makeAsyncIterator() -> AsyncSocketMessageSequence { self }
363+
364+
init(socket: AsyncSocket, maxMessageLength: Int = Self.DefaultMaxMessageLength) {
365+
self.socket = socket
366+
self.maxMessageLength = maxMessageLength
367+
}
368+
369+
public mutating func next() async throws -> Element? {
370+
#if !canImport(WinSDK)
371+
try await socket.receive(atMost: maxMessageLength)
372+
#else
373+
let peerAddress: any SocketAddress
374+
let bytes: [UInt8]
375+
376+
(peerAddress, bytes) = try await socket.receive(atMost: maxMessageLength)
377+
return AsyncSocket.Message(peerAddress: peerAddress, bytes: bytes)
378+
#endif
379+
}
380+
}
381+
240382
private actor ClientPoolLoader {
241383
static let shared = ClientPoolLoader()
242384

FlyingSocks/Sources/Socket+Android.swift

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,10 @@ let EPOLLET: UInt32 = 1 << 31;
3737

3838
public extension Socket {
3939
typealias FileDescriptorType = Int32
40+
typealias IovLengthType = UInt
41+
typealias ControlMessageHeaderLengthType = Int
42+
typealias IPv4InterfaceIndexType = Int32
43+
typealias IPv6InterfaceIndexType = Int32
4044
}
4145

4246
extension Socket.FileDescriptor {
@@ -45,7 +49,12 @@ extension Socket.FileDescriptor {
4549

4650
extension Socket {
4751
static let stream = Int32(SOCK_STREAM)
52+
static let datagram = Int32(SOCK_DGRAM)
4853
static let in_addr_any = Android.in_addr(s_addr: Android.in_addr_t(0))
54+
static let ipproto_ip = Int32(IPPROTO_IP)
55+
static let ipproto_ipv6 = Int32(IPPROTO_IPV6)
56+
static let ip_pktinfo = Int32(IP_PKTINFO)
57+
static let ipv6_pktinfo = Int32(IPV6_PKTINFO)
4958

5059
static func makeAddressINET(port: UInt16) -> Android.sockaddr_in {
5160
Android.sockaddr_in(
@@ -175,6 +184,22 @@ extension Socket {
175184
static func pollfd(fd: FileDescriptorType, events: Int16, revents: Int16) -> Android.pollfd {
176185
Android.pollfd(fd: fd, events: events, revents: revents)
177186
}
187+
188+
static func recvfrom(_ fd: FileDescriptorType, _ buffer: UnsafeMutableRawPointer!, _ nbyte: Int, _ flags: Int32, _ addr: UnsafeMutablePointer<sockaddr>!, _ len: UnsafeMutablePointer<socklen_t>!) -> Int {
189+
Android.recvfrom(fd, buffer, nbyte, flags, addr, len)
190+
}
191+
192+
static func sendto(_ fd: FileDescriptorType, _ buffer: UnsafeRawPointer!, _ nbyte: Int, _ flags: Int32, _ destaddr: UnsafePointer<sockaddr>!, _ destlen: socklen_t) -> Int {
193+
Android.sendto(fd, buffer, nbyte, flags, destaddr, destlen)
194+
}
195+
196+
static func recvmsg(_ fd: FileDescriptorType, _ message: UnsafeMutablePointer<msghdr>, _ flags: Int32) -> Int {
197+
Android.recvmsg(fd, message, flags)
198+
}
199+
200+
static func sendmsg(_ fd: FileDescriptorType, _ message: UnsafePointer<msghdr>, _ flags: Int32) -> Int {
201+
Android.sendmsg(fd, message, flags)
202+
}
178203
}
179204

180205
#endif

FlyingSocks/Sources/Socket+Darwin.swift

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,10 @@ import Darwin
3434

3535
public extension Socket {
3636
typealias FileDescriptorType = Int32
37+
typealias IovLengthType = Int
38+
typealias ControlMessageHeaderLengthType = UInt32
39+
typealias IPv4InterfaceIndexType = UInt32
40+
typealias IPv6InterfaceIndexType = UInt32
3741
}
3842

3943
extension Socket.FileDescriptor {
@@ -42,7 +46,12 @@ extension Socket.FileDescriptor {
4246

4347
extension Socket {
4448
static let stream = Int32(SOCK_STREAM)
49+
static let datagram = Int32(SOCK_DGRAM)
4550
static let in_addr_any = Darwin.in_addr(s_addr: Darwin.in_addr_t(0))
51+
static let ipproto_ip = Int32(IPPROTO_IP)
52+
static let ipproto_ipv6 = Int32(IPPROTO_IPV6)
53+
static let ip_pktinfo = Int32(IP_PKTINFO)
54+
static let ipv6_pktinfo = Int32(50) // __APPLE_USE_RFC_2292
4655

4756
static func makeAddressINET(port: UInt16) -> Darwin.sockaddr_in {
4857
Darwin.sockaddr_in(
@@ -176,6 +185,22 @@ extension Socket {
176185
static func pollfd(fd: FileDescriptorType, events: Int16, revents: Int16) -> Darwin.pollfd {
177186
Darwin.pollfd(fd: fd, events: events, revents: revents)
178187
}
188+
189+
static func recvfrom(_ fd: FileDescriptorType, _ buffer: UnsafeMutableRawPointer!, _ nbyte: Int, _ flags: Int32, _ addr: UnsafeMutablePointer<sockaddr>!, _ len: UnsafeMutablePointer<socklen_t>!) -> Int {
190+
Darwin.recvfrom(fd, buffer, nbyte, flags, addr, len)
191+
}
192+
193+
static func sendto(_ fd: FileDescriptorType, _ buffer: UnsafeRawPointer!, _ nbyte: Int, _ flags: Int32, _ destaddr: UnsafePointer<sockaddr>!, _ destlen: socklen_t) -> Int {
194+
Darwin.sendto(fd, buffer, nbyte, flags, destaddr, destlen)
195+
}
196+
197+
static func recvmsg(_ fd: FileDescriptorType, _ message: UnsafeMutablePointer<msghdr>, _ flags: Int32) -> Int {
198+
Darwin.recvmsg(fd, message, flags)
199+
}
200+
201+
static func sendmsg(_ fd: FileDescriptorType, _ message: UnsafePointer<msghdr>, _ flags: Int32) -> Int {
202+
Darwin.sendmsg(fd, message, flags)
203+
}
179204
}
180205

181206
#endif

FlyingSocks/Sources/Socket+Glibc.swift

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,10 @@ import Glibc
3434

3535
public extension Socket {
3636
typealias FileDescriptorType = Int32
37+
typealias IovLengthType = Int
38+
typealias ControlMessageHeaderLengthType = Int
39+
typealias IPv4InterfaceIndexType = Int32
40+
typealias IPv6InterfaceIndexType = UInt32
3741
}
3842

3943
extension Socket.FileDescriptor {
@@ -42,7 +46,12 @@ extension Socket.FileDescriptor {
4246

4347
extension Socket {
4448
static let stream = Int32(SOCK_STREAM.rawValue)
49+
static let datagram = Int32(SOCK_DGRAM.rawValue)
4550
static let in_addr_any = Glibc.in_addr(s_addr: Glibc.in_addr_t(0))
51+
static let ipproto_ip = Int32(IPPROTO_IP)
52+
static let ipproto_ipv6 = Int32(IPPROTO_IPV6)
53+
static let ip_pktinfo = Int32(IP_PKTINFO)
54+
static let ipv6_pktinfo = Int32(IPV6_PKTINFO)
4655

4756
static func makeAddressINET(port: UInt16) -> Glibc.sockaddr_in {
4857
Glibc.sockaddr_in(
@@ -172,6 +181,27 @@ extension Socket {
172181
static func pollfd(fd: FileDescriptorType, events: Int16, revents: Int16) -> Glibc.pollfd {
173182
Glibc.pollfd(fd: fd, events: events, revents: revents)
174183
}
184+
185+
static func recvfrom(_ fd: FileDescriptorType, _ buffer: UnsafeMutableRawPointer!, _ nbyte: Int, _ flags: Int32, _ addr: UnsafeMutablePointer<sockaddr>!, _ len: UnsafeMutablePointer<socklen_t>!) -> Int {
186+
Glibc.recvfrom(fd, buffer, nbyte, flags, addr, len)
187+
}
188+
189+
static func sendto(_ fd: FileDescriptorType, _ buffer: UnsafeRawPointer!, _ nbyte: Int, _ flags: Int32, _ destaddr: UnsafePointer<sockaddr>!, _ destlen: socklen_t) -> Int {
190+
Glibc.sendto(fd, buffer, nbyte, flags, destaddr, destlen)
191+
}
192+
193+
static func recvmsg(_ fd: FileDescriptorType, _ message: UnsafeMutablePointer<msghdr>, _ flags: Int32) -> Int {
194+
Glibc.recvmsg(fd, message, flags)
195+
}
196+
197+
static func sendmsg(_ fd: FileDescriptorType, _ message: UnsafePointer<msghdr>, _ flags: Int32) -> Int {
198+
Glibc.sendmsg(fd, message, flags)
199+
}
200+
}
201+
202+
struct in6_pktinfo {
203+
var ipi6_addr: in6_addr
204+
var ipi6_ifindex: CUnsignedInt
175205
}
176206

177207
#endif

FlyingSocks/Sources/Socket+Musl.swift

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,10 @@ import Musl
3434

3535
public extension Socket {
3636
typealias FileDescriptorType = Int32
37+
typealias IovLengthType = Int
38+
typealias ControlMessageHeaderLengthType = UInt32
39+
typealias IPv4InterfaceIndexType = Int32
40+
typealias IPv6InterfaceIndexType = UInt32
3741
}
3842

3943
extension Socket.FileDescriptor {
@@ -42,7 +46,12 @@ extension Socket.FileDescriptor {
4246

4347
extension Socket {
4448
static let stream = Int32(SOCK_STREAM)
49+
static let datagram = Int32(SOCK_DGRAM)
4550
static let in_addr_any = Musl.in_addr(s_addr: Musl.in_addr_t(0))
51+
static let ipproto_ip = Int32(IPPROTO_IP)
52+
static let ipproto_ipv6 = Int32(IPPROTO_IPV6)
53+
static let ip_pktinfo = Int32(IP_PKTINFO)
54+
static let ipv6_pktinfo = Int32(IPV6_PKTINFO)
4655

4756
static func makeAddressINET(port: UInt16) -> Musl.sockaddr_in {
4857
Musl.sockaddr_in(
@@ -172,6 +181,22 @@ extension Socket {
172181
static func pollfd(fd: FileDescriptorType, events: Int16, revents: Int16) -> Musl.pollfd {
173182
Musl.pollfd(fd: fd, events: events, revents: revents)
174183
}
184+
185+
static func recvfrom(_ fd: FileDescriptorType, _ buffer: UnsafeMutableRawPointer!, _ nbyte: Int, _ flags: Int32, _ addr: UnsafeMutablePointer<sockaddr>!, _ len: UnsafeMutablePointer<socklen_t>!) -> Int {
186+
Musl.recvfrom(fd, buffer, nbyte, flags, addr, len)
187+
}
188+
189+
static func sendto(_ fd: FileDescriptorType, _ buffer: UnsafeRawPointer!, _ nbyte: Int, _ flags: Int32, _ destaddr: UnsafePointer<sockaddr>!, _ destlen: socklen_t) -> Int {
190+
Musl.sendto(fd, buffer, nbyte, flags, destaddr, destlen)
191+
}
192+
193+
static func recvmsg(_ fd: FileDescriptorType, _ message: UnsafeMutablePointer<msghdr>, _ flags: Int32) -> Int {
194+
Musl.recvmsg(fd, message, flags)
195+
}
196+
197+
static func sendmsg(_ fd: FileDescriptorType, _ message: UnsafePointer<msghdr>, _ flags: Int32) -> Int {
198+
Musl.sendmsg(fd, message, flags)
199+
}
175200
}
176201

177202
#endif

0 commit comments

Comments
 (0)