Skip to content

Commit

Permalink
Adds the option to supply custom claims and fixes refresh token gener…
Browse files Browse the repository at this point in the history
…ation and processing.
  • Loading branch information
frederic-kneier authored and fkneier-bikeleasing committed Sep 20, 2024
1 parent bb8665a commit dcb1710
Show file tree
Hide file tree
Showing 5 changed files with 183 additions and 30 deletions.
45 changes: 37 additions & 8 deletions src/main/kotlin/de/solugo/oauthmock/controller/TokenController.kt
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,9 @@ import de.solugo.oauthmock.ConfigurationProvider
import de.solugo.oauthmock.ServerProperties
import de.solugo.oauthmock.service.TokenService
import de.solugo.oauthmock.token.*
import de.solugo.oauthmock.util.plus
import de.solugo.oauthmock.util.scopes
import de.solugo.oauthmock.util.sessionId
import de.solugo.oauthmock.util.uuid
import de.solugo.oauthmock.util.*
import kotlinx.coroutines.reactor.awaitSingle
import org.jose4j.jwt.JwtClaims
import org.slf4j.LoggerFactory
import org.springframework.http.MediaType
import org.springframework.util.MultiValueMap
Expand Down Expand Up @@ -77,6 +75,22 @@ class TokenController(
scopes = parameters.getFirst("scope")?.split(" ")?.toSet(),
)

parameters["claims"]?.forEach {
context.commonClaims.put(JwtClaims.parse(it))
}

parameters["idClaims"]?.forEach {
context.idClaims.put(JwtClaims.parse(it))
}

parameters["accessClaims"]?.forEach {
context.accessClaims.put(JwtClaims.parse(it))
}

parameters["refreshClaims"]?.forEach {
context.refreshClaims.put(JwtClaims.parse(it))
}

tokenProcessors.process(TokenProcessor.Step.PRE, context)

grant.process(context)
Expand All @@ -96,24 +110,39 @@ class TokenController(
claims.sessionId = claims.sessionId ?: uuid()
}

commonClaims + context.refreshClaims

