Skip to content

Commit 7ec2df6

Browse files
committed
Added: tri, cumsum. Still a mess.
1 parent 6d0ad37 commit 7ec2df6

File tree

8 files changed

+996
-205
lines changed

8 files changed

+996
-205
lines changed

ggml/include/ggml.h

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -243,6 +243,8 @@
243243

244244
#define GGML_MROPE_SECTIONS 4
245245

246+
#define GGML_DELTA_NET_CHUNK 64
247+
246248
#define GGML_UNUSED(x) (void)(x)
247249
#ifdef __CUDACC__
248250
template<typename... Args>
@@ -472,6 +474,7 @@ extern "C" {
472474
GGML_OP_COS,
473475
GGML_OP_SUM,
474476
GGML_OP_SUM_ROWS,
477+
GGML_OP_CUMSUM,
475478
GGML_OP_MEAN,
476479
GGML_OP_ARGMAX,
477480
GGML_OP_COUNT_EQUAL,
@@ -527,6 +530,7 @@ extern "C" {
527530
GGML_OP_TIMESTEP_EMBEDDING,
528531
GGML_OP_ARGSORT,
529532
GGML_OP_LEAKY_RELU,
533+
GGML_OP_TRI,
530534

531535
GGML_OP_FLASH_ATTN_EXT,
532536
GGML_OP_FLASH_ATTN_BACK,
@@ -539,6 +543,7 @@ extern "C" {
539543
GGML_OP_RWKV_WKV6,
540544
GGML_OP_GATED_LINEAR_ATTN,
541545
GGML_OP_RWKV_WKV7,
546+
GGML_OP_DELTA_NET,
542547

543548
GGML_OP_UNARY,
544549

@@ -612,6 +617,13 @@ extern "C" {
612617
GGML_TENSOR_FLAG_LOSS = 8, // ...defines loss for numerical optimization (multiple loss tensors add up)
613618
};
614619

620+
enum ggml_tri_type {
621+
GGML_TRI_TYPE_UPPER_DIAG = 0,
622+
GGML_TRI_TYPE_UPPER = 1,
623+
GGML_TRI_TYPE_LOWER_DIAG = 2,
624+
GGML_TRI_TYPE_LOWER = 3
625+
};
626+
615627
struct ggml_init_params {
616628
// memory pool
617629
size_t mem_size; // bytes
@@ -975,6 +987,10 @@ extern "C" {
975987
struct ggml_context * ctx,
976988
struct ggml_tensor * a);
977989

990+
GGML_API struct ggml_tensor * ggml_cumsum(
991+
struct ggml_context * ctx,
992+
struct ggml_tensor * a);
993+
978994
// mean along rows
979995
GGML_API struct ggml_tensor * ggml_mean(
980996
struct ggml_context * ctx,
@@ -2119,6 +2135,17 @@ extern "C" {
21192135
int shift2,
21202136
int shift3);
21212137

2138+
// Make matrix into a triangular one (upper, upper + diagonal, lower or lower + diagonal) with constant value
2139+
GGML_API struct ggml_tensor * ggml_tri(
2140+
struct ggml_context * ctx,
2141+
struct ggml_tensor * a,
2142+
float constant,
2143+
enum ggml_tri_type tritype);
2144+
2145+
GGML_API struct ggml_tensor * ggml_tri_keep(
2146+
struct ggml_context * ctx,
2147+
struct ggml_tensor * a,
2148+
enum ggml_tri_type tritype);
21222149

21232150
// Ref: https://github.com/CompVis/stable-diffusion/blob/main/ldm/modules/diffusionmodules/util.py#L151
21242151
// timesteps: [N,]
@@ -2289,6 +2316,18 @@ extern "C" {
22892316
struct ggml_tensor * b,
22902317
struct ggml_tensor * state);
22912318

2319+
GGML_API struct ggml_tensor * ggml_delta_net(
2320+
struct ggml_context * ctx,
2321+
struct ggml_tensor * q,
2322+
struct ggml_tensor * k,
2323+
struct ggml_tensor * v,
2324+
struct ggml_tensor * g,
2325+
struct ggml_tensor * beta,
2326+
struct ggml_tensor * state,
2327+
bool use_qk_l2norm,
2328+
float scale,
2329+
float eps_norm);
2330+
22922331
// custom operators
22932332

22942333
typedef void (*ggml_custom1_op_t)(struct ggml_tensor * dst , const struct ggml_tensor * a, int ith, int nth, void * userdata);

ggml/src/ggml-cpu/ggml-cpu.c

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1731,6 +1731,10 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm
17311731
{
17321732
ggml_compute_forward_sum_rows(params, tensor);
17331733
} break;
1734+
case GGML_OP_CUMSUM:
1735+
{
1736+
ggml_compute_forward_cumsum(params, tensor);
1737+
} break;
17341738
case GGML_OP_MEAN:
17351739
{
17361740
ggml_compute_forward_mean(params, tensor);
@@ -1943,6 +1947,10 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm
19431947
{
19441948
ggml_compute_forward_leaky_relu(params, tensor);
19451949
} break;
1950+
case GGML_OP_TRI:
1951+
{
1952+
ggml_compute_forward_tri(params, tensor);
1953+
} break;
19461954
case GGML_OP_FLASH_ATTN_EXT:
19471955
{
19481956
ggml_compute_forward_flash_attn_ext(params, tensor);
@@ -1998,6 +2006,10 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm
19982006
{
19992007
ggml_compute_forward_rwkv_wkv7(params, tensor);
20002008
} break;
2009+
case GGML_OP_DELTA_NET:
2010+
{
2011+
ggml_compute_forward_delta_net_f32(params, tensor);
2012+
} break;
20012013
case GGML_OP_MAP_CUSTOM1:
20022014
{
20032015
ggml_compute_forward_map_custom1(params, tensor);
@@ -2153,6 +2165,8 @@ static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads) {
21532165
case GGML_OP_SUM_ROWS:
21542166
case GGML_OP_MEAN:
21552167
case GGML_OP_ARGMAX:
2168+
case GGML_OP_CUMSUM:
2169+
case GGML_OP_TRI:
21562170
{
21572171
n_tasks = 1;
21582172
} break;
@@ -2297,6 +2311,7 @@ static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads) {
22972311
case GGML_OP_WIN_PART:
22982312
case GGML_OP_WIN_UNPART:
22992313
case GGML_OP_GET_REL_POS:
2314+
case GGML_OP_DELTA_NET:
23002315
{
23012316
n_tasks = 1;
23022317
} break;

0 commit comments

Comments
 (0)