forked from a1k0n/a1gpt
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathkernel.h
14 lines (10 loc) · 775 Bytes
/
kernel.h
1
2
3
4
5
6
7
8
9
10
11
12
13
14
#pragma once
extern void loadEmbedding(float *output, int token, int pos, int embeddingSize, float* wte, float* wpe);
extern void layerNorm(float* output, int embedding_dim, float* gamma, float* beta, float* input);
extern void qkv(int kv_idx, float* xbuf, float *qbuf, float* kvbuf,
float* attn_weight, float* attn_bias, int embedding_dim);
extern void gemv(float *y, float *A, float *x, float *b, int m, int k);
extern void gemvSum(float *y, float *A, float *x, float *b, int m, int k);
extern void gemvGelu(float *y, float *A, float *x, float *b, int m, int k);
extern void attn(int kv_idx, float *xbuf, float *qbuf, float *kvbuf, int emb_siz, int num_heads);
extern void attn2(int kv_idx, float *xbuf, float *qbuf, float *kvbuf, int emb_siz, int num_heads);