Skip to content

Commit

Permalink
fix(api): Change the getToken to async (#2856)
Browse files Browse the repository at this point in the history
  • Loading branch information
royjit authored Apr 13, 2023
1 parent c8ab069 commit 50b556e
Show file tree
Hide file tree
Showing 10 changed files with 143 additions and 65 deletions.
17 changes: 12 additions & 5 deletions Amplify/Categories/API/Operation/RetryableGraphQLOperation.swift
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ public protocol RetryableGraphQLOperationBehavior: Operation, DefaultLogger {
/// GraphQLOperation concrete type
associatedtype OperationType: AnyGraphQLOperation

typealias RequestFactory = () -> GraphQLRequest<Payload>
typealias RequestFactory = (@escaping (GraphQLRequest<Payload>) -> Void) -> Void
typealias OperationFactory = (GraphQLRequest<Payload>, @escaping OperationResultListener) -> OperationType
typealias OperationResultListener = OperationType.ResultListener

Expand Down Expand Up @@ -64,7 +64,10 @@ extension RetryableGraphQLOperationBehavior {
let wrappedResultListener: OperationResultListener = { result in
if case let .failure(error) = result, self.shouldRetry(error: error as? APIError) {
self.log.debug("\(error)")
self.start(request: self.requestFactory())
self.requestFactory { [weak self] request in
self?.start(request: request)
}

return
}

Expand Down Expand Up @@ -95,7 +98,7 @@ public final class RetryableGraphQLOperation<Payload: Decodable>: Operation, Ret
public var resultListener: OperationResultListener
public var operationFactory: OperationFactory

public init(requestFactory: @escaping () -> GraphQLRequest<Payload>,
public init(requestFactory: @escaping RetryableGraphQLOperation<Payload>.RequestFactory,
maxRetries: Int,
resultListener: @escaping OperationResultListener,
_ operationFactory: @escaping OperationFactory) {
Expand All @@ -106,7 +109,9 @@ public final class RetryableGraphQLOperation<Payload: Decodable>: Operation, Ret
self.resultListener = resultListener
}
public override func main() {
start(request: requestFactory())
requestFactory { [weak self] request in
self?.start(request: request)
}
}

public override func cancel() {
Expand Down Expand Up @@ -154,7 +159,9 @@ public final class RetryableGraphQLSubscriptionOperation<Payload: Decodable>: Op
self.resultListener = resultListener
}
public override func main() {
start(request: requestFactory())
requestFactory { [weak self] request in
self?.start(request: request)
}
}

public override func cancel() {
Expand Down
1 change: 0 additions & 1 deletion Amplify/Core/Support/Optional+Extension.swift
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
// SPDX-License-Identifier: Apache-2.0
//


import Foundation

extension Optional {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -125,3 +125,4 @@ final public class AWSRESTOperation: AmplifyOperation<
task.resume()
}
}

2 changes: 1 addition & 1 deletion AmplifyPlugins/API/Podfile.lock
Original file line number Diff line number Diff line change
Expand Up @@ -116,4 +116,4 @@ SPEC CHECKSUMS:

PODFILE CHECKSUM: 5170578806036f2ba018abb8868d56e448fb0ada

COCOAPODS: 1.11.3
COCOAPODS: 1.12.0
Original file line number Diff line number Diff line change
Expand Up @@ -140,15 +140,16 @@ final class InitialSyncOperation: AsynchronousOperation {
}

var authTypes = authModeStrategy.authTypesFor(schema: modelSchema,
operation: .read)

RetryableGraphQLOperation(requestFactory: {
GraphQLRequest<SyncQueryResult>.syncQuery(modelSchema: self.modelSchema,
where: queryPredicate,
limit: limit,
nextToken: nextToken,
lastSync: lastSyncTime,
authType: authTypes.next())
operation: .read)

RetryableGraphQLOperation(requestFactory: { completion in
completion(GraphQLRequest<SyncQueryResult>.syncQuery(modelSchema: self.modelSchema,
where: queryPredicate,
limit: limit,
nextToken: nextToken,
lastSync: lastSyncTime,
authType: authTypes.next()))

},
maxRetries: authTypes.count,
resultListener: completionListener, { nextRequest, wrappedCompletionListener in
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -196,38 +196,97 @@ final class IncomingAsyncSubscriptionEventPublisher: AmplifyCancellable {
}

// swiftlint:disable:next function_parameter_count
static func makeAPIRequest(for modelSchema: ModelSchema,
subscriptionType: GraphQLSubscriptionType,
api: APICategoryGraphQLBehavior,
auth: AuthCategoryBehavior?,
authType: AWSAuthorizationType?,
awsAuthService: AWSAuthServiceBehavior) -> GraphQLRequest<Payload> {
let request: GraphQLRequest<Payload>
if modelSchema.hasAuthenticationRules,
auth != nil,
case .success(let tokenString) = awsAuthService.getToken(),
case .success(let claims) = awsAuthService.getTokenClaims(tokenString: tokenString) {
request = GraphQLRequest<Payload>.subscription(to: modelSchema,
subscriptionType: subscriptionType,
claims: claims,
authType: authType)
} else if modelSchema.hasAuthenticationRules,
let oidcAuthProvider = hasOIDCAuthProviderAvailable(api: api),
case .success(let tokenString) = oidcAuthProvider.getLatestAuthToken(),
case .success(let claims) = awsAuthService.getTokenClaims(tokenString: tokenString) {
request = GraphQLRequest<Payload>.subscription(to: modelSchema,
subscriptionType: subscriptionType,
claims: claims,
authType: authType)
static func makeAPIRequest(
for modelSchema: ModelSchema,
subscriptionType: GraphQLSubscriptionType,
api: APICategoryGraphQLBehavior,
auth: AuthCategoryBehavior?,
authType: AWSAuthorizationType?,
awsAuthService: AWSAuthServiceBehavior,
completion: @escaping (GraphQLRequest<Payload>) -> Void) {

let requestWithOutClaims = GraphQLRequest<Payload>.subscription(
to: modelSchema,
subscriptionType: subscriptionType,
authType: authType)

guard modelSchema.hasAuthenticationRules else {
completion(requestWithOutClaims)
return
}

getClaims(api: api,
auth: auth,
awsAuthService: awsAuthService) { claims in

guard let claims = claims else {
completion(requestWithOutClaims)
return
}
let request = GraphQLRequest<Payload>.subscription(
to: modelSchema,
subscriptionType: subscriptionType,
claims: claims,
authType: authType)
completion(request)
return
}

}

static func getClaims(api: APICategoryGraphQLBehavior,
auth: AuthCategoryBehavior?,
awsAuthService: AWSAuthServiceBehavior,
completion: @escaping ([String: AnyObject]?) -> Void) {
if auth != nil {
getClaimsFromUserPool(awsAuthService: awsAuthService) { claims in
if let claims = claims {
completion(claims)
} else {
getClaimsFromOIDCProvider(
api: api,
awsAuthService: awsAuthService,
completion: completion)
}
}
} else {
request = GraphQLRequest<Payload>.subscription(to: modelSchema,
subscriptionType: subscriptionType,
authType: authType)
getClaimsFromOIDCProvider(
api: api,
awsAuthService: awsAuthService,
completion: completion)
}

return request
}

static func getClaimsFromUserPool(
awsAuthService: AWSAuthServiceBehavior,
completion: @escaping ([String: AnyObject]?) -> Void) {

awsAuthService.getUserPoolAccessToken { result in
if case .success(let tokenString) = result,
case .success(let claims) = awsAuthService.getTokenClaims(tokenString: tokenString) {
completion(claims)
} else {
completion(nil)
}
}
}

static func getClaimsFromOIDCProvider(
api: APICategoryGraphQLBehavior,
awsAuthService: AWSAuthServiceBehavior,
completion: @escaping ([String: AnyObject]?) -> Void) {

guard let oidcAuthProvider = hasOIDCAuthProviderAvailable(api: api),
case .success(let tokenString) = oidcAuthProvider.getLatestAuthToken(),
case .success(let claims) = awsAuthService.getTokenClaims(tokenString: tokenString)
else {
completion(nil)
return
}
completion(claims)
}

static func hasOIDCAuthProviderAvailable(api: APICategoryGraphQLBehavior) -> AmplifyOIDCAuthProvider? {
if let apiPlugin = api as? APICategoryAuthProviderFactoryBehavior,
let oidcAuthProvider = apiPlugin.apiAuthProviderFactory().oidcAuthProvider() {
Expand Down Expand Up @@ -292,16 +351,20 @@ extension IncomingAsyncSubscriptionEventPublisher {
api: APICategoryGraphQLBehavior,
auth: AuthCategoryBehavior?,
awsAuthService: AWSAuthServiceBehavior,
authTypeProvider: AWSAuthorizationTypeIterator) -> RetryableGraphQLOperation<Payload>.RequestFactory {
authTypeProvider: AWSAuthorizationTypeIterator)
-> RetryableGraphQLOperation<Payload>.RequestFactory {

// swiftlint:disable:previous line_length
var authTypes = authTypeProvider
return {
return IncomingAsyncSubscriptionEventPublisher.makeAPIRequest(for: modelSchema,
subscriptionType: subscriptionType,
api: api,
auth: auth,
authType: authTypes.next(),
awsAuthService: awsAuthService)
return { completion in
return IncomingAsyncSubscriptionEventPublisher.makeAPIRequest(
for: modelSchema,
subscriptionType: subscriptionType,
api: api,
auth: auth,
authType: authTypes.next(),
awsAuthService: awsAuthService,
completion: completion)
}
}
}
Expand Down
2 changes: 1 addition & 1 deletion AmplifyPlugins/DataStore/Podfile.lock
Original file line number Diff line number Diff line change
Expand Up @@ -127,4 +127,4 @@ SPEC CHECKSUMS:

PODFILE CHECKSUM: 0bab7193bebdf470839514f327440893b0d26090

COCOAPODS: 1.11.3
COCOAPODS: 1.12.0
Original file line number Diff line number Diff line change
Expand Up @@ -29,10 +29,9 @@ class RetryableGraphQLOperationTests: XCTestCase {
resultExpectation.fulfill()
}

let requestFactory: RequestFactory = {
let requestFactory: RequestFactory = { completion in
requestFactoryExpectation.fulfill()
return self.makeTestRequest()

self.makeTestRequestAsync(completion: completion)
}

let operation = RetryableGraphQLOperation<Payload>(requestFactory: requestFactory,
Expand Down Expand Up @@ -69,10 +68,9 @@ class RetryableGraphQLOperationTests: XCTestCase {
resultExpectation.fulfill()
}

let requestFactory: RequestFactory = {
let requestFactory: RequestFactory = { completion in
requestFactoryExpectation.fulfill()
return self.makeTestRequest()

completion(self.makeTestRequest())
}

let operation = RetryableGraphQLOperation<Payload>(requestFactory: requestFactory,
Expand Down Expand Up @@ -103,10 +101,9 @@ class RetryableGraphQLOperationTests: XCTestCase {
resultExpectation.fulfill()
}

let requestFactory: RequestFactory = {
let requestFactory: RequestFactory = { completion in
requestFactoryExpectation.fulfill()
return self.makeTestRequest()

completion(self.makeTestRequest())
}

let operation = RetryableGraphQLOperation<Payload>(requestFactory: requestFactory,
Expand All @@ -133,6 +130,16 @@ extension RetryableGraphQLOperationTests {
responseType: Payload.self)
}

private func makeTestRequestAsync(completion: @escaping (GraphQLRequest<Payload>) -> Void ) {
DispatchQueue.global().asyncAfter(deadline: .now() + 2) {
let request = GraphQLRequest<Payload>(apiName: self.testApiName,
document: "",
responseType: Payload.self)
completion(request)
}

}

private func makeTestOperation() -> GraphQLOperation<Payload> {
let requestOptions = GraphQLOperationRequest<Payload>.Options(pluginOptions: nil)
let operationRequest = GraphQLOperationRequest<Payload>(apiName: testApiName,
Expand Down
2 changes: 1 addition & 1 deletion AmplifyTests/CoreTests/Optional+ExtensionTests.swift
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,6 @@ class OptionalExtensionTests: XCTestCase {

}

fileprivate struct TestRuntimeError: Error, Equatable {
private struct TestRuntimeError: Error, Equatable {
let id = UUID()
}
6 changes: 3 additions & 3 deletions Podfile.lock
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
PODS:
- AWSCore (2.30.1)
- AWSCore (2.30.4)
- CwlCatchException (2.1.1):
- CwlCatchExceptionSupport (~> 2.1.1)
- CwlCatchExceptionSupport (2.1.1)
Expand Down Expand Up @@ -39,7 +39,7 @@ CHECKOUT OPTIONS:
:tag: 2.1.0

SPEC CHECKSUMS:
AWSCore: 493e49f8118e04fa57d927ceb117ba24a9b5ca02
AWSCore: 19b8233fe2d0ed3ccf5cff833a615814282cdc90
CwlCatchException: 86760545af2a490a23e964d76d7c77442dbce79b
CwlCatchExceptionSupport: a004322095d7101b945442c86adc7cec0650f676
CwlMachBadInstructionHandler: aa1fe9f2d08b29507c150d099434b2890247e7f8
Expand All @@ -50,4 +50,4 @@ SPEC CHECKSUMS:

PODFILE CHECKSUM: 5e20e56b8ef40444b018a3736b7b726ff9772f00

COCOAPODS: 1.11.3
COCOAPODS: 1.12.0

0 comments on commit 50b556e

Please sign in to comment.