diff --git a/src/Microsoft.Identity.Web/ClientInfo.cs b/src/Microsoft.Identity.Web/ClientInfo.cs
index fd9d0e210..88f32e745 100644
--- a/src/Microsoft.Identity.Web/ClientInfo.cs
+++ b/src/Microsoft.Identity.Web/ClientInfo.cs
@@ -23,7 +23,8 @@ internal class ClientInfo
throw new ArgumentNullException(nameof(clientInfo), IDWebErrorMessage.ClientInfoReturnedFromServerIsNull);
}
- return DeserializeFromJson(Base64UrlHelpers.DecodeToBytes(clientInfo));
+ var bytes = Base64UrlHelpers.DecodeBytes(clientInfo);
+ return bytes != null ? DeserializeFromJson(bytes) : null;
}
internal static ClientInfo? DeserializeFromJson(byte[] jsonByteArray)
diff --git a/src/Microsoft.Identity.Web/Microsoft.Identity.Web.csproj b/src/Microsoft.Identity.Web/Microsoft.Identity.Web.csproj
index 32e69ec76..2e8e6ea70 100644
--- a/src/Microsoft.Identity.Web/Microsoft.Identity.Web.csproj
+++ b/src/Microsoft.Identity.Web/Microsoft.Identity.Web.csproj
@@ -58,6 +58,7 @@
../../build/MSAL.snk
true
enable
+ true
diff --git a/src/Microsoft.Identity.Web/Util/Base64UrlHelpers.cs b/src/Microsoft.Identity.Web/Util/Base64UrlHelpers.cs
index 351090b98..e62999a20 100644
--- a/src/Microsoft.Identity.Web/Util/Base64UrlHelpers.cs
+++ b/src/Microsoft.Identity.Web/Util/Base64UrlHelpers.cs
@@ -6,22 +6,36 @@
namespace Microsoft.Identity.Web.Util
{
+ // Based on https://github.com/AzureAD/azure-activedirectory-identitymodel-extensions-for-dotnet/pull/1698/files
internal static class Base64UrlHelpers
{
private const char Base64PadCharacter = '=';
+
private const char Base64Character62 = '+';
private const char Base64Character63 = '/';
private const char Base64UrlCharacter62 = '-';
private const char Base64UrlCharacter63 = '_';
- private static readonly Encoding TextEncoding = Encoding.UTF8;
- private static readonly string DoubleBase64PadCharacter = new string(Base64PadCharacter, 2);
+ ///
+ /// Encoding table.
+ ///
+ private static readonly char[] s_base64Table =
+ {
+ 'A', 'B', 'C', 'D', 'E', 'F', 'G', 'H', 'I', 'J', 'K', 'L', 'M', 'N', 'O', 'P', 'Q', 'R', 'S', 'T', 'U', 'V', 'W', 'X', 'Y', 'Z',
+ 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q', 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z',
+ '0', '1', '2', '3', '4', '5', '6', '7', '8', '9',
+ Base64UrlCharacter62,
+ Base64UrlCharacter63,
+ };
- // The following functions perform Base64URL encoding which differs from regular Base64 encoding:
- // * Padding is skipped so the pad character '=' doesn't have to be percent encoded.
- // * The 62nd and 63rd regular Base64 encoding characters ('+' and '/') are replaced with ('-' and '_').
- // The changes make the encoding alphabet file and URL safe.
- // See RFC4648, section 5 for more info.
+ ///
+ /// The following functions perform base64url encoding which differs from regular base64 encoding as follows
+ /// * padding is skipped so the pad character '=' doesn't have to be percent encoded
+ /// * the 62nd and 63rd regular base64 encoding characters ('+' and '/') are replace with ('-' and '_')
+ /// The changes make the encoding alphabet file and URL safe.
+ ///
+ /// string to encode.
+ /// Base64Url encoding of the UTF8 bytes.
public static string? Encode(string arg)
{
if (arg == null)
@@ -29,57 +43,231 @@ internal static class Base64UrlHelpers
return null;
}
- return Encode(TextEncoding.GetBytes(arg));
+ return Encode(Encoding.UTF8.GetBytes(arg));
+ }
+
+ ///
+ /// Converts a subset of an array of 8-bit unsigned integers to its equivalent string representation that is encoded with base-64-url digits. Parameters specify
+ /// the subset as an offset in the input array, and the number of elements in the array to convert.
+ ///
+ /// An array of 8-bit unsigned integers.
+ /// The number of elements of inArray to convert.
+ /// An offset in inArray.
+ /// The string representation in base 64 url encoding of length elements of inArray, starting at position offset.
+ /// 'inArray' is null.
+ /// offset or length is negative OR offset plus length is greater than the length of inArray.
+ private static string Encode(byte[] inArray, int offset, int length)
+ {
+ _ = inArray ?? throw new ArgumentNullException(nameof(inArray));
+
+ if (length == 0)
+ {
+ return string.Empty;
+ }
+
+ if (length < 0)
+ {
+ throw new ArgumentOutOfRangeException(nameof(length));
+ }
+
+ if (offset < 0 || inArray.Length < offset)
+ {
+ throw new ArgumentOutOfRangeException(nameof(offset));
+ }
+
+ if (inArray.Length < offset + length)
+ {
+ throw new ArgumentOutOfRangeException(nameof(length));
+ }
+
+ int lengthmod3 = length % 3;
+ int limit = offset + (length - lengthmod3);
+ char[] output = new char[(length + 2) / 3 * 4];
+ char[] table = s_base64Table;
+ int i, j = 0;
+
+ // takes 3 bytes from inArray and insert 4 bytes into output
+ for (i = offset; i < limit; i += 3)
+ {
+ byte d0 = inArray[i];
+ byte d1 = inArray[i + 1];
+ byte d2 = inArray[i + 2];
+
+ output[j + 0] = table[d0 >> 2];
+ output[j + 1] = table[((d0 & 0x03) << 4) | (d1 >> 4)];
+ output[j + 2] = table[((d1 & 0x0f) << 2) | (d2 >> 6)];
+ output[j + 3] = table[d2 & 0x3f];
+ j += 4;
+ }
+
+ // Where we left off before
+ i = limit;
+
+ switch (lengthmod3)
+ {
+ case 2:
+ {
+ byte d0 = inArray[i];
+ byte d1 = inArray[i + 1];
+
+ output[j + 0] = table[d0 >> 2];
+ output[j + 1] = table[((d0 & 0x03) << 4) | (d1 >> 4)];
+ output[j + 2] = table[(d1 & 0x0f) << 2];
+ j += 3;
+ }
+
+ break;
+
+ case 1:
+ {
+ byte d0 = inArray[i];
+
+ output[j + 0] = table[d0 >> 2];
+ output[j + 1] = table[(d0 & 0x03) << 4];
+ j += 2;
+ }
+
+ break;
+
+ // default or case 0: no further operations are needed.
+ }
+
+ return new string(output, 0, j);
+ }
+
+ ///
+ /// Converts a subset of an array of 8-bit unsigned integers to its equivalent string representation that is encoded with base-64-url digits. Parameters specify
+ /// the subset as an offset in the input array, and the number of elements in the array to convert.
+ ///
+ /// An array of 8-bit unsigned integers.
+ /// The string representation in base 64 url encoding of length elements of inArray, starting at position offset.
+ /// 'inArray' is null.
+ /// offset or length is negative OR offset plus length is greater than the length of inArray.
+ public static string? Encode(byte[] inArray)
+ {
+ if (inArray == null)
+ {
+ return null;
+ }
+
+ return Encode(inArray, 0, inArray.Length);
}
- public static string DecodeToString(string arg)
+ internal static string? EncodeString(string str)
{
- byte[] decoded = DecodeToBytes(arg);
- return CreateString(decoded);
+ if (str == null)
+ {
+ return null;
+ }
+
+ return Encode(Encoding.UTF8.GetBytes(str));
}
- public static string CreateString(byte[] bytes)
+ ///
+ /// Converts the specified string, which encodes binary data as base-64-url digits, to an equivalent 8-bit unsigned integer array.
+ /// base64Url encoded string.
+ /// UTF8 bytes.
+ public static byte[]? DecodeBytes(string str)
{
- return Encoding.UTF8.GetString(bytes, 0, bytes.Length);
+ if (str == null)
+ {
+ return null;
+ }
+
+ return UnsafeDecode(str);
}
- public static byte[] DecodeToBytes(string arg)
+ private static unsafe byte[] UnsafeDecode(string str)
{
- string s = arg;
- s = s.Replace(Base64UrlCharacter62, Base64Character62); // 62nd char of encoding
- s = s.Replace(Base64UrlCharacter63, Base64Character63); // 63rd char of encoding
+ int mod = str.Length % 4;
+ if (mod == 1)
+ {
+ throw new ArgumentException(IDWebErrorMessage.InvalidBase64UrlString, nameof(str));
+ }
- switch (s.Length % 4)
+ bool needReplace = false;
+ int decodedLength = str.Length + ((4 - mod) % 4);
+
+ for (int i = 0; i < str.Length; i++)
{
- // Pad
- case 0:
- break; // No pad chars in this case
- case 2:
- s += DoubleBase64PadCharacter;
- break; // Two pad chars
- case 3:
- s += Base64PadCharacter;
- break; // One pad char
- default:
- throw new ArgumentException(IDWebErrorMessage.InvalidBase64UrlString, nameof(arg));
+ if (str[i] == Base64UrlCharacter62 || str[i] == Base64UrlCharacter63)
+ {
+ needReplace = true;
+ break;
+ }
}
- return Convert.FromBase64String(s); // Standard Base64 decoder
+ if (needReplace)
+ {
+ string decodedString = new string(char.MinValue, decodedLength);
+ fixed (char* dest = decodedString)
+ {
+ int i = 0;
+ for (; i < str.Length; i++)
+ {
+ if (str[i] == Base64UrlCharacter62)
+ {
+ dest[i] = Base64Character62;
+ }
+ else if (str[i] == Base64UrlCharacter63)
+ {
+ dest[i] = Base64Character63;
+ }
+ else
+ {
+ dest[i] = str[i];
+ }
+ }
+
+ for (; i < decodedLength; i++)
+ {
+ dest[i] = Base64PadCharacter;
+ }
+ }
+
+ return Convert.FromBase64String(decodedString);
+ }
+ else
+ {
+ if (decodedLength == str.Length)
+ {
+ return Convert.FromBase64String(str);
+ }
+ else
+ {
+ string decodedString = new string(char.MinValue, decodedLength);
+ fixed (char* src = str)
+ {
+ fixed (char* dest = decodedString)
+ {
+ Buffer.MemoryCopy(src, dest, str.Length * 2, str.Length * 2);
+ dest[str.Length] = Base64PadCharacter;
+ if (str.Length + 2 == decodedLength)
+ {
+ dest[str.Length + 1] = Base64PadCharacter;
+ }
+ }
+ }
+
+ return Convert.FromBase64String(decodedString);
+ }
+ }
}
- internal static string? Encode(byte[] arg)
+ ///
+ /// Decodes the string from Base64UrlEncoded to UTF8.
+ ///
+ /// string to decode.
+ /// UTF8 string.
+ public static string? Decode(string arg)
{
- if (arg == null)
+ byte[]? bytes = DecodeBytes(arg);
+ if (bytes == null)
{
return null;
}
- string s = Convert.ToBase64String(arg);
- s = s.Split(Base64PadCharacter)[0]; // Remove any trailing padding
- s = s.Replace(Base64Character62, Base64UrlCharacter62); // 62nd char of encoding
- s = s.Replace(Base64Character63, Base64UrlCharacter63); // 63rd char of encoding
-
- return s;
+ return Encoding.UTF8.GetString(bytes);
}
}
}
diff --git a/tests/Microsoft.Identity.Web.Test/Base64UrlHelpersTests.cs b/tests/Microsoft.Identity.Web.Test/Base64UrlHelpersTests.cs
index c2ef2c593..b86bb2982 100644
--- a/tests/Microsoft.Identity.Web.Test/Base64UrlHelpersTests.cs
+++ b/tests/Microsoft.Identity.Web.Test/Base64UrlHelpersTests.cs
@@ -61,34 +61,11 @@ public void Encode_DecodedString_ReturnsEncodedString(string stringToEncode, str
[InlineData("", "")] // Empty string
public void DecodeToString_ValidBase64UrlString_ReturnsDecodedString(string stringToDecode, string expectedDecodedString)
{
- var actualDecodedString = Base64UrlHelpers.DecodeToString(stringToDecode);
+ var actualDecodedString = Base64UrlHelpers.Decode(stringToDecode);
Assert.Equal(expectedDecodedString, actualDecodedString);
}
- [Theory]
- [InlineData("123456")]
- [InlineData("")]
- public void CreateString_UTF8Bytes_ReturnsValidString(string stringToCreate)
- {
- var resultString = Base64UrlHelpers.CreateString(Encoding.UTF8.GetBytes(stringToCreate));
-
- Assert.Equal(stringToCreate, resultString);
- }
-
- [Theory]
- [InlineData("123456")]
- public void CreateString_NonUTF8Bytes_ReturnsInvalidString(string stringToCreate)
- {
- var resultString = Base64UrlHelpers.CreateString(Encoding.UTF32.GetBytes(stringToCreate));
-
- Assert.NotEqual(stringToCreate, resultString);
-
- resultString = Base64UrlHelpers.CreateString(Encoding.Unicode.GetBytes(stringToCreate));
-
- Assert.NotEqual(stringToCreate, resultString);
- }
-
[Theory]
[InlineData("MTIzNDU2", "123456")] // No padding
[InlineData("MTIzNDU2Nzg", "12345678")] // 1 padding
@@ -100,7 +77,7 @@ public void DecodeToBytes_ValidBase64UrlString_ReturnsByteArray(string stringToD
{
var expectedDecodedByteArray = Encoding.UTF8.GetBytes(expectedDecodedString);
- var actualDecodedByteArray = Base64UrlHelpers.DecodeToBytes(stringToDecode);
+ var actualDecodedByteArray = Base64UrlHelpers.DecodeBytes(stringToDecode);
Assert.Equal(expectedDecodedByteArray, actualDecodedByteArray);
}
@@ -110,10 +87,10 @@ public void DecodeToBytes_InvalidBase64UrlStringLength_ThrowsException()
{
var stringToDecodeWithInvalidLength = "MTIzNDU21";
- Action decodeAction = () => Base64UrlHelpers.DecodeToBytes(stringToDecodeWithInvalidLength);
+ Action decodeAction = () => Base64UrlHelpers.DecodeBytes(stringToDecodeWithInvalidLength);
var exception = Assert.Throws(decodeAction);
- Assert.Equal(IDWebErrorMessage.InvalidBase64UrlString + " (Parameter 'arg')", exception.Message);
+ Assert.Equal(IDWebErrorMessage.InvalidBase64UrlString + " (Parameter 'str')", exception.Message);
}
}
}