Skip to content

feat: support custom TimeProvider when validating tokens (introspect, userinfo) #730

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Aug 21, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -3,15 +3,11 @@ package no.nav.security.mock.oauth2.introspect
import com.fasterxml.jackson.annotation.JsonInclude
import com.fasterxml.jackson.annotation.JsonProperty
import com.nimbusds.jwt.JWTClaimsSet
import com.nimbusds.jwt.SignedJWT
import com.nimbusds.oauth2.sdk.OAuth2Error
import com.nimbusds.oauth2.sdk.id.Issuer
import mu.KotlinLogging
import no.nav.security.mock.oauth2.OAuth2Exception
import no.nav.security.mock.oauth2.extensions.OAuth2Endpoints.INTROSPECT
import no.nav.security.mock.oauth2.extensions.issuerId
import no.nav.security.mock.oauth2.extensions.toIssuerUrl
import no.nav.security.mock.oauth2.extensions.verifySignatureAndIssuer
import no.nav.security.mock.oauth2.http.OAuth2HttpRequest
import no.nav.security.mock.oauth2.http.Route
import no.nav.security.mock.oauth2.http.json
@@ -51,12 +47,10 @@ internal fun Route.Builder.introspect(tokenProvider: OAuth2TokenProvider) =
}

private fun OAuth2HttpRequest.verifyToken(tokenProvider: OAuth2TokenProvider): JWTClaimsSet? {
val tokenString = this.formParameters.get("token")
val issuer = url.toIssuerUrl()
val jwkSet = tokenProvider.publicJwkSet(issuer.issuerId())
val algorithm = tokenProvider.getAlgorithm()
return try {
SignedJWT.parse(tokenString).verifySignatureAndIssuer(Issuer(issuer.toString()), jwkSet, algorithm)
this.formParameters.get("token")?.let {
tokenProvider.verify(url.toIssuerUrl(), it)
}
} catch (e: Exception) {
log.debug("token_introspection: failed signature validation")
return null
Original file line number Diff line number Diff line change
@@ -3,9 +3,12 @@ package no.nav.security.mock.oauth2.token
import com.nimbusds.jose.JWSAlgorithm
import com.nimbusds.jose.jwk.ECKey
import com.nimbusds.jose.jwk.JWK
import com.nimbusds.jose.jwk.JWKSelector
import com.nimbusds.jose.jwk.JWKSet
import com.nimbusds.jose.jwk.KeyType
import com.nimbusds.jose.jwk.RSAKey
import com.nimbusds.jose.jwk.source.JWKSource
import com.nimbusds.jose.proc.SecurityContext
import no.nav.security.mock.oauth2.OAuth2Exception
import java.util.concurrent.ConcurrentHashMap
import java.util.concurrent.LinkedBlockingDeque
@@ -15,7 +18,7 @@ open class KeyProvider
constructor(
private val initialKeys: List<JWK> = keysFromFile(INITIAL_KEYS_FILE),
private val algorithm: String = JWSAlgorithm.RS256.name,
) {
) : JWKSource<SecurityContext> {
private val signingKeys: ConcurrentHashMap<String, JWK> = ConcurrentHashMap()

private var generator: KeyGenerator = KeyGenerator(JWSAlgorithm.parse(algorithm))
@@ -35,9 +38,11 @@ open class KeyProvider
KeyType.RSA.value -> {
RSAKey.Builder(polledJwk.toRSAKey()).keyID(keyId).build()
}

KeyType.EC.value -> {
ECKey.Builder(polledJwk.toECKey()).keyID(keyId).build()
}

else -> {
throw OAuth2Exception("Unsupported key type: ${polledJwk.keyType.value}")
}
@@ -63,4 +68,10 @@ open class KeyProvider
return emptyList()
}
}

override fun get(
jwkSelector: JWKSelector?,
context: SecurityContext?,
): MutableList<JWK> = jwkSelector?.select(JWKSet(signingKeys.values.toList()).toPublicJWKSet()) ?: mutableListOf()

}
Original file line number Diff line number Diff line change
@@ -7,8 +7,13 @@ import com.nimbusds.jose.crypto.ECDSASigner
import com.nimbusds.jose.crypto.RSASSASigner
import com.nimbusds.jose.jwk.JWKSet
import com.nimbusds.jose.jwk.KeyType
import com.nimbusds.jose.proc.DefaultJOSEObjectTypeVerifier
import com.nimbusds.jose.proc.JWSVerificationKeySelector
import com.nimbusds.jose.proc.SecurityContext
import com.nimbusds.jwt.JWTClaimsSet
import com.nimbusds.jwt.SignedJWT
import com.nimbusds.jwt.proc.DefaultJWTClaimsVerifier
import com.nimbusds.jwt.proc.DefaultJWTProcessor
import com.nimbusds.oauth2.sdk.TokenRequest
import no.nav.security.mock.oauth2.OAuth2Exception
import no.nav.security.mock.oauth2.extensions.clientIdAsString
@@ -106,6 +111,11 @@ class OAuth2TokenProvider
builder.build()
}.sign(issuerId, JOSEObjectType.JWT.type)

