forked from codeplaysoftware/portBLAS
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathblas3_interface.h
113 lines (103 loc) · 5.53 KB
/
blas3_interface.h
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
/***************************************************************************
*
* @license
* Copyright (C) Codeplay Software Limited
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* For your convenience, a copy of the License has been included in this
* repository.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
* SYCL-BLAS: BLAS implementation using SYCL
*
* @filename blas3_interface.h
*
**************************************************************************/
#ifndef SYCL_BLAS_BLAS3_INTERFACE_H
#define SYCL_BLAS_BLAS3_INTERFACE_H
#include "operations/blas3_trees.h"
namespace blas {
namespace internal {
/*!
* @brief This is a top-level wrapper for GemmFactory, which provides a
* "standard" BLAS gemm interface.
*
* See the netlib blas interface documentation for more details of the hig
* level interface:
* http://www.netlib.org/lapack/explore-html/d4/de2/sgemm_8f.html
*/
template <typename executor_t, typename container_0_t, typename container_1_t,
typename container_2_t, typename element_t, typename index_t>
typename executor_t::policy_t::event_t _gemm(executor_t& ex, char _TransA,
char _TransB, index_t _M,
index_t _N, index_t _K,
element_t _alpha, container_0_t a_,
index_t _lda, container_1_t b_,
index_t _ldb, element_t _beta,
container_2_t _C, index_t _ldc);
template <typename executor_t, typename container_0_t, typename container_1_t,
typename container_2_t, typename element_t, typename index_t>
typename executor_t::policy_t::event_t _gemm_batched(
executor_t& ex, char _TransA, char _TransB, index_t _M, index_t _N,
index_t _K, element_t _alpha, container_0_t a_, index_t _lda,
container_1_t b_, index_t _ldb, element_t _beta, container_2_t _C,
index_t _ldc, index_t batch_size,
gemm_batch_type_t batch_type = gemm_batch_type_t::strided);
template <typename executor_t, typename container_0_t, typename container_1_t,
typename element_t, typename index_t>
typename executor_t::policy_t::event_t _trsm(executor_t& ex, char side,
char uplo, char trans,
char diag, index_t M,
index_t N, element_t alpha,
container_0_t A, index_t lda,
container_1_t B, index_t ldb);
} // namespace internal
template <typename executor_t, typename container_0_t, typename container_1_t,
typename container_2_t, typename element_t, typename index_t>
typename executor_t::policy_t::event_t _gemm(executor_t& ex, char _TransA,
char _TransB, index_t _M,
index_t _N, index_t _K,
element_t _alpha, container_0_t a_,
index_t _lda, container_1_t b_,
index_t _ldb, element_t _beta,
container_2_t _C, index_t _ldc) {
return internal::_gemm(ex, _TransA, _TransB, _M, _N, _K, _alpha,
ex.get_policy_handler().get_buffer(a_), _lda,
ex.get_policy_handler().get_buffer(b_), _ldb, _beta,
ex.get_policy_handler().get_buffer(_C), _ldc);
}
template <typename executor_t, typename container_0_t, typename container_1_t,
typename container_2_t, typename element_t, typename index_t>
typename executor_t::policy_t::event_t _gemm_batched(
executor_t& ex, char _TransA, char _TransB, index_t _M, index_t _N,
index_t _K, element_t _alpha, container_0_t a_, index_t _lda,
container_1_t b_, index_t _ldb, element_t _beta, container_2_t _C,
index_t _ldc, index_t batch_size,
gemm_batch_type_t batch_type = gemm_batch_type_t::strided) {
return internal::_gemm_batched(ex, _TransA, _TransB, _M, _N, _K, _alpha,
ex.get_policy_handler().get_buffer(a_), _lda,
ex.get_policy_handler().get_buffer(b_), _ldb,
_beta, ex.get_policy_handler().get_buffer(_C),
_ldc, batch_size, batch_type);
}
template <typename executor_t, typename container_0_t, typename container_1_t,
typename element_t, typename index_t>
typename executor_t::policy_t::event_t inline _trsm(
executor_t& ex, char side, char uplo, char trans, char diag, index_t M,
index_t N, element_t alpha, container_0_t A, index_t lda, container_1_t B,
index_t ldb) {
return internal::_trsm(ex, side, uplo, trans, diag, M, N, alpha,
ex.get_policy_handler().get_buffer(A), lda,
ex.get_policy_handler().get_buffer(B), ldb);
}
} // namespace blas
#endif // SYCL_BLAS_BLAS3_INTERFACE