Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

using System.Collections.Generic;
using System.Linq;
using System.Runtime.InteropServices;
using Xunit;

namespace System.Collections.Tests
Expand Down Expand Up @@ -245,13 +246,33 @@ public static void Xor_With_Resize(BitArray left, BitArray right, int newLeftLen

public static IEnumerable<object[]> Shift_Data()
{
foreach (int size in new[] { 0, 1, BitsPerInt32 / 2, BitsPerInt32, BitsPerInt32 + 1, 2 * BitsPerInt32, 2 * BitsPerInt32 + 1 })
Random random = new Random(0);
foreach (int size in new[] {
0,
1,
BitsPerInt32 / 2,
BitsPerInt32,
BitsPerInt32 + 1,
2 * BitsPerInt32 - 1,
2 * BitsPerInt32 + 1,
1023,
1024,
1025,
})
{
foreach (int shift in new[] { 0, 1, size / 2, size - 1, size }.Where(s => s >= 0).Distinct())
foreach (int shift in new[] {
0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10,
30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40,
size / 3, size / 2, size / 2 + 1, size - 1, size,
}.Where(s => s >= 0).Distinct())
{
yield return new object[] { size, new int[] { /* deliberately empty */ }, shift };
yield return new object[] { size, Enumerable.Range(0, size), shift };

int[] nums = Enumerable.Range(0, size).ToArray();
random.Shuffle(nums);
yield return new object[] { size, nums.Take(size / 2), shift };

if (size > 1)
{
foreach (int position in new[] { 0, size / 2, size - 1 })
Expand Down Expand Up @@ -290,6 +311,14 @@ public static void LeftShift(int length, IEnumerable<int> set, int shift)

int index = 0;
Assert.All(ba.Cast<bool>(), bit => Assert.Equal(expected[index++], bit));

(int byteIndex, int bitOffeset) = Math.DivRem(length, BitsPerByte);
if (bitOffeset != 0)
{
Span<byte> bs = CollectionsMarshal.AsBytes(ba);
Assert.Equal(byteIndex + 1, bs.Length);
Assert.Equal(0, bs[byteIndex] >> bitOffeset);
}
}

private static bool[] GetBoolArray(int length, IEnumerable<int> set)
Expand Down
181 changes: 139 additions & 42 deletions src/libraries/System.Private.CoreLib/src/System/Collections/BitArray.cs
Original file line number Diff line number Diff line change
Expand Up @@ -514,51 +514,94 @@ public BitArray RightShift(int count)
return this;
}

Span<int> intSpan = MemoryMarshal.Cast<byte, int>((Span<byte>)_array);

Span<byte> thisSpan = new Span<byte>(_array, 0, GetByteArrayLengthFromBitLength(_bitLength));
int toIndex = 0;
int ints = GetInt32ArrayLengthFromBitLength(_bitLength);

if (count < _bitLength)
{
// We can not use Math.DivRem without taking a dependency on System.Runtime.Extensions
(int fromIndex, int shiftCount) = Math.DivRem(count, 32);
int extraBits = (int)((uint)_bitLength % 32);
(int fromIndex, int shiftCount) = Math.DivRem(count, BitsPerByte);
if (shiftCount == 0)
{
// Cannot use `(1u << extraBits) - 1u` as the mask
// because for extraBits == 0, we need the mask to be 111...111, not 0.
// In that case, we are shifting a uint by 32, which could be considered undefined.
// The result of a shift operation is undefined ... if the right operand
// is greater than or equal to the width in bits of the promoted left operand,
// https://learn.microsoft.com/cpp/c-language/bitwise-shift-operators?view=vs-2017
// However, the compiler protects us from undefined behaviour by constraining the
// right operand to between 0 and width - 1 (inclusive), i.e. right_operand = (right_operand % width).
uint mask = uint.MaxValue >> (BitsPerInt32 - extraBits);
intSpan[ints - 1] &= ReverseIfBE((int)mask);

intSpan.Slice((int)fromIndex, ints - fromIndex).CopyTo(intSpan);
toIndex = ints - fromIndex;
thisSpan.Slice(fromIndex).CopyTo(thisSpan);
toIndex = thisSpan.Length - fromIndex;
}
else
{
int lastIndex = ints - 1;
if (Vector512.IsHardwareAccelerated)
{
toIndex = Apply<Vector512<byte>>(shiftCount, fromIndex, thisSpan);
}
else if (Vector256.IsHardwareAccelerated)
{
toIndex = Apply<Vector256<byte>>(shiftCount, fromIndex, thisSpan);
}
else if (Vector128.IsHardwareAccelerated)
{
toIndex = Apply<Vector128<byte>>(shiftCount, fromIndex, thisSpan);
}
fromIndex += toIndex;

int carryCount = BitsPerByte - shiftCount;

ref byte p = ref MemoryMarshal.GetReference(thisSpan);

const uint shiftUnit = 0x01010101u;
uint shiftMask = (shiftUnit << carryCount) - shiftUnit;
uint carryMask = ~shiftMask;

while (fromIndex < lastIndex)
while (fromIndex < thisSpan.Length - 4)
{
uint right = (uint)ReverseIfBE(intSpan[fromIndex]) >> shiftCount;
int left = ReverseIfBE(intSpan[++fromIndex]) << (BitsPerInt32 - shiftCount);
intSpan[toIndex++] = ReverseIfBE(left | (int)right);
uint lo = (Unsafe.ReadUnaligned<uint>(ref Unsafe.AddByteOffset(ref p, (uint)fromIndex)) >>> shiftCount) & shiftMask;
uint hi = (Unsafe.ReadUnaligned<uint>(ref Unsafe.AddByteOffset(ref p, (uint)(fromIndex + 1))) << carryCount) & carryMask;
uint result = hi | lo;
Unsafe.WriteUnaligned(ref Unsafe.AddByteOffset(ref p, toIndex), result);

fromIndex += 4;
toIndex += 4;
}

uint mask = uint.MaxValue >> (BitsPerInt32 - extraBits);
mask &= (uint)ReverseIfBE(intSpan[fromIndex]);
intSpan[toIndex++] = ReverseIfBE((int)(mask >> shiftCount));
while (fromIndex < thisSpan.Length)
{
int lo = thisSpan[fromIndex] >>> shiftCount;
int hi =
fromIndex + 1 < thisSpan.Length
? thisSpan[fromIndex + 1] << carryCount
: 0;

thisSpan[toIndex] = (byte)(hi | lo);

fromIndex++;
toIndex++;
}
}
}

intSpan.Slice(toIndex, ints - toIndex).Clear();
thisSpan.Slice(toIndex).Clear();
_version++;
return this;

[MethodImpl(MethodImplOptions.AggressiveInlining)]
static int Apply<TVector>(int shiftCount, int fromIndex, Span<byte> thisSpan)
where TVector : ISimdVector<TVector, byte>
{
ref byte p = ref MemoryMarshal.GetReference(thisSpan);
int carryCount = BitsPerByte - shiftCount;

int toIndex = 0;

while (fromIndex <= thisSpan.Length - (TVector.ElementCount + 1))
{
TVector lo = TVector.LoadUnsafe(ref p, (uint)fromIndex) >>> shiftCount;
TVector hi = TVector.LoadUnsafe(ref p, (uint)(fromIndex + 1)) << carryCount;
TVector result = lo | hi;
result.StoreUnsafe(ref p, (uint)toIndex);

fromIndex += TVector.ElementCount;
toIndex += TVector.ElementCount;
}

return toIndex;
}
}

/// <summary>
Expand All @@ -577,41 +620,95 @@ public BitArray LeftShift(int count)
return this;
}

Span<int> intSpan = MemoryMarshal.Cast<byte, int>((Span<byte>)_array);
Span<byte> thisSpan = new Span<byte>(_array, 0, GetByteArrayLengthFromBitLength(_bitLength));

int lengthToClear;
if (count < _bitLength)
{
int lastIndex = (int)((uint)(_bitLength - 1) / BitsPerInt32);

(lengthToClear, int shiftCount) = Math.DivRem(count, BitsPerInt32);
(lengthToClear, int shiftCount) = Math.DivRem(count, BitsPerByte);

if (shiftCount == 0)
{
intSpan.Slice(0, lastIndex + 1 - lengthToClear).CopyTo(intSpan.Slice(lengthToClear));
thisSpan.Slice(0, thisSpan.Length - lengthToClear).CopyTo(thisSpan.Slice(lengthToClear));
}
else
{
int fromindex = lastIndex - lengthToClear;
int toIndex = thisSpan.Length;
int fromIndex = toIndex - lengthToClear;

if (Vector512.IsHardwareAccelerated)
{
toIndex = Apply<Vector512<byte>>(shiftCount, fromIndex, thisSpan);
}
else if (Vector256.IsHardwareAccelerated)
{
toIndex = Apply<Vector256<byte>>(shiftCount, fromIndex, thisSpan);
}
else if (Vector128.IsHardwareAccelerated)
{
toIndex = Apply<Vector128<byte>>(shiftCount, fromIndex, thisSpan);
}
fromIndex = toIndex - lengthToClear;

int carryCount = BitsPerByte - shiftCount;

ref byte p = ref MemoryMarshal.GetReference(thisSpan);

const uint shiftUnit = 0x01010101u;
uint carryMask = (shiftUnit << shiftCount) - shiftUnit;
uint shiftMask = ~carryMask;

while (fromIndex >= 5)
{
uint lo = (Unsafe.ReadUnaligned<uint>(ref Unsafe.AddByteOffset(ref p, (uint)(fromIndex -= 4))) << shiftCount) & shiftMask;
uint hi = (Unsafe.ReadUnaligned<uint>(ref Unsafe.AddByteOffset(ref p, (uint)(fromIndex - 1))) >>> carryCount) & carryMask;
uint result = hi | lo;
Unsafe.WriteUnaligned(ref Unsafe.AddByteOffset(ref p, toIndex -= 4), result);
}

while (fromindex > 0)
while (--fromIndex >= 0)
{
int left = ReverseIfBE(intSpan[fromindex]) << shiftCount;
uint right = (uint)ReverseIfBE(intSpan[--fromindex]) >> (BitsPerInt32 - shiftCount);
intSpan[lastIndex] = ReverseIfBE(left | (int)right);
lastIndex--;
int hi = thisSpan[fromIndex] << shiftCount;
int lo =
fromIndex > 0
? thisSpan[fromIndex - 1] >>> carryCount
: 0;

thisSpan[--toIndex] = (byte)(hi | lo);
}
intSpan[lastIndex] = ReverseIfBE(ReverseIfBE(intSpan[fromindex]) << shiftCount);

Debug.Assert(toIndex == lengthToClear);
}
}
else
{
lengthToClear = GetInt32ArrayLengthFromBitLength(_bitLength); // Clear all
lengthToClear = thisSpan.Length; // Clear all
}

intSpan.Slice(0, lengthToClear).Clear();
thisSpan.Slice(0, lengthToClear).Clear();
ClearHighExtraBits();
_version++;
return this;

[MethodImpl(MethodImplOptions.AggressiveInlining)]
static int Apply<TVector>(int shiftCount, int fromIndex, Span<byte> thisSpan)
where TVector : ISimdVector<TVector, byte>
{
ref byte p = ref MemoryMarshal.GetReference(thisSpan);
int carryCount = BitsPerByte - shiftCount;

int toIndex = thisSpan.Length;

while (fromIndex >= TVector.ElementCount + 1)
{
TVector hi = TVector.LoadUnsafe(ref p, (nuint)(fromIndex -= TVector.ElementCount)) << shiftCount;
TVector lo = TVector.LoadUnsafe(ref p, (nuint)(fromIndex - 1)) >>> carryCount;
TVector result = hi | lo;
result.StoreUnsafe(ref p, (nuint)(toIndex -= TVector.ElementCount));
}

return toIndex;
}
}

/// <summary>
Expand Down
Loading