Skip to content

Commit

Permalink
Moves request claim processing to processor and adds support for simp…
Browse files Browse the repository at this point in the history
…le claims
  • Loading branch information
fkneier-bikeleasing committed Sep 20, 2024
1 parent dcb1710 commit daddfa6
Show file tree
Hide file tree
Showing 4 changed files with 93 additions and 19 deletions.
22 changes: 4 additions & 18 deletions src/main/kotlin/de/solugo/oauthmock/controller/TokenController.kt
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,11 @@ import de.solugo.oauthmock.ConfigurationProvider
import de.solugo.oauthmock.ServerProperties
import de.solugo.oauthmock.service.TokenService
import de.solugo.oauthmock.token.*
import de.solugo.oauthmock.util.*
import de.solugo.oauthmock.util.plus
import de.solugo.oauthmock.util.scopes
import de.solugo.oauthmock.util.sessionId
import de.solugo.oauthmock.util.uuid
import kotlinx.coroutines.reactor.awaitSingle
import org.jose4j.jwt.JwtClaims
import org.slf4j.LoggerFactory
import org.springframework.http.MediaType
import org.springframework.util.MultiValueMap
Expand Down Expand Up @@ -75,22 +77,6 @@ class TokenController(
scopes = parameters.getFirst("scope")?.split(" ")?.toSet(),
)

parameters["claims"]?.forEach {
context.commonClaims.put(JwtClaims.parse(it))
}

parameters["idClaims"]?.forEach {
context.idClaims.put(JwtClaims.parse(it))
}

parameters["accessClaims"]?.forEach {
context.accessClaims.put(JwtClaims.parse(it))
}

parameters["refreshClaims"]?.forEach {
context.refreshClaims.put(JwtClaims.parse(it))
}

tokenProcessors.process(TokenProcessor.Step.PRE, context)

grant.process(context)
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
package de.solugo.oauthmock.token.processor

import de.solugo.oauthmock.token.*
import de.solugo.oauthmock.util.put
import de.solugo.oauthmock.util.removePrefixOrNull
import org.jose4j.jwt.JwtClaims
import org.springframework.stereotype.Component

@Component
class RequestClaimsProcessor : TokenProcessor {

override val step = TokenProcessor.Step.CLAIMS

override suspend fun process(context: TokenContext) {
context.parameters["claims"]?.forEach {
context.commonClaims.put(JwtClaims.parse(it))
}
context.parameters["idClaims"]?.forEach {
context.idClaims.put(JwtClaims.parse(it))
}
context.parameters["accessClaims"]?.forEach {
context.accessClaims.put(JwtClaims.parse(it))
}
context.parameters["refreshClaims"]?.forEach {
context.refreshClaims.put(JwtClaims.parse(it))
}
context.parameters.entries.forEach { (key, values) ->
key.removePrefixOrNull("claim_")?.also {
context.commonClaims.setClaim(it, values.last())
}
key.removePrefixOrNull("accessClaim_")?.also {
context.accessClaims.setClaim(it, values.last())
}
key.removePrefixOrNull("idClaim_")?.also {
context.idClaims.setClaim(it, values.last())
}
key.removePrefixOrNull("refreshClaim_")?.also {
context.refreshClaims.setClaim(it, values.last())
}
}
}

}
6 changes: 6 additions & 0 deletions src/main/kotlin/de/solugo/oauthmock/util/Util.kt
Original file line number Diff line number Diff line change
Expand Up @@ -12,3 +12,9 @@ fun uri(uri: String, block: UriComponentsBuilder.() -> Unit) = UriComponentsBuil
block()
toUriString()
}


