Skip to content

Commit

Permalink
feat: Handle Unauthorized errors (#69)
Browse files Browse the repository at this point in the history
* feat: Handle Unauthorized errors

* Add additional info to ConnectionProvider.other case

* fix: refactor error handling

* address PR comments

* update to .unauthorized and .unknown
  • Loading branch information
lawmicha committed Jul 11, 2022
1 parent 613ee2e commit 876f40f
Show file tree
Hide file tree
Showing 8 changed files with 177 additions and 49 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -108,8 +108,10 @@ extension AppSyncSubscriptionConnection {
additionalInfo=\(String(describing: errorPayload))
"""
)
case .other:
AppSyncLogger.error("ConnectionProviderError.other")
case .unauthorized:
AppSyncLogger.error("ConnectionProviderError.unauthorized")
case .unknown:
AppSyncLogger.error("ConnectionProviderError.unknown")
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,11 @@ extension RealtimeConnectionProvider: AppSyncWebsocketDelegate {
connectionQueue.async { [weak self] in
self?.handleError(response: response)
}
case .connectionError:
AppSyncLogger.verbose("[RealtimeConnectionProvider] received error")
connectionQueue.async { [weak self] in
self?.handleError(response: response)
}
case .subscriptionAck, .unsubscriptionAck, .data:
if let appSyncResponse = response.toAppSyncResponse() {
updateCallback(event: .data(appSyncResponse))
Expand Down Expand Up @@ -104,49 +109,19 @@ extension RealtimeConnectionProvider: AppSyncWebsocketDelegate {
///
/// - Warning: This method must be invoked on the `connectionQueue`
func handleError(response: RealtimeConnectionProviderResponse) {
// If we get an error in connection inprogress state, return back as connection error.
guard status != .inProgress else {
// If we get an error while the connection was inProgress state,
let error = response.toConnectionProviderError(connectionState: status)
if status == .inProgress {
status = .notConnected
updateCallback(event: .error(ConnectionProviderError.connection))
return
}

if response.isLimitExceededError() {
let limitExceedError = ConnectionProviderError.limitExceeded(response.id)

guard response.id == nil else {
updateCallback(event: .error(limitExceedError))
return
}

if #available(iOS 13.0, *) {
self.limitExceededSubject.send(limitExceedError)
return
} else {
updateCallback(event: .error(limitExceedError))
return
}
}

if response.isMaxSubscriptionReachedError() {
let limitExceedError = ConnectionProviderError.limitExceeded(response.id)
updateCallback(event: .error(limitExceedError))
return
}

// If the type of error is not handled (by checking `isLimitExceededError`, `isMaxSubscriptionReachedError`,
// etc), and is not for a specific subscription, then return a generic error
guard let identifier = response.id else {
let genericError = ConnectionProviderError.other
updateCallback(event: .error(genericError))
return
// If limit exceeded is for a particular subscription identifier, throttle using `limitExceededSubject`
if case .limitExceeded(let id) = error, id == nil, #available(iOS 13.0, *) {
self.limitExceededSubject.send(error)
} else {
updateCallback(event: .error(error))
}

// Default scenario - return the error with subscription id and error payload.
let subscriptionError = ConnectionProviderError.subscription(identifier, response.payload)
updateCallback(event: .error(subscriptionError))
}

}

extension RealtimeConnectionProviderResponse {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,8 @@ enum RealtimeConnectionProviderResponseType: String, Decodable {
case data

case error

case connectionError = "connection_error"
}

extension RealtimeConnectionProviderResponse: Decodable {
Expand All @@ -66,6 +68,33 @@ extension RealtimeConnectionProviderResponse: Decodable {
/// https://docs.aws.amazon.com/appsync/latest/devguide/real-time-websocket-client.html#error-message
extension RealtimeConnectionProviderResponse {

func toConnectionProviderError(connectionState: ConnectionState) -> ConnectionProviderError {
// If it is Unauthorized, return `.unauthorized` error.
guard !isUnauthorizationError() else {
return .unauthorized
}

// If it is in-progress, return `.connection`.
guard connectionState != .inProgress else {
return .connection
}

if isLimitExceededError() || isMaxSubscriptionReachedError() {
// Is it observed that LimitExceeded error does not have `id` while MaxSubscriptionReached does have a
// corresponding identifier. Both are mapped to `limitExceeded` with optional identifier.
return .limitExceeded(id)
}

// If the type of error is not handled (by checking `isLimitExceededError`, `isMaxSubscriptionReachedError`,
// etc), and is not for a specific subscription, then return unknown error
guard let identifier = id else {
return .unknown(message: nil, causedBy: nil, payload: payload)
}

// Default scenario - return the error with subscription id and error payload.
return .subscription(identifier, payload)
}

func isMaxSubscriptionReachedError() -> Bool {
// It is expected to contain payload with corresponding error information
guard let payload = payload else {
Expand Down Expand Up @@ -116,4 +145,33 @@ extension RealtimeConnectionProviderResponse {

return false
}

func isUnauthorizationError() -> Bool {
// It is expected to contain payload with corresponding error information
guard let payload = payload,
let errors = payload["errors"],
case let .array(errorsArray) = errors else {
return false
}

// The observed response from the service
// { "payload": {
// "errors": [{
// "errorType":"com.amazonaws.deepdish.graphql.auth#UnauthorizedException",
// "message":"You are not authorized to make this call.",
// "errorCode":400 }]},
// "type":"connection_error" }
return errorsArray.contains { error in
guard case let .object(errorObject) = error,
case let .string(errorObjectString) = errorObject["errorType"] else {
return false
}

if errorObjectString.contains("UnauthorizedException") {
return true
}

return false
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,9 @@ public enum ConnectionProviderError: Error {
/// payload in dictionary format.
case subscription(String, [String: Any]?)

/// Any other error is identified by this type
case other
/// Caused when not authorized to establish the connection.
case unauthorized

/// Unknown error
case unknown(message: String? = nil, causedBy: Error? = nil, payload: [String: Any]?)
}
Original file line number Diff line number Diff line change
Expand Up @@ -166,6 +166,39 @@ class AppSyncRealTimeClientFailureTests: AppSyncRealTimeClientTestBase {
}
}

func testAPIKeyInvalid() {
apiKey = "invalid"
let subscribeFailed = expectation(description: "subscribe failed")
let authInterceptor = APIKeyAuthInterceptor(apiKey)
let connectionProvider = ConnectionProviderFactory.createConnectionProvider(
for: url,
authInterceptor: authInterceptor,
connectionType: .appSyncRealtime
)
let subscriptionConnection = AppSyncSubscriptionConnection(provider: connectionProvider)
_ = subscriptionConnection.subscribe(
requestString: requestString,
variables: nil
) { event, _ in

switch event {
case .connection:
break
case .data:
break
case .failed(let error):
guard let connectionError = error as? ConnectionProviderError,
case .unauthorized = connectionError else {
XCTFail("Should be `.unauthorized` error")
return
}
subscribeFailed.fulfill()
}
}

wait(for: [subscribeFailed], timeout: TestCommonConstants.networkTimeout)
}

class TestConnectionRetryHandler: ConnectionRetryHandler {
var count: Int = 0
func shouldRetryRequest(for error: ConnectionProviderError) -> RetryAdvice {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -239,7 +239,7 @@ class AppSyncSubscriptionConnectionErrorHandlerTests: XCTestCase {
XCTAssertNil(connectionProvider.listener)
}

func testOther() throws {
func testUnknown() throws {
let connection = AppSyncSubscriptionConnection(provider: connectionProvider)
let connectedMessageExpectation = expectation(description: "Connected event should be fired")
let failedEvent = expectation(description: "Failed event should be fired")
Expand All @@ -253,8 +253,8 @@ class AppSyncSubscriptionConnectionErrorHandlerTests: XCTestCase {
XCTFail("Data event should not be published")
case .failed(let error):
guard let connection = error as? ConnectionProviderError,
case .other = connection else {
XCTFail("Should be .other")
case .unknown = connection else {
XCTFail("Should be .unknown")
return
}
failedEvent.fulfill()
Expand All @@ -264,8 +264,8 @@ class AppSyncSubscriptionConnectionErrorHandlerTests: XCTestCase {
XCTAssertEqual(connection.subscriptionState, .subscribed)
XCTAssertNotNil(connectionProvider.listener)

let otherError = ConnectionProviderError.other
connection.handleError(error: otherError)
let unknownError = ConnectionProviderError.unknown(message: nil, causedBy: nil, payload: nil)
connection.handleError(error: unknownError)
wait(for: [failedEvent], timeout: 5)
XCTAssertEqual(connection.subscriptionState, .notSubscribed)
XCTAssertNil(connectionProvider.listener)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ class RealTimeConnectionProviderResponseTests: XCTestCase {
)

XCTAssertFalse(response.isMaxSubscriptionReachedError())
XCTAssertEqual(response.toConnectionProviderError(connectionState: .connected), .subscription("id", nil))
}

func testIsMaxSubscriptionReached_MaxSubscriptionsReachedException() throws {
Expand All @@ -29,6 +30,7 @@ class RealTimeConnectionProviderResponseTests: XCTestCase {
)

XCTAssertTrue(response.isMaxSubscriptionReachedError())
XCTAssertEqual(response.toConnectionProviderError(connectionState: .connected), .limitExceeded("id"))
}

func testIsMaxSubscriptionReached_MaxSubscriptionsReachedError() throws {
Expand All @@ -40,6 +42,7 @@ class RealTimeConnectionProviderResponseTests: XCTestCase {
)

XCTAssertTrue(response.isMaxSubscriptionReachedError())
XCTAssertEqual(response.toConnectionProviderError(connectionState: .connected), .limitExceeded("id"))
}

func testIsLimitExceeded_EmptyPayload() throws {
Expand All @@ -49,6 +52,10 @@ class RealTimeConnectionProviderResponseTests: XCTestCase {
)

XCTAssertFalse(response.isLimitExceededError())
XCTAssertEqual(
response.toConnectionProviderError(connectionState: .connected),
.unknown(message: nil, causedBy: nil, payload: nil)
)
}

func testIsLimitExceeded_LimitExceededError() throws {
Expand All @@ -59,6 +66,56 @@ class RealTimeConnectionProviderResponseTests: XCTestCase {
)

XCTAssertTrue(response.isLimitExceededError())
XCTAssertEqual(response.toConnectionProviderError(connectionState: .connected), .limitExceeded(nil))
}

func testIsUnauthorized_EmptyPayload() throws {
let response = RealtimeConnectionProviderResponse(
payload: nil,
type: .error
)

XCTAssertFalse(response.isUnauthorizationError())
XCTAssertEqual(
response.toConnectionProviderError(connectionState: .connected),
.unknown(message: nil, causedBy: nil, payload: nil)
)
}

func testIsUnauthorized_UnauthorizedException() throws {
let payload = ["errors": AppSyncJSONValue.array([
["errorType": "com.amazonaws.deepdish.graphql.auth#UnauthorizedException"]
])]
let response = RealtimeConnectionProviderResponse(
payload: payload,
type: .error
)

XCTAssertTrue(response.isUnauthorizationError())
XCTAssertEqual(
response.toConnectionProviderError(connectionState: .connected),
.unauthorized
)
}
}

extension ConnectionProviderError: Equatable {
public static func == (lhs: ConnectionProviderError, rhs: ConnectionProviderError) -> Bool {
switch (lhs, rhs) {
case (.connection, .connection):
return true
case (.jsonParse, .jsonParse):
return true
case (.limitExceeded(let id1), .limitExceeded(let id2)):
return id1 == id2
case (.subscription(let id1, _), .subscription(let id2, _)):
return id1 == id2
case (.unauthorized, .unauthorized):
return true
case (.unknown(let message1, _, _), .unknown(let message2, _, _)):
return message1 == message2
default:
return false
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -121,8 +121,8 @@ class ConnectionProviderHandleErrorTests: XCTestCase {
XCTFail("Should have received error event")
return
}
guard case .other = connectionError else {
XCTFail("Should have received .other error")
guard case .unknown = connectionError else {
XCTFail("Should have received .unknown error")
return
}

Expand Down

0 comments on commit 876f40f

Please sign in to comment.