diff --git a/src/Containers/Microsoft.NET.Build.Containers/AuthHandshakeMessageHandler.cs b/src/Containers/Microsoft.NET.Build.Containers/AuthHandshakeMessageHandler.cs index 8e4bc29b7656..46761e94df7c 100644 --- a/src/Containers/Microsoft.NET.Build.Containers/AuthHandshakeMessageHandler.cs +++ b/src/Containers/Microsoft.NET.Build.Containers/AuthHandshakeMessageHandler.cs @@ -71,18 +71,19 @@ private static bool TryParseAuthenticationInfo(HttpResponseMessage msg, [NotNull if (header.Scheme is not null) { scheme = header.Scheme; - var keyValues = ParseBearerArgs(header.Parameter); - if (keyValues is null) - { - return false; - } if (header.Scheme.Equals(BasicAuthScheme, StringComparison.OrdinalIgnoreCase)) { - return TryParseBasicAuthInfo(keyValues, msg.RequestMessage!.RequestUri!, out bearerAuthInfo); + bearerAuthInfo = null; + return true; } else if (header.Scheme.Equals(BearerAuthScheme, StringComparison.OrdinalIgnoreCase)) { + var keyValues = ParseBearerArgs(header.Parameter); + if (keyValues is null) + { + return false; + } return TryParseBearerAuthInfo(keyValues, out bearerAuthInfo); } else @@ -110,12 +111,6 @@ static bool TryParseBearerAuthInfo(Dictionary authValues, [NotNu } } - static bool TryParseBasicAuthInfo(Dictionary authValues, Uri requestUri, out AuthInfo? authInfo) - { - authInfo = null; - return true; - } - static Dictionary? ParseBearerArgs(string? bearerHeaderArgs) { if (bearerHeaderArgs is null) @@ -159,7 +154,6 @@ public DateTimeOffset ResolvedExpiration /// private async Task<(AuthenticationHeaderValue, DateTimeOffset)?> GetAuthenticationAsync(string registry, string scheme, AuthInfo? bearerAuthInfo, CancellationToken cancellationToken) { - DockerCredentials? privateRepoCreds; // Allow overrides for auth via environment variables if (GetDockerCredentialsFromEnvironment(_registryMode) is (string credU, string credP)) @@ -180,14 +174,20 @@ public DateTimeOffset ResolvedExpiration { Debug.Assert(bearerAuthInfo is not null); - var authenticationValueAndDuration = await TryOAuthPostAsync(privateRepoCreds, bearerAuthInfo, cancellationToken).ConfigureAwait(false); - if (authenticationValueAndDuration is not null) + // Obtain a Bearer token, when the credentials are: + // - an identity token: use it for OAuth + // - a username/password: use them for Basic auth, and fall back to OAuth + + if (string.IsNullOrWhiteSpace(privateRepoCreds.IdentityToken)) { - return authenticationValueAndDuration; + var authenticationValueAndDuration = await TryTokenGetAsync(privateRepoCreds, bearerAuthInfo, cancellationToken).ConfigureAwait(false); + if (authenticationValueAndDuration is not null) + { + return authenticationValueAndDuration; + } } - authenticationValueAndDuration = await TryTokenGetAsync(privateRepoCreds, bearerAuthInfo, cancellationToken).ConfigureAwait(false); - return authenticationValueAndDuration; + return await TryOAuthPostAsync(privateRepoCreds, bearerAuthInfo, cancellationToken).ConfigureAwait(false); } else { @@ -293,8 +293,7 @@ internal static (string credU, string credP)? GetDockerCredentialsFromEnvironmen if (!postResponse.IsSuccessStatusCode) { await postResponse.LogHttpResponseAsync(_logger, cancellationToken).ConfigureAwait(false); - //return null to try HTTP GET instead - return null; + return null; // try next method } _logger.LogTrace("Received '{statuscode}'.", postResponse.StatusCode); TokenResponse? tokenResponse = JsonSerializer.Deserialize(postResponse.Content.ReadAsStream(cancellationToken)); @@ -306,8 +305,7 @@ internal static (string credU, string credP)? GetDockerCredentialsFromEnvironmen else { _logger.LogTrace(Resource.GetString(nameof(Strings.CouldntDeserializeJsonToken))); - // logging and returning null to try HTTP GET instead - return null; + return null; // try next method } } @@ -318,9 +316,7 @@ internal static (string credU, string credP)? GetDockerCredentialsFromEnvironmen { // this doesn't seem to be called out in the spec, but actual username/password auth information should be converted into Basic auth here, // even though the overall Scheme we're authenticating for is Bearer - var header = privateRepoCreds.Username == "" - ? new AuthenticationHeaderValue(BearerAuthScheme, privateRepoCreds.Password) - : new AuthenticationHeaderValue(BasicAuthScheme, Convert.ToBase64String(Encoding.ASCII.GetBytes($"{privateRepoCreds.Username}:{privateRepoCreds.Password}"))); + var header = new AuthenticationHeaderValue(BasicAuthScheme, Convert.ToBase64String(Encoding.ASCII.GetBytes($"{privateRepoCreds.Username}:{privateRepoCreds.Password}"))); var builder = new UriBuilder(new Uri(bearerAuthInfo.Realm)); _logger.LogTrace("Attempting to authenticate on {uri} using GET.", bearerAuthInfo.Realm); @@ -340,7 +336,8 @@ internal static (string credU, string credP)? GetDockerCredentialsFromEnvironmen using var tokenResponse = await base.SendAsync(message, cancellationToken).ConfigureAwait(false); if (!tokenResponse.IsSuccessStatusCode) { - throw new UnableToAccessRepositoryException(_registryName); + await tokenResponse.LogHttpResponseAsync(_logger, cancellationToken).ConfigureAwait(false); + return null; // try next method } TokenResponse? token = JsonSerializer.Deserialize(tokenResponse.Content.ReadAsStream(cancellationToken)); @@ -412,7 +409,8 @@ protected override async Task SendAsync(HttpRequestMessage request.Headers.Authorization = authHeader; return await base.SendAsync(request, cancellationToken).ConfigureAwait(false); } - return response; + + throw new UnableToAccessRepositoryException(_registryName); } else { diff --git a/test/Microsoft.NET.Build.Containers.UnitTests/AuthHandshakeMessageHandlerTests.cs b/test/Microsoft.NET.Build.Containers.UnitTests/AuthHandshakeMessageHandlerTests.cs index 5089658bd859..bbb0376c550d 100644 --- a/test/Microsoft.NET.Build.Containers.UnitTests/AuthHandshakeMessageHandlerTests.cs +++ b/test/Microsoft.NET.Build.Containers.UnitTests/AuthHandshakeMessageHandlerTests.cs @@ -1,9 +1,20 @@ // Licensed to the .NET Foundation under one or more agreements. // The .NET Foundation licenses this file to you under the MIT license. + +using System.Net; +using System.Web; +using System.Net.Http.Headers; +using System.Collections.Specialized; +using Microsoft.Extensions.Logging.Abstractions; + namespace Microsoft.NET.Build.Containers.UnitTests { public class AuthHandshakeMessageHandlerTests { + private const string TestRegistryName = "registry.test"; + private const string RequestUrl = $"https://{TestRegistryName}/v2"; + private const string BearerRealmUrl = $"https://bearer.test/token"; + [Theory] [InlineData("SDK_CONTAINER_REGISTRY_UNAME", "SDK_CONTAINER_REGISTRY_PWORD", (int)RegistryMode.Push)] [InlineData("DOTNET_CONTAINER_PUSH_REGISTRY_UNAME", "DOTNET_CONTAINER_PUSH_REGISTRY_PWORD", (int)RegistryMode.Push)] @@ -33,5 +44,265 @@ public void GetDockerCredentialsFromEnvironment_ReturnsCorrectValues(string unam Environment.SetEnvironmentVariable(unameVarName, originalUnameValue); Environment.SetEnvironmentVariable(pwordVarName, originalPwordValue); } + + [Theory] + [MemberData(nameof(GetAuthenticateTestData))] + public async Task Authenticate(string authConf, Func server) + { + string authFile = Path.GetTempFileName(); + try + { + File.WriteAllText(authFile, authConf); + Environment.SetEnvironmentVariable("REGISTRY_AUTH_FILE", authFile); + + var authHandler = new AuthHandshakeMessageHandler(TestRegistryName, new ServerMessageHandler(server), NullLogger.Instance, RegistryMode.Push); + using var httpClient = new HttpClient(authHandler); + + var response = await httpClient.GetAsync(RequestUrl); + Assert.Equal(HttpStatusCode.OK, response.StatusCode); + } + finally + { + try + { + File.Delete(authFile); + } + catch + { } + } + } + + public static IEnumerable GetAuthenticateTestData() + { + // Check auth with username and password. + // The '' username has a special meaning that is already handled by the docker-creds-provider library. + // We cover it it in the test to verify the SDK doesn't handled it special. + string password = "pass"; + string username = "user"; + foreach (string user in new[] { username, ""}) + { + // Basic auth + yield return new object[] { + ConfigAuthWithUserAndPassword(user, password), + ServerWithBasicAuth(user, password) + }; + + // Basic auth for token + yield return new object[] { + ConfigAuthWithUserAndPassword(user, password), + ServerWithBasicAuthForToken($"realm=\"{BearerRealmUrl}\"", BearerRealmUrl, user, password, + queryParameters: new()) + }; + + // OAuth password auth + yield return new object[] { + ConfigAuthWithUserAndPassword(user, password), + ServerWithOAuthForToken($"realm=\"{BearerRealmUrl}\"", BearerRealmUrl, + formParameters: new() + { + { "client_id", "netsdkcontainers" }, + { "grant_type", "password" }, + { "username", user }, + { "password", password } + }) + }; + } + + // Check auth with an identity token. + string identityToken = "my-identity-token"; + yield return new object[] { + ConfigAuthWithIdentityToken(identityToken), + ServerWithOAuthForToken($"realm=\"{BearerRealmUrl}\"", BearerRealmUrl, + formParameters: new() + { + { "client_id", "netsdkcontainers" }, + { "grant_type", "refresh_token" }, + { "refresh_token", identityToken } + }) + }; + + // Verify the bearer parameters (service/scope) are passed. + // With OAuth auth as form parameters + string scope = "my-scope"; + string service = "my-service"; + yield return new object[] { + ConfigAuthWithIdentityToken(identityToken), + ServerWithOAuthForToken($"realm=\"{BearerRealmUrl}\", service={service}, scope={scope}", BearerRealmUrl, + formParameters: new() + { + { "client_id", "netsdkcontainers" }, + { "grant_type", "refresh_token" }, + { "refresh_token", identityToken }, + { "service", service }, + { "scope", scope } + }) + }; + // With Basic auth as query parameters + yield return new object[] { + ConfigAuthWithUserAndPassword(username, password), + ServerWithBasicAuthForToken($"realm=\"{BearerRealmUrl}\", service={service}, scope={scope}", BearerRealmUrl, username, password, + queryParameters: new() + { + { "service", service }, + { "scope", scope } + }) + }; + + static string ConfigAuthWithUserAndPassword(string username, string password) => + $$""" + { + "auths": { + "{{TestRegistryName}}": { + "auth": "{{GetUserPasswordBase64(username, password)}}" + } + } + } + """; + + static string ConfigAuthWithIdentityToken(string identityToken) => + $$""" + { + "auths": { + "{{TestRegistryName}}": { + "identitytoken": {{identityToken}}, + "auth": "{{GetUserPasswordBase64("__", "__")}}" + } + } + } + """; + } + + static string GetUserPasswordBase64(string username, string password) + => Convert.ToBase64String(Encoding.ASCII.GetBytes($"{username}:{password}")); + + static Func ServerWithBasicAuth(string username, string password) + { + return (HttpRequestMessage request) => + { + if (request.RequestUri?.ToString() == RequestUrl && + IsBasicAuthenticated(request, username, password)) + { + return new HttpResponseMessage(HttpStatusCode.OK); + } + + return CreateRequestAuthenticateResponse("Basic", ""); + }; + + static bool IsBasicAuthenticated(HttpRequestMessage requestMessage, string username, string password) + { + AuthenticationHeaderValue? header = requestMessage.Headers.Authorization; + if (header is null) + { + return false; + } + return header.Scheme == "Basic" && header.Parameter == GetUserPasswordBase64(username, password); + } + } + + static Func ServerWithBasicAuthForToken(string authenticateParameters, string requestUri, string username, string password, Dictionary queryParameters) + => ServerWithBearerAuth(authenticateParameters, requestUri, HttpMethod.Get, queryParameters, new(), new AuthenticationHeaderValue("Basic", GetUserPasswordBase64(username, password))); + + static Func ServerWithOAuthForToken(string authenticateParameters, string requestUri, Dictionary formParameters) + => ServerWithBearerAuth(authenticateParameters, requestUri, HttpMethod.Post, new(), formParameters, null); + + static Func ServerWithBearerAuth(string authenticateParameters, string requestUri, HttpMethod method, Dictionary queryParameters, Dictionary formParameters, AuthenticationHeaderValue? authHeader) + { + const string BearerToken = "my-bearer-token"; + + return (HttpRequestMessage request) => + { + if (request.RequestUri?.ToString() == RequestUrl && + IsBearerAuthenticated(request, BearerToken)) + { + return new HttpResponseMessage(HttpStatusCode.OK); + } + + if (request.RequestUri?.ToString() == BearerRealmUrl) + { + // Verify the method is the expected one. + Assert.Equal(method, request.Method); + + // Verify the query parameter are the expected ones. + AssertParametersAreEqual(queryParameters, request.RequestUri.Query); + + // Verify the auth header is the expected one. + AuthenticationHeaderValue? header = request.Headers.Authorization; + if (authHeader is not null) + { + Assert.NotNull(header); + Assert.Equal(header.Scheme, authHeader.Scheme); + Assert.Equal(header.Parameter, authHeader.Parameter); + } + else + { + Assert.Null(header); + } + + // Verify the content. + string content = request.Content is null ? "" : request.Content.ReadAsStringAsync().Result; + AssertParametersAreEqual(formParameters, content); + + // Issue the token. + return CreateBearerTokenResponse(BearerToken); + } + + return CreateRequestAuthenticateResponse("Bearer", authenticateParameters); + }; + + static bool IsBearerAuthenticated(HttpRequestMessage requestMessage, string bearerToken) + { + AuthenticationHeaderValue? header = requestMessage.Headers.Authorization; + if (header is null) + { + return false; + } + return header.Scheme == "Bearer" && header.Parameter == bearerToken; + } + + static void AssertParametersAreEqual(Dictionary expected, string actual) + { + NameValueCollection parsedParameters = HttpUtility.ParseQueryString(actual); + foreach (var parameter in expected) + { + Assert.Equal(parameter.Value, parsedParameters.Get(parameter.Key)); + } + Assert.Equal(expected.Count, parsedParameters.AllKeys.Length); + } + } + + static HttpResponseMessage CreateRequestAuthenticateResponse(string scheme, string parameter) + { + var response = new HttpResponseMessage(HttpStatusCode.Unauthorized); + response.Headers.WwwAuthenticate.Add(new AuthenticationHeaderValue(scheme, parameter)); + return response; + } + + static HttpResponseMessage CreateBearerTokenResponse(string bearerToken) + { + var response = new HttpResponseMessage(HttpStatusCode.OK); + string json = + $$""" + { + "token": "{{bearerToken}}" + } + """; + response.Content = new ByteArrayContent(Encoding.UTF8.GetBytes(json)); + return response; + } + + private sealed class ServerMessageHandler : HttpMessageHandler + { + private readonly Func _server; + + public ServerMessageHandler(Func server) + { + _server = server; + } + + protected override Task SendAsync(HttpRequestMessage request, CancellationToken cancellationToken) + { + return Task.FromResult(_server(request)); + } + } } }