Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(#102): Implementation of CudaMatrixVectorProduct #102 #106

Merged
merged 1 commit into from
Nov 7, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 11 additions & 0 deletions kernels/cuda/matrixVectorMultiplicationKernel.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
extern "C"
__global__ void matrixVectorMultiplicationKernel(double *matrix, double *vector, double *result, int numRows, int numCols) {
int row = blockIdx.x * blockDim.x + threadIdx.x;
if (row < numRows) {
double sum = 0.0;
for (int col = 0; col < numCols; ++col) {
sum += matrix[row * numCols + col] * vector[col];
}
result[row] = sum;
}
}
112 changes: 112 additions & 0 deletions lib/cuda_kernels/matrixVectorMultiplicationKernel.ptx
Original file line number Diff line number Diff line change
@@ -0,0 +1,112 @@
//
// Generated by NVIDIA NVVM Compiler
//
// Compiler Build ID: CL-31833905
// Cuda compilation tools, release 11.8, V11.8.89
// Based on NVVM 7.0.1
//

.version 7.8
.target sm_52
.address_size 64

// .globl matrixVectorMultiplicationKernel

.visible .entry matrixVectorMultiplicationKernel(
.param .u64 matrixVectorMultiplicationKernel_param_0,
.param .u64 matrixVectorMultiplicationKernel_param_1,
.param .u64 matrixVectorMultiplicationKernel_param_2,
.param .u32 matrixVectorMultiplicationKernel_param_3,
.param .u32 matrixVectorMultiplicationKernel_param_4
)
{
.reg .pred %p<7>;
.reg .b32 %r<25>;
.reg .f64 %fd<30>;
.reg .b64 %rd<28>;


ld.param.u64 %rd15, [matrixVectorMultiplicationKernel_param_0];
ld.param.u64 %rd16, [matrixVectorMultiplicationKernel_param_1];
ld.param.u64 %rd14, [matrixVectorMultiplicationKernel_param_2];
ld.param.u32 %r12, [matrixVectorMultiplicationKernel_param_3];
ld.param.u32 %r11, [matrixVectorMultiplicationKernel_param_4];
cvta.to.global.u64 %rd1, %rd16;
cvta.to.global.u64 %rd2, %rd15;
mov.u32 %r13, %ntid.x;
mov.u32 %r14, %ctaid.x;
mov.u32 %r15, %tid.x;
mad.lo.s32 %r1, %r14, %r13, %r15;
setp.ge.s32 %p1, %r1, %r12;
@%p1 bra $L__BB0_9;

setp.lt.s32 %p2, %r11, 1;
mov.f64 %fd29, 0d0000000000000000;
@%p2 bra $L__BB0_8;

add.s32 %r17, %r11, -1;
and.b32 %r24, %r11, 3;
setp.lt.u32 %p3, %r17, 3;
mov.f64 %fd29, 0d0000000000000000;
mov.u32 %r23, 0;
@%p3 bra $L__BB0_5;

sub.s32 %r22, %r11, %r24;
mul.lo.s32 %r19, %r11, %r1;
mul.wide.s32 %rd17, %r19, 8;
add.s64 %rd18, %rd2, %rd17;
add.s64 %rd25, %rd18, 16;
mov.u64 %rd24, %rd1;

$L__BB0_4:
ld.global.f64 %fd12, [%rd24];
ld.global.f64 %fd13, [%rd25+-16];
fma.rn.f64 %fd14, %fd13, %fd12, %fd29;
ld.global.f64 %fd15, [%rd24+8];
ld.global.f64 %fd16, [%rd25+-8];
fma.rn.f64 %fd17, %fd16, %fd15, %fd14;
ld.global.f64 %fd18, [%rd24+16];
ld.global.f64 %fd19, [%rd25];
fma.rn.f64 %fd20, %fd19, %fd18, %fd17;
ld.global.f64 %fd21, [%rd24+24];
ld.global.f64 %fd22, [%rd25+8];
fma.rn.f64 %fd29, %fd22, %fd21, %fd20;
add.s32 %r23, %r23, 4;
add.s64 %rd25, %rd25, 32;
add.s64 %rd24, %rd24, 32;
add.s32 %r22, %r22, -4;
setp.ne.s32 %p4, %r22, 0;
@%p4 bra $L__BB0_4;

$L__BB0_5:
setp.eq.s32 %p5, %r24, 0;
@%p5 bra $L__BB0_8;

mul.wide.s32 %rd19, %r23, 8;
add.s64 %rd27, %rd1, %rd19;
mad.lo.s32 %r20, %r11, %r1, %r23;
mul.wide.s32 %rd20, %r20, 8;
add.s64 %rd26, %rd2, %rd20;

$L__BB0_7:
.pragma "nounroll";
ld.global.f64 %fd23, [%rd27];
ld.global.f64 %fd24, [%rd26];
fma.rn.f64 %fd29, %fd24, %fd23, %fd29;
add.s64 %rd27, %rd27, 8;
add.s64 %rd26, %rd26, 8;
add.s32 %r24, %r24, -1;
setp.ne.s32 %p6, %r24, 0;
@%p6 bra $L__BB0_7;

$L__BB0_8:
cvta.to.global.u64 %rd21, %rd14;
mul.wide.s32 %rd22, %r1, 8;
add.s64 %rd23, %rd21, %rd22;
st.global.f64 [%rd23], %fd29;

$L__BB0_9:
ret;

}

Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
package de.edux.core.math.matrix.cuda;

import jcuda.driver.CUfunction;

public interface CUDAKernelUser {
CUfunction loadKernel();
}
Original file line number Diff line number Diff line change
Expand Up @@ -64,11 +64,27 @@ private void printCudaDeviceInformation() {

@Override
public double[][] multiply(double[][] matrixA, double[][] matrixB) {
if (matrixA == null || matrixB == null) {
throw new IllegalArgumentException("Matrices must not be null.");
}
if (matrixA.length == 0 || matrixB.length == 0) {
throw new IllegalArgumentException("Matrices must not be empty.");
}
if (matrixA[0].length != matrixB.length) {
throw new IllegalArgumentException("Matrix A columns must match Matrix B rows.");
}

return matrixProduct.multiply(matrixA, matrixB);
}

@Override
public double[] multiply(double[][] matrix, double[] vector) {
if (matrix.length == 0 || matrix[0].length == 0 || vector.length == 0) {
throw new IllegalArgumentException("Matrix and vector must not be empty.");
}
if (matrix[0].length != vector.length) {
throw new IllegalArgumentException("Matrix columns and vector size do not match.");
}
return matrixVectorProduct.multiply(matrix, vector);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,14 @@
import static jcuda.driver.JCudaDriver.cuMemFree;

import de.edux.core.math.IMatrixProduct;
import de.edux.core.math.matrix.cuda.CUDAKernelUser;
import java.io.File;
import java.util.List;
import jcuda.Pointer;
import jcuda.Sizeof;
import jcuda.driver.*;

public class CudaMatrixProduct implements IMatrixProduct {
public class CudaMatrixProduct implements IMatrixProduct, CUDAKernelUser {

static {
JCudaDriver.setExceptionsEnabled(true);
Expand All @@ -31,12 +32,7 @@ public double[][] multiply(double[][] matrixA, double[][] matrixB) {
throw new IllegalArgumentException("Inner dimensions do not match.");
}

String ptxFileName = preparePtxFile();

CUmodule module = new CUmodule();
cuModuleLoad(module, ptxFileName);
CUfunction function = new CUfunction();
cuModuleGetFunction(function, module, "matrixMultiply");
CUfunction function = loadKernel();

double[] hostInputA = flatten(matrixA);
double[] hostInputB = flatten(matrixB);
Expand Down Expand Up @@ -113,7 +109,14 @@ private double[] flatten(double[][] matrix) {
return flat;
}

private String preparePtxFile() {
return "cuda_kernels" + File.separator + "matrixMultiplicationKernel.ptx";
// CUDAKernelUser
@Override
public CUfunction loadKernel() {
String ptxFileName = "cuda_kernels" + File.separator + "matrixMultiplicationKernel.ptx";
CUmodule module = new CUmodule();
cuModuleLoad(module, ptxFileName);
CUfunction function = new CUfunction();
cuModuleGetFunction(function, module, "matrixMultiply");
return function;
}
}
Original file line number Diff line number Diff line change
@@ -1,10 +1,110 @@
package de.edux.core.math.matrix.cuda.operations;

import static jcuda.driver.JCudaDriver.*;
import static jcuda.driver.JCudaDriver.cuMemFree;

import de.edux.core.math.IMatrixVectorProduct;
import de.edux.core.math.matrix.cuda.CUDAKernelUser;
import java.io.File;
import java.util.List;
import jcuda.Pointer;
import jcuda.Sizeof;
import jcuda.driver.*;
import jcuda.driver.CUfunction;

public class CudaMatrixVectorProduct implements IMatrixVectorProduct, CUDAKernelUser {

static {
JCudaDriver.setExceptionsEnabled(true);
cuInit(0);
CUdevice device = new CUdevice();
cuDeviceGet(device, 0);
CUcontext context = new CUcontext();
cuCtxCreate(context, 0, device);
}

public class CudaMatrixVectorProduct implements IMatrixVectorProduct {
@Override
public double[] multiply(double[][] matrix, double[] vector) {
return new double[0];
int numRows = matrix.length;
int numCols = matrix[0].length;

if (numCols != vector.length) {
throw new IllegalArgumentException("Matrix columns and vector size do not match.");
}

CUfunction function = loadKernel();

double[] hostMatrix = flatten(matrix);
double[] hostVector = vector.clone();
double[] hostOutput = new double[numRows];

CUdeviceptr deviceMatrix = new CUdeviceptr();
cuMemAlloc(deviceMatrix, hostMatrix.length * Sizeof.DOUBLE);
cuMemcpyHtoD(deviceMatrix, Pointer.to(hostMatrix), hostMatrix.length * Sizeof.DOUBLE);

CUdeviceptr deviceVector = new CUdeviceptr();
cuMemAlloc(deviceVector, hostVector.length * Sizeof.DOUBLE);
cuMemcpyHtoD(deviceVector, Pointer.to(hostVector), hostVector.length * Sizeof.DOUBLE);

CUdeviceptr deviceOutput = new CUdeviceptr();
cuMemAlloc(deviceOutput, hostOutput.length * Sizeof.DOUBLE);

Pointer kernelParameters =
Pointer.to(
Pointer.to(deviceMatrix),
Pointer.to(deviceVector),
Pointer.to(deviceOutput),
Pointer.to(new int[] {numRows}),
Pointer.to(new int[] {numCols}));

int blockSize = 256; // This should be tuned according to your hardware capability
int gridSize = (int) Math.ceil((double) numRows / blockSize);
cuLaunchKernel(
function,
gridSize,
1,
1, // Grid dimension
blockSize,
1,
1, // Block dimension
0,
null, // Shared memory size and stream
kernelParameters,
null // Kernel- and extra parameters
);
cuCtxSynchronize();

cuMemcpyDtoH(Pointer.to(hostOutput), deviceOutput, hostOutput.length * Sizeof.DOUBLE);

List<CUdeviceptr> list = List.of(deviceMatrix, deviceVector, deviceOutput);
cleanUp(list);

return hostOutput;
}

private void cleanUp(List<CUdeviceptr> devicePtrs) {
for (CUdeviceptr devicePtr : devicePtrs) {
cuMemFree(devicePtr);
}
}

private double[] flatten(double[][] matrix) {
int rows = matrix.length;
int cols = matrix[0].length;
double[] flat = new double[rows * cols];
for (int i = 0; i < rows; i++) {
System.arraycopy(matrix[i], 0, flat, i * cols, cols);
}
return flat;
}

@Override
public CUfunction loadKernel() {
String ptxFileName = "cuda_kernels" + File.separator + "matrixVectorMultiplicationKernel.ptx";
CUmodule module = new CUmodule();
cuModuleLoad(module, ptxFileName);
CUfunction function = new CUfunction();
cuModuleGetFunction(function, module, "matrixVectorMultiplicationKernel");
return function;
}
}
Loading