Skip to content

Commit e05c3b8

Browse files
committed
Issue #24343 Vector Ctor using Span
1 parent 8ec2cf2 commit e05c3b8

File tree

5 files changed

+370
-2
lines changed

5 files changed

+370
-2
lines changed

src/Common/src/CoreLib/System/Numerics/Vector.cs

+251
Original file line numberDiff line numberDiff line change
@@ -623,6 +623,257 @@ public unsafe Vector(T[] values, int index)
623623
}
624624
}
625625

626+
/// <summary>
627+
/// Constructs a vector from the given span.
628+
/// The span must contain at least Vector'T.Count elements.
629+
/// </summary>
630+
public unsafe Vector(Span<T> values)
631+
: this()
632+
{
633+
if (values == null)
634+
{
635+
// Match the JIT's exception type here. For perf, a NullReference is thrown instead of an ArgumentNull.
636+
throw new NullReferenceException(SR.Arg_NullArgumentNullRef);
637+
}
638+
if (values.Length < Count)
639+
{
640+
throw new IndexOutOfRangeException();
641+
}
642+
643+
if (Vector.IsHardwareAccelerated)
644+
{
645+
if (typeof(T) == typeof(Byte))
646+
{
647+
fixed (Byte* basePtr = &this.register.byte_0)
648+
{
649+
for (int g = 0; g < Count; g++)
650+
{
651+
*(basePtr + g) = (Byte)(object)values[g];
652+
}
653+
}
654+
}
655+
else if (typeof(T) == typeof(SByte))
656+
{
657+
fixed (SByte* basePtr = &this.register.sbyte_0)
658+
{
659+
for (int g = 0; g < Count; g++)
660+
{
661+
*(basePtr + g) = (SByte)(object)values[g];
662+
}
663+
}
664+
}
665+
else if (typeof(T) == typeof(UInt16))
666+
{
667+
fixed (UInt16* basePtr = &this.register.uint16_0)
668+
{
669+
for (int g = 0; g < Count; g++)
670+
{
671+
*(basePtr + g) = (UInt16)(object)values[g];
672+
}
673+
}
674+
}
675+
else if (typeof(T) == typeof(Int16))
676+
{
677+
fixed (Int16* basePtr = &this.register.int16_0)
678+
{
679+
for (int g = 0; g < Count; g++)
680+
{
681+
*(basePtr + g) = (Int16)(object)values[g];
682+
}
683+
}
684+
}
685+
else if (typeof(T) == typeof(UInt32))
686+
{
687+
fixed (UInt32* basePtr = &this.register.uint32_0)
688+
{
689+
for (int g = 0; g < Count; g++)
690+
{
691+
*(basePtr + g) = (UInt32)(object)values[g];
692+
}
693+
}
694+
}
695+
else if (typeof(T) == typeof(Int32))
696+
{
697+
fixed (Int32* basePtr = &this.register.int32_0)
698+
{
699+
for (int g = 0; g < Count; g++)
700+
{
701+
*(basePtr + g) = (Int32)(object)values[g];
702+
}
703+
}
704+
}
705+
else if (typeof(T) == typeof(UInt64))
706+
{
707+
fixed (UInt64* basePtr = &this.register.uint64_0)
708+
{
709+
for (int g = 0; g < Count; g++)
710+
{
711+
*(basePtr + g) = (UInt64)(object)values[g];
712+
}
713+
}
714+
}
715+
else if (typeof(T) == typeof(Int64))
716+
{
717+
fixed (Int64* basePtr = &this.register.int64_0)
718+
{
719+
for (int g = 0; g < Count; g++)
720+
{
721+
*(basePtr + g) = (Int64)(object)values[g];
722+
}
723+
}
724+
}
725+
else if (typeof(T) == typeof(Single))
726+
{
727+
fixed (Single* basePtr = &this.register.single_0)
728+
{
729+
for (int g = 0; g < Count; g++)
730+
{
731+
*(basePtr + g) = (Single)(object)values[g];
732+
}
733+
}
734+
}
735+
else if (typeof(T) == typeof(Double))
736+
{
737+
fixed (Double* basePtr = &this.register.double_0)
738+
{
739+
for (int g = 0; g < Count; g++)
740+
{
741+
*(basePtr + g) = (Double)(object)values[g];
742+
}
743+
}
744+
}
745+
}
746+
else
747+
{
748+
if (typeof(T) == typeof(Byte))
749+
{
750+
fixed (Byte* basePtr = &this.register.byte_0)
751+
{
752+
*(basePtr + 0) = (Byte)(object)values[0];
753+
*(basePtr + 1) = (Byte)(object)values[1];
754+
*(basePtr + 2) = (Byte)(object)values[2];
755+
*(basePtr + 3) = (Byte)(object)values[3];
756+
*(basePtr + 4) = (Byte)(object)values[4];
757+
*(basePtr + 5) = (Byte)(object)values[5];
758+
*(basePtr + 6) = (Byte)(object)values[6];
759+
*(basePtr + 7) = (Byte)(object)values[7];
760+
*(basePtr + 8) = (Byte)(object)values[8];
761+
*(basePtr + 9) = (Byte)(object)values[9];
762+
*(basePtr + 10) = (Byte)(object)values[10];
763+
*(basePtr + 11) = (Byte)(object)values[11];
764+
*(basePtr + 12) = (Byte)(object)values[12];
765+
*(basePtr + 13) = (Byte)(object)values[13];
766+
*(basePtr + 14) = (Byte)(object)values[14];
767+
*(basePtr + 15) = (Byte)(object)values[15];
768+
}
769+
}
770+
else if (typeof(T) == typeof(SByte))
771+
{
772+
fixed (SByte* basePtr = &this.register.sbyte_0)
773+
{
774+
*(basePtr + 0) = (SByte)(object)values[0];
775+
*(basePtr + 1) = (SByte)(object)values[1];
776+
*(basePtr + 2) = (SByte)(object)values[2];
777+
*(basePtr + 3) = (SByte)(object)values[3];
778+
*(basePtr + 4) = (SByte)(object)values[4];
779+
*(basePtr + 5) = (SByte)(object)values[5];
780+
*(basePtr + 6) = (SByte)(object)values[6];
781+
*(basePtr + 7) = (SByte)(object)values[7];
782+
*(basePtr + 8) = (SByte)(object)values[8];
783+
*(basePtr + 9) = (SByte)(object)values[9];
784+
*(basePtr + 10) = (SByte)(object)values[10];
785+
*(basePtr + 11) = (SByte)(object)values[11];
786+
*(basePtr + 12) = (SByte)(object)values[12];
787+
*(basePtr + 13) = (SByte)(object)values[13];
788+
*(basePtr + 14) = (SByte)(object)values[14];
789+
*(basePtr + 15) = (SByte)(object)values[15];
790+
}
791+
}
792+
else if (typeof(T) == typeof(UInt16))
793+
{
794+
fixed (UInt16* basePtr = &this.register.uint16_0)
795+
{
796+
*(basePtr + 0) = (UInt16)(object)values[0];
797+
*(basePtr + 1) = (UInt16)(object)values[1];
798+
*(basePtr + 2) = (UInt16)(object)values[2];
799+
*(basePtr + 3) = (UInt16)(object)values[3];
800+
*(basePtr + 4) = (UInt16)(object)values[4];
801+
*(basePtr + 5) = (UInt16)(object)values[5];
802+
*(basePtr + 6) = (UInt16)(object)values[6];
803+
*(basePtr + 7) = (UInt16)(object)values[7];
804+
}
805+
}
806+
else if (typeof(T) == typeof(Int16))
807+
{
808+
fixed (Int16* basePtr = &this.register.int16_0)
809+
{
810+
*(basePtr + 0) = (Int16)(object)values[0];
811+
*(basePtr + 1) = (Int16)(object)values[1];
812+
*(basePtr + 2) = (Int16)(object)values[2];
813+
*(basePtr + 3) = (Int16)(object)values[3];
814+
*(basePtr + 4) = (Int16)(object)values[4];
815+
*(basePtr + 5) = (Int16)(object)values[5];
816+
*(basePtr + 6) = (Int16)(object)values[6];
817+
*(basePtr + 7) = (Int16)(object)values[7];
818+
}
819+
}
820+
else if (typeof(T) == typeof(UInt32))
821+
{
822+
fixed (UInt32* basePtr = &this.register.uint32_0)
823+
{
824+
*(basePtr + 0) = (UInt32)(object)values[0];
825+
*(basePtr + 1) = (UInt32)(object)values[1];
826+
*(basePtr + 2) = (UInt32)(object)values[2];
827+
*(basePtr + 3) = (UInt32)(object)values[3];
828+
}
829+
}
830+
else if (typeof(T) == typeof(Int32))
831+
{
832+
fixed (Int32* basePtr = &this.register.int32_0)
833+
{
834+
*(basePtr + 0) = (Int32)(object)values[0];
835+
*(basePtr + 1) = (Int32)(object)values[1];
836+
*(basePtr + 2) = (Int32)(object)values[2];
837+
*(basePtr + 3) = (Int32)(object)values[3];
838+
}
839+
}
840+
else if (typeof(T) == typeof(UInt64))
841+
{
842+
fixed (UInt64* basePtr = &this.register.uint64_0)
843+
{
844+
*(basePtr + 0) = (UInt64)(object)values[0];
845+
*(basePtr + 1) = (UInt64)(object)values[1];
846+
}
847+
}
848+
else if (typeof(T) == typeof(Int64))
849+
{
850+
fixed (Int64* basePtr = &this.register.int64_0)
851+
{
852+
*(basePtr + 0) = (Int64)(object)values[0];
853+
*(basePtr + 1) = (Int64)(object)values[1];
854+
}
855+
}
856+
else if (typeof(T) == typeof(Single))
857+
{
858+
fixed (Single* basePtr = &this.register.single_0)
859+
{
860+
*(basePtr + 0) = (Single)(object)values[0];
861+
*(basePtr + 1) = (Single)(object)values[1];
862+
*(basePtr + 2) = (Single)(object)values[2];
863+
*(basePtr + 3) = (Single)(object)values[3];
864+
}
865+
}
866+
else if (typeof(T) == typeof(Double))
867+
{
868+
fixed (Double* basePtr = &this.register.double_0)
869+
{
870+
*(basePtr + 0) = (Double)(object)values[0];
871+
*(basePtr + 1) = (Double)(object)values[1];
872+
}
873+
}
874+
}
875+
}
876+
626877
#pragma warning disable 3001 // void* is not a CLS-Compliant argument type
627878
internal unsafe Vector(void* dataPointer) : this(dataPointer, 0) { }
628879
#pragma warning restore 3001 // void* is not a CLS-Compliant argument type