fun String.removePrefixOrNull(prefix: String) = when {
startsWith(prefix) -> substring(prefix.length)
else -> null
}
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,6 @@ class TokenControllerTest : IntegrationTest() {
body<ObjectNode>().apply {
at("/token_type").textValue() shouldBe "Bearer"
at("/access_token").textValue() shouldNotBe null

at("/id_token").textValue() shouldBe null
at("/refresh_token").textValue() shouldBe null
}
Expand All @@ -58,6 +57,10 @@ class TokenControllerTest : IntegrationTest() {
"client_id" to listOf("client_test"),
"username" to listOf("test"),
"scope" to listOf("openid offline_access"),
"claim_qcc_simple" to listOf("common"),
"idClaim_qic_simple" to listOf("id"),
"accessClaim_qac_simple" to listOf("access"),
"refreshClaim_qrc_simple" to listOf("refresh"),
"claims" to listOf("""{"qcc": "common"}"""),
"idClaims" to listOf("""{"qic": "id"}"""),
"accessClaims" to listOf("""{"qac": "access"}"""),
Expand All @@ -76,29 +79,41 @@ class TokenControllerTest : IntegrationTest() {
claims.clientId shouldBe "client_test"
claims.scopes shouldBe setOf("openid", "offline_access")
claims.getClaimValueAsString("qcc") shouldBe "common"
claims.getClaimValueAsString("qcc_simple") shouldBe "common"
claims.getClaimValueAsString("qac") shouldBe "access"
claims.getClaimValueAsString("qac_simple") shouldBe "access"
claims.hasClaim("qic") shouldBe false
claims.hasClaim("qic_simple") shouldBe false
claims.hasClaim("qrc") shouldBe false
claims.hasClaim("qrc_simple") shouldBe false
}
at("/id_token").textValue().also { token ->
val claims = consumer.processToClaims(token)
claims.subject shouldNotBe null
claims.clientId shouldBe "client_test"
claims.scopes shouldBe setOf("openid", "offline_access")
claims.getClaimValueAsString("qcc") shouldBe "common"
claims.getClaimValueAsString("qcc_simple") shouldBe "common"
claims.getClaimValueAsString("qic") shouldBe "id"
claims.getClaimValueAsString("qic_simple") shouldBe "id"
claims.hasClaim("qac") shouldBe false
claims.hasClaim("qac_simple") shouldBe false
claims.hasClaim("qrc") shouldBe false
claims.hasClaim("qrc_simple") shouldBe false
}
at("/refresh_token").textValue().also { token ->
val claims = consumer.processToClaims(token)
claims.subject shouldNotBe null
claims.clientId shouldBe "client_test"
claims.scopes shouldBe setOf("openid", "offline_access")
claims.getClaimValueAsString("qcc") shouldBe "common"
claims.getClaimValueAsString("qcc_simple") shouldBe "common"
claims.getClaimValueAsString("qrc") shouldBe "refresh"
claims.getClaimValueAsString("qrc_simple") shouldBe "refresh"
claims.hasClaim("qac") shouldBe false
claims.hasClaim("qac_simple") shouldBe false
claims.hasClaim("qic") shouldBe false
claims.hasClaim("qic_simple") shouldBe false
}
}
}
Expand All @@ -110,6 +125,10 @@ class TokenControllerTest : IntegrationTest() {
"grant_type" to listOf("client_credentials"),
"client_id" to listOf("client_test"),
"scope" to listOf("custom"),
"claim_qcc_simple" to listOf("common"),
"idClaim_qic_simple" to listOf("id"),
"accessClaim_qac_simple" to listOf("access"),
"refreshClaim_qrc_simple" to listOf("refresh"),
"claims" to listOf("""{"qcc": "common"}"""),
"idClaims" to listOf("""{"qic": "id"}"""),
"accessClaims" to listOf("""{"qac": "access"}"""),
Expand All @@ -128,9 +147,13 @@ class TokenControllerTest : IntegrationTest() {
claims.clientId shouldBe "client_test"
claims.scopes shouldBe setOf("custom")
claims.getClaimValueAsString("qcc") shouldBe "common"
claims.getClaimValueAsString("qcc_simple") shouldBe "common"
claims.getClaimValueAsString("qac") shouldBe "access"
claims.getClaimValueAsString("qac_simple") shouldBe "access"
claims.hasClaim("qic") shouldBe false
claims.hasClaim("qic_simple") shouldBe false
claims.hasClaim("qrc") shouldBe false
claims.hasClaim("qrc_simple") shouldBe false
}
at("/id_token").textValue() shouldBe null
at("/refresh_token").textValue() shouldBe null
Expand All @@ -145,6 +168,10 @@ class TokenControllerTest : IntegrationTest() {
"client_id" to listOf("client_test"),
"username" to listOf("test"),
"scope" to listOf("openid offline_access"),
"claim_qcc_simple" to listOf("common"),
"idClaim_qic_simple" to listOf("id"),
"accessClaim_qac_simple" to listOf("access"),
"refreshClaim_qrc_simple" to listOf("refresh"),
"claims" to listOf("""{"qcc": "common"}"""),
"idClaims" to listOf("""{"qic": "id"}"""),
"accessClaims" to listOf("""{"qac": "access"}"""),
Expand Down Expand Up @@ -175,29 +202,41 @@ class TokenControllerTest : IntegrationTest() {
claims.clientId shouldBe "client_test"
claims.scopes shouldBe setOf("openid", "offline_access")
claims.getClaimValueAsString("qcc") shouldBe "common"
claims.getClaimValueAsString("qcc_simple") shouldBe "common"
claims.getClaimValueAsString("qac") shouldBe "access"
claims.getClaimValueAsString("qac_simple") shouldBe "access"
claims.hasClaim("qic") shouldBe false
claims.hasClaim("qic_simple") shouldBe false
claims.hasClaim("qrc") shouldBe false
claims.hasClaim("qrc_simple") shouldBe false
}
at("/id_token").textValue().also { token ->
val claims = consumer.processToClaims(token)
claims.subject shouldNotBe null
claims.clientId shouldBe "client_test"
claims.scopes shouldBe setOf("openid", "offline_access")
claims.getClaimValueAsString("qcc") shouldBe "common"
claims.getClaimValueAsString("qcc_simple") shouldBe "common"
claims.getClaimValueAsString("qic") shouldBe "id"
claims.getClaimValueAsString("qic_simple") shouldBe "id"
claims.hasClaim("qac") shouldBe false
claims.hasClaim("qac_simple") shouldBe false
claims.hasClaim("qrc") shouldBe false
claims.hasClaim("qrc_simple") shouldBe false
}
at("/refresh_token").textValue().also { token ->
val claims = consumer.processToClaims(token)
claims.subject shouldNotBe null
claims.clientId shouldBe "client_test"
claims.scopes shouldBe setOf("openid", "offline_access")
claims.getClaimValueAsString("qcc") shouldBe "common"
claims.getClaimValueAsString("qcc_simple") shouldBe "common"
claims.getClaimValueAsString("qrc") shouldBe "refresh"
claims.getClaimValueAsString("qrc_simple") shouldBe "refresh"
claims.hasClaim("qac") shouldBe false
claims.hasClaim("qac_simple") shouldBe false
claims.hasClaim("qic") shouldBe false
claims.hasClaim("qic_simple") shouldBe false
}
}
}
Expand Down

0 comments on commit daddfa6

Please sign in to comment.