Skip to content
This repository was archived by the owner on Jan 23, 2023. It is now read-only.

Commit 5efbacc

Browse files
author
Lakshmi Priya Sekar
committed
Add test infra for auth testing.
1 parent 38ff6d1 commit 5efbacc

File tree

4 files changed

+403
-15
lines changed

4 files changed

+403
-15
lines changed

src/Common/tests/System/Net/Http/LoopbackServer.cs

Lines changed: 259 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,13 @@ public class LoopbackServer
2222
{
2323
public static Func<HttpRequestMessage, X509Certificate2, X509Chain, SslPolicyErrors, bool> AllowAllCertificates = (_, __, ___, ____) => true;
2424

25+
private enum AuthenticationProtocols
26+
{
27+
Basic,
28+
Digest,
29+
None
30+
}
31+
2532
public class Options
2633
{
2734
public IPAddress Address { get; set; } = IPAddress.Loopback;
@@ -30,6 +37,9 @@ public class Options
3037
public SslProtocols SslProtocols { get; set; } = SslProtocols.Tls | SslProtocols.Tls11 | SslProtocols.Tls12;
3138
public bool WebSocketEndpoint { get; set; } = false;
3239
public Func<Stream, Stream> ResponseStreamWrapper { get; set; }
40+
public string Domain { get; set; }
41+
public string Username { get; set; }
42+
public string Password { get; set; }
3343
}
3444

3545
public static Task CreateServerAsync(Func<Socket, Uri, Task> funcAsync, Options options = null)
@@ -49,7 +59,7 @@ public static Task CreateServerAsync(Func<Socket, Uri, Task> funcAsync, out IPEn
4959
server.Listen(options.ListenBacklog);
5060

5161
localEndPoint = (IPEndPoint)server.LocalEndPoint;
52-
string host = options.Address.AddressFamily == AddressFamily.InterNetworkV6 ?
62+
string host = options.Address.AddressFamily == AddressFamily.InterNetworkV6 ?
5363
$"[{localEndPoint.Address}]" :
5464
localEndPoint.Address.ToString();
5565

@@ -100,6 +110,11 @@ public static Task<List<string>> ReadRequestAndSendResponseAsync(Socket server,
100110
return AcceptSocketAsync(server, (s, stream, reader, writer) => ReadWriteAcceptedAsync(s, reader, writer, response), options);
101111
}
102112

113+
public static Task<List<string>> ReadRequestAndAuthenticateAsync(Socket server, string response, Options options)
114+
{
115+
return AcceptSocketAsync(server, (s, stream, reader, writer) => ValidateAuthenticationAsync(s, reader, writer, response, options), options);
116+
}
117+
103118
public static async Task<List<string>> ReadWriteAcceptedAsync(Socket s, StreamReader reader, StreamWriter writer, string response = null)
104119
{
105120
// Read request line and headers. Skip any request body.
@@ -115,6 +130,247 @@ public static async Task<List<string>> ReadWriteAcceptedAsync(Socket s, StreamRe
115130
return lines;
116131
}
117132

133+
public static async Task<List<string>> ValidateAuthenticationAsync(Socket s, StreamReader reader, StreamWriter writer, string response, Options options)
134+
{
135+
// Send unauthorized response from server.
136+
await ReadWriteAcceptedAsync(s, reader, writer, response);
137+
138+
// Read the request method.
139+
string line = await reader.ReadLineAsync().ConfigureAwait(false);
140+
int index = line != null ? line.IndexOf(' ') : -1;
141+
string requestMethod = null;
142+
if (index != -1)
143+
{
144+
requestMethod = line.Substring(0, index);
145+
}
146+
147+
// Read the authorization header from client.
148+
AuthenticationProtocols protocol = AuthenticationProtocols.None;
149+
string clientResponse = null;
150+
while (!string.IsNullOrEmpty(line = await reader.ReadLineAsync().ConfigureAwait(false)))
151+
{
152+
if (line.StartsWith("Authorization"))
153+
{
154+
clientResponse = line;
155+
if (line.Contains(nameof(AuthenticationProtocols.Basic), StringComparison.OrdinalIgnoreCase))
156+
{
157+
protocol = AuthenticationProtocols.Basic;
158+
break;
159+
}
160+
else if (line.Contains(nameof(AuthenticationProtocols.Digest), StringComparison.OrdinalIgnoreCase))
161+
{
162+
protocol = AuthenticationProtocols.Digest;
163+
break;
164+
}
165+
}
166+
}
167+
168+
bool success = false;
169+
switch (protocol)
170+
{
171+
case AuthenticationProtocols.Basic:
172+
success = IsBasicAuthTokenValid(line, options);
173+
break;
174+
175+
case AuthenticationProtocols.Digest:
176+
// Read the request content.
177+
string requestContent = null;
178+
while (!string.IsNullOrEmpty(line = await reader.ReadLineAsync().ConfigureAwait(false)))
179+
{
180+
if (line.Contains("Content-Length"))
181+
{
182+
line = await reader.ReadLineAsync().ConfigureAwait(false);
183+
while (!string.IsNullOrEmpty(line = await reader.ReadLineAsync().ConfigureAwait(false)))
184+
{
185+
requestContent += line;
186+
}
187+
}
188+
}
189+
190+
success = IsDigestAuthTokenValid(clientResponse, requestContent, requestMethod, options);
191+
break;
192+
}
193+
194+
if (success)
195+
{
196+
await writer.WriteAsync(DefaultHttpResponse).ConfigureAwait(false);
197+
}
198+
else
199+
{
200+
await writer.WriteAsync(response).ConfigureAwait(false);
201+
}
202+
203+
return null;
204+
}
205+
206+
private static bool IsBasicAuthTokenValid(string clientResponse, Options options)
207+
{
208+
string clientHash = clientResponse.Substring(clientResponse.IndexOf(nameof(AuthenticationProtocols.Basic), StringComparison.OrdinalIgnoreCase) +
209+
nameof(AuthenticationProtocols.Basic).Length).Trim();
210+
string userPass = string.IsNullOrEmpty(options.Domain) ? options.Username + ":" + options.Password : options.Domain + "\\" + options.Username + ":" + options.Password;
211+
return clientHash == Convert.ToBase64String(Encoding.UTF8.GetBytes(userPass));
212+
}
213+
214+
private static bool IsDigestAuthTokenValid(string clientResponse, string requestContent, string requestMethod, Options options)
215+
{
216+
string clientHash = clientResponse.Substring(clientResponse.IndexOf(nameof(AuthenticationProtocols.Digest), StringComparison.OrdinalIgnoreCase) +
217+
nameof(AuthenticationProtocols.Digest).Length).Trim();
218+
string[] values = clientHash.Split(',');
219+
220+
string username = null, uri = null, realm = null, nonce = null, response = null, algorithm = null, cnonce = null, opaque = null, qop = null, nc = null;
221+
bool userhash = false;
222+
for (int i = 0; i < values.Length; i++)
223+
{
224+
string trimmedValue = values[i].Trim();
225+
if (trimmedValue.Contains(nameof(username)))
226+
{
227+
// Username is a quoted string.
228+
int startIndex = trimmedValue.IndexOf('"') + 1;
229+
230+
if (startIndex != -1)
231+
username = trimmedValue.Substring(startIndex, trimmedValue.Length - startIndex - 1);
232+
233+
// Username is mandatory.
234+
if (string.IsNullOrEmpty(username))
235+
return false;
236+
}
237+
if (trimmedValue.Contains(nameof(userhash)) && trimmedValue.Contains("true"))
238+
{
239+
userhash = true;
240+
}
241+
else if (trimmedValue.Contains(nameof(uri)))
242+
{
243+
int startIndex = trimmedValue.IndexOf('"') + 1;
244+
if (startIndex != -1)
245+
uri = trimmedValue.Substring(startIndex, trimmedValue.Length - startIndex - 1);
246+
247+
// Request uri is mandatory.
248+
if (string.IsNullOrEmpty(uri))
249+
return false;
250+
}
251+
else if (trimmedValue.Contains(nameof(realm)))
252+
{
253+
// Realm is a quoted string.
254+
int startIndex = trimmedValue.IndexOf('"') + 1;
255+
if (startIndex != -1)
256+
realm = trimmedValue.Substring(startIndex, trimmedValue.Length - startIndex - 1);
257+
258+
// Realm is mandatory.
259+
if (string.IsNullOrEmpty(realm))
260+
return false;
261+
}
262+
else if (trimmedValue.Contains(nameof(cnonce)))
263+
{
264+
// CNonce is a quoted string.
265+
int startIndex = trimmedValue.IndexOf('"') + 1;
266+
if (startIndex != -1)
267+
cnonce = trimmedValue.Substring(startIndex, trimmedValue.Length - startIndex - 1);
268+
}
269+
else if (trimmedValue.Contains(nameof(nonce)))
270+
{
271+
// Nonce is a quoted string.
272+
int startIndex = trimmedValue.IndexOf('"') + 1;
273+
if (startIndex != -1)
274+
nonce = trimmedValue.Substring(startIndex, trimmedValue.Length - startIndex - 1);
275+
276+
// Nonce is mandatory.
277+
if (string.IsNullOrEmpty(nonce))
278+
return false;
279+
}
280+
else if (trimmedValue.Contains(nameof(response)))
281+
{
282+
// response is a quoted string.
283+
int startIndex = trimmedValue.IndexOf('"') + 1;
284+
if (startIndex != -1)
285+
response = trimmedValue.Substring(startIndex, trimmedValue.Length - startIndex - 1);
286+
287+
// Response is mandatory.
288+
if (string.IsNullOrEmpty(response))
289+
return false;
290+
}
291+
else if (trimmedValue.Contains(nameof(algorithm)))
292+
{
293+
int startIndex = trimmedValue.IndexOf('=') + 1;
294+
if (startIndex != -1)
295+
algorithm = trimmedValue.Substring(startIndex, trimmedValue.Length - startIndex).Trim();
296+
297+
if (string.IsNullOrEmpty(algorithm))
298+
algorithm = "sha-256";
299+
}
300+
else if (trimmedValue.Contains(nameof(opaque)))
301+
{
302+
// Opaque is a quoted string.
303+
int startIndex = trimmedValue.IndexOf('"') + 1;
304+
if (startIndex != -1)
305+
opaque = trimmedValue.Substring(startIndex, trimmedValue.Length - startIndex - 1);
306+
}
307+
else if (trimmedValue.Contains(nameof(qop)))
308+
{
309+
int startIndex = trimmedValue.IndexOf('=') + 1;
310+
if (startIndex != -1)
311+
qop = trimmedValue.Substring(startIndex, trimmedValue.Length - startIndex).Trim();
312+
}
313+
else if (trimmedValue.Contains(nameof(nc)))
314+
{
315+
int startIndex = trimmedValue.IndexOf('=') + 1;
316+
if (startIndex != -1)
317+
nc = trimmedValue.Substring(startIndex, trimmedValue.Length - startIndex).Trim();
318+
}
319+
}
320+
321+
// Verify username.
322+
if (userhash && ComputeHash(options.Username + ":" + realm, algorithm) != username)
323+
{
324+
return false;
325+
}
326+
327+
if (!userhash && options.Username != username)
328+
{
329+
return false;
330+
}
331+
332+
// Calculate response and compare with the client response hash.
333+
string a1 = options.Username + ":" + realm + ":" + options.Password;
334+
if (algorithm.Contains("sess"))
335+
{
336+
a1 = ComputeHash(a1, algorithm) + ":" + nonce + ":" + cnonce ?? string.Empty;
337+
}
338+
339+
string a2 = requestMethod + ":" + uri;
340+
if (qop.Equals("auth-int"))
341+
{
342+
string content = requestContent ?? string.Empty;
343+
a2 = a2 + ":" + ComputeHash(content, algorithm);
344+
}
345+
346+
string serverResponseHash = ComputeHash(ComputeHash(a1, algorithm) + ":" +
347+
nonce + ":" +
348+
nc + ":" +
349+
cnonce + ":" +
350+
qop + ":" +
351+
ComputeHash(a2, algorithm), algorithm);
352+
353+
return response == serverResponseHash;
354+
}
355+
356+
private static string ComputeHash(string data, string algorithm)
357+
{
358+
// Disable MD5 insecure warning.
359+
#pragma warning disable CA5351
360+
using (HashAlgorithm hash = algorithm.Contains("SHA-256") ? SHA256.Create() : (HashAlgorithm)MD5.Create())
361+
#pragma warning restore CA5351
362+
{
363+
Encoding enc = Encoding.UTF8;
364+
byte[] result = hash.ComputeHash(enc.GetBytes(data));
365+
366+
StringBuilder sb = new StringBuilder(result.Length * 2);
367+
foreach (byte b in result)
368+
sb.Append(b.ToString("x2"));
369+
370+
return sb.ToString();
371+
}
372+
}
373+
118374
public static async Task<bool> WebSocketHandshakeAsync(Socket s, StreamReader reader, StreamWriter writer)
119375
{
120376
string serverResponse = null;
@@ -226,7 +482,7 @@ public static Task StartTransferTypeAndErrorServer(
226482
{
227483
// Read past request headers.
228484
string line;
229-
while (!string.IsNullOrEmpty(line = reader.ReadLine())) ;
485+
while (!string.IsNullOrEmpty(line = reader.ReadLine()));
230486

231487
// Determine response transfer headers.
232488
string transferHeader = null;
@@ -272,6 +528,6 @@ public static Task StartTransferTypeAndErrorServer(
272528

273529
return null;
274530
}), out localEndPoint);
275-
}
531+
}
276532
}
277533
}

0 commit comments

Comments
 (0)