Skip to content

Commit

Permalink
Fixes ##2251
Browse files Browse the repository at this point in the history
and the same issue for WIF for MSI
and adds loggin in WIF for AKS
  • Loading branch information
jmprieur committed May 19, 2023
1 parent 49f2ceb commit 4e44fdd
Show file tree
Hide file tree
Showing 9 changed files with 237 additions and 33 deletions.
33 changes: 30 additions & 3 deletions src/Microsoft.Identity.Web.Certificate/DefaultCertificateLoader.cs
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
using System.Collections.Generic;
using System.Linq;
using System.Security.Cryptography.X509Certificates;
using Microsoft.Extensions.Logging;
using Microsoft.Identity.Abstractions;

namespace Microsoft.Identity.Web
Expand All @@ -26,6 +27,22 @@ namespace Microsoft.Identity.Web
/// </summary>
public class DefaultCertificateLoader : DefaultCredentialsLoader, ICertificateLoader
{
/// <summary>
/// Constructor with a logger.
/// </summary>
/// <param name="logger"></param>
public DefaultCertificateLoader(ILogger<DefaultCertificateLoader>? logger) : base(logger)
{

}

/// <summary>
/// Default constuctor.
/// </summary>
//[Obsolete("Rather use the constructor with a logger")]
public DefaultCertificateLoader() : this(null)
{
}

/// <summary>
/// This default is overridable at the level of the credential description (for the certificate from KeyVault).
Expand All @@ -50,7 +67,7 @@ public static string? UserAssignedManagedIdentityClientId
/// <returns>First certificate in the certificate description list.</returns>
public static X509Certificate2? LoadFirstCertificate(IEnumerable<CertificateDescription> certificateDescriptions)
{
DefaultCertificateLoader defaultCertificateLoader = new();
DefaultCertificateLoader defaultCertificateLoader = new(null);
CertificateDescription? certDescription = certificateDescriptions.FirstOrDefault(c =>
{
defaultCertificateLoader.LoadCredentialsIfNeededAsync(c).GetAwaiter().GetResult();
Expand All @@ -67,12 +84,22 @@ public static string? UserAssignedManagedIdentityClientId
/// <returns>All the certificates in the certificate description list.</returns>
public static IEnumerable<X509Certificate2?> LoadAllCertificates(IEnumerable<CertificateDescription> certificateDescriptions)
{
DefaultCertificateLoader defaultCertificateLoader = new();
DefaultCertificateLoader defaultCertificateLoader = new(null);
return defaultCertificateLoader.LoadCertificates(certificateDescriptions);
}

/// <summary>
/// Load the certificates from the certificate description list.
/// </summary>
/// <param name="certificateDescriptions"></param>
/// <returns>a collection of certificates</returns>
private IEnumerable<X509Certificate2?> LoadCertificates(IEnumerable<CertificateDescription> certificateDescriptions)
{
if (certificateDescriptions != null)
{
foreach (var certDescription in certificateDescriptions)
{
defaultCertificateLoader.LoadCredentialsIfNeededAsync(certDescription).GetAwaiter().GetResult();
LoadCredentialsIfNeededAsync(certDescription).GetAwaiter().GetResult();
if (certDescription.Certificate != null)
{
yield return certDescription.Certificate;
Expand Down
42 changes: 32 additions & 10 deletions src/Microsoft.Identity.Web.Certificate/DefaultCredentialsLoader.cs
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

using System;
using System.Collections.Generic;
using System.Threading.Tasks;
using Microsoft.Extensions.Logging;
using Microsoft.Identity.Abstractions;

namespace Microsoft.Identity.Web
Expand All @@ -12,20 +14,40 @@ namespace Microsoft.Identity.Web
/// </summary>
public class DefaultCredentialsLoader : ICredentialsLoader
{
ILogger<DefaultCredentialsLoader>? _logger;

/// <summary>
/// Constructor with a logger
/// </summary>
/// <param name="logger"></param>
public DefaultCredentialsLoader(ILogger<DefaultCredentialsLoader>? logger)
{
_logger = logger;
CredentialSourceLoaders = new Dictionary<CredentialSource, ICredentialSourceLoader>
{
{ CredentialSource.KeyVault, new KeyVaultCertificateLoader() },
{ CredentialSource.Path, new FromPathCertificateLoader() },
{ CredentialSource.StoreWithThumbprint, new StoreWithThumbprintCertificateLoader() },
{ CredentialSource.StoreWithDistinguishedName, new StoreWithDistinguishedNameCertificateLoader() },
{ CredentialSource.Base64Encoded, new Base64EncodedCertificateLoader() },
{ CredentialSource.SignedAssertionFromManagedIdentity, new SignedAssertionFromManagedIdentityCredentialLoader() },
{ CredentialSource.SignedAssertionFilePath, new SignedAssertionFilePathCredentialsLoader(_logger) }
};
}

/// <summary>
/// Default constructor (for backward compatibility)
/// </summary>
//[Obsolete("Rather use the constructor with a logger.")]
public DefaultCredentialsLoader() : this(null)
{
}

/// <summary>
/// Dictionary of credential loaders per credential source. The application can add more to
/// process additional credential sources(like dSMS).
/// </summary>
public IDictionary<CredentialSource, ICredentialSourceLoader> CredentialSourceLoaders { get; } = new Dictionary<CredentialSource, ICredentialSourceLoader>
{
{ CredentialSource.KeyVault, new KeyVaultCertificateLoader() },
{ CredentialSource.Path, new FromPathCertificateLoader() },
{ CredentialSource.StoreWithThumbprint, new StoreWithThumbprintCertificateLoader() },
{ CredentialSource.StoreWithDistinguishedName, new StoreWithDistinguishedNameCertificateLoader() },
{ CredentialSource.Base64Encoded, new Base64EncodedCertificateLoader() },
{ CredentialSource.SignedAssertionFromManagedIdentity, new SignedAssertionFromManagedIdentityCredentialLoader() },
{ CredentialSource.SignedAssertionFilePath, new SignedAssertionFilePathCredentialsLoader() },
};
public IDictionary<CredentialSource, ICredentialSourceLoader> CredentialSourceLoaders { get; }

/// <inheritdoc/>
/// Load the credentials from the description, if needed.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,22 @@
using System.Threading;
using Microsoft.Identity.Abstractions;
using System.Threading.Tasks;
using Microsoft.Extensions.Logging;

namespace Microsoft.Identity.Web
{
internal class SignedAssertionFilePathCredentialsLoader : ICredentialSourceLoader
{
ILogger? _logger;

/// <summary>
/// Constructor
/// </summary>
/// <param name="logger">Optional logger.</param>
public SignedAssertionFilePathCredentialsLoader(ILogger? logger)
{
_logger = logger;
}
public CredentialSource CredentialSource => CredentialSource.SignedAssertionFilePath;

public async Task LoadIfNeededAsync(CredentialDescription credentialDescription, CredentialSourceLoaderParameters? credentialSourceLoaderParameters)
Expand All @@ -19,17 +30,17 @@ public async Task LoadIfNeededAsync(CredentialDescription credentialDescription,
AzureIdentityForKubernetesClientAssertion? signedAssertion = credentialDescription.CachedValue as AzureIdentityForKubernetesClientAssertion;
if (credentialDescription.CachedValue == null)
{
signedAssertion = new AzureIdentityForKubernetesClientAssertion(credentialDescription.SignedAssertionFileDiskPath);
signedAssertion = new AzureIdentityForKubernetesClientAssertion(credentialDescription.SignedAssertionFileDiskPath, _logger);
}
try
{
// Given that managed identity can be not available locally, we need to try to get a
// signed assertion, and if it fails, move to the next credentials
_= await signedAssertion!.GetSignedAssertion(CancellationToken.None);
credentialDescription.CachedValue = signedAssertion;
}
catch (Exception)
{
credentialDescription.CachedValue = signedAssertion;
credentialDescription.Skip = true;
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,10 +30,10 @@ public async Task LoadIfNeededAsync(CredentialDescription credentialDescription,
// Given that managed identity can be not available locally, we need to try to get a
// signed assertion, and if it fails, move to the next credentials
_= await managedIdentityClientAssertion!.GetSignedAssertion(CancellationToken.None);
credentialDescription.CachedValue = managedIdentityClientAssertion;
}
catch (AuthenticationFailedException)
{
credentialDescription.CachedValue = managedIdentityClientAssertion;
credentialDescription.Skip = true;
throw;
}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

using System;
using Microsoft.Extensions.Logging;

namespace Microsoft.Identity.Web
{
public partial class AzureIdentityForKubernetesClientAssertion
{
/*
// High performance logger messages (before generation).
#pragma warning disable SYSLIB1009 // Logging methods must be static
[LoggerMessage(EventId = 1, Level = LogLevel.Information, Message = "SignedAssertionFileDiskPath not provided. Falling back to the content of the AZURE_FEDERATED_TOKEN_FILE environment variable.")]
partial void SignedAssertionFileDiskPathNotProvided(ILogger logger);
[LoggerMessage(EventId = 2, Level = LogLevel.Information, Message = "The `{environmentVariableName}` environment variable not provided.")]
partial void SignedAssertionEnvironmentVariableNotProvided(ILogger logger, string environmentVariableName);
[LoggerMessage(EventId = 3, Level = LogLevel.Error, Message = "The environment variable AZURE_FEDERATED_TOKEN_FILE or AZURE_ACCESS_TOKEN_FILE or the 'SignedAssertionFileDiskPath' must be set to the path of the file containing the signed assertion.")]
partial void NoSignedAssertionParameterProvided(ILogger logger);
[LoggerMessage(EventId = 4, Level = LogLevel.Error, Message = "The file `{filePath}` containing the signed assertion was not found.")]
partial void FileAssertionPathNotFound(ILogger logger, string filePath);
[LoggerMessage(EventId = 5, Level = LogLevel.Information, Message = "Successfully read the signed assertion for `{filePath}`. Expires at {expiry}.")]
partial void SuccessFullyReadSignedAssertion(ILogger logger, string filePath, DateTime expiry);
[LoggerMessage(EventId = 6, Level = LogLevel.Error, Message = "The file `{filePath} does not contain a valid signed assertion. {message}.")]
partial void FileDoesNotContainValidAssertion(ILogger logger, string filePath, string message);
#pragma warning restore SYSLIB1009 // Logging methods must be static
*/

/// <summary>
/// Performant logging messages.
/// </summary>
static class Log
{
private static readonly Action<ILogger, Exception?> __SignedAssertionFileDiskPathNotProvidedCallback =
LoggerMessage.Define(LogLevel.Information, new EventId(1, nameof(SignedAssertionFileDiskPathNotProvided)), "SignedAssertionFileDiskPath not provided. Falling back to the content of the AZURE_FEDERATED_TOKEN_FILE environment variable.");

public static void SignedAssertionFileDiskPathNotProvided(ILogger? logger)
{
if (logger != null && logger.IsEnabled(LogLevel.Information))
{
__SignedAssertionFileDiskPathNotProvidedCallback(logger, null);
}
}
private static readonly Action<ILogger, string, Exception?> __SignedAssertionEnvironmentVariableNotProvidedCallback =
LoggerMessage.Define<string>(LogLevel.Information, new EventId(2, nameof(SignedAssertionEnvironmentVariableNotProvided)), "The `{environmentVariableName}` environment variable not provided.");

public static void SignedAssertionEnvironmentVariableNotProvided(ILogger? logger, string environmentVariableName)
{
if (logger != null && logger.IsEnabled(LogLevel.Information))
{
__SignedAssertionEnvironmentVariableNotProvidedCallback(logger, environmentVariableName, null);
}
}
private static readonly Action<ILogger, Exception?> __NoSignedAssertionParameterProvidedCallback =
LoggerMessage.Define(LogLevel.Error, new EventId(3, nameof(NoSignedAssertionParameterProvided)), "The environment variable AZURE_FEDERATED_TOKEN_FILE or AZURE_ACCESS_TOKEN_FILE or the 'SignedAssertionFileDiskPath' must be set to the path of the file containing the signed assertion.");

public static void NoSignedAssertionParameterProvided(ILogger? logger)
{
if (logger != null && logger.IsEnabled(LogLevel.Error))
{
__NoSignedAssertionParameterProvidedCallback(logger, null);
}
}
private static readonly Action<ILogger, string, Exception?> __FileAssertionPathNotFoundCallback =
LoggerMessage.Define<string>(LogLevel.Error, new EventId(4, nameof(FileAssertionPathNotFound)), "The file `{filePath}` containing the signed assertion was not found.");

public static void FileAssertionPathNotFound(ILogger? logger, string filePath)
{
if (logger != null && logger.IsEnabled(LogLevel.Error))
{
__FileAssertionPathNotFoundCallback(logger, filePath, null);
}
}
private static readonly Action<ILogger, string, DateTime, Exception?> __SuccessFullyReadSignedAssertionCallback =
LoggerMessage.Define<string, DateTime>(LogLevel.Information, new EventId(5, nameof(SuccessFullyReadSignedAssertion)), "Successfully read the signed assertion for `{filePath}`. Expires at {expiry}.");

public static void SuccessFullyReadSignedAssertion(ILogger? logger, string filePath, DateTime expiry)
{
if (logger != null && logger.IsEnabled(LogLevel.Information))
{
__SuccessFullyReadSignedAssertionCallback(logger, filePath, expiry, null);
}
}
private static readonly Action<ILogger, string, string, Exception?> __FileDoesNotContainValidAssertionCallback =
LoggerMessage.Define<string, string>(LogLevel.Error, new EventId(6, nameof(FileDoesNotContainValidAssertion)), "The file `{filePath} does not contain a valid signed assertion. {message}.");

public static void FileDoesNotContainValidAssertion(ILogger? logger, string filePath, string message)
{
if (logger != null && logger.IsEnabled(LogLevel.Error))
{
__FileDoesNotContainValidAssertionCallback(logger, filePath, message, null);
}
}
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
using System.IO;
using System.Threading;
using System.Threading.Tasks;
using Microsoft.Extensions.Logging;
using Microsoft.IdentityModel.JsonWebTokens;

namespace Microsoft.Identity.Web
Expand All @@ -14,40 +15,79 @@ namespace Microsoft.Identity.Web
/// in Azure Kubernetes Services. See https://aka.ms/ms-id-web/certificateless and
/// https://learn.microsoft.com/azure/aks/workload-identity-overview
/// </summary>
public class AzureIdentityForKubernetesClientAssertion : ClientAssertionProviderBase
public partial class AzureIdentityForKubernetesClientAssertion : ClientAssertionProviderBase
{
/// <summary>
/// Gets a signed assertion from Azure workload identity for kubernetes. The file path is provided
/// by an environment variable ("AZURE_FEDERATED_TOKEN_FILE")
/// See https://aka.ms/ms-id-web/certificateless.
/// </summary>
public AzureIdentityForKubernetesClientAssertion() : this(null)
public AzureIdentityForKubernetesClientAssertion(ILogger? logger = null) : this(null, logger)
{
}

/// <summary>
/// Gets a signed assertion from a file.
/// See https://aka.ms/ms-id-web/certificateless.
/// </summary>
/// <param name="filePath"></param>
public AzureIdentityForKubernetesClientAssertion(string? filePath)
/// <param name="filePath">Path to a file containing the signed assertion.</param>
/// <param name="logger">Logger.</param>
public AzureIdentityForKubernetesClientAssertion(string? filePath, ILogger? logger = null)
{
_logger = logger;

if (filePath == null)
{
Log.SignedAssertionFileDiskPathNotProvided(_logger);
}

_filePath = _filePath ?? Environment.GetEnvironmentVariable("AZURE_ACCESS_TOKEN_FILE");
if (filePath == null)
{
Log.SignedAssertionEnvironmentVariableNotProvided(_logger, "AZURE_ACCESS_TOKEN_FILE");
}

// See https://blog.identitydigest.com/azuread-federate-k8s/
_filePath = filePath ?? Environment.GetEnvironmentVariable("AZURE_FEDERATED_TOKEN_FILE");
if (_filePath == null)
{
Log.SignedAssertionEnvironmentVariableNotProvided(_logger, "AZURE_FEDERATED_TOKEN_FILE");
Log.NoSignedAssertionParameterProvided(_logger);
}
}

private readonly string _filePath;
private readonly string? _filePath;

private readonly ILogger? _logger;

/// <summary>
/// Get the signed assertion from a file.
/// </summary>
/// <returns>The signed assertion.</returns>
protected override Task<ClientAssertion> GetClientAssertion(CancellationToken cancellationToken)
{
if (_filePath != null && !File.Exists(_filePath))
{
Log.FileAssertionPathNotFound(_logger, _filePath);
throw new FileNotFoundException($"The file '{_filePath}' containing the signed assertion was not found.");

}
string signedAssertion = File.ReadAllText(_filePath);
// Compute the expiry
JsonWebToken jwt = new JsonWebToken(signedAssertion);
return Task.FromResult(new ClientAssertion(signedAssertion, jwt.ValidTo));

// Verify that the assertion is a JWS, JWE, and computes the expiry
try
{
JsonWebToken jwt = new JsonWebToken(signedAssertion);

Log.SuccessFullyReadSignedAssertion(_logger, _filePath!, jwt.ValidTo);

return Task.FromResult(new ClientAssertion(signedAssertion, jwt.ValidTo));
}
catch (ArgumentException ex)
{
Log.FileDoesNotContainValidAssertion(_logger, _filePath!, ex.Message);
throw;
}
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@

<ItemGroup>
<PackageReference Include="Azure.Identity" Version="$(AzureIdentityVersion)" />
<PackageReference Include="Microsoft.Extensions.Logging.Abstractions" Version="$(MicrosoftExtensionsLoggingVersion)" />
<PackageReference Include="System.Text.Encodings.Web" Version="$(SystemTextEncodingsWebVersion)" />
<PackageReference Include="Microsoft.IdentityModel.JsonWebTokens " Version="$(IdentityModelVersion)" />
</ItemGroup>
Expand Down
Loading

0 comments on commit 4e44fdd

Please sign in to comment.