Skip to content

Commit

Permalink
Same implementation for Sparse Multiplication for aligned and unalign…
Browse files Browse the repository at this point in the history
…ed arrays (#1274)

* sparse vector corrected

* Removind Dead Code, correcting names, adding assert checks to correct place, span overloads and function for common code

* fixing build on unix

* cmake file corrected, if def removed from sse.cpp and unitest name modified

* Performance test corrected, resolved merge conflicts, fma supported added
  • Loading branch information
Anipik authored Oct 24, 2018
1 parent 00021b6 commit 263a67b
Show file tree
Hide file tree
Showing 14 changed files with 543 additions and 5,913 deletions.
1,165 changes: 0 additions & 1,165 deletions src/Microsoft.ML.CpuMath/Avx.cs

This file was deleted.

192 changes: 136 additions & 56 deletions src/Microsoft.ML.CpuMath/AvxIntrinsics.cs
Original file line number Diff line number Diff line change
Expand Up @@ -46,25 +46,6 @@ internal static class AvxIntrinsics

private static readonly Vector256<float> _absMask256 = Avx.StaticCast<int, float>(Avx.SetAllVector256(0x7FFFFFFF));

private const int Vector256Alignment = 32;

[MethodImplAttribute(MethodImplOptions.AggressiveInlining)]
private static bool HasCompatibleAlignment(AlignedArray alignedArray)
{
Contracts.AssertValue(alignedArray);
Contracts.Assert(alignedArray.Size > 0);
return (alignedArray.CbAlign % Vector256Alignment) == 0;
}

[MethodImplAttribute(MethodImplOptions.AggressiveInlining)]
private static unsafe float* GetAlignedBase(AlignedArray alignedArray, float* unalignedBase)
{
Contracts.AssertValue(alignedArray);
float* alignedBase = unalignedBase + alignedArray.GetBase((long)unalignedBase);
Contracts.Assert(((long)alignedBase % Vector256Alignment) == 0);
return alignedBase;
}

[MethodImplAttribute(MethodImplOptions.AggressiveInlining)]
private static Vector128<float> GetHigh(in Vector256<float> x)
=> Avx.ExtractVector128(x, 1);
Expand Down Expand Up @@ -170,19 +151,19 @@ private static Vector256<float> MultiplyAdd(Vector256<float> src1, Vector256<flo
}

// Multiply matrix times vector into vector.
public static unsafe void MatMulX(AlignedArray mat, AlignedArray src, AlignedArray dst, int crow, int ccol)
public static unsafe void MatMul(AlignedArray mat, AlignedArray src, AlignedArray dst, int crow, int ccol)
{
Contracts.Assert(crow % 4 == 0);
Contracts.Assert(ccol % 4 == 0);

MatMulX(mat.Items, src.Items, dst.Items, crow, ccol);
MatMul(mat.Items, src.Items, dst.Items, crow, ccol);
}

public static unsafe void MatMulX(float[] mat, float[] src, float[] dst, int crow, int ccol)
public static unsafe void MatMul(ReadOnlySpan<float> mat, ReadOnlySpan<float> src, Span<float> dst, int crow, int ccol)
{
fixed (float* psrc = &src[0])
fixed (float* pdst = &dst[0])
fixed (float* pmat = &mat[0])
Contracts.Assert(crow % 4 == 0);
Contracts.Assert(ccol % 4 == 0);

fixed (float* psrc = &MemoryMarshal.GetReference(src))
fixed (float* pdst = &MemoryMarshal.GetReference(dst))
fixed (float* pmat = &MemoryMarshal.GetReference(mat))
fixed (uint* pLeadingAlignmentMask = &LeadingAlignmentMask[0])
fixed (uint* pTrailingAlignmentMask = &TrailingAlignmentMask[0])
{
Expand Down Expand Up @@ -312,32 +293,134 @@ public static unsafe void MatMulX(float[] mat, float[] src, float[] dst, int cro
}

// Partial sparse source vector.
public static unsafe void MatMulPX(AlignedArray mat, int[] rgposSrc, AlignedArray src,
int posMin, int iposMin, int iposEnd, AlignedArray dst, int crow, int ccol)
public static unsafe void MatMulP(AlignedArray mat, ReadOnlySpan<int> rgposSrc, AlignedArray src,
int posMin, int iposMin, int iposEnd, AlignedArray dst, int crow, int ccol)
{
Contracts.Assert(HasCompatibleAlignment(mat));
Contracts.Assert(HasCompatibleAlignment(src));
Contracts.Assert(HasCompatibleAlignment(dst));
MatMulP(mat.Items, rgposSrc, src.Items, posMin, iposMin, iposEnd, dst.Items, crow, ccol);
}

public static unsafe void MatMulP(ReadOnlySpan<float> mat, ReadOnlySpan<int> rgposSrc, ReadOnlySpan<float> src,
int posMin, int iposMin, int iposEnd, Span<float> dst, int crow, int ccol)
{
Contracts.Assert(crow % 8 == 0);
Contracts.Assert(ccol % 8 == 0);

// REVIEW: For extremely sparse inputs, interchanging the loops would
// likely be more efficient.
fixed (float* pSrcStart = &src.Items[0])
fixed (float* pDstStart = &dst.Items[0])
fixed (float* pMatStart = &mat.Items[0])
fixed (int* pposSrc = &rgposSrc[0])
fixed (float* psrc = &MemoryMarshal.GetReference(src))
fixed (float* pdst = &MemoryMarshal.GetReference(dst))
fixed (float* pmat = &MemoryMarshal.GetReference(mat))
fixed (int* pposSrc = &MemoryMarshal.GetReference(rgposSrc))
fixed (uint* pLeadingAlignmentMask = &LeadingAlignmentMask[0])
fixed (uint* pTrailingAlignmentMask = &TrailingAlignmentMask[0])
{
float* psrc = GetAlignedBase(src, pSrcStart);
float* pdst = GetAlignedBase(dst, pDstStart);
float* pmat = GetAlignedBase(mat, pMatStart);

int* pposMin = pposSrc + iposMin;
int* pposEnd = pposSrc + iposEnd;
float* pDstEnd = pdst + crow;
float* pm0 = pmat - posMin;
float* pSrcCurrent = psrc - posMin;
float* pDstCurrent = pdst;

while (pDstCurrent < pDstEnd)
nuint address = (nuint)(pDstCurrent);
int misalignment = (int)(address % 32);
int length = crow;
int remainder = 0;

if ((misalignment & 3) != 0)
{
while (pDstCurrent < pDstEnd)
{
Avx.Store(pDstCurrent, SparseMultiplicationAcrossRow());
pDstCurrent += 8;
pm0 += 8 * ccol;
}
}
else
{
if (misalignment != 0)
{
misalignment >>= 2;
misalignment = 8 - misalignment;

Vector256<float> mask = Avx.LoadVector256(((float*)(pLeadingAlignmentMask)) + (misalignment * 8));

float* pm1 = pm0 + ccol;
float* pm2 = pm1 + ccol;
float* pm3 = pm2 + ccol;
Vector256<float> result = Avx.SetZeroVector256<float>();

int* ppos = pposMin;

while (ppos < pposEnd)
{
int col1 = *ppos;
int col2 = col1 + 4 * ccol;
Vector256<float> x1 = Avx.SetVector256(pm3[col2], pm2[col2], pm1[col2], pm0[col2],
pm3[col1], pm2[col1], pm1[col1], pm0[col1]);

x1 = Avx.And(mask, x1);
Vector256<float> x2 = Avx.SetAllVector256(pSrcCurrent[col1]);
result = MultiplyAdd(x2, x1, result);
ppos++;
}

Avx.Store(pDstCurrent, result);
pDstCurrent += misalignment;
pm0 += misalignment * ccol;
length -= misalignment;
}

if (length > 7)
{
remainder = length % 8;
while (pDstCurrent < pDstEnd)
{
Avx.Store(pDstCurrent, SparseMultiplicationAcrossRow());
pDstCurrent += 8;
pm0 += 8 * ccol;
}
}
else
{
remainder = length;
}

if (remainder != 0)
{
pDstCurrent -= (8 - remainder);
pm0 -= (8 - remainder) * ccol;
Vector256<float> trailingMask = Avx.LoadVector256(((float*)(pTrailingAlignmentMask)) + (remainder * 8));
Vector256<float> leadingMask = Avx.LoadVector256(((float*)(pLeadingAlignmentMask)) + ((8 - remainder) * 8));

float* pm1 = pm0 + ccol;
float* pm2 = pm1 + ccol;
float* pm3 = pm2 + ccol;
Vector256<float> result = Avx.SetZeroVector256<float>();

int* ppos = pposMin;

while (ppos < pposEnd)
{
int col1 = *ppos;
int col2 = col1 + 4 * ccol;
Vector256<float> x1 = Avx.SetVector256(pm3[col2], pm2[col2], pm1[col2], pm0[col2],
pm3[col1], pm2[col1], pm1[col1], pm0[col1]);
x1 = Avx.And(x1, trailingMask);

Vector256<float> x2 = Avx.SetAllVector256(pSrcCurrent[col1]);
result = MultiplyAdd(x2, x1, result);
ppos++;
}

result = Avx.Add(result, Avx.And(leadingMask, Avx.LoadVector256(pDstCurrent)));

Avx.Store(pDstCurrent, result);
pDstCurrent += 8;
pm0 += 8 * ccol;
}
}

Vector256<float> SparseMultiplicationAcrossRow()
{
float* pm1 = pm0 + ccol;
float* pm2 = pm1 + ccol;
Expand All @@ -351,33 +434,30 @@ public static unsafe void MatMulPX(AlignedArray mat, int[] rgposSrc, AlignedArra
int col1 = *ppos;
int col2 = col1 + 4 * ccol;
Vector256<float> x1 = Avx.SetVector256(pm3[col2], pm2[col2], pm1[col2], pm0[col2],
pm3[col1], pm2[col1], pm1[col1], pm0[col1]);
pm3[col1], pm2[col1], pm1[col1], pm0[col1]);
Vector256<float> x2 = Avx.SetAllVector256(pSrcCurrent[col1]);
result = MultiplyAdd(x2, x1, result);

ppos++;
}

Avx.StoreAligned(pDstCurrent, result);
pDstCurrent += 8;
pm0 += 8 * ccol;
return result;
}
}
}

public static unsafe void MatMulTranX(AlignedArray mat, AlignedArray src, AlignedArray dst, int crow, int ccol)
public static unsafe void MatMulTran(AlignedArray mat, AlignedArray src, AlignedArray dst, int crow, int ccol)
{
Contracts.Assert(crow % 4 == 0);
Contracts.Assert(ccol % 4 == 0);

MatMulTranX(mat.Items, src.Items, dst.Items, crow, ccol);
MatMulTran(mat.Items, src.Items, dst.Items, crow, ccol);
}

public static unsafe void MatMulTranX(float[] mat, float[] src, float[] dst, int crow, int ccol)
public static unsafe void MatMulTran(ReadOnlySpan<float> mat, ReadOnlySpan<float> src, Span<float> dst, int crow, int ccol)
{
fixed (float* psrc = &src[0])
fixed (float* pdst = &dst[0])
fixed (float* pmat = &mat[0])
Contracts.Assert(crow % 4 == 0);
Contracts.Assert(ccol % 4 == 0);

fixed (float* psrc = &MemoryMarshal.GetReference(src))
fixed (float* pdst = &MemoryMarshal.GetReference(dst))
fixed (float* pmat = &MemoryMarshal.GetReference(mat))
fixed (uint* pLeadingAlignmentMask = &LeadingAlignmentMask[0])
fixed (uint* pTrailingAlignmentMask = &TrailingAlignmentMask[0])
{
Expand Down
8 changes: 4 additions & 4 deletions src/Microsoft.ML.CpuMath/CpuMathUtils.netcoreapp.cs
Original file line number Diff line number Diff line change
Expand Up @@ -34,12 +34,12 @@ public static void MatTimesSrc(bool tran, AlignedArray mat, AlignedArray src, Al
if (!tran)
{
Contracts.Assert(crun <= dst.Size);
AvxIntrinsics.MatMulX(mat, src, dst, crun, src.Size);
AvxIntrinsics.MatMul(mat, src, dst, crun, src.Size);
}
else
{
Contracts.Assert(crun <= src.Size);
AvxIntrinsics.MatMulTranX(mat, src, dst, dst.Size, crun);
AvxIntrinsics.MatMulTran(mat, src, dst, dst.Size, crun);
}
}
else if (Sse.IsSupported)
Expand Down Expand Up @@ -109,12 +109,12 @@ public static void MatTimesSrc(AlignedArray mat, int[] rgposSrc, AlignedArray sr
if (Avx.IsSupported)
{
Contracts.Assert(crun <= dst.Size);
AvxIntrinsics.MatMulPX(mat, rgposSrc, srcValues, posMin, iposMin, iposLim, dst, crun, srcValues.Size);
AvxIntrinsics.MatMulP(mat, rgposSrc, srcValues, posMin, iposMin, iposLim, dst, crun, srcValues.Size);
}
else if (Sse.IsSupported)
{
Contracts.Assert(crun <= dst.Size);
SseIntrinsics.MatMulPA(mat, rgposSrc, srcValues, posMin, iposMin, iposLim, dst, crun, srcValues.Size);
SseIntrinsics.MatMulP(mat, rgposSrc, srcValues, posMin, iposMin, iposLim, dst, crun, srcValues.Size);
}
else
{
Expand Down
Loading

0 comments on commit 263a67b

Please sign in to comment.