buildMap {
put("token_type", "Bearer")

refreshClaims.setClaim("common_claims", context.commonClaims.claimsMap)

accessClaims.takeIf { it.claimsMap.isNotEmpty() }?.also { claims ->
put("access_token", tokenService.encodeJwt(context.issuer, claims))

refreshClaims.setClaim("access_claims", context.accessClaims.claimsMap)

claims.expirationTime?.also {
put("expires_in", it.value - (claims.issuedAt?.value ?: Instant.now().epochSecond))
}
}
refreshClaims.takeIf { it.claimsMap.isNotEmpty() }?.also { claims ->
if (claims.scopes?.contains("offline_access") != true) return@also
put("refresh_token", tokenService.encodeJwt(context.issuer, claims))
}

idClaims.takeIf { it.claimsMap.isNotEmpty() }?.also { claims ->
if (claims.scopes?.contains("openid") != true) return@also

refreshClaims.setClaim("id_claims", context.idClaims.claimsMap)

put("id_token", tokenService.encodeJwt(context.issuer, claims))
}

refreshClaims.takeIf { it.claimsMap.isNotEmpty() }?.also { claims ->
if (claims.scopes?.contains("offline_access") != true) return@also

refreshClaims.setClaim("refresh_claims", context.refreshClaims.claimsMap)

put("refresh_token", tokenService.encodeJwt(context.issuer, claims))
}

}
} catch (ex: Exception) {
logger.error("Error processing token request", ex)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package de.solugo.oauthmock.token.grant
import de.solugo.oauthmock.service.TokenService
import de.solugo.oauthmock.token.*
import de.solugo.oauthmock.util.clientId
import de.solugo.oauthmock.util.put
import de.solugo.oauthmock.util.scopes
import org.springframework.stereotype.Component

Expand All @@ -23,20 +24,31 @@ class RefreshTokenGrant(
description = "Request is missing refresh_token parameter",
)

val refreshContext = tokenService.decodeJwt(refreshToken)
val refreshTokenClaims = tokenService.decodeJwt(refreshToken).jwtClaims

if (refreshContext.jwtClaims.clientId != client.id) throw TokenException(
if (refreshTokenClaims.clientId != client.id) throw TokenException(
error = TokenError.AccessDenied,
description = "Client is not allowed to use this refresh token",
)

val refreshScopes = refreshContext.jwtClaims.scopes ?: emptySet()
val refreshScopes = refreshTokenClaims.scopes ?: emptySet()

context.scopes = context.scopes?.filter { refreshScopes.contains(it) }?.toSet() ?: refreshScopes
(refreshTokenClaims.getClaimValue("common_claims") as? Map<*, *>)?.also { claims ->
context.commonClaims.put(claims)
}

(refreshTokenClaims.getClaimValue("refresh_claims") as? Map<*, *>)?.also { claims ->
context.refreshClaims.put(claims)
}

context.commonClaims.apply {
refreshContext.jwtClaims.claimsMap.forEach { (key, value) -> setClaim(key, value) }
(refreshTokenClaims.getClaimValue("access_claims") as? Map<*, *>)?.also { claims ->
context.accessClaims.put(claims)
}

(refreshTokenClaims.getClaimValue("id_claims") as? Map<*, *>)?.also { claims ->
context.idClaims.put(claims)
}

context.scopes = context.scopes?.filter { refreshScopes.contains(it) }?.toSet() ?: refreshScopes
}
}
Original file line number Diff line number Diff line change
@@ -1,20 +1,23 @@
package de.solugo.oauthmock.token.processor

import de.solugo.oauthmock.token.*
import de.solugo.oauthmock.token.TokenContext
import de.solugo.oauthmock.token.TokenProcessor
import de.solugo.oauthmock.token.commonClaims
import de.solugo.oauthmock.token.user
import de.solugo.oauthmock.util.preferredUsername
import org.springframework.stereotype.Component

@Component
class UserClaimsProcessor : TokenProcessor {

override val step = TokenProcessor.Step.CLAIMS
override val step = TokenProcessor.Step.CLAIMS

override suspend fun process(context: TokenContext) {
val user = context.user ?: return

context.commonClaims.apply {
subject = user.id
preferredUsername = user.username
subject = subject ?: user.id
preferredUsername = preferredUsername ?: user.username
}
}

Expand Down
12 changes: 12 additions & 0 deletions src/main/kotlin/de/solugo/oauthmock/util/ClaimsUtil.kt
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,18 @@ private const val CLAIM_SESSION_ID = "sid"
private const val CLAIM_CLIENT_ID = "client_id"
private const val CLAIM_PREFERRED_USERNAME = "preferred_username"

fun JwtClaims.put(claims: JwtClaims) {
put(claims.claimsMap)
}

fun JwtClaims.put(claims: Map<*, *>) {
claims.forEach { (key, value) ->
if (key !is String) throw RuntimeException("Key $key is not a string")
setClaim(key, value)
}
}


var JwtClaims.scopes: Set<String>?
get() = run {
getStringClaimValue(CLAIM_SCOPE)?.split(" ")?.toSet()
Expand Down
121 changes: 109 additions & 12 deletions src/test/kotlin/de/solugo/oauthmock/controller/TokenControllerTest.kt
Original file line number Diff line number Diff line change
Expand Up @@ -2,19 +2,22 @@ package de.solugo.oauthmock.controller

import IntegrationTest
import com.fasterxml.jackson.databind.node.ObjectNode
import io.kotest.matchers.longs.beGreaterThan
import io.kotest.matchers.should
import de.solugo.oauthmock.util.clientId
import de.solugo.oauthmock.util.scopes
import io.kotest.matchers.shouldBe
import io.kotest.matchers.shouldNotBe
import io.ktor.client.call.*
import io.ktor.client.request.*
import io.ktor.client.request.forms.*
import io.ktor.http.*
import kotlinx.coroutines.test.runTest
import org.jose4j.jwt.consumer.JwtConsumerBuilder
import org.junit.jupiter.api.Test

class TokenControllerTest : IntegrationTest() {

private val consumer = JwtConsumerBuilder().setSkipSignatureVerification().setSkipAllValidators().build()

@Test
fun `Get openid configuration`() = runTest {
rest.get(".well-known/openid-configuration").apply {
Expand All @@ -25,6 +28,28 @@ class TokenControllerTest : IntegrationTest() {
}
}

@Test
fun `Create only access token using password grant`() = runTest {
val parameters = parametersOf(
"grant_type" to listOf("password"),
"client_id" to listOf("client_test"),
"username" to listOf("test"),
)

rest.post("token") {
setBody(FormDataContent(parameters))
}.apply {
status shouldBe HttpStatusCode.OK
body<ObjectNode>().apply {
at("/token_type").textValue() shouldBe "Bearer"
at("/access_token").textValue() shouldNotBe null

at("/id_token").textValue() shouldBe null
at("/refresh_token").textValue() shouldBe null
}
}
}


@Test
fun `Create token using password grant`() = runTest {
Expand All @@ -33,6 +58,10 @@ class TokenControllerTest : IntegrationTest() {
"client_id" to listOf("client_test"),
"username" to listOf("test"),
"scope" to listOf("openid offline_access"),
"claims" to listOf("""{"qcc": "common"}"""),
"idClaims" to listOf("""{"qic": "id"}"""),
"accessClaims" to listOf("""{"qac": "access"}"""),
"refreshClaims" to listOf("""{"qrc": "refresh"}"""),
)

rest.post("token") {
Expand All @@ -41,10 +70,36 @@ class TokenControllerTest : IntegrationTest() {
status shouldBe HttpStatusCode.OK
body<ObjectNode>().apply {
at("/token_type").textValue() shouldBe "Bearer"
at("/expires_in").longValue() should beGreaterThan(3500L)
at("/access_token").textValue() shouldNotBe null
at("/id_token").textValue() shouldNotBe null
at("/refresh_token").textValue() shouldNotBe null
at("/access_token").textValue().also { token ->
val claims = consumer.processToClaims(token)
claims.subject shouldNotBe null
claims.clientId shouldBe "client_test"
claims.scopes shouldBe setOf("openid", "offline_access")
claims.getClaimValueAsString("qcc") shouldBe "common"
claims.getClaimValueAsString("qac") shouldBe "access"
claims.hasClaim("qic") shouldBe false
claims.hasClaim("qrc") shouldBe false
}
at("/id_token").textValue().also { token ->
val claims = consumer.processToClaims(token)
claims.subject shouldNotBe null
claims.clientId shouldBe "client_test"
claims.scopes shouldBe setOf("openid", "offline_access")
claims.getClaimValueAsString("qcc") shouldBe "common"
claims.getClaimValueAsString("qic") shouldBe "id"
claims.hasClaim("qac") shouldBe false
claims.hasClaim("qrc") shouldBe false
}
at("/refresh_token").textValue().also { token ->
val claims = consumer.processToClaims(token)
claims.subject shouldNotBe null
claims.clientId shouldBe "client_test"
claims.scopes shouldBe setOf("openid", "offline_access")
claims.getClaimValueAsString("qcc") shouldBe "common"
claims.getClaimValueAsString("qrc") shouldBe "refresh"
claims.hasClaim("qac") shouldBe false
claims.hasClaim("qic") shouldBe false
}
}
}
}
Expand All @@ -55,6 +110,10 @@ class TokenControllerTest : IntegrationTest() {
"grant_type" to listOf("client_credentials"),
"client_id" to listOf("client_test"),
"scope" to listOf("custom"),
"claims" to listOf("""{"qcc": "common"}"""),
"idClaims" to listOf("""{"qic": "id"}"""),
"accessClaims" to listOf("""{"qac": "access"}"""),
"refreshClaims" to listOf("""{"qrc": "refresh"}"""),
)

rest.post("token") {
Expand All @@ -63,8 +122,16 @@ class TokenControllerTest : IntegrationTest() {
status shouldBe HttpStatusCode.OK
body<ObjectNode>().apply {
at("/token_type").textValue() shouldBe "Bearer"
at("/expires_in").numberValue() shouldBe 3600
at("/access_token").textValue() shouldNotBe null
at("/access_token").textValue().also { token ->
val claims = consumer.processToClaims(token)
claims.subject shouldNotBe null
claims.clientId shouldBe "client_test"
claims.scopes shouldBe setOf("custom")
claims.getClaimValueAsString("qcc") shouldBe "common"
claims.getClaimValueAsString("qac") shouldBe "access"
claims.hasClaim("qic") shouldBe false
claims.hasClaim("qrc") shouldBe false
}
at("/id_token").textValue() shouldBe null
at("/refresh_token").textValue() shouldBe null
}
Expand All @@ -78,6 +145,10 @@ class TokenControllerTest : IntegrationTest() {
"client_id" to listOf("client_test"),
"username" to listOf("test"),
"scope" to listOf("openid offline_access"),
"claims" to listOf("""{"qcc": "common"}"""),
"idClaims" to listOf("""{"qic": "id"}"""),
"accessClaims" to listOf("""{"qac": "access"}"""),
"refreshClaims" to listOf("""{"qrc": "refresh"}"""),
)

val refreshToken = rest.post("token") {
Expand All @@ -98,10 +169,36 @@ class TokenControllerTest : IntegrationTest() {
status shouldBe HttpStatusCode.OK
body<ObjectNode>().apply {
at("/token_type").textValue() shouldBe "Bearer"
at("/expires_in").numberValue() shouldBe 3600
at("/access_token").textValue() shouldNotBe null
at("/id_token").textValue() shouldNotBe null
at("/refresh_token").textValue() shouldNotBe null
at("/access_token").textValue().also { token ->
val claims = consumer.processToClaims(token)
claims.subject shouldNotBe null
claims.clientId shouldBe "client_test"
claims.scopes shouldBe setOf("openid", "offline_access")
claims.getClaimValueAsString("qcc") shouldBe "common"
claims.getClaimValueAsString("qac") shouldBe "access"
claims.hasClaim("qic") shouldBe false
claims.hasClaim("qrc") shouldBe false
}
at("/id_token").textValue().also { token ->
val claims = consumer.processToClaims(token)
claims.subject shouldNotBe null
claims.clientId shouldBe "client_test"
claims.scopes shouldBe setOf("openid", "offline_access")
claims.getClaimValueAsString("qcc") shouldBe "common"
claims.getClaimValueAsString("qic") shouldBe "id"
claims.hasClaim("qac") shouldBe false
claims.hasClaim("qrc") shouldBe false
}
at("/refresh_token").textValue().also { token ->
val claims = consumer.processToClaims(token)
claims.subject shouldNotBe null
claims.clientId shouldBe "client_test"
claims.scopes shouldBe setOf("openid", "offline_access")
claims.getClaimValueAsString("qcc") shouldBe "common"
claims.getClaimValueAsString("qrc") shouldBe "refresh"
claims.hasClaim("qac") shouldBe false
claims.hasClaim("qic") shouldBe false
}
}
}
}
Expand Down

0 comments on commit dcb1710

Please sign in to comment.