Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(predictions): add web socket retry for clock skew #3816

Merged
merged 2 commits into from
Aug 19, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -131,17 +131,17 @@ class AWSTranscribeStreamingAdapter: AWSTranscribeStreamingBehavior {
continuation.yield(transcribedPayload)
let isPartial = transcribedPayload.transcript?.results?.map(\.isPartial) ?? []
let shouldContinue = isPartial.allSatisfy { $0 }
return shouldContinue
return shouldContinue ? .continueToReceive : .stopAndInvalidateSession
phantumcode marked this conversation as resolved.
Show resolved Hide resolved
} catch {
return true
return .continueToReceive
}
case .success(.string):
return true
return .continueToReceive
case .failure(let error):
continuation.finish(throwing: error)
return false
return .stopAndInvalidateSession
@unknown default:
return true
return .continueToReceive
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,14 @@ public final class FaceLivenessSession: LivenessService {
let baseURL: URL
var serverEventListeners: [LivenessEventKind.Server: (FaceLivenessSession.SessionConfiguration) -> Void] = [:]
var onComplete: (ServerDisconnection) -> Void = { _ in }
var serverDate: Date?
var savedURLForReconnect: URL?
var connectingState: ConnectingState = .normal

enum ConnectingState {
case normal
case reconnect
}

private let livenessServiceDispatchQueue = DispatchQueue(
label: "com.amazon.aws.amplify.liveness.service",
Expand All @@ -35,12 +43,16 @@ public final class FaceLivenessSession: LivenessService {
self.websocket = websocket

websocket.onMessageReceived { [weak self] result in
self?.receive(result: result) ?? false
self?.receive(result: result) ?? .stopAndInvalidateSession
}

websocket.onSocketClosed { [weak self] closeCode in
self?.onComplete(.unexpectedClosure(closeCode))
}

websocket.onServerDateReceived { [weak self] serverDate in
self?.serverDate = serverDate
}
}

public var onServiceException: (FaceLivenessSessionError) -> Void = { _ in }
Expand Down Expand Up @@ -75,6 +87,7 @@ public final class FaceLivenessSession: LivenessService {
guard let url = components?.url
else { throw FaceLivenessSessionError.invalidURL }

savedURLForReconnect = url
let signedConnectionURL = signer.sign(url: url)
websocket.open(url: signedConnectionURL)
}
Expand All @@ -93,17 +106,22 @@ public final class FaceLivenessSession: LivenessService {
]
)

let eventDate = eventDate()
let dateForSigning: Date
if let serverDate = serverDate {
dateForSigning = serverDate
} else {
dateForSigning = eventDate()
}

let signedPayload = self.signer.signWithPreviousSignature(
payload: encodedPayload,
dateHeader: (key: ":date", value: eventDate)
dateHeader: (key: ":date", value: dateForSigning)
)

let encodedEvent = self.eventStreamEncoder.encode(
payload: encodedPayload,
headers: [
":date": .timestamp(eventDate),
":date": .timestamp(dateForSigning),
":chunk-signature": .data(signedPayload)
]
)
Expand All @@ -115,7 +133,7 @@ public final class FaceLivenessSession: LivenessService {
}
}

