Skip to content

Commit

Permalink
more testing
Browse files Browse the repository at this point in the history
  • Loading branch information
michaelgsharp committed Jun 4, 2024
1 parent de8cff8 commit 03c8268
Show file tree
Hide file tree
Showing 7 changed files with 627 additions and 69 deletions.

Large diffs are not rendered by default.

Original file line number Diff line number Diff line change
Expand Up @@ -41,26 +41,6 @@ public static Tensor<T> Resize<T>(Tensor<T> input, ReadOnlySpan<nint> shape)
return output;
}

/// <summary>
/// Creates a new <see cref="TensorSpan{T}"/>, allocates new managed memory, and copies the data from <paramref name="input"/>. If the final shape is smaller all data after
/// </summary>
/// <param name="input">Input <see cref="TensorSpan{T}"/>.</param>
/// <param name="shape"><see cref="ReadOnlySpan{T}"/> of the desired new shape.</param>
public static TensorSpan<T> Resize<T>(TensorSpan<T> input, ReadOnlySpan<nint> shape)
where T : IEquatable<T>, IEqualityOperators<T, T, bool>
{
nint newSize = TensorSpanHelpers.CalculateTotalLength(shape);
T[] values = new T[newSize];
TensorSpan<T> output = new TensorSpan<T>(values, 0, shape, default);
ReadOnlySpan<T> span = MemoryMarshal.CreateSpan(ref input._reference, (int)input.FlattenedLength);
Span<T> ospan = MemoryMarshal.CreateSpan(ref output._reference, (int)output.FlattenedLength);
if (newSize > input.FlattenedLength)
TensorSpanHelpers.Memmove(ospan, span, input.FlattenedLength);
else
TensorSpanHelpers.Memmove(ospan, span, newSize);

return output;
}
#endregion

