@@ -82,6 +82,7 @@ double BlasConnector::dot( const int n, const double *X, const int incX, const d
8282}
8383
8484// C = a * A.? * B.? + b * C
85+ // Row-Major part
8586void BlasConnector::gemm (const char transa, const char transb, const int m, const int n, const int k,
8687 const float alpha, const float *a, const int lda, const float *b, const int ldb,
8788 const float beta, float *c, const int ldc, base_device::AbacusDevice_t device_type)
@@ -154,6 +155,147 @@ void BlasConnector::gemm(const char transa, const char transb, const int m, cons
154155 #endif
155156}
156157
158+ // Col-Major part
159+ void BlasConnector::gemm_cm (const char transa, const char transb, const int m, const int n, const int k,
160+ const float alpha, const float *a, const int lda, const float *b, const int ldb,
161+ const float beta, float *c, const int ldc, base_device::AbacusDevice_t device_type)
162+ {
163+ if (device_type == base_device::AbacusDevice_t::CpuDevice) {
164+ sgemm_ (&transa, &transb, &m, &n, &k,
165+ &alpha, a, &lda, b, &ldb,
166+ &beta, c, &ldc);
167+ }
168+ #ifdef __DSP
169+ else if (device_type == base_device::AbacusDevice_t::DspDevice){
170+ sgemm_mth_ (&transb, &transa, &m, &n, &k,
171+ &alpha, a, &lda, b, &ldb,
172+ &beta, c, &ldc, GlobalV::MY_RANK);
173+ }
174+ #endif
175+ }
176+
177+ void BlasConnector::gemm_cm (const char transa, const char transb, const int m, const int n, const int k,
178+ const double alpha, const double *a, const int lda, const double *b, const int ldb,
179+ const double beta, double *c, const int ldc, base_device::AbacusDevice_t device_type)
180+ {
181+ if (device_type == base_device::AbacusDevice_t::CpuDevice) {
182+ dgemm_ (&transa, &transb, &m, &n, &k,
183+ &alpha, a, &lda, b, &ldb,
184+ &beta, c, &ldc);
185+ }
186+ #ifdef __DSP
187+ else if (device_type == base_device::AbacusDevice_t::DspDevice){
188+ dgemm_mth_ (&transa, &transb, &m, &n, &k,
189+ &alpha, a, &lda, b, &ldb,
190+ &beta, c, &ldc, GlobalV::MY_RANK);
191+ }
192+ #endif
193+ }
194+
195+ void BlasConnector::gemm_cm (const char transa, const char transb, const int m, const int n, const int k,
196+ const std::complex <float > alpha, const std::complex <float > *a, const int lda, const std::complex <float > *b, const int ldb,
197+ const std::complex <float > beta, std::complex <float > *c, const int ldc, base_device::AbacusDevice_t device_type)
198+ {
199+ if (device_type == base_device::AbacusDevice_t::CpuDevice) {
200+ cgemm_ (&transa, &transb, &m, &n, &k,
201+ &alpha, a, &lda, b, &ldb,
202+ &beta, c, &ldc);
203+ }
204+ #ifdef __DSP
205+ else if (device_type == base_device::AbacusDevice_t::DspDevice) {
206+ cgemm_mth_ (&transa, &transb, &m, &n, &k,
207+ &alpha, a, &lda, b, &ldb,
208+ &beta, c, &ldc, GlobalV::MY_RANK);
209+ }
210+ #endif
211+ }
212+
213+ void BlasConnector::gemm_cm (const char transa, const char transb, const int m, const int n, const int k,
214+ const std::complex <double > alpha, const std::complex <double > *a, const int lda, const std::complex <double > *b, const int ldb,
215+ const std::complex <double > beta, std::complex <double > *c, const int ldc, base_device::AbacusDevice_t device_type)
216+ {
217+ if (device_type == base_device::AbacusDevice_t::CpuDevice) {
218+ zgemm_ (&transa, &transb, &m, &n, &k,
219+ &alpha, a, &lda, b, &ldb,
220+ &beta, c, &ldc);
221+ }
222+ #ifdef __DSP
223+ else if (device_type == base_device::AbacusDevice_t::DspDevice) {
224+ zgemm_mth_ (&transa, &transb, &m, &n, &k,
225+ &alpha, a, &lda, b, &ldb,
226+ &beta, c, &ldc, GlobalV::MY_RANK);
227+ }
228+ #endif
229+ }
230+
231+ // Symm and Hemm part. Only col-major is supported.
232+
233+ void BlasConnector::symm_cm (const char side, const char uplo, const int m, const int n,
234+ const float alpha, const float *a, const int lda, const float *b, const int ldb,
235+ const float beta, float *c, const int ldc, base_device::AbacusDevice_t device_type)
236+ {
237+ if (device_type == base_device::AbacusDevice_t::CpuDevice) {
238+ ssymm_ (&side, &uplo, &m, &n,
239+ &alpha, a, &lda, b, &ldb,
240+ &beta, c, &ldc);
241+ }
242+ }
243+
244+ void BlasConnector::symm_cm (const char side, const char uplo, const int m, const int n,
245+ const double alpha, const double *a, const int lda, const double *b, const int ldb,
246+ const double beta, double *c, const int ldc, base_device::AbacusDevice_t device_type)
247+ {
248+ if (device_type == base_device::AbacusDevice_t::CpuDevice) {
249+ dsymm_ (&side, &uplo, &m, &n,
250+ &alpha, a, &lda, b, &ldb,
251+ &beta, c, &ldc);
252+ }
253+ }
254+
255+ void BlasConnector::symm_cm (const char side, const char uplo, const int m, const int n,
256+ const std::complex <float > alpha, const std::complex <float > *a, const int lda, const std::complex <float > *b, const int ldb,
257+ const std::complex <float > beta, std::complex <float > *c, const int ldc, base_device::AbacusDevice_t device_type)
258+ {
259+ if (device_type == base_device::AbacusDevice_t::CpuDevice) {
260+ csymm_ (&side, &uplo, &m, &n,
261+ &alpha, a, &lda, b, &ldb,
262+ &beta, c, &ldc);
263+ }
264+ }
265+
266+ void BlasConnector::symm_cm (const char side, const char uplo, const int m, const int n,
267+ const std::complex <double > alpha, const std::complex <double > *a, const int lda, const std::complex <double > *b, const int ldb,
268+ const std::complex <double > beta, std::complex <double > *c, const int ldc, base_device::AbacusDevice_t device_type)
269+ {
270+ if (device_type == base_device::AbacusDevice_t::CpuDevice) {
271+ zsymm_ (&side, &uplo, &m, &n,
272+ &alpha, a, &lda, b, &ldb,
273+ &beta, c, &ldc);
274+ }
275+ }
276+
277+ void BlasConnector::hemm_cm (char side, char uplo, int m, int n,
278+ std::complex <float > alpha, std::complex <float > *a, int lda, std::complex <float > *b, int ldb,
279+ std::complex <float > beta, std::complex <float > *c, int ldc, base_device::AbacusDevice_t device_type)
280+ {
281+ if (device_type == base_device::AbacusDevice_t::CpuDevice) {
282+ chemm_ (&side, &uplo, &m, &n,
283+ &alpha, a, &lda, b, &ldb,
284+ &beta, c, &ldc);
285+ }
286+ }
287+
288+ void BlasConnector::hemm_cm (char side, char uplo, int m, int n,
289+ std::complex <double > alpha, std::complex <double > *a, int lda, std::complex <double > *b, int ldb,
290+ std::complex <double > beta, std::complex <double > *c, int ldc, base_device::AbacusDevice_t device_type)
291+ {
292+ if (device_type == base_device::AbacusDevice_t::CpuDevice) {
293+ zhemm_ (&side, &uplo, &m, &n,
294+ &alpha, a, &lda, b, &ldb,
295+ &beta, c, &ldc);
296+ }
297+ }
298+
157299void BlasConnector::gemv (const char trans, const int m, const int n,
158300 const float alpha, const float * A, const int lda, const float * X, const int incx,
159301 const float beta, float * Y, const int incy, base_device::AbacusDevice_t device_type)
@@ -190,7 +332,6 @@ void BlasConnector::gemv(const char trans, const int m, const int n,
190332}
191333}
192334
193-
194335// out = ||x||_2
195336float BlasConnector::nrm2 ( const int n, const float *X, const int incX, base_device::AbacusDevice_t device_type )
196337{
0 commit comments