Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ggml: new optimization interface #988

Merged
merged 4 commits into from
Nov 16, 2024

Conversation

JohannesGaessler
Copy link
Collaborator

This PR adapts the training code from the MNIST example into GGML with the goal of establishing a new interface for training models. The goal is to provide downstream projects with a common, more high-level interface for training that can be tested and debugged more easily. The general design is procedural and relies on the definition of data structures for optimization contexts, datasets, and results.

As of right now essentially only feed-forward classifiers are supported. I put the code into a new file ggml-opt.cpp with a corresponding new header ggml-opt.h. One reason for this is that I am using some C++ functionality that is not performance critical but convenient. Another reason is that with the current GGML code there is no need to mess around with the internals of a GGML graph so I think it makes sense to split off functionality that is only going to be used by a subset of the userbase into a separate header (also the general vibe from what I can tell is that people find ggml.c hard to navigate due to its size).

There is still a lot to do but I would like to get feedback on the interface early if possible. In particular, one thing that is still missing is testing code for the new interface. For now the prefix that I am using for the new interface is ggml_opt_new, I plan to change this to ggml_opt and remove the old ggml_opt code prior to merging.

@slaren
Copy link
Collaborator

slaren commented Oct 11, 2024

How do you plan to support multiple GPUs? Currently this interface is taking a ggml_backend, so I think that would limit it to a single GPU.

@JohannesGaessler
Copy link
Collaborator Author

