Skip to content

Commit 9d21557

Browse files
committed
Support for datagram (UDP, message) sockets
1 parent af1345a commit 9d21557

File tree

8 files changed

+225
-2
lines changed

8 files changed

+225
-2
lines changed
Lines changed: 127 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,127 @@
1+
//
2+
// AsyncMessageSocket.swift
3+
// FlyingFox
4+
//
5+
// Created by Luke Howard on 11/11/2024.
6+
// Copyright © 2024 PADL Software Pty Ltd. All rights reserved.
7+
//
8+
// Distributed under the permissive MIT license
9+
// Get the latest version from here:
10+
//
11+
// https://github.com/swhitty/FlyingFox
12+
//
13+
// Permission is hereby granted, free of charge, to any person obtaining a copy
14+
// of this software and associated documentation files (the "Software"), to deal
15+
// in the Software without restriction, including without limitation the rights
16+
// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
17+
// copies of the Software, and to permit persons to whom the Software is
18+
// furnished to do so, subject to the following conditions:
19+
//
20+
// The above copyright notice and this permission notice shall be included in all
21+
// copies or substantial portions of the Software.
22+
//
23+
// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
24+
// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
25+
// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
26+
// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
27+
// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
28+
// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
29+
// SOFTWARE.
30+
//
31+
32+
import Foundation
33+
34+
public struct AsyncMessageSocket<A: SocketAddress>: Sendable {
35+
public let socket: Socket
36+
37+
let pool: any AsyncSocketPool
38+
39+
public struct Message: Sendable {
40+
public let address: A
41+
public let data: Data
42+
43+
public init(address: A, data: Data) {
44+
self.address = address
45+
self.data = data
46+
}
47+
}
48+
49+
public init(socket: Socket, pool: some AsyncSocketPool) throws {
50+
self.socket = socket
51+
self.pool = pool
52+
try socket.setFlags(.nonBlocking)
53+
}
54+
55+
public static func connected(to address: A, timeout: TimeInterval = 5) async throws -> Self {
56+
try await connected(
57+
to: address,
58+
pool: ClientPoolLoader.shared.getPool(),
59+
timeout: timeout
60+
)
61+
}
62+
63+
public static func connected(to address: A,
64+
pool: some AsyncSocketPool,
65+
timeout: TimeInterval = 5) async throws -> Self {
66+
try await withThrowingTimeout(seconds: timeout) {
67+
let socket = try Socket(domain: Int32(type(of: address).family), type: Socket.datagram)
68+
let asyncMessageSocket = try AsyncMessageSocket(socket: socket, pool: pool)
69+
try await asyncMessageSocket.connect(to: address)
70+
return asyncMessageSocket
71+
}
72+
}
73+
74+
public func connect(to address: A) async throws {
75+
return try await pool.loopUntilReady(for: [.write], on: socket) {
76+
try socket.connect(to: address)
77+
}
78+
}
79+
80+
public func receive(atMost length: Int = 4096) async throws -> Message {
81+
try Task.checkCancellation()
82+
83+
repeat {
84+
do {
85+
let (address, bytes): (A, [UInt8]) = try socket.receive(length: length)
86+
return Message(address: address, data: Data(bytes))
87+
} catch SocketError.blocked {
88+
try await pool.suspendSocket(socket, untilReadyFor: .read)
89+
} catch {
90+
throw error
91+
}
92+
} while true
93+
}
94+
95+
public func send(message: Message) async throws {
96+
let sent = try await pool.loopUntilReady(for: .write, on: socket) {
97+
try socket.send(Array(message.data), to: message.address)
98+
}
99+
guard sent == message.data.count else {
100+
throw SocketError.disconnected
101+
}
102+
}
103+
104+
public func close() throws {
105+
try socket.close()
106+
}
107+
108+
public var messages: AsyncSocketMessageSequence<A> {
109+
AsyncSocketMessageSequence(socket: self)
110+
}
111+
}
112+
113+
public struct AsyncSocketMessageSequence<A: SocketAddress>: AsyncSequence, AsyncIteratorProtocol, Sendable {
114+
public typealias Element = AsyncMessageSocket<A>.Message
115+
116+
let socket: AsyncMessageSocket<A>
117+
118+
public func makeAsyncIterator() -> Self { self }
119+
120+
public mutating func next() async throws -> Element? {
121+
return try await socket.receive()
122+
}
123+
124+
public func nextMessage(atMost length: Int) async throws -> Element? {
125+
return try await socket.receive(atMost: length)
126+
}
127+
}

