diff --git a/src/Microsoft.Identity.Web/ClientInfo.cs b/src/Microsoft.Identity.Web/ClientInfo.cs index fd9d0e210..88f32e745 100644 --- a/src/Microsoft.Identity.Web/ClientInfo.cs +++ b/src/Microsoft.Identity.Web/ClientInfo.cs @@ -23,7 +23,8 @@ internal class ClientInfo throw new ArgumentNullException(nameof(clientInfo), IDWebErrorMessage.ClientInfoReturnedFromServerIsNull); } - return DeserializeFromJson(Base64UrlHelpers.DecodeToBytes(clientInfo)); + var bytes = Base64UrlHelpers.DecodeBytes(clientInfo); + return bytes != null ? DeserializeFromJson(bytes) : null; } internal static ClientInfo? DeserializeFromJson(byte[] jsonByteArray) diff --git a/src/Microsoft.Identity.Web/Microsoft.Identity.Web.csproj b/src/Microsoft.Identity.Web/Microsoft.Identity.Web.csproj index 32e69ec76..2e8e6ea70 100644 --- a/src/Microsoft.Identity.Web/Microsoft.Identity.Web.csproj +++ b/src/Microsoft.Identity.Web/Microsoft.Identity.Web.csproj @@ -58,6 +58,7 @@ ../../build/MSAL.snk true enable + true diff --git a/src/Microsoft.Identity.Web/Util/Base64UrlHelpers.cs b/src/Microsoft.Identity.Web/Util/Base64UrlHelpers.cs index 351090b98..e62999a20 100644 --- a/src/Microsoft.Identity.Web/Util/Base64UrlHelpers.cs +++ b/src/Microsoft.Identity.Web/Util/Base64UrlHelpers.cs @@ -6,22 +6,36 @@ namespace Microsoft.Identity.Web.Util { + // Based on https://github.com/AzureAD/azure-activedirectory-identitymodel-extensions-for-dotnet/pull/1698/files internal static class Base64UrlHelpers { private const char Base64PadCharacter = '='; + private const char Base64Character62 = '+'; private const char Base64Character63 = '/'; private const char Base64UrlCharacter62 = '-'; private const char Base64UrlCharacter63 = '_'; - private static readonly Encoding TextEncoding = Encoding.UTF8; - private static readonly string DoubleBase64PadCharacter = new string(Base64PadCharacter, 2); + /// + /// Encoding table. + /// + private static readonly char[] s_base64Table = + { + 'A', 'B', 'C', 'D', 'E', 'F', 'G', 'H', 'I', 'J', 'K', 'L', 'M', 'N', 'O', 'P', 'Q', 'R', 'S', 'T', 'U', 'V', 'W', 'X', 'Y', 'Z', + 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q', 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z', + '0', '1', '2', '3', '4', '5', '6', '7', '8', '9', + Base64UrlCharacter62, + Base64UrlCharacter63, + }; - // The following functions perform Base64URL encoding which differs from regular Base64 encoding: - // * Padding is skipped so the pad character '=' doesn't have to be percent encoded. - // * The 62nd and 63rd regular Base64 encoding characters ('+' and '/') are replaced with ('-' and '_'). - // The changes make the encoding alphabet file and URL safe. - // See RFC4648, section 5 for more info. + /// + /// The following functions perform base64url encoding which differs from regular base64 encoding as follows + /// * padding is skipped so the pad character '=' doesn't have to be percent encoded + /// * the 62nd and 63rd regular base64 encoding characters ('+' and '/') are replace with ('-' and '_') + /// The changes make the encoding alphabet file and URL safe. + /// + /// string to encode. + /// Base64Url encoding of the UTF8 bytes. public static string? Encode(string arg) { if (arg == null) @@ -29,57 +43,231 @@ internal static class Base64UrlHelpers return null; } - return Encode(TextEncoding.GetBytes(arg)); + return Encode(Encoding.UTF8.GetBytes(arg)); + } + + /// + /// Converts a subset of an array of 8-bit unsigned integers to its equivalent string representation that is encoded with base-64-url digits. Parameters specify + /// the subset as an offset in the input array, and the number of elements in the array to convert. + /// + /// An array of 8-bit unsigned integers. + /// The number of elements of inArray to convert. + /// An offset in inArray. + /// The string representation in base 64 url encoding of length elements of inArray, starting at position offset. + /// 'inArray' is null. + /// offset or length is negative OR offset plus length is greater than the length of inArray. + private static string Encode(byte[] inArray, int offset, int length) + { + _ = inArray ?? throw new ArgumentNullException(nameof(inArray)); + + if (length == 0) + { + return string.Empty; + } + + if (length < 0) + { + throw new ArgumentOutOfRangeException(nameof(length)); + } + + if (offset < 0 || inArray.Length < offset) + { + throw new ArgumentOutOfRangeException(nameof(offset)); + } + + if (inArray.Length < offset + length) + { + throw new ArgumentOutOfRangeException(nameof(length)); + } + + int lengthmod3 = length % 3; + int limit = offset + (length - lengthmod3); + char[] output = new char[(length + 2) / 3 * 4]; + char[] table = s_base64Table; + int i, j = 0; + + // takes 3 bytes from inArray and insert 4 bytes into output + for (i = offset; i < limit; i += 3) + { + byte d0 = inArray[i]; + byte d1 = inArray[i + 1]; + byte d2 = inArray[i + 2]; + + output[j + 0] = table[d0 >> 2]; + output[j + 1] = table[((d0 & 0x03) << 4) | (d1 >> 4)]; + output[j + 2] = table[((d1 & 0x0f) << 2) | (d2 >> 6)]; + output[j + 3] = table[d2 & 0x3f]; + j += 4; + } + + // Where we left off before + i = limit; + + switch (lengthmod3) + { + case 2: + { + byte d0 = inArray[i]; + byte d1 = inArray[i + 1]; + + output[j + 0] = table[d0 >> 2]; + output[j + 1] = table[((d0 & 0x03) << 4) | (d1 >> 4)]; + output[j + 2] = table[(d1 & 0x0f) << 2]; + j += 3; + } + + break; + + case 1: + { + byte d0 = inArray[i]; + + output[j + 0] = table[d0 >> 2]; + output[j + 1] = table[(d0 & 0x03) << 4]; + j += 2; + } + + break; + + // default or case 0: no further operations are needed. + } + + return new string(output, 0, j); + } + + /// + /// Converts a subset of an array of 8-bit unsigned integers to its equivalent string representation that is encoded with base-64-url digits. Parameters specify + /// the subset as an offset in the input array, and the number of elements in the array to convert. + /// + /// An array of 8-bit unsigned integers. + /// The string representation in base 64 url encoding of length elements of inArray, starting at position offset. + /// 'inArray' is null. + /// offset or length is negative OR offset plus length is greater than the length of inArray. + public static string? Encode(byte[] inArray) + { + if (inArray == null) + { + return null; + } + + return Encode(inArray, 0, inArray.Length); } - public static string DecodeToString(string arg) + internal static string? EncodeString(string str) { - byte[] decoded = DecodeToBytes(arg); - return CreateString(decoded); + if (str == null) + { + return null; + } + + return Encode(Encoding.UTF8.GetBytes(str)); } - public static string CreateString(byte[] bytes) + /// + /// Converts the specified string, which encodes binary data as base-64-url digits, to an equivalent 8-bit unsigned integer array. + /// base64Url encoded string. + /// UTF8 bytes. + public static byte[]? DecodeBytes(string str) { - return Encoding.UTF8.GetString(bytes, 0, bytes.Length); + if (str == null) + { + return null; + } + + return UnsafeDecode(str); } - public static byte[] DecodeToBytes(string arg) + private static unsafe byte[] UnsafeDecode(string str) { - string s = arg; - s = s.Replace(Base64UrlCharacter62, Base64Character62); // 62nd char of encoding - s = s.Replace(Base64UrlCharacter63, Base64Character63); // 63rd char of encoding + int mod = str.Length % 4; + if (mod == 1) + { + throw new ArgumentException(IDWebErrorMessage.InvalidBase64UrlString, nameof(str)); + } - switch (s.Length % 4) + bool needReplace = false; + int decodedLength = str.Length + ((4 - mod) % 4); + + for (int i = 0; i < str.Length; i++) { - // Pad - case 0: - break; // No pad chars in this case - case 2: - s += DoubleBase64PadCharacter; - break; // Two pad chars - case 3: - s += Base64PadCharacter; - break; // One pad char - default: - throw new ArgumentException(IDWebErrorMessage.InvalidBase64UrlString, nameof(arg)); + if (str[i] == Base64UrlCharacter62 || str[i] == Base64UrlCharacter63) + { + needReplace = true; + break; + } } - return Convert.FromBase64String(s); // Standard Base64 decoder + if (needReplace) + { + string decodedString = new string(char.MinValue, decodedLength); + fixed (char* dest = decodedString) + { + int i = 0; + for (; i < str.Length; i++) + { + if (str[i] == Base64UrlCharacter62) + { + dest[i] = Base64Character62; + } + else if (str[i] == Base64UrlCharacter63) + { + dest[i] = Base64Character63; + } + else + { + dest[i] = str[i]; + } + } + + for (; i < decodedLength; i++) + { + dest[i] = Base64PadCharacter; + } + } + + return Convert.FromBase64String(decodedString); + } + else + { + if (decodedLength == str.Length) + { + return Convert.FromBase64String(str); + } + else + { + string decodedString = new string(char.MinValue, decodedLength); + fixed (char* src = str) + { + fixed (char* dest = decodedString) + { + Buffer.MemoryCopy(src, dest, str.Length * 2, str.Length * 2); + dest[str.Length] = Base64PadCharacter; + if (str.Length + 2 == decodedLength) + { + dest[str.Length + 1] = Base64PadCharacter; + } + } + } + + return Convert.FromBase64String(decodedString); + } + } } - internal static string? Encode(byte[] arg) + /// + /// Decodes the string from Base64UrlEncoded to UTF8. + /// + /// string to decode. + /// UTF8 string. + public static string? Decode(string arg) { - if (arg == null) + byte[]? bytes = DecodeBytes(arg); + if (bytes == null) { return null; } - string s = Convert.ToBase64String(arg); - s = s.Split(Base64PadCharacter)[0]; // Remove any trailing padding - s = s.Replace(Base64Character62, Base64UrlCharacter62); // 62nd char of encoding - s = s.Replace(Base64Character63, Base64UrlCharacter63); // 63rd char of encoding - - return s; + return Encoding.UTF8.GetString(bytes); } } } diff --git a/tests/Microsoft.Identity.Web.Test/Base64UrlHelpersTests.cs b/tests/Microsoft.Identity.Web.Test/Base64UrlHelpersTests.cs index c2ef2c593..b86bb2982 100644 --- a/tests/Microsoft.Identity.Web.Test/Base64UrlHelpersTests.cs +++ b/tests/Microsoft.Identity.Web.Test/Base64UrlHelpersTests.cs @@ -61,34 +61,11 @@ public void Encode_DecodedString_ReturnsEncodedString(string stringToEncode, str [InlineData("", "")] // Empty string public void DecodeToString_ValidBase64UrlString_ReturnsDecodedString(string stringToDecode, string expectedDecodedString) { - var actualDecodedString = Base64UrlHelpers.DecodeToString(stringToDecode); + var actualDecodedString = Base64UrlHelpers.Decode(stringToDecode); Assert.Equal(expectedDecodedString, actualDecodedString); } - [Theory] - [InlineData("123456")] - [InlineData("")] - public void CreateString_UTF8Bytes_ReturnsValidString(string stringToCreate) - { - var resultString = Base64UrlHelpers.CreateString(Encoding.UTF8.GetBytes(stringToCreate)); - - Assert.Equal(stringToCreate, resultString); - } - - [Theory] - [InlineData("123456")] - public void CreateString_NonUTF8Bytes_ReturnsInvalidString(string stringToCreate) - { - var resultString = Base64UrlHelpers.CreateString(Encoding.UTF32.GetBytes(stringToCreate)); - - Assert.NotEqual(stringToCreate, resultString); - - resultString = Base64UrlHelpers.CreateString(Encoding.Unicode.GetBytes(stringToCreate)); - - Assert.NotEqual(stringToCreate, resultString); - } - [Theory] [InlineData("MTIzNDU2", "123456")] // No padding [InlineData("MTIzNDU2Nzg", "12345678")] // 1 padding @@ -100,7 +77,7 @@ public void DecodeToBytes_ValidBase64UrlString_ReturnsByteArray(string stringToD { var expectedDecodedByteArray = Encoding.UTF8.GetBytes(expectedDecodedString); - var actualDecodedByteArray = Base64UrlHelpers.DecodeToBytes(stringToDecode); + var actualDecodedByteArray = Base64UrlHelpers.DecodeBytes(stringToDecode); Assert.Equal(expectedDecodedByteArray, actualDecodedByteArray); } @@ -110,10 +87,10 @@ public void DecodeToBytes_InvalidBase64UrlStringLength_ThrowsException() { var stringToDecodeWithInvalidLength = "MTIzNDU21"; - Action decodeAction = () => Base64UrlHelpers.DecodeToBytes(stringToDecodeWithInvalidLength); + Action decodeAction = () => Base64UrlHelpers.DecodeBytes(stringToDecodeWithInvalidLength); var exception = Assert.Throws(decodeAction); - Assert.Equal(IDWebErrorMessage.InvalidBase64UrlString + " (Parameter 'arg')", exception.Message); + Assert.Equal(IDWebErrorMessage.InvalidBase64UrlString + " (Parameter 'str')", exception.Message); } } }