Skip to content

Commit

Permalink
Pull request OpenMathLib#118: Fix gemmt
Browse files Browse the repository at this point in the history
Merge in PL/openblas from dev/k.zaytseva/LM-538 to dev-riscv
  • Loading branch information
Kseniya Zaytseva authored and kseniyazaytseva committed Dec 18, 2023
1 parent 8f23df4 commit df731dd
Show file tree
Hide file tree
Showing 3 changed files with 114 additions and 19 deletions.
8 changes: 8 additions & 0 deletions cblas.h
Original file line number Diff line number Diff line change
Expand Up @@ -289,6 +289,14 @@ void cblas_zgemm(OPENBLAS_CONST enum CBLAS_ORDER Order, OPENBLAS_CONST enum CBLA
void cblas_zgemm3m(OPENBLAS_CONST enum CBLAS_ORDER Order, OPENBLAS_CONST enum CBLAS_TRANSPOSE TransA, OPENBLAS_CONST enum CBLAS_TRANSPOSE TransB, OPENBLAS_CONST blasint M, OPENBLAS_CONST blasint N, OPENBLAS_CONST blasint K,
OPENBLAS_CONST void *alpha, OPENBLAS_CONST void *A, OPENBLAS_CONST blasint lda, OPENBLAS_CONST void *B, OPENBLAS_CONST blasint ldb, OPENBLAS_CONST void *beta, void *C, OPENBLAS_CONST blasint ldc);

void cblas_sgemmt(OPENBLAS_CONST enum CBLAS_ORDER Order, OPENBLAS_CONST enum CBLAS_UPLO Uplo, OPENBLAS_CONST enum CBLAS_TRANSPOSE TransA, OPENBLAS_CONST enum CBLAS_TRANSPOSE TransB, OPENBLAS_CONST blasint M, OPENBLAS_CONST blasint K,
OPENBLAS_CONST float alpha, OPENBLAS_CONST float *A, OPENBLAS_CONST blasint lda, OPENBLAS_CONST float *B, OPENBLAS_CONST blasint ldb, OPENBLAS_CONST float beta, float *C, OPENBLAS_CONST blasint ldc);
void cblas_dgemmt(OPENBLAS_CONST enum CBLAS_ORDER Order, OPENBLAS_CONST enum CBLAS_UPLO Uplo, OPENBLAS_CONST enum CBLAS_TRANSPOSE TransA, OPENBLAS_CONST enum CBLAS_TRANSPOSE TransB, OPENBLAS_CONST blasint M, OPENBLAS_CONST blasint K,
OPENBLAS_CONST double alpha, OPENBLAS_CONST double *A, OPENBLAS_CONST blasint lda, OPENBLAS_CONST double *B, OPENBLAS_CONST blasint ldb, OPENBLAS_CONST double beta, double *C, OPENBLAS_CONST blasint ldc);
void cblas_cgemmt(OPENBLAS_CONST enum CBLAS_ORDER Order, OPENBLAS_CONST enum CBLAS_UPLO Uplo, OPENBLAS_CONST enum CBLAS_TRANSPOSE TransA, OPENBLAS_CONST enum CBLAS_TRANSPOSE TransB, OPENBLAS_CONST blasint M, OPENBLAS_CONST blasint K,
OPENBLAS_CONST void *alpha, OPENBLAS_CONST void *A, OPENBLAS_CONST blasint lda, OPENBLAS_CONST void *B, OPENBLAS_CONST blasint ldb, OPENBLAS_CONST void *beta, void *C, OPENBLAS_CONST blasint ldc);
void cblas_zgemmt(OPENBLAS_CONST enum CBLAS_ORDER Order, OPENBLAS_CONST enum CBLAS_UPLO Uplo, OPENBLAS_CONST enum CBLAS_TRANSPOSE TransA, OPENBLAS_CONST enum CBLAS_TRANSPOSE TransB, OPENBLAS_CONST blasint M, OPENBLAS_CONST blasint K,
OPENBLAS_CONST void *alpha, OPENBLAS_CONST void *A, OPENBLAS_CONST blasint lda, OPENBLAS_CONST void *B, OPENBLAS_CONST blasint ldb, OPENBLAS_CONST void *beta, void *C, OPENBLAS_CONST blasint ldc);

void cblas_ssymm(OPENBLAS_CONST enum CBLAS_ORDER Order, OPENBLAS_CONST enum CBLAS_SIDE Side, OPENBLAS_CONST enum CBLAS_UPLO Uplo, OPENBLAS_CONST blasint M, OPENBLAS_CONST blasint N,
OPENBLAS_CONST float alpha, OPENBLAS_CONST float *A, OPENBLAS_CONST blasint lda, OPENBLAS_CONST float *B, OPENBLAS_CONST blasint ldb, OPENBLAS_CONST float beta, float *C, OPENBLAS_CONST blasint ldc);
Expand Down
9 changes: 9 additions & 0 deletions common_interface.h
Original file line number Diff line number Diff line change
Expand Up @@ -498,6 +498,15 @@ void BLASFUNC(zgemm3m)(char *, char *, blasint *, blasint *, blasint *, double *
void BLASFUNC(xgemm3m)(char *, char *, blasint *, blasint *, blasint *, xdouble *,
xdouble *, blasint *, xdouble *, blasint *, xdouble *, xdouble *, blasint *);

void BLASFUNC(sgemmt)(char*, char *, char *, blasint *, blasint *, float *,
float *, blasint *, float *, blasint *, float *, float *, blasint *);
void BLASFUNC(dgemmt)(char*, char *, char *, blasint *, blasint *, double *,
double *, blasint *, double *, blasint *, double *, double *, blasint *);
void BLASFUNC(cgemmt)(char*, char *, char *, blasint *, blasint *, float *,
float *, blasint *, float *, blasint *, float *, float *, blasint *);
void BLASFUNC(zgemmt)(char*, char *, char *, blasint *, blasint *, double *,
double *, blasint *, double *, blasint *, double *, double *, blasint *);

int BLASFUNC(sge2mm)(char *, char *, char *, blasint *, blasint *,
float *, float *, blasint *, float *, blasint *,
float *, float *, blasint *);
Expand Down
116 changes: 97 additions & 19 deletions interface/gemmt.c
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,10 @@ void NAME(char *UPLO, char *TRANSA, char *TRANSB,
blasint info;

char transA, transB, Uplo;
blasint nrowa, nrowb;
#if defined(COMPLEX)
blasint ncolb;
#endif
IFLOAT *buffer;
IFLOAT *aa, *bb;
FLOAT *cc;
Expand Down Expand Up @@ -159,12 +163,29 @@ void NAME(char *UPLO, char *TRANSA, char *TRANSB,
if (Uplo == 'L')
uplo = 1;

nrowa = m;
if (transa & 1) nrowa = k;
nrowb = k;
#if defined(COMPLEX)
ncolb = m;
#endif
if (transb & 1) {
nrowb = m;
#if defined(COMPLEX)
ncolb = k;
#endif
}

info = 0;

if (uplo < 0)
info = 14;
if (ldc < m)
info = 13;
if (ldb < MAX(1, nrowb))
info = 10;
if (lda < MAX(1, nrowa))
info = 8;
if (k < 0)
info = 5;
if (n < 0)
Expand Down Expand Up @@ -207,6 +228,9 @@ void CNAME(enum CBLAS_ORDER order, enum CBLAS_UPLO Uplo,
blasint info;
blasint m, n, lda, ldb;
FLOAT *a, *b;
#if defined(COMPLEX)
blasint nrowb, ncolb;
#endif
XFLOAT *buffer;

PRINT_DEBUG_CNAME;
Expand Down Expand Up @@ -258,7 +282,24 @@ void CNAME(enum CBLAS_ORDER order, enum CBLAS_UPLO Uplo,

info = -1;

if (ldc < m)
blasint nrowa;
#if !defined(COMPLEX)
blasint nrowb;
#endif
nrowa = m;
if (transa & 1) nrowa = k;
nrowb = k;
#if defined(COMPLEX)
ncolb = m;
#endif
if (transb & 1) {
nrowb = m;
#if defined(COMPLEX)
ncolb = k;
#endif
}

if (ldc < MAX(1, m))
info = 13;
if (k < 0)
info = 5;
Expand Down Expand Up @@ -315,17 +356,39 @@ void CNAME(enum CBLAS_ORDER order, enum CBLAS_UPLO Uplo,

info = -1;

if (ldc < m)
blasint ncola;
#if !defined(COMPLEX)
blasint ncolb;
#endif
ncola = m;
if (transa & 1) ncola = k;
ncolb = k;
#if defined(COMPLEX)
nrowb = m;
#endif

if (transb & 1) {
#if defined(COMPLEX)
nrowb = k;
#endif
ncolb = m;
}

if (ldc < MAX(1,m))
info = 13;
if (ldb < MAX(1, ncolb))
info = 8;
if (lda < MAX(1, ncola))
info = 10;
if (k < 0)
info = 5;
if (n < 0)
info = 4;
if (m < 0)
info = 3;
if (transb < 0)
info = 2;
if (transa < 0)
info = 3;
if (uplo < 0)
info = 1;

}
Expand Down Expand Up @@ -412,9 +475,20 @@ void CNAME(enum CBLAS_ORDER order, enum CBLAS_UPLO Uplo,

IDEBUG_START;

FUNCTION_PROFILE_START();
#if defined(COMPLEX)
if (transb > 1){
#ifndef CBLAS
IMATCOPY_K_CNC(nrowb, ncolb, (FLOAT)(1.0), (FLOAT)(0.0), b, ldb);
#else
if (order == CblasColMajor)
IMATCOPY_K_CNC(nrowb, ncolb, (FLOAT)(1.0), (FLOAT)(0.0), b, ldb);
if (order == CblasRowMajor)
IMATCOPY_K_RNC(nrowb, ncolb, (FLOAT)(1.0), (FLOAT)(0.0), b, ldb);
#endif
}
#endif

const blasint incb = (transb == 0) ? 1 : ldb;
const blasint incb = ((transb & 1) == 0) ? 1 : ldb;

if (uplo == 1) {
for (i = 0; i < n; i++) {
Expand All @@ -424,18 +498,20 @@ void CNAME(enum CBLAS_ORDER order, enum CBLAS_UPLO Uplo,
#if defined(COMPLEX)
aa = a + i * 2;
bb = b + i * ldb * 2;
if (transa) {
l = k;
if (transa & 1) {
aa = a + lda * i * 2;
}
if (transb & 1)
bb = b + i * 2;
}
cc = c + i * 2 * ldc + i * 2;
#else
aa = a + i;
bb = b + i * ldb;
if (transa) {
l = k;
if (transa & 1) {
aa = a + lda * i;
}
if (transb & 1)
bb = b + i;
}
cc = c + i * ldc + i;
Expand All @@ -447,7 +523,7 @@ void CNAME(enum CBLAS_ORDER order, enum CBLAS_UPLO Uplo,
NULL, 0);

if (alpha_r == ZERO && alpha_i == ZERO)
return;
continue;
#else
if (beta != ONE)
SCAL_K(l, 0, 0, beta, cc, 1, NULL, 0, NULL, 0);
Expand Down Expand Up @@ -479,16 +555,18 @@ void CNAME(enum CBLAS_ORDER order, enum CBLAS_UPLO Uplo,
#endif

#if defined(COMPLEX)
if (!(transa & 1))
(gemv[(int)transa]) (j, k, 0, alpha_r, alpha_i,
aa, lda, bb, incb, cc, 1,
buffer);
#else
if (!(transa & 1))
(gemv[(int)transa]) (j, k, 0, alpha, aa, lda,
bb, incb, cc, 1, buffer);
#endif
#ifdef SMP
} else {

if (!(transa & 1))
(gemv_thread[(int)transa]) (j, k, alpha, aa,
lda, bb, incb, cc,
1, buffer,
Expand All @@ -507,15 +585,13 @@ void CNAME(enum CBLAS_ORDER order, enum CBLAS_UPLO Uplo,
l = j;
#if defined COMPLEX
bb = b + i * ldb * 2;
if (transa) {
l = k;
if (transb & 1) {
bb = b + i * 2;
}
cc = c + i * 2 * ldc;
#else
bb = b + i * ldb;
if (transa) {
l = k;
if (transb & 1) {
bb = b + i;
}
cc = c + i * ldc;
Expand All @@ -527,7 +603,7 @@ void CNAME(enum CBLAS_ORDER order, enum CBLAS_UPLO Uplo,
NULL, 0);

if (alpha_r == ZERO && alpha_i == ZERO)
return;
continue;
#else
if (beta != ONE)
SCAL_K(l, 0, 0, beta, cc, 1, NULL, 0, NULL, 0);
Expand Down Expand Up @@ -558,17 +634,19 @@ void CNAME(enum CBLAS_ORDER order, enum CBLAS_UPLO Uplo,
#endif

#if defined(COMPLEX)
if (!(transa & 1))
(gemv[(int)transa]) (j, k, 0, alpha_r, alpha_i,
a, lda, bb, incb, cc, 1,
buffer);
#else
if (!(transa & 1))
(gemv[(int)transa]) (j, k, 0, alpha, a, lda, bb,
incb, cc, 1, buffer);
#endif

#ifdef SMP
} else {

if (!(transa & 1))
(gemv_thread[(int)transa]) (j, k, alpha, a, lda,
bb, incb, cc, 1,
buffer, nthreads);
Expand All @@ -586,4 +664,4 @@ void CNAME(enum CBLAS_ORDER order, enum CBLAS_UPLO Uplo,
IDEBUG_END;

return;
}
}

0 comments on commit df731dd

Please sign in to comment.