diff --git a/buildSrc/src/main/kotlin/test/server/tests/Auth.kt b/buildSrc/src/main/kotlin/test/server/tests/Auth.kt index 9f0f7d22e2c..20aa582e312 100644 --- a/buildSrc/src/main/kotlin/test/server/tests/Auth.kt +++ b/buildSrc/src/main/kotlin/test/server/tests/Auth.kt @@ -129,7 +129,8 @@ internal fun Application.authTestServer() { val token = call.request.headers["Authorization"] if (token.isNullOrEmpty() || token.contains("invalid")) { call.response.header(HttpHeaders.WWWAuthenticate, "Bearer realm=\"TestServer\"") - call.respond(HttpStatusCode.Unauthorized) + val status = call.request.queryParameters["status"]?.toIntOrNull() ?: 401 + call.respond(HttpStatusCode.fromValue(status)) return@get } diff --git a/ktor-client/ktor-client-plugins/ktor-client-auth/api/ktor-client-auth.api b/ktor-client/ktor-client-plugins/ktor-client-auth/api/ktor-client-auth.api index 9198fe3d958..e2c695df713 100644 --- a/ktor-client/ktor-client-plugins/ktor-client-auth/api/ktor-client-auth.api +++ b/ktor-client/ktor-client-plugins/ktor-client-auth/api/ktor-client-auth.api @@ -1,6 +1,8 @@ public final class io/ktor/client/plugins/auth/AuthConfig { public fun ()V public final fun getProviders ()Ljava/util/List; + public final fun isUnauthorizedResponse ()Lkotlin/jvm/functions/Function2; + public final fun reAuthorizeOnResponse (Lkotlin/jvm/functions/Function2;)V } public final class io/ktor/client/plugins/auth/AuthKt { diff --git a/ktor-client/ktor-client-plugins/ktor-client-auth/api/ktor-client-auth.klib.api b/ktor-client/ktor-client-plugins/ktor-client-auth/api/ktor-client-auth.klib.api index afe2c92a770..a5f0797997e 100644 --- a/ktor-client/ktor-client-plugins/ktor-client-auth/api/ktor-client-auth.klib.api +++ b/ktor-client/ktor-client-plugins/ktor-client-auth/api/ktor-client-auth.klib.api @@ -155,6 +155,11 @@ final class io.ktor.client.plugins.auth/AuthConfig { // io.ktor.client.plugins.a final val providers // io.ktor.client.plugins.auth/AuthConfig.providers|{}providers[0] final fun (): kotlin.collections/MutableList // io.ktor.client.plugins.auth/AuthConfig.providers.|(){}[0] + + final var isUnauthorizedResponse // io.ktor.client.plugins.auth/AuthConfig.isUnauthorizedResponse|{}isUnauthorizedResponse[0] + final fun (): kotlin.coroutines/SuspendFunction1 // io.ktor.client.plugins.auth/AuthConfig.isUnauthorizedResponse.|(){}[0] + + final fun reAuthorizeOnResponse(kotlin.coroutines/SuspendFunction1) // io.ktor.client.plugins.auth/AuthConfig.reAuthorizeOnResponse|reAuthorizeOnResponse(kotlin.coroutines.SuspendFunction1){}[0] } final val io.ktor.client.plugins.auth/Auth // io.ktor.client.plugins.auth/Auth|{}Auth[0] diff --git a/ktor-client/ktor-client-plugins/ktor-client-auth/common/src/io/ktor/client/plugins/auth/Auth.kt b/ktor-client/ktor-client-plugins/ktor-client-auth/common/src/io/ktor/client/plugins/auth/Auth.kt index 1fd57fa6d25..3cbdef92196 100644 --- a/ktor-client/ktor-client-plugins/ktor-client-auth/common/src/io/ktor/client/plugins/auth/Auth.kt +++ b/ktor-client/ktor-client-plugins/ktor-client-auth/common/src/io/ktor/client/plugins/auth/Auth.kt @@ -1,14 +1,14 @@ /* - * Copyright 2014-2019 JetBrains s.r.o and contributors. Use of this source code is governed by the Apache 2.0 license. + * Copyright 2014-2024 JetBrains s.r.o and contributors. Use of this source code is governed by the Apache 2.0 license. */ package io.ktor.client.plugins.auth import io.ktor.client.* import io.ktor.client.call.* -import io.ktor.client.plugins.* import io.ktor.client.plugins.api.* import io.ktor.client.request.* +import io.ktor.client.statement.* import io.ktor.http.* import io.ktor.http.auth.* import io.ktor.util.* @@ -23,9 +23,36 @@ private class AtomicCounter { val atomic = atomic(0) } +/** + * Configuration used by [Auth] plugin. + */ @KtorDsl public class AuthConfig { + /** + * [AuthProvider] list to use. + */ public val providers: MutableList = mutableListOf() + + /** + * The currently set function to control whether a response is unauthorized and should trigger a refresh / re-auth. + * + * By default checks against HTTP status 401. + * + * You can set this value via [reAuthorizeOnResponse]. + */ + @InternalAPI + public var isUnauthorizedResponse: suspend (HttpResponse) -> Boolean = { it.status == HttpStatusCode.Unauthorized } + private set + + /** + * Sets a custom function to control whether a response is unauthorized and should trigger a refresh / re-auth. + * + * Use this to change the value of [isUnauthorizedResponse]. + */ + public fun reAuthorizeOnResponse(block: suspend (HttpResponse) -> Boolean) { + @OptIn(InternalAPI::class) + isUnauthorizedResponse = block + } } /** @@ -39,8 +66,9 @@ public val AuthCircuitBreaker: AttributeKey = AttributeKey("auth-request") * * You can learn more from [Authentication and authorization](https://ktor.io/docs/auth.html). * - * [providers] - list of auth providers to use. + * @see [AuthConfig] for configuration options. */ +@OptIn(InternalAPI::class) public val Auth: ClientPlugin = createClientPlugin("Auth", ::AuthConfig) { val providers = pluginConfig.providers.toList() @@ -50,7 +78,6 @@ public val Auth: ClientPlugin = createClientPlugin("Auth", ::AuthCon val tokenVersionsAttributeKey = AttributeKey>("ProviderVersionAttributeKey") - @OptIn(InternalAPI::class) fun findProvider( call: HttpClientCall, candidateProviders: Set @@ -64,10 +91,10 @@ public val Auth: ClientPlugin = createClientPlugin("Auth", ::AuthCon } authHeaders.isEmpty() -> { - LOGGER.trace( - "401 response ${call.request.url} has no or empty \"WWW-Authenticate\" header. " + + LOGGER.trace { + "Unauthorized response ${call.request.url} has no or empty \"WWW-Authenticate\" header. " + "Can not add or refresh token" - ) + } null } @@ -88,9 +115,9 @@ public val Auth: ClientPlugin = createClientPlugin("Auth", ::AuthCon val requestTokenVersion = requestTokenVersions[provider] if (requestTokenVersion != null && requestTokenVersion >= tokenVersion.atomic.value) { - LOGGER.trace("Refreshing token for ${call.request.url}") + LOGGER.trace { "Refreshing token for ${call.request.url}" } if (!provider.refreshToken(call.response)) { - LOGGER.trace("Refreshing token failed for ${call.request.url}") + LOGGER.trace { "Refreshing token failed for ${call.request.url}" } return false } else { requestTokenVersions[provider] = tokenVersion.atomic.incrementAndGet() @@ -99,7 +126,6 @@ public val Auth: ClientPlugin = createClientPlugin("Auth", ::AuthCon return true } - @OptIn(InternalAPI::class) suspend fun Send.Sender.executeWithNewToken( call: HttpClientCall, provider: AuthProvider, @@ -111,13 +137,13 @@ public val Auth: ClientPlugin = createClientPlugin("Auth", ::AuthCon provider.addRequestHeaders(request, authHeader) request.attributes.put(AuthCircuitBreaker, Unit) - LOGGER.trace("Sending new request to ${call.request.url}") + LOGGER.trace { "Sending new request to ${call.request.url}" } return proceed(request) } onRequest { request, _ -> providers.filter { it.sendWithoutRequest(request) }.forEach { provider -> - LOGGER.trace("Adding auth headers for ${request.url} from provider $provider") + LOGGER.trace { "Adding auth headers for ${request.url} from provider $provider" } val tokenVersion = tokenVersions.computeIfAbsent(provider) { AtomicCounter() } val requestTokenVersions = request.attributes .computeIfAbsent(tokenVersionsAttributeKey) { mutableMapOf() } @@ -128,22 +154,22 @@ public val Auth: ClientPlugin = createClientPlugin("Auth", ::AuthCon on(Send) { originalRequest -> val origin = proceed(originalRequest) - if (origin.response.status != HttpStatusCode.Unauthorized) return@on origin + if (!pluginConfig.isUnauthorizedResponse(origin.response)) return@on origin if (origin.request.attributes.contains(AuthCircuitBreaker)) return@on origin var call = origin val candidateProviders = HashSet(providers) - while (call.response.status == HttpStatusCode.Unauthorized) { - LOGGER.trace("Received 401 for ${call.request.url}") + while (pluginConfig.isUnauthorizedResponse(call.response)) { + LOGGER.trace { "Unauthorized response for ${call.request.url}" } val (provider, authHeader) = findProvider(call, candidateProviders) ?: run { - LOGGER.trace("Can not find auth provider for ${call.request.url}") + LOGGER.trace { "Can not find auth provider for ${call.request.url}" } return@on call } - LOGGER.trace("Using provider $provider for ${call.request.url}") + LOGGER.trace { "Using provider $provider for ${call.request.url}" } candidateProviders.remove(provider) if (!refreshTokenIfNeeded(call, provider, originalRequest)) return@on call diff --git a/ktor-client/ktor-client-plugins/ktor-client-auth/common/test/io/ktor/client/plugins/auth/AuthTest.kt b/ktor-client/ktor-client-plugins/ktor-client-auth/common/test/io/ktor/client/plugins/auth/AuthTest.kt index 828573ec60f..20107f77f51 100644 --- a/ktor-client/ktor-client-plugins/ktor-client-auth/common/test/io/ktor/client/plugins/auth/AuthTest.kt +++ b/ktor-client/ktor-client-plugins/ktor-client-auth/common/test/io/ktor/client/plugins/auth/AuthTest.kt @@ -403,6 +403,27 @@ class AuthTest : ClientLoader() { } } + @Test + fun testForbiddenBearerAuthWithInvalidAccessAndValidRefreshTokens() = clientTests { + config { + install(Auth) { + reAuthorizeOnResponse { it.status == HttpStatusCode.Forbidden } + bearer { + refreshTokens { BearerTokens("valid", "refresh") } + loadTokens { BearerTokens("invalid", "refresh") } + } + } + + expectSuccess = false + } + + test { client -> + client.prepareGet("$TEST_SERVER/auth/bearer/test-refresh?status=403").execute { + assertEquals(HttpStatusCode.OK, it.status) + } + } + } + // The return of refreshTokenFun is null, cause it should not be called at all, if loadTokensFun returns valid tokens @Test fun testUnauthorizedBearerAuthWithValidAccessTokenAndInvalidRefreshToken() = clientTests {