Skip to content

Commit

Permalink
feat: pass URLRequest instead of URL to interfaces (#110)
Browse files Browse the repository at this point in the history
* feat: pass URLRequest instead of URL to interfaces

* return error on missing URL

* fix integration tests

* fix RTConnectionProvider URLRequest should not be updated on connect
  • Loading branch information
lawmicha authored Jan 24, 2023
1 parent 779f548 commit 0d22315
Show file tree
Hide file tree
Showing 16 changed files with 100 additions and 64 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ public class RealtimeConnectionProvider: ConnectionProvider {
/// message before we consider it stale and force a disconnect
static let staleConnectionTimeout: TimeInterval = 5 * 60

private let url: URL
private let urlRequest: URLRequest
var listeners: [String: ConnectionProviderCallback]

let websocket: AppSyncWebsocketProvider
Expand Down Expand Up @@ -58,12 +58,12 @@ public class RealtimeConnectionProvider: ConnectionProvider {
return iLimitExceededSubject as! PassthroughSubject<ConnectionProviderError, Never> // swiftlint:disable:this force_cast line_length
}

public convenience init(for url: URL, websocket: AppSyncWebsocketProvider) {
self.init(url: url, websocket: websocket)
public convenience init(for urlRequest: URLRequest, websocket: AppSyncWebsocketProvider) {
self.init(urlRequest: urlRequest, websocket: websocket)
}

init(
url: URL,
urlRequest: URLRequest,
websocket: AppSyncWebsocketProvider,
connectionQueue: DispatchQueue = DispatchQueue(
label: "com.amazonaws.AppSyncRealTimeConnectionProvider.serialQueue"
Expand All @@ -73,7 +73,7 @@ public class RealtimeConnectionProvider: ConnectionProvider {
),
connectivityMonitor: ConnectivityMonitor = ConnectivityMonitor()
) {
self.url = url
self.urlRequest = urlRequest
self.websocket = websocket
self.listeners = [:]
self.status = .notConnected
Expand Down Expand Up @@ -103,13 +103,25 @@ public class RealtimeConnectionProvider: ConnectionProvider {
self.updateCallback(event: .connection(self.status))
return
}

guard let url = self.urlRequest.url else {
self.updateCallback(event: .error(ConnectionProviderError.unknown(
message: "Missing URL",
payload: nil
)))
return
}
self.status = .inProgress
self.updateCallback(event: .connection(self.status))
let request = AppSyncConnectionRequest(url: self.url)
let signedRequest = self.interceptConnection(request, for: self.url)

let request = AppSyncConnectionRequest(url: url)
let signedRequest = self.interceptConnection(request, for: url)
var urlRequest = self.urlRequest
urlRequest.url = signedRequest.url

DispatchQueue.global().async {
self.websocket.connect(
url: signedRequest.url,
urlRequest: urlRequest,
protocols: ["graphql-ws"],
delegate: self
)
Expand All @@ -123,8 +135,14 @@ public class RealtimeConnectionProvider: ConnectionProvider {
guard let self = self else {
return
}

let signedMessage = self.interceptMessage(message, for: self.url)
guard let url = self.urlRequest.url else {
self.updateCallback(event: .error(ConnectionProviderError.unknown(
message: "Missing URL",
payload: nil
)))
return
}
let signedMessage = self.interceptMessage(message, for: url)
let jsonEncoder = JSONEncoder()
do {
let jsonData = try jsonEncoder.encode(signedMessage)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ public class RealtimeConnectionProviderAsync: ConnectionProvider {
/// message before we consider it stale and force a disconnect
static let staleConnectionTimeout: TimeInterval = 5 * 60

let url: URL
let urlRequest: URLRequest
var listeners: [String: ConnectionProviderCallback]

let websocket: AppSyncWebsocketProvider
Expand Down Expand Up @@ -63,15 +63,15 @@ public class RealtimeConnectionProviderAsync: ConnectionProvider {
}

init(
url: URL,
urlRequest: URLRequest,
websocket: AppSyncWebsocketProvider,

serialCallbackQueue: DispatchQueue = DispatchQueue(
label: "com.amazonaws.AppSyncRealTimeConnectionProvider.callbackQueue"
),
connectivityMonitor: ConnectivityMonitor = ConnectivityMonitor()
) {
self.url = url
self.urlRequest = urlRequest
self.websocket = websocket
self.listeners = [:]
self.status = .notConnected
Expand All @@ -84,8 +84,8 @@ public class RealtimeConnectionProviderAsync: ConnectionProvider {
subscribeToLimitExceededThrottle()
}

public convenience init(for url: URL, websocket: AppSyncWebsocketProvider) {
self.init(url: url, websocket: websocket)
public convenience init(for urlRequest: URLRequest, websocket: AppSyncWebsocketProvider) {
self.init(urlRequest: urlRequest, websocket: websocket)
}

// MARK: - ConnectionProvider methods
Expand All @@ -99,13 +99,21 @@ public class RealtimeConnectionProviderAsync: ConnectionProvider {
self.updateCallback(event: .connection(self.status))
return
}
guard let url = self.urlRequest.url else {
self.updateCallback(event: .error(ConnectionProviderError.unknown(
message: "Missing URL",
payload: nil
)))
return
}
self.status = .inProgress
self.updateCallback(event: .connection(self.status))
let request = AppSyncConnectionRequest(url: self.url)

let signedRequest = await self.interceptConnection(request, for: self.url)
let request = AppSyncConnectionRequest(url: url)
let signedRequest = await self.interceptConnection(request, for: url)
var urlRequest = self.urlRequest
urlRequest.url = signedRequest.url
self.websocket.connect(
url: signedRequest.url,
urlRequest: urlRequest,
protocols: ["graphql-ws"],
delegate: self
)
Expand All @@ -117,8 +125,14 @@ public class RealtimeConnectionProviderAsync: ConnectionProvider {
guard let self = self else {
return
}

let signedMessage = await self.interceptMessage(message, for: self.url)
guard let url = self.urlRequest.url else {
self.updateCallback(event: .error(ConnectionProviderError.unknown(
message: "Missing URL",
payload: nil
)))
return
}
let signedMessage = await self.interceptMessage(message, for: url)

let jsonEncoder = JSONEncoder()
do {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ import Foundation
public enum ConnectionProviderFactory {

public static func createConnectionProvider(
for url: URL,
for urlRequest: URLRequest,
authInterceptor: AuthInterceptor,
connectionType: SubscriptionConnectionType
) -> ConnectionProvider {
Expand All @@ -20,7 +20,7 @@ public enum ConnectionProviderFactory {
switch connectionType {
case .appSyncRealtime:
let websocketProvider = StarscreamAdapter()
provider = RealtimeConnectionProvider(for: url, websocket: websocketProvider)
provider = RealtimeConnectionProvider(for: urlRequest, websocket: websocketProvider)
}

if let messageInterceptable = provider as? MessageInterceptable {
Expand All @@ -37,7 +37,7 @@ public enum ConnectionProviderFactory {
#if swift(>=5.5.2)
@available(iOS 13.0, macOS 10.15, tvOS 13.0, watchOS 6.0, *)
public static func createConnectionProviderAsync(
for url: URL,
for urlRequest: URLRequest,
authInterceptor: AuthInterceptorAsync,
connectionType: SubscriptionConnectionType
) -> ConnectionProvider {
Expand All @@ -46,7 +46,7 @@ public enum ConnectionProviderFactory {
switch connectionType {
case .appSyncRealtime:
let websocketProvider = StarscreamAdapter()
provider = RealtimeConnectionProviderAsync(for: url, websocket: websocketProvider)
provider = RealtimeConnectionProviderAsync(for: urlRequest, websocket: websocketProvider)
}

if let messageInterceptable = provider as? MessageInterceptableAsync {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ public protocol AppSyncWebsocketProvider {
///
/// This is an async call. After the connection is succesfully established, the delegate
/// will receive the callback on `websocketDidConnect(:)`
func connect(url: URL, protocols: [String], delegate: AppSyncWebsocketDelegate?)
func connect(urlRequest: URLRequest, protocols: [String], delegate: AppSyncWebsocketDelegate?)

/// Disconnects the websocket.
func disconnect()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,10 +34,11 @@ public class StarscreamAdapter: AppSyncWebsocketProvider {
self.callbackQueue = callbackQueue
}

public func connect(url: URL, protocols: [String], delegate: AppSyncWebsocketDelegate?) {
public func connect(urlRequest: URLRequest, protocols: [String], delegate: AppSyncWebsocketDelegate?) {
serialQueue.async {
AppSyncLogger.verbose("[StarscreamAdapter] connect. Connecting to url")
var urlRequest = URLRequest(url: url)
var urlRequest = urlRequest

urlRequest.setValue("no-store", forHTTPHeaderField: "Cache-Control")

let protocolHeaderValue = protocols.joined(separator: ", ")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ class AppSyncRealTimeClientAsyncFailureTests: AppSyncRealTimeClientTestBase {
subscribeSuccess.expectedFulfillmentCount = 100
let authInterceptor = APIKeyAuthInterceptor(apiKey)
let connectionProvider = ConnectionProviderFactory.createConnectionProviderAsync(
for: url,
for: urlRequest,
authInterceptor: authInterceptor,
connectionType: .appSyncRealtime
)
Expand Down Expand Up @@ -91,7 +91,7 @@ class AppSyncRealTimeClientAsyncFailureTests: AppSyncRealTimeClientTestBase {
subscribeSuccess.expectedFulfillmentCount = 100
let authInterceptor = APIKeyAuthInterceptor(apiKey)
let connectionProvider = ConnectionProviderFactory.createConnectionProviderAsync(
for: url,
for: urlRequest,
authInterceptor: authInterceptor,
connectionType: .appSyncRealtime
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ class AppSyncRealTimeClientAsyncIntegrationTests: AppSyncRealTimeClientTestBase
let subscribeSuccess = expectation(description: "subscribe successfully")
let authInterceptor = APIKeyAuthInterceptor(apiKey)
let connectionProvider = ConnectionProviderFactory.createConnectionProviderAsync(
for: url,
for: urlRequest,
authInterceptor: authInterceptor,
connectionType: .appSyncRealtime
)
Expand Down Expand Up @@ -75,7 +75,7 @@ class AppSyncRealTimeClientAsyncIntegrationTests: AppSyncRealTimeClientTestBase

let authInterceptor = APIKeyAuthInterceptor(apiKey)
let connectionProvider = ConnectionProviderFactory.createConnectionProviderAsync(
for: url,
for: urlRequest,
authInterceptor: authInterceptor,
connectionType: .appSyncRealtime
)
Expand Down Expand Up @@ -170,7 +170,7 @@ class AppSyncRealTimeClientAsyncIntegrationTests: AppSyncRealTimeClientTestBase
func testSubscribeUnsubscribeRepeat() {
let authInterceptor = APIKeyAuthInterceptor(apiKey)
let connectionProvider = ConnectionProviderFactory.createConnectionProviderAsync(
for: url,
for: urlRequest,
authInterceptor: authInterceptor,
connectionType: .appSyncRealtime
)
Expand All @@ -197,7 +197,7 @@ class AppSyncRealTimeClientAsyncIntegrationTests: AppSyncRealTimeClientTestBase
func testMultipleThreadsSubscribeUnsubscribe() {
let authInterceptor = APIKeyAuthInterceptor(apiKey)
let connectionProvider = ConnectionProviderFactory.createConnectionProviderAsync(
for: url,
for: urlRequest,
authInterceptor: authInterceptor,
connectionType: .appSyncRealtime
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ class AppSyncRealTimeClientFailureTests: AppSyncRealTimeClientTestBase {
subscribeSuccess.expectedFulfillmentCount = 100
let authInterceptor = APIKeyAuthInterceptor(apiKey)
let connectionProvider = ConnectionProviderFactory.createConnectionProvider(
for: url,
for: urlRequest,
authInterceptor: authInterceptor,
connectionType: .appSyncRealtime
)
Expand Down Expand Up @@ -91,7 +91,7 @@ class AppSyncRealTimeClientFailureTests: AppSyncRealTimeClientTestBase {
subscribeSuccess.expectedFulfillmentCount = 100
let authInterceptor = APIKeyAuthInterceptor(apiKey)
let connectionProvider = ConnectionProviderFactory.createConnectionProvider(
for: url,
for: urlRequest,
authInterceptor: authInterceptor,
connectionType: .appSyncRealtime
)
Expand Down Expand Up @@ -171,7 +171,7 @@ class AppSyncRealTimeClientFailureTests: AppSyncRealTimeClientTestBase {
let subscribeFailed = expectation(description: "subscribe failed")
let authInterceptor = APIKeyAuthInterceptor(apiKey)
let connectionProvider = ConnectionProviderFactory.createConnectionProvider(
for: url,
for: urlRequest,
authInterceptor: authInterceptor,
connectionType: .appSyncRealtime
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ class AppSyncRealTimeClientIntegrationTests: AppSyncRealTimeClientTestBase {
let subscribeSuccess = expectation(description: "subscribe successfully")
let authInterceptor = APIKeyAuthInterceptor(apiKey)
let connectionProvider = ConnectionProviderFactory.createConnectionProvider(
for: url,
for: urlRequest,
authInterceptor: authInterceptor,
connectionType: .appSyncRealtime
)
Expand Down Expand Up @@ -75,7 +75,7 @@ class AppSyncRealTimeClientIntegrationTests: AppSyncRealTimeClientTestBase {

let authInterceptor = APIKeyAuthInterceptor(apiKey)
let connectionProvider = ConnectionProviderFactory.createConnectionProvider(
for: url,
for: urlRequest,
authInterceptor: authInterceptor,
connectionType: .appSyncRealtime
)
Expand Down Expand Up @@ -170,7 +170,7 @@ class AppSyncRealTimeClientIntegrationTests: AppSyncRealTimeClientTestBase {
func testSubscribeUnsubscribeRepeat() {
let authInterceptor = APIKeyAuthInterceptor(apiKey)
let connectionProvider = ConnectionProviderFactory.createConnectionProvider(
for: url,
for: urlRequest,
authInterceptor: authInterceptor,
connectionType: .appSyncRealtime
)
Expand All @@ -197,7 +197,7 @@ class AppSyncRealTimeClientIntegrationTests: AppSyncRealTimeClientTestBase {
func testMultipleThreadsSubscribeUnsubscribe() {
let authInterceptor = APIKeyAuthInterceptor(apiKey)
let connectionProvider = ConnectionProviderFactory.createConnectionProvider(
for: url,
for: urlRequest,
authInterceptor: authInterceptor,
connectionType: .appSyncRealtime
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ import XCTest

class AppSyncRealTimeClientTestBase: XCTestCase {

var url: URL!
var urlRequest: URLRequest!
var apiKey: String!
let requestString = """
subscription onCreate {
Expand All @@ -35,7 +35,8 @@ class AppSyncRealTimeClientTestBase: XCTestCase {
let endpoint = apiName["endpoint"] as? String,
let apiKey = apiName["apiKey"] as? String {

url = URL(string: endpoint)
urlRequest = URLRequest(url: URL(string: endpoint)!)

self.apiKey = apiKey
} else {
throw "Could not retrieve endpoint"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,14 @@ class StarscreamAdapterTests: AppSyncRealTimeClientTestBase {
func testConnectDisconnect() throws {
let starscreamAdapter = StarscreamAdapter()
let apiKeyAuthInterceptor = APIKeyAuthInterceptor(apiKey)
let request = AppSyncConnectionRequest(url: url)
let signedRequest = apiKeyAuthInterceptor.interceptConnection(request, for: url)
let request = AppSyncConnectionRequest(url: urlRequest.url!)
let signedRequest = apiKeyAuthInterceptor.interceptConnection(request, for: urlRequest.url!)
urlRequest.url = signedRequest.url
let expectedPerforms = expectation(description: "total performs")
expectedPerforms.expectedFulfillmentCount = 1_000
DispatchQueue.concurrentPerform(iterations: 1_000) { _ in
starscreamAdapter.connect(
url: signedRequest.url,
urlRequest: urlRequest,
protocols: ["graphql-ws"],
delegate: nil
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ import XCTest

class RealtimeConnectionProviderAsyncTestBase: XCTestCase {

let url = URL(string: "https://www.appsyncrealtimeclient.test/")!
let urlRequest = URLRequest(url: URL(string: "https://www.appsyncrealtimeclient.test/")!)

var websocket: MockWebsocketProvider!

Expand Down Expand Up @@ -47,7 +47,7 @@ class RealtimeConnectionProviderAsyncTestBase: XCTestCase {
connectivityMonitor: ConnectivityMonitor = ConnectivityMonitor()
) -> RealtimeConnectionProvider {
let provider = RealtimeConnectionProvider(
url: url,
urlRequest: urlRequest,
websocket: websocket,
serialCallbackQueue: serialCallbackQueue,
connectivityMonitor: connectivityMonitor
Expand Down
Loading

0 comments on commit 0d22315

Please sign in to comment.