Skip to content

Commit 204f0e1

Browse files
Tensor validations, scope changes, and TensorPrimitives methods (#103005)
* Additional validations and TensorPrimitives forwarding * TensorSpan extensions * ref updates * more testing * test fixes and ref fixes
1 parent 2ea80d6 commit 204f0e1

22 files changed

+6925
-1213
lines changed

src/libraries/System.Numerics.Tensors/ref/System.Numerics.Tensors.netcore.cs

+349-97
Large diffs are not rendered by default.

src/libraries/System.Numerics.Tensors/src/Resources/Strings.resx

+6
Original file line numberDiff line numberDiff line change
@@ -219,4 +219,10 @@
219219
<data name="ThrowArgument_StrideLessThan0" xml:space="preserve">
220220
<value>Strides cannot be less than 0.</value>
221221
</data>
222+
<data name="Argument_2DTensorRequired" xml:space="preserve">
223+
<value>Must be a 2d Tensor.</value>
224+
</data>
225+
<data name="Argument_IncompatibleDimensions" xml:space="preserve">
226+
<value>Incompatible dimensions for provided tensors. left.Lengths[1] == {0} while right.Lengths[1] == {1}.</value>
227+
</data>
222228
</root>

src/libraries/System.Numerics.Tensors/src/System.Numerics.Tensors.csproj

+1-1
Original file line numberDiff line numberDiff line change
@@ -16,12 +16,12 @@
1616
</ItemGroup>
1717

1818
<ItemGroup Condition="'$(TargetFrameworkIdentifier)' == '.NETCoreApp'">
19-
<Compile Include="System\Numerics\Tensors\netcore\Common\IReadOnlyTensor.cs" />
2019
<Compile Include="System\Numerics\Tensors\netcore\TensorHelpers.cs" />
2120
<Compile Include="System\Numerics\Tensors\netcore\TensorExtensions.cs" />
2221
<Compile Include="System\Numerics\Tensors\netcore\Tensor.Factory.cs" />
2322
<Compile Include="System\Numerics\Tensors\netcore\Tensor.cs" />
2423
<Compile Include="System\Numerics\Tensors\netcore\ITensor.cs" />
24+
<Compile Include="System\Numerics\Tensors\netcore\IReadOnlyTensor.cs" />
2525
<Compile Include="System\Numerics\Tensors\netcore\TensorSpanDebugView.cs" />
2626
<Compile Include="System\Numerics\Tensors\netcore\TensorSpanExtensions.cs" />
2727
<Compile Include="System\Numerics\Tensors\netcore\ReadOnlyTensorSpan.cs" />

src/libraries/System.Numerics.Tensors/src/System/Numerics/Tensors/netcore/Common/IReadOnlyTensor.cs src/libraries/System.Numerics.Tensors/src/System/Numerics/Tensors/netcore/IReadOnlyTensor.cs

+2-6
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,8 @@
11
// Licensed to the .NET Foundation under one or more agreements.
22
// The .NET Foundation licenses this file to you under the MIT license.
33

4-
using System;
54
using System.Buffers;
65
using System.Collections.Generic;
7-
using System.Linq;
8-
using System.Text;
9-
using System.Threading.Tasks;
106

117
namespace System.Numerics.Tensors
128
{
@@ -20,8 +16,8 @@ public interface IReadOnlyTensor<TSelf, T> : IEnumerable<T>
2016
nint FlattenedLength { get; }
2117
int Rank { get; }
2218

23-
T this[params ReadOnlySpan<nint> indexes] { get; }
24-
T this[params ReadOnlySpan<NIndex> indexes] { get; }
19+
T this[params scoped ReadOnlySpan<nint> indexes] { get; }
20+
T this[params scoped ReadOnlySpan<NIndex> indexes] { get; }
2521
TSelf this[params scoped ReadOnlySpan<NRange> ranges] { get; }
2622

2723
ReadOnlyTensorSpan<T> AsReadOnlyTensorSpan();

src/libraries/System.Numerics.Tensors/src/System/Numerics/Tensors/netcore/ITensor.cs

+2-7
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,7 @@
11
// Licensed to the .NET Foundation under one or more agreements.
22
// The .NET Foundation licenses this file to you under the MIT license.
33

4-
using System;
54
using System.Buffers;
6-
using System.Collections.Generic;
7-
using System.Linq;
8-
using System.Text;
9-
using System.Threading.Tasks;
105

116
namespace System.Numerics.Tensors
127
{
@@ -26,8 +21,8 @@ public interface ITensor<TSelf, T>
2621

2722
bool IsReadOnly { get; }
2823

29-
new T this[params ReadOnlySpan<nint> indexes] { get; set; }
30-
new T this[params ReadOnlySpan<NIndex> indexes] { get; set; }
24+
new T this[params scoped ReadOnlySpan<nint> indexes] { get; set; }
25+
new T this[params scoped ReadOnlySpan<NIndex> indexes] { get; set; }
3126
new TSelf this[params scoped ReadOnlySpan<NRange> ranges] { get; set; }
3227

3328
TensorSpan<T> AsTensorSpan();

src/libraries/System.Numerics.Tensors/src/System/Numerics/Tensors/netcore/ReadOnlyTensorSpan.cs

+56-38
Original file line numberDiff line numberDiff line change
@@ -7,9 +7,7 @@
77
using System.Linq;
88
using System.Runtime.CompilerServices;
99
using System.Runtime.InteropServices;
10-
using System.Runtime.InteropServices.Marshalling;
11-
using System.Runtime.Versioning;
12-
using System.Text;
10+
using static System.Runtime.InteropServices.JavaScript.JSType;
1311
using EditorBrowsableAttribute = System.ComponentModel.EditorBrowsableAttribute;
1412
using EditorBrowsableState = System.ComponentModel.EditorBrowsableState;
1513

@@ -43,7 +41,7 @@ public readonly ref struct ReadOnlyTensorSpan<T>
4341
/// <param name="array">The target array.</param>
4442
/// <remarks>Returns default when <paramref name="array"/> is null.</remarks>
4543
/// <exception cref="ArrayTypeMismatchException">Thrown when <paramref name="array"/> is covariant and array's type is not exactly T[].</exception>
46-
public ReadOnlyTensorSpan(T[]? array) : this(array, 0, [], [])
44+
public ReadOnlyTensorSpan(T[]? array) : this(array, 0, [array?.Length ?? 0], [])
4745
{
4846
}
4947

@@ -81,6 +79,9 @@ public ReadOnlyTensorSpan(T[]? array, Index startIndex, scoped ReadOnlySpan<nint
8179
[MethodImpl(MethodImplOptions.AggressiveInlining)]
8280
public ReadOnlyTensorSpan(T[]? array, int start, scoped ReadOnlySpan<nint> lengths, scoped ReadOnlySpan<nint> strides)
8381
{
82+
if (lengths.IsEmpty && array != null)
83+
lengths = [array.Length];
84+
8485
nint linearLength = TensorSpanHelpers.CalculateTotalLength(lengths);
8586
if (array == null)
8687
{
@@ -92,18 +93,20 @@ public ReadOnlyTensorSpan(T[]? array, int start, scoped ReadOnlySpan<nint> lengt
9293
if (!typeof(T).IsValueType && array.GetType() != typeof(T[]))
9394
ThrowHelper.ThrowArrayTypeMismatchException();
9495

95-
strides = strides.IsEmpty ? (ReadOnlySpan<nint>)TensorSpanHelpers.CalculateStrides(lengths) : strides;
96-
nint maxElements = TensorSpanHelpers.ComputeMaxElementCount(strides, lengths);
96+
strides = strides.IsEmpty ? (ReadOnlySpan<nint>)TensorSpanHelpers.CalculateStrides(lengths, linearLength) : strides;
97+
TensorSpanHelpers.ValidateStrides(strides, lengths);
98+
nint maxElements = TensorSpanHelpers.ComputeMaxLinearIndex(strides, lengths);
99+
97100
if (Environment.Is64BitProcess)
98101
{
99102
// See comment in Span<T>.Slice for how this works.
100-
if ((ulong)(uint)start + (ulong)(uint)maxElements > (ulong)(uint)array.Length)
101-
ThrowHelper.ThrowArgumentOutOfRangeException();
103+
if ((ulong)(uint)start + (ulong)(uint)maxElements >= (ulong)(uint)array.Length && array.Length != 0)
104+
ThrowHelper.ThrowArgument_InvalidStridesAndLengths();
102105
}
103106
else
104107
{
105-
if ((uint)start > (uint)array.Length || (uint)maxElements > (uint)(array.Length - start))
106-
ThrowHelper.ThrowArgumentOutOfRangeException();
108+
if (((uint)start > (uint)array.Length || (uint)maxElements >= (uint)(array.Length - start)) && array.Length != 0)
109+
ThrowHelper.ThrowArgument_InvalidStridesAndLengths();
107110
}
108111

109112
_flattenedLength = linearLength;
@@ -115,8 +118,8 @@ public ReadOnlyTensorSpan(T[]? array, int start, scoped ReadOnlySpan<nint> lengt
115118
}
116119

117120
/// <summary>
118-
/// Creates a new <see cref="ReadOnlyTensorSpan{T}"/> over the provided <see cref="Span{T}"/>. The new <see cref="ReadOnlyTensorSpan{T}"/> will
119-
/// have a rank of 1 and a length equal to the length of the provided <see cref="Span{T}"/>.
121+
/// Creates a new <see cref="ReadOnlyTensorSpan{T}"/> over the provided <see cref="ReadOnlySpan{T}"/>. The new <see cref="ReadOnlyTensorSpan{T}"/> will
122+
/// have a rank of 1 and a length equal to the length of the provided <see cref="ReadOnlySpan{T}"/>.
120123
/// </summary>
121124
/// <param name="span">The target span.</param>
122125
public ReadOnlyTensorSpan(ReadOnlySpan<T> span) : this(span, [span.Length], []) { }
@@ -130,18 +133,15 @@ public ReadOnlyTensorSpan(ReadOnlySpan<T> span) : this(span, [span.Length], [])
130133
/// <param name="strides">The strides for each dimension. Will be automatically calculated if not provided.</param>
131134
public ReadOnlyTensorSpan(ReadOnlySpan<T> span, scoped ReadOnlySpan<nint> lengths, scoped ReadOnlySpan<nint> strides)
132135
{
136+
if (lengths.IsEmpty)
137+
lengths = [span.Length];
138+
133139
nint linearLength = TensorSpanHelpers.CalculateTotalLength(lengths);
134-
if (span.IsEmpty)
135-
{
136-
if (linearLength != 0)
137-
ThrowHelper.ThrowArgumentOutOfRangeException();
138-
this = default;
139-
return; // returns default
140-
}
141140

142-
strides = strides.IsEmpty ? (ReadOnlySpan<nint>)TensorSpanHelpers.CalculateStrides(lengths) : strides;
143-
nint maxElements = TensorSpanHelpers.ComputeMaxElementCount(strides, lengths);
144-
if (maxElements >= span.Length)
141+
strides = strides.IsEmpty ? (ReadOnlySpan<nint>)TensorSpanHelpers.CalculateStrides(lengths, linearLength) : strides;
142+
TensorSpanHelpers.ValidateStrides(strides, lengths);
143+
nint maxElements = TensorSpanHelpers.ComputeMaxLinearIndex(strides, lengths);
144+
if (maxElements >= span.Length && span.Length != 0)
145145
ThrowHelper.ThrowArgument_InvalidStridesAndLengths();
146146

147147
_flattenedLength = linearLength;
@@ -157,7 +157,7 @@ public ReadOnlyTensorSpan(ReadOnlySpan<T> span, scoped ReadOnlySpan<nint> length
157157
/// have a rank of 1 and a length equal to the length of the provided <see cref="Array"/>.
158158
/// </summary>
159159
/// <param name="array">The target array.</param>
160-
public ReadOnlyTensorSpan(Array? array) : this(array, ReadOnlySpan<int>.Empty, [], []) { }
160+
public ReadOnlyTensorSpan(Array? array) : this(array, ReadOnlySpan<int>.Empty, array == null ? [0] : (from dim in Enumerable.Range(0, array.Rank) select (nint)array.GetLength(dim)).ToArray(), []) { }
161161

162162
/// <summary>
163163
/// Creates a new <see cref="ReadOnlyTensorSpan{T}"/> over the provided <see cref="Array"/> using the specified start offsets, lengths, and strides.
@@ -169,9 +169,10 @@ public ReadOnlyTensorSpan(Array? array) : this(array, ReadOnlySpan<int>.Empty, [
169169
/// <param name="strides">The strides for each dimension. Will be automatically calculated if not provided.</param>
170170
public ReadOnlyTensorSpan(Array? array, scoped ReadOnlySpan<int> start, scoped ReadOnlySpan<nint> lengths, scoped ReadOnlySpan<nint> strides)
171171
{
172+
if (lengths.IsEmpty && array != null)
173+
lengths = (from dim in Enumerable.Range(0, array.Rank) select (nint)array.GetLength(dim)).ToArray();
174+
172175
nint linearLength = TensorSpanHelpers.CalculateTotalLength(lengths);
173-
strides = strides.IsEmpty ? (ReadOnlySpan<nint>)TensorSpanHelpers.CalculateStrides(lengths) : strides;
174-
nint startOffset = TensorSpanHelpers.ComputeLinearIndex(start, strides, lengths);
175176
if (array == null)
176177
{
177178
if (!start.IsEmpty || linearLength != 0)
@@ -182,16 +183,20 @@ public ReadOnlyTensorSpan(Array? array, scoped ReadOnlySpan<int> start, scoped R
182183
if (!typeof(T).IsValueType && array.GetType() != typeof(T[]))
183184
ThrowHelper.ThrowArrayTypeMismatchException();
184185

185-
nint maxElements = TensorSpanHelpers.ComputeMaxElementCount(strides, lengths);
186+
strides = strides.IsEmpty ? (ReadOnlySpan<nint>)TensorSpanHelpers.CalculateStrides(lengths, linearLength) : strides;
187+
TensorSpanHelpers.ValidateStrides(strides, lengths);
188+
189+
nint startOffset = TensorSpanHelpers.ComputeStartOffsetSystemArray(array, start);
190+
nint maxElements = TensorSpanHelpers.ComputeMaxLinearIndex(strides, lengths);
186191
if (Environment.Is64BitProcess)
187192
{
188193
// See comment in Span<T>.Slice for how this works.
189-
if ((ulong)(uint)startOffset + (ulong)(uint)maxElements > (ulong)(uint)array.Length)
194+
if ((ulong)(uint)startOffset + (ulong)(uint)maxElements >= (ulong)(uint)array.Length && array.Length != 0)
190195
ThrowHelper.ThrowArgumentOutOfRangeException();
191196
}
192197
else
193198
{
194-
if ((uint)startOffset > (uint)array.Length || (uint)maxElements > (uint)(array.Length - startOffset))
199+
if (((uint)startOffset > (uint)array.Length || (uint)maxElements >= (uint)(array.Length - startOffset)) && array.Length != 0)
195200
ThrowHelper.ThrowArgumentOutOfRangeException();
196201
}
197202

@@ -213,9 +218,10 @@ public ReadOnlyTensorSpan(Array? array, scoped ReadOnlySpan<int> start, scoped R
213218
/// <param name="strides">The strides for each dimension. Will be automatically calculated if not provided.</param>
214219
public ReadOnlyTensorSpan(Array? array, scoped ReadOnlySpan<NIndex> startIndex, scoped ReadOnlySpan<nint> lengths, scoped ReadOnlySpan<nint> strides)
215220
{
221+
if (lengths.IsEmpty && array != null)
222+
lengths = (from dim in Enumerable.Range(0, array.Rank) select (nint)array.GetLength(dim)).ToArray();
223+
216224
nint linearLength = TensorSpanHelpers.CalculateTotalLength(lengths);
217-
strides = strides.IsEmpty ? (ReadOnlySpan<nint>)TensorSpanHelpers.CalculateStrides(lengths) : strides;
218-
nint start = TensorSpanHelpers.ComputeLinearIndex(startIndex, strides, lengths);
219225
if (array == null)
220226
{
221227
if (!startIndex.IsEmpty || linearLength != 0)
@@ -226,22 +232,26 @@ public ReadOnlyTensorSpan(Array? array, scoped ReadOnlySpan<NIndex> startIndex,
226232
if (!typeof(T).IsValueType && array.GetType() != typeof(T[]))
227233
ThrowHelper.ThrowArrayTypeMismatchException();
228234

229-
nint maxElements = TensorSpanHelpers.ComputeMaxElementCount(strides, lengths);
235+
strides = strides.IsEmpty ? (ReadOnlySpan<nint>)TensorSpanHelpers.CalculateStrides(lengths, linearLength) : strides;
236+
TensorSpanHelpers.ValidateStrides(strides, lengths);
237+
238+
nint startOffset = TensorSpanHelpers.ComputeStartOffsetSystemArray(array, startIndex);
239+
nint maxElements = TensorSpanHelpers.ComputeMaxLinearIndex(strides, lengths);
230240
if (Environment.Is64BitProcess)
231241
{
232242
// See comment in Span<T>.Slice for how this works.
233-
if ((ulong)(uint)start + (ulong)(uint)maxElements > (ulong)(uint)array.Length)
243+
if ((ulong)(uint)startOffset + (ulong)(uint)maxElements > (ulong)(uint)array.Length)
234244
ThrowHelper.ThrowArgumentOutOfRangeException();
235245
}
236246
else
237247
{
238-
if ((uint)start > (uint)array.Length || (uint)maxElements > (uint)(array.Length - start))
248+
if ((uint)startOffset > (uint)array.Length || (uint)maxElements >= (uint)(array.Length - startOffset))
239249
ThrowHelper.ThrowArgumentOutOfRangeException();
240250
}
241251

242252
_flattenedLength = linearLength;
243253
_memoryLength = array.Length;
244-
_reference = ref Unsafe.Add(ref Unsafe.As<byte, T>(ref MemoryMarshal.GetArrayDataReference(array)), (nint)(uint)start /* force zero-extension */);
254+
_reference = ref Unsafe.Add(ref Unsafe.As<byte, T>(ref MemoryMarshal.GetArrayDataReference(array)), (nint)(uint)startOffset /* force zero-extension */);
245255

246256
_lengths = lengths.ToArray();
247257
_strides = strides.ToArray();
@@ -276,13 +286,21 @@ public unsafe ReadOnlyTensorSpan(T* data, nint dataLength) : this(data, dataLeng
276286
[MethodImpl(MethodImplOptions.AggressiveInlining)]
277287
public unsafe ReadOnlyTensorSpan(T* data, nint dataLength, scoped ReadOnlySpan<nint> lengths, scoped ReadOnlySpan<nint> strides)
278288
{
279-
nint linearLength = TensorSpanHelpers.CalculateTotalLength(lengths);
289+
if (dataLength < 0)
290+
ThrowHelper.ThrowArgumentOutOfRangeException();
291+
280292
if (RuntimeHelpers.IsReferenceOrContainsReferences<T>())
281293
ThrowHelper.ThrowInvalidTypeWithPointersNotSupported(typeof(T));
282294

283-
strides = strides.IsEmpty ? (ReadOnlySpan<nint>)TensorSpanHelpers.CalculateStrides(lengths) : strides;
284-
nint maxElements = TensorSpanHelpers.ComputeMaxElementCount(strides, lengths);
285-
if (maxElements >= dataLength)
295+
if (lengths.IsEmpty)
296+
lengths = [dataLength];
297+
298+
nint linearLength = TensorSpanHelpers.CalculateTotalLength(lengths);
299+
300+
strides = strides.IsEmpty ? (ReadOnlySpan<nint>)TensorSpanHelpers.CalculateStrides(lengths, linearLength) : strides;
301+
TensorSpanHelpers.ValidateStrides(strides, lengths);
302+
nint maxElements = TensorSpanHelpers.ComputeMaxLinearIndex(strides, lengths);
303+
if (maxElements >= dataLength && dataLength != 0)
286304
ThrowHelper.ThrowArgument_InvalidStridesAndLengths();
287305

288306
_flattenedLength = linearLength;

0 commit comments

Comments
 (0)