Skip to content

Commit 3a2eb06

Browse files
authored
Initial commit (#5790)
1 parent d7b76fc commit 3a2eb06

File tree

3 files changed

+217
-5
lines changed

3 files changed

+217
-5
lines changed

source/module_base/blas_connector.cpp

Lines changed: 142 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -82,6 +82,7 @@ double BlasConnector::dot( const int n, const double *X, const int incX, const d
8282
}
8383

8484
// C = a * A.? * B.? + b * C
85+
// Row-Major part
8586
void BlasConnector::gemm(const char transa, const char transb, const int m, const int n, const int k,
8687
const float alpha, const float *a, const int lda, const float *b, const int ldb,
8788
const float beta, float *c, const int ldc, base_device::AbacusDevice_t device_type)
@@ -154,6 +155,147 @@ void BlasConnector::gemm(const char transa, const char transb, const int m, cons
154155
#endif
155156
}
156157

158+
// Col-Major part
159+
void BlasConnector::gemm_cm(const char transa, const char transb, const int m, const int n, const int k,
160+
const float alpha, const float *a, const int lda, const float *b, const int ldb,
161+
const float beta, float *c, const int ldc, base_device::AbacusDevice_t device_type)
162+
{
163+
if (device_type == base_device::AbacusDevice_t::CpuDevice) {
164+
sgemm_(&transa, &transb, &m, &n, &k,
165+
&alpha, a, &lda, b, &ldb,
166+
&beta, c, &ldc);
167+
}
168+
#ifdef __DSP
169+
else if (device_type == base_device::AbacusDevice_t::DspDevice){
170+
sgemm_mth_(&transb, &transa, &m, &n, &k,
171+
&alpha, a, &lda, b, &ldb,
172+
&beta, c, &ldc, GlobalV::MY_RANK);
173+
}
174+
#endif
175+
}
176+
177+
void BlasConnector::gemm_cm(const char transa, const char transb, const int m, const int n, const int k,
178+
const double alpha, const double *a, const int lda, const double *b, const int ldb,
179+
const double beta, double *c, const int ldc, base_device::AbacusDevice_t device_type)
180+
{
181+
if (device_type == base_device::AbacusDevice_t::CpuDevice) {
182+
dgemm_(&transa, &transb, &m, &n, &k,
183+
&alpha, a, &lda, b, &ldb,
184+
&beta, c, &ldc);
185+
}
186+
#ifdef __DSP
187+
else if (device_type == base_device::AbacusDevice_t::DspDevice){
188+
dgemm_mth_(&transa, &transb, &m, &n, &k,
189+
&alpha, a, &lda, b, &ldb,
190+
&beta, c, &ldc, GlobalV::MY_RANK);
191+
}
192+
#endif
193+
}
194+
195+
void BlasConnector::gemm_cm(const char transa, const char transb, const int m, const int n, const int k,
196+
const std::complex<float> alpha, const std::complex<float> *a, const int lda, const std::complex<float> *b, const int ldb,
197+
const std::complex<float> beta, std::complex<float> *c, const int ldc, base_device::AbacusDevice_t device_type)
198+
{
199+
if (device_type == base_device::AbacusDevice_t::CpuDevice) {
200+
cgemm_(&transa, &transb, &m, &n, &k,
201+
&alpha, a, &lda, b, &ldb,
202+
&beta, c, &ldc);
203+
}
204+
#ifdef __DSP
205+
else if (device_type == base_device::AbacusDevice_t::DspDevice) {
206+
cgemm_mth_(&transa, &transb, &m, &n, &k,
207+
&alpha, a, &lda, b, &ldb,
208+
&beta, c, &ldc, GlobalV::MY_RANK);
209+
}
210+
#endif
211+
}
212+
213+
void BlasConnector::gemm_cm(const char transa, const char transb, const int m, const int n, const int k,
214+
const std::complex<double> alpha, const std::complex<double> *a, const int lda, const std::complex<double> *b, const int ldb,
215+
const std::complex<double> beta, std::complex<double> *c, const int ldc, base_device::AbacusDevice_t device_type)
216+
{
217+
if (device_type == base_device::AbacusDevice_t::CpuDevice) {
218+
zgemm_(&transa, &transb, &m, &n, &k,
219+
&alpha, a, &lda, b, &ldb,
220+
&beta, c, &ldc);
221+
}
222+
#ifdef __DSP
223+
else if (device_type == base_device::AbacusDevice_t::DspDevice) {
224+
zgemm_mth_(&transa, &transb, &m, &n, &k,
225+
&alpha, a, &lda, b, &ldb,
226+
&beta, c, &ldc, GlobalV::MY_RANK);
227+
}
228+
#endif
229+
}
230+
231+
// Symm and Hemm part. Only col-major is supported.
232+
233+
void BlasConnector::symm_cm(const char side, const char uplo, const int m, const int n,
234+
const float alpha, const float *a, const int lda, const float *b, const int ldb,
235+
const float beta, float *c, const int ldc, base_device::AbacusDevice_t device_type)
236+
{
237+
if (device_type == base_device::AbacusDevice_t::CpuDevice) {
238+
ssymm_(&side, &uplo, &m, &n,
239+
&alpha, a, &lda, b, &ldb,
240+
&beta, c, &ldc);
241+
}
242+
}
243+
244+
void BlasConnector::symm_cm(const char side, const char uplo, const int m, const int n,
245+
const double alpha, const double *a, const int lda, const double *b, const int ldb,
246+
const double beta, double *c, const int ldc, base_device::AbacusDevice_t device_type)
247+
{
248+
if (device_type == base_device::AbacusDevice_t::CpuDevice) {
249+
dsymm_(&side, &uplo, &m, &n,
250+
&alpha, a, &lda, b, &ldb,
251+
&beta, c, &ldc);
252+
}
253+
}
254+
255+
void BlasConnector::symm_cm(const char side, const char uplo, const int m, const int n,
256+
const std::complex<float> alpha, const std::complex<float> *a, const int lda, const std::complex<float> *b, const int ldb,
257+
const std::complex<float> beta, std::complex<float> *c, const int ldc, base_device::AbacusDevice_t device_type)
258+
{
259+
if (device_type == base_device::AbacusDevice_t::CpuDevice) {
260+
csymm_(&side, &uplo, &m, &n,
261+
&alpha, a, &lda, b, &ldb,
262+
&beta, c, &ldc);
263+
}
264+
}
265+
266+
void BlasConnector::symm_cm(const char side, const char uplo, const int m, const int n,
267+
const std::complex<double> alpha, const std::complex<double> *a, const int lda, const std::complex<double> *b, const int ldb,
268+
const std::complex<double> beta, std::complex<double> *c, const int ldc, base_device::AbacusDevice_t device_type)
269+
{
270+
if (device_type == base_device::AbacusDevice_t::CpuDevice) {
271+
zsymm_(&side, &uplo, &m, &n,
272+
&alpha, a, &lda, b, &ldb,
273+
&beta, c, &ldc);
274+
}
275+
}
276+
277+
void BlasConnector::hemm_cm(char side, char uplo, int m, int n,
278+
std::complex<float> alpha, std::complex<float> *a, int lda, std::complex<float> *b, int ldb,
279+
std::complex<float> beta, std::complex<float> *c, int ldc, base_device::AbacusDevice_t device_type)
280+
{
281+
if (device_type == base_device::AbacusDevice_t::CpuDevice) {
282+
chemm_(&side, &uplo, &m, &n,
283+
&alpha, a, &lda, b, &ldb,
284+
&beta, c, &ldc);
285+
}
286+
}
287+
288+
void BlasConnector::hemm_cm(char side, char uplo, int m, int n,
289+
std::complex<double> alpha, std::complex<double> *a, int lda, std::complex<double> *b, int ldb,
290+
std::complex<double> beta, std::complex<double> *c, int ldc, base_device::AbacusDevice_t device_type)
291+
{
292+
if (device_type == base_device::AbacusDevice_t::CpuDevice) {
293+
zhemm_(&side, &uplo, &m, &n,
294+
&alpha, a, &lda, b, &ldb,
295+
&beta, c, &ldc);
296+
}
297+
}
298+
157299
void BlasConnector::gemv(const char trans, const int m, const int n,
158300
const float alpha, const float* A, const int lda, const float* X, const int incx,
159301
const float beta, float* Y, const int incy, base_device::AbacusDevice_t device_type)
@@ -190,7 +332,6 @@ void BlasConnector::gemv(const char trans, const int m, const int n,
190332
}
191333
}
192334

