Skip to content

Commit 5273aad

Browse files
committed
Issue #24343 Vector Ctor using Span
1 parent 4852538 commit 5273aad

File tree

5 files changed

+370
-2
lines changed

5 files changed

+370
-2
lines changed

src/Common/src/CoreLib/System/Numerics/Vector.cs

+251
Original file line numberDiff line numberDiff line change
@@ -610,6 +610,257 @@ public unsafe Vector(T[] values, int index)
610610
}
611611
}
612612

613+
/// <summary>
614+
/// Constructs a vector from the given span.
615+
/// The span must contain at least Vector'T.Count elements.
616+
/// </summary>
617+
public unsafe Vector(Span<T> values)
618+
: this()
619+
{
620+
if (values == null)
621+
{
622+
// Match the JIT's exception type here. For perf, a NullReference is thrown instead of an ArgumentNull.
623+
throw new NullReferenceException(SR.Arg_NullArgumentNullRef);
624+
}
625+
if (values.Length < Count)
626+
{
627+
throw new IndexOutOfRangeException();
628+
}
629+
630+
if (Vector.IsHardwareAccelerated)
631+
{
632+
if (typeof(T) == typeof(Byte))
633+
{
634+
fixed (Byte* basePtr = &this.register.byte_0)
635+
{
636+
for (int g = 0; g < Count; g++)
637+
{
638+
*(basePtr + g) = (Byte)(object)values[g];
639+
}
640+
}
641+
}
642+
else if (typeof(T) == typeof(SByte))
643+
{
644+
fixed (SByte* basePtr = &this.register.sbyte_0)
645+
{
646+
for (int g = 0; g < Count; g++)
647+
{
648+
*(basePtr + g) = (SByte)(object)values[g];
649+
}
650+
}
651+
}
652+
else if (typeof(T) == typeof(UInt16))
653+
{
654+
fixed (UInt16* basePtr = &this.register.uint16_0)
655+
{
656+
for (int g = 0; g < Count; g++)
657+
{
658+
*(basePtr + g) = (UInt16)(object)values[g];
659+
}
660+
}
661+
}
662+
else if (typeof(T) == typeof(Int16))
663+
{
664+
fixed (Int16* basePtr = &this.register.int16_0)
665+
{
666+
for (int g = 0; g < Count; g++)
667+
{
668+
*(basePtr + g) = (Int16)(object)values[g];
669+
}
670+
}
671+
}
672+
else if (typeof(T) == typeof(UInt32))
673+
{
674+
fixed (UInt32* basePtr = &this.register.uint32_0)
675+
{
676+
for (int g = 0; g < Count; g++)
677+
{
678+
*(basePtr + g) = (UInt32)(object)values[g];
679+
}
680+
}
681+
}
682+
else if (typeof(T) == typeof(Int32))
683+
{
684+
fixed (Int32* basePtr = &this.register.int32_0)
685+
{
686+
for (int g = 0; g < Count; g++)
687+
{
688+
*(basePtr + g) = (Int32)(object)values[g];
689+
}
690+
}
691+
}
692+
else if (typeof(T) == typeof(UInt64))
693+
{
694+
fixed (UInt64* basePtr = &this.register.uint64_0)
695+
{
696+
for (int g = 0; g < Count; g++)
697+
{
698+
*(basePtr + g) = (UInt64)(object)values[g];
699+
}
700+
}
701+
}
702+
else if (typeof(T) == typeof(Int64))
703+
{
704+
fixed (Int64* basePtr = &this.register.int64_0)
705+
{
706+
for (int g = 0; g < Count; g++)
707+
{
708+
*(basePtr + g) = (Int64)(object)values[g];
709+
}
710+
}
711+
}
712+
else if (typeof(T) == typeof(Single))
713+
{
714+
fixed (Single* basePtr = &this.register.single_0)
715+
{
716+
for (int g = 0; g < Count; g++)
717+
{
718+
*(basePtr + g) = (Single)(object)values[g];
719+
}
720+
}
721+
}
722+
else if (typeof(T) == typeof(Double))
723+
{
724+
fixed (Double* basePtr = &this.register.double_0)
725+
{
726+
for (int g = 0; g < Count; g++)
727+
{
728+
*(basePtr + g) = (Double)(object)values[g];
729+
}
730+
}
731+
}
732+
}
733+
else
734+
{
735+
if (typeof(T) == typeof(Byte))
736+
{
737+
fixed (Byte* basePtr = &this.register.byte_0)
738+
{
739+
*(basePtr + 0) = (Byte)(object)values[0];
740+
*(basePtr + 1) = (Byte)(object)values[1];
741+
*(basePtr + 2) = (Byte)(object)values[2];
742+
*(basePtr + 3) = (Byte)(object)values[3];
743+
*(basePtr + 4) = (Byte)(object)values[4];
744+
*(basePtr + 5) = (Byte)(object)values[5];
745+
*(basePtr + 6) = (Byte)(object)values[6];
746+
*(basePtr + 7) = (Byte)(object)values[7];
747+
*(basePtr + 8) = (Byte)(object)values[8];
748+
*(basePtr + 9) = (Byte)(object)values[9];
749+
*(basePtr + 10) = (Byte)(object)values[10];
750+
*(basePtr + 11) = (Byte)(object)values[11];
751+
*(basePtr + 12) = (Byte)(object)values[12];
752+
*(basePtr + 13) = (Byte)(object)values[13];
753+
*(basePtr + 14) = (Byte)(object)values[14];
754+
*(basePtr + 15) = (Byte)(object)values[15];
755+
}
756+
}
757+
else if (typeof(T) == typeof(SByte))
758+
{
759+
fixed (SByte* basePtr = &this.register.sbyte_0)
760+
{
761+
*(basePtr + 0) = (SByte)(object)values[0];
762+
*(basePtr + 1) = (SByte)(object)values[1];
763+
*(basePtr + 2) = (SByte)(object)values[2];
764+
*(basePtr + 3) = (SByte)(object)values[3];
765+
*(basePtr + 4) = (SByte)(object)values[4];
766+
*(basePtr + 5) = (SByte)(object)values[5];
767+
*(basePtr + 6) = (SByte)(object)values[6];
768+
*(basePtr + 7) = (SByte)(object)values[7];
769+
*(basePtr + 8) = (SByte)(object)values[8];
770+
*(basePtr + 9) = (SByte)(object)values[9];
771+
*(basePtr + 10) = (SByte)(object)values[10];
772+
*(basePtr + 11) = (SByte)(object)values[11];
773+
*(basePtr + 12) = (SByte)(object)values[12];
774+
*(basePtr + 13) = (SByte)(object)values[13];
775+
*(basePtr + 14) = (SByte)(object)values[14];
776+
*(basePtr + 15) = (SByte)(object)values[15];
777+
}
778+
}
779+
else if (typeof(T) == typeof(UInt16))
780+
{
781+
fixed (UInt16* basePtr = &this.register.uint16_0)
782+
{
783+
*(basePtr + 0) = (UInt16)(object)values[0];
784+
*(basePtr + 1) = (UInt16)(object)values[1];
785+
*(basePtr + 2) = (UInt16)(object)values[2];
786+
*(basePtr + 3) = (UInt16)(object)values[3];
787+
*(basePtr + 4) = (UInt16)(object)values[4];
788+
*(basePtr + 5) = (UInt16)(object)values[5];
789+
*(basePtr + 6) = (UInt16)(object)values[6];
790+
*(basePtr + 7) = (UInt16)(object)values[7];
791+
}
792+
}
793+
else if (typeof(T) == typeof(Int16))
794+
{
795+
fixed (Int16* basePtr = &this.register.int16_0)
796+
{
797+
*(basePtr + 0) = (Int16)(object)values[0];
798+
*(basePtr + 1) = (Int16)(object)values[1];
799+
*(basePtr + 2) = (Int16)(object)values[2];
800+
*(basePtr + 3) = (Int16)(object)values[3];
801+
*(basePtr + 4) = (Int16)(object)values[4];
802+
*(basePtr + 5) = (Int16)(object)values[5];
803+
*(basePtr + 6) = (Int16)(object)values[6];
804+
*(basePtr + 7) = (Int16)(object)values[7];
805+
}
806+
}
807+
else if (typeof(T) == typeof(UInt32))
808+
{
809+
fixed (UInt32* basePtr = &this.register.uint32_0)
810+
{
811+
*(basePtr + 0) = (UInt32)(object)values[0];
812+
*(basePtr + 1) = (UInt32)(object)values[1];
813+
*(basePtr + 2) = (UInt32)(object)values[2];
814+
*(basePtr + 3) = (UInt32)(object)values[3];
815+
}
816+
}
817+
else if (typeof(T) == typeof(Int32))
818+
{
819+
fixed (Int32* basePtr = &this.register.int32_0)
820+
{
821+
*(basePtr + 0) = (Int32)(object)values[0];
822+
*(basePtr + 1) = (Int32)(object)values[1];
823+
*(basePtr + 2) = (Int32)(object)values[2];
824+
*(basePtr + 3) = (Int32)(object)values[3];
825+
}
826+
}
827+
else if (typeof(T) == typeof(UInt64))
828+
{
829+
fixed (UInt64* basePtr = &this.register.uint64_0)
830+
{
831+
*(basePtr + 0) = (UInt64)(object)values[0];
832+
*(basePtr + 1) = (UInt64)(object)values[1];
833+
}
834+
}
835+
else if (typeof(T) == typeof(Int64))
836+
{
837+
fixed (Int64* basePtr = &this.register.int64_0)
838+
{
839+
*(basePtr + 0) = (Int64)(object)values[0];
840+
*(basePtr + 1) = (Int64)(object)values[1];
841+
}
842+
}
843+
else if (typeof(T) == typeof(Single))
844+
{
845+
fixed (Single* basePtr = &this.register.single_0)
846+
{
847+
*(basePtr + 0) = (Single)(object)values[0];
848+
*(basePtr + 1) = (Single)(object)values[1];
849+
*(basePtr + 2) = (Single)(object)values[2];
850+
*(basePtr + 3) = (Single)(object)values[3];
851+
}
852+
}
853+
else if (typeof(T) == typeof(Double))
854+
{
855+
fixed (Double* basePtr = &this.register.double_0)
856+
{
857+
*(basePtr + 0) = (Double)(object)values[0];
858+
*(basePtr + 1) = (Double)(object)values[1];
859+
}
860+
}
861+
}
862+
}
863+
613864
#pragma warning disable 3001 // void* is not a CLS-Compliant argument type
614865
internal unsafe Vector(void* dataPointer) : this(dataPointer, 0) { }
615866
#pragma warning restore 3001 // void* is not a CLS-Compliant argument type

