diff --git a/Sources/Base/OAuth2Base.swift b/Sources/Base/OAuth2Base.swift index 70db5ea9..42477267 100644 --- a/Sources/Base/OAuth2Base.swift +++ b/Sources/Base/OAuth2Base.swift @@ -417,12 +417,15 @@ open class OAuth2Base: OAuth2Securable { This method checks `state`, throws `OAuth2Error.missingState` or `OAuth2Error.invalidState`. Resets state if it matches. */ public final func assureMatchesState(_ params: OAuth2JSON) throws { - guard let state = params["state"] as? String, !state.isEmpty else { - throw OAuth2Error.missingState - } - logger?.trace("OAuth2", msg: "Checking state, got “\(state)”, expecting “\(context.state)”") - if !context.matchesState(state) { - throw OAuth2Error.invalidState + if let state = params["state"] as? String, !state.isEmpty { + logger?.trace("OAuth2", msg: "Checking state, got “\(state)”, expecting “\(context.state)”") + if !context.matchesState(state) { + throw OAuth2Error.invalidState + } + } else { + if !clientConfig.stateParameterOptional { + throw OAuth2Error.missingState + } } context.resetState() } diff --git a/Sources/Base/OAuth2ClientConfig.swift b/Sources/Base/OAuth2ClientConfig.swift index 2a128875..668b0e68 100644 --- a/Sources/Base/OAuth2ClientConfig.swift +++ b/Sources/Base/OAuth2ClientConfig.swift @@ -71,6 +71,9 @@ open class OAuth2ClientConfig { /// Add custom parameters to the authorization request. public var customParameters: [String: String]? = nil + + /// Whether the state parameter is optional + public var stateParameterOptional = false /// Most servers use UTF-8 encoding for Authorization headers, but that's not 100% true: make it configurable (see https://github.com/p2/OAuth2/issues/165). open var authStringEncoding = String.Encoding.utf8 @@ -134,7 +137,9 @@ open class OAuth2ClientConfig { if let params = settings["parameters"] as? OAuth2StringDict { customParameters = params } - + if let stateOptional = settings["state_parameter_optional"] as? Bool { + stateParameterOptional = stateOptional + } // access token options if let assume = settings["token_assume_unexpired"] as? Bool { accessTokenAssumeUnexpired = assume diff --git a/Tests/FlowTests/OAuth2CodeGrantTests.swift b/Tests/FlowTests/OAuth2CodeGrantTests.swift index c656de8f..7fd298ca 100644 --- a/Tests/FlowTests/OAuth2CodeGrantTests.swift +++ b/Tests/FlowTests/OAuth2CodeGrantTests.swift @@ -179,6 +179,28 @@ class OAuth2CodeGrantTests: XCTestCase { XCTAssertTrue(false, "Should not throw, but threw \(error)") } } + + func testRedirectURINoStateParameterAllowed() { + let settings: OAuth2JSON = [ + "client_id": "abc", + "client_secret": "xyz", + "authorize_uri": "https://auth.ful.io", + "token_uri": "https://token.ful.io", + "keychain": false, + "state_parameter_optional": true + ] + let oauth = OAuth2CodeGrant(settings: settings) + oauth.redirect = "oauth2://callback" + oauth.context.redirectURL = oauth.redirect + // parse no state + let redirect = URL(string: "oauth2://callback?code=C0D3")! + do { + _ = try oauth.validateRedirectURL(redirect) + } + catch let error { + XCTAssertTrue(false, "Must not end up here with \(error)") + } + } func testTokenRequest() { let oauth = OAuth2CodeGrant(settings: [