Skip to content

Commit

Permalink
Merge pull request #2 from ROCmSoftwarePlatform/mlir-rocblas-succeeds
Browse files Browse the repository at this point in the history
inc files added
  • Loading branch information
weihanmines authored Jan 19, 2022
2 parents dc0829b + 8563310 commit d9b2dec
Show file tree
Hide file tree
Showing 5 changed files with 113 additions and 92 deletions.
33 changes: 16 additions & 17 deletions third_party/hip/hip_stub.cc.inc
Original file line number Diff line number Diff line change
Expand Up @@ -12,16 +12,6 @@ hipError_t hipRuntimeGetVersion(int* runtimeVersion) {
"hipRuntimeGetVersion", runtimeVersion);
}

hipError_t hipGetLastError(void) {
return DynamicCall<decltype(hipGetLastError), &hipGetLastError>(
"hipGetLastError");
}

hipError_t hipPeekAtLastError(void) {
return DynamicCall<decltype(hipPeekAtLastError), &hipPeekAtLastError>(
"hipPeekAtLastError");
}

hipError_t hipDeviceGet(hipDevice_t* device, int ordinal) {
return DynamicCall<decltype(hipDeviceGet), &hipDeviceGet>("hipDeviceGet",
device, ordinal);
Expand Down Expand Up @@ -73,6 +63,16 @@ hipError_t hipDeviceGetLimit(size_t* pValue, enum hipLimit_t limit) {
"hipDeviceGetLimit", pValue, limit);
}

hipError_t hipGetLastError(void) {
return DynamicCall<decltype(hipGetLastError), &hipGetLastError>(
"hipGetLastError");
}

hipError_t hipPeekAtLastError(void) {
return DynamicCall<decltype(hipPeekAtLastError), &hipPeekAtLastError>(
"hipPeekAtLastError");
}

