Skip to content

Commit

Permalink
Merge pull request #443 from auth0/ft-jwt
Browse files Browse the repository at this point in the history
Refactor JWT decoding logic
  • Loading branch information
lbalmaceda authored Jan 19, 2021
2 parents f268a87 + 577bdf5 commit d1446c7
Show file tree
Hide file tree
Showing 14 changed files with 388 additions and 78 deletions.
1 change: 0 additions & 1 deletion auth0/build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,6 @@ dependencies {
implementation 'com.squareup.okhttp3:okhttp:4.9.0'
implementation 'com.squareup.okhttp3:logging-interceptor:4.9.0'
implementation 'com.google.code.gson:gson:2.8.6'
implementation 'com.auth0.android:jwtdecode:1.3.0'

testImplementation 'junit:junit:4.13.1'
testImplementation 'org.hamcrest:java-hamcrest:2.0.0.0'
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
package com.auth0.android.authentication.storage;

import com.auth0.android.jwt.JWT;
import com.auth0.android.request.internal.Jwt;

/**
* Bridge class for decoding JWTs.
Expand All @@ -11,7 +11,7 @@ class JWTDecoder {
JWTDecoder() {
}

JWT decode(String jwt) {
return new JWT(jwt);
Jwt decode(String jwt) {
return new Jwt(jwt);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import androidx.annotation.NonNull;

import com.auth0.android.jwt.JWT;
import com.auth0.android.request.internal.Jwt;

import java.util.Calendar;
import java.util.Date;
Expand All @@ -25,7 +25,7 @@ class IdTokenVerifier {
* @param verifyOptions the verification options, like audience, issuer, algorithm.
* @throws TokenValidationException If the ID Token is null, its signing algorithm not supported, its signature invalid or one of its claim invalid.
*/
void verify(@NonNull JWT token, @NonNull IdTokenVerificationOptions verifyOptions) throws TokenValidationException {
void verify(@NonNull Jwt token, @NonNull IdTokenVerificationOptions verifyOptions) throws TokenValidationException {
verifyOptions.getSignatureVerifier().verify(token);

if (isEmpty(token.getIssuer())) {
Expand Down Expand Up @@ -69,7 +69,7 @@ void verify(@NonNull JWT token, @NonNull IdTokenVerificationOptions verifyOption
}

if (verifyOptions.getNonce() != null) {
String nonceClaim = token.getClaim(NONCE_CLAIM).asString();
String nonceClaim = token.getNonce();
if (isEmpty(nonceClaim)) {
throw new TokenValidationException("Nonce (nonce) claim must be a string present in the ID token");
}
Expand All @@ -79,7 +79,7 @@ void verify(@NonNull JWT token, @NonNull IdTokenVerificationOptions verifyOption
}

if (audience.size() > 1) {
String azpClaim = token.getClaim(AZP_CLAIM).asString();
String azpClaim = token.getAuthorizedParty();
if (isEmpty(azpClaim)) {
throw new TokenValidationException("Authorized Party (azp) claim must be a string present in the ID token when Audience (aud) claim has multiple values");
}
Expand All @@ -89,7 +89,7 @@ void verify(@NonNull JWT token, @NonNull IdTokenVerificationOptions verifyOption
}

if (verifyOptions.getMaxAge() != null) {
Date authTime = token.getClaim(AUTH_TIME_CLAIM).asDate();
Date authTime = token.getAuthenticationTime();
if (authTime == null) {
throw new TokenValidationException("Authentication Time (auth_time) claim must be a number present in the ID token when Max Age (max_age) is specified");
}
Expand Down
13 changes: 6 additions & 7 deletions auth0/src/main/java/com/auth0/android/provider/OAuthManager.kt
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,7 @@ import com.auth0.android.Auth0Exception
import com.auth0.android.authentication.AuthenticationAPIClient
import com.auth0.android.authentication.AuthenticationException
import com.auth0.android.callback.Callback
import com.auth0.android.jwt.DecodeException
import com.auth0.android.jwt.JWT
import com.auth0.android.request.internal.Jwt
import com.auth0.android.result.Credentials
import java.security.SecureRandom
import java.util.*
Expand Down Expand Up @@ -134,10 +133,10 @@ internal class OAuthManager(
validationCallback.onFailure(TokenValidationException("ID token is required but missing"))
return
}
val decodedIdToken: JWT = try {
JWT(idToken!!)
} catch (ignored: DecodeException) {
validationCallback.onFailure(TokenValidationException("ID token could not be decoded"))
val decodedIdToken: Jwt = try {
Jwt(idToken!!)
} catch (error: Exception) {
validationCallback.onFailure(TokenValidationException("ID token could not be decoded", error))
return
}
val signatureVerifierCallback: Callback<SignatureVerifier, TokenValidationException> =
Expand Down Expand Up @@ -167,7 +166,7 @@ internal class OAuthManager(
}
}
}
val tokenKeyId = decodedIdToken.header["kid"]
val tokenKeyId = decodedIdToken.keyId
SignatureVerifier.forAsymmetricAlgorithm(tokenKeyId, apiClient, signatureVerifierCallback)
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import com.auth0.android.authentication.AuthenticationException;
import com.auth0.android.callback.AuthenticationCallback;
import com.auth0.android.callback.Callback;
import com.auth0.android.jwt.JWT;
import com.auth0.android.request.internal.Jwt;

import java.security.InvalidKeyException;
import java.security.PublicKey;
Expand All @@ -31,12 +31,9 @@ abstract class SignatureVerifier {
* @param token the ID token to have its signature validated
* @throws TokenValidationException if the signature is not valid
*/
void verify(@NonNull JWT token) throws TokenValidationException {
String tokenAlg = token.getHeader().get("alg");
String[] tokenParts = token.toString().split("\\.");

checkAlgorithm(tokenAlg);
checkSignature(tokenParts);
void verify(@NonNull Jwt token) throws TokenValidationException {
checkAlgorithm(token.getAlgorithm());
checkSignature(token.getParts());
}

private void checkAlgorithm(String tokenAlgorithm) throws TokenValidationException {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,4 +5,8 @@ import com.auth0.android.Auth0Exception
/**
* Exception thrown when the validation of the ID token failed.
*/
internal class TokenValidationException(message: String) : Auth0Exception(message)
internal class TokenValidationException @JvmOverloads constructor(
message: String,
cause: Throwable? = null
) :
Auth0Exception(message, cause)
82 changes: 82 additions & 0 deletions auth0/src/main/java/com/auth0/android/request/internal/Jwt.kt
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
package com.auth0.android.request.internal

import android.util.Base64
import com.google.gson.reflect.TypeToken
import java.util.*


/**
* Internal class meant to decode the given token of type JWT and provide access to its claims.
*/
internal class Jwt(rawToken: String) {

private val decodedHeader: Map<String, Any>
private val decodedPayload: Map<String, Any>
val parts: Array<String>

// header
val algorithm: String
val keyId: String?

// payload
val subject: String?
val issuer: String?
val nonce: String?
val issuedAt: Date?
val expiresAt: Date?
val authorizedParty: String?
val authenticationTime: Date?
val audience: List<String>

init {
parts = splitToken(rawToken)
val jsonHeader = parts[0].decodeBase64()
val jsonPayload = parts[1].decodeBase64()
val mapAdapter = GsonProvider.gson.getAdapter(object : TypeToken<Map<String, Any>>() {})
decodedHeader = mapAdapter.fromJson(jsonHeader)
decodedPayload = mapAdapter.fromJson(jsonPayload)

// header claims
algorithm = decodedHeader["alg"] as String
keyId = decodedHeader["kid"] as String?

// payload claims
subject = decodedPayload["sub"] as String?
issuer = decodedPayload["iss"] as String?
nonce = decodedPayload["nonce"] as String?
issuedAt = (decodedPayload["iat"] as? Double)?.let { Date(it.toLong() * 1000) }
expiresAt = (decodedPayload["exp"] as? Double)?.let { Date(it.toLong() * 1000) }
authorizedParty = decodedPayload["azp"] as String?
authenticationTime =
(decodedPayload["auth_time"] as? Double)?.let { Date(it.toLong() * 1000) }
audience = when (val aud = decodedPayload["aud"]) {
is String -> listOf(aud)
is List<*> -> aud as List<String>
else -> emptyList()
}
}

private fun splitToken(token: String): Array<String> {
var parts = token.split(".").toTypedArray()
if (parts.size == 2 && token.endsWith(".")) {
// Tokens with alg='none' have empty String as Signature.
parts = arrayOf(parts[0], parts[1], "")
}
if (parts.size != 3) {
throw IllegalArgumentException(
String.format(
"The token was expected to have 3 parts, but got %s.",
parts.size
)
)
}
return parts
}

private fun String.decodeBase64(): String {
val bytes: ByteArray =
Base64.decode(this, Base64.URL_SAFE or Base64.NO_WRAP or Base64.NO_PADDING)
return String(bytes, Charsets.UTF_8)
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,8 @@ package com.auth0.android.authentication.storage
import com.auth0.android.authentication.AuthenticationAPIClient
import com.auth0.android.authentication.AuthenticationException
import com.auth0.android.callback.Callback
import com.auth0.android.jwt.JWT
import com.auth0.android.request.Request
import com.auth0.android.request.internal.Jwt
import com.auth0.android.result.Credentials
import com.auth0.android.result.CredentialsMock
import com.auth0.android.util.Clock
Expand Down Expand Up @@ -355,7 +355,7 @@ public class CredentialsManagerTest {
client.renewAuth("refreshToken")
).thenReturn(request)
val newDate = Date(CredentialsMock.ONE_HOUR_AHEAD_MS + ONE_HOUR_SECONDS * 1000)
val jwtMock = mock<JWT>()
val jwtMock = mock<Jwt>()
Mockito.`when`(jwtMock.expiresAt).thenReturn(newDate)
Mockito.`when`(jwtDecoder.decode("newId")).thenReturn(jwtMock)
manager.getCredentials("some scope", 0, callback)
Expand Down Expand Up @@ -416,7 +416,7 @@ public class CredentialsManagerTest {
client.renewAuth("refreshToken")
).thenReturn(request)
val newDate = Date(CredentialsMock.ONE_HOUR_AHEAD_MS + ONE_HOUR_SECONDS * 1000)
val jwtMock = mock<JWT>()
val jwtMock = mock<Jwt>()
Mockito.`when`(jwtMock.expiresAt).thenReturn(newDate)
Mockito.`when`(jwtDecoder.decode("newId")).thenReturn(jwtMock)
manager.getCredentials("some scope", 0, callback)
Expand Down Expand Up @@ -478,7 +478,7 @@ public class CredentialsManagerTest {
).thenReturn(request)
val newDate =
Date(CredentialsMock.CURRENT_TIME_MS + 61 * 1000) // New token expires in minTTL + 1 second
val jwtMock = mock<JWT>()
val jwtMock = mock<Jwt>()
Mockito.`when`(jwtMock.expiresAt).thenReturn(newDate)
Mockito.`when`(jwtDecoder.decode("newId")).thenReturn(jwtMock)
manager.getCredentials(null, 60, callback) // 60 seconds of minTTL
Expand Down Expand Up @@ -539,7 +539,7 @@ public class CredentialsManagerTest {
client.renewAuth("refreshToken")
).thenReturn(request)
val newDate = Date(CredentialsMock.ONE_HOUR_AHEAD_MS)
val jwtMock = mock<JWT>()
val jwtMock = mock<Jwt>()
Mockito.`when`(jwtMock.expiresAt).thenReturn(newDate)
Mockito.`when`(jwtDecoder.decode("newId")).thenReturn(jwtMock)
manager.getCredentials(callback)
Expand Down Expand Up @@ -601,7 +601,7 @@ public class CredentialsManagerTest {
).thenReturn(request)
val newDate =
Date(CredentialsMock.CURRENT_TIME_MS + 59 * 1000) // New token expires in minTTL - 1 second
val jwtMock = mock<JWT>()
val jwtMock = mock<Jwt>()
Mockito.`when`(jwtMock.expiresAt).thenReturn(newDate)
Mockito.`when`(jwtDecoder.decode("newId")).thenReturn(jwtMock)
manager.getCredentials(null, 60, callback) // 60 seconds of minTTL
Expand Down Expand Up @@ -654,7 +654,7 @@ public class CredentialsManagerTest {
client.renewAuth("refreshToken")
).thenReturn(request)
val newDate = Date(CredentialsMock.ONE_HOUR_AHEAD_MS)
val jwtMock = mock<JWT>()
val jwtMock = mock<Jwt>()
Mockito.`when`(jwtMock.expiresAt).thenReturn(newDate)
Mockito.`when`(jwtDecoder.decode("newId")).thenReturn(jwtMock)
manager.getCredentials(callback)
Expand Down Expand Up @@ -863,7 +863,7 @@ public class CredentialsManagerTest {
}

private fun prepareJwtDecoderMock(expiresAt: Date?) {
val jwtMock = mock<JWT>()
val jwtMock = mock<Jwt>()
Mockito.`when`(jwtMock.expiresAt).thenReturn(expiresAt)
Mockito.`when`(jwtDecoder.decode("idToken")).thenReturn(jwtMock)
}
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
package com.auth0.android.authentication.storage;

import com.auth0.android.jwt.JWT;
import com.auth0.android.request.internal.Jwt;

import org.junit.Test;
import org.junit.runner.RunWith;
Expand All @@ -14,25 +14,24 @@ public class JWTDecoderTest {

@Test
public void shouldDecodeAToken() {
String token = "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiIxMjM0NTY3ODkwIiwibmFtZSI6IkpvaG4gRG9lIiwiaWF0IjoxNTE2MjM5MDIyfQ.SflKxwRJSMeKKF2QT4fwpMeJf36POk6yJV_adQssw5c";
JWT jwt1 = new JWT(token);
String token = "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCIsImtpZCI6ImFsaWNlIn0.eyJzdWIiOiIxMjM0NTY3ODkwIiwibm9uY2UiOiJyZWFsbHkgcmFuZG9tIHRleHQiLCJpYXQiOjE1MTYyMzkwMjJ9.rYG-HEs1EKKDhwQIoEg32_p-NQzNi5rB7akqGnH_q4k";
Jwt jwt1 = new Jwt(token);

JWT jwt2 = new JWTDecoder().decode(token);
Jwt jwt2 = new JWTDecoder().decode(token);

//Header claims
assertThat(jwt1.getHeader().get("alg"), is("HS256"));
assertThat(jwt1.getHeader().get("typ"), is("JWT"));

assertThat(jwt2.getHeader().get("typ"), is("JWT"));
assertThat(jwt2.getHeader().get("alg"), is("HS256"));
assertThat(jwt1.getAlgorithm(), is("HS256"));
assertThat(jwt1.getKeyId(), is("alice"));
assertThat(jwt2.getAlgorithm(), is("HS256"));
assertThat(jwt2.getKeyId(), is("alice"));

//Payload claims
assertThat(jwt1.getSubject(), is("1234567890"));
assertThat(jwt1.getIssuedAt().getTime(), is(1516239022000L));
assertThat(jwt1.getClaim("name").asString(), is("John Doe"));
assertThat(jwt1.getNonce(), is("really random text"));

assertThat(jwt2.getSubject(), is("1234567890"));
assertThat(jwt2.getIssuedAt().getTime(), is(1516239022000L));
assertThat(jwt2.getClaim("name").asString(), is("John Doe"));
assertThat(jwt2.getNonce(), is("really random text"));
}
}
Loading

0 comments on commit d1446c7

Please sign in to comment.