193-
194335
// out = ||x||_2
195336
float BlasConnector::nrm2( const int n, const float *X, const int incX, base_device::AbacusDevice_t device_type )
196337
{

source/module_base/blas_connector.h

Lines changed: 72 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -111,11 +111,23 @@ extern "C"
111111
const std::complex<double> *alpha, const std::complex<double> *a, const int *lda, const std::complex<double> *b, const int *ldb,
112112
const std::complex<double> *beta, std::complex<double> *c, const int *ldc);
113113

114-
//a is symmetric
114+
// A is symmetric. C = a * A.? * B.? + b * C
115+
void ssymm_(const char *side, const char *uplo, const int *m, const int *n,
116+
const float *alpha, const float *a, const int *lda, const float *b, const int *ldb,
117+
const float *beta, float *c, const int *ldc);
115118
void dsymm_(const char *side, const char *uplo, const int *m, const int *n,
116119
const double *alpha, const double *a, const int *lda, const double *b, const int *ldb,
117120
const double *beta, double *c, const int *ldc);
118-
//a is hermitian
121+
void csymm_(const char *side, const char *uplo, const int *m, const int *n,
122+
const std::complex<float> *alpha, const std::complex<float> *a, const int *lda, const std::complex<float> *b, const int *ldb,
123+
const std::complex<float> *beta, std::complex<float> *c, const int *ldc);
124+
void zsymm_(const char *side, const char *uplo, const int *m, const int *n,
125+
const std::complex<double> *alpha, const std::complex<double> *a, const int *lda, const std::complex<double> *b, const int *ldb,
126+
const std::complex<double> *beta, std::complex<double> *c, const int *ldc);
127+
128+
// A is hermitian. C = a * A.? * B.? + b * C
129+
void chemm_(char *side, char *uplo, int *m, int *n,std::complex<float> *alpha,
130+
std::complex<float> *a, int *lda, std::complex<float> *b, int *ldb, std::complex<float> *beta, std::complex<float> *c, int *ldc);
119131
void zhemm_(char *side, char *uplo, int *m, int *n,std::complex<double> *alpha,
120132
std::complex<double> *a, int *lda, std::complex<double> *b, int *ldb, std::complex<double> *beta, std::complex<double> *c, int *ldc);
121133

