forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
CPUBlas.h
139 lines (118 loc) · 4.15 KB
/
CPUBlas.h
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
#pragma once
#include <ATen/native/DispatchStub.h>
#include <c10/util/complex.h>
#include <c10/core/ScalarType.h>
#include <c10/core/Scalar.h>
namespace at {
namespace native {
namespace cpublas {
enum TransposeType {
Transpose,
NoTranspose,
// ConjTranspose, -- Not implemented
};
namespace internal {
void normalize_last_dims(
TransposeType transa, TransposeType transb,
int64_t m, int64_t n, int64_t k,
int64_t *lda, int64_t *ldb, int64_t *ldc);
} // namespace internal
using gemm_fn = void(*)(
at::ScalarType type,
TransposeType transa, TransposeType transb,
int64_t m, int64_t n, int64_t k,
const Scalar& alpha,
const void *a, int64_t lda,
const void *b, int64_t ldb,
const Scalar& beta,
void *c, int64_t ldc);
DECLARE_DISPATCH(gemm_fn, gemm_stub);
template <typename scalar_t>
void gemm(
TransposeType transa, TransposeType transb,
int64_t m, int64_t n, int64_t k,
scalar_t alpha,
const scalar_t *a, int64_t lda,
const scalar_t *b, int64_t ldb,
scalar_t beta,
scalar_t *c, int64_t ldc) {
internal::normalize_last_dims(transa, transb, m, n, k, &lda, &ldb, &ldc);
gemm_stub(
kCPU, c10::CppTypeToScalarType<scalar_t>::value,
transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc);
}
void gemm(
TransposeType transa, TransposeType transb,
int64_t m, int64_t n, int64_t k,
double alpha,
const double *a, int64_t lda,
const double *b, int64_t ldb,
double beta,
double *c, int64_t ldc);
void gemm(
TransposeType transa, TransposeType transb,
int64_t m, int64_t n, int64_t k,
float alpha,
const float *a, int64_t lda,
const float *b, int64_t ldb,
float beta,
float *c, int64_t ldc);
void gemm(
TransposeType transa, TransposeType transb,
int64_t m, int64_t n, int64_t k,
c10::complex<double> alpha,
const c10::complex<double> *a, int64_t lda,
const c10::complex<double> *b, int64_t ldb,
c10::complex<double> beta,
c10::complex<double> *c, int64_t ldc);
void gemm(
TransposeType transa, TransposeType transb,
int64_t m, int64_t n, int64_t k,
c10::complex<float> alpha,
const c10::complex<float> *a, int64_t lda,
const c10::complex<float> *b, int64_t ldb,
c10::complex<float> beta,
c10::complex<float> *c, int64_t ldc);
void gemm(
TransposeType transa, TransposeType transb,
int64_t m, int64_t n, int64_t k,
int64_t alpha,
const int64_t *a, int64_t lda,
const int64_t *b, int64_t ldb,
int64_t beta,
int64_t *c, int64_t ldc);
using axpy_fn = void(*)(at::ScalarType type, int64_t n, const Scalar& a, const void *x, int64_t incx, void *y, int64_t incy);
DECLARE_DISPATCH(axpy_fn, axpy_stub);
template<typename scalar_t>
void axpy(int64_t n, scalar_t a, const scalar_t *x, int64_t incx, scalar_t *y, int64_t incy){
if(n == 1)
{
incx = 1;
incy = 1;
}
axpy_stub(
kCPU, c10::CppTypeToScalarType<scalar_t>::value,
n, a, x, incx, y, incy);
}
void axpy(int64_t n, double a, const double *x, int64_t incx, double *y, int64_t incy);
void axpy(int64_t n, float a, const float *x, int64_t incx, float *y, int64_t incy);
void axpy(int64_t n, c10::complex<double> a, const c10::complex<double> *x, int64_t incx, c10::complex<double> *y, int64_t incy);
void axpy(int64_t n, c10::complex<float> a, const c10::complex<float> *x, int64_t incx, c10::complex<float> *y, int64_t incy);
using copy_fn = void(*)(at::ScalarType type, int64_t n, const void *x, int64_t incx, void *y, int64_t incy);
DECLARE_DISPATCH(copy_fn, copy_stub);
template<typename scalar_t>
void copy(int64_t n, const scalar_t *x, int64_t incx, scalar_t *y, int64_t incy) {
if(n == 1)
{
incx = 1;
incy = 1;
}
copy_stub(
kCPU, c10::CppTypeToScalarType<scalar_t>::value,
n, x, incx, y, incy);
}
void copy(int64_t n, const double *x, int64_t incx, double *y, int64_t incy);
void copy(int64_t n, const float *x, int64_t incx, float *y, int64_t incy);
void copy(int64_t n, const c10::complex<double> *x, int64_t incx, c10::complex<double> *y, int64_t incy);
void copy(int64_t n, const c10::complex<float> *x, int64_t incx, c10::complex<float> *y, int64_t incy);
}}} // namespace at::native::cpublas