fun verify(
issuerUrl: HttpUrl,
token: String,
): JWTClaimsSet = SignedJWT.parse(token).verify(issuerUrl)

private fun JWTClaimsSet.sign(
issuerId: String,
type: String,
@@ -124,6 +134,7 @@ class OAuth2TokenProvider
sign(RSASSASigner(key.toRSAKey().toPrivateKey()))
}
}

supported && keyType == KeyType.EC.value -> {
SignedJWT(
jwsHeader(key.keyID, type, algorithm),
@@ -132,6 +143,7 @@ class OAuth2TokenProvider
sign(ECDSASigner(key.toECKey().toECPrivateKey()))
}
}

else -> {
throw OAuth2Exception("Unsupported algorithm: ${algorithm.name}")
}
@@ -178,4 +190,20 @@ class OAuth2TokenProvider
}

private fun Instant?.orNow(): Instant = this ?: Instant.now()

private fun SignedJWT.verify(issuerUrl: HttpUrl): JWTClaimsSet {
val jwtProcessor =
DefaultJWTProcessor<SecurityContext?>().apply {
jwsTypeVerifier = DefaultJOSEObjectTypeVerifier(JOSEObjectType("JWT"))
jwsKeySelector = JWSVerificationKeySelector(keyProvider.algorithm(), keyProvider)
jwtClaimsSetVerifier =
object : DefaultJWTClaimsVerifier<SecurityContext?>(
JWTClaimsSet.Builder().issuer(issuerUrl.toString()).build(),
HashSet(listOf("iat", "exp")),
) {
override fun currentTime(): Date = Date.from(timeProvider().orNow())
}
}
return jwtProcessor.process(this, null)
}
}
Original file line number Diff line number Diff line change
@@ -1,16 +1,12 @@
package no.nav.security.mock.oauth2.userinfo

import com.nimbusds.jwt.JWTClaimsSet
import com.nimbusds.jwt.SignedJWT
import com.nimbusds.oauth2.sdk.ErrorObject
import com.nimbusds.oauth2.sdk.http.HTTPResponse
import com.nimbusds.oauth2.sdk.id.Issuer
import mu.KotlinLogging
import no.nav.security.mock.oauth2.OAuth2Exception
import no.nav.security.mock.oauth2.extensions.OAuth2Endpoints.USER_INFO
import no.nav.security.mock.oauth2.extensions.issuerId
import no.nav.security.mock.oauth2.extensions.toIssuerUrl
import no.nav.security.mock.oauth2.extensions.verifySignatureAndIssuer
import no.nav.security.mock.oauth2.http.OAuth2HttpRequest
import no.nav.security.mock.oauth2.http.Route
import no.nav.security.mock.oauth2.http.json
@@ -26,17 +22,12 @@ internal fun Route.Builder.userInfo(tokenProvider: OAuth2TokenProvider) =
json(claims)
}

