@@ -18123,36 +18123,75 @@ void ggml_build_backward_gradient_checkpointing(
1812318123 ggml_hash_map_free(replacements);
1812418124}
1812518125
18126- // functions to change gradients considering the case that input a might be initial gradient with zero value
18126+ // utility functions to change gradients
18127+ // by default, just add/subtract/etc. the gradients
18128+ // if a is in zero_table and not a gradient accumulator, replace a
18129+ // if a is in zero_table and a gradient accumulator, modify gradients in-place and mark result as gradient accumulator
1812718130
1812818131static struct ggml_tensor * ggml_add_or_set(struct ggml_context * ctx, struct ggml_tensor * a, struct ggml_tensor * b, struct ggml_hash_set * zero_table) {
1812918132 if (ggml_hash_contains(zero_table, a)) {
18130- return b;
18133+ if (a->flags & GGML_TENSOR_FLAG_GRAD_ACC) {
18134+ struct ggml_tensor * ret = ggml_add_impl(ctx, a, b, true);
18135+ ret->flags |= GGML_TENSOR_FLAG_GRAD_ACC;
18136+ const size_t insert_result = ggml_hash_insert(zero_table, ret);
18137+ GGML_ASSERT(insert_result != GGML_HASHSET_FULL);
18138+ GGML_ASSERT(insert_result != GGML_HASHSET_ALREADY_EXISTS);
18139+ return ret;
18140+ } else {
18141+ return b;
18142+ }
1813118143 } else {
1813218144 return ggml_add_impl(ctx, a, b, false);
1813318145 }
1813418146}
1813518147
1813618148static struct ggml_tensor * ggml_acc_or_set(struct ggml_context * ctx, struct ggml_tensor * a, struct ggml_tensor * b, size_t nb1, size_t nb2, size_t nb3, size_t offset, struct ggml_hash_set * zero_table) {
1813718149 if (ggml_hash_contains(zero_table, a)) {
18138- struct ggml_tensor * a_zero = ggml_scale(ctx, a, 0.0f);
18139- return ggml_acc_impl(ctx, a_zero, b, nb1, nb2, nb3, offset, false);
18150+ if (a->flags & GGML_TENSOR_FLAG_GRAD_ACC) {
18151+ struct ggml_tensor * ret = ggml_acc_impl(ctx, a, b, nb1, nb2, nb3, offset, true);
18152+ ret->flags |= GGML_TENSOR_FLAG_GRAD_ACC;
18153+ const size_t insert_result = ggml_hash_insert(zero_table, ret);
18154+ GGML_ASSERT(insert_result != GGML_HASHSET_FULL);
18155+ GGML_ASSERT(insert_result != GGML_HASHSET_ALREADY_EXISTS);
18156+ return ret;
18157+ } else {
18158+ struct ggml_tensor * a_zero = ggml_scale(ctx, a, 0.0f); // FIXME this is going to produce NaN if a contains inf/NaN
18159+ return ggml_acc_impl(ctx, a_zero, b, nb1, nb2, nb3, offset, false);
18160+ }
1814018161 } else {
1814118162 return ggml_acc_impl(ctx, a, b, nb1, nb2, nb3, offset, false);
1814218163 }
1814318164}
1814418165
1814518166static struct ggml_tensor * ggml_add1_or_set(struct ggml_context * ctx, struct ggml_tensor * a, struct ggml_tensor * b, struct ggml_hash_set * zero_table) {
1814618167 if (ggml_hash_contains(zero_table, a)) {
18147- return ggml_repeat(ctx, b, a);
18168+ if (a->flags & GGML_TENSOR_FLAG_GRAD_ACC) {
18169+ struct ggml_tensor * ret = ggml_add1_impl(ctx, a, b, true);
18170+ ret->flags |= GGML_TENSOR_FLAG_GRAD_ACC;
18171+ const size_t insert_result = ggml_hash_insert(zero_table, ret);
18172+ GGML_ASSERT(insert_result != GGML_HASHSET_FULL);
18173+ GGML_ASSERT(insert_result != GGML_HASHSET_ALREADY_EXISTS);
18174+ return ret;
18175+ } else {
18176+ return ggml_repeat(ctx, b, a);
18177+ }
1814818178 } else {
1814918179 return ggml_add1_impl(ctx, a, b, false);
1815018180 }
1815118181}
1815218182
1815318183static struct ggml_tensor * ggml_sub_or_set(struct ggml_context * ctx, struct ggml_tensor * a, struct ggml_tensor * b, struct ggml_hash_set * zero_table) {
1815418184 if (ggml_hash_contains(zero_table, a)) {
18155- return ggml_neg(ctx, b);
18185+ if (a->flags & GGML_TENSOR_FLAG_GRAD_ACC) {
18186+ struct ggml_tensor * ret = ggml_sub_impl(ctx, a, b, true);
18187+ ret->flags |= GGML_TENSOR_FLAG_GRAD_ACC;
18188+ const size_t insert_result = ggml_hash_insert(zero_table, ret);
18189+ GGML_ASSERT(insert_result != GGML_HASHSET_FULL);
18190+ GGML_ASSERT(insert_result != GGML_HASHSET_ALREADY_EXISTS);
18191+ return ret;
18192+ } else {
18193+ return ggml_neg(ctx, b);
18194+ }
1815618195 } else {
1815718196 return ggml_sub_impl(ctx, a, b, false);
1815818197 }
@@ -19136,22 +19175,25 @@ void ggml_build_backward_expand(struct ggml_context * ctx, struct ggml_cgraph *
1913619175 }
1913719176 }
1913819177
19139- // hash table of original gradients that should be overwritten instead of incremented
19178+ // keep table of original gradients for replacement/accumulation logic
1914019179 struct ggml_hash_set zero_table = ggml_hash_set_new(gf->size);
19180+ for (int i = 0; i < gf->n_nodes; i++) {
19181+ struct ggml_tensor * node = gf->nodes[i];
1914119182
19142- // when accumulating gradients the table is empty -> gradients always incremented
19143- if (!accumulate) {
19144- for (int i = 0; i < gf->n_nodes; i++) {
19145- if (gf->grads[i]) {
19146- ggml_hash_insert(&zero_table, gf->grads[i]);
19183+ if (node->grad) {
19184+ // only gradients of trainable parameters should be accumulated
19185+ if (accumulate && (node->flags & GGML_TENSOR_FLAG_PARAM)) {
19186+ node->grad->flags |= GGML_TENSOR_FLAG_GRAD_ACC;
1914719187 }
19188+
19189+ ggml_hash_insert(&zero_table, node->grad);
1914819190 }
1914919191 }
1915019192
1915119193 for (int i = gf->n_nodes - 1; i >= 0; i--) {
1915219194 struct ggml_tensor * node = gf->nodes[i];
1915319195
19154- // inplace operations to add gradients are not created by ggml_compute_backward
19196+ // inplace operations to add gradients are not created by ggml_compute_backward except for gradient accumulation
1915519197 // use allocator to automatically make inplace operations
1915619198 if (node->grad) {
1915719199 ggml_compute_backward(ctx, node, &zero_table);
@@ -19319,19 +19361,18 @@ void ggml_graph_reset(struct ggml_cgraph * cgraph) {
1931919361
1932019362 for (int i = 0; i < cgraph->n_nodes; i++) {
1932119363 struct ggml_tensor * node = cgraph->nodes[i];
19322- struct ggml_tensor * grad = cgraph->grads[i];
1932319364
1932419365 // initial gradients of loss should be 1, 0 otherwise
19325- if (grad) {
19366+ if (node-> grad) {
1932619367 if (node->flags & GGML_TENSOR_FLAG_LOSS) {
19327- GGML_ASSERT(grad->buffer);
19368+ GGML_ASSERT(node-> grad->buffer);
1932819369 GGML_ASSERT(node->type == GGML_TYPE_F32);
1932919370 GGML_ASSERT(ggml_is_scalar(node));
1933019371
1933119372 const float onef = 1.0f;
19332- ggml_backend_tensor_set(grad, &onef, 0, ggml_nbytes(grad));
19373+ ggml_backend_tensor_set(node-> grad, &onef, 0, ggml_nbytes(node-> grad));
1933319374 } else {
19334- ggml_set_zero(grad);
19375+ ggml_set_zero(node-> grad);
1933519376 }
1933619377 }
1933719378
0 commit comments