Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Performance: Adds Authorization Helper improvements #2000

Merged
merged 17 commits into from
Nov 18, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
126 changes: 111 additions & 15 deletions Microsoft.Azure.Cosmos/src/Authorization/AuthorizationHelper.cs
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ namespace Microsoft.Azure.Cosmos
{
using System;
using System.Buffers;
using System.Buffers.Text;
using System.Collections.Generic;
using System.Diagnostics.CodeAnalysis;
using System.Globalization;
Expand Down Expand Up @@ -126,7 +127,7 @@ public static string GenerateKeyAuthorizationSignature(string verb,
INameValueCollection headers,
IComputeHash stringHMACSHA256Helper)
{
string authorizationToken = AuthorizationHelper.GenerateAuthorizationTokenWithHashCore(
string authorizationToken = AuthorizationHelper.GenerateUrlEncodedAuthorizationTokenWithHashCore(
verb,
resourceId,
resourceType,
Expand All @@ -135,7 +136,7 @@ public static string GenerateKeyAuthorizationSignature(string verb,
out ArrayOwner payloadStream);
using (payloadStream)
{
return AuthorizationHelper.AuthorizationFormatPrefixUrlEncoded + HttpUtility.UrlEncode(authorizationToken);
return AuthorizationHelper.AuthorizationFormatPrefixUrlEncoded + authorizationToken;
}
}

Expand All @@ -148,7 +149,7 @@ public static string GenerateKeyAuthorizationSignature(string verb,
IComputeHash stringHMACSHA256Helper,
out string payload)
{
string authorizationToken = AuthorizationHelper.GenerateAuthorizationTokenWithHashCore(
string authorizationToken = AuthorizationHelper.GenerateUrlEncodedAuthorizationTokenWithHashCore(
verb,
resourceId,
resourceType,
Expand All @@ -158,7 +159,7 @@ public static string GenerateKeyAuthorizationSignature(string verb,
using (payloadStream)
{
payload = AuthorizationHelper.AuthorizationEncoding.GetString(payloadStream.Buffer.Array, payloadStream.Buffer.Offset, (int)payloadStream.Buffer.Count);
return AuthorizationHelper.AuthorizationFormatPrefixUrlEncoded + HttpUtility.UrlEncode(authorizationToken);
return AuthorizationHelper.AuthorizationFormatPrefixUrlEncoded + authorizationToken;
}
}

Expand All @@ -171,16 +172,16 @@ public static string GenerateKeyAuthorizationSignature(string verb,
IComputeHash stringHMACSHA256Helper,
out ArrayOwner payload)
{
string authorizationToken = AuthorizationHelper.GenerateAuthorizationTokenWithHashCore(
verb,
resourceId,
resourceType,
headers,
stringHMACSHA256Helper,
out payload);
string authorizationToken = AuthorizationHelper.GenerateUrlEncodedAuthorizationTokenWithHashCore(
verb: verb,
resourceId: resourceId,
resourceType: resourceType,
headers: headers,
stringHMACSHA256Helper: stringHMACSHA256Helper,
payload: out payload);
try
{
return AuthorizationHelper.AuthorizationFormatPrefixUrlEncoded + HttpUtility.UrlEncode(authorizationToken);
return AuthorizationHelper.AuthorizationFormatPrefixUrlEncoded + authorizationToken;
}
catch
{
Expand Down Expand Up @@ -652,7 +653,7 @@ private static void ValidateInputRequestTime(
AuthorizationHelper.CheckTimeRangeIsCurrent(allowedClockSkewInSeconds, utcStartTime, utcEndTime);
}

private static string GenerateAuthorizationTokenWithHashCore(
private static string GenerateUrlEncodedAuthorizationTokenWithHashCore(
string verb,
string resourceId,
string resourceType,
Expand Down Expand Up @@ -704,8 +705,7 @@ private static string GenerateAuthorizationTokenWithHashCore(

payload = new ArrayOwner(ArrayPool<byte>.Shared, new ArraySegment<byte>(buffer, 0, length));
byte[] hashPayLoad = stringHMACSHA256Helper.ComputeHash(payload.Buffer);
string authorizationToken = Convert.ToBase64String(hashPayLoad);
return authorizationToken;
return AuthorizationHelper.OptimizedConvertToBase64string(hashPayLoad);
}
catch
{
Expand All @@ -714,6 +714,43 @@ private static string GenerateAuthorizationTokenWithHashCore(
}
}

/// <summary>
/// This an optimized version of doing HttpUtility.UrlEncode(Convert.ToBase64String(hashPayLoad)).
/// This avoids the over head of converting it to a string and back to a byte[].
/// </summary>
private static unsafe string OptimizedConvertToBase64string(byte[] hashPayLoad)
{
// Create a large enough buffer that URL encode can use it.
// Increase the buffer by 3x so it can be used for the URL encoding
int capacity = Base64.GetMaxEncodedToUtf8Length(hashPayLoad.Length) * 3;
byte[] rentedBuffer = ArrayPool<byte>.Shared.Rent(capacity);

try
{
Span<byte> encodingBuffer = rentedBuffer;
// This replaces the Convert.ToBase64String
OperationStatus status = Base64.EncodeToUtf8(
hashPayLoad,
encodingBuffer,
out int _,
out int bytesWritten);

if (status != OperationStatus.Done)
{
throw new ArgumentException($"Authorization key payload is invalid. {status}");
}

return AuthorizationHelper.UrlEncodeBase64SpanInPlace(encodingBuffer, bytesWritten);
}
finally
{
if (rentedBuffer != null)
{
ArrayPool<byte>.Shared.Return(rentedBuffer);
}
}
}

private static int ComputeMemoryCapacity(string verbInput, string authResourceId, string resourceTypeInput)
{
return
Expand Down Expand Up @@ -786,6 +823,65 @@ private static string GenerateKeyAuthorizationCore(
}
}

/// <summary>
/// This does HttpUtility.UrlEncode functionality with Span buffer. It does an in place update to avoid
/// creating the new buffer.
/// </summary>
/// <param name="base64Bytes">The buffer that include the bytes to url encode.</param>
/// <param name="length">The length of bytes used in the buffer</param>
/// <returns>The URLEncoded string of the bytes in the buffer</returns>
public unsafe static string UrlEncodeBase64SpanInPlace(Span<byte> base64Bytes, int length)
{
if (base64Bytes == default)
{
throw new ArgumentNullException(nameof(base64Bytes));
}

if (base64Bytes.Length < length * 3)
{
throw new ArgumentException($"{nameof(base64Bytes)} should be 3x to avoid running out of space in worst case scenario where all characters are special");
}

if (length == 0)
{
return string.Empty;
}

int escapeBufferPosition = base64Bytes.Length - 1;
for (int i = length - 1; i >= 0; i--)
{
byte curr = base64Bytes[i];
// Base64 is limited to Alphanumeric characters and '/' '=' '+'
switch (curr)
{
case (byte)'/':
base64Bytes[escapeBufferPosition--] = (byte)'f';
base64Bytes[escapeBufferPosition--] = (byte)'2';
base64Bytes[escapeBufferPosition--] = (byte)'%';
break;
case (byte)'=':
base64Bytes[escapeBufferPosition--] = (byte)'d';
base64Bytes[escapeBufferPosition--] = (byte)'3';
base64Bytes[escapeBufferPosition--] = (byte)'%';
break;
case (byte)'+':
base64Bytes[escapeBufferPosition--] = (byte)'b';
base64Bytes[escapeBufferPosition--] = (byte)'2';
base64Bytes[escapeBufferPosition--] = (byte)'%';
break;
default:
base64Bytes[escapeBufferPosition--] = curr;
break;
}
}

Span<byte> endSlice = base64Bytes.Slice(escapeBufferPosition + 1);
fixed (byte* bp = endSlice)
{
return Encoding.UTF8.GetString(bp, endSlice.Length);
}
}

private static int Write(this Span<byte> stream, string contentToWrite)
{
int actualByteCount = AuthorizationHelper.AuthorizationEncoding.GetBytes(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,12 +26,14 @@ public MasterKeyAuthorizationBenchmark()
}

[Benchmark]
[BenchmarkCategory("GateBenchmark")]
public void CreateSignatureGeneration()
{
this.TestSignature("POST", "dbs/testdb/colls/testcollection/dbs", "dbs");
}

[Benchmark]
[BenchmarkCategory("GateBenchmark")]
public void ReadSignatureGeneration()
{
this.TestSignature("GET", "dbs/testdb/colls/testcollection/dbs/item1", "dbs");
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
{
"MasterKeyAuthorizationBenchmark.CreateSignatureGeneration;": 546.0,
"MasterKeyAuthorizationBenchmark.ReadSignatureGeneration;": 536.0,
"MockedItemBenchmark.CreateItem;[Type=Stream]": 27406.0,
"MockedItemBenchmark.DeleteItemExists;[Type=Stream]": 27438.0,
"MockedItemBenchmark.DeleteItemNotExists;[Type=Stream]": 44810.0,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,8 @@ public static int ValidateSummaryResultsAgainstBaseline(Dictionary<string, doubl
}
}

Console.WriteLine("Current benchmark results: " + currentBenchmarkResults);

return 0;
}
}
Expand Down
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
Sample usage pattern

> dotnet run -c Release --framework netcoreapp3.1 -- -j Medium -f *MockedItemBenchmark* -m --allStats --join
> dotnet run -c Release --framework netcoreapp3.1 -- -j Medium -f *MockedItemBenchmark* -m --allStats --join

Run all benchmarks for gates:
dotnet run -c Release --framework netcoreapp3.1 --allCategories=GateBenchmark -- -j Medium -m --BaselineValidation
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ namespace Microsoft.Azure.Cosmos.Tests
{
using System;
using System.Net.Http;
using System.Text;
using Microsoft.Azure.Documents;
using Microsoft.Azure.Documents.Collections;
using Microsoft.VisualStudio.TestTools.UnitTesting;
Expand Down Expand Up @@ -227,5 +228,82 @@ public void AuthorizationBaselineTests()
Assert.IsTrue(AuthorizationHelper.CheckPayloadUsingKey(tokenOutput3, baseline[2], baseline[1], baseline[3], nvc, key));
}
}

[TestMethod]
public void Base64UrlEncoderFuzzTest()
{
Random random = new Random();
for(int i = 0; i < 2000; i++)
{
Span<byte> randomBytes = new byte[random.Next(1, 500)];
random.NextBytes(randomBytes);
string randomBase64String = Convert.ToBase64String(randomBytes);
byte[] randomBase64Bytes = Encoding.UTF8.GetBytes(randomBase64String);
Span<byte> buffered = new byte[randomBase64Bytes.Length * 3];
randomBase64Bytes.CopyTo(buffered);

string baseline = null;
string newResults = null;
try
{
baseline = HttpUtility.UrlEncode(randomBase64Bytes);
newResults = AuthorizationHelper.UrlEncodeBase64SpanInPlace(buffered, randomBase64Bytes.Length);
}
catch(Exception e)
{
Assert.Fail($"Url encode failed with string {randomBase64String} ; Exception:{e}");
}

Assert.AreEqual(baseline, newResults);
}
}

[TestMethod]
public void Base64UrlEncoderEdgeCasesTest()
{
{
Span<byte> singleInvalidChar = new byte[3];
singleInvalidChar[0] = (byte)'=';
string urlEncoded = AuthorizationHelper.UrlEncodeBase64SpanInPlace(singleInvalidChar, 1);
Assert.AreEqual("%3d", urlEncoded);
}

{
Span<byte> singleInvalidChar = new byte[3];
singleInvalidChar[0] = (byte)'+';
string urlEncoded = AuthorizationHelper.UrlEncodeBase64SpanInPlace(singleInvalidChar, 1);
Assert.AreEqual("%2b", urlEncoded);
}

{
Span<byte> singleInvalidChar = new byte[3];
singleInvalidChar[0] = (byte)'/';
string urlEncoded = AuthorizationHelper.UrlEncodeBase64SpanInPlace(singleInvalidChar, 1);
Assert.AreEqual("%2f", urlEncoded);
}

{
Span<byte> multipleInvalidChar = new byte[9];
multipleInvalidChar[0] = (byte)'=';
multipleInvalidChar[1] = (byte)'+';
multipleInvalidChar[2] = (byte)'/';
string urlEncoded = AuthorizationHelper.UrlEncodeBase64SpanInPlace(multipleInvalidChar, 3);
Assert.AreEqual("%3d%2b%2f", urlEncoded);
}

{
Span<byte> singleValidChar = new byte[3];
singleValidChar[0] = (byte)'a';
string urlEncoded = AuthorizationHelper.UrlEncodeBase64SpanInPlace(singleValidChar, 1);
Assert.AreEqual("a", urlEncoded);
}

{
byte[] singleInvalidChar = new byte[0];
string result = HttpUtility.UrlEncode(singleInvalidChar);
string urlEncoded = AuthorizationHelper.UrlEncodeBase64SpanInPlace(singleInvalidChar, 0);
Assert.AreEqual(result, urlEncoded);
}
}
}
}