diff --git a/templates/apple/Package.swift.twig b/templates/apple/Package.swift.twig index 132ae387f..a8cd6504e 100644 --- a/templates/apple/Package.swift.twig +++ b/templates/apple/Package.swift.twig @@ -22,7 +22,7 @@ let package = Package( ), ], dependencies: [ - .package(url: "https://github.com/swift-server/async-http-client.git", from: "1.9.0"), + .package(url: "https://github.com/swift-server/async-http-client.git", from: "1.17.0"), .package(url: "https://github.com/apple/swift-nio.git", from: "2.32.0"), ], targets: [ diff --git a/templates/swift/Sources/Models/RealtimeModels.swift.twig b/templates/swift/Sources/Models/RealtimeModels.swift.twig index 3a9c3e298..5129b4cd8 100644 --- a/templates/swift/Sources/Models/RealtimeModels.swift.twig +++ b/templates/swift/Sources/Models/RealtimeModels.swift.twig @@ -1,11 +1,15 @@ import Foundation public class RealtimeSubscription { - public var close: () -> Void + private var close: () async throws -> Void - init(close: @escaping () -> Void) { + init(close: @escaping () async throws-> Void) { self.close = close } + + public func close() async throws { + try await self.close() + } } public class RealtimeCallback { @@ -14,7 +18,7 @@ public class RealtimeCallback { init( for channels: Set, - and callback: @escaping (RealtimeResponseEvent) -> Void + with callback: @escaping (RealtimeResponseEvent) -> Void ) { self.channels = channels self.callback = callback diff --git a/templates/swift/Sources/Services/Realtime.swift.twig b/templates/swift/Sources/Services/Realtime.swift.twig index 543e1ffc8..7838fc9fa 100644 --- a/templates/swift/Sources/Services/Realtime.swift.twig +++ b/templates/swift/Sources/Services/Realtime.swift.twig @@ -7,24 +7,23 @@ open class Realtime : Service { private let TYPE_ERROR = "error" private let TYPE_EVENT = "event" - private let DEBOUNCE_MILLIS = 1 + private let DEBOUNCE_NANOS = 1_000_000 private var socketClient: WebSocketClient? = nil private var activeChannels = Set() private var activeSubscriptions = [Int: RealtimeCallback]() let connectSync = DispatchQueue(label: "ConnectSync") - let callbackSync = DispatchQueue(label: "CallbackSync") private var subCallDepth = 0 private var reconnectAttempts = 0 private var subscriptionsCounter = 0 private var reconnect = true - private func createSocket() { + private func createSocket() async throws { guard activeChannels.count > 0 else { reconnect = false - closeSocket() + try await closeSocket() return } @@ -38,17 +37,31 @@ open class Realtime : Service { if (socketClient != nil) { reconnect = false - closeSocket() - } else { - socketClient = WebSocketClient(url, tlsEnabled: !client.selfSigned, delegate: self)! + try await closeSocket() } - try! socketClient?.connect() + socketClient = WebSocketClient( + url, + tlsEnabled: !client.selfSigned, + delegate: self + ) + + try await socketClient?.connect() } - private func closeSocket() { - socketClient?.close() - //socket?.close(RealtimeCode.POLICY_VIOLATION.value, null) + private func closeSocket() async throws { + guard let client = socketClient, + let group = client.threadGroup else { + return + } + + if (client.isConnected) { + let promise = group.any().makePromise(of: Void.self) + client.close(promise: promise) + try await promise.futureResult.get() + } + + try await group.shutdownGracefully() } private func getTimeout() -> Int { @@ -63,8 +76,8 @@ open class Realtime : Service { public func subscribe( channel: String, callback: @escaping (RealtimeResponseEvent) -> Void - ) -> RealtimeSubscription { - return subscribe( + ) async throws -> RealtimeSubscription { + return try await subscribe( channels: [channel], payloadType: String.self, callback: callback @@ -74,8 +87,8 @@ open class Realtime : Service { public func subscribe( channels: Set, callback: @escaping (RealtimeResponseEvent) -> Void - ) -> RealtimeSubscription { - return subscribe( + ) async throws -> RealtimeSubscription { + return try await subscribe( channels: channels, payloadType: String.self, callback: callback @@ -86,8 +99,8 @@ open class Realtime : Service { channel: String, payloadType: T.Type, callback: @escaping (RealtimeResponseEvent) -> Void - ) -> RealtimeSubscription { - return subscribe( + ) async throws -> RealtimeSubscription { + return try await subscribe( channels: [channel], payloadType: T.self, callback: callback @@ -98,36 +111,38 @@ open class Realtime : Service { channels: Set, payloadType: T.Type, callback: @escaping (RealtimeResponseEvent) -> Void - ) -> RealtimeSubscription { + ) async throws -> RealtimeSubscription { subscriptionsCounter += 1 - let counter = subscriptionsCounter + + let count = subscriptionsCounter channels.forEach { activeChannels.insert($0) } - activeSubscriptions[counter] = RealtimeCallback( + activeSubscriptions[count] = RealtimeCallback( for: Set(channels), - and: callback + with: callback ) connectSync.sync { subCallDepth+=1 } - DispatchQueue.main.asyncAfter(deadline: .now() + .milliseconds(DEBOUNCE_MILLIS)) { - if (self.subCallDepth == 1) { - self.createSocket() - } - self.connectSync.sync { - self.subCallDepth-=1 - } + try await Task.sleep(nanoseconds: UInt64(DEBOUNCE_NANOS)) + + if self.subCallDepth == 1 { + try await self.createSocket() + } + + connectSync.sync { + self.subCallDepth -= 1 } return RealtimeSubscription { - self.activeSubscriptions[counter] = nil + self.activeSubscriptions[count] = nil self.cleanUp(channels: channels) - self.createSocket() + try await self.createSocket() } } @@ -163,7 +178,7 @@ extension Realtime: WebSocketClientDelegate { } } - public func onClose(channel: Channel, data: Data) { + public func onClose(channel: Channel, data: Data) async throws { if (!reconnect) { reconnect = true return @@ -173,10 +188,11 @@ extension Realtime: WebSocketClientDelegate { print("Realtime disconnected. Re-connecting in \(timeout / 1000) seconds.") - DispatchQueue.main.asyncAfter(deadline: .now() + .milliseconds(timeout)) { - self.reconnectAttempts += 1 - self.createSocket() - } + try await Task.sleep(nanoseconds: UInt64(timeout * 1_000_000)) + + self.reconnectAttempts += 1 + + try await self.createSocket() } public func onError(error: Swift.Error?, status: HTTPResponseStatus?) { @@ -188,16 +204,10 @@ extension Realtime: WebSocketClientDelegate { } func handleResponseEvent(from json: [String: Any]) { - guard let data = json["data"] as? [String: Any] else { - return - } - guard let channels = data["channels"] as? Array else { - return - } - guard let events = data["events"] as? Array else { - return - } - guard let payload = data["payload"] as? [String: Any] else { + guard let data = json["data"] as? [String: Any], + let channels = data["channels"] as? [String], + let events = data["events"] as? [String], + let payload = data["payload"] as? [String: Any] else { return } guard channels.contains(where: { channel in diff --git a/templates/swift/Sources/WebSockets/WebSocketClient.swift.twig b/templates/swift/Sources/WebSockets/WebSocketClient.swift.twig index 7e8e26fa1..72322b784 100644 --- a/templates/swift/Sources/WebSockets/WebSocketClient.swift.twig +++ b/templates/swift/Sources/WebSockets/WebSocketClient.swift.twig @@ -7,6 +7,8 @@ import NIOFoundationCompat import NIOSSL public let WEBSOCKET_LOCKER_QUEUE = "SyncLocker" +public let WEBSOCKET_THREAD_QUEUE = "ThreadLocker" +public let WEBSOCKET_CHANNEL_QUEUE = "ChannelLocker" /// Creates and manages connections to a WebSocket server. /// @@ -20,16 +22,35 @@ public class WebSocketClient { let query: String let headers: HTTPHeaders let frameKey: String - + public private(set) var maxFrameSize: Int - - var channel: Channel? = nil + var tlsEnabled: Bool = false var closeSent: Bool = false - let locker = DispatchQueue(label: WEBSOCKET_LOCKER_QUEUE, qos: .background) + private let locker = DispatchQueue(label: WEBSOCKET_LOCKER_QUEUE, qos: .background) + private let channelQueue = DispatchQueue(label: WEBSOCKET_CHANNEL_QUEUE) + private let threadGroupQueue = DispatchQueue(label: WEBSOCKET_THREAD_QUEUE) - var threadGroup: MultiThreadedEventLoopGroup? = nil + var channel: Channel? { + get { + return channelQueue.sync { _channel } + } + set { + channelQueue.sync { _channel = newValue } + } + } + private var _channel: Channel? = nil + + var threadGroup: MultiThreadedEventLoopGroup? { + get { + return threadGroupQueue.sync { _threadGroup } + } + set { + threadGroupQueue.sync { _threadGroup = newValue } + } + } + private var _threadGroup: MultiThreadedEventLoopGroup? weak var delegate: WebSocketClientDelegate? = nil @@ -216,45 +237,45 @@ public class WebSocketClient { self.threadGroup = MultiThreadedEventLoopGroup(numberOfThreads: 1) } } - - deinit { - try! threadGroup!.syncShutdownGracefully() - } // MARK: - Open connection - + /// Open a connection to the configured host and attempt to upgrade the connection to a WebSocket. If successful the `onOpen` callback will fire, otherwise a connection error will be thrown from here. - public func connect() throws { + public func connect() async throws { let socketOptions = ChannelOptions.socket( SocketOptionLevel(SOL_SOCKET), SO_REUSEPORT ) - while(threadGroup == nil) {} - + while(threadGroup == nil) { + try? await Task.sleep(nanoseconds: 10_000_000) + } + let bootstrap = ClientBootstrap(group: threadGroup!) .channelOption(socketOptions, value: 1) - .channelInitializer(self.openChannel) - - _ = try bootstrap - .connect(host: self.host, port: self.port) - .wait() + .channelInitializer { + self.openChannel(channel: $0) + } + + _ = try await bootstrap + .connect(host: self.host,port: self.port) + .get() } private func openChannel(channel: Channel) -> EventLoopFuture { let httpHandler = HTTPHandler(client: self, headers: headers) - + let basicUpgrader = NIOWebSocketClientUpgrader( requestKey: self.frameKey, upgradePipelineHandler: { channel, response in self.upgradePipelineHandler(channel: channel, response: response) } ) - + let config: NIOHTTPClientUpgradeConfiguration = (upgraders: [basicUpgrader], completionHandler: { context in context.channel.pipeline.removeHandler(httpHandler, promise: nil) }) - + return channel.pipeline.addHTTPClientHandlers(withClientUpgrade: config).flatMap { _ in return channel.pipeline.addHandler(httpHandler).flatMap { _ in if self.tlsEnabled { @@ -269,39 +290,43 @@ public class WebSocketClient { } } - @Sendable private func upgradePipelineHandler(channel: Channel, response: HTTPResponseHead) -> EventLoopFuture { + private func upgradePipelineHandler(channel: Channel, response: HTTPResponseHead) -> EventLoopFuture { let handler = MessageHandler(client: self) - + if response.status == .switchingProtocols { self.channel = channel } - + return channel.pipeline.addHandler(handler) } // MARK: - Close connection - + /// Closes the connection /// /// - parameters: /// - data: Close frame payload - public func close(data: Data = Data()) { + public func close( + data: Data = Data(), + promise: EventLoopPromise? = nil + ) { closeSent = true - + var buffer = ByteBufferAllocator() .buffer(capacity: data.count) - + buffer.writeBytes(data) - + send( data: buffer, opcode: .connectionClose, - finalFrame: true + finalFrame: true, + promise: promise ) } - + // MARK: - Send data - + /// Sends binary-formatted data to the connected server in multiple frames. /// /// - parameters: @@ -311,21 +336,23 @@ public class WebSocketClient { public func send( data: Data, opcode: WebSocketOpcode, - finalFrame: Bool = true + finalFrame: Bool = true, + promise: EventLoopPromise? = nil ) { var buffer = ByteBufferAllocator() .buffer(capacity: data.count) - + buffer.writeBytes(data) - + if opcode == .connectionClose { self.closeSent = true } - + send( data: buffer, opcode: opcode, - finalFrame: finalFrame + finalFrame: finalFrame, + promise: promise ) } @@ -338,21 +365,23 @@ public class WebSocketClient { public func send( text: String, opcode: WebSocketOpcode = .text, - finalFrame: Bool = true + finalFrame: Bool = true, + promise: EventLoopPromise? = nil ) { var buffer = ByteBufferAllocator() .buffer(capacity: text.count) - + buffer.writeString(text) - + send( data: buffer, opcode: opcode, - finalFrame: finalFrame + finalFrame: finalFrame, + promise: promise ) } - + /// Sends the JSON representation of the given model to the connected server in multiple frames. /// /// - parameters: @@ -362,7 +391,8 @@ public class WebSocketClient { public func send( model: T, opcode: WebSocketOpcode = .text, - finalFrame: Bool = true + finalFrame: Bool = true, + promise: EventLoopPromise? = nil ) { let jsonEncoder = JSONEncoder() do { @@ -370,13 +400,14 @@ public class WebSocketClient { let string = String(data: jsonData, encoding: .utf8)! var buffer = ByteBufferAllocator() .buffer(capacity: string.count) - + buffer.writeString(string) - + send( data: buffer, opcode: opcode, - finalFrame: finalFrame + finalFrame: finalFrame, + promise: promise ) } catch let error { print(error) @@ -392,7 +423,8 @@ public class WebSocketClient { public func send( data: ByteBuffer, opcode: WebSocketOpcode, - finalFrame: Bool + finalFrame: Bool, + promise: EventLoopPromise? = nil ) { let frame = WebSocketFrame( fin: finalFrame, @@ -400,17 +432,19 @@ public class WebSocketClient { maskKey: nil, data: data ) + guard let channel = channel else { return } + if finalFrame { - channel.writeAndFlush(frame, promise: nil) + channel.writeAndFlush(frame, promise: promise) } else { - channel.write(frame, promise: nil) + channel.write(frame, promise: promise) } - + if opcode == .connectionClose { - channel.close(mode: .all, promise: nil) + channel.close(mode: .all, promise: promise) } } } diff --git a/templates/swift/example-swiftui/Shared/ExampleView.swift b/templates/swift/example-swiftui/Shared/ExampleView.swift index 792f3b308..0c3d2de7a 100644 --- a/templates/swift/example-swiftui/Shared/ExampleView.swift +++ b/templates/swift/example-swiftui/Shared/ExampleView.swift @@ -20,6 +20,9 @@ struct ExampleView: View { TextField("", text: $viewModel.response, axis: .vertical) .padding() + TextField("", text: $viewModel.response2, axis: .vertical) + .padding() + Button("Login") { Task { await viewModel.login() } } @@ -41,7 +44,7 @@ struct ExampleView: View { } Button("Subscribe") { - viewModel.subscribe() + Task { await viewModel.subscribe() } } } #if os(macOS) diff --git a/templates/swift/example-swiftui/Shared/ExampleViewModel.swift b/templates/swift/example-swiftui/Shared/ExampleViewModel.swift index 817a9064f..3c6de090c 100644 --- a/templates/swift/example-swiftui/Shared/ExampleViewModel.swift +++ b/templates/swift/example-swiftui/Shared/ExampleViewModel.swift @@ -15,8 +15,10 @@ extension ExampleView { @Published public var fileId: String = "test" @Published public var databaseId: String = "test" @Published public var collectionId: String = "test" + @Published public var collectionId2: String = "test2" @Published public var isShowPhotoLibrary = false @Published public var response: String = "" + @Published public var response2: String = "" func register() async { do { @@ -127,13 +129,25 @@ extension ExampleView { } } } - - func subscribe() { - _ = realtime.subscribe(channels: ["databases.\(databaseId).collections.\(collectionId).documents"]) { event in + + func subscribe() async { + let sub1 = try? await realtime.subscribe(channels: ["databases.\(databaseId).collections.\(collectionId).documents"]) { event in DispatchQueue.main.async { self.response = String(describing: event.payload!) } } + + try? await Task.sleep(nanoseconds: UInt64(500_000_000)) + + _ = try? await realtime.subscribe(channels: ["databases.\(databaseId).collections.\(collectionId2).documents"]) { event in + DispatchQueue.main.async { + self.response2 = String(describing: event.payload!) + } + } + + try? await Task.sleep(nanoseconds: UInt64(500_000_000)) + + try? await sub1?.close() } } } diff --git a/tests/languages/apple/Tests.swift b/tests/languages/apple/Tests.swift index 86dde4c46..b7c09c7e4 100644 --- a/tests/languages/apple/Tests.swift +++ b/tests/languages/apple/Tests.swift @@ -34,7 +34,7 @@ class Tests: XCTestCase { let expectation = XCTestExpectation(description: "realtime server") - realtime.subscribe(channels: ["tests"]) { message in + try await realtime.subscribe(channels: ["tests"]) { message in realtimeResponse = message.payload!["response"] as! String expectation.fulfill() }