diff --git a/ggml/src/ggml-cuda/common.cuh b/ggml/src/ggml-cuda/common.cuh index 64fb4ff4cec..2b00c024fed 100644 --- a/ggml/src/ggml-cuda/common.cuh +++ b/ggml/src/ggml-cuda/common.cuh @@ -763,6 +763,7 @@ struct ggml_backend_cuda_context { cublasHandle_t cublas_handles[GGML_CUDA_MAX_DEVICES] = {nullptr}; std::unique_ptr cuda_graph; + std::vector cuda_graphs; explicit ggml_backend_cuda_context(int device) : device(device), @@ -783,6 +784,13 @@ struct ggml_backend_cuda_context { CUBLAS_CHECK(cublasDestroy(cublas_handles[i])); } } + while (!cuda_graphs.empty()) { + auto graph = cuda_graphs.back(); + cuda_graphs.pop_back(); + if (graph != nullptr) { + delete graph; + } + } } cudaStream_t stream(int device, int stream) { diff --git a/ggml/src/ggml-cuda/ggml-cuda.cu b/ggml/src/ggml-cuda/ggml-cuda.cu index c442a649243..4a769bef7ec 100644 --- a/ggml/src/ggml-cuda/ggml-cuda.cu +++ b/ggml/src/ggml-cuda/ggml-cuda.cu @@ -2581,18 +2581,26 @@ static bool ggml_graph_node_has_matching_properties(ggml_tensor * node, ggml_gra return true; } -static bool is_cuda_graph_update_required(ggml_backend_cuda_context * cuda_ctx, ggml_cgraph * cgraph) { +static bool is_cuda_graph_update_required(ggml_backend_cuda_context * cuda_ctx, ggml_cgraph * cgraph, ggml_cuda_graph * cuda_graph, bool update_properties) { bool cuda_graph_update_required = false; - if (cuda_ctx->cuda_graph->instance == nullptr) { + if (cuda_graph->instance == nullptr) { cuda_graph_update_required = true; + if (!update_properties) { + return cuda_graph_update_required; + } } // Check if the graph size has changed - if (cuda_ctx->cuda_graph->ggml_graph_properties.size() != (size_t)cgraph->n_nodes) { + if (cuda_graph->ggml_graph_properties.size() != (size_t)cgraph->n_nodes) { cuda_graph_update_required = true; - cuda_ctx->cuda_graph->ggml_graph_properties.resize(cgraph->n_nodes); + if (update_properties) { + cuda_graph->ggml_graph_properties.resize(cgraph->n_nodes); + } + else { + return cuda_graph_update_required; + } } // Loop over nodes in GGML graph to determine if CUDA graph update is required @@ -2600,15 +2608,22 @@ static bool is_cuda_graph_update_required(ggml_backend_cuda_context * cuda_ctx, for (int i = 0; i < cgraph->n_nodes; i++) { bool has_matching_properties = true; if (!cuda_graph_update_required) { - has_matching_properties = ggml_graph_node_has_matching_properties(cgraph->nodes[i], &cuda_ctx->cuda_graph->ggml_graph_properties[i]); + has_matching_properties = ggml_graph_node_has_matching_properties(cgraph->nodes[i], &cuda_graph->ggml_graph_properties[i]); } if (!has_matching_properties) { cuda_graph_update_required = true; + if (!update_properties) { + return cuda_graph_update_required; + } + } + if (update_properties) { + set_ggml_graph_node_properties(cgraph->nodes[i], &cuda_graph->ggml_graph_properties[i]); } - set_ggml_graph_node_properties(cgraph->nodes[i], &cuda_ctx->cuda_graph->ggml_graph_properties[i]); } return cuda_graph_update_required; + + GGML_UNUSED(cuda_ctx); } static void update_cuda_graph_executable(ggml_backend_cuda_context * cuda_ctx) { @@ -2714,6 +2729,13 @@ static enum ggml_status ggml_backend_cuda_graph_compute(ggml_backend_t backend, cuda_ctx->cuda_graph.reset(new ggml_cuda_graph()); } + // the input node may change to a different address in layer split + // mode which cuases the graph to be invalidated. cache some number of graphs + // and search them all. + while (cuda_ctx->cuda_graphs.size() < 4) { + cuda_ctx->cuda_graphs.emplace_back(new ggml_cuda_graph()); + } + bool use_cuda_graph = true; bool cuda_graph_update_required = false; @@ -2737,7 +2759,27 @@ static enum ggml_status ggml_backend_cuda_graph_compute(ggml_backend_t backend, } if (use_cuda_graph) { - cuda_graph_update_required = is_cuda_graph_update_required(cuda_ctx, cgraph); + // find a matching graph, testing most recent one first, then check lru + if (is_cuda_graph_update_required(cuda_ctx, cgraph, cuda_ctx->cuda_graph.get(), false)) { + for (size_t graph_index = 0; graph_index < cuda_ctx->cuda_graphs.size(); graph_index++) { + auto cuda_graph = cuda_ctx->cuda_graphs[graph_index]; + + if (graph_index == cuda_ctx->cuda_graphs.size() - 1) { + cuda_ctx->cuda_graphs.erase(cuda_ctx->cuda_graphs.begin() + graph_index); + cuda_graph_update_required = is_cuda_graph_update_required(cuda_ctx, cgraph, cuda_graph, true); + ggml_cuda_graph * existing = cuda_ctx->cuda_graph.release(); + cuda_ctx->cuda_graph.reset(cuda_graph); + cuda_ctx->cuda_graphs.insert(cuda_ctx->cuda_graphs.begin(), existing); + break; + } else if (!is_cuda_graph_update_required(cuda_ctx, cgraph, cuda_graph, false)) { + cuda_ctx->cuda_graphs.erase(cuda_ctx->cuda_graphs.begin() + graph_index); + ggml_cuda_graph * existing = cuda_ctx->cuda_graph.release(); + cuda_ctx->cuda_graph.reset(cuda_graph); + cuda_ctx->cuda_graphs.insert(cuda_ctx->cuda_graphs.begin(), existing); + break; + } + } + } use_cuda_graph = check_node_graph_compatibility_and_refresh_copy_ops(cuda_ctx, cgraph, use_cuda_graph);