@@ -175,6 +187,7 @@ class BlasConnector
175187

176188
// Peize Lin add 2017-10-27, fix bug trans 2019-01-17
177189
// C = a * A.? * B.? + b * C
190+
// Row Major by default
178191
static
179192
void gemm(const char transa, const char transb, const int m, const int n, const int k,
180193
const float alpha, const float *a, const int lda, const float *b, const int ldb,
@@ -195,6 +208,61 @@ class BlasConnector
195208
const std::complex<double> alpha, const std::complex<double> *a, const int lda, const std::complex<double> *b, const int ldb,
196209
const std::complex<double> beta, std::complex<double> *c, const int ldc, base_device::AbacusDevice_t device_type = base_device::AbacusDevice_t::CpuDevice);
197210

211+
// Col-Major if you need to use it
212+
213+
static
214+
void gemm_cm(const char transa, const char transb, const int m, const int n, const int k,
215+
const float alpha, const float *a, const int lda, const float *b, const int ldb,
216+
const float beta, float *c, const int ldc, base_device::AbacusDevice_t device_type = base_device::AbacusDevice_t::CpuDevice);
217+
218+
static
219+
void gemm_cm(const char transa, const char transb, const int m, const int n, const int k,
220+
const double alpha, const double *a, const int lda, const double *b, const int ldb,
221+
const double beta, double *c, const int ldc, base_device::AbacusDevice_t device_type = base_device::AbacusDevice_t::CpuDevice);
222+
223+
static
224+
void gemm_cm(const char transa, const char transb, const int m, const int n, const int k,
225+
const std::complex<float> alpha, const std::complex<float> *a, const int lda, const std::complex<float> *b, const int ldb,
226+
const std::complex<float> beta, std::complex<float> *c, const int ldc, base_device::AbacusDevice_t device_type = base_device::AbacusDevice_t::CpuDevice);
227+
228+
static
229+
void gemm_cm(const char transa, const char transb, const int m, const int n, const int k,
230+
const std::complex<double> alpha, const std::complex<double> *a, const int lda, const std::complex<double> *b, const int ldb,
231+
const std::complex<double> beta, std::complex<double> *c, const int ldc, base_device::AbacusDevice_t device_type = base_device::AbacusDevice_t::CpuDevice);
232+
233+
// Because you cannot pack symm or hemm into a row-major kernel by exchanging parameters, so only col-major functions are provided.
234+
static
235+
void symm_cm(const char side, const char uplo, const int m, const int n,
236+
const float alpha, const float *a, const int lda, const float *b, const int ldb,
237+
const float beta, float *c, const int ldc, base_device::AbacusDevice_t device_type = base_device::AbacusDevice_t::CpuDevice);
238+
239+
static
240+
void symm_cm(const char side, const char uplo, const int m, const int n,
241+
const double alpha, const double *a, const int lda, const double *b, const int ldb,
242+
const double beta, double *c, const int ldc, base_device::AbacusDevice_t device_type = base_device::AbacusDevice_t::CpuDevice);
243+
244+
static
245+
void symm_cm(const char side, const char uplo, const int m, const int n,
246+
const std::complex<float> alpha, const std::complex<float> *a, const int lda, const std::complex<float> *b, const int ldb,
247+
const std::complex<float> beta, std::complex<float> *c, const int ldc, base_device::AbacusDevice_t device_type = base_device::AbacusDevice_t::CpuDevice);
248+
249+
static
250+
void symm_cm(const char side, const char uplo, const int m, const int n,
251+
const std::complex<double> alpha, const std::complex<double> *a, const int lda, const std::complex<double> *b, const int ldb,
252+
const std::complex<double> beta, std::complex<double> *c, const int ldc, base_device::AbacusDevice_t device_type = base_device::AbacusDevice_t::CpuDevice);
253+
254+
static
255+
void hemm_cm(char side, char uplo, int m, int n,
256+
std::complex<float> alpha, std::complex<float> *a, int lda, std::complex<float> *b, int ldb,
257+
std::complex<float> beta, std::complex<float> *c, int ldc, base_device::AbacusDevice_t device_type = base_device::AbacusDevice_t::CpuDevice);
258+
259+
static
260+
void hemm_cm(char side, char uplo, int m, int n,
261+
std::complex<double> alpha, std::complex<double> *a, int lda, std::complex<double> *b, int ldb,
262+
std::complex<double> beta, std::complex<double> *c, int ldc, base_device::AbacusDevice_t device_type = base_device::AbacusDevice_t::CpuDevice);
263+
264+
// y = A*x + beta*y
265+
198266
static
199267
void gemv(const char trans, const int m, const int n,
200268
const float alpha, const float* A, const int lda, const float* X, const int incx,
@@ -234,6 +302,8 @@ class BlasConnector
234302

235303
static
236304
void copy(const long n, const std::complex<double> *a, const int incx, std::complex<double> *b, const int incy, base_device::AbacusDevice_t device_type = base_device::AbacusDevice_t::CpuDevice);
305+
306+
// A is symmetric
237307
};
238308

239309
// If GATHER_INFO is defined, the original function is replaced with a "i" suffix,

source/module_hamilt_lcao/module_gint/mult_psi_dmr.cpp

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
#include "gint_tools.h"
22
#include "module_base/timer.h"
33
#include "module_base/ylm.h"
4+
#include "module_base/blas_connector.h"
45

56
namespace Gint_Tools{
67

@@ -60,8 +61,8 @@ void mult_psi_DMR(
6061

6162
const auto tmp_matrix_ptr = tmp_matrix->get_pointer();
6263
const int idx1 = block_index[ia1];
63-
dsymm_(&side, &uplo, &block_size[ia1], &ib_len, &alpha, tmp_matrix_ptr, &block_size[ia1],
64-
&psi[ib_start][idx1], &LD_pool, &beta, &psi_DMR[ib_start][idx1], &LD_pool);
64+
BlasConnector::symm_cm(side, uplo, block_size[ia1], ib_len, alpha, tmp_matrix_ptr, block_size[ia1],
65+
&psi[ib_start][idx1], LD_pool, beta, &psi_DMR[ib_start][idx1], LD_pool);
6566
}
6667

6768
//! get (j,beta,R2)

0 commit comments

Comments
 (0)