#region Broadcast
Expand Down Expand Up @@ -1229,7 +1209,7 @@ public static Tensor<T> Permute<T>(Tensor<T> input, params ReadOnlySpan<int> axi
/// </summary>
/// <param name="input">The <see cref="Tensor{T}"/> to take the sin of.</param>
public static Tensor<T> Abs<T>(Tensor<T> input)
where T : IEquatable<T>, IEqualityOperators<T, T, bool>, ITrigonometricFunctions<T>
where T : IEquatable<T>, IEqualityOperators<T, T, bool>, INumberBase<T>
{
return TensorPrimitivesHelperSpanInSpanOut(input, TensorPrimitives.Abs);
}
Expand All @@ -1239,7 +1219,7 @@ public static Tensor<T> Abs<T>(Tensor<T> input)
/// </summary>
/// <param name="input">The <see cref="Tensor{T}"/> to take the sin of.</param>
public static Tensor<T> AbsInPlace<T>(Tensor<T> input)
where T : IEquatable<T>, IEqualityOperators<T, T, bool>, ITrigonometricFunctions<T>
where T : IEquatable<T>, IEqualityOperators<T, T, bool>, INumberBase<T>
{
return TensorPrimitivesHelperSpanInSpanOut(input, TensorPrimitives.Abs, true);
}
Expand Down Expand Up @@ -2520,7 +2500,7 @@ public static Tensor<T> Negate<T>(Tensor<T> input)

/// <summary>Computes the element-wise negation of each number in the specified tensor.</summary>
/// <param name="input">The <see cref="Tensor{T}"/></param>
public static Tensor<T> NegatePlace<T>(Tensor<T> input)
public static Tensor<T> NegateInPlace<T>(Tensor<T> input)
where T : IEquatable<T>, IEqualityOperators<T, T, bool>, IUnaryNegationOperators<T, T>
{
return TensorPrimitivesHelperSpanInSpanOut(input, TensorPrimitives.Negate, true);
Expand Down Expand Up @@ -2550,7 +2530,7 @@ public static Tensor<T> OnesComplement<T>(Tensor<T> input)

/// <summary>Computes the element-wise one's complement of numbers in the specified tensor.</summary>
/// <param name="input">The <see cref="Tensor{T}"/></param>
public static Tensor<T> OnesComplementPlace<T>(Tensor<T> input)
public static Tensor<T> OnesComplementInPlace<T>(Tensor<T> input)
where T : IEquatable<T>, IEqualityOperators<T, T, bool>, IBitwiseOperators<T, T, T>
{
return TensorPrimitivesHelperSpanInSpanOut(input, TensorPrimitives.OnesComplement, true);
Expand All @@ -2568,7 +2548,7 @@ public static Tensor<T> PopCount<T>(Tensor<T> input)

/// <summary>Computes the element-wise population count of numbers in the specified tensor.</summary>
/// <param name="input">The <see cref="Tensor{T}"/></param>
public static Tensor<T> PopCountPlace<T>(Tensor<T> input)
public static Tensor<T> PopCountInPlace<T>(Tensor<T> input)
where T : IEquatable<T>, IEqualityOperators<T, T, bool>, IBinaryInteger<T>
{
return TensorPrimitivesHelperSpanInSpanOut(input, TensorPrimitives.PopCount, true);
Expand Down Expand Up @@ -3058,7 +3038,7 @@ private static Tensor<TTo> TensorPrimitivesHelperTFromSpanInTToSpanOut<TFrom, TT
private static Tensor<T> TensorPrimitivesHelperTwoSpanInSpanOut<T>(Tensor<T> left, Tensor<T> right, PerformCalculationTwoSpanInSpanOut<T> performCalculation, bool inPlace = false)
where T : IEquatable<T>, IEqualityOperators<T, T, bool>
{
if (inPlace && left.Lengths != right.Lengths)
if (inPlace && !left.Lengths.SequenceEqual(right.Lengths))
ThrowHelper.ThrowArgument_InPlaceInvalidShape();

Tensor<T> output;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,27 @@ public static class TensorSpan
/// <summary>
/// Determines whether two sequences are equal by comparing the elements using IEquatable{T}.Equals(T).
/// </summary>
[MethodImpl(MethodImplOptions.AggressiveInlining)]
public static unsafe bool SequenceEqual<T>(this TensorSpan<T> span, TensorSpan<T> other) where T : IEquatable<T>? => span.FlattenedLength == other.FlattenedLength && span.Lengths.SequenceEqual(other.Lengths) && TensorSpanHelpers.SequenceEqual(ref span.GetPinnableReference(), ref other.GetPinnableReference(), (nuint)span.FlattenedLength);
public static bool SequenceEqual<T>(this ReadOnlyTensorSpan<T> span, in ReadOnlyTensorSpan<T> other) where T : IEquatable<T>?
{
return span.FlattenedLength == other.FlattenedLength
&& MemoryMarshal.CreateReadOnlySpan(in span.GetPinnableReference(), (int)span.FlattenedLength).SequenceEqual(MemoryMarshal.CreateReadOnlySpan(in other.GetPinnableReference(), (int)other.FlattenedLength));
}

/// <summary>
/// Determines whether two sequences are equal by comparing the elements using IEquatable{T}.Equals(T).
/// </summary>
public static bool SequenceEqual<T>(this TensorSpan<T> span, in TensorSpan<T> other) where T : IEquatable<T>?
{
return ((ReadOnlyTensorSpan<T>)span).SequenceEqual((ReadOnlyTensorSpan<T>)other);
}

/// <summary>
/// Determines whether two sequences are equal by comparing the elements using IEquatable{T}.Equals(T).
/// </summary>
public static bool SequenceEqual<T>(this TensorSpan<T> span, in ReadOnlyTensorSpan<T> other) where T : IEquatable<T>?
{
return ((ReadOnlyTensorSpan<T>)span).SequenceEqual(other);
}
#endregion

#region AsTensorSpan
Expand Down Expand Up @@ -275,6 +294,29 @@ public static TensorSpan<T> Reshape<T>(this TensorSpan<T> input, params ReadOnly
}
#endregion

#region Resize
/// <summary>
/// Creates a new <see cref="TensorSpan{T}"/>, allocates new managed memory, and copies the data from <paramref name="input"/>. If the final shape is smaller all data after
/// </summary>
/// <param name="input">Input <see cref="TensorSpan{T}"/>.</param>
/// <param name="shape"><see cref="ReadOnlySpan{T}"/> of the desired new shape.</param>
public static TensorSpan<T> Resize<T>(TensorSpan<T> input, ReadOnlySpan<nint> shape)
where T : IEquatable<T>, IEqualityOperators<T, T, bool>
{
nint newSize = TensorSpanHelpers.CalculateTotalLength(shape);
T[] values = new T[newSize];
TensorSpan<T> output = new TensorSpan<T>(values, 0, shape, default);
ReadOnlySpan<T> span = MemoryMarshal.CreateSpan(ref input._reference, (int)input.FlattenedLength);
Span<T> ospan = MemoryMarshal.CreateSpan(ref output._reference, (int)output.FlattenedLength);
if (newSize > input.FlattenedLength)
TensorSpanHelpers.Memmove(ospan, span, input.FlattenedLength);
else
TensorSpanHelpers.Memmove(ospan, span, newSize);

return output;
}
#endregion

#region Squeeze
// REVIEW: NAME?
/// <summary>
Expand Down Expand Up @@ -428,7 +470,7 @@ public static TResult Mean<T, TResult>(TensorSpan<T> input)
/// </summary>
/// <param name="input">The <see cref="TensorSpan{T}"/> to take the sin of.</param>
public static TensorSpan<T> Abs<T>(TensorSpan<T> input)
where T : IEquatable<T>, IEqualityOperators<T, T, bool>, ITrigonometricFunctions<T>
where T : IEquatable<T>, IEqualityOperators<T, T, bool>, INumberBase<T>
{
return TensorPrimitivesHelperSpanInSpanOut(input, TensorPrimitives.Abs);
}
Expand All @@ -438,7 +480,7 @@ public static TensorSpan<T> Abs<T>(TensorSpan<T> input)
/// </summary>
/// <param name="input">The <see cref="TensorSpan{T}"/> to take the sin of.</param>
public static TensorSpan<T> AbsInPlace<T>(TensorSpan<T> input)
where T : IEquatable<T>, IEqualityOperators<T, T, bool>, ITrigonometricFunctions<T>
where T : IEquatable<T>, IEqualityOperators<T, T, bool>, INumberBase<T>
{
return TensorPrimitivesHelperSpanInSpanOut(input, TensorPrimitives.Abs, true);
}
Expand Down Expand Up @@ -1717,7 +1759,7 @@ public static TensorSpan<T> Negate<T>(TensorSpan<T> input)

/// <summary>Computes the element-wise negation of each number in the specified tensor.</summary>
/// <param name="input">The <see cref="TensorSpan{T}"/></param>
public static TensorSpan<T> NegatePlace<T>(TensorSpan<T> input)
public static TensorSpan<T> NegateInPlace<T>(TensorSpan<T> input)
where T : IEquatable<T>, IEqualityOperators<T, T, bool>, IUnaryNegationOperators<T, T>
{
return TensorPrimitivesHelperSpanInSpanOut(input, TensorPrimitives.Negate, true);
Expand Down Expand Up @@ -1749,7 +1791,7 @@ public static TensorSpan<T> OnesComplement<T>(TensorSpan<T> input)

/// <summary>Computes the element-wise one's complement of numbers in the specified tensor.</summary>
/// <param name="input">The <see cref="TensorSpan{T}"/></param>
public static TensorSpan<T> OnesComplementPlace<T>(TensorSpan<T> input)
public static TensorSpan<T> OnesComplementInPlace<T>(TensorSpan<T> input)
where T : IEquatable<T>, IEqualityOperators<T, T, bool>, IBitwiseOperators<T, T, T>
{
return TensorPrimitivesHelperSpanInSpanOut(input, TensorPrimitives.OnesComplement, true);
Expand All @@ -1767,7 +1809,7 @@ public static TensorSpan<T> PopCount<T>(TensorSpan<T> input)

/// <summary>Computes the element-wise population count of numbers in the specified tensor.</summary>
/// <param name="input">The <see cref="TensorSpan{T}"/></param>
public static TensorSpan<T> PopCountPlace<T>(TensorSpan<T> input)
public static TensorSpan<T> PopCountInPlace<T>(TensorSpan<T> input)
where T : IEquatable<T>, IEqualityOperators<T, T, bool>, IBinaryInteger<T>
{
return TensorPrimitivesHelperSpanInSpanOut(input, TensorPrimitives.PopCount, true);
Expand Down Expand Up @@ -2285,7 +2327,7 @@ private static TensorSpan<TTo> TensorPrimitivesHelperTFromSpanInTToSpanOut<TFrom
private static TensorSpan<T> TensorPrimitivesHelperTwoSpanInSpanOut<T>(TensorSpan<T> left, TensorSpan<T> right, PerformCalculationTwoSpanInSpanOut<T> performCalculation, bool inPlace = false)
where T : IEquatable<T>, IEqualityOperators<T, T, bool>
{
if (inPlace && left.Lengths != right.Lengths)
if (inPlace && !left.Lengths.SequenceEqual(right.Lengths))
ThrowHelper.ThrowArgument_InPlaceInvalidShape();

TensorSpan<T> output;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -52,20 +52,5 @@ public static unsafe void Fill<T>(ref T dest, nuint numElements, T value)
numElements -= toFill;
}
}

public static bool SequenceEqual<T>(ref T first, ref T second, nuint length) where T : IEquatable<T>?
{
bool equal = true;
while (length > 0)
{
nuint toCompare = Math.Min(length, int.MaxValue);
equal &= MemoryMarshal.CreateSpan(ref first, (int)toCompare).SequenceEqual(MemoryMarshal.CreateSpan(ref second, (int)toCompare));
first = ref Unsafe.Add(ref first, toCompare);
second = ref Unsafe.Add(ref second, toCompare);
length -= toCompare;
}

return equal;
}
}
}
1 change: 1 addition & 0 deletions src/libraries/System.Numerics.Tensors/tests/Helpers.cs
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ public static class Helpers
public static IEnumerable<int> TensorLengthsIncluding0 => Enumerable.Range(0, 257);

public static IEnumerable<int> TensorLengths => Enumerable.Range(1, 256);
public static IEnumerable<nint[]> TensorShapes => [[1], [2], [10], [1,1], [1,2], [2,2], [5, 5], [2, 2, 2], [5, 5, 5], [3, 3, 3, 3], [4, 4, 4, 4, 4]];

// Tolerances taken from testing in the scalar math routines:
// cf. https://github.com/dotnet/runtime/blob/89f7ad3b276fb0b48f20cb4e8408bdce85c2b415/src/libraries/System.Runtime/tests/System.Runtime.Extensions.Tests/System/Math.cs
Expand Down
Loading

0 comments on commit 03c8268

Please sign in to comment.