Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor request token to be a JWT #125

Merged
merged 15 commits into from
Jan 18, 2024
Merged
Show file tree
Hide file tree
Changes from 11 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@jiyoontbd thoughts on doing the same as this here?

Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
package tbdex.sdk.httpclient

/**
* Request token verification exception
*
* @param cause the underlying exception
* @param message the exception message detailing the error
*/
class RequestTokenVerificationException(cause: Throwable, message: String? = null)
: RuntimeException(message, cause)

/**
* Request token audience mismatch exception
*
* @param message the exception message detailing the error
*/
class RequestTokenAudMismatchException(message: String? = null)
: RuntimeException(message)

/**
* Request token missing claims exception
*
* @param message the exception message detailing the error
*/
class RequestTokenMissingClaimsException(message: String? = null)
: RuntimeException(message)

/**
* Request token expired exception
*
* @param message the exception message detailing the error
*/
class RequestTokenExpiredException(message: String? = null)
: RuntimeException(message)
126 changes: 126 additions & 0 deletions httpclient/src/main/kotlin/tbdex/sdk/httpclient/RequestToken.kt
Original file line number Diff line number Diff line change
@@ -0,0 +1,126 @@
package tbdex.sdk.httpclient

import com.nimbusds.jose.JOSEObjectType
import com.nimbusds.jose.JWSAlgorithm
import com.nimbusds.jose.JWSHeader
import com.nimbusds.jose.jwk.JWK
import com.nimbusds.jose.util.Base64URL
import com.nimbusds.jwt.JWTClaimsSet
import com.nimbusds.jwt.SignedJWT
import foundation.identity.did.VerificationMethod
import web5.sdk.common.Convert
import web5.sdk.dids.Did
import web5.sdk.dids.DidResolvers
import web5.sdk.dids.findAssertionMethodById
import java.time.Instant
import java.util.Date
import java.util.UUID

