diff --git a/ios/PacketTunnelCore/Pinger/TunnelPinger.swift b/ios/PacketTunnelCore/Pinger/TunnelPinger.swift index 9b9162f2e516..47c5d8734ae0 100644 --- a/ios/PacketTunnelCore/Pinger/TunnelPinger.swift +++ b/ios/PacketTunnelCore/Pinger/TunnelPinger.swift @@ -14,24 +14,12 @@ import WireGuardKit public final class TunnelPinger: PingerProtocol { private var sequenceNumber: UInt16 = 0 - private let stateLock = NSRecursiveLock() + private let stateLock = NSLock() private let pingReceiveQueue: DispatchQueue private let replyQueue: DispatchQueue private var destAddress: IPv4Address? - private var _onReply: ((PingerReply) -> Void)? - public var onReply: ((PingerReply) -> Void)? { - get { - stateLock.withLock { - return _onReply - } - } - set { - stateLock.withLock { - _onReply = newValue - } - } - } - + /// Always accessed from the `replyQueue` and is assigned once, on the main thread of the PacketTunnel. It is thread safe. + public var onReply: ((PingerReply) -> Void)? private var pingProvider: ICMPPingProvider private let logger: Logger @@ -49,18 +37,26 @@ public final class TunnelPinger: PingerProtocol { public func openSocket(bindTo interfaceName: String?, destAddress: IPv4Address) throws { try pingProvider.openICMP(address: destAddress) - self.destAddress = destAddress + stateLock.withLock { + self.destAddress = destAddress + } pingReceiveQueue.async { [weak self] in while let self { do { let seq = try pingProvider.receiveICMP() replyQueue.async { [weak self] in - self?.onReply?(PingerReply.success(destAddress, UInt16(seq))) + self?.stateLock.withLock { + self?.onReply?(PingerReply.success(destAddress, UInt16(seq))) + } } } catch { replyQueue.async { [weak self] in - self?.onReply?(PingerReply.parseError(error)) + self?.stateLock.withLock { + if self?.destAddress != nil { + self?.onReply?(PingerReply.parseError(error)) + } + } } return } @@ -69,14 +65,18 @@ public final class TunnelPinger: PingerProtocol { } public func closeSocket() { - pingProvider.closeICMP() - self.destAddress = nil + stateLock.withLock { + self.destAddress = nil + pingProvider.closeICMP() + } } public func send() throws -> PingerSendResult { let sequenceNumber = nextSequenceNumber() - guard let destAddress else { throw WireGuardAdapterError.invalidState } + stateLock.lock() + defer { stateLock.unlock() } + guard destAddress != nil else { throw WireGuardAdapterError.invalidState } // NOTE: we cheat here by returning the destination address we were passed, rather than parsing it from the packet on the other side of the FFI boundary. try pingProvider.sendICMPPing(seqNumber: sequenceNumber)