-
Notifications
You must be signed in to change notification settings - Fork 0
/
wmma_overlap.cu
140 lines (112 loc) · 4.23 KB
/
wmma_overlap.cu
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
#include<mma.h>
#include<cuda/pipeline>
#include<cooperative_groups.h>
#include<stdio.h>
using namespace nvcuda;
#define CUDA_CHECK_RETURN(X) X
#define NUM_ITERS 10
// Disables `pipeline_shared_state` initialization warning.
#pragma nv_diag_suppress static_var_with_dynamic_init
// Define some error checking macros.
#define cudaErrCheck(stat) \
{ \
cudaErrCheck_((stat), __FILE__, __LINE__); \
}
void cudaErrCheck_(cudaError_t stat, const char *file, int line)
{
if (stat != cudaSuccess)
{
fprintf(stderr, "CUDA Error: %s %s %d\n", cudaGetErrorString(stat), file, line);
}
}
const int WMMA_M = 16;
const int WMMA_N = 16;
const int WMMA_K = 16;
//const int M = 327660;
const int M = 327664; //To make it a closest multiple of 16!
const int N = 1536;
const int K = 512;
const int num_threads = 512;
const int smem = K * 16;
__global__ void wmma_kernel(half* a, half* b, float* c){
__shared__ half SMEM[2*smem];
int warp_id = threadIdx.x/32;
int num_warps = num_threads/32;
int work_per_warp = N/(WMMA_N*num_warps);
wmma::fragment<wmma::matrix_a,WMMA_M,WMMA_N,WMMA_K,half,wmma::row_major> frag_a;
wmma::fragment<wmma::matrix_b,WMMA_M,WMMA_N,WMMA_K,half,wmma::row_major> frag_b;
wmma::fragment<wmma::accumulator,WMMA_M,WMMA_N,WMMA_K,float> frag_c;
auto group = cooperative_groups::this_thread_block();
constexpr auto scope = cuda::thread_scope_block;
constexpr auto stages_count = 2;
__shared__ cuda::pipeline_shared_state<scope, stages_count> shared_state;
auto pipeline = cuda::make_pipeline(group, &shared_state);
for(int it=0; it<NUM_ITERS; it++){
//Warm the pipeline
pipeline.producer_acquire();
cuda::memcpy_async(group,&SMEM[0],&a[0],sizeof(half)*K*16,pipeline);
pipeline.producer_commit();
//Run the pipeline
for(int m=1; m<(M/16); m++){
pipeline.producer_acquire();
cuda::memcpy_async(group,&SMEM[(m%2)?smem:0],&a[m*K*16],sizeof(half)*K*16,pipeline);
pipeline.producer_commit();
pipeline.consumer_wait();
for(int i=0 ; i<1; i++){
for(int j=0; j<work_per_warp; j++){
wmma::fill_fragment(frag_c,0.0f);
for(int k=0; k<(K/WMMA_K); k++){
wmma::load_matrix_sync(frag_a,&SMEM[((i*K*WMMA_M) + (k*WMMA_K)) + (m%2)?0:smem],K);
wmma::load_matrix_sync(frag_b,&b[(j*WMMA_N) + work_per_warp*warp_id*WMMA_N + (k*WMMA_K*N)],N);
wmma::mma_sync(frag_c,frag_a,frag_b,frag_c);
}
wmma::store_matrix_sync(&c[(i*WMMA_M*N)+((m-1)*16*N)+((j+(warp_id*work_per_warp))*WMMA_N)],frag_c,N,wmma::mem_row_major);
}
}
pipeline.consumer_release();
}
//Drain the pipeline
pipeline.consumer_wait();
for(int j=0; j<work_per_warp; j++){
wmma::fill_fragment(frag_c,0.0f);
for(int k=0; k<(K/WMMA_K); k++){
wmma::load_matrix_sync(frag_a,&SMEM[(k*WMMA_K) + ((M/16) % 2)?0:smem],K);
wmma::load_matrix_sync(frag_b,&b[(j*WMMA_N) + (work_per_warp*warp_id*WMMA_N) + (k*WMMA_K*N)],N);
wmma::mma_sync(frag_c,frag_a,frag_b,frag_c);
}
wmma::store_matrix_sync(&c[(((M/16)-1)*16*N) + ((j+(warp_id*work_per_warp))*WMMA_N)],frag_c,N,wmma::mem_row_major);
}
pipeline.consumer_release();
}
}
int main(){
half *d_a, *h_a, *d_b, *h_b;
float *d_c, *h_c;
h_c = new float[M*N];
h_b = new half[K*N];
h_a = new half[M*K];
cudaMalloc(&d_a, M*K*sizeof(half));
cudaMalloc(&d_b, K*N*sizeof(half));
cudaMalloc(&d_c, M*N*sizeof(float));
for (int i = 0; i < M*K; i++)
h_a[i] = 1.0f;
for (int i = 0; i < N*K; i++)
h_b[i] = 1.0f;
cudaMemcpy(d_a, h_a, M*K*sizeof(half), cudaMemcpyHostToDevice);
cudaMemcpy(d_b, h_b, K*N*sizeof(half), cudaMemcpyHostToDevice);
cudaEvent_t start, stop;
CUDA_CHECK_RETURN(cudaEventCreate(&start));
CUDA_CHECK_RETURN(cudaEventCreate(&stop));
CUDA_CHECK_RETURN(cudaEventRecord(start));
wmma_kernel<<<1,num_threads>>>(d_a, d_b, d_c);
cudaErrCheck(cudaGetLastError());
CUDA_CHECK_RETURN(cudaEventRecord(stop));
cudaMemcpy(h_c, d_c, M*N*sizeof(float), cudaMemcpyDeviceToHost);
for(int i=0; i<M*N; i++)
if(h_c[i] != K)
printf("Error at: %d %f\n",i,h_c[i]);
float elapsedTime;
cudaEventElapsedTime(&elapsedTime, start, stop);
printf("Elapsed Time : %f\n",elapsedTime);
return 0;
}