src/Common/src/CoreLib/System/Numerics/Vector.tt

+61
Original file line numberDiff line numberDiff line change
@@ -245,6 +245,67 @@ namespace System.Numerics
245245
}
246246
}
247247

248+
/// <summary>
249+
/// Constructs a vector from the given span.
250+
/// The span must contain at least Vector'T.Count elements.
251+
/// </summary>
252+
public unsafe Vector(Span<T> values)
253+
: this()
254+
{
255+
if (values == null)
256+
{
257+
// Match the JIT's exception type here. For perf, a NullReference is thrown instead of an ArgumentNull.
258+
throw new NullReferenceException(SR.Arg_NullArgumentNullRef);
259+
}
260+
if (values.Length < Count)
261+
{
262+
throw new IndexOutOfRangeException();
263+
}
264+
265+
if (Vector.IsHardwareAccelerated)
266+
{
267+
<# foreach (Type type in supportedTypes)
268+
{
269+
#>
270+
<#=GenerateIfStatementHeader(type)#>
271+
{
272+
fixed (<#=type.Name#>* basePtr = &this.<#=GetRegisterFieldName(type, 0)#>)
273+
{
274+
for (int g = 0; g < Count; g++)
275+
{
276+
*(basePtr + g) = (<#=type.Name#>)(object)values[g];
277+
}
278+
}
279+
}
280+
<#
281+
}
282+
#>
283+
}
284+
else
285+
{
286+
<# foreach (Type type in supportedTypes)
287+
{
288+
#>
289+
<#=GenerateIfStatementHeader(type)#>
290+
{
291+
fixed (<#=type.Name#>* basePtr = &this.<#=GetRegisterFieldName(type, 0)#>)
292+
{
293+
<#
294+
for (int g = 0; g < GetNumFields(type, totalSize); g++)
295+
{
296+
#>
297+
*(basePtr + <#=g#>) = (<#=type.Name#>)(object)values[<#=g#>];
298+
<#
299+
}
300+
#>
301+
}
302+
}
303+
<#
304+
}
305+
#>
306+
}
307+
}
308+
248309
#pragma warning disable 3001 // void* is not a CLS-Compliant argument type
249310
internal unsafe Vector(void* dataPointer) : this(dataPointer, 0) { }
250311
#pragma warning restore 3001 // void* is not a CLS-Compliant argument type

src/System.Numerics.Vectors/ref/System.Numerics.Vectors.cs

+1
Original file line numberDiff line numberDiff line change
@@ -301,6 +301,7 @@ public partial struct Vector<T> : System.IEquatable<System.Numerics.Vector<T>>,
301301
public Vector(T value) { throw null; }
302302
public Vector(T[] values) { throw null; }
303303
public Vector(T[] values, int index) { throw null; }
304+
public Vector(System.Span<T> values) { throw null; }
304305
public static int Count { get { throw null; } }
305306
public T this[int index] { get { throw null; } }
306307
public static System.Numerics.Vector<T> One { get { throw null; } }

src/System.Numerics.Vectors/tests/GenericVectorTests.cs

+34-1
Original file line numberDiff line numberDiff line change
@@ -155,6 +155,39 @@ private void TestConstructorDefault<T>() where T : struct
155155
});
156156
}
157157

158+
[Fact]
159+
public void ConstructorWithSpanByte() { TestConstructorWithSpan<Byte>(); }
160+
[Fact]
161+
public void ConstructorWithSpanSByte() { TestConstructorWithSpan<SByte>(); }
162+
[Fact]
163+
public void ConstructorWithSpanUInt16() { TestConstructorWithSpan<UInt16>(); }
164+
[Fact]
165+
public void ConstructorWithSpanInt16() { TestConstructorWithSpan<Int16>(); }
166+
[Fact]
167+
public void ConstructorWithSpanUInt32() { TestConstructorWithSpan<UInt32>(); }
168+
[Fact]
169+
public void ConstructorWithSpanInt32() { TestConstructorWithSpan<Int32>(); }
170+
[Fact]
171+
public void ConstructorWithSpanUInt64() { TestConstructorWithSpan<UInt64>(); }
172+
[Fact]
173+
public void ConstructorWithSpanInt64() { TestConstructorWithSpan<Int64>(); }
174+
[Fact]
175+
public void ConstructorWithSpanSingle() { TestConstructorWithSpan<Single>(); }
176+
[Fact]
177+
public void ConstructorWithSpanDouble() { TestConstructorWithSpan<Double>(); }
178+
private void TestConstructorWithSpan<T>() where T : struct
179+
{
180+
T[] values = GenerateRandomValuesForVector<T>().ToArray();
181+
Span<T> valueSpan = new Span<T>(values);
182+
183+
var vector = new Vector<T>(valueSpan);
184+
ValidateVector(vector,
185+
(index, val) =>
186+
{
187+
Assert.Equal(val, values[index]);
188+
});
189+
}
190+
158191
[Fact]
159192
public void ConstructorExceptionByte() { TestConstructorArrayTooSmallException<Byte>(); }
160193
[Fact]
@@ -2776,4 +2809,4 @@ internal static T GetValueWithAllOnesSet<T>() where T : struct
27762809
}
27772810
#endregion
27782811
}
2779-
}
2812+
}

0 commit comments

Comments
 (0)