Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update TensorPrimitives aggregations to vectorize handling of remaining elements #92672

Merged
merged 2 commits into from
Sep 28, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,7 @@ public static float Distance(ReadOnlySpan<float> x, ReadOnlySpan<float> y)
ThrowHelper.ThrowArgument_SpansMustHaveSameLength();
}

return MathF.Sqrt(Aggregate<SubtractSquaredOperator, AddOperator>(0f, x, y));
return MathF.Sqrt(Aggregate<SubtractSquaredOperator, AddOperator>(x, y));
}

/// <summary>Computes the element-wise result of: <c><paramref name="x" /> / <paramref name="y" /></c>.</summary>
Expand Down Expand Up @@ -162,7 +162,7 @@ public static float Dot(ReadOnlySpan<float> x, ReadOnlySpan<float> y) // BLAS1:
ThrowHelper.ThrowArgument_SpansMustHaveSameLength();
}

return Aggregate<MultiplyOperator, AddOperator>(0f, x, y);
return Aggregate<MultiplyOperator, AddOperator>(x, y);
}

/// <summary>Computes the element-wise result of: <c>pow(e, <paramref name="x" />)</c>.</summary>
Expand Down Expand Up @@ -545,7 +545,7 @@ public static void Negate(ReadOnlySpan<float> x, Span<float> destination) =>
/// <param name="x">The first tensor, represented as a span.</param>
/// <returns>The L2 norm.</returns>
public static float Norm(ReadOnlySpan<float> x) => // BLAS1: nrm2
MathF.Sqrt(Aggregate<SquaredOperator, AddOperator>(0f, x));
MathF.Sqrt(Aggregate<SquaredOperator, AddOperator>(x));

/// <summary>Computes the product of all elements in <paramref name="x"/>.</summary>
/// <param name="x">The tensor, represented as a span.</param>
Expand All @@ -558,7 +558,7 @@ public static float Product(ReadOnlySpan<float> x)
ThrowHelper.ThrowArgument_SpansMustBeNonEmpty();
}

return Aggregate<IdentityOperator, MultiplyOperator>(1.0f, x);
return Aggregate<IdentityOperator, MultiplyOperator>(x);
}

/// <summary>Computes the product of the element-wise result of: <c><paramref name="x" /> - <paramref name="y" /></c>.</summary>
Expand All @@ -580,7 +580,7 @@ public static float ProductOfDifferences(ReadOnlySpan<float> x, ReadOnlySpan<flo
ThrowHelper.ThrowArgument_SpansMustHaveSameLength();
}

return Aggregate<SubtractOperator, MultiplyOperator>(1.0f, x, y);
return Aggregate<SubtractOperator, MultiplyOperator>(x, y);
}

/// <summary>Computes the product of the element-wise result of: <c><paramref name="x" /> + <paramref name="y" /></c>.</summary>
Expand All @@ -602,7 +602,7 @@ public static float ProductOfSums(ReadOnlySpan<float> x, ReadOnlySpan<float> y)
ThrowHelper.ThrowArgument_SpansMustHaveSameLength();
}

return Aggregate<AddOperator, MultiplyOperator>(1.0f, x, y);
return Aggregate<AddOperator, MultiplyOperator>(x, y);
}

/// <summary>
Expand Down Expand Up @@ -703,7 +703,7 @@ public static void Subtract(ReadOnlySpan<float> x, float y, Span<float> destinat
/// <param name="x">The tensor, represented as a span.</param>
/// <returns>The result of adding all elements in <paramref name="x"/>, or zero if <paramref name="x"/> is empty.</returns>
public static float Sum(ReadOnlySpan<float> x) =>
Aggregate<IdentityOperator, AddOperator>(0f, x);
Aggregate<IdentityOperator, AddOperator>(x);

/// <summary>Computes the sum of the absolute values of every element in <paramref name="x"/>.</summary>
/// <param name="x">The tensor, represented as a span.</param>
Expand All @@ -713,14 +713,14 @@ public static float Sum(ReadOnlySpan<float> x) =>
/// <para>This method corresponds to the <c>asum</c> method defined by <c>BLAS1</c>.</para>
/// </remarks>
public static float SumOfMagnitudes(ReadOnlySpan<float> x) =>
Aggregate<AbsoluteOperator, AddOperator>(0f, x);
Aggregate<AbsoluteOperator, AddOperator>(x);

/// <summary>Computes the sum of the squares of every element in <paramref name="x"/>.</summary>
/// <param name="x">The tensor, represented as a span.</param>
/// <returns>The result of adding every element in <paramref name="x"/> multiplied by itself, or zero if <paramref name="x"/> is empty.</returns>
/// <remarks>This method effectively does <c><see cref="TensorPrimitives" />.Sum(<see cref="TensorPrimitives" />.Multiply(<paramref name="x" />, <paramref name="x" />))</c>.</remarks>
public static float SumOfSquares(ReadOnlySpan<float> x) =>
Aggregate<SquaredOperator, AddOperator>(0f, x);
Aggregate<SquaredOperator, AddOperator>(x);

/// <summary>Computes the element-wise result of: <c>tanh(<paramref name="x" />)</c>.</summary>
/// <param name="x">The tensor, represented as a span.</param>
Expand All @@ -739,5 +739,31 @@ public static void Tanh(ReadOnlySpan<float> x, Span<float> destination)
destination[i] = MathF.Tanh(x[i]);
}
}

/// <summary>Mask used to handle remaining elements after vectorized handling of the input.</summary>
/// <remarks>
/// Logically 16 rows of 16 uints. The Nth row should be used to handle N remaining elements at the
/// end of the input, where elements in the vector prior to that will be zero'd.
/// </remarks>
private static ReadOnlySpan<uint> RemainderUInt32Mask_16x16 => new uint[]
{
0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000,
0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0xFFFFFFFF,
0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0xFFFFFFFF, 0xFFFFFFFF,
0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF,
0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF,
0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF,
0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF,
0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF,
0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF,
0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF,
0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF,
0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF,
0x00000000, 0x00000000, 0x00000000, 0x00000000, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF,
0x00000000, 0x00000000, 0x00000000, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF,
0x00000000, 0x00000000, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF,
0x00000000, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF,
0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF,
};
}
}
Loading