Skip to content

Commit

Permalink
Tests added
Browse files Browse the repository at this point in the history
  • Loading branch information
AlexanderSher committed Aug 26, 2020
1 parent 19d3705 commit bd6412f
Show file tree
Hide file tree
Showing 2 changed files with 281 additions and 0 deletions.
34 changes: 34 additions & 0 deletions sdk/identity/Azure.Identity/src/ChainedTokenCredential.cs
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,11 @@ public class ChainedTokenCredential : TokenCredential

private readonly TokenCredential[] _sources;

internal ChainedTokenCredential()
{
_sources = Array.Empty<TokenCredential>();
}

/// <summary>
/// Creates an instance with the specified <see cref="TokenCredential"/> sources.
/// </summary>
Expand Down Expand Up @@ -107,5 +112,34 @@ public override async ValueTask<AccessToken> GetTokenAsync(TokenRequestContext r

throw AuthenticationFailedException.CreateAggregateException(AggregateAllUnavailableErrorMessage, exceptions);
}

private async ValueTask<AccessToken> GetTokenImplAsync(bool async, TokenRequestContext requestContext, CancellationToken cancellationToken)
{
using CredentialDiagnosticScope scope = _pipeline.StartGetTokenScopeGroup("DefaultAzureCredential.GetToken", requestContext);

try
{
using var asyncLock = await _credentialLock.GetLockOrValueAsync(async, cancellationToken).ConfigureAwait(false);

AccessToken token;
if (asyncLock.HasValue)
{
token = await GetTokenFromCredentialAsync(asyncLock.Value, requestContext, async, cancellationToken).ConfigureAwait(false);
}
else
{
TokenCredential credential;
(token, credential) = await GetTokenFromSourcesAsync(_sources, requestContext, async, cancellationToken).ConfigureAwait(false);
_sources = default;
asyncLock.SetValue(credential);
}

return scope.Succeeded(token);
}
catch (Exception e)
{
throw scope.FailWrapAndThrow(e);
}
}
}
}
247 changes: 247 additions & 0 deletions sdk/identity/Azure.Identity/tests/ChainedTokenCredentialLiveTests.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,247 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

using System;
using System.Collections.Generic;
using System.Threading;
using System.Threading.Tasks;
using Azure.Core;
using Azure.Core.TestFramework;
using Azure.Core.Tests;
using NUnit.Framework;