src/Common/src/CoreLib/System/Numerics/Vector.tt

+61
Original file line numberDiff line numberDiff line change
@@ -232,6 +232,67 @@ namespace System.Numerics
232232
}
233233
}
234234

235+
/// <summary>
236+
/// Constructs a vector from the given span.
237+
/// The span must contain at least Vector'T.Count elements.
238+
/// </summary>
239+
public unsafe Vector(Span<T> values)
240+
: this()
241+
{
242+
if (values == null)
243+
{
244+
// Match the JIT's exception type here. For perf, a NullReference is thrown instead of an ArgumentNull.
245+
throw new NullReferenceException(SR.Arg_NullArgumentNullRef);
246+
}
247+
if (values.Length < Count)
248+
{
249+
throw new IndexOutOfRangeException();
250+
}
251+
252+
if (Vector.IsHardwareAccelerated)
253+
{
254+
<# foreach (Type type in supportedTypes)
255+
{
256+
#>
257+
<#=GenerateIfStatementHeader(type)#>
258+
{
259+
fixed (<#=type.Name#>* basePtr = &this.<#=GetRegisterFieldName(type, 0)#>)
260+
{
261+
for (int g = 0; g < Count; g++)
262+
{
263+
*(basePtr + g) = (<#=type.Name#>)(object)values[g];
264+
}
265+
}
266+
}
267+
<#
268+
}
269+
#>
270+
}
271+
else
272+
{
273+
<# foreach (Type type in supportedTypes)
274+
{
275+
#>
276+
<#=GenerateIfStatementHeader(type)#>
277+
{
278+
fixed (<#=type.Name#>* basePtr = &this.<#=GetRegisterFieldName(type, 0)#>)
279+
{
280+
<#
281+
for (int g = 0; g < GetNumFields(type, totalSize); g++)
282+
{
283+
#>
284+
*(basePtr + <#=g#>) = (<#=type.Name#>)(object)values[<#=g#>];
285+
<#
286+
}
287+
#>
288+
}
289+
}
290+
<#
291+
}
292+
#>
293+
}
294+
}
295+
235296
#pragma warning disable 3001 // void* is not a CLS-Compliant argument type
236297
internal unsafe Vector(void* dataPointer) : this(dataPointer, 0) { }
237298
#pragma warning restore 3001 // void* is not a CLS-Compliant argument type

src/System.Numerics.Vectors/ref/System.Numerics.Vectors.cs

+1
Original file line numberDiff line numberDiff line change
@@ -301,6 +301,7 @@ public partial struct Vector<T> : System.IEquatable<System.Numerics.Vector<T>>,
301301
public Vector(T value) { throw null; }
302302
public Vector(T[] values) { throw null; }
303303
public Vector(T[] values, int index) { throw null; }
304+
public Vector(System.Span<T> values) { throw null; }
304305
public static int Count { get { throw null; } }
305306
public T this[int index] { get { throw null; } }
306307
public static System.Numerics.Vector<T> One { get { throw null; } }

src/System.Numerics.Vectors/tests/GenericVectorTests.cs

+34-1
Original file line numberDiff line numberDiff line change
@@ -155,6 +155,39 @@ private void TestConstructorDefault<T>() where T : struct
155155
});
156156
}
157157

158+
[Fact]
159+
public void ConstructorWithSpanByte() { TestConstructorWithSpan<Byte>(); }
160+
[Fact]
161+
public void ConstructorWithSpanSByte() { TestConstructorWithSpan<SByte>(); }
162+
[Fact]
163+
public void ConstructorWithSpanUInt16() { TestConstructorWithSpan<UInt16>(); }
164+
[Fact]
165+
public void ConstructorWithSpanInt16() { TestConstructorWithSpan<Int16>(); }
166+
[Fact]
167+
public void ConstructorWithSpanUInt32() { TestConstructorWithSpan<UInt32>(); }
168+
[Fact]
169+
public void ConstructorWithSpanInt32() { TestConstructorWithSpan<Int32>(); }
170+
[Fact]
171+
public void ConstructorWithSpanUInt64() { TestConstructorWithSpan<UInt64>(); }
172+
[Fact]
173+
public void ConstructorWithSpanInt64() { TestConstructorWithSpan<Int64>(); }
174+
[Fact]
175+
public void ConstructorWithSpanSingle() { TestConstructorWithSpan<Single>(); }
176+
[Fact]
177+
public void ConstructorWithSpanDouble() { TestConstructorWithSpan<Double>(); }
178+
private void TestConstructorWithSpan<T>() where T : struct
179+
{
180+
T[] values = GenerateRandomValuesForVector<T>().ToArray();
181+
Span<T> valueSpan = new Span<T>(values);
182+
183+
var vector = new Vector<T>(valueSpan);
184+
ValidateVector(vector,
185+
(index, val) =>
186+
{
187+
Assert.Equal(val, values[index]);
188+
});
189+
}
190+
158191
[Fact]
159192
public void ConstructorExceptionByte() { TestConstructorArrayTooSmallException<Byte>(); }
160193
[Fact]
@@ -2776,4 +2809,4 @@ internal static T GetValueWithAllOnesSet<T>() where T : struct
27762809
}
27772810
#endregion
27782811
}
2779-
}
2812+
}

0 commit comments

Comments
 (0)