Skip to content

Commit

Permalink
fixed some bugs
Browse files Browse the repository at this point in the history
  • Loading branch information
tarinduj committed Jan 17, 2025
1 parent 7a26e44 commit 625633f
Show file tree
Hide file tree
Showing 3 changed files with 50 additions and 47 deletions.
4 changes: 4 additions & 0 deletions mlir/include/mlir/Analysis/Presburger/Matrix-impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,10 @@ void Matrix<Int>::resize(unsigned newNRows, unsigned newNColumns) {
throwOverflowIf(true);
else
std::abort();
} else if (isMatrixized)
{
if (newNColumns > MatrixSize || newNRows > MatrixSize)
std::abort();
}
unsigned newNReservedColumns = nextPowOfTwo(newNColumns);
data.resize(newNRows * newNReservedColumns);
Expand Down
91 changes: 45 additions & 46 deletions mlir/include/mlir/Analysis/Presburger/Simplex-impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -476,11 +476,17 @@ void Simplex<Int>::pivot(Pivot pair) { pivot(pair.row, pair.column); }
}

template <typename Int>
inline void Simplex<Int>::SMEPivotHelper(Int *matrix, int rows, int cols, int pivot_row, int pivot_col) {
inline void Simplex<Int>::SMEPivotHelper(Int *matrix, int rows, int cols, int pivot_row, int pivot_col, int reserved_cols) {

// Matrix<Int, numRows, numCols> test_matrix;
// test_matrix.initZero();

// Int c1 = 0;
// Int c2 = 0;

// printf("coeff1: %f\n", matrix[pivot_row * cols + 0]);
// printf("coeff2: %f\n", matrix[pivot_row * cols + pivot_col]);

__asm__ __volatile__(

// IMP: smstart za only enables SME and not SVE. So, use smart to enable both.
Expand All @@ -490,11 +496,12 @@ inline void Simplex<Int>::SMEPivotHelper(Int *matrix, int rows, int cols, int pi
"mov x1, %[src] \n" // Source matrix pointer
// "mov x6, %[test_mat] \n" // Test matrix pointer

"mov w2, %w[coeff1] \n" // Coefficient 1/ Coefficient 3
"mov w3, %w[coeff2] \n" // Coefficient 2
"fmov s2, %w[coeff1] \n" // Coefficient 1/ Coefficient 3
"fmov s3, %w[coeff2] \n" // Coefficient 2

"mov w4, %w[nrows] \n" // Move rows to w4
"mov w5, %w[ncols] \n" // Move cols to w5
"mov w6, %w[reservedcols] \n" // Move reservedcols to w6

"mov x12, #0 \n" // Zeroth column index
"mov w13, %w[prow] \n" // Move pivot_row to w13
Expand Down Expand Up @@ -527,7 +534,7 @@ inline void Simplex<Int>::SMEPivotHelper(Int *matrix, int rows, int cols, int pi

// Increment loop counter
"add w15, w15, #1 \n" // i++
"add x9, x9, x5 \n" // Increment offset by number of columns
"add x9, x9, x6 \n" // Increment offset by number of columns

// Loop back
"b 1b \n"
Expand All @@ -538,7 +545,9 @@ inline void Simplex<Int>::SMEPivotHelper(Int *matrix, int rows, int cols, int pi
"mov z0.s, p0/m, za0h.s[w13, 0] \n" // Move pivot row from ZA to z0
"mov z1.s, p0/m, za0v.s[w14, 0] \n" // Move pivot column from ZA to z1

/* ******************** */
// If the denominator is negative, we canonicalize the row.
"fcvtzs w2, s2 \n" // Convert coeff1 to integer
"cmp w2, #0 \n" // Compare w2 (coeff1) with 0
"b.ge 1f \n" // If >= 0, skip negation

Expand All @@ -547,10 +556,15 @@ inline void Simplex<Int>::SMEPivotHelper(Int *matrix, int rows, int cols, int pi
"mov za0h.s[w13, 0], p0/m, z0.s \n" // Store back negated pivot row to ZA

// Negate coeff1 and coeff2
"neg w2, w2 \n"
"neg w3, w3 \n"
"fneg s2, s2 \n" // Negate coeff1
"fneg s3, s3 \n" // Negate coeff2

"1: \n"
/* ******************** */

// move to general purpose registers
"fmov w2, s2 \n" // Coefficient 1/ Coefficient 3
"fmov w3, s3 \n" // Coefficient 2

// Take the outer product of the pivot row and pivot column
"zero {za1.s} \n" // Zero ZA
Expand All @@ -559,7 +573,7 @@ inline void Simplex<Int>::SMEPivotHelper(Int *matrix, int rows, int cols, int pi
// TODO: Reimplement the above instruction using SMOPA


/* *************** */
/* ******************** */
// Masked Column Multiplies

"mov z0.s, p0/m, za0v.s[W12, 0] \n" // Move zeroth column from ZA to z0
Expand All @@ -586,38 +600,8 @@ inline void Simplex<Int>::SMEPivotHelper(Int *matrix, int rows, int cols, int pi
"mov za0v.s[w12, 0], p0/m, z0.s \n" // Move zeroth column from z0 to ZA
"mov za0v.s[w14, 0], p0/m, z1.s \n" // Move pivot column from z1 to ZA

// /*========================================================*/
// // test matrix store

// "whilelo p0.s, xzr, x5 \n" // Set predicate register to cols to mask seg faults

// // Initialize registers
// "mov w15, #0 \n" // Loop counter i = 0
// "mov x9, #0 \n" // Offset in source matrix

// // Loop label
// "1: \n"
// "cmp w15, w4 \n" // Compare i with nrows
// "b.ge 2f \n" // If i >= nrows, exit loop

// // Store ith row of ZA into matrix
// "st1w {za0h.s[w15, 0]}, p0, [x6, x9, lsl #2] \n"

// // Increment loop counter
// "add w15, w15, #1 \n" // i++
// "add x9, x9, x5 \n" // Increment offset by number of columns

// // Loop back
// "b 1b \n"

// // Loop exit label
// "2: \n"

// "ptrue p0.s \n" // Predicate vector

// /*========================================================*/

/* *************** */
/* ******************** */
// Masked Multiply-Add

// create a predicate with masking pivot_row
Expand Down Expand Up @@ -690,6 +674,7 @@ inline void Simplex<Int>::SMEPivotHelper(Int *matrix, int rows, int cols, int pi
// Loop exit label
"4: \n"

/* ******************** */
// Store the result back into the matrix
"whilelo p0.s, xzr, x5 \n" // Set predicate register to cols to mask seg faults

Expand All @@ -707,27 +692,35 @@ inline void Simplex<Int>::SMEPivotHelper(Int *matrix, int rows, int cols, int pi

// Increment loop counter
"add w15, w15, #1 \n" // i++
"add x9, x9, x5 \n" // Increment offset by number of columns
"add x9, x9, x6 \n" // Increment offset by number of columns

// Loop back
"b 1b \n"

// Loop exit label
"2: \n"

// "str w2, %[c1] \n" // Coefficient 1/ Coefficient 3
// "str w3, %[c2] \n" // Coefficient 2

"smstop \n"
:
: /*[c1] "=m"(c1),
[c2] "=m"(c2)*/
: [src] "r"(matrix),
/*[test_mat] "r"(&test_matrix.m[0][0]),*/
/* [test_mat] "r"(&test_matrix.m[0][0]), */
[prow] "r"(pivot_row),
[pcol] "r"(pivot_col),
[nrows] "r"(rows),
[ncols] "r"(cols),
[coeff1] "r"(matrix[pivot_row * cols + 0]),
[coeff2] "r"(matrix[pivot_row * cols + pivot_col])
: "x1", "x2", "x3", "x4", "x5", "x6", "x9", "x12", "x14", "x15", "w13", "w14", "w15", "za", "z0", "z1", "p0", "p1"
[reservedcols] "r"(reserved_cols),
[coeff1] "r"(matrix[pivot_row * reserved_cols + 0]),
[coeff2] "r"(matrix[pivot_row * reserved_cols + pivot_col])
: "x1", "x2", "x3", "x4", "x5", "x6", "x9", "x12", "x14", "x15", "w13", "w14", "w15", "za", "z0", "z1", "p0", "p1", "s2", "s3"
);

// printf("coeff1: %f\n", c1);
// printf("coeff2: %f\n", c2);

}

template <typename Int>
Expand Down Expand Up @@ -777,7 +770,13 @@ void Simplex<Int>::pivot(unsigned pivotRow, unsigned pivotCol) {
tableau(pivotRow, 0) = -tableau(pivotRow, pivotCol);
tableau(pivotRow, pivotCol) = -tmp;

SMEPivotHelper(dataptr, tableau.getNumRows(), tableau.getNumColumns(), pivotRow, pivotCol);
// if (tableau(pivotRow, 0) < 0) {
// for (unsigned col = 0; col < tableau.getNumColumns(); col++) {
// tableau(pivotRow, col) = -tableau(pivotRow, col);
// }
// }

SMEPivotHelper(dataptr, tableau.getNumRows(), tableau.getNumColumns(), pivotRow, pivotCol, tableau.getNReservedColumns());

// for (unsigned row = 0; row < nRow; ++row) {
// normalizeRowScalar(row);
Expand Down
2 changes: 1 addition & 1 deletion mlir/include/mlir/Analysis/Presburger/Simplex.h
Original file line number Diff line number Diff line change
Expand Up @@ -323,7 +323,7 @@ class Simplex {
void swapRowWithCol(unsigned row, unsigned col);

/// Pivot the row with the column.
inline void SMEPivotHelper(Int *matrix, int rows, int cols, int pivot_row, int pivot_col);
inline void SMEPivotHelper(Int *matrix, int rows, int cols, int pivot_row, int pivot_col, int reserved_col );
void pivot(unsigned row, unsigned col);
void pivot(Pivot pair);

Expand Down

0 comments on commit 625633f

Please sign in to comment.