Skip to content

Commit

Permalink
Some fine tuning
Browse files Browse the repository at this point in the history
  • Loading branch information
aalmada committed Apr 9, 2024
1 parent d6fa12f commit 63a6f54
Show file tree
Hide file tree
Showing 11 changed files with 111 additions and 132 deletions.
2 changes: 0 additions & 2 deletions src/NetFabric.Numerics.Tensors.Benchmarks/Baseline.cs
Original file line number Diff line number Diff line change
@@ -1,6 +1,4 @@
using System.Numerics;
using System.Runtime.CompilerServices;
using System.Runtime.InteropServices;

namespace NetFabric.Numerics.Tensors.Benchmarks;

Expand Down
4 changes: 2 additions & 2 deletions src/NetFabric.Numerics.Tensors/AggregateNumber.cs
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ public static TResult AggregateNumber<TSource, TTransformed, TResult, TTransform
var partial1 = TAggregateOperator.Seed;
var partial2 = TAggregateOperator.Seed;
var partial3 = TAggregateOperator.Seed;
for (; indexSource + 3 < source.Length; indexSource += 4)
for (; indexSource < source.Length - 3; indexSource += 4)
{
aggregate = TAggregateOperator.Invoke(aggregate, TTransformOperator.Invoke(Unsafe.Add(ref sourceRef, indexSource)));
partial1 = TAggregateOperator.Invoke(partial1, TTransformOperator.Invoke(Unsafe.Add(ref sourceRef, indexSource + 1)));
Expand Down Expand Up @@ -261,7 +261,7 @@ public static TResult AggregateNumber<T1, T2, TTransformed, TResult, TTransformO
var partial1 = TAggregateOperator.Seed;
var partial2 = TAggregateOperator.Seed;
var partial3 = TAggregateOperator.Seed;
for (; indexSource + 3 < x.Length; indexSource += 4)
for (; indexSource < x.Length - 3; indexSource += 4)
{
aggregate = TAggregateOperator.Invoke(aggregate, TTransformOperator.Invoke(Unsafe.Add(ref xRef, indexSource), Unsafe.Add(ref yRef, indexSource)));
partial1 = TAggregateOperator.Invoke(partial1, TTransformOperator.Invoke(Unsafe.Add(ref xRef, indexSource + 1), Unsafe.Add(ref yRef, indexSource + 1)));
Expand Down
8 changes: 4 additions & 4 deletions src/NetFabric.Numerics.Tensors/AggregateNumber2D.cs
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ public static (TResult, TResult) AggregateNumber2D<TSource, TTransformed, TResul
}

// aggregate the aggregate vector into the aggregate
for (var index = 0; index + 1 < Vector<TResult>.Count; index += 2)
for (var index = 0; index < Vector<TResult>.Count - 1; index += 2)
{
aggregateX = TAggregateOperator.Invoke(aggregateX, resultVector[index]);
aggregateY = TAggregateOperator.Invoke(aggregateY, resultVector[index + 1]);
Expand All @@ -85,7 +85,7 @@ public static (TResult, TResult) AggregateNumber2D<TSource, TTransformed, TResul
{
var partialX1 = TAggregateOperator.Seed;
var partialY1 = TAggregateOperator.Seed;
for (; indexSource + 3 < source.Length; indexSource += 4)
for (; indexSource < source.Length - 3; indexSource += 4)
{
aggregateX = TAggregateOperator.Invoke(aggregateX, TTransformOperator.Invoke(Unsafe.Add(ref sourceRef, indexSource)));
aggregateY = TAggregateOperator.Invoke(aggregateY, TTransformOperator.Invoke(Unsafe.Add(ref sourceRef, indexSource + 1)));
Expand Down Expand Up @@ -192,7 +192,7 @@ public static (TResult, TResult) AggregateNumber2D<T1, T2, TTransformed, TResult
}

// aggregate the aggregate vector into the aggregate
for (var index = 0; index + 1 < Vector<TResult>.Count; index += 2)
for (var index = 0; index < Vector<TResult>.Count - 1; index += 2)
{
aggregateX = TAggregateOperator.Invoke(aggregateX, resultVector[index]);
aggregateY = TAggregateOperator.Invoke(aggregateY, resultVector[index + 1]);
Expand All @@ -211,7 +211,7 @@ public static (TResult, TResult) AggregateNumber2D<T1, T2, TTransformed, TResult
{
var partialX1 = TAggregateOperator.Seed;
var partialY1 = TAggregateOperator.Seed;
for (; indexSource + 3 < x.Length; indexSource += 4)
for (; indexSource < x.Length - 3; indexSource += 4)
{
aggregateX = TAggregateOperator.Invoke(aggregateX, TTransformOperator.Invoke(Unsafe.Add(ref xRef, indexSource), Unsafe.Add(ref yRef, indexSource)));
aggregateY = TAggregateOperator.Invoke(aggregateY, TTransformOperator.Invoke(Unsafe.Add(ref xRef, indexSource + 1), Unsafe.Add(ref yRef, indexSource + 1)));
Expand Down
42 changes: 12 additions & 30 deletions src/NetFabric.Numerics.Tensors/AggregateNumber3D.cs
Original file line number Diff line number Diff line change
Expand Up @@ -52,12 +52,12 @@ public static (TResult, TResult, TResult) AggregateNumber3D<T1, T2, TResult, TTr
// convert source span to vector span without copies
var sourceVectors = MemoryMarshal.Cast<T1, Vector<T1>>(source);

// check if there is at least one vector to aggregate
if (sourceVectors.Length > 0)
// check if there are more than 3 vector to aggregate
if (sourceVectors.Length > 3)
{
// initialize aggregate vectors
// use 3 vectors as 3 times the number of items in a vector is a multiple of 3
var values = new TResult[Vector<TResult>.Count * 3];
var values = GC.AllocateUninitializedArray<TResult>(Vector<TResult>.Count * 3);
Array.Fill(values, TAggregateOperator.Seed);
var resultValues = values.AsSpan();
var resultVectors = MemoryMarshal.Cast<TResult, Vector<TResult>>(resultValues);
Expand All @@ -66,7 +66,7 @@ public static (TResult, TResult, TResult) AggregateNumber3D<T1, T2, TResult, TTr
ref var sourceVectorsRef = ref MemoryMarshal.GetReference(sourceVectors);
ref var resultVectorsRef = ref MemoryMarshal.GetReference(resultVectors);
var indexVector = 0;
for (; indexVector + 2 < sourceVectors.Length; indexVector += 3)
for (; indexVector < sourceVectors.Length - 2; indexVector += 3)
{
var transformedVector0 = TTransformOperator.Invoke(ref Unsafe.Add(ref sourceVectorsRef, indexVector));
Unsafe.Add(ref resultVectorsRef, 0) = TAggregateOperator.Invoke(ref Unsafe.Add(ref resultVectorsRef, 0), ref transformedVector0);
Expand All @@ -78,7 +78,7 @@ public static (TResult, TResult, TResult) AggregateNumber3D<T1, T2, TResult, TTr

// aggregate the aggregate vector into the aggregate
ref var resultValuesRef = ref MemoryMarshal.GetReference(resultValues);
for (var index = 0; index + 2 < Vector<TResult>.Count * 3; index += 3)
for (var index = 0; index < resultValues.Length - 2; index += 3)
{
aggregateX = TAggregateOperator.Invoke(aggregateX, Unsafe.Add(ref resultValuesRef, index));
aggregateY = TAggregateOperator.Invoke(aggregateY, Unsafe.Add(ref resultValuesRef, index + 1));
Expand All @@ -92,22 +92,13 @@ public static (TResult, TResult, TResult) AggregateNumber3D<T1, T2, TResult, TTr

// aggregate the remaining elements in the source
ref var sourceRef = ref MemoryMarshal.GetReference(source);
for (; indexSource + 2 < source.Length; indexSource += 3)
for (; indexSource < source.Length - 2; indexSource += 3)
{
aggregateX = TAggregateOperator.Invoke(aggregateX, TTransformOperator.Invoke(Unsafe.Add(ref sourceRef, indexSource)));
aggregateY = TAggregateOperator.Invoke(aggregateY, TTransformOperator.Invoke(Unsafe.Add(ref sourceRef, indexSource + 1)));
aggregateZ = TAggregateOperator.Invoke(aggregateZ, TTransformOperator.Invoke(Unsafe.Add(ref sourceRef, indexSource + 2)));
}

switch (source.Length - indexSource)
{
case 0:
break;
default:
Throw.Exception("Should not happen!");
break;
}

return (aggregateX, aggregateY, aggregateZ);
}

Expand Down Expand Up @@ -172,12 +163,12 @@ public static (TResult, TResult, TResult) AggregateNumber3D<T1, T2, TTransformed
var xVectors = MemoryMarshal.Cast<T1, Vector<T1>>(x);
var yVectors = MemoryMarshal.Cast<T2, Vector<T2>>(y);

// check if there is at least one vector to aggregate
if (xVectors.Length > 0)
// check if there are more than 3 vector to aggregate
if (xVectors.Length > 3)
{
// initialize aggregate vectors
// use 3 vectors as 3 times the number of items in a vector is a multiple of 3
var values = new TResult[Vector<TResult>.Count * 3];
var values = GC.AllocateUninitializedArray<TResult>(Vector<TResult>.Count * 3);
Array.Fill(values, TAggregateOperator.Seed);
var resultValues = values.AsSpan();
var resultVectors = MemoryMarshal.Cast<TResult, Vector<TResult>>(resultValues);
Expand All @@ -187,7 +178,7 @@ public static (TResult, TResult, TResult) AggregateNumber3D<T1, T2, TTransformed
ref var yVectorsRef = ref MemoryMarshal.GetReference(yVectors);
ref var resultVectorsRef = ref MemoryMarshal.GetReference(resultVectors);
var indexVector = 0;
for (; indexVector + 2 < xVectors.Length; indexVector += 3)
for (; indexVector < xVectors.Length - 2; indexVector += 3)
{
var transformedVector0 = TTransformOperator.Invoke(ref Unsafe.Add(ref xVectorsRef, indexVector), ref Unsafe.Add(ref yVectorsRef, indexVector));
Unsafe.Add(ref resultVectorsRef, 0) = TAggregateOperator.Invoke(ref Unsafe.Add(ref resultVectorsRef, 0), ref transformedVector0);
Expand All @@ -199,7 +190,7 @@ public static (TResult, TResult, TResult) AggregateNumber3D<T1, T2, TTransformed

// aggregate the aggregate vector into the aggregate
ref var resultValuesRef = ref MemoryMarshal.GetReference(resultValues);
for (var index = 0; index + 2 < Vector<TResult>.Count * 3; index += 3)
for (var index = 0; index < resultValues.Length - 2; index += 3)
{
aggregateX = TAggregateOperator.Invoke(aggregateX, Unsafe.Add(ref resultValuesRef, index));
aggregateY = TAggregateOperator.Invoke(aggregateY, Unsafe.Add(ref resultValuesRef, index + 1));
Expand All @@ -214,22 +205,13 @@ public static (TResult, TResult, TResult) AggregateNumber3D<T1, T2, TTransformed
// aggregate the remaining elements in the source
ref var xRef = ref MemoryMarshal.GetReference(x);
ref var yRef = ref MemoryMarshal.GetReference(y);
for (; indexSource + 2 < x.Length; indexSource += 3)
for (; indexSource < x.Length - 2; indexSource += 3)
{
aggregateX = TAggregateOperator.Invoke(aggregateX, TTransformOperator.Invoke(Unsafe.Add(ref xRef, indexSource), Unsafe.Add(ref yRef, indexSource)));
aggregateY = TAggregateOperator.Invoke(aggregateY, TTransformOperator.Invoke(Unsafe.Add(ref xRef, indexSource + 1), Unsafe.Add(ref yRef, indexSource + 1)));
aggregateZ = TAggregateOperator.Invoke(aggregateZ, TTransformOperator.Invoke(Unsafe.Add(ref xRef, indexSource + 2), Unsafe.Add(ref yRef, indexSource + 2)));
}

switch (x.Length - indexSource)
{
case 0:
break;
default:
Throw.Exception("Should not happen!");
break;
}

return (aggregateX, aggregateY, aggregateZ);
}
}
Loading

0 comments on commit 63a6f54

Please sign in to comment.