Skip to content

Commit

Permalink
Algorithm validation: Remove exceptions (#2719)
Browse files Browse the repository at this point in the history
* Remove throws when validating algorithm used for signing a token.

* Adding small commits into one

* Remove the use of delegate check and move to ValidationParameters

* Replaced .Any() method call from check

---------

Co-authored-by: Franco Fung <francofung@microsoft.com>
  • Loading branch information
FuPingFranco and Franco Fung authored Jul 16, 2024
1 parent 2e7c701 commit 8c5e456
Show file tree
Hide file tree
Showing 9 changed files with 336 additions and 4 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,11 @@ public struct JwtRegisteredClaimNames
/// </summary>
public const string Address = "address";

/// <summary>
/// See: <see href="https://datatracker.ietf.org/doc/html/rfc7519#section-4"/>.
/// </summary>
public const string Alg = "alg";

/// <summary>
/// See: <see href="https://openid.net/specs/openid-connect-core-1_0.html#IDToken"/>.
/// </summary>
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

using System;
#nullable enable
namespace Microsoft.IdentityModel.Tokens
{
/// <summary>
/// Contains the result of validating the Algorithm of a <see cref="SecurityToken"/>.
/// The <see cref="TokenValidationResult"/> contains a collection of <see cref="ValidationResult"/> for each step in the token validation.
/// </summary>
internal class AlgorithmValidationResult : ValidationResult
{
private Exception? _exception;
private const string TokenSource = "Microsoft.IdentityModel.Tokens";

/// <summary>
/// Creates an instance of <see cref="AlgorithmValidationResult"/>.
/// </summary>
/// <paramref name="algorithm"/>The algorithm to be validated.
public AlgorithmValidationResult(string? algorithm)
: base(ValidationFailureType.ValidationSucceeded)
{
Algorithm = algorithm;
IsValid = true;
}

/// <summary>
/// Creates an instance of <see cref=" AlgorithmValidationResult"/>
/// </summary>
/// <paramref name="algorithm"/>The algorithm to be validated.
/// <paramref name="validationFailure"/> is the <see cref="ValidationFailureType"/> that occurred during validation.
/// <paramref name="exceptionDetail"/> is the <see cref="ExceptionDetail"/> that occurred during validation.
public AlgorithmValidationResult(string? algorithm, ValidationFailureType validationFailure, ExceptionDetail exceptionDetail)
: base(validationFailure, exceptionDetail)
{
Algorithm = algorithm;
IsValid = false;
}

/// <summary>
/// Gets the <see cref="Exception"/> that occurred during validation.
/// </summary>
public override Exception? Exception
{
get
{
if (_exception != null || ExceptionDetail == null)
return _exception;

HasValidOrExceptionWasRead = true;
_exception = ExceptionDetail.GetException();
if (_exception is SecurityTokenInvalidAlgorithmException securityTokenInvalidAlgorithmException)
{
securityTokenInvalidAlgorithmException.InvalidAlgorithm = Algorithm;
securityTokenInvalidAlgorithmException.Source = TokenSource;
}

return _exception;
}
}

/// <summary>
/// Gets the algorithm used to sign the token.
/// </summary>
public string? Algorithm { get; }

}
}
#nullable restore
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,12 @@ private class NullArgumentFailure : ValidationFailureType { internal NullArgumen
public static readonly ValidationFailureType IssuerValidationFailed = new IssuerValidationFailure("IssuerValidationFailed");
private class IssuerValidationFailure : ValidationFailureType { internal IssuerValidationFailure(string name) : base(name) { } }

/// <summary>
/// Defines a type that represents an algorithm validation failed.
/// </summary>
public static readonly ValidationFailureType AlgorithmValidationFailed = new AlgorithmValidationFailure("AlgorithmValidationFailed");
private class AlgorithmValidationFailure : ValidationFailureType { internal AlgorithmValidationFailure(string name) : base(name) { } }

/// <summary>
/// Defines a type that represents that audience validation failed.
/// </summary>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -476,7 +476,7 @@ public string RoleClaimType
/// If set to a non-empty collection, only the algorithms listed will be considered valid.
/// The default is <c>null</c>.
/// </remarks>
public IList<string> ValidAlgorithms { get; }
public IList<string> ValidAlgorithms { get; set; }

/// <summary>
/// Gets the <see cref="IList{String}"/> that contains valid audiences that will be used to check against the token's audience.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,60 @@
// Licensed under the MIT License.

using System;
using System.Diagnostics;
using System.Linq;
using Microsoft.IdentityModel.Logging;

#nullable enable
namespace Microsoft.IdentityModel.Tokens
{
public static partial class Validators
{
/// <summary>
/// Validates a given algorithm for a <see cref="SecurityKey"/>.
/// </summary>
/// <param name="algorithm">The algorithm to be validated.</param>
/// <param name="securityKey">The <see cref="SecurityKey"/> that signed the <see cref="SecurityToken"/>.</param>
/// <param name="securityToken">The <see cref="SecurityToken"/> being validated.</param>
/// <param name="validationParameters"><see cref="TokenValidationParameters"/> required for validation.</param>
/// <param name="callContext"></param>
#pragma warning disable CA1801 // TODO: remove pragma disable once callContext is used for logging
internal static AlgorithmValidationResult ValidateAlgorithm(
string algorithm,
SecurityKey securityKey,
SecurityToken securityToken,
ValidationParameters validationParameters,
CallContext callContext)
#pragma warning restore CA1801 // TODO: remove pragma disable once callContext is used for logging
{
if (validationParameters == null)
{
return new AlgorithmValidationResult(
algorithm,
ValidationFailureType.NullArgument,
new ExceptionDetail(
new MessageDetail(
LogMessages.IDX10000,
LogHelper.MarkAsNonPII(nameof(validationParameters))),
typeof(ArgumentNullException),
new StackFrame(true)));
}

if (validationParameters.ValidAlgorithms != null && validationParameters.ValidAlgorithms.Count > 0 && !validationParameters.ValidAlgorithms.Contains(algorithm, StringComparer.Ordinal))
{
return new AlgorithmValidationResult(
algorithm,
ValidationFailureType.AlgorithmValidationFailed,
new ExceptionDetail(
new MessageDetail(
LogMessages.IDX10696,
LogHelper.MarkAsNonPII(algorithm)),
typeof(SecurityTokenInvalidAlgorithmException),
new StackFrame(true)));
}

return new AlgorithmValidationResult(algorithm);
}
}
}
#nullable restore
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,11 @@ public struct JwtRegisteredClaimNames
/// </summary>
public const string Acr = Microsoft.IdentityModel.JsonWebTokens.JwtRegisteredClaimNames.Acr;

/// <summary>
/// https://datatracker.ietf.org/doc/html/rfc7519#section-4
/// </summary>
public const string Alg = Microsoft.IdentityModel.JsonWebTokens.JwtRegisteredClaimNames.Alg;

/// <summary>
/// http://openid.net/specs/openid-connect-core-1_0.html#IDToken
/// </summary>
Expand Down
5 changes: 5 additions & 0 deletions test/Microsoft.IdentityModel.TestUtils/ExpectedException.cs
Original file line number Diff line number Diff line change
Expand Up @@ -255,6 +255,11 @@ public static ExpectedException SecurityTokenKeyWrapException(string substringEx
return new ExpectedException(typeof(SecurityTokenKeyWrapException), substringExpected, innerTypeExpected, propertiesExpected: propertiesExpected);
}

public static ExpectedException SecurityTokenInvalidAlgorithmException(string substringExpected = null, Type innerTypeExpected = null, Dictionary<string, object> propertiesExpected = null)
{
return new ExpectedException(typeof(SecurityTokenInvalidAlgorithmException), substringExpected, innerTypeExpected, propertiesExpected: propertiesExpected);
}

public static ExpectedException SecurityTokenInvalidLifetimeException(string substringExpected = null, Type innerTypeExpected = null, Dictionary<string, object> propertiesExpected = null)
{
return new ExpectedException(typeof(SecurityTokenInvalidLifetimeException), substringExpected, innerTypeExpected, propertiesExpected: propertiesExpected);
Expand Down
71 changes: 68 additions & 3 deletions test/Microsoft.IdentityModel.TestUtils/IdentityComparer.cs
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@
using System.Security.Cryptography;
using System.Security.Cryptography.X509Certificates;
using System.Text.Json;
using System.Xml.Linq;
using Microsoft.IdentityModel.JsonWebTokens;
using Microsoft.IdentityModel.Protocols;
using Microsoft.IdentityModel.Protocols.OpenIdConnect;
Expand Down Expand Up @@ -157,6 +156,73 @@ public class IdentityComparer
};

// Keep methods in alphabetical order

public static bool AreAlgorithmValidationResultsEqual(object object1, object object2, CompareContext context)
{
var localContext = new CompareContext(context);
if (!ContinueCheckingEquality(object1, object2, context))
return context.Merge(localContext);

return AreAlgorithmValidationResultsEqual(
object1 as AlgorithmValidationResult,
object2 as AlgorithmValidationResult,
"AlgorithmValidationResult1",
"AlgorithmValidationResult2",
null,
context);
}

internal static bool AreAlgorithmValidationResultsEqual(
AlgorithmValidationResult algorithmValidationResult1,
AlgorithmValidationResult algorithmValidationResult2,
string name1,
string name2,
string stackPrefix,
CompareContext context)
{
var localContext = new CompareContext(context);
if (!ContinueCheckingEquality(algorithmValidationResult1, algorithmValidationResult2, localContext))
return context.Merge(localContext);

if (algorithmValidationResult1.Algorithm != algorithmValidationResult2.Algorithm)
localContext.Diffs.Add($"AlgorithmValidationResult1.Algorithm: '{algorithmValidationResult1.Algorithm}' != AlgorithmValidationResult2.Algorithm: '{algorithmValidationResult2.Algorithm}'");

if (algorithmValidationResult1.IsValid != algorithmValidationResult2.IsValid)
localContext.Diffs.Add($"AlgorithmValidationResult1.IsValid: {algorithmValidationResult1.IsValid} != AlgorithmValidationResult2.IsValid: {algorithmValidationResult2.IsValid}");

if (algorithmValidationResult1.ValidationFailureType != algorithmValidationResult2.ValidationFailureType)
localContext.Diffs.Add($"AlgorithmValidationResult1.ValidationFailureType: {algorithmValidationResult1.ValidationFailureType} != AlgorithmValidationResult2.ValidationFailureType: {algorithmValidationResult2.ValidationFailureType}");

// true => both are not null.
if (ContinueCheckingEquality(algorithmValidationResult1.Exception, algorithmValidationResult2.Exception, localContext))
{
AreStringsEqual(
algorithmValidationResult1.Exception.Message,
algorithmValidationResult2.Exception.Message,
$"({name1}).Exception.Message",
$"({name2}).Exception.Message",
localContext);

AreStringsEqual(
algorithmValidationResult1.Exception.Source,
algorithmValidationResult2.Exception.Source,
$"({name1}).Exception.Source",
$"({name2}).Exception.Source",
localContext);

if (!string.IsNullOrEmpty(stackPrefix))
AreStringPrefixesEqual(
algorithmValidationResult1.Exception.StackTrace.Trim(),
algorithmValidationResult2.Exception.StackTrace.Trim(),
$"({name1}).Exception.StackTrace",
$"({name2}).Exception.StackTrace",
stackPrefix.Trim(),
localContext);
}

return context.Merge(localContext);
}

public static bool AreBoolsEqual(object object1, object object2, CompareContext context)
{
return AreBoolsEqual(object1, object2, "bool1", "bool2", context);
Expand Down Expand Up @@ -903,7 +969,7 @@ internal static bool AreTokenTypeValidationResultsEqual(
return context.Merge(localContext);

if (tokenTypeValidationResult1.Type != tokenTypeValidationResult2.Type)
localContext.Diffs.Add($"TokenTypeValidationResult1.Type: '{tokenTypeValidationResult1.Type}' != TokenTypeValidationResult2.ExpirationTime: '{tokenTypeValidationResult2.Type}'");
localContext.Diffs.Add($"TokenTypeValidationResult1.Type: '{tokenTypeValidationResult1.Type}' != TokenTypeValidationResult2.Type: '{tokenTypeValidationResult2.Type}'");

if (tokenTypeValidationResult1.IsValid != tokenTypeValidationResult2.IsValid)
localContext.Diffs.Add($"TokenTypeValidationResult1.IsValid: {tokenTypeValidationResult1.IsValid} != TokenTypeValidationResult2.IsValid: {tokenTypeValidationResult2.IsValid}");
Expand Down Expand Up @@ -1817,6 +1883,5 @@ public static bool AreDatesEqualWithEpsilon(DateTime? dateTime1, DateTime? dateT

return dateTime1 == dateTime2;
}

}
}
Loading

0 comments on commit 8c5e456

Please sign in to comment.