Skip to content

Commit

Permalink
Update TensorPrimitives aggregations to vectorize handling of remaini…
Browse files Browse the repository at this point in the history
…ng elements (#92672)

* Update TensorPrimitives.CosineSimilarity to vectorize handling of remaining elements

* Vectorize remainder handling for Aggregate helpers
  • Loading branch information
stephentoub authored Sep 28, 2023
1 parent dc1f86a commit 3bf40a3
Show file tree
Hide file tree
Showing 3 changed files with 443 additions and 205 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,7 @@ public static float Distance(ReadOnlySpan<float> x, ReadOnlySpan<float> y)
ThrowHelper.ThrowArgument_SpansMustHaveSameLength();
}

return MathF.Sqrt(Aggregate<SubtractSquaredOperator, AddOperator>(0f, x, y));
return MathF.Sqrt(Aggregate<SubtractSquaredOperator, AddOperator>(x, y));
}

/// <summary>Computes the element-wise result of: <c><paramref name="x" /> / <paramref name="y" /></c>.</summary>
Expand Down Expand Up @@ -162,7 +162,7 @@ public static float Dot(ReadOnlySpan<float> x, ReadOnlySpan<float> y) // BLAS1:
ThrowHelper.ThrowArgument_SpansMustHaveSameLength();
}

return Aggregate<MultiplyOperator, AddOperator>(0f, x, y);
return Aggregate<MultiplyOperator, AddOperator>(x, y);
}

/// <summary>Computes the element-wise result of: <c>pow(e, <paramref name="x" />)</c>.</summary>
Expand Down Expand Up @@ -545,7 +545,7 @@ public static void Negate(ReadOnlySpan<float> x, Span<float> destination) =>
/// <param name="x">The first tensor, represented as a span.</param>
/// <returns>The L2 norm.</returns>
public static float Norm(ReadOnlySpan<float> x) => // BLAS1: nrm2
MathF.Sqrt(Aggregate<SquaredOperator, AddOperator>(0f, x));
MathF.Sqrt(Aggregate<SquaredOperator, AddOperator>(x));

/// <summary>Computes the product of all elements in <paramref name="x"/>.</summary>
/// <param name="x">The tensor, represented as a span.</param>
Expand All @@ -558,7 +558,7 @@ public static float Product(ReadOnlySpan<float> x)
ThrowHelper.ThrowArgument_SpansMustBeNonEmpty();
}

return Aggregate<IdentityOperator, MultiplyOperator>(1.0f, x);
return Aggregate<IdentityOperator, MultiplyOperator>(x);
}

/// <summary>Computes the product of the element-wise result of: <c><paramref name="x" /> - <paramref name="y" /></c>.</summary>
Expand All @@ -580,7 +580,7 @@ public static float ProductOfDifferences(ReadOnlySpan<float> x, ReadOnlySpan<flo
ThrowHelper.ThrowArgument_SpansMustHaveSameLength();
}

return Aggregate<SubtractOperator, MultiplyOperator>(1.0f, x, y);
return Aggregate<SubtractOperator, MultiplyOperator>(x, y);
}

/// <summary>Computes the product of the element-wise result of: <c><paramref name="x" /> + <paramref name="y" /></c>.</summary>
Expand All @@ -602,7 +602,7 @@ public static float ProductOfSums(ReadOnlySpan<float> x, ReadOnlySpan<float> y)
ThrowHelper.ThrowArgument_SpansMustHaveSameLength();
}

return Aggregate<AddOperator, MultiplyOperator>(1.0f, x, y);
return Aggregate<AddOperator, MultiplyOperator>(x, y);
}

/// <summary>
Expand Down Expand Up @@ -703,7 +703,7 @@ public static void Subtract(ReadOnlySpan<float> x, float y, Span<float> destinat
/// <param name="x">The tensor, represented as a span.</param>
/// <returns>The result of adding all elements in <paramref name="x"/>, or zero if <paramref name="x"/> is empty.</returns>
public static float Sum(ReadOnlySpan<float> x) =>
Aggregate<IdentityOperator, AddOperator>(0f, x);
Aggregate<IdentityOperator, AddOperator>(x);

/// <summary>Computes the sum of the absolute values of every element in <paramref name="x"/>.</summary>
/// <param name="x">The tensor, represented as a span.</param>
Expand All @@ -713,14 +713,14 @@ public static float Sum(ReadOnlySpan<float> x) =>
/// <para>This method corresponds to the <c>asum</c> method defined by <c>BLAS1</c>.</para>
/// </remarks>
public static float SumOfMagnitudes(ReadOnlySpan<float> x) =>
Aggregate<AbsoluteOperator, AddOperator>(0f, x);
Aggregate<AbsoluteOperator, AddOperator>(x);

/// <summary>Computes the sum of the squares of every element in <paramref name="x"/>.</summary>
/// <param name="x">The tensor, represented as a span.</param>
/// <returns>The result of adding every element in <paramref name="x"/> multiplied by itself, or zero if <paramref name="x"/> is empty.</returns>
/// <remarks>This method effectively does <c><see cref="TensorPrimitives" />.Sum(<see cref="TensorPrimitives" />.Multiply(<paramref name="x" />, <paramref name="x" />))</c>.</remarks>
public static float SumOfSquares(ReadOnlySpan<float> x) =>
Aggregate<SquaredOperator, AddOperator>(0f, x);
Aggregate<SquaredOperator, AddOperator>(x);

/// <summary>Computes the element-wise result of: <c>tanh(<paramref name="x" />)</c>.</summary>
/// <param name="x">The tensor, represented as a span.</param>
Expand All @@ -739,5 +739,31 @@ public static void Tanh(ReadOnlySpan<float> x, Span<float> destination)
destination[i] = MathF.Tanh(x[i]);
}
}

/// <summary>Mask used to handle remaining elements after vectorized handling of the input.</summary>
/// <remarks>
/// Logically 16 rows of 16 uints. The Nth row should be used to handle N remaining elements at the
/// end of the input, where elements in the vector prior to that will be zero'd.
/// </remarks>
private static ReadOnlySpan<uint> RemainderUInt32Mask_16x16 => new uint[]
{
0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000,
0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0xFFFFFFFF,
0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0xFFFFFFFF, 0xFFFFFFFF,
0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF,
0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF,
0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF,
0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF,
0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF,
0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF,
0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF,
0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF,
0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF,
0x00000000, 0x00000000, 0x00000000, 0x00000000, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF,
0x00000000, 0x00000000, 0x00000000, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF,
0x00000000, 0x00000000, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF,
0x00000000, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF,
0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF,
};
}
}
Loading

0 comments on commit 3bf40a3

Please sign in to comment.