Skip to content

Commit 7e51126

Browse files
authored
Add generic overloads to TensorPrimitives (#94555)
* Add generic overloads to TensorPrimitives This overhauls the implementation and tests to have a generic overload for each existing float-based overload. I've avoided touching the core logic, but have augmented the structure in a few ways, e.g. only taking vectorized code paths when the type supports vectorization. To keep the shared definitions of the float-based APIs, on .NET 9 they delegate to shims that are implemented on top of the generic variants. The tests have all been made instance members, with an abstract base class containing most of the tests, and calling into abstract methods for the core operations and validation routines. Derived types then fill in this logic, letting us use all the tests for both the non-generic and generic overloads. Generic tests are validating most of the primitive types that implement the required interfaces. This does not yet: - Provide generic overloads for the IndexOfMin/Max{Magnitude} methods - Vectorize the trig-related functions for Ts other than floats * Disable tests on mono due to Vector128 bug * Change "Float" to "Single" in various file, type, and member names
1 parent 5455432 commit 7e51126

16 files changed

+8302
-5308
lines changed

src/libraries/System.Numerics.Tensors/ref/System.Numerics.Tensors.netcore.cs

+42
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,49 @@ namespace System.Numerics.Tensors
88
{
99
public static partial class TensorPrimitives
1010
{
11+
public static void Abs<T>(System.ReadOnlySpan<T> x, System.Span<T> destination) where T : System.Numerics.INumberBase<T> { }
12+
public static void AddMultiply<T>(System.ReadOnlySpan<T> x, System.ReadOnlySpan<T> y, System.ReadOnlySpan<T> multiplier, System.Span<T> destination) where T : System.Numerics.IAdditionOperators<T, T, T>, System.Numerics.IMultiplyOperators<T, T, T> { }
13+
public static void AddMultiply<T>(System.ReadOnlySpan<T> x, System.ReadOnlySpan<T> y, T multiplier, System.Span<T> destination) where T : System.Numerics.IAdditionOperators<T, T, T>, System.Numerics.IMultiplyOperators<T, T, T> { }
14+
public static void AddMultiply<T>(System.ReadOnlySpan<T> x, T y, System.ReadOnlySpan<T> multiplier, System.Span<T> destination) where T : System.Numerics.IAdditionOperators<T, T, T>, System.Numerics.IMultiplyOperators<T, T, T> { }
15+
public static void Add<T>(System.ReadOnlySpan<T> x, System.ReadOnlySpan<T> y, System.Span<T> destination) where T : System.Numerics.IAdditionOperators<T, T, T>, System.Numerics.IAdditiveIdentity<T, T> { }
16+
public static void Add<T>(System.ReadOnlySpan<T> x, T y, System.Span<T> destination) where T : System.Numerics.IAdditionOperators<T, T, T>, System.Numerics.IAdditiveIdentity<T, T> { }
1117
public static void ConvertToHalf(System.ReadOnlySpan<float> source, System.Span<System.Half> destination) { throw null; }
1218
public static void ConvertToSingle(System.ReadOnlySpan<System.Half> source, System.Span<float> destination) { throw null; }
19+
public static void Cosh<T>(System.ReadOnlySpan<T> x, System.Span<T> destination) where T : System.Numerics.IHyperbolicFunctions<T> { }
20+
public static T CosineSimilarity<T>(System.ReadOnlySpan<T> x, System.ReadOnlySpan<T> y) where T : System.Numerics.IRootFunctions<T> { throw null; }
21+
public static T Distance<T>(System.ReadOnlySpan<T> x, System.ReadOnlySpan<T> y) where T : System.Numerics.IRootFunctions<T> { throw null; }
22+
public static void Divide<T>(System.ReadOnlySpan<T> x, System.ReadOnlySpan<T> y, System.Span<T> destination) where T : System.Numerics.IDivisionOperators<T, T, T> { }
23+
public static void Divide<T>(System.ReadOnlySpan<T> x, T y, System.Span<T> destination) where T : System.Numerics.IDivisionOperators<T, T, T> { }
24+
public static T Dot<T>(System.ReadOnlySpan<T> x, System.ReadOnlySpan<T> y) where T : System.Numerics.IAdditionOperators<T, T, T>, System.Numerics.IAdditiveIdentity<T, T>, System.Numerics.IMultiplyOperators<T, T, T>, System.Numerics.IMultiplicativeIdentity<T, T> { throw null; }
25+
public static void Exp<T>(System.ReadOnlySpan<T> x, System.Span<T> destination) where T : System.Numerics.IExponentialFunctions<T> { }
26+
public static void Log2<T>(System.ReadOnlySpan<T> x, System.Span<T> destination) where T : System.Numerics.ILogarithmicFunctions<T> { }
27+
public static void Log<T>(System.ReadOnlySpan<T> x, System.Span<T> destination) where T : System.Numerics.ILogarithmicFunctions<T> { }
28+
public static T MaxMagnitude<T>(System.ReadOnlySpan<T> x) where T : System.Numerics.INumberBase<T> { throw null; }
29+
public static void MaxMagnitude<T>(System.ReadOnlySpan<T> x, System.ReadOnlySpan<T> y, System.Span<T> destination) where T : System.Numerics.INumberBase<T> { }
30+
public static T Max<T>(System.ReadOnlySpan<T> x) where T : System.Numerics.INumber<T> { throw null; }
31+
public static void Max<T>(System.ReadOnlySpan<T> x, System.ReadOnlySpan<T> y, System.Span<T> destination) where T : System.Numerics.INumber<T> { }
32+
public static T MinMagnitude<T>(System.ReadOnlySpan<T> x) where T : System.Numerics.INumberBase<T> { throw null; }
33+
public static void MinMagnitude<T>(System.ReadOnlySpan<T> x, System.ReadOnlySpan<T> y, System.Span<T> destination) where T : System.Numerics.INumberBase<T> { }
34+
public static T Min<T>(System.ReadOnlySpan<T> x) where T : System.Numerics.INumber<T> { throw null; }
35+
public static void Min<T>(System.ReadOnlySpan<T> x, System.ReadOnlySpan<T> y, System.Span<T> destination) where T : System.Numerics.INumber<T> { }
36+
public static void MultiplyAdd<T>(System.ReadOnlySpan<T> x, System.ReadOnlySpan<T> y, System.ReadOnlySpan<T> addend, System.Span<T> destination) where T : System.Numerics.IAdditionOperators<T, T, T>, System.Numerics.IMultiplyOperators<T, T, T> { }
37+
public static void MultiplyAdd<T>(System.ReadOnlySpan<T> x, System.ReadOnlySpan<T> y, T addend, System.Span<T> destination) where T : System.Numerics.IAdditionOperators<T, T, T>, System.Numerics.IMultiplyOperators<T, T, T> { }
38+
public static void MultiplyAdd<T>(System.ReadOnlySpan<T> x, T y, System.ReadOnlySpan<T> addend, System.Span<T> destination) where T : System.Numerics.IAdditionOperators<T, T, T>, System.Numerics.IMultiplyOperators<T, T, T> { }
39+
public static void Multiply<T>(System.ReadOnlySpan<T> x, System.ReadOnlySpan<T> y, System.Span<T> destination) where T : System.Numerics.IMultiplyOperators<T, T, T>, System.Numerics.IMultiplicativeIdentity<T, T> { }
40+
public static void Multiply<T>(System.ReadOnlySpan<T> x, T y, System.Span<T> destination) where T : System.Numerics.IMultiplyOperators<T, T, T>, System.Numerics.IMultiplicativeIdentity<T, T> { }
41+
public static void Negate<T>(System.ReadOnlySpan<T> x, System.Span<T> destination) where T : System.Numerics.IUnaryNegationOperators<T, T> { }
42+
public static T Norm<T>(System.ReadOnlySpan<T> x) where T : System.Numerics.IRootFunctions<T> { throw null; }
43+
public static T ProductOfDifferences<T>(System.ReadOnlySpan<T> x, System.ReadOnlySpan<T> y) where T : System.Numerics.ISubtractionOperators<T, T, T>, System.Numerics.IMultiplyOperators<T, T, T>, System.Numerics.IMultiplicativeIdentity<T, T> { throw null; }
44+
public static T ProductOfSums<T>(System.ReadOnlySpan<T> x, System.ReadOnlySpan<T> y) where T : System.Numerics.IAdditionOperators<T, T, T>, System.Numerics.IAdditiveIdentity<T, T>, System.Numerics.IMultiplyOperators<T, T, T>, System.Numerics.IMultiplicativeIdentity<T, T> { throw null; }
45+
public static T Product<T>(System.ReadOnlySpan<T> x) where T : System.Numerics.IMultiplyOperators<T, T, T>, System.Numerics.IMultiplicativeIdentity<T, T> { throw null; }
46+
public static void Sigmoid<T>(System.ReadOnlySpan<T> x, System.Span<T> destination) where T : System.Numerics.IExponentialFunctions<T> { }
47+
public static void Sinh<T>(System.ReadOnlySpan<T> x, System.Span<T> destination) where T : System.Numerics.IHyperbolicFunctions<T> { }
48+
public static void SoftMax<T>(System.ReadOnlySpan<T> x, System.Span<T> destination) where T : System.Numerics.IExponentialFunctions<T> { }
49+
public static void Subtract<T>(System.ReadOnlySpan<T> x, System.ReadOnlySpan<T> y, System.Span<T> destination) where T : System.Numerics.ISubtractionOperators<T, T, T> { }
50+
public static void Subtract<T>(System.ReadOnlySpan<T> x, T y, System.Span<T> destination) where T : System.Numerics.ISubtractionOperators<T, T, T> { }
51+
public static T SumOfMagnitudes<T>(System.ReadOnlySpan<T> x) where T : System.Numerics.INumberBase<T> { throw null; }
52+
public static T SumOfSquares<T>(System.ReadOnlySpan<T> x) where T : System.Numerics.IAdditionOperators<T, T, T>, System.Numerics.IAdditiveIdentity<T, T>, System.Numerics.IMultiplyOperators<T, T, T> { throw null; }
53+
public static T Sum<T>(System.ReadOnlySpan<T> x) where T : System.Numerics.IAdditionOperators<T, T, T>, System.Numerics.IAdditiveIdentity<T, T> { throw null; }
54+
public static void Tanh<T>(System.ReadOnlySpan<T> x, System.Span<T> destination) where T : System.Numerics.IHyperbolicFunctions<T> { }
1355
}
1456
}

src/libraries/System.Numerics.Tensors/src/Resources/Strings.resx

+3
Original file line numberDiff line numberDiff line change
@@ -129,4 +129,7 @@
129129
<data name="Argument_InputAndDestinationSpanMustNotOverlap" xml:space="preserve">
130130
<value>The destination span may only overlap with an input span if the two spans start at the same memory location.</value>
131131
</data>
132+
<data name="Overflow_NegateTwosCompNum" xml:space="preserve">
133+
<value>Negating the minimum value of a twos complement number is invalid.</value>
134+
</data>
132135
</root>

src/libraries/System.Numerics.Tensors/src/System.Numerics.Tensors.csproj

+6-3
Original file line numberDiff line numberDiff line change
@@ -9,16 +9,19 @@
99
</PropertyGroup>
1010

1111
<ItemGroup>
12-
<Compile Include="System\Numerics\Tensors\TensorPrimitives.cs" />
12+
<Compile Include="System\Numerics\Tensors\TensorPrimitives.Single.cs" />
13+
<Compile Include="System\Numerics\Tensors\TensorPrimitives.Helpers.cs" />
1314
<Compile Include="System\ThrowHelper.cs" />
1415
</ItemGroup>
1516

1617
<ItemGroup Condition="'$(TargetFrameworkIdentifier)' == '.NETCoreApp'">
17-
<Compile Include="System\Numerics\Tensors\TensorPrimitives.netcore.cs" />
18+
<Compile Include="System\Numerics\Tensors\netcore\TensorPrimitives.Single.netcore.cs" />
19+
<Compile Include="System\Numerics\Tensors\netcore\TensorPrimitives.T.cs" />
20+
<Compile Include="System\Numerics\Tensors\netcore\TensorPrimitives.netcore.cs" />
1821
</ItemGroup>
1922

2023
<ItemGroup Condition="'$(TargetFrameworkIdentifier)' != '.NETCoreApp'">
21-
<Compile Include="System\Numerics\Tensors\TensorPrimitives.netstandard.cs" />
24+
<Compile Include="System\Numerics\Tensors\netstandard\TensorPrimitives.Single.netstandard.cs" />
2225
<PackageReference Include="System.Memory" Version="$(SystemMemoryVersion)" />
2326
<ProjectReference Include="$(LibrariesProjectRoot)Microsoft.Bcl.Numerics\src\Microsoft.Bcl.Numerics.csproj" />
2427
</ItemGroup>

0 commit comments

Comments
 (0)