private func fallbackDecoding(_ message: EventStream.Message) -> Bool {
private func fallbackDecoding(_ message: EventStream.Message) -> WebSocketSession.WebSocketMessageResult {
// We only care about two events above.
// Just in case the header value changes (it shouldn't)
// We'll try to decode each of these events
Expand All @@ -124,12 +142,12 @@ public final class FaceLivenessSession: LivenessService {
self.serverEventListeners[.challenge]?(sessionConfiguration)
} else if (try? JSONDecoder().decode(DisconnectEvent.self, from: message.payload)) != nil {
onComplete(.disconnectionEvent)
return false
return .stopAndInvalidateSession
}
return true
return .continueToReceive
}

private func receive(result: Result<URLSessionWebSocketTask.Message, Error>) -> Bool {
private func receive(result: Result<URLSessionWebSocketTask.Message, Error>) -> WebSocketSession.WebSocketMessageResult {
switch result {
case .success(.data(let data)):
do {
Expand All @@ -145,28 +163,41 @@ public final class FaceLivenessSession: LivenessService {
)
let sessionConfiguration = sessionConfiguration(from: payload)
serverEventListeners[.challenge]?(sessionConfiguration)
return true
return .continueToReceive
case .disconnect:
// :event-type DisconnectionEvent
onComplete(.disconnectionEvent)
return false
return .stopAndInvalidateSession
default:
return true
return .continueToReceive
}
} else if let exceptionType = message.headers.first(where: { $0.name == ":exception-type" }) {
let exceptionEvent = LivenessEventKind.Exception(rawValue: exceptionType.value)
onServiceException(.init(event: exceptionEvent))
return false
Amplify.log.verbose("\(#function): Received exception: \(exceptionEvent)")
guard exceptionEvent == .invalidSignature,
connectingState == .normal,
let savedURLForReconnect = savedURLForReconnect,
let serverDate = serverDate else {
onServiceException(.init(event: exceptionEvent))
return .stopAndInvalidateSession
}

connectingState = .reconnect
let signedConnectionURL = signer.sign(
url: savedURLForReconnect,
date: { serverDate }
)
return .invalidateSessionAndRetry(url: signedConnectionURL)
} else {
return fallbackDecoding(message)
}
} catch {
return false
return .stopAndInvalidateSession
}
case .success:
return true
return .continueToReceive
case .failure:
return false
return .stopAndInvalidateSession
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,15 @@
//

import Foundation
import Amplify

final class WebSocketSession {
private let urlSessionWebSocketDelegate: Delegate
private let session: URLSession
private var task: URLSessionWebSocketTask?
private var receiveMessage: ((Result<URLSessionWebSocketTask.Message, Error>) -> Bool)?
private var receiveMessage: ((Result<URLSessionWebSocketTask.Message, Error>) -> WebSocketMessageResult)?
private var onSocketClosed: ((URLSessionWebSocketTask.CloseCode) -> Void)?
private var onServerDateReceived: ((Date?) -> Void)?

init() {
self.urlSessionWebSocketDelegate = Delegate()
Expand All @@ -23,7 +25,7 @@ final class WebSocketSession {
)
}

func onMessageReceived(_ receive: @escaping (Result<URLSessionWebSocketTask.Message, Error>) -> Bool) {
func onMessageReceived(_ receive: @escaping (Result<URLSessionWebSocketTask.Message, Error>) -> WebSocketMessageResult) {
self.receiveMessage = receive
}

Expand All @@ -34,25 +36,32 @@ final class WebSocketSession {
func onSocketOpened(_ onOpen: @escaping () -> Void) {
urlSessionWebSocketDelegate.onOpen = onOpen
}

func onServerDateReceived(_ onServerDateReceived: @escaping (Date?) -> Void) {
urlSessionWebSocketDelegate.onServerDateReceived = onServerDateReceived
}

func receive(shouldContinue: Bool) {
guard shouldContinue else {
func receive(result: WebSocketMessageResult) {
switch result {
case .continueToReceive:
task?.receive(completionHandler: { [weak self] result in
if let webSocketResult = self?.receiveMessage?(result) {
self?.receive(result: webSocketResult)
}
})
case .stopAndInvalidateSession:
session.finishTasksAndInvalidate()
case .invalidateSessionAndRetry(let url):
session.finishTasksAndInvalidate()
return
open(url: url)
}

task?.receive(completionHandler: { [weak self] result in
if let shouldContinue = self?.receiveMessage?(result) {
self?.receive(shouldContinue: shouldContinue)
}
})
}

func open(url: URL) {
var request = URLRequest(url: url)
request.setValue("no-store", forHTTPHeaderField: "Cache-Control")
task = session.webSocketTask(with: request)
receive(shouldContinue: true)
receive(result: .continueToReceive)
task?.resume()
}

Expand All @@ -77,10 +86,12 @@ final class WebSocketSession {
)
}

final class Delegate: NSObject, URLSessionWebSocketDelegate {
final class Delegate: NSObject, URLSessionWebSocketDelegate, URLSessionTaskDelegate {
var onClose: (URLSessionWebSocketTask.CloseCode) -> Void = { _ in }
var onOpen: () -> Void = {}
var onServerDateReceived: (Date?) -> Void = { _ in }

// MARK: - URLSessionWebSocketDelegate methods
func urlSession(
_ session: URLSession,
webSocketTask: URLSessionWebSocketTask,
Expand All @@ -97,5 +108,34 @@ final class WebSocketSession {
) {
onClose(closeCode)
}

// MARK: - URLSessionTaskDelegate methods
func urlSession(_ session: URLSession,
task: URLSessionTask,
didFinishCollecting metrics: URLSessionTaskMetrics
) {
guard let httpResponse = metrics.transactionMetrics.first?.response as? HTTPURLResponse,
let dateString = httpResponse.value(forHTTPHeaderField: "Date") else {
Amplify.log.verbose("\(#function): Couldn't find Date header in URLSession metrics")
onServerDateReceived(nil)
return
}

let dateFormatter = DateFormatter()
dateFormatter.dateFormat = "EEE, d MMM yyyy HH:mm:ss z"
guard let serverDate = dateFormatter.date(from: dateString) else {
Amplify.log.verbose("\(#function): Error parsing Date header in expected format")
thisisabhash marked this conversation as resolved.
Show resolved Hide resolved
onServerDateReceived(nil)
return
}

onServerDateReceived(serverDate)
}
}

enum WebSocketMessageResult {
case continueToReceive
case stopAndInvalidateSession
case invalidateSessionAndRetry(url: URL)
}
}
Loading