hipError_t hipStreamCreateWithFlags(hipStream_t* stream, unsigned int flags) {
return DynamicCall<decltype(hipStreamCreateWithFlags),
&hipStreamCreateWithFlags>("hipStreamCreateWithFlags",
Expand Down Expand Up @@ -209,6 +209,12 @@ hipError_t hipMemcpy(void* dst, const void* src, size_t sizeBytes,
sizeBytes, kind);
}

hipError_t hipModuleGetGlobal(hipDeviceptr_t* dptr, size_t* bytes,
hipModule_t hmod, const char* name) {
return DynamicCall<decltype(hipModuleGetGlobal), &hipModuleGetGlobal>(
"hipModuleGetGlobal", dptr, bytes, hmod, name);
}

hipError_t hipMemcpyAsync(void* dst, const void* src, size_t sizeBytes,
hipMemcpyKind kind, hipStream_t stream __dparm(0)) {
return DynamicCall<decltype(hipMemcpyAsync), &hipMemcpyAsync>(
Expand Down Expand Up @@ -376,13 +382,6 @@ hipError_t hipModuleGetFunction(hipFunction_t* function, hipModule_t module,
"hipModuleGetFunction", function, module, kname);
}

hipError_t hipModuleGetGlobal(void** ptr, size_t* bytes, hipModule_t module,
const char* kname) {
return DynamicCall<decltype(hipModuleGetGlobal), &hipModuleGetGlobal>(
"hipModuleGetGlobal", ptr, bytes, module, kname);
}


hipError_t hipFuncGetAttributes(struct hipFuncAttributes* attr,
const void* func) {
return DynamicCall<decltype(hipFuncGetAttributes), &hipFuncGetAttributes>(
Expand Down
34 changes: 27 additions & 7 deletions third_party/hip/hip_stub.h.inc
Original file line number Diff line number Diff line change
Expand Up @@ -62,12 +62,22 @@ enum hipError_t {
hipErrorPeerAccessAlreadyEnabled = 704,
hipErrorPeerAccessNotEnabled = 705,
hipErrorSetOnActiveProcess = 708,
hipErrorContextIsDestroyed = 709,
hipErrorAssert = 710,
hipErrorHostMemoryAlreadyRegistered = 712,
hipErrorHostMemoryNotRegistered = 713,
hipErrorLaunchFailure = 719,
hipErrorCooperativeLaunchTooLarge = 720,
hipErrorNotSupported = 801,
hipErrorStreamCaptureUnsupported = 900,
hipErrorStreamCaptureInvalidated = 901,
hipErrorStreamCaptureMerge = 902,
hipErrorStreamCaptureUnmatched = 903,
hipErrorStreamCaptureUnjoined = 904,
hipErrorStreamCaptureIsolation = 905,
hipErrorStreamCaptureImplicit = 906,
hipErrorCapturedEvent = 907,
hipErrorStreamCaptureWrongThread = 908,
hipErrorUnknown = 999,
hipErrorRuntimeMemory = 1052,
hipErrorRuntimeOther = 1053,
Expand Down Expand Up @@ -153,6 +163,7 @@ enum hipFunction_attribute {
};

enum hipLimit_t {
hipLimitPrintfFifoSize = 0x01,
hipLimitMallocHeapSize = 0x02,
};

Expand Down Expand Up @@ -196,10 +207,6 @@ hipError_t hipDriverGetVersion(int* driverVersion);

hipError_t hipRuntimeGetVersion(int* runtimeVersion);

hipError_t hipGetLastError(void);

hipError_t hipPeekAtLastError(void);

hipError_t hipDeviceGet(hipDevice_t* device, int ordinal);

hipError_t hipDeviceGetName(char* name, int len, hipDevice_t device);
Expand All @@ -221,6 +228,10 @@ hipError_t hipGetDeviceProperties(hipDeviceProp_t* prop, int deviceId);

hipError_t hipDeviceGetLimit(size_t* pValue, enum hipLimit_t limit);

hipError_t hipGetLastError(void);

hipError_t hipPeekAtLastError(void);

hipError_t hipStreamCreateWithFlags(hipStream_t* stream, unsigned int flags);

hipError_t hipStreamCreateWithPriority(hipStream_t* stream, unsigned int flags,
Expand Down Expand Up @@ -278,6 +289,9 @@ hipError_t hipHostFree(void* ptr);
hipError_t hipMemcpy(void* dst, const void* src, size_t sizeBytes,
hipMemcpyKind kind);

hipError_t hipModuleGetGlobal(hipDeviceptr_t* dptr, size_t* bytes,
hipModule_t hmod, const char* name);

hipError_t hipMemcpyAsync(void* dst, const void* src, size_t sizeBytes,
hipMemcpyKind kind, hipStream_t stream __dparm(0));

Expand Down Expand Up @@ -353,9 +367,6 @@ hipError_t hipModuleUnload(hipModule_t module);
hipError_t hipModuleGetFunction(hipFunction_t* function, hipModule_t module,
const char* kname);

hipError_t hipModuleGetGlobal(void** ptr, size_t* bytes, hipModule_t module,
const char* kname);

hipError_t hipFuncGetAttributes(struct hipFuncAttributes* attr,
const void* func);

Expand Down Expand Up @@ -385,3 +396,12 @@ hipError_t hipOccupancyMaxPotentialBlockSize(int* gridSize, int* blockSize,
const void* f,
size_t dynSharedMemPerBlk,
int blockSizeLimit);

enum hipDataType {
HIP_R_16F = 2,
HIP_R_32F = 0,
HIP_R_64F = 1,
HIP_C_16F = 6,
HIP_C_32F = 4,
HIP_C_64F = 5,
};
1 change: 1 addition & 0 deletions third_party/hip/miopen_stub.h.inc
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ typedef enum {
miopenInt8 = 3,
miopenInt8x4 = 4,
miopenBFloat16 = 5,
miopenDouble = 6,
} miopenDataType_t;

typedef enum {
Expand Down
84 changes: 42 additions & 42 deletions third_party/hip/rocblas_stub.cc.inc
Original file line number Diff line number Diff line change
Expand Up @@ -34,48 +34,6 @@ ROCBLAS_EXPORT rocblas_status rocblas_get_pointer_mode(
handle, pointer_mode);
}

ROCBLAS_EXPORT rocblas_status rocblas_gemm_ex(
rocblas_handle handle, rocblas_operation transA, rocblas_operation transB,
rocblas_int m, rocblas_int n, rocblas_int k, const void* alpha,
const void* a, rocblas_datatype a_type, rocblas_int lda, const void* b,
rocblas_datatype b_type, rocblas_int ldb, const void* beta, const void* c,
rocblas_datatype c_type, rocblas_int ldc, void* d, rocblas_datatype d_type,
rocblas_int ldd, rocblas_datatype compute_type, rocblas_gemm_algo algo,
int32_t solution_index, uint32_t flags) {
return DynamicCall<decltype(rocblas_gemm_ex), &rocblas_gemm_ex>(
"rocblas_gemm_ex", handle, transA, transB, m, n, k, alpha, a, a_type, lda,
b, b_type, ldb, beta, c, c_type, ldc, d, d_type, ldd, compute_type, algo,
solution_index, flags);
}

ROCBLAS_EXPORT rocblas_status rocblas_gemm_strided_batched_ex(
rocblas_handle handle, rocblas_operation transA, rocblas_operation transB,
rocblas_int m, rocblas_int n, rocblas_int k, const void* alpha,
const void* a, rocblas_datatype a_type, rocblas_int lda,
rocblas_stride stride_a, const void* b, rocblas_datatype b_type,
rocblas_int ldb, rocblas_stride stride_b, const void* beta, const void* c,
rocblas_datatype c_type, rocblas_int ldc, rocblas_stride stride_c, void* d,
rocblas_datatype d_type, rocblas_int ldd, rocblas_stride stride_d,
rocblas_int batch_count, rocblas_datatype compute_type,
rocblas_gemm_algo algo, int32_t solution_index, uint32_t flags) {
return DynamicCall<decltype(rocblas_gemm_strided_batched_ex),
&rocblas_gemm_strided_batched_ex>(
"rocblas_gemm_strided_batched_ex", handle, transA, transB, m, n, k, alpha,
a, a_type, lda, stride_a, b, b_type, ldb, stride_b, beta, c, c_type, ldc,
stride_c, d, d_type, ldd, stride_d, batch_count, compute_type, algo,
solution_index, flags);
}

ROCBLAS_EXPORT rocblas_status rocblas_axpy_ex(
rocblas_handle handle, rocblas_int n, const void* alpha,
rocblas_datatype alpha_type, const void* x, rocblas_datatype x_type,
rocblas_int incx, void* y, rocblas_datatype y_type, rocblas_int incy,
rocblas_datatype execution_type) {
return DynamicCall<decltype(rocblas_axpy_ex), &rocblas_axpy_ex>(
"rocblas_axpy_ex", handle, n, alpha, alpha_type, x, x_type, incx, y,
y_type, incy, execution_type);
}

ROCBLAS_EXPORT rocblas_status rocblas_strsm_batched(
rocblas_handle handle, rocblas_side side, rocblas_fill uplo,
rocblas_operation transA, rocblas_diagonal diag, rocblas_int m,
Expand Down Expand Up @@ -120,3 +78,45 @@ ROCBLAS_EXPORT rocblas_status rocblas_ztrsm_batched(
"rocblas_ztrsm_batched", handle, side, uplo, transA, diag, m, n, alpha, A,
lda, B, ldb, batch_count);
}

ROCBLAS_EXPORT rocblas_status rocblas_gemm_ex(
rocblas_handle handle, rocblas_operation transA, rocblas_operation transB,
rocblas_int m, rocblas_int n, rocblas_int k, const void* alpha,
const void* a, rocblas_datatype a_type, rocblas_int lda, const void* b,
rocblas_datatype b_type, rocblas_int ldb, const void* beta, const void* c,
rocblas_datatype c_type, rocblas_int ldc, void* d, rocblas_datatype d_type,
rocblas_int ldd, rocblas_datatype compute_type, rocblas_gemm_algo algo,
int32_t solution_index, uint32_t flags) {
return DynamicCall<decltype(rocblas_gemm_ex), &rocblas_gemm_ex>(
"rocblas_gemm_ex", handle, transA, transB, m, n, k, alpha, a, a_type, lda,
b, b_type, ldb, beta, c, c_type, ldc, d, d_type, ldd, compute_type, algo,
solution_index, flags);
}

ROCBLAS_EXPORT rocblas_status rocblas_gemm_strided_batched_ex(
rocblas_handle handle, rocblas_operation transA, rocblas_operation transB,
rocblas_int m, rocblas_int n, rocblas_int k, const void* alpha,
const void* a, rocblas_datatype a_type, rocblas_int lda,
rocblas_stride stride_a, const void* b, rocblas_datatype b_type,
rocblas_int ldb, rocblas_stride stride_b, const void* beta, const void* c,
rocblas_datatype c_type, rocblas_int ldc, rocblas_stride stride_c, void* d,
rocblas_datatype d_type, rocblas_int ldd, rocblas_stride stride_d,
rocblas_int batch_count, rocblas_datatype compute_type,
rocblas_gemm_algo algo, int32_t solution_index, uint32_t flags) {
return DynamicCall<decltype(rocblas_gemm_strided_batched_ex),
&rocblas_gemm_strided_batched_ex>(
"rocblas_gemm_strided_batched_ex", handle, transA, transB, m, n, k, alpha,
a, a_type, lda, stride_a, b, b_type, ldb, stride_b, beta, c, c_type, ldc,
stride_c, d, d_type, ldd, stride_d, batch_count, compute_type, algo,
solution_index, flags);
}

ROCBLAS_EXPORT rocblas_status rocblas_axpy_ex(
rocblas_handle handle, rocblas_int n, const void* alpha,
rocblas_datatype alpha_type, const void* x, rocblas_datatype x_type,
rocblas_int incx, void* y, rocblas_datatype y_type, rocblas_int incy,
rocblas_datatype execution_type) {
return DynamicCall<decltype(rocblas_axpy_ex), &rocblas_axpy_ex>(
"rocblas_axpy_ex", handle, n, alpha, alpha_type, x, x_type, incx, y,
y_type, incy, execution_type);
}
53 changes: 27 additions & 26 deletions third_party/hip/rocblas_stub.h.inc
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,7 @@ typedef enum rocblas_gemm_algo_ {
typedef enum rocblas_gemm_flags_ {
rocblas_gemm_flags_none = 0x0,
rocblas_gemm_flags_pack_int8x4 = 0x1,
rocblas_gemm_flags_use_cu_efficiency = 0x2,
} rocblas_gemm_flags;

ROCBLAS_EXPORT rocblas_status rocblas_create_handle(rocblas_handle* handle);
Expand All @@ -89,32 +90,6 @@ ROCBLAS_EXPORT rocblas_status rocblas_set_pointer_mode(
ROCBLAS_EXPORT rocblas_status rocblas_get_pointer_mode(
rocblas_handle handle, rocblas_pointer_mode* pointer_mode);

ROCBLAS_EXPORT rocblas_status rocblas_gemm_ex(
rocblas_handle handle, rocblas_operation transA, rocblas_operation transB,
rocblas_int m, rocblas_int n, rocblas_int k, const void* alpha,
const void* a, rocblas_datatype a_type, rocblas_int lda, const void* b,
rocblas_datatype b_type, rocblas_int ldb, const void* beta, const void* c,
rocblas_datatype c_type, rocblas_int ldc, void* d, rocblas_datatype d_type,
rocblas_int ldd, rocblas_datatype compute_type, rocblas_gemm_algo algo,
int32_t solution_index, uint32_t flags);

ROCBLAS_EXPORT rocblas_status rocblas_gemm_strided_batched_ex(
rocblas_handle handle, rocblas_operation transA, rocblas_operation transB,
rocblas_int m, rocblas_int n, rocblas_int k, const void* alpha,
const void* a, rocblas_datatype a_type, rocblas_int lda,
rocblas_stride stride_a, const void* b, rocblas_datatype b_type,
rocblas_int ldb, rocblas_stride stride_b, const void* beta, const void* c,
rocblas_datatype c_type, rocblas_int ldc, rocblas_stride stride_c, void* d,
rocblas_datatype d_type, rocblas_int ldd, rocblas_stride stride_d,
rocblas_int batch_count, rocblas_datatype compute_type,
rocblas_gemm_algo algo, int32_t solution_index, uint32_t flags);

ROCBLAS_EXPORT rocblas_status rocblas_axpy_ex(
rocblas_handle handle, rocblas_int n, const void* alpha,
rocblas_datatype alpha_type, const void* x, rocblas_datatype x_type,
rocblas_int incx, void* y, rocblas_datatype y_type, rocblas_int incy,
rocblas_datatype execution_type);

ROCBLAS_EXPORT rocblas_status rocblas_strsm_batched(
rocblas_handle handle, rocblas_side side, rocblas_fill uplo,
rocblas_operation transA, rocblas_diagonal diag, rocblas_int m,
Expand Down Expand Up @@ -142,3 +117,29 @@ ROCBLAS_EXPORT rocblas_status rocblas_ztrsm_batched(
const rocblas_double_complex* const A[], rocblas_int lda,
rocblas_double_complex* const B[], rocblas_int ldb,
rocblas_int batch_count);

ROCBLAS_EXPORT rocblas_status rocblas_gemm_ex(
rocblas_handle handle, rocblas_operation transA, rocblas_operation transB,
rocblas_int m, rocblas_int n, rocblas_int k, const void* alpha,
const void* a, rocblas_datatype a_type, rocblas_int lda, const void* b,
rocblas_datatype b_type, rocblas_int ldb, const void* beta, const void* c,
rocblas_datatype c_type, rocblas_int ldc, void* d, rocblas_datatype d_type,
rocblas_int ldd, rocblas_datatype compute_type, rocblas_gemm_algo algo,
int32_t solution_index, uint32_t flags);

ROCBLAS_EXPORT rocblas_status rocblas_gemm_strided_batched_ex(
rocblas_handle handle, rocblas_operation transA, rocblas_operation transB,
rocblas_int m, rocblas_int n, rocblas_int k, const void* alpha,
const void* a, rocblas_datatype a_type, rocblas_int lda,
rocblas_stride stride_a, const void* b, rocblas_datatype b_type,
rocblas_int ldb, rocblas_stride stride_b, const void* beta, const void* c,
rocblas_datatype c_type, rocblas_int ldc, rocblas_stride stride_c, void* d,
rocblas_datatype d_type, rocblas_int ldd, rocblas_stride stride_d,
rocblas_int batch_count, rocblas_datatype compute_type,
rocblas_gemm_algo algo, int32_t solution_index, uint32_t flags);

ROCBLAS_EXPORT rocblas_status rocblas_axpy_ex(
rocblas_handle handle, rocblas_int n, const void* alpha,
rocblas_datatype alpha_type, const void* x, rocblas_datatype x_type,
rocblas_int incx, void* y, rocblas_datatype y_type, rocblas_int incy,
rocblas_datatype execution_type);

0 comments on commit d9b2dec

Please sign in to comment.