private fun OAuth2HttpRequest.verifyBearerToken(tokenProvider: OAuth2TokenProvider): JWTClaimsSet {
val tokenString = this.headers.bearerToken()
val issuer = url.toIssuerUrl()
val jwkSet = tokenProvider.publicJwkSet(issuer.issuerId())
val algorithm = tokenProvider.getAlgorithm()
return try {
SignedJWT.parse(tokenString).verifySignatureAndIssuer(Issuer(issuer.toString()), jwkSet, algorithm)
private fun OAuth2HttpRequest.verifyBearerToken(tokenProvider: OAuth2TokenProvider): JWTClaimsSet =
try {
tokenProvider.verify(url.toIssuerUrl(), this.headers.bearerToken())
} catch (e: Exception) {
throw invalidToken(e.message ?: "could not verify bearer token")
}
}

private fun Headers.bearerToken(): String =
this["Authorization"]
Original file line number Diff line number Diff line change
@@ -19,6 +19,8 @@ import no.nav.security.mock.oauth2.token.OAuth2TokenProvider
import okhttp3.Headers
import okhttp3.HttpUrl.Companion.toHttpUrl
import org.junit.jupiter.api.Test
import java.time.Instant
import java.time.temporal.ChronoUnit

internal class IntrospectTest {
private val rs384TokenProvider = OAuth2TokenProvider(keyProvider = KeyProvider(initialKeys = emptyList(), algorithm = JWSAlgorithm.RS384.name))
@@ -66,6 +68,29 @@ internal class IntrospectTest {
}
}

@Test
fun `introspect should return active and claims from token when using a custom timeProvider in the OAuth2TokenProvider`() {
val issuerUrl = "http://localhost/default"
val yesterday = Instant.now().minus(1, ChronoUnit.DAYS)
val tokenProvider = OAuth2TokenProvider(timeProvider = { yesterday })
val claims =
mapOf(
"iss" to issuerUrl,
"client_id" to "yolo",
"token_type" to "token",
"sub" to "foo",
)
val token = tokenProvider.jwt(claims)
val request = request("$issuerUrl$INTROSPECT", token.serialize())

routes { introspect(tokenProvider) }.invoke(request).asClue {
it.status shouldBe 200
val response = it.parse<Map<String, Any>>()
response shouldContainAll claims
response shouldContain ("active" to true)
}
}

@Test
fun `introspect should return active false when token is missing`() {
val url = "http://localhost/default$INTROSPECT"
Original file line number Diff line number Diff line change
@@ -16,9 +16,7 @@ import okhttp3.HttpUrl.Companion.toHttpUrl
import org.junit.jupiter.api.Test
import org.junit.jupiter.params.ParameterizedTest
import org.junit.jupiter.params.provider.ValueSource
import java.time.Clock
import java.time.Instant
import java.time.ZoneId
import java.time.temporal.ChronoUnit
import java.util.Date

@@ -106,87 +104,71 @@ internal class OAuth2TokenProviderRSATest {
val yesterday = Instant.now().minus(1, ChronoUnit.DAYS)
val tokenProvider = OAuth2TokenProvider(systemTime = yesterday)

tokenProvider
.exchangeAccessToken(
tokenRequest =
nimbusTokenRequest(
"id",
"grant_type" to GrantType.CLIENT_CREDENTIALS.value,
"scope" to "scope1",
),
issuerUrl = "http://default_if_not_overridden".toHttpUrl(),
claimsSet = tokenProvider.jwt(mapOf()).jwtClaimsSet,
oAuth2TokenCallback = DefaultOAuth2TokenCallback(),
).asClue {
it.jwtClaimsSet.issueTime shouldBe Date.from(tokenProvider.systemTime)
println(it.serialize())
}
tokenProvider.clientCredentialsToken("http://localhost/default").asClue {
it.jwtClaimsSet.issueTime shouldBe Date.from(tokenProvider.systemTime)
}

val now = Instant.now().minus(1, ChronoUnit.SECONDS)
OAuth2TokenProvider().clientCredentialsToken("http://localhost/default").asClue {
it.jwtClaimsSet.issueTime shouldBeAfter now
}
}

@Test
fun `token should have issuedAt set dynamically according to timeProvider`() {
val clock =
object : Clock() {
private var clock = systemDefaultZone()
val timeProvider =
object : TimeProvider {
var time = Instant.now()

override fun instant() = clock.instant()

override fun withZone(zone: ZoneId) = clock.withZone(zone)

override fun getZone() = clock.zone

fun fixed(instant: Instant) {
clock = fixed(instant, zone)
}
override fun invoke(): Instant = time
}

val tokenProvider = OAuth2TokenProvider { clock.instant() }
val tokenProvider = OAuth2TokenProvider(timeProvider = timeProvider)

val instant1 = Instant.parse("2000-12-03T10:15:30.00Z")
val instant2 = Instant.parse("2020-01-21T00:00:00.00Z")
instant1 shouldNotBe instant2

run {
clock.fixed(instant1)
tokenProvider.systemTime shouldBe instant1
timeProvider.time = instant1
tokenProvider.systemTime shouldBe instant1

tokenProvider.exchangeAccessToken(
tokenRequest =
nimbusTokenRequest(
"id",
"grant_type" to GrantType.CLIENT_CREDENTIALS.value,
"scope" to "scope1",
),
issuerUrl = "http://default_if_not_overridden".toHttpUrl(),
claimsSet = tokenProvider.jwt(mapOf()).jwtClaimsSet,
oAuth2TokenCallback = DefaultOAuth2TokenCallback(),
)
}.asClue {
tokenProvider.clientCredentialsToken("http://localhost/default").asClue {
it.jwtClaimsSet.issueTime shouldBe Date.from(instant1)
println(it.serialize())
}

run {
clock.fixed(instant2)
tokenProvider.systemTime shouldBe instant2
timeProvider.time = instant2
tokenProvider.systemTime shouldBe instant2

tokenProvider.exchangeAccessToken(
tokenRequest =
nimbusTokenRequest(
"id",
"grant_type" to GrantType.CLIENT_CREDENTIALS.value,
"scope" to "scope1",
),
issuerUrl = "http://default_if_not_overridden".toHttpUrl(),
claimsSet = tokenProvider.jwt(mapOf()).jwtClaimsSet,
oAuth2TokenCallback = DefaultOAuth2TokenCallback(),
)
}.asClue {
tokenProvider.clientCredentialsToken("http://localhost/default").asClue {
it.jwtClaimsSet.issueTime shouldBe Date.from(instant2)
println(it.serialize())
}
}

@Test
fun `token with issueTime set to yesterday should be able to validate with the verify function using the same timeprovider`() {
val yesterday = Instant.now().minus(1, ChronoUnit.DAYS)
val tokenProvider = OAuth2TokenProvider(timeProvider = { yesterday })

val token = tokenProvider.clientCredentialsToken("http://localhost/default")

token.jwtClaimsSet.issueTime shouldBe Date.from(tokenProvider.systemTime)

tokenProvider.verify("http://localhost/default".toHttpUrl(), token.serialize()).toJSONObject().asClue {
it shouldBe token.jwtClaimsSet.toJSONObject()
}
}

private fun OAuth2TokenProvider.clientCredentialsToken(issuerUrl: String): SignedJWT =
accessToken(
tokenRequest =
nimbusTokenRequest(
"client1",
"grant_type" to "client_credentials",
"scope" to "scope1",
),
issuerUrl = issuerUrl.toHttpUrl(),
oAuth2TokenCallback = DefaultOAuth2TokenCallback(),
)

private fun idToken(issuerUrl: String): SignedJWT =
tokenProvider.idToken(
tokenRequest =
@@ -198,4 +180,6 @@ internal class OAuth2TokenProviderRSATest {
issuerUrl = issuerUrl.toHttpUrl(),
oAuth2TokenCallback = DefaultOAuth2TokenCallback(),
)

private infix fun Date.shouldBeAfter(instant: Instant?) = this.after(Date.from(instant)) shouldBe true
}
Original file line number Diff line number Diff line change
@@ -17,6 +17,8 @@ import no.nav.security.mock.oauth2.token.OAuth2TokenProvider
import okhttp3.Headers
import okhttp3.HttpUrl.Companion.toHttpUrl
import org.junit.jupiter.api.Test
import java.time.Instant
import java.time.temporal.ChronoUnit

internal class UserInfoTest {
@Test
@@ -38,6 +40,26 @@ internal class UserInfoTest {
}
}

@Test
fun `userinfo should return claims from bearer token when using a custom timeProvider in OAuth2TokenProvider`() {
val issuerUrl = "http://localhost/default"
val yesterday = Instant.now().minus(1, ChronoUnit.DAYS)
val tokenProvider = OAuth2TokenProvider(timeProvider = { yesterday })
val claims =
mapOf(
"iss" to issuerUrl,
"sub" to "foo",
"extra" to "bar",
)
val bearerToken = tokenProvider.jwt(claims)
val request = request("$issuerUrl$USER_INFO", bearerToken.serialize())

routes { userInfo(tokenProvider) }.invoke(request).asClue {
it.status shouldBe 200
it.parse<Map<String, Any>>() shouldContainAll claims
}
}

@Test
fun `userinfo should throw OAuth2Exception when algorithm does not match`() {
val issuerUrl = "http://localhost/default"