-
Notifications
You must be signed in to change notification settings - Fork 0
/
lstm.h
51 lines (41 loc) · 1.11 KB
/
lstm.h
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
#ifndef __LSTM_H__
#define __LSTM_H__
#include "cuda_runtime.h"
enum LSTMScaleParams {
kLstmGateNumber = 4,
kHiddenSize = 256,
kInputSize = 256,
kCellNumber = 10,
kLstmTimestep = 100,
};
enum LSTMKernelScaleParams {
kThreadsPerWarp = 32,
kWarpsPerBlock = 8,
kColumnsPerBlock = kThreadsPerWarp,
kGemvBlockNumber = kHiddenSize / kColumnsPerBlock,
kRowsPerWarp = kHiddenSize / kWarpsPerBlock,
};
typedef struct {
float weights_w[kLstmGateNumber][kInputSize][kHiddenSize];
float weights_u[kLstmGateNumber][kHiddenSize][kHiddenSize];
float bias[kLstmGateNumber][kHiddenSize];
} CellModel;
typedef struct {
float data[kHiddenSize];
} StepInput;
typedef struct {
float state_h[kHiddenSize];
float state_c[kHiddenSize];
float gemvw_temp[kLstmGateNumber][kHiddenSize];
float gemvu_temp[kLstmGateNumber][kHiddenSize];
} CellRuntime;
typedef struct {
CellModel cell_model[kCellNumber];
} ModelParams;
typedef struct {
StepInput step_input[kLstmTimestep];
} InputParams;
typedef struct {
CellRuntime cell_runtime[kCellNumber];
} CellParams;
#endif