For some reason ggml_backend_sched didn't work correctly but I wasn't able to nail down why because the MNIST example is already relatively complex. My plan is to implement tests for this interface, do the transition to ggml_backend_sched, and use those tests for debugging. (I'm flexible regarding whether that should already be done in this PR or later)

@slaren
Copy link
Collaborator

slaren commented Oct 14, 2024

I mentioned some possible solutions to the problem with ggml_backend_sched in the previous PR. I think that whatever solution is finally implemented, it will need to be factored into the design early on, because it is going to affect fundamentally the way the tensors are allocated, and it is not a detail that can be ignored until the last moment.

@JohannesGaessler
Copy link
Collaborator Author

I added tests for the new optimization interface. I'll do the transition towards ggml_backend_sched next and then I'll try to get this PR into a state where it can be merged.

@JohannesGaessler
Copy link
Collaborator Author

I adapted the new optimization interface to use ggml_backend_sched (see also discussion in #999). The tests I wrote seem to work correctly but the MNIST training results using CUDA have for some reason become worse vs. master so there seem to still be issues. The interface should be mostly stable now though.

@slaren
Copy link
Collaborator

slaren commented Nov 2, 2024

With some changes it can be used with BLAS and Metal. On M3 Max with BLAS it takes just 3 seconds to train, compared to 15 seconds with 3090 Ti CUDA or ~9 seconds with 13900k CPU.

ggml_opt_new_fit: epoch 0029:
train: [========================================| data=057000/057000, loss=0.011858+-0.000376, accuracy=99.86+-0.02%, t=00:00:00, ETA=00:00:00]
val:   [========================================| data=003000/003000, loss=0.065221+-0.011987, accuracy=97.80+-0.27%, t=00:00:00, ETA=00:00:00]

ggml_opt_new_fit: training took 00:00:03
diff --git a/examples/mnist/mnist-common.h b/examples/mnist/mnist-common.h
index 6e2d235..c2a4464 100644
--- a/examples/mnist/mnist-common.h
+++ b/examples/mnist/mnist-common.h
@@ -134,6 +134,17 @@ struct mnist_model {
             devices.push_back(dev);
         }

+        // add accel devices
+        for (size_t i = 0; i < ggml_backend_dev_count(); ++i) {
+            ggml_backend_dev_t dev = ggml_backend_dev_get(i);
+            if (ggml_backend_dev_type(dev) == GGML_BACKEND_DEVICE_TYPE_CPU) {
+                ggml_backend_t backend = ggml_backend_dev_init(dev, nullptr);
+                GGML_ASSERT(backend);
+                backends.push_back(backend);
+                devices.push_back(dev);
+            }
+        }
+
         ggml_backend_dev_t dev_cpu = ggml_backend_dev_by_name("CPU");
         GGML_ASSERT(dev_cpu);
         ggml_backend_t backend_cpu = ggml_backend_dev_init(dev_cpu, nullptr);
@@ -151,12 +162,17 @@ struct mnist_model {
         if (backends.size() == 1) {
             fprintf(stderr, "%s: using %s (%s) backend\n",
                     __func__, ggml_backend_name(backends[0]), ggml_backend_dev_description(devices[0]));
-        } else if (backends.size() == 2) {
-            fprintf(stderr, "%s: using %s (%s) backend with %s (%s) fallback\n",
-                    __func__, ggml_backend_name(backends[0]), ggml_backend_dev_description(devices[0]),
-                    ggml_backend_name(backends[1]), ggml_backend_dev_description(devices[1]));
         } else {
-            GGML_ASSERT(false);
+
+            fprintf(stderr, "%s: using %s (%s) backend with fallbacks: ",
+                    __func__, ggml_backend_name(backends[0]), ggml_backend_dev_description(devices[0]));
+            for (size_t i = 1; i < backends.size(); ++i) {
+                fprintf(stderr, "%s (%s)", ggml_backend_name(backends[i]), ggml_backend_dev_description(devices[i]));
+                if (i + 1 < backends.size()) {
+                    fprintf(stderr, ", ");
+                }
+            }
+            fprintf(stderr, "\n");
         }

         {
diff --git a/src/ggml-metal.m b/src/ggml-metal.m
index fb2efc6..a9f35c7 100644
--- a/src/ggml-metal.m
+++ b/src/ggml-metal.m
@@ -3285,6 +3285,12 @@ static void ggml_backend_metal_buffer_free_buffer(ggml_backend_buffer_t buffer)
     return ctx->all_data;
 }

+static void ggml_backend_metal_buffer_memset_tensor(ggml_backend_buffer_t buffer, struct ggml_tensor * tensor, uint8_t value, size_t offset, size_t size) {
+    memset((char *)tensor->data + offset, value, size);
+
+    UNUSED(buffer);
+}
+
 static void ggml_backend_metal_buffer_set_tensor(ggml_backend_buffer_t buffer, struct ggml_tensor * tensor, const void * data, size_t offset, size_t size) {
     memcpy((char *)tensor->data + offset, data, size);

@@ -3318,7 +3324,7 @@ static void ggml_backend_metal_buffer_clear(ggml_backend_buffer_t buffer, uint8_
     /* .free_buffer     = */ ggml_backend_metal_buffer_free_buffer,
     /* .get_base        = */ ggml_backend_metal_buffer_get_base,
     /* .init_tensor     = */ NULL,
-    /* .memset_tensor   = */ NULL,
+    /* .memset_tensor   = */ ggml_backend_metal_buffer_memset_tensor,
     /* .set_tensor      = */ ggml_backend_metal_buffer_set_tensor,
     /* .get_tensor      = */ ggml_backend_metal_buffer_get_tensor,
     /* .cpy_tensor      = */ ggml_backend_metal_buffer_cpy_tensor,

include/ggml-opt.h Outdated Show resolved Hide resolved
@JohannesGaessler
Copy link
Collaborator Author

I changed the MNIST code slightly to a version that I think is simpler. Am I right in assuming that it's unproblematic to initialize two backends for the same device and to then pass those backends to the same instance of ggml_backend_sched?

@JohannesGaessler
Copy link
Collaborator Author

compared to 15 seconds with 3090 Ti CUDA

The performance seems very poor but since the model is so small that is basically just a measure of the overhead. I think to remember that you are using WSL2 so maybe that has to do with it? On my machines (all running native Linux) I see the following performance:

Hardware Time [s]
RTX 4090 3
RTX 3090 4
P40 5
Epyc 7742 9
RX 6800 14
Ryzen 5950X (65W power limit) 17
Xeon E5-2683 v4 21
Thinkpad T16 Gen 1 30

Notably the RX 6800 is also performing much worse than the P40 even though with llama.cpp the performance is very similar.

@slaren
Copy link
Collaborator

slaren commented Nov 2, 2024

I changed the MNIST code slightly to a version that I think is simpler. Am I right in assuming that it's unproblematic to initialize two backends for the same device and to then pass those backends to the same instance of ggml_backend_sched?

It may waste some resources and make graph splitting a bit slower, but not much. Generally I don't think it is very useful to have multiple GPU backends, the CPU backend is usually a better fallback since it the cost of copying the state is lower.

I think to remember that you are using WSL2 so maybe that has to do with it?

Kernel launch overhead is higher on Windows (it's the same reason -sm row is so slow). I assume that it would be possible to remove nearly all of that using CUDA graphs, since it is the same graph being evaluated many times.

@JohannesGaessler
Copy link
Collaborator Author

I removed the use of GGML graph exports from the MNIST example. In its current state the feature is fundamentally incompatible because it relies on statically allocated CPU tensors (also it would be necessary to mess with the internals of the optimization context).

Currently the optimization interface works by making the user statically allocate the model weights and inputs, and defining the computation of the outputs without allocation. The optimization context then statically allocates tensors for e.g. the optimizer momenta and defines the backward pass without allocation. The unallocated tensors are then given to ggml_backend_sched. I think the correct way to reintroduce graph exports would be to selectively save the data of only those tensors that were statically allocated by the user and to save only the metadata for the other forward graph tensors. The logic that would be needed to minimize disk space is essentially the same that I have already implemented for allocating the forward/backward graphs.

@slaren
Copy link
Collaborator

slaren commented Nov 2, 2024

After this is merged, can all the "opt" functions from ggml.h/ggml.c be removed, or is any that still used? I am moving all the CPU backend specific code to a separate file, and it would be easier if I could just remove these functions, since they only work with the CPU backend.

@slaren
Copy link
Collaborator

slaren commented Nov 2, 2024

About the graph exports - I don't think these are used, it seems that it was an experimental feature that never really took off. It may be better to remove these functions entirely. cc @ggerganov.

@ggerganov
Copy link
Owner

Yes, everything graph export should be removed.

@JohannesGaessler
Copy link
Collaborator Author

JohannesGaessler commented Nov 2, 2024

After this is merged, can all the "opt" functions from ggml.h/ggml.c be removed, or is any that still used?

Actually, my plan was to nail down the features for the new interface, then remove the old ggml_opt functions and rename the new interface to ggml_opt (I just thought it would be easier that way). The ggml_opt functionality on master can already be removed ahead of time I think.

@JohannesGaessler
Copy link
Collaborator Author

I fixed gradient accumulation and I think that this PR is now feature complete and just needs the ggml_op_new -> ggml_opt transition. @slaren since you are currently also doing something where the old optimization interface would be removed, how should we coordinate this?

@slaren
Copy link
Collaborator

slaren commented Nov 2, 2024

I am almost done with the change, I was planning to open a PR later tonight. It's moving code around so there will be merge conflicts, but it should be fairly straightforward to resolve them since I am not changing the functions that you are modifying here.

@JohannesGaessler
Copy link
Collaborator Author

Unless I'm forgetting something I now have all features that I was targeting for this PR. After #1006 is merged all that is left to do is to rebase the code and change the prefix from ggml_opt_new to ggml_opt.

@JohannesGaessler JohannesGaessler marked this pull request as ready for review November 4, 2024 22:08
@ggerganov
Copy link
Owner

Minor patch to clear some compile warnings with clang:

diff --git a/src/ggml-opt.cpp b/src/ggml-opt.cpp
index ec9bccd..a1fb512 100644
--- a/src/ggml-opt.cpp
+++ b/src/ggml-opt.cpp
@@ -635,7 +635,7 @@ void ggml_opt_epoch_callback_progress_bar(
     const int64_t t_eta_m = t_eta_s / 60;
     t_eta_s -= t_eta_m * 60;
 
-    fprintf(stderr, "| data=%06ld/%06ld, loss=%.6lf+-%.6lf, accuracy=%.2lf+-%.2lf%%, t=%02ld:%02ld:%02ld, ETA=%02ld:%02ld:%02ld]\r",
+    fprintf(stderr, "| data=%06" PRId64 "/%06" PRId64 ", loss=%.6lf+-%.6lf, accuracy=%.2lf+-%.2lf%%, t=%02" PRId64 ":%02" PRId64 ":%02" PRId64 ", ETA=%02" PRId64 ":%02" PRId64 ":%02" PRId64 "]\r",
             idata, idata_max, loss, loss_unc, 100.0*accuracy, 100.0*accuracy_unc,
             t_ibatch_h, t_ibatch_m, t_ibatch_s, t_eta_h, t_eta_m, t_eta_s);
     if (ibatch == ibatch_max) {
@@ -712,7 +712,7 @@ void ggml_opt_fit(
         t_total_s -= t_total_h * 3600;
         const int64_t t_total_m = t_total_s / 60;
         t_total_s -= t_total_m * 60;
-        fprintf(stderr, "%s: training took %02ld:%02ld:%02ld\n", __func__, t_total_h, t_total_m, t_total_s);
+        fprintf(stderr, "%s: training took %02" PRId64 ":%02" PRId64 ":%02" PRId64 "\n", __func__, t_total_h, t_total_m, t_total_s);
     }
 
     ggml_opt_free(opt_ctx);

@JohannesGaessler
Copy link
Collaborator Author

  1. Should we put me in ggml-opt.h as the maintainer for this module? I check llama.cpp and ggml issues for things relevant to my contributions at least once a day but maybe this would make it easier for downstream problems to identify who would be the right person to ask for help?
  2. test2.c and test3.c can I think be removed. There is a similar test for convergence in test-opt.cpp and combined with the tests in test-backend-ops.cpp any bug that would cause those tests to fail would already by covered.
  3. What is the reasoning behind e.g. typedef struct ggml_backend * ggml_backend_t;? Is it just to make the code more readable? Should something similar be done for the new structs I added?

@slaren
Copy link
Collaborator

slaren commented Nov 5, 2024

3. What is the reasoning behind e.g. typedef struct ggml_backend * ggml_backend_t;?

The reasoning is that they are opaque types and it is not relevant to the user whether they are structs or not. This is done with all the structs that are hidden from user code.

@ggerganov
Copy link
Owner

  1. Should we put me in ggml-opt.h as the maintainer for this module?

It's up to you. We can also add a CODEOWNERS where maintainers can add themselves if they would like to be notified for PRs.

  1. test2.c and test3.c can I think be removed.

Sounds good.

@JohannesGaessler
Copy link
Collaborator Author

I noticed that the carriage return for the progress bar only results in the expected animation-like behavior if the progress bar is short enough to fit the terminal, otherwise it only returns to the point where the line is broken and spams the terminal with one new line for each minibatch. I just reduced the size of the progress bar but maybe there is a better solution.

@JohannesGaessler
Copy link
Collaborator Author

JohannesGaessler commented Nov 8, 2024

I pushed a version that stores the optimization parameters in the graph. They are allocated in a CPU buffer and written at the start of an eval. The CPU backend can use the parameters directly. I changed the CUDA backend to expect a device buffer with parameters instead of passing the parameters as kernel arguments. To change optimization parameters from their defaults users need to pass a custom function that calculates them.

Edit: no changes to ggml_backend_sched were necessary.

src/ggml-opt.cpp Outdated Show resolved Hide resolved
@JohannesGaessler
Copy link
Collaborator Author

I did some fixup, from my end this PR would be ready to merge. There is still the issue of refactoring the GGML code in such a way that ggml_tensor.grad can be removed but that should not result in any externally visible changes. This is what I'll do next but for me it doesn't matter whether this is done in this PR or another one so I'll just go with whatever variant is easier for you to review.

I adapted the MNIST example README and while doing so I noticed that the convolutional model can now be trained with partial CUDA support which is faster than CPU only.

@JohannesGaessler
Copy link
Collaborator Author

I pushed a refactor of the code around gradients. ggml_tensor.grad has been removed, the tensor <-> gradient mapping is now entirely defined by the compute graph. I don't have a good understanding of the gradient checkpointing code and I am not aware of any other code (particularly tests) that uses it so I decided to remove it since it is now probably broken (I think something like this should be part of ggml-opt anyways). I removed test-grad0 and test1 since my impression is that those tests don't provide enough utility to justify maintaining them (but I have no problem adapting them if someone disagrees).

Currently for to get the gradient or gradient accumulator for a tensor there is a loop over ggml_cgraph.nodes which results in $O(N)$ time to get a gradient (accumulator) and $O(N^2)$ time to build the backward graph. It may make sense to use a hashtable here. This would be a relatively simple change in C++ but a bit tedious in C. So before touching that particular part of the code I would like to ask what the long-term plans for ggml.c are in terms which language to use.

@JohannesGaessler
Copy link
Collaborator Author

Actually, I think the hash table should be part of ggml_cgraph but since that struct is defined in ggml-impl.h and presumably that header should be C-compatible, wouldn't that imply that a C implementation is necessary anyways?

@slaren
Copy link
Collaborator

slaren commented Nov 14, 2024

The goal is to progressive port the code to C++, but modifying ggml_cgraph to require C++ may require updating a lot of code. If you only need a map from ggml_tensor * to node index, that should be possible to do with a small addition to the hash table that is already in ggml_cgraph, just add a an array of values with the same size as the hash set and set it in ggml_visit_parents such as value[ggml_hash(node)] = index.

src/ggml-alloc.c Outdated Show resolved Hide resolved
@JohannesGaessler
Copy link
Collaborator Author

I did an implementation using hashsets but I realized that with the current GGML hashsets building the backward pass would still take quadratic time. If a tensor is not contained in the hash set ggml_hash_find will iterate over the entire set (and that will always be slower than the current implementation). To get a linear runtime for building the backward pass the hash set would need an explicit value for unused cells that causes the function to return when encountered. But since the backward pass is only constructed once per instance of ggml_opt_ctx I think this optimization would not be worthwhile and just add unnecessary complexity. I would merge the backward pass construction as-is and revisit the issue if it at a later date it turns out that optimization is needed.

@slaren
Copy link
Collaborator

slaren commented Nov 14, 2024

If a tensor is not contained in the hash set ggml_hash_find will iterate over the entire set (and that will always be slower than the current implementation).

Is this correct? It should only iterate until the first empty slot. That's just the way of dealing with collisions, but if the table is correctly sized, the number of collisions will be very close to zero.

@JohannesGaessler
Copy link
Collaborator Author

You are absolutely right, looking at the code again it seems I missed part of the condition in the while loop.

@JohannesGaessler
Copy link
Collaborator Author

Definitely thank you for the clarification, my hashset-based implementation (which I now pushed) had a defect related to my misunderstanding that randomly did not manifest as a bug.

@JohannesGaessler
Copy link
Collaborator Author

Are there still pending (re-)reviews or should we merge this?

@ggerganov
Copy link
Owner

Let me have a look now.

Copy link
Owner

@ggerganov ggerganov left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Very cool stuff!

Before merging, I would like to first sync the backend split from llama.cpp and resolve the PR conflicts here. Otherwise if we merge it now, then I will have a much harder time to resolve these conflicts through the sync scripts and the git-am commands since some of the files have moved.

Would that be OK? The sync might be ready tonight, but more likely tomorrow.

@JohannesGaessler
Copy link
Collaborator Author

From my end I am in no particular rush, I am definitely not running out of things to work on (even ignoring projects unrelated to llama.cpp/GGML). I just don't want to do more rebases than necessary since this PR touches a lot of lines.

@ggerganov
Copy link
Owner

Should be ready to rebase and merge.

remove test2.c, test3.c

store adamw params in tensor

move grads from tensor to graph
@JohannesGaessler
Copy link
Collaborator Author

I noticed that the memory for gradient (accumulator) pointers upon graph creation is not explicitly cleared so it was possible to provoke a segfault via API misuse. This is fixed.

There also seem to be build issues on Apple where ggml-cpu.o is not available when linking ggml-opt.o which is why compilation fails. I don't understand why this is happening and would appreciate help.

(I also noticed that ggml-opt.h was missing from the list of public headers, but this seems to be completely unrelated.)

src/ggml-opt.cpp Outdated Show resolved Hide resolved
@JohannesGaessler JohannesGaessler merged commit 0ce2226 into ggerganov:master Nov 16, 2024
4 checks passed
lyapple2008 pushed a commit to lyapple2008/ggml_mars that referenced this pull request Nov 20, 2024
* ggml: new optimization interface

remove test2.c, test3.c

store adamw params in tensor

move grads from tensor to graph

* avoid segfault upon API misuse

* add ggml-opt.h to public headers

* remove dependence of ggml-opt.cpp on ggml-cpu.h
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants