Skip to content
This repository was archived by the owner on Jan 23, 2023. It is now read-only.
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,299 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.

using System.Collections.Generic;
using System.Diagnostics;
using System.Security.Cryptography;
using System.Text;
using System.Threading.Tasks;

namespace System.Net.Test.Common
{
public sealed partial class LoopbackServer
{
internal enum AuthenticationProtocols
{
Basic,
Digest,
None
}

public async Task<List<string>> AcceptConnectionPerformAuthenticationAndCloseAsync(string authenticateHeaders)
{
List<string> lines = null;
await AcceptConnectionAsync(async connection =>
{
await connection.ReadRequestHeaderAndSendResponseAsync(HttpStatusCode.Unauthorized, authenticateHeaders);

lines = await connection.ReadRequestHeaderAsync();
Debug.Assert(lines.Count > 0);

int index = lines[0] != null ? lines[0].IndexOf(' ') : -1;
string requestMethod = null;
if (index != -1)
{
requestMethod = lines[0].Substring(0, index);
}

// Read the authorization header from client.
AuthenticationProtocols protocol = AuthenticationProtocols.None;
string clientResponse = null;
for (int i = 1; i < lines.Count; i++)
{
if (lines[i].StartsWith("Authorization"))
{
clientResponse = lines[i];
if (lines[i].Contains(nameof(AuthenticationProtocols.Basic)))
{
protocol = AuthenticationProtocols.Basic;
break;
}
else if (lines[i].Contains(nameof(AuthenticationProtocols.Digest)))
{
protocol = AuthenticationProtocols.Digest;
break;
}
}
}

bool success = false;
switch (protocol)
{
case AuthenticationProtocols.Basic:
success = IsBasicAuthTokenValid(clientResponse, _options);
break;

case AuthenticationProtocols.Digest:
// Read the request content.
success = IsDigestAuthTokenValid(clientResponse, requestMethod, _options);
break;
}

if (success)
{
await connection.SendResponseAsync();
}
else
{
await connection.SendResponseAsync(HttpStatusCode.Unauthorized, authenticateHeaders);
}
});

return lines;
}

internal static bool IsBasicAuthTokenValid(string clientResponse, LoopbackServer.Options options)
{
string clientHash = clientResponse.Substring(clientResponse.IndexOf(nameof(AuthenticationProtocols.Basic), StringComparison.OrdinalIgnoreCase) +
nameof(AuthenticationProtocols.Basic).Length).Trim();
string userPass = string.IsNullOrEmpty(options.Domain) ? options.Username + ":" + options.Password : options.Domain + "\\" + options.Username + ":" + options.Password;
return clientHash == Convert.ToBase64String(Encoding.UTF8.GetBytes(userPass));
}

internal static bool IsDigestAuthTokenValid(string clientResponse, string requestMethod, LoopbackServer.Options options)
{
string clientHash = clientResponse.Substring(clientResponse.IndexOf(nameof(AuthenticationProtocols.Digest), StringComparison.OrdinalIgnoreCase) +
nameof(AuthenticationProtocols.Digest).Length).Trim();
string[] values = clientHash.Split(',');

string username = null, uri = null, realm = null, nonce = null, response = null, algorithm = null, cnonce = null, opaque = null, qop = null, nc = null;
bool userhash = false;
for (int i = 0; i < values.Length; i++)
{
string trimmedValue = values[i].Trim();
if (trimmedValue.Contains(nameof(username)))
{
// Username is a quoted string.
int startIndex = trimmedValue.IndexOf('"');

if (startIndex != -1)
{
startIndex += 1;
username = trimmedValue.Substring(startIndex, trimmedValue.Length - startIndex - 1);
}

// Username is mandatory.
if (string.IsNullOrEmpty(username))
return false;
}
else if (trimmedValue.Contains(nameof(userhash)) && trimmedValue.Contains("true"))
{
userhash = true;
}
else if (trimmedValue.Contains(nameof(uri)))
{
int startIndex = trimmedValue.IndexOf('"');
if (startIndex != -1)
{
startIndex += 1;
uri = trimmedValue.Substring(startIndex, trimmedValue.Length - startIndex - 1);
}

// Request uri is mandatory.
if (string.IsNullOrEmpty(uri))
return false;
}
else if (trimmedValue.Contains(nameof(realm)))
{
// Realm is a quoted string.
int startIndex = trimmedValue.IndexOf('"');
if (startIndex != -1)
{
startIndex += 1;
realm = trimmedValue.Substring(startIndex, trimmedValue.Length - startIndex - 1);
}

// Realm is mandatory.
if (string.IsNullOrEmpty(realm))
return false;
}
else if (trimmedValue.Contains(nameof(cnonce)))
{
// CNonce is a quoted string.
int startIndex = trimmedValue.IndexOf('"');
if (startIndex != -1)
{
startIndex += 1;
cnonce = trimmedValue.Substring(startIndex, trimmedValue.Length - startIndex - 1);
}
}
else if (trimmedValue.Contains(nameof(nonce)))
{
// Nonce is a quoted string.
int startIndex = trimmedValue.IndexOf('"');
if (startIndex != -1)
{
startIndex += 1;
nonce = trimmedValue.Substring(startIndex, trimmedValue.Length - startIndex - 1);
}

// Nonce is mandatory.
if (string.IsNullOrEmpty(nonce))
return false;
}
else if (trimmedValue.Contains(nameof(response)))
{
// response is a quoted string.
int startIndex = trimmedValue.IndexOf('"');
if (startIndex != -1)
{
startIndex += 1;
response = trimmedValue.Substring(startIndex, trimmedValue.Length - startIndex - 1);
}

// Response is mandatory.
if (string.IsNullOrEmpty(response))
return false;
}
else if (trimmedValue.Contains(nameof(algorithm)))
{
int startIndex = trimmedValue.IndexOf('=');
if (startIndex != -1)
{
startIndex += 1;
algorithm = trimmedValue.Substring(startIndex, trimmedValue.Length - startIndex).Trim();
}
}
else if (trimmedValue.Contains(nameof(opaque)))
{
// Opaque is a quoted string.
int startIndex = trimmedValue.IndexOf('"');
if (startIndex != -1)
{
startIndex += 1;
opaque = trimmedValue.Substring(startIndex, trimmedValue.Length - startIndex - 1);
}
}
else if (trimmedValue.Contains(nameof(qop)))
{
int startIndex = trimmedValue.IndexOf('"');
if (startIndex != -1)
{
startIndex += 1;
qop = trimmedValue.Substring(startIndex, trimmedValue.Length - startIndex - 1);
}
else if ((startIndex = trimmedValue.IndexOf('=')) != -1)
{
startIndex += 1;
qop = trimmedValue.Substring(startIndex, trimmedValue.Length - startIndex).Trim();
}
}
else if (trimmedValue.Contains(nameof(nc)))
{
int startIndex = trimmedValue.IndexOf('=');
if (startIndex != -1)
{
startIndex += 1;
nc = trimmedValue.Substring(startIndex, trimmedValue.Length - startIndex).Trim();
}
}
}

// Verify username.
if (userhash && ComputeHash(options.Username + ":" + realm, algorithm) != username)
{
return false;
}

if (!userhash && options.Username != username)
{
return false;
}

if (string.IsNullOrEmpty(algorithm))
algorithm = "sha-256";

// Calculate response and compare with the client response hash.
string a1 = options.Username + ":" + realm + ":" + options.Password;
if (algorithm.Contains("sess"))
{
a1 = ComputeHash(a1, algorithm) + ":" + nonce;

if (cnonce != null)
a1 += ":" + cnonce;
}

string a2 = requestMethod + ":" + uri;
if (!string.IsNullOrEmpty(qop) && qop.Equals("auth-int"))
{
// Request content is empty.
a2 = a2 + ":" + ComputeHash(string.Empty, algorithm);
}

string serverResponseHash = ComputeHash(a1, algorithm) + ":" + nonce + ":";

if (nc != null)
serverResponseHash += nc + ":";

if (cnonce != null)
serverResponseHash += cnonce + ":";

if (qop != null)
serverResponseHash += qop + ":";

serverResponseHash += ComputeHash(a2, algorithm);
serverResponseHash = ComputeHash(serverResponseHash, algorithm);

return response == serverResponseHash;
}

private static string ComputeHash(string data, string algorithm)
{
// Disable MD5 insecure warning.
#pragma warning disable CA5351
using (HashAlgorithm hash = algorithm.Contains("SHA-256") ? SHA256.Create() : (HashAlgorithm)MD5.Create())
#pragma warning restore CA5351
{
Encoding enc = Encoding.UTF8;
byte[] result = hash.ComputeHash(enc.GetBytes(data));

StringBuilder sb = new StringBuilder(result.Length * 2);
foreach (byte b in result)
sb.Append(b.ToString("x2"));

return sb.ToString();
}
}
}
}
Loading