/**
* Request token object
*
* Exposes methods for generating and verifying request tokens
*/
object RequestToken {

/**
* List of required JWT claims keys
*/
val requiredClaimKeys = listOf("aud", "iss", "exp", "jti", "iat")


/**
* Generate request token.
*
* @param did DID of the token creator
* @param pfiDid DID of the PFI
* @param assertionMethodId
* @return signed request token to be included as Authorization header for sending to PFI endpoints
*
*/
fun generate(did: Did, pfiDid: String, assertionMethodId: String? = null): String {

val didResolutionResult = DidResolvers.resolve(did.uri)
val assertionMethod: VerificationMethod = didResolutionResult.didDocument.findAssertionMethodById(assertionMethodId)

// TODO: ensure that publicKeyJwk is not null
val publicKeyJwk = JWK.parse(assertionMethod.publicKeyJwk)
val keyAlias = did.keyManager.getDeterministicAlias(publicKeyJwk)

// TODO: figure out how to make more reliable since algorithm is technically not a required property of a JWK
val algorithm = publicKeyJwk.algorithm
val jwsAlgorithm = JWSAlgorithm.parse(algorithm.toString())

val kid = when (assertionMethod.id.isAbsolute) {
true -> assertionMethod.id.toString()
false -> "${did.uri}${assertionMethod.id}"
}

val jwtHeader = JWSHeader.Builder(jwsAlgorithm)
.type(JOSEObjectType.JWT)
.keyID(kid)
.build()

val now = Instant.now()
val exp = now.plusSeconds(60)
val jwtPayload = JWTClaimsSet.Builder()
.audience(pfiDid)
.issuer(did.uri)
.expirationTime(Date.from(exp))
.issueTime(Date.from(now))
.jwtID(UUID.randomUUID().toString())
.build()

val jwtObject = SignedJWT(jwtHeader, jwtPayload)
val toSign = jwtObject.signingInput
val signatureBytes = did.keyManager.sign(keyAlias, toSign)

val base64UrlEncodedHeader = jwtHeader.toBase64URL()
val base64UrlEncodedPayload = jwtPayload.toPayload().toBase64URL()
val base64UrlEncodedSignature = Base64URL(Convert(signatureBytes).toBase64Url(padding = false))

return "$base64UrlEncodedHeader.$base64UrlEncodedPayload.$base64UrlEncodedSignature"
}

/**
* Verify request token
*
* @param token JWT bearer token received from the requester
* @param pfiDid DID of the PFI
* @return DID of the requester/JWT token issuer
*/
fun verify(token: String, pfiDid: String): String {
val claimsSet: JWTClaimsSet
try {
claimsSet = SignedJWT.parse(token).jwtClaimsSet
// todo: resolving header.kid against a didresolver
// todo: getting the verificationMethod and publicKeyJwk and algorithmId
// todo: checking if signature is valid `signer.verify({...})`
} catch (e: Exception) {
throw RequestTokenVerificationException(e, "Failed to parse request token")
}

val issuer = claimsSet.issuer
val audience = claimsSet.audience
val expirationTime = claimsSet.expirationTime

requiredClaimKeys.forEach { key ->
if (!claimsSet.claims.containsKey(key)) {
throw RequestTokenMissingClaimsException("Missing required claim for key $key")
}
}

require(Instant.now().isBefore(expirationTime.toInstant())) {
throw RequestTokenExpiredException("Request Token is expired.")
}

require(audience.contains(pfiDid)) {
throw RequestTokenAudMismatchException(
"Request token contains invalid audience. " +
"Expected aud property to be PFI DID."
)
}

return issuer
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,7 @@ object TbdexHttpClient {
fun getExchange(pfiDid: String, requesterDid: Did, exchangeId: String): Exchange {
val pfiServiceEndpoint = getPfiServiceEndpoint(pfiDid)
val baseUrl = "$pfiServiceEndpoint/exchanges/$exchangeId"
val requestToken = generateRequestToken(requesterDid)
val requestToken = generateRequestToken(requesterDid, pfiDid)

val request = Request.Builder()
.url(baseUrl)
Expand Down Expand Up @@ -146,7 +146,7 @@ object TbdexHttpClient {
fun getExchanges(pfiDid: String, requesterDid: Did, filter: GetExchangesFilter? = null): List<Exchange> {
val pfiServiceEndpoint = getPfiServiceEndpoint(pfiDid)
val baseUrl = "$pfiServiceEndpoint/exchanges/"
val requestToken = generateRequestToken(requesterDid)
val requestToken = generateRequestToken(requesterDid, pfiDid)

// compose query param
val httpUrlBuilder = baseUrl.toHttpUrl().newBuilder()
Expand Down
47 changes: 0 additions & 47 deletions httpclient/src/main/kotlin/tbdex/sdk/httpclient/Utils.kt
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would separate out the request token related methods into a RequestTokenUtils file

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

Original file line number Diff line number Diff line change
@@ -1,16 +1,6 @@
package tbdex.sdk.httpclient

import com.nimbusds.jose.JWSAlgorithm
import com.nimbusds.jose.JWSHeader
import com.nimbusds.jose.JWSObject
import com.nimbusds.jose.Payload
import com.nimbusds.jose.jwk.JWK
import com.nimbusds.jose.util.Base64URL
import tbdex.sdk.protocol.CryptoUtils
import web5.sdk.common.Convert
import web5.sdk.dids.Did
import web5.sdk.dids.DidResolvers
import java.time.Instant

/**
* Get pfi service endpoint
Expand All @@ -27,41 +17,4 @@ fun getPfiServiceEndpoint(pfiDid: String): String {
}

return service.serviceEndpoint.toString()
}


/**
* Generate request token.
*
* @param did
* @param assertionMethodId
* @return
*/
fun generateRequestToken(did: Did, assertionMethodId: String? = null): String {
val assertionMethod = CryptoUtils.getAssertionMethod(did, assertionMethodId)

// TODO: ensure that publicKeyJwk is not null
val publicKeyJwk = JWK.parse(assertionMethod.publicKeyJwk)
val keyAlias = did.keyManager.getDeterministicAlias(publicKeyJwk)

val algorithm = publicKeyJwk.algorithm
val jwsAlgorithm = JWSAlgorithm.parse(algorithm.toString())

val jwsHeader = JWSHeader.Builder(jwsAlgorithm)
.keyID(assertionMethod.id.toString())
.build()

val payload = mapOf("timestamp" to Instant.now().toString())
val jwsPayload = Payload(payload)
val base64UrlEncodedPayload = jwsPayload.toBase64URL().toString()

val jwsObject = JWSObject(jwsHeader, jwsPayload)
val toSign = jwsObject.signingInput

val signedBytes = did.keyManager.sign(keyAlias, toSign)
val base64UrlEncodedSignature = Base64URL(Convert(signedBytes).toBase64Url(padding = false))
val base64UrlEncodedHeader = jwsHeader.toBase64URL()


return "$base64UrlEncodedHeader.$base64UrlEncodedPayload.$base64UrlEncodedSignature"
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
package tbdex.sdk.httpclient

import assertk.assertThat
import assertk.assertions.containsExactlyInAnyOrder
import com.nimbusds.jwt.SignedJWT
import web5.sdk.crypto.InMemoryKeyManager
import web5.sdk.dids.methods.dht.DidDht
import kotlin.test.Test
import kotlin.test.assertEquals
import kotlin.test.assertNotNull
import kotlin.test.assertTrue

class RequestTokenTest {

@Test
fun `generateRequestToken() generates a JWT`() {
val did = DidDht.create(InMemoryKeyManager())
val pfiDid = "did:ion:123"

val token = RequestToken.generate(did, pfiDid)
assertNotNull(token)
}

@Test
fun `generateRequestToken() generates JWT with all required fields`() {
val did = DidDht.create(InMemoryKeyManager())
val pfiDid = "did:ion:123"

val token = RequestToken.generate(did, pfiDid)
val claimsSet = SignedJWT.parse(token).jwtClaimsSet

assertThat(claimsSet.claims.keys)
.containsExactlyInAnyOrder(RequestToken.requiredClaimKeys)
}

@Test
fun `generateRequestToken() generates JWT with fields containing correct values`() {
val did = DidDht.create(InMemoryKeyManager())
val pfiDid = "did:ion:123"

val token = RequestToken.generate(did, pfiDid)
val claimsSet = SignedJWT.parse(token).jwtClaimsSet

assertTrue(claimsSet.issuer.contains(did.uri))
assertTrue(claimsSet.audience.contains(pfiDid))
assertEquals(60000, claimsSet.expirationTime.time - claimsSet.issueTime.time)
}

@Test
fun `verifyRequestToken() validates given JWT token`() {
val did = DidDht.create(InMemoryKeyManager())
val pfiDid = "did:ion:123"

val token = RequestToken.generate(did, pfiDid)

val verificationResult = RequestToken.verify(token, pfiDid)

assertEquals(did.uri, verificationResult)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import io.ktor.http.HttpHeaders
import io.ktor.http.HttpStatusCode
import io.ktor.server.application.ApplicationCall
import io.ktor.server.response.respond
import tbdex.sdk.httpclient.RequestToken
import tbdex.sdk.httpclient.models.ErrorDetail
import tbdex.sdk.httpserver.models.ErrorResponse
import tbdex.sdk.httpserver.models.ExchangesApi
Expand Down Expand Up @@ -62,11 +63,9 @@ suspend fun getExchanges(
return
}

// todo: verify JWT token using new CryptoUtils.verify() method
// to be written to address these issues:
// generating JWT token: https://github.com/TBD54566975/tbdex-kt/issues/121
// verifying JWT token: https://github.com/TBD54566975/tbdex/issues/210

val token = arr[1]
// TODO: how to access pfiDid here?
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@phoebe-lew so i want to actually call RequestToken.verify() in this protected endpoint, but not sure how to access pfiDid from here?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

punting it to a separate issue #135

val requesterDid = RequestToken.verify(token, "")
val exchanges = exchangesApi.getExchanges()

if (callback != null) {
Expand Down
Loading