From df249f9eef26571e820601b2d218bd0386946685 Mon Sep 17 00:00:00 2001 From: beshrislambouli Date: Sun, 20 Jul 2025 15:40:03 -0400 Subject: [PATCH] add Matrix Transpose CPU annot, impl, and test --- src/codegen/codegen.cpp | 4 +-- test/library/matrix/annot/cpu-matrix.h | 46 +++++++++++++++++++++++++ test/library/matrix/impl/cpu-matrix.h | 12 +++++++ test/test-cpu-matrix.cpp | 47 +++++++++++++++++++++++++- 4 files changed, 106 insertions(+), 3 deletions(-) diff --git a/src/codegen/codegen.cpp b/src/codegen/codegen.cpp index 5ce90b1d..641886cc 100644 --- a/src/codegen/codegen.cpp +++ b/src/codegen/codegen.cpp @@ -12,7 +12,7 @@ #include "codegen/lower.h" #include "utils/debug.h" #include "utils/error.h" - +#include namespace gern { namespace codegen { @@ -734,4 +734,4 @@ CGStmt CodeGenerator::assertGrid(const Grid::Dim &dim) { } // namespace codegen -} // namespace gern \ No newline at end of file +} // namespace gern diff --git a/test/library/matrix/annot/cpu-matrix.h b/test/library/matrix/annot/cpu-matrix.h index 0249903a..328f466b 100644 --- a/test/library/matrix/annot/cpu-matrix.h +++ b/test/library/matrix/annot/cpu-matrix.h @@ -125,6 +125,52 @@ class MatrixAddCPU : public AbstractFunction { Variable end{"end"}; }; +class MatrixTransposeCPU : public AbstractFunction { +public: + MatrixTransposeCPU() + : input(new const MatrixCPU("input")), + output(new const MatrixCPU("output")) { + } + std::string getName() { + return "gern::impl::transpose"; + } + + Annotation getAnnotation() override { + + Variable x("x"); + Variable y("y"); + Variable l_x("l_x"); + Variable l_y("l_y"); + + Variable row("row"); + Variable col("col"); + + return annotate(Tileable(x = Expr(0), output["row"], l_x, + Tileable(y = Expr(0), output["col"], l_y, + Produces::Subset(output, {x, y, l_x, l_y}), + Consumes::Subset(input, {y, x, l_y, l_x})))); + } + + std::vector getHeader() override { + return { + "cpu-matrix.h", + }; + } + + virtual FunctionSignature getFunction() override { + FunctionSignature f; + f.name = "gern::impl::transpose"; + f.args = {Parameter(input), Parameter(output)}; + return f; + } + +protected: + AbstractDataTypePtr input; + AbstractDataTypePtr output; + Variable end{"end"}; + +}; + class SumRow : public AbstractFunction { public: SumRow() diff --git a/test/library/matrix/impl/cpu-matrix.h b/test/library/matrix/impl/cpu-matrix.h index d4593de1..f675b141 100644 --- a/test/library/matrix/impl/cpu-matrix.h +++ b/test/library/matrix/impl/cpu-matrix.h @@ -131,6 +131,18 @@ inline void add(MatrixCPU a, MatrixCPU b) { } } +inline void transpose(MatrixCPU a, MatrixCPU b) { + float *a_data; + float *b_data; + for(int64_t i = 0; i < b.row; i++) { + b_data = b.data + (i * b.lda); + for(int64_t j = 0; j < b.col; j++) { + a_data = a.data + (j * a.lda); + b_data[j] = a_data [i] ; + } + } +} + inline void exp_matrix(MatrixCPU a, MatrixCPU b) { float *a_data; float *b_data; diff --git a/test/test-cpu-matrix.cpp b/test/test-cpu-matrix.cpp index 9e4797e1..8464a388 100644 --- a/test/test-cpu-matrix.cpp +++ b/test/test-cpu-matrix.cpp @@ -145,4 +145,49 @@ TEST(LoweringCPU, MatrixCPUAdd) { // b.destroy(); // reference.destroy(); // max_row_ref.destroy(); -// } \ No newline at end of file +// } + +TEST(LoweringCPU, MatrixCPUTranspose) { + auto inputDS = AbstractDataTypePtr(new const annot::MatrixCPU("input")); + auto outputDS = AbstractDataTypePtr(new const annot::MatrixCPU("output")); + + annot::MatrixTransposeCPU transpose; + Variable l_x("l_x"); + Variable l_y("l_y"); + + Composable program = { + (Tile(outputDS["row"], l_x))( + Tile(outputDS["col"], l_y)( + transpose(inputDS, outputDS)))}; + + Runner run(program); + + run.compile(test::cpuRunner(std::vector{"matrix"})); + + int64_t row_val_input = 21; + int64_t col_val_input = 15; + impl::MatrixCPU a(row_val_input, col_val_input, col_val_input); + a.ascending(); + int64_t row_val_output = 15; + int64_t col_val_output = 21; + impl::MatrixCPU b(row_val_output, col_val_output, col_val_output); + + int64_t l_x_val = 5; + int64_t l_y_val = 7; + + ASSERT_NO_THROW(run.evaluate({ + {inputDS.getName(), &a}, + {outputDS.getName(), &b}, + {l_x.getName(), &l_x_val}, + {l_y.getName(), &l_y_val}, + })); + + for (int i = 0 ; i < row_val_input ; i ++ ) { + for (int j = 0 ; j < col_val_input ; j ++ ) { + ASSERT_EQ(a(i,j), b(j,i)); // note: a(i,j) = a.data[i * a.lda + j] + } + } + + a.destroy(); + b.destroy(); +} \ No newline at end of file