-
Notifications
You must be signed in to change notification settings - Fork 0
/
sgemm.cpp
113 lines (93 loc) · 3.8 KB
/
sgemm.cpp
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
#include <iostream>
#include <vector>
#include <random>
#include <algorithm>
#define CL_HPP_MINIMUM_OPENCL_VERSION 120
#define CL_HPP_TARGET_OPENCL_VERSION 120
#include "CL/opencl.hpp"
#include "clblast.h"
#include "cblas.h"
void InitVector(std::vector<float> &v) {
std::random_device rd;
std::mt19937 gen(rd());
std::uniform_real_distribution<float> distribution(0.0, 1.0);
for (int i = 0; i < v.size(); i++) {
v[i] = distribution(gen);
}
}
float Compare(const int M, const int N, const float *C, const float *C_ref) {
float diff, max_diff = 0.0f;
for (int i = 0; i < M; i++) {
for (int j = 0; j < N; j++) {
float c = C[i * N + j], c_ref = C_ref[i * N + j];
diff = std::abs(c - c_ref);
max_diff = diff > max_diff ? diff : max_diff;
}
}
return max_diff;
}
int main() {
std::vector<cl::Platform> platforms;
cl::Platform::get(&platforms);
auto platform = platforms.front();
std::vector<cl::Device> devices;
platform.getDevices(CL_DEVICE_TYPE_ALL, &devices);
auto device = devices.front();
// Create an OpenCL context
cl::Context context({device});
// Create a command queue
cl::CommandQueue queue(context, device);
int range_start = 1000;
int range_end = 3200;
int scale = (range_start + range_end) / 2;
while(scale > range_start && scale < range_end) {
std::vector<float> A(scale * scale),
B(scale * scale),
C_Openblas(scale * scale, 0),
C_Clblast(scale * scale, 0);
InitVector(A);
InitVector(B);
cblas_sgemm(CblasRowMajor, CblasNoTrans, CblasNoTrans, scale, scale, scale, 1.f, A.data(), scale, B.data(), scale, 0.f, C_Openblas.data(), scale);
// OpenCL computation
cl::Buffer A_Buffer(context, CL_MEM_READ_ONLY, sizeof(float) * A.size()),
B_Buffer(context, CL_MEM_READ_ONLY, sizeof(float) * B.size()),
C_Clblast_Buffer(context, CL_MEM_WRITE_ONLY, sizeof(float) * C_Clblast.size());
queue.enqueueWriteBuffer(A_Buffer, CL_TRUE, 0, A.size() * sizeof(float), A.data());
queue.enqueueWriteBuffer(B_Buffer, CL_TRUE, 0, B.size() * sizeof(float), B.data());
queue.enqueueWriteBuffer(C_Clblast_Buffer, CL_TRUE, 0, C_Clblast.size() * sizeof(float), C_Clblast.data());
auto raw_queue = queue();
cl::Event event{nullptr};
auto raw_event = event();
auto status = clblast::Gemm(clblast::Layout::kRowMajor, clblast::Transpose::kNo, clblast::Transpose::kNo,
scale, scale, scale,
1.f,
A_Buffer(), 0, scale,
B_Buffer(), 0, scale,
0.f,
C_Clblast_Buffer(), 0, scale,
&raw_queue, &raw_event);
if (status == clblast::StatusCode::kSuccess) {
clWaitForEvents(1, &raw_event);
clFinish(raw_queue);
}
queue.enqueueReadBuffer(C_Clblast_Buffer, CL_TRUE, 0, C_Clblast.size() * sizeof(float), C_Clblast.data(), nullptr, &event);
cl::WaitForEvents({event});
// Compare
float max_diff = max_diff = Compare(scale, scale, C_Clblast.data(), C_Openblas.data());
int next_scale;
if (max_diff >= 1e-3) {
printf("scale: %d, max_diff: %f\n", scale, max_diff);
range_end = scale;
} else {
printf("sacle: %d, OK\n", scale);
range_start = scale;
}
next_scale = (range_start + range_end) / 2;
if (next_scale == scale) {
break;
} else {
scale = next_scale;
}
}
return 0;
}