From 48c5b2113e2ab267d27c01a9a2f79c3dd8f16a40 Mon Sep 17 00:00:00 2001 From: Ahson Khan Date: Sun, 25 Mar 2018 15:15:00 -0700 Subject: [PATCH] Fix CompareTo/Equals when dealing with Empty Span or Span wrapping a null string (#17115) * Fix CompareTo/Equals when dealing with Span wrapping a null string * Removing unnecessary virtual keyword and address other feedback. * Remove unnecessary/incorrect Debug.Asserts. * Remove more unnecessary/incorrect Debug.Asserts. --- .../System/Globalization/CompareInfo.Unix.cs | 2 ++ .../Globalization/CompareInfo.Windows.cs | 14 +++++++++++++ .../System/Globalization/CompareInfo.cs | 20 +++++++++++++++---- .../shared/System/MemoryExtensions.cs | 13 ++++++++++++ 4 files changed, 45 insertions(+), 4 deletions(-) diff --git a/src/mscorlib/shared/System/Globalization/CompareInfo.Unix.cs b/src/mscorlib/shared/System/Globalization/CompareInfo.Unix.cs index bfe1f6918d59..5a68492c69d4 100644 --- a/src/mscorlib/shared/System/Globalization/CompareInfo.Unix.cs +++ b/src/mscorlib/shared/System/Globalization/CompareInfo.Unix.cs @@ -181,6 +181,8 @@ internal static unsafe int LastIndexOfOrdinalCore(string source, string value, i private static unsafe int CompareStringOrdinalIgnoreCase(char* string1, int count1, char* string2, int count2) { Debug.Assert(!GlobalizationMode.Invariant); + Debug.Assert(string1 != null); + Debug.Assert(string2 != null); return Interop.Globalization.CompareStringOrdinalIgnoreCase(string1, count1, string2, count2); } diff --git a/src/mscorlib/shared/System/Globalization/CompareInfo.Windows.cs b/src/mscorlib/shared/System/Globalization/CompareInfo.Windows.cs index edc7b03bf225..37ed9469d9b9 100644 --- a/src/mscorlib/shared/System/Globalization/CompareInfo.Windows.cs +++ b/src/mscorlib/shared/System/Globalization/CompareInfo.Windows.cs @@ -40,6 +40,8 @@ private static unsafe int FindStringOrdinal( bool bIgnoreCase) { Debug.Assert(!GlobalizationMode.Invariant); + Debug.Assert(stringSource != null); + Debug.Assert(value != null); fixed (char* pSource = stringSource) fixed (char* pValue = value) @@ -62,6 +64,8 @@ private static unsafe int FindStringOrdinal( bool bIgnoreCase) { Debug.Assert(!GlobalizationMode.Invariant); + Debug.Assert(!source.IsEmpty); + Debug.Assert(!value.IsEmpty); fixed (char* pSource = &MemoryMarshal.GetReference(source)) fixed (char* pValue = &MemoryMarshal.GetReference(value)) @@ -165,6 +169,8 @@ private unsafe int GetHashCodeOfStringCore(string source, CompareOptions options private static unsafe int CompareStringOrdinalIgnoreCase(char* string1, int count1, char* string2, int count2) { Debug.Assert(!GlobalizationMode.Invariant); + Debug.Assert(string1 != null); + Debug.Assert(string2 != null); // Use the OS to compare and then convert the result to expected value by subtracting 2 return Interop.Kernel32.CompareStringOrdinal(string1, count1, string2, count2, true) - 2; @@ -185,6 +191,7 @@ private unsafe int CompareString(ReadOnlySpan string1, string string2, Com fixed (char* pString1 = &MemoryMarshal.GetReference(string1)) fixed (char* pString2 = &string2.GetRawStringData()) { + Debug.Assert(pString1 != null); int result = Interop.Kernel32.CompareStringEx( pLocaleName, (uint)GetNativeCompareFlags(options), @@ -217,6 +224,8 @@ private unsafe int CompareString(ReadOnlySpan string1, ReadOnlySpan fixed (char* pString1 = &MemoryMarshal.GetReference(string1)) fixed (char* pString2 = &MemoryMarshal.GetReference(string2)) { + Debug.Assert(pString1 != null); + Debug.Assert(pString2 != null); int result = Interop.Kernel32.CompareStringEx( pLocaleName, (uint)GetNativeCompareFlags(options), @@ -245,6 +254,8 @@ private unsafe int FindString( int* pcchFound) { Debug.Assert(!_invariantMode); + Debug.Assert(!lpStringSource.IsEmpty); + Debug.Assert(!lpStringValue.IsEmpty); string localeName = _sortHandle != IntPtr.Zero ? null : _sortName; @@ -277,6 +288,8 @@ private unsafe int FindString( int* pcchFound) { Debug.Assert(!_invariantMode); + Debug.Assert(lpStringSource != null); + Debug.Assert(lpStringValue != null); string localeName = _sortHandle != IntPtr.Zero ? null : _sortName; @@ -572,6 +585,7 @@ private unsafe SortKey CreateSortKey(String source, CompareOptions options) private static unsafe bool IsSortable(char* text, int length) { Debug.Assert(!GlobalizationMode.Invariant); + Debug.Assert(text != null); return Interop.Kernel32.IsNLSDefinedString(Interop.Kernel32.COMPARE_STRING, 0, IntPtr.Zero, text, length); } diff --git a/src/mscorlib/shared/System/Globalization/CompareInfo.cs b/src/mscorlib/shared/System/Globalization/CompareInfo.cs index f54ecd91440e..71dc270bc299 100644 --- a/src/mscorlib/shared/System/Globalization/CompareInfo.cs +++ b/src/mscorlib/shared/System/Globalization/CompareInfo.cs @@ -391,15 +391,23 @@ internal int Compare(ReadOnlySpan string1, string string2, CompareOptions return CompareString(string1, string2, options); } - internal virtual int CompareOptionNone(ReadOnlySpan string1, ReadOnlySpan string2) + internal int CompareOptionNone(ReadOnlySpan string1, ReadOnlySpan string2) { + // Check for empty span or span from a null string + if (string1.Length == 0 || string2.Length == 0) + return string1.Length - string2.Length; + return _invariantMode ? string.CompareOrdinal(string1, string2) : CompareString(string1, string2, CompareOptions.None); } - internal virtual int CompareOptionIgnoreCase(ReadOnlySpan string1, ReadOnlySpan string2) + internal int CompareOptionIgnoreCase(ReadOnlySpan string1, ReadOnlySpan string2) { + // Check for empty span or span from a null string + if (string1.Length == 0 || string2.Length == 0) + return string1.Length - string2.Length; + return _invariantMode ? CompareOrdinalIgnoreCase(string1, string2) : CompareString(string1, string2, CompareOptions.IgnoreCase); @@ -892,15 +900,19 @@ public unsafe virtual int IndexOf(string source, string value, int startIndex, i return IndexOfCore(source, value, startIndex, count, options, null); } - internal virtual int IndexOfOrdinal(ReadOnlySpan source, ReadOnlySpan value, bool ignoreCase) + internal int IndexOfOrdinal(ReadOnlySpan source, ReadOnlySpan value, bool ignoreCase) { Debug.Assert(!_invariantMode); + Debug.Assert(!source.IsEmpty); + Debug.Assert(!value.IsEmpty); return IndexOfOrdinalCore(source, value, ignoreCase); } - internal unsafe virtual int IndexOf(ReadOnlySpan source, ReadOnlySpan value, CompareOptions options) + internal unsafe int IndexOf(ReadOnlySpan source, ReadOnlySpan value, CompareOptions options) { Debug.Assert(!_invariantMode); + Debug.Assert(!source.IsEmpty); + Debug.Assert(!value.IsEmpty); return IndexOfCore(source, value, options, null); } diff --git a/src/mscorlib/shared/System/MemoryExtensions.cs b/src/mscorlib/shared/System/MemoryExtensions.cs index 16ce76b4f236..46cf7a1d4b9a 100644 --- a/src/mscorlib/shared/System/MemoryExtensions.cs +++ b/src/mscorlib/shared/System/MemoryExtensions.cs @@ -113,6 +113,7 @@ public static ReadOnlySpan TrimEnd(this ReadOnlySpan span, char trim /// /// The source span from which the characters are removed. /// The span which contains the set of characters to remove. + /// If is empty, white-space characters are removed instead. public static ReadOnlySpan Trim(this ReadOnlySpan span, ReadOnlySpan trimChars) { return span.TrimStart(trimChars).TrimEnd(trimChars); @@ -124,8 +125,14 @@ public static ReadOnlySpan Trim(this ReadOnlySpan span, ReadOnlySpan /// /// The source span from which the characters are removed. /// The span which contains the set of characters to remove. + /// If is empty, white-space characters are removed instead. public static ReadOnlySpan TrimStart(this ReadOnlySpan span, ReadOnlySpan trimChars) { + if (trimChars.IsEmpty) + { + return span.TrimStart(); + } + int start = 0; for (; start < span.Length; start++) { @@ -147,8 +154,14 @@ public static ReadOnlySpan TrimStart(this ReadOnlySpan span, ReadOnl /// /// The source span from which the characters are removed. /// The span which contains the set of characters to remove. + /// If is empty, white-space characters are removed instead. public static ReadOnlySpan TrimEnd(this ReadOnlySpan span, ReadOnlySpan trimChars) { + if (trimChars.IsEmpty) + { + return span.TrimEnd(); + } + int end = span.Length - 1; for (; end >= 0; end--) {