diff --git a/cpp/include/cuopt/linear_programming/utilities/cython_solve.hpp b/cpp/include/cuopt/linear_programming/utilities/cython_solve.hpp index de4d9cf66..cc30ff7a0 100644 --- a/cpp/include/cuopt/linear_programming/utilities/cython_solve.hpp +++ b/cpp/include/cuopt/linear_programming/utilities/cython_solve.hpp @@ -103,7 +103,8 @@ struct solver_ret_t { // Wrapper for solve to expose the API to cython. std::unique_ptr call_solve(cuopt::mps_parser::data_model_view_t*, - linear_programming::solver_settings_t*); + linear_programming::solver_settings_t*, + unsigned int flags = cudaStreamNonBlocking); std::pair>, double> call_batch_solve( std::vector*>, diff --git a/cpp/src/linear_programming/utilities/cython_solve.cu b/cpp/src/linear_programming/utilities/cython_solve.cu index b333e99a4..2b784beeb 100644 --- a/cpp/src/linear_programming/utilities/cython_solve.cu +++ b/cpp/src/linear_programming/utilities/cython_solve.cu @@ -208,12 +208,13 @@ mip_ret_t call_solve_mip( std::unique_ptr call_solve( cuopt::mps_parser::data_model_view_t* data_model, - cuopt::linear_programming::solver_settings_t* solver_settings) + cuopt::linear_programming::solver_settings_t* solver_settings, + unsigned int flags) { raft::common::nvtx::range fun_scope("Call Solve"); cudaStream_t stream; - RAFT_CUDA_TRY(cudaStreamCreateWithFlags(&stream, cudaStreamNonBlocking)); + RAFT_CUDA_TRY(cudaStreamCreateWithFlags(&stream, flags)); const raft::handle_t handle_{stream}; auto op_problem = data_model_to_optimization_problem(data_model, solver_settings, &handle_); @@ -283,9 +284,11 @@ std::pair>, double> call_batch_solve( solver_settings->set_parameter(CUOPT_METHOD, CUOPT_METHOD_PDLP); } + // Use a default stream instead of a non-blocking to avoid invalid operations while some CUDA + // Graph might be capturing in another stream #pragma omp parallel for num_threads(max_thread) for (std::size_t i = 0; i < size; ++i) - list[i] = std::move(call_solve(data_models[i], solver_settings)); + list[i] = std::move(call_solve(data_models[i], solver_settings, cudaStreamDefault)); auto end = std::chrono::high_resolution_clock::now(); auto duration = std::chrono::duration_cast(end - start_solver);