namespace Azure.Identity.Tests
{
public class ChainedTokenCredentialLiveTests : RecordedTestBase<IdentityTestEnvironment>
{
private const string ExpectedServiceName = "VS Code Azure";

public ChainedTokenCredentialLiveTests(bool isAsync) : base(isAsync)
{
Matcher.ExcludeHeaders.Add("Content-Length");
Matcher.ExcludeHeaders.Add("client-request-id");
Matcher.ExcludeHeaders.Add("x-client-OS");
Matcher.ExcludeHeaders.Add("x-client-SKU");
Matcher.ExcludeHeaders.Add("x-client-CPU");

Sanitizer = new IdentityRecordedTestSanitizer();
TestDiagnostics = false;
}

[Test]
[RunOnlyOnPlatforms(Windows = true)] // VisualStudioCredential works only on Windows
public async Task ChainedTokenCredential_UseVisualStudioCredential()
{
var pipeline = CredentialPipeline.GetInstance(null);
var fileSystem = CredentialTestHelpers.CreateFileSystemForVisualStudio();
var (expectedToken, expectedExpiresOn, processOutput) = CredentialTestHelpers.CreateTokenForVisualStudio();
var processService = new TestProcessService(new TestProcess { Output = processOutput });

var miCredential = new ManagedIdentityCredential(EnvironmentVariables.ClientId, pipeline);
var vsCredential = new VisualStudioCredential(default, pipeline, fileSystem, processService);
var credential = InstrumentClient(new ChainedTokenCredential(miCredential, vsCredential));

AccessToken token;
List<ClientDiagnosticListener.ProducedDiagnosticScope> scopes;

using (ClientDiagnosticListener diagnosticListener = new ClientDiagnosticListener(s => s.StartsWith("Azure.Identity")))
{
token = await credential.GetTokenAsync(new TokenRequestContext(new[] {"https://vault.azure.net/.default"}), CancellationToken.None);
scopes = diagnosticListener.Scopes;
}

Assert.AreEqual(token.Token, expectedToken);
Assert.AreEqual(token.ExpiresOn, expectedExpiresOn);

Assert.AreEqual(2, scopes.Count);
Assert.AreEqual($"{nameof(ChainedTokenCredential)}.{nameof(ChainedTokenCredential.GetToken)}", scopes[0].Name);
Assert.AreEqual($"{nameof(VisualStudioCredential)}.{nameof(VisualStudioCredential.GetToken)}", scopes[1].Name);
}

[Test]
[RunOnlyOnPlatforms(Windows = true, OSX = true, ContainerNames = new[] { "ubuntu_netcore2_keyring" })]
public async Task ChainedTokenCredential_UseVisualStudioCodeCredential()
{
var pipeline = CredentialPipeline.GetInstance(null);
var cloudName = Guid.NewGuid().ToString();
var fileSystem = CredentialTestHelpers.CreateFileSystemForVisualStudioCode(TestEnvironment, cloudName);
var processService = new TestProcessService(new TestProcess { Error = "Error" });

var miCredential = new ManagedIdentityCredential(EnvironmentVariables.ClientId, pipeline);
var vsCredential = new VisualStudioCredential(default, pipeline, fileSystem, processService);
var vscCredential = new VisualStudioCodeCredential(new VisualStudioCodeCredentialOptions { TenantId = TestEnvironment.TestTenantId }, pipeline, default, fileSystem, default);

var credential = InstrumentClient(new ChainedTokenCredential(miCredential, vsCredential, vscCredential));

AccessToken token;
List<ClientDiagnosticListener.ProducedDiagnosticScope> scopes;

using (await CredentialTestHelpers.CreateRefreshTokenFixtureAsync(TestEnvironment, Mode, ExpectedServiceName, cloudName))
using (ClientDiagnosticListener diagnosticListener = new ClientDiagnosticListener(s => s.StartsWith("Azure.Identity")))
{
token = await credential.GetTokenAsync(new TokenRequestContext(new[] {"https://vault.azure.net/.default"}), CancellationToken.None);
scopes = diagnosticListener.Scopes;
}

Assert.IsNotNull(token.Token);

Assert.AreEqual(2, scopes.Count);
Assert.AreEqual($"{nameof(ChainedTokenCredential)}.{nameof(ChainedTokenCredential.GetToken)}", scopes[0].Name);
Assert.AreEqual($"{nameof(VisualStudioCodeCredential)}.{nameof(VisualStudioCodeCredential.GetToken)}", scopes[1].Name);
}

[Test]
[RunOnlyOnPlatforms(Windows = true, OSX = true, ContainerNames = new[] { "ubuntu_netcore2_keyring" })]
public async Task ChainedTokenCredential_UseVisualStudioCodeCredential_ParallelCalls()
{
var pipeline = CredentialPipeline.GetInstance(null);
var cloudName = Guid.NewGuid().ToString();
var fileSystem = CredentialTestHelpers.CreateFileSystemForVisualStudioCode(TestEnvironment, cloudName);
var processService = new TestProcessService { CreateHandler = psi => new TestProcess { Error = "Error" }};

var miCredential = new ManagedIdentityCredential(EnvironmentVariables.ClientId, pipeline);
var vsCredential = new VisualStudioCredential(default, pipeline, fileSystem, processService);
var vscCredential = new VisualStudioCodeCredential(new VisualStudioCodeCredentialOptions { TenantId = TestEnvironment.TestTenantId }, pipeline, default, fileSystem, default);
var credential = InstrumentClient(new ChainedTokenCredential(miCredential, vsCredential, vscCredential));

var tasks = new List<Task<AccessToken>>();
using (await CredentialTestHelpers.CreateRefreshTokenFixtureAsync(TestEnvironment, Mode, ExpectedServiceName, cloudName))
{
for (int i = 0; i < 10; i++)
{
tasks.Add(Task.Run(async () => await credential.GetTokenAsync(new TokenRequestContext(new[] {"https://vault.azure.net/.default"}), CancellationToken.None)));
}

await Task.WhenAll(tasks);
}

foreach (Task<AccessToken> task in tasks)
{
Assert.IsNotNull(task.Result.Token);
}
}

[Test]
public async Task ChainedTokenCredential_UseAzureCliCredential()
{
var pipeline = CredentialPipeline.GetInstance(null);
var (expectedToken, expectedExpiresOn, processOutput) = CredentialTestHelpers.CreateTokenForAzureCli();
var vscAdapter = new TestVscAdapter(ExpectedServiceName, "Azure", null);
var fileSystem = CredentialTestHelpers.CreateFileSystemForVisualStudioCode(TestEnvironment);
var processService = new TestProcessService(new TestProcess { Output = processOutput });

var miCredential = new ManagedIdentityCredential(EnvironmentVariables.ClientId, pipeline);
var vsCredential = new VisualStudioCredential(default, pipeline, fileSystem, processService);
var vscCredential = new VisualStudioCodeCredential(new VisualStudioCodeCredentialOptions { TenantId = TestEnvironment.TestTenantId }, pipeline, default, fileSystem, vscAdapter);
var azureCliCredential = new AzureCliCredential(pipeline, processService);

var credential = InstrumentClient(new ChainedTokenCredential(miCredential, vsCredential, vscCredential, azureCliCredential));

AccessToken token;
List<ClientDiagnosticListener.ProducedDiagnosticScope> scopes;

using (ClientDiagnosticListener diagnosticListener = new ClientDiagnosticListener(s => s.StartsWith("Azure.Identity")))
{
token = await credential.GetTokenAsync(new TokenRequestContext(new[] {"https://vault.azure.net/.default"}), CancellationToken.None);
scopes = diagnosticListener.Scopes;
}

Assert.AreEqual(token.Token, expectedToken);
Assert.AreEqual(token.ExpiresOn, expectedExpiresOn);

Assert.AreEqual(2, scopes.Count);
Assert.AreEqual($"{nameof(ChainedTokenCredential)}.{nameof(ChainedTokenCredential.GetToken)}", scopes[0].Name);
Assert.AreEqual($"{nameof(AzureCliCredential)}.{nameof(AzureCliCredential.GetToken)}", scopes[1].Name);
}

[Test]
public async Task ChainedTokenCredential_UseAzureCliCredential_ParallelCalls()
{
var pipeline = CredentialPipeline.GetInstance(null);
var (expectedToken, expectedExpiresOn, processOutput) = CredentialTestHelpers.CreateTokenForAzureCli();
var vscAdapter = new TestVscAdapter(ExpectedServiceName, "Azure", null);
var fileSystem = CredentialTestHelpers.CreateFileSystemForVisualStudioCode(TestEnvironment);
var processService = new TestProcessService(new TestProcess { Output = processOutput });

var miCredential = new ManagedIdentityCredential(EnvironmentVariables.ClientId, pipeline);
var vsCredential = new VisualStudioCredential(default, pipeline, fileSystem, processService);
var vscCredential = new VisualStudioCodeCredential(new VisualStudioCodeCredentialOptions { TenantId = TestEnvironment.TestTenantId }, pipeline, default, fileSystem, vscAdapter);
var azureCliCredential = new AzureCliCredential(pipeline, processService);

var credential = InstrumentClient(new ChainedTokenCredential(miCredential, vsCredential, vscCredential, azureCliCredential));

var tasks = new List<Task<AccessToken>>();
for (int i = 0; i < 10; i++)
{
tasks.Add(Task.Run(async () => await credential.GetTokenAsync(new TokenRequestContext(new[] {"https://vault.azure.net/.default"}), CancellationToken.None)));
}

await Task.WhenAll(tasks);

foreach (Task<AccessToken> task in tasks)
{
Assert.AreEqual(task.Result.Token, expectedToken);
Assert.AreEqual(task.Result.ExpiresOn, expectedExpiresOn);
}
}

[Test]
public void ChainedTokenCredential_AllCredentialsHaveFailed_CredentialUnavailableException()
{
var vscAdapter = new TestVscAdapter(ExpectedServiceName, "Azure", "{}");

var pipeline = CredentialPipeline.GetInstance(null);
var fileSystem = new TestFileSystemService();
var processService = new TestProcessService(new TestProcess { Error = "'az' is not recognized" });

var vsCredential = new VisualStudioCredential(default, pipeline, fileSystem, processService);
var vscCredential = new VisualStudioCodeCredential(new VisualStudioCodeCredentialOptions { TenantId = TestEnvironment.TestTenantId }, pipeline, default, fileSystem, vscAdapter);
var azureCliCredential = new AzureCliCredential(pipeline, processService);

var credential = InstrumentClient(new ChainedTokenCredential(vsCredential, vscCredential, azureCliCredential));

List<ClientDiagnosticListener.ProducedDiagnosticScope> scopes;
using (ClientDiagnosticListener diagnosticListener = new ClientDiagnosticListener(s => s.StartsWith("Azure.Identity")))
{
Assert.CatchAsync<CredentialUnavailableException>(async () => await credential.GetTokenAsync(new TokenRequestContext(new[] {"https://vault.azure.net/.default"}), CancellationToken.None));
scopes = diagnosticListener.Scopes;
}

Assert.AreEqual(4, scopes.Count);
Assert.AreEqual($"{nameof(ChainedTokenCredential)}.{nameof(ChainedTokenCredential.GetToken)}", scopes[0].Name);
Assert.AreEqual($"{nameof(VisualStudioCredential)}.{nameof(VisualStudioCredential.GetToken)}", scopes[1].Name);
Assert.AreEqual($"{nameof(VisualStudioCodeCredential)}.{nameof(VisualStudioCodeCredential.GetToken)}", scopes[2].Name);
Assert.AreEqual($"{nameof(AzureCliCredential)}.{nameof(AzureCliCredential.GetToken)}", scopes[3].Name);
}

[Test]
public void ChainedTokenCredential_AllCredentialsHaveFailed_AuthenticationFailedException()
{
var pipeline = CredentialPipeline.GetInstance(null);
var vscAdapter = new TestVscAdapter(ExpectedServiceName, "Azure", null);
var fileSystem = new TestFileSystemService();
var processService = new TestProcessService(new TestProcess {Error = "Error"});

var miCredential = new ManagedIdentityCredential(EnvironmentVariables.ClientId, pipeline);
var vsCredential = new VisualStudioCredential(default, pipeline, fileSystem, processService);
var vscCredential = new VisualStudioCodeCredential(new VisualStudioCodeCredentialOptions { TenantId = TestEnvironment.TestTenantId }, pipeline, default, fileSystem, vscAdapter);
var azureCliCredential = new AzureCliCredential(pipeline, processService);

var credential = InstrumentClient(new ChainedTokenCredential(miCredential, vsCredential, vscCredential, azureCliCredential));

List<ClientDiagnosticListener.ProducedDiagnosticScope> scopes;
using (ClientDiagnosticListener diagnosticListener = new ClientDiagnosticListener(s => s.StartsWith("Azure.Identity")))
{
Assert.CatchAsync<AuthenticationFailedException>(async () => await credential.GetTokenAsync(new TokenRequestContext(new[] {"https://vault.azure.net/.default"}), CancellationToken.None));
scopes = diagnosticListener.Scopes;
}

Assert.AreEqual(5, scopes.Count);
Assert.AreEqual($"{nameof(ChainedTokenCredential)}.{nameof(ChainedTokenCredential.GetToken)}", scopes[0].Name);
Assert.AreEqual($"{nameof(ManagedIdentityCredential)}.{nameof(ManagedIdentityCredential.GetToken)}", scopes[1].Name);
Assert.AreEqual($"{nameof(VisualStudioCredential)}.{nameof(VisualStudioCredential.GetToken)}", scopes[2].Name);
Assert.AreEqual($"{nameof(VisualStudioCodeCredential)}.{nameof(VisualStudioCodeCredential.GetToken)}", scopes[3].Name);
Assert.AreEqual($"{nameof(AzureCliCredential)}.{nameof(AzureCliCredential.GetToken)}", scopes[4].Name);
}
}
}

0 comments on commit bd6412f

Please sign in to comment.