-
Notifications
You must be signed in to change notification settings - Fork 1.1k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
ggml/examples: add backend support for numerical optimization #949
ggml/examples: add backend support for numerical optimization #949
Conversation
Actually, a much bigger issue is that for |
I don't think that's a problem. ggml-backend was designed to not require many changes to the core ggml code, but since then I think it has become the standard way to use ggml, and it doesn't make much sense to maintain the subset of the API that only works with the CPU backend. We should move all the CPU backend code to a separate file, and make all the core ggml functions explicitly compatible with ggml-backend. The design looks good to me. Something to consider is that to support multiple GPUs and fallback to CPU for unimplemented ops in the backends, it is necessary to use |
0fc3efe
to
ed5cde0
Compare
I forgot: the current code also has an extension to the GGML backend interface with |
The Adam optimizer needs to know the current iteration since it does a warmup. I'm currently passing this information via |
I pushed a working prototype for CUDA MNIST training/evaluation (fully connected only). Compared to PyTorch the training on my RTX 3090 is ~45x faster (1.25s vs. 56.58s) but with such a small model you're basically just measuring overhead. The CUDA evaluation is actually slower than the CPU evaluation, presumably because the model is too small to make GPU acceleration worthwhile given the additional overhead. One issue that I still have is how to handle the combination of GGUF+backends other than CPU. Right now I'm allocating a temporary context that just stores the data in RAM but it feels kind of clunky. Is there a better way to do this? |
Check the way the magika example does this: make a |
src/ggml.c
Outdated
for (int i = 0; i < gf->n_nodes; i++) { | ||
struct ggml_tensor * node = gf->nodes[i]; | ||
|
||
if (node->flags & GGML_TENSOR_FLAG_PARAM) { | ||
GGML_PRINT_DEBUG("%s: found root node %p\n", __func__, (void *) node); | ||
struct ggml_tensor * opt_step = ggml_opt_step_adam(ctx, node, 1.0f, 0.001f, 0.9f, 0.999f, 1e-8f); | ||
ggml_build_forward_expand(gb, opt_step); | ||
} | ||
} | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The overall design that I envision is that the optimizer is specified when creating the backwards graph. If no optimizer is specified, calculate the gradients without touching the weights. If an optimizer is specified, apply it to all parameters after the gradients have been calculated by adding an extra GGML op on top (could probably be optimized to overwrite gradients that are no longer needed).
Purely from API PoV, it might be better to have separate calls that expand the graph with an optimizer computation (e.g. ggml_build_opt
or something similar) that can optionally be called after ggml_build_backward
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do you have an opinion on what to do with the current ggml_opt
API? If we keep it the addition of tensors for optimization could be done in ggml_opt_init
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm not really sure how good is the design of the existing ggml_opt
API. I think we can afford to change it significantly, since it is not really adopted by other projects. We can even implement a new API in parallel and when we know which one is better - remove the other.
ggml_opt_init
adding the optimization graph sounds OK to me. The question is if there would be use cases where you would want to create an optimizer, but not immediately "apply" it to a graph. Maybe you might want to apply the same optimizer to multiple graphs? If there is some case like this, then a separate ggml_build_opt
-like step might make sense.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think it's definitely better to have a separate call for adding the optimizer. That way gradient accumulation can be implemented relatively easily by defining one graph that calculates just the gradients and one that also invokes the optimizer.
include/ggml-backend.h
Outdated
@@ -234,6 +234,7 @@ extern "C" { | |||
GGML_API void ggml_backend_tensor_alloc(ggml_backend_buffer_t buffer, struct ggml_tensor * tensor, void * addr); | |||
GGML_API void ggml_backend_view_init(struct ggml_tensor * tensor); | |||
|
|||
GGML_API bool ggml_backend_load_from_gguf(const char * fname, struct ggml_context * ctx_ggml, struct gguf_context * ctx_gguf); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is this temporary? Seems like it does not belong to ggml-backend
. More like a utility function in user code.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What API should we use long-term for loading data from GGUF? I was thinking that since the pattern for tensors created in program code is initialization -> backend allocation -> data setting this would be the equivalent way to do it for GGUF.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There was some talk in llama.cpp about moving some of the loading code to ggml, including mmap support, so that other ggml applications can benefit from it. I am not sure how that API should look, though. It may be good to add this as a first step, but most likely it will need a different API to be able to achieve all the goals.
09d1703
to
397f617
Compare
I think for |
edefe47
to
efaf8e5
Compare
efaf8e5
to
7dd2c94
Compare
Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>
I get comparable results between PyTorch and GGML in terms of training loss when I add the following two modifications: disable dataset shuffling for PyTorch and set the GGML physical batch size to 1000. The latter is a bug since by definition the physical batch size should have no effect beyond differences in rounding error. For the dataset shuffling I would have intuitively expected that this is only relevant for generalization but it seems that it also improves the rate at which the model gets better on the training set. |
I figured out the problem: I incorrectly assumed that the |
So if I understand correctly, the following call is basically a noop atm: ggml/examples/mnist/mnist-common.cpp Lines 563 to 566 in c1d13df
The reason is because I tried your idea, which I think is simply to set the diff --git a/src/ggml.c b/src/ggml.c
index de61438..483a3b2 100644
--- a/src/ggml.c
+++ b/src/ggml.c
@@ -18129,7 +18129,7 @@ static struct ggml_tensor * ggml_add_or_set(struct ggml_context * ctx, struct gg
if (ggml_hash_contains(zero_table, a)) {
return b;
} else {
- return ggml_add_impl(ctx, a, b, false);
+ return ggml_add_impl(ctx, a, b, true);
}
} But it seems we are still missing something, as the training accuracy dropped:
|
I've pushed a WIP fix that works but has bad performance to clarify the problem. The original gradients are initialized with zero and need to be incremented with the sum tensors after each accumulation step to get correct results. Unrelated to the problem with accumulation there are also two other issues: the wrong graph was being copied for |
The problem is the upper branch where the tensor is in the zero table. In that case there needs to be an in-place addition instead of a replacement. But so far I have not been able to make that work so there is likely still some other issue. |
06e8adc
to
cac7aa1
Compare
Sorry, the supposed fix had two bugs that happened to cancel each other out. |
cac7aa1
to
478472b
Compare
I pushed a proper fix. The correct handling of gradient accumulation needs some extra bookkeeping to track the gradients of parameters and whether they should be accumulated, I added a new tensor flag for this. |
Actually, now that I think about it it would maybe be better to do this via a hash set instead of via tensor modification since whether or not a gradient should be accumulated is a property of the compute graph rather than the gradient tensor. But by that logic the existing code in Also: there was some inconsistent use of |
Co-authored-by: slaren <slarengh@gmail.com>
Co-authored-by: slaren <slarengh@gmail.com>
Co-authored-by: slaren <slarengh@gmail.com>
Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Going back to an earlier comment by @slaren:
Something to consider is that to support multiple GPUs and fallback to CPU for unimplemented ops in the backends, it is necessary to use ggml_backend_sched.
Should we attempt to do this within this PR or after we merge the existing changes? Do we see any obstacles to achieve this?
Overall, I think the changes are quite good. I'm not familiar with other training codebases, so not sure if we are missing something obvious from functionality perspective.
My priorities: while Right now I have a prototype for a dataset in user space. |
The issue with |
Re asynchronous data loading: this may already be obvious, but you should look at |
The ultimate goal of this PR is to add backend support for numerical optimization, namely Adam and L-BFGS. As of right now the corresponding computations are done by a single thread outside any of the GGML graphs. As a consequence only a single thread is used and only the CPU backend is compatible. I think the correct way to remedy this is to make the optimizers part of the GGML compute graphs. This also fixes some allocation issues where the optimization code allocates extra tensors to hold persistent extra data for the optimizers.
As of right now this PR contains my WIP version that only supports stochastic gradient descent and the CPU backend. The training is ~3x faster than on master (but the overall rate of convergence is worse than fully featured Adam).
The overall design that I envision is that the optimizer is specified when creating the backwards graph. If no optimizer is specified, calculate the gradients without touching the weights. If an optimizer is specified, apply it to all parameters after the gradients have been calculated by adding an extra GGML op on top (could probably be optimized to overwrite gradients that are no longer needed). During backwards graph creation also specify any extra tensors needed for the optimizer so they can be correctly allocated for all backends. Functions like
ggml_opt
would then mainly be calling the backwards graph in a loop and check convergence. One potential issue is that the convergence logic would require calls toggml_backend_tensor_get
which would makeggml.c
depend onggml_backend.c
(which it currently does not). If that is a problem the optimization code could maybe be moved to a new file likeggml-algo.c
.If there are issues with my design please let me know early.