Skip to content

Commit

Permalink
Pull request OpenMathLib#118: Fix gemmt
Browse files Browse the repository at this point in the history
Merge in PL/openblas from dev/k.zaytseva/LM-538 to dev-riscv
  • Loading branch information
Kseniya Zaytseva authored and kseniyazaytseva committed Dec 18, 2023
1 parent 51cef97 commit 08f9517
Showing 1 changed file with 73 additions and 29 deletions.
102 changes: 73 additions & 29 deletions interface/gemmt.c
Original file line number Diff line number Diff line change
Expand Up @@ -160,17 +160,25 @@ void NAME(char *UPLO, char *TRANSA, char *TRANSB,
uplo = 1;

nrowa = m;
if (transa) nrowa = k;
if (transa & 1) nrowa = k;
nrowb = k;
if (transb) nrowb = m;
#if defined(COMPLEX)
ncolb = m;
#endif
if (transb & 1) {
nrowb = m;
#if defined(COMPLEX)
ncolb = k;
#endif
}

info = 0;

if (ldc < MAX(1, m))
info = 13;
if (ldb < MAX(1, nrowa))
if (ldb < MAX(1, nrowb))
info = 10;
if (lda < MAX(1, nrowb))
if (lda < MAX(1, nrowa))
info = 8;
if (k < 0)
info = 5;
Expand Down Expand Up @@ -268,11 +276,22 @@ void CNAME(enum CBLAS_ORDER order, enum CBLAS_UPLO Uplo,

info = -1;

blasint nrowa, nrowb;
blasint nrowa;
#if !defined(COMPLEX)
blasint nrowb;
#endif
nrowa = m;
if (transa) nrowa = k;
if (transa & 1) nrowa = k;
nrowb = k;
if (transb) nrowb = m;
#if defined(COMPLEX)
ncolb = m;
#endif
if (transb & 1) {
nrowb = m;
#if defined(COMPLEX)
ncolb = k;
#endif
}

if (ldc < MAX(1, m))
info = 13;
Expand Down Expand Up @@ -336,26 +355,38 @@ void CNAME(enum CBLAS_ORDER order, enum CBLAS_UPLO Uplo,

info = -1;

blasint ncola, ncolb;
ncola = k;
if (transa) ncola = m;
ncolb = m;
if (transb) ncolb = k;
blasint ncola;
#if !defined(COMPLEX)
blasint ncolb;
#endif
ncola = m;
if (transa & 1) ncola = k;
ncolb = k;
#if defined(COMPLEX)
nrowb = m;
#endif

if (transb & 1) {
#if defined(COMPLEX)
nrowb = k;
#endif
ncolb = m;
}

if (ldc < MAX(1,m))
info = 13;
if (ldb < MAX(1, ncolb))
info = 10;
if (lda < MAX(1, ncola))
info = 8;
if (lda < MAX(1, ncola))
info = 10;
if (k < 0)
info = 5;
if (m < 0)
info = 4;
if (transb < 0)
info = 3;
if (transa < 0)
info = 2;
if (transa < 0)
info = 3;
if (uplo < 0)
info = 1;
}
Expand Down Expand Up @@ -434,7 +465,20 @@ void CNAME(enum CBLAS_ORDER order, enum CBLAS_UPLO Uplo,

IDEBUG_START;

const blasint incb = (transb == 0) ? 1 : ldb;
#if defined(COMPLEX)
if (transb > 1){
#ifndef CBLAS
IMATCOPY_K_CNC(nrowb, ncolb, (FLOAT)(1.0), (FLOAT)(0.0), b, ldb);
#else
if (order == CblasColMajor)
IMATCOPY_K_CNC(nrowb, ncolb, (FLOAT)(1.0), (FLOAT)(0.0), b, ldb);
if (order == CblasRowMajor)
IMATCOPY_K_RNC(nrowb, ncolb, (FLOAT)(1.0), (FLOAT)(0.0), b, ldb);
#endif
}
#endif

const blasint incb = ((transb & 1) == 0) ? 1 : ldb;

if (uplo == 1) {
for (i = 0; i < m; i++) {
Expand All @@ -444,19 +488,19 @@ void CNAME(enum CBLAS_ORDER order, enum CBLAS_UPLO Uplo,
#if defined(COMPLEX)
aa = a + i * 2;
bb = b + i * ldb * 2;
if (transa) {
if (transa & 1) {
aa = a + lda * i * 2;
}
if (transb)
if (transb & 1)
bb = b + i * 2;
cc = c + i * 2 * ldc + i * 2;
#else
aa = a + i;
bb = b + i * ldb;
if (transa) {
if (transa & 1) {
aa = a + lda * i;
}
if (transb)
if (transb & 1)
bb = b + i;
cc = c + i * ldc + i;
#endif
Expand Down Expand Up @@ -497,7 +541,7 @@ void CNAME(enum CBLAS_ORDER order, enum CBLAS_UPLO Uplo,
#endif

#if defined(COMPLEX)
if (!transa)
if (!(transa & 1))
(gemv[(int)transa]) (j, k, 0, alpha_r, alpha_i,
aa, lda, bb, incb, cc, 1,
buffer);
Expand All @@ -506,7 +550,7 @@ void CNAME(enum CBLAS_ORDER order, enum CBLAS_UPLO Uplo,
aa, lda, bb, incb, cc, 1,
buffer);
#else
if (!transa)
if (!(transa & 1))
(gemv[(int)transa]) (j, k, 0, alpha, aa, lda,
bb, incb, cc, 1, buffer);
else
Expand All @@ -515,7 +559,7 @@ void CNAME(enum CBLAS_ORDER order, enum CBLAS_UPLO Uplo,
#endif
#ifdef SMP
} else {
if (!transa)
if (!(transa & 1))
(gemv_thread[(int)transa]) (j, k, alpha, aa,
lda, bb, incb, cc,
1, buffer,
Expand All @@ -539,13 +583,13 @@ void CNAME(enum CBLAS_ORDER order, enum CBLAS_UPLO Uplo,
l = j;
#if defined COMPLEX
bb = b + i * ldb * 2;
if (transb) {
if (transb & 1) {
bb = b + i * 2;
}
cc = c + i * 2 * ldc;
#else
bb = b + i * ldb;
if (transb) {
if (transb & 1) {
bb = b + i;
}
cc = c + i * ldc;
Expand Down Expand Up @@ -586,7 +630,7 @@ void CNAME(enum CBLAS_ORDER order, enum CBLAS_UPLO Uplo,
#endif

#if defined(COMPLEX)
if (!transa)
if (!(transa & 1))
(gemv[(int)transa]) (j, k, 0, alpha_r, alpha_i,
a, lda, bb, incb, cc, 1,
buffer);
Expand All @@ -595,7 +639,7 @@ void CNAME(enum CBLAS_ORDER order, enum CBLAS_UPLO Uplo,
a, lda, bb, incb, cc, 1,
buffer);
#else
if (!transa)
if (!(transa & 1))
(gemv[(int)transa]) (j, k, 0, alpha, a, lda, bb,
incb, cc, 1, buffer);
else
Expand All @@ -605,7 +649,7 @@ void CNAME(enum CBLAS_ORDER order, enum CBLAS_UPLO Uplo,

#ifdef SMP
} else {
if (!transa)
if (!(transa & 1))
(gemv_thread[(int)transa]) (j, k, alpha, a, lda,
bb, incb, cc, 1,
buffer, nthreads);
Expand Down

0 comments on commit 08f9517

Please sign in to comment.