Skip to content

Commit 1be2b8c

Browse files
authored
ggml : revert change to ggml_cpy, add ggml_cont_Nd instead (#3275)
ggml-ci
1 parent 2f3a46f commit 1be2b8c

File tree

3 files changed

+80
-12
lines changed

3 files changed

+80
-12
lines changed

ggml.c

+49-1
Original file line numberDiff line numberDiff line change
@@ -6343,7 +6343,7 @@ static struct ggml_tensor * ggml_cpy_impl(
63436343
}
63446344

63456345
// make a view of the destination
6346-
struct ggml_tensor * result = b->op == GGML_OP_NONE ? b : ggml_view_tensor(ctx, b);
6346+
struct ggml_tensor * result = ggml_view_tensor(ctx, b);
63476347
if (strlen(b->name) > 0) {
63486348
ggml_format_name(result, "%s (copy of %s)", b->name, a->name);
63496349
} else {
@@ -6406,6 +6406,54 @@ struct ggml_tensor * ggml_cont_inplace(
64066406
return ggml_cont_impl(ctx, a, true);
64076407
}
64086408

6409+
6410+
// make contiguous, with new shape
6411+
GGML_API struct ggml_tensor * ggml_cont_1d(
6412+
struct ggml_context * ctx,
6413+
struct ggml_tensor * a,
6414+
int64_t ne0) {
6415+
return ggml_cont_4d(ctx, a, ne0, 1, 1, 1);
6416+
}
6417+
6418+
GGML_API struct ggml_tensor * ggml_cont_2d(
6419+
struct ggml_context * ctx,
6420+
struct ggml_tensor * a,
6421+
int64_t ne0,
6422+
int64_t ne1) {
6423+
return ggml_cont_4d(ctx, a, ne0, ne1, 1, 1);
6424+
}
6425+
6426+
GGML_API struct ggml_tensor * ggml_cont_3d(
6427+
struct ggml_context * ctx,
6428+
struct ggml_tensor * a,
6429+
int64_t ne0,
6430+
int64_t ne1,
6431+
int64_t ne2) {
6432+
return ggml_cont_4d(ctx, a, ne0, ne1, ne2, 1);
6433+
}
6434+
6435+
struct ggml_tensor * ggml_cont_4d(
6436+
struct ggml_context * ctx,
6437+
struct ggml_tensor * a,
6438+
int64_t ne0,
6439+
int64_t ne1,
6440+
int64_t ne2,
6441+
int64_t ne3) {
6442+
GGML_ASSERT(ggml_nelements(a) == (ne0*ne1*ne2*ne3));
6443+
6444+
bool is_node = false;
6445+
6446+
struct ggml_tensor * result = ggml_new_tensor_4d(ctx, a->type, ne0, ne1, ne2, ne3);
6447+
ggml_format_name(result, "%s (cont)", a->name);
6448+
6449+
result->op = GGML_OP_CONT;
6450+
result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
6451+
result->src[0] = a;
6452+
6453+
return result;
6454+
}
6455+
6456+
64096457
// ggml_reshape
64106458

64116459
struct ggml_tensor * ggml_reshape(

ggml.h

+27-1
Original file line numberDiff line numberDiff line change
@@ -1049,7 +1049,6 @@ extern "C" {
10491049
size_t nb1,
10501050
size_t offset);
10511051

1052-
10531052
// a -> b, return view(b)
10541053
GGML_API struct ggml_tensor * ggml_cpy(
10551054
struct ggml_context * ctx,
@@ -1072,6 +1071,33 @@ extern "C" {
10721071
struct ggml_context * ctx,
10731072
struct ggml_tensor * a);
10741073

1074+
// make contiguous, with new shape
1075+
GGML_API struct ggml_tensor * ggml_cont_1d(
1076+
struct ggml_context * ctx,
1077+
struct ggml_tensor * a,
1078+
int64_t ne0);
1079+
1080+
GGML_API struct ggml_tensor * ggml_cont_2d(
1081+
struct ggml_context * ctx,
1082+
struct ggml_tensor * a,
1083+
int64_t ne0,
1084+
int64_t ne1);
1085+
1086+
GGML_API struct ggml_tensor * ggml_cont_3d(
1087+
struct ggml_context * ctx,
1088+
struct ggml_tensor * a,
1089+
int64_t ne0,
1090+
int64_t ne1,
1091+
int64_t ne2);
1092+
1093+
GGML_API struct ggml_tensor * ggml_cont_4d(
1094+
struct ggml_context * ctx,
1095+
struct ggml_tensor * a,
1096+
int64_t ne0,
1097+
int64_t ne1,
1098+
int64_t ne2,
1099+
int64_t ne3);
1100+
10751101
// return view(a), b specifies the new shape
10761102
// TODO: when we start computing gradient, make a copy instead of view
10771103
GGML_API struct ggml_tensor * ggml_reshape(

llama.cpp

+4-10
Original file line numberDiff line numberDiff line change
@@ -2893,9 +2893,7 @@ static struct ggml_cgraph * llm_build_llama(
28932893
ggml_set_name(KQV_merged, "KQV_merged");
28942894

28952895
// cur = KQV_merged.contiguous().view(n_embd, n_tokens)
2896-
cur = ggml_cpy(ctx0,
2897-
KQV_merged,
2898-
ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_embd, n_tokens));
2896+
cur = ggml_cont_2d(ctx0, KQV_merged, n_embd, n_tokens);
28992897
offload_func_v(cur);
29002898
ggml_set_name(cur, "KQV_merged_contiguous");
29012899

@@ -3302,9 +3300,7 @@ static struct ggml_cgraph * llm_build_baichaun(
33023300
ggml_set_name(KQV_merged, "KQV_merged");
33033301

33043302
// cur = KQV_merged.contiguous().view(n_embd, n_tokens)
3305-
cur = ggml_cpy(ctx0,
3306-
KQV_merged,
3307-
ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_embd, n_tokens));
3303+
cur = ggml_cont_2d(ctx0, KQV_merged, n_embd, n_tokens);
33083304
offload_func_v(cur);
33093305
ggml_set_name(cur, "KQV_merged_contiguous");
33103306

@@ -3710,7 +3706,7 @@ static struct ggml_cgraph * llm_build_falcon(
37103706
offload_func_v(KQV_merged);
37113707
ggml_set_name(KQV_merged, "KQV_merged");
37123708

3713-
cur = ggml_cpy(ctx0, KQV_merged, ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_embd, n_tokens));
3709+
cur = ggml_cont_2d(ctx0, KQV_merged, n_embd, n_tokens);
37143710
offload_func_v(cur);
37153711
ggml_set_name(cur, "KQV_merged_contiguous");
37163712

@@ -3964,9 +3960,7 @@ static struct ggml_cgraph * llm_build_starcoder(
39643960
ggml_set_name(KQV_merged, "KQV_merged");
39653961

39663962
// cur = KQV_merged.contiguous().view(n_embd, n_tokens)
3967-
cur = ggml_cpy(ctx0,
3968-
KQV_merged,
3969-
ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_embd, n_tokens));
3963+
cur = ggml_cont_2d(ctx0, KQV_merged, n_embd, n_tokens);
39703964
ggml_set_name(cur, "KQV_merged_contiguous");
39713965
}
39723966

0 commit comments

Comments
 (0)