-
Notifications
You must be signed in to change notification settings - Fork 147
/
Copy pathllm.hpp
95 lines (86 loc) · 2.24 KB
/
llm.hpp
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
#ifndef LLM_HPP
#define LLM_HPP
#include "nn/nn-core.hpp"
#include "nn/nn-executor.hpp"
#include "nn/nn-network.hpp"
enum LlmHeaderKey {
VERSION = 0,
ARCH_TYPE = 1,
DIM = 2,
HIDDEN_DIM = 3,
N_LAYERS = 4,
N_HEADS = 5,
N_KV_HEADS = 6,
N_EXPERTS = 7,
N_ACTIVE_EXPERTS = 8,
VOCAB_SIZE = 9,
SEQ_LEN = 10,
HIDDEN_ACT = 11,
ROPE_THETA = 12,
WEIGHT_FLOAT_TYPE = 13,
ROPE_SCALING_FACTOR = 14,
ROPE_SCALING_LOW_FREQ_FACTOR = 15,
ROPE_SCALING_HIGH_FREQ_FACTORY = 16,
ROPE_SCALING_ORIG_MAX_SEQ_LEN = 17,
ROPE_TYPE = 18,
};
enum LlmHiddenAct {
HIDDEN_ACT_GELU,
HIDDEN_ACT_SILU,
};
enum LlmArchType {
LLAMA = 0xABCD00,
};
typedef struct {
NnSize headerSize;
NnSize fileSize;
int version;
LlmArchType archType;
NnUint dim;
NnUint nLayers;
NnUint nHeads;
NnUint headSize;
NnUint nKvHeads;
NnUint nExperts;
NnUint nActiveExperts;
NnUint origSeqLen; // Original model context length
NnUint seqLen; // Limited context length by the `--max-seq-len` argument
NnUint hiddenDim;
LlmHiddenAct hiddenAct;
NnUint kvDim;
NnUint vocabSize;
float ropeTheta;
NnRopeType ropeType;
float ropeScalingFactor;
float ropeScalingLowFreqFactor;
float ropeScalingHighFreqFactory;
NnUint ropeScalingOrigMaxSeqLen;
float normEpsilon;
NnFloatType weightType;
NnFloatType syncType;
} LlmHeader;
typedef struct {
LlmHeader *header;
NnNetConfig netConfig;
NnNodeConfig *nodeConfigs;
NnRowMatmulSlice qSlice;
NnRowMatmulSlice kSlice;
NnRowMatmulSlice vSlice;
NnColMatmulSlice woSlice;
NnRowMatmulSlice w1Slice;
NnColMatmulSlice w2Slice;
NnRowMatmulSlice w3Slice;
NnRowMatmulSlice wclsSlice;
NnUint positionPipeIndex;
NnUint tokenPipeIndex;
NnUint xPipeIndex;
NnUint logitsPipeIndex;
NnSize2D tokenEmbeddingSize;
NnSize2D rmsNormSize;
} LlmNet;
LlmHeader loadLlmHeader(const char* path, const unsigned int maxSeqLen, NnFloatType syncType);
void printLlmHeader(LlmHeader *header);
LlmNet buildLlmNet(LlmHeader *h, NnUint nNodes, NnUint nBatches);
void releaseLlmNet(LlmNet *net);
void loadLlmNetWeight(const char* path, LlmNet *net, NnRootWeightLoader *loader);
#endif