FlyingSocks/Sources/AsyncSocket.swift

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -186,7 +186,7 @@ package extension AsyncSocket {
186186
}
187187
}
188188

189-
private extension AsyncSocketPool {
189+
extension AsyncSocketPool {
190190

191191
func loopUntilReady<T>(for events: Socket.Events, on socket: Socket, body: () throws -> T) async throws -> T {
192192
var result: T?
@@ -237,7 +237,7 @@ public struct AsyncSocketSequence: AsyncSequence, AsyncIteratorProtocol, Sendabl
237237
}
238238
}
239239

240-
private actor ClientPoolLoader {
240+
actor ClientPoolLoader {
241241
static let shared = ClientPoolLoader()
242242

243243
private let pool: some AsyncSocketPool = SocketPool.make()

FlyingSocks/Sources/Socket+Android.swift

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@ extension Socket.FileDescriptor {
4545

4646
extension Socket {
4747
static let stream = Int32(SOCK_STREAM)
48+
static let datagram = Int32(SOCK_DGRAM)
4849
static let in_addr_any = Android.in_addr(s_addr: Android.in_addr_t(0))
4950

5051
static func makeAddressINET(port: UInt16) -> Android.sockaddr_in {
@@ -175,6 +176,14 @@ extension Socket {
175176
static func pollfd(fd: FileDescriptorType, events: Int16, revents: Int16) -> Android.pollfd {
176177
Android.pollfd(fd: fd, events: events, revents: revents)
177178
}
179+
180+
static func recvfrom(_ fd: FileDescriptorType, _ buffer: UnsafeMutableRawPointer!, _ nbyte: Int, _ flags: Int32, _ addr: UnsafeMutablePointer<sockaddr>!, _ len: UnsafeMutablePointer<socklen_t>!) -> Int {
181+
Android.recvfrom(fd, buffer, nbyte, flags, addr, len)
182+
}
183+
184+
static func sendto(_ fd: FileDescriptorType, _ buffer: UnsafeRawPointer!, _ nbyte: Int, _ flags: Int32, _ destaddr: UnsafePointer<sockaddr>!, _ destlen: socklen_t) -> Int {
185+
Android.sendto(fd, buffer, nbyte, flags, destaddr, destlen)
186+
}
178187
}
179188

180189
#endif

FlyingSocks/Sources/Socket+Darwin.swift

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@ extension Socket.FileDescriptor {
4242

4343
extension Socket {
4444
static let stream = Int32(SOCK_STREAM)
45+
static let datagram = Int32(SOCK_DGRAM)
4546
static let in_addr_any = Darwin.in_addr(s_addr: Darwin.in_addr_t(0))
4647

4748
static func makeAddressINET(port: UInt16) -> Darwin.sockaddr_in {
@@ -176,6 +177,14 @@ extension Socket {
176177
static func pollfd(fd: FileDescriptorType, events: Int16, revents: Int16) -> Darwin.pollfd {
177178
Darwin.pollfd(fd: fd, events: events, revents: revents)
178179
}
180+
181+
static func recvfrom(_ fd: FileDescriptorType, _ buffer: UnsafeMutableRawPointer!, _ nbyte: Int, _ flags: Int32, _ addr: UnsafeMutablePointer<sockaddr>!, _ len: UnsafeMutablePointer<socklen_t>!) -> Int {
182+
Darwin.recvfrom(fd, buffer, nbyte, flags, addr, len)
183+
}
184+
185+
static func sendto(_ fd: FileDescriptorType, _ buffer: UnsafeRawPointer!, _ nbyte: Int, _ flags: Int32, _ destaddr: UnsafePointer<sockaddr>!, _ destlen: socklen_t) -> Int {
186+
Darwin.sendto(fd, buffer, nbyte, flags, destaddr, destlen)
187+
}
179188
}
180189

181190
#endif

FlyingSocks/Sources/Socket+Glibc.swift

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@ extension Socket.FileDescriptor {
4242

4343
extension Socket {
4444
static let stream = Int32(SOCK_STREAM.rawValue)
45+
static let datagram = Int32(SOCK_DGRAM)
4546
static let in_addr_any = Glibc.in_addr(s_addr: Glibc.in_addr_t(0))
4647

4748
static func makeAddressINET(port: UInt16) -> Glibc.sockaddr_in {
@@ -172,6 +173,14 @@ extension Socket {
172173
static func pollfd(fd: FileDescriptorType, events: Int16, revents: Int16) -> Glibc.pollfd {
173174
Glibc.pollfd(fd: fd, events: events, revents: revents)
174175
}
176+
177+
static func recvfrom(_ fd: FileDescriptorType, _ buffer: UnsafeMutableRawPointer!, _ nbyte: Int, _ flags: Int32, _ addr: UnsafeMutablePointer<sockaddr>!, _ len: UnsafeMutablePointer<socklen_t>!) -> Int {
178+
Glibc.recvfrom(fd, buffer, nbyte, flags, addr, len)
179+
}
180+
181+
static func sendto(_ fd: FileDescriptorType, _ buffer: UnsafeRawPointer!, _ nbyte: Int, _ flags: Int32, _ destaddr: UnsafePointer<sockaddr>!, _ destlen: socklen_t) -> Int {
182+
Glibc.sendto(fd, buffer, nbyte, flags, destaddr, destlen)
183+
}
175184
}
176185

177186
#endif

FlyingSocks/Sources/Socket+Musl.swift

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@ extension Socket.FileDescriptor {
4242

4343
extension Socket {
4444
static let stream = Int32(SOCK_STREAM)
45+
static let datagram = Int32(SOCK_DGRAM)
4546
static let in_addr_any = Musl.in_addr(s_addr: Musl.in_addr_t(0))
4647

4748
static func makeAddressINET(port: UInt16) -> Musl.sockaddr_in {
@@ -172,6 +173,14 @@ extension Socket {
172173
static func pollfd(fd: FileDescriptorType, events: Int16, revents: Int16) -> Musl.pollfd {
173174
Musl.pollfd(fd: fd, events: events, revents: revents)
174175
}
176+
177+
static func recvfrom(_ fd: FileDescriptorType, _ buffer: UnsafeMutableRawPointer!, _ nbyte: Int, _ flags: Int32, _ addr: UnsafeMutablePointer<sockaddr>!, _ len: UnsafeMutablePointer<socklen_t>!) -> Int {
178+
Musl.recvfrom(fd, buffer, nbyte, flags, addr, len)
179+
}
180+
181+
static func sendto(_ fd: FileDescriptorType, _ buffer: UnsafeRawPointer!, _ nbyte: Int, _ flags: Int32, _ destaddr: UnsafePointer<sockaddr>!, _ destlen: socklen_t) -> Int {
182+
Musl.sendto(fd, buffer, nbyte, flags, destaddr, destlen)
183+
}
175184
}
176185

177186
#endif

FlyingSocks/Sources/Socket+WinSock2.swift

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,7 @@ extension Socket.FileDescriptor {
5252

5353
extension Socket {
5454
static let stream = Int32(SOCK_STREAM)
55+
static let datagram = Int32(SOCK_DGRAM)
5556
static let in_addr_any = WinSDK.in_addr()
5657

5758
static func makeAddressINET(port: UInt16) -> WinSDK.sockaddr_in {
@@ -184,6 +185,14 @@ extension Socket {
184185
static func pollfd(fd: FileDescriptorType, events: Int16, revents: Int16) -> WinSDK.WSAPOLLFD {
185186
WinSDK.WSAPOLLFD(fd: fd, events: events, revents: revents)
186187
}
188+
189+
static func recvfrom(_ fd: FileDescriptorType, _ buffer: UnsafeMutableRawPointer!, _ nbyte: Int, _ flags: Int32, _ addr: UnsafeMutablePointer<sockaddr>!, _ len: UnsafeMutablePointer<socklen_t>!) -> Int {
190+
WinSDK.recvfrom(fd, buffer, nbyte, flags, addr, len)
191+
}
192+
193+
static func sendto(_ fd: FileDescriptorType, _ buffer: UnsafeRawPointer!, _ nbyte: Int, _ flags: Int32, _ destaddr: UnsafePointer<sockaddr>!, _ destlen: socklen_t) -> Int {
194+
WinSDK.sendto(fd, buffer, nbyte, flags, destaddr, destlen)
195+
}
187196
}
188197

189198
#endif

FlyingSocks/Sources/Socket.swift

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -237,6 +237,57 @@ public struct Socket: Sendable, Hashable {
237237
return sent
238238
}
239239

240+
public func receive<A: SocketAddress>(length: Int) throws -> (A, [UInt8]) {
241+
var address: A?
242+
let bytes = try [UInt8](unsafeUninitializedCapacity: length) { buffer, count in
243+
(address, count) = try receive(into: buffer.baseAddress!, length: length)
244+
}
245+
return (address!, bytes)
246+
}
247+
248+
private func receive<A: SocketAddress>(into buffer: UnsafeMutablePointer<UInt8>, length: Int) throws -> (A, Int) {
249+
var addr = sockaddr_storage()
250+
var size = socklen_t(MemoryLayout<sockaddr_storage>.size)
251+
let count = withUnsafeMutablePointer(to: &addr) {
252+
$0.withMemoryRebound(to: sockaddr.self, capacity: 1) {
253+
Socket.recvfrom(file.rawValue, buffer, length, 0, $0, &size)
254+
}
255+
}
256+
guard count > 0 else {
257+
if errno == EWOULDBLOCK {
258+
throw SocketError.blocked
259+
} else if errno == EBADF || count == 0 {
260+
throw SocketError.disconnected
261+
} else {
262+
throw SocketError.makeFailed("RecvFrom")
263+
}
264+
}
265+
return (try A.make(from: addr), count)
266+
}
267+
268+
public func send(_ bytes: [UInt8], to address: some SocketAddress) throws -> Int {
269+
try bytes.withUnsafeBytes { buffer in
270+
try send(buffer.baseAddress!, length: bytes.count, to: address)
271+
}
272+
}
273+
274+
private func send<A: SocketAddress>(_ pointer: UnsafeRawPointer, length: Int, to address: A) throws -> Int {
275+
var addr = address
276+
let sent = withUnsafePointer(to: &addr) {
277+
$0.withMemoryRebound(to: sockaddr.self, capacity: 1) {
278+
Socket.sendto(file.rawValue, pointer, length, 0, $0, socklen_t(MemoryLayout<A>.size))
279+
}
280+
}
281+
guard sent >= 0 || errno == EISCONN else {
282+
if errno == EINPROGRESS {
283+
throw SocketError.blocked
284+
} else {
285+
throw SocketError.makeFailed("SendTo")
286+
}
287+
}
288+
return sent
289+
}
290+
240291
public func close() throws {
241292
if Socket.close(file.rawValue) == -1 {
242293
throw SocketError.makeFailed("Close")

0 commit comments

Comments
 (0)