From 82b15e66410f95b30d51d91d13d71a2c8d3ea602 Mon Sep 17 00:00:00 2001 From: Keenan Gugeler Date: Thu, 16 Nov 2023 14:46:16 -0500 Subject: [PATCH] prepared_statement: move parameters In the Java API, we had a bug where we take ownership of and free parameters passed into executeWithParams. Inspecting the method itself, it was taking a shared_ptr, but then performing a deep copy, which is nonsense. Instead, we should take a unique_ptr, since we need to copy the parameters to guarantee that they are not modified for the duration of the query. This commit also fixes three other issues. First, the Java tests weren't running any tests from ConnectionTest.java, which is why we didn't observe this bug. Additionally, the constructor of KuzuConnection uses an assertion, but assertions are disabled by default, which causes our tests to fail (and if the assertion is skipped, we segfault). Also, since https://github.com/rust-lang/cc-rs/issues/900 has been fixed, we can remove the version pinning of `cc` on MacOS. --- .github/workflows/ci-workflow.yml | 1 - Makefile | 4 +-- src/c_api/connection.cpp | 18 +++++++--- src/include/main/connection.h | 16 ++++----- src/main/connection.cpp | 8 ++--- tools/java_api/src/jni/kuzu_java.cpp | 26 ++++++-------- .../main/java/com/kuzudb/KuzuConnection.java | 3 +- .../src/main/java/com/kuzudb/KuzuValue.java | 1 + .../java/com/kuzudb/test/ConnectionTest.java | 11 ++++-- .../src_cpp/include/node_connection.h | 10 +++--- tools/nodejs_api/src_cpp/include/node_util.h | 11 +++--- tools/nodejs_api/src_cpp/node_connection.cpp | 4 +-- tools/nodejs_api/src_cpp/node_util.cpp | 17 +++++----- .../src_cpp/include/py_connection.h | 6 ---- tools/python_api/src_cpp/py_connection.cpp | 34 ++++++++++--------- 15 files changed, 85 insertions(+), 85 deletions(-) diff --git a/.github/workflows/ci-workflow.yml b/.github/workflows/ci-workflow.yml index 740cce9f6d..2512178f9a 100644 --- a/.github/workflows/ci-workflow.yml +++ b/.github/workflows/ci-workflow.yml @@ -359,7 +359,6 @@ jobs: run: | ulimit -n 10240 source /Users/runner/.cargo/env - cargo update -p cc --precise '1.0.83' make rusttest - name: Rust example diff --git a/Makefile b/Makefile index c779e796a1..d5e99124fc 100644 --- a/Makefile +++ b/Makefile @@ -128,11 +128,11 @@ nodejstest: nodejs javatest: java ifeq ($(OS),Windows_NT) $(call mkdirp,tools/java_api/build/test) && cd tools/java_api/ && \ - javac -d build/test -cp ".;build/kuzu_java.jar;third_party/junit-platform-console-standalone-1.9.3.jar" -sourcepath src/test/java/com/kuzudb/test/*.java && \ + javac -d build/test -cp ".;build/kuzu_java.jar;third_party/junit-platform-console-standalone-1.9.3.jar" src/test/java/com/kuzudb/test/*.java && \ java -jar third_party/junit-platform-console-standalone-1.9.3.jar -cp ".;build/kuzu_java.jar;build/test/" --scan-classpath --include-package=com.kuzudb.java_test --details=verbose else $(call mkdirp,tools/java_api/build/test) && cd tools/java_api/ && \ - javac -d build/test -cp ".:build/kuzu_java.jar:third_party/junit-platform-console-standalone-1.9.3.jar" -sourcepath src/test/java/com/kuzudb/test/*.java && \ + javac -d build/test -cp ".:build/kuzu_java.jar:third_party/junit-platform-console-standalone-1.9.3.jar" src/test/java/com/kuzudb/test/*.java && \ java -jar third_party/junit-platform-console-standalone-1.9.3.jar -cp ".:build/kuzu_java.jar:build/test/" --scan-classpath --include-package=com.kuzudb.java_test --details=verbose endif diff --git a/src/c_api/connection.cpp b/src/c_api/connection.cpp index add525690e..80a1cfa9ab 100644 --- a/src/c_api/connection.cpp +++ b/src/c_api/connection.cpp @@ -67,7 +67,7 @@ kuzu_prepared_statement* kuzu_connection_prepare(kuzu_connection* connection, co auto* c_prepared_statement = new kuzu_prepared_statement; c_prepared_statement->_prepared_statement = prepared_statement; c_prepared_statement->_bound_values = - new std::unordered_map>; + new std::unordered_map>; return c_prepared_statement; } @@ -75,11 +75,19 @@ kuzu_query_result* kuzu_connection_execute( kuzu_connection* connection, kuzu_prepared_statement* prepared_statement) { auto prepared_statement_ptr = static_cast(prepared_statement->_prepared_statement); - auto bound_values = static_cast>*>( + auto bound_values = static_cast>*>( prepared_statement->_bound_values); - auto query_result = static_cast(connection->_connection) - ->executeWithParams(prepared_statement_ptr, *bound_values) - .release(); + + // Must copy the parameters for safety. + std::unordered_map> copied_bound_values; + for (auto& [name, value] : *bound_values) { + copied_bound_values.emplace(name, value->copy()); + } + + auto query_result = + static_cast(connection->_connection) + ->executeWithParams(prepared_statement_ptr, std::move(copied_bound_values)) + .release(); if (query_result == nullptr) { return nullptr; } diff --git a/src/include/main/connection.h b/src/include/main/connection.h index c12ef2de42..48cca26242 100644 --- a/src/include/main/connection.h +++ b/src/include/main/connection.h @@ -82,8 +82,8 @@ class Connection { template inline std::unique_ptr execute( PreparedStatement* preparedStatement, std::pair... args) { - std::unordered_map> inputParameters; - return executeWithParams(preparedStatement, inputParameters, args...); + std::unordered_map> inputParameters; + return executeWithParams(preparedStatement, std::move(inputParameters), args...); } /** * @brief Executes the given prepared statement with inputParams and returns the result. @@ -93,7 +93,7 @@ class Connection { * @return the result of the query. */ KUZU_API std::unique_ptr executeWithParams(PreparedStatement* preparedStatement, - std::unordered_map>& inputParams); + std::unordered_map> inputParams); /** * @brief interrupts all queries currently executing within this connection. */ @@ -151,16 +151,16 @@ class Connection { template std::unique_ptr executeWithParams(PreparedStatement* preparedStatement, - std::unordered_map>& params, + std::unordered_map> params, std::pair arg, std::pair... args) { auto name = arg.first; - auto val = std::make_shared((T)arg.second); - params.insert({name, val}); - return executeWithParams(preparedStatement, params, args...); + auto val = std::make_unique((T)arg.second); + params.insert({name, std::move(val)}); + return executeWithParams(preparedStatement, std::move(params), args...); } void bindParametersNoLock(PreparedStatement* preparedStatement, - std::unordered_map>& inputParams); + std::unordered_map> inputParams); std::unique_ptr executeAndAutoCommitIfNecessaryNoLock( PreparedStatement* preparedStatement, uint32_t planIdx = 0u); diff --git a/src/main/connection.cpp b/src/main/connection.cpp index 77f3850563..1b2b6ae3cd 100644 --- a/src/main/connection.cpp +++ b/src/main/connection.cpp @@ -157,13 +157,13 @@ uint64_t Connection::getQueryTimeOut() { } std::unique_ptr Connection::executeWithParams(PreparedStatement* preparedStatement, - std::unordered_map>& inputParams) { + std::unordered_map> inputParams) { lock_t lck{mtx}; if (!preparedStatement->isSuccess()) { return queryResultWithError(preparedStatement->errMsg); } try { - bindParametersNoLock(preparedStatement, inputParams); + bindParametersNoLock(preparedStatement, std::move(inputParams)); } catch (Exception& exception) { std::string errMsg = exception.what(); return queryResultWithError(errMsg); @@ -172,7 +172,7 @@ std::unique_ptr Connection::executeWithParams(PreparedStatement* pr } void Connection::bindParametersNoLock(PreparedStatement* preparedStatement, - std::unordered_map>& inputParams) { + std::unordered_map> inputParams) { auto& parameterMap = preparedStatement->parameterMap; for (auto& [name, value] : inputParams) { if (!parameterMap.contains(name)) { @@ -184,7 +184,7 @@ void Connection::bindParametersNoLock(PreparedStatement* preparedStatement, value->getDataType()->toString() + " but expects " + expectParam->getDataType()->toString() + "."); } - parameterMap.at(name)->copyValueFrom(*value); + parameterMap.at(name) = std::move(value); } } diff --git a/tools/java_api/src/jni/kuzu_java.cpp b/tools/java_api/src/jni/kuzu_java.cpp index b5afacde4d..92a5a07195 100644 --- a/tools/java_api/src/jni/kuzu_java.cpp +++ b/tools/java_api/src/jni/kuzu_java.cpp @@ -1,7 +1,5 @@ -#include #include -#include "binder/bound_statement_result.h" // This header is generated at build time. See CMakeLists.txt. #include "com_kuzudb_KuzuNative.h" #include "common/exception/conversion.h" @@ -11,10 +9,8 @@ #include "common/types/value/node.h" #include "common/types/value/rel.h" #include "common/types/value/value.h" -#include "json.hpp" #include "main/kuzu.h" #include "main/query_summary.h" -#include "planner/operator/logical_plan.h" #include using namespace kuzu::main; @@ -116,8 +112,8 @@ std::string dataTypeToString(const LogicalType& dataType) { return LogicalTypeUtils::toString(dataType.getLogicalTypeID()); } -void javaMapToCPPMap( - JNIEnv* env, jobject javaMap, std::unordered_map>& cppMap) { +std::unordered_map> javaMapToCPPMap( + JNIEnv* env, jobject javaMap) { jclass mapClass = env->FindClass("java/util/Map"); jmethodID entrySet = env->GetMethodID(mapClass, "entrySet", "()Ljava/util/Set;"); @@ -132,20 +128,22 @@ void javaMapToCPPMap( jmethodID entryGetKey = env->GetMethodID(entryClass, "getKey", "()Ljava/lang/Object;"); jmethodID entryGetValue = env->GetMethodID(entryClass, "getValue", "()Ljava/lang/Object;"); + std::unordered_map> result; while (env->CallBooleanMethod(iter, hasNext)) { jobject entry = env->CallObjectMethod(iter, next); jstring key = (jstring)env->CallObjectMethod(entry, entryGetKey); jobject value = env->CallObjectMethod(entry, entryGetValue); const char* keyStr = env->GetStringUTFChars(key, JNI_FALSE); - Value* v = getValue(env, value); - std::shared_ptr value_ptr(v); - cppMap.insert({keyStr, value_ptr}); + const Value* v = getValue(env, value); + // Java code can keep a reference to the value, so we cannot move. + result.insert({keyStr, v->copy()}); env->DeleteLocalRef(entry); env->ReleaseStringUTFChars(key, keyStr); env->DeleteLocalRef(key); env->DeleteLocalRef(value); } + return result; } /** @@ -301,14 +299,10 @@ JNIEXPORT jobject JNICALL Java_com_kuzudb_KuzuNative_kuzu_1connection_1execute( Connection* conn = getConnection(env, thisConn); PreparedStatement* ps = getPreparedStatement(env, preStm); - std::unordered_map> param; - javaMapToCPPMap(env, param_map, param); + std::unordered_map> params = + javaMapToCPPMap(env, param_map); - for (auto const& pair : param) { - std::cout << "{" << pair.first << ": " << pair.second.get()->toString() << "}\n"; - } - - auto query_result = conn->executeWithParams(ps, param).release(); + auto query_result = conn->executeWithParams(ps, std::move(params)).release(); if (query_result == nullptr) { return nullptr; } diff --git a/tools/java_api/src/main/java/com/kuzudb/KuzuConnection.java b/tools/java_api/src/main/java/com/kuzudb/KuzuConnection.java index fd80c51504..1e6d59d24c 100644 --- a/tools/java_api/src/main/java/com/kuzudb/KuzuConnection.java +++ b/tools/java_api/src/main/java/com/kuzudb/KuzuConnection.java @@ -16,7 +16,8 @@ public class KuzuConnection { * @param db: KuzuDatabase instance. */ public KuzuConnection(KuzuDatabase db) { - assert db != null : "Cannot create connection, database is null."; + if (db == null) + throw new AssertionError("Cannot create connection, database is null."); conn_ref = KuzuNative.kuzu_connection_init(db); } diff --git a/tools/java_api/src/main/java/com/kuzudb/KuzuValue.java b/tools/java_api/src/main/java/com/kuzudb/KuzuValue.java index 547237f128..ffd978f75e 100644 --- a/tools/java_api/src/main/java/com/kuzudb/KuzuValue.java +++ b/tools/java_api/src/main/java/com/kuzudb/KuzuValue.java @@ -74,6 +74,7 @@ public void destroy() throws KuzuObjectRefDestroyedException { KuzuNative.kuzu_value_destroy(this); destroyed = true; } + System.gc(); } /** diff --git a/tools/java_api/src/test/java/com/kuzudb/test/ConnectionTest.java b/tools/java_api/src/test/java/com/kuzudb/test/ConnectionTest.java index eddeff1f45..4c99df3989 100644 --- a/tools/java_api/src/test/java/com/kuzudb/test/ConnectionTest.java +++ b/tools/java_api/src/test/java/com/kuzudb/test/ConnectionTest.java @@ -256,8 +256,10 @@ void ConnPrepareInterval() throws KuzuObjectRefDestroyedException { void ConnPrepareMultiParam() throws KuzuObjectRefDestroyedException { String query = "MATCH (a:person) WHERE a.lastJobDuration > $1 AND a.fName = $2 RETURN COUNT(*)"; Map m = new HashMap(); - m.put("1", new KuzuValue(Duration.ofDays(730))); - m.put("2", new KuzuValue("Alice")); + KuzuValue v1 = new KuzuValue(Duration.ofDays(730)); + KuzuValue v2 = new KuzuValue("Alice"); + m.put("1", v1); + m.put("2", v2); KuzuPreparedStatement statement = conn.prepare(query); assertNotNull(statement); KuzuQueryResult result = conn.execute(statement, m); @@ -270,6 +272,11 @@ void ConnPrepareMultiParam() throws KuzuObjectRefDestroyedException { assertEquals(((long) tuple.getValue(0).getValue()), 1); statement.destroy(); result.destroy(); + + // Not strictly necessary, but this makes sure if we freed v1 or v2 in + // the execute() call, we segfault here. + v1.destroy(); + v2.destroy(); } @Test diff --git a/tools/nodejs_api/src_cpp/include/node_connection.h b/tools/nodejs_api/src_cpp/include/node_connection.h index 19104d11b8..d4a1761a43 100644 --- a/tools/nodejs_api/src_cpp/include/node_connection.h +++ b/tools/nodejs_api/src_cpp/include/node_connection.h @@ -1,7 +1,5 @@ #pragma once -#include - #include "main/kuzu.h" #include "node_database.h" #include "node_prepared_statement.h" @@ -65,15 +63,15 @@ class ConnectionExecuteAsyncWorker : public Napi::AsyncWorker { public: ConnectionExecuteAsyncWorker(Napi::Function& callback, std::shared_ptr& connection, std::shared_ptr preparedStatement, NodeQueryResult* nodeQueryResult, - std::unordered_map>& params) + std::unordered_map> params) : Napi::AsyncWorker(callback), preparedStatement(preparedStatement), - nodeQueryResult(nodeQueryResult), connection(connection), params(params) {} + nodeQueryResult(nodeQueryResult), connection(connection), params(std::move(params)) {} ~ConnectionExecuteAsyncWorker() = default; inline void Execute() override { try { std::shared_ptr result = - std::move(connection->executeWithParams(preparedStatement.get(), params)); + connection->executeWithParams(preparedStatement.get(), std::move(params)); nodeQueryResult->SetQueryResult(result); if (!result->isSuccess()) { SetError(result->getErrorMessage()); @@ -90,5 +88,5 @@ class ConnectionExecuteAsyncWorker : public Napi::AsyncWorker { std::shared_ptr connection; std::shared_ptr preparedStatement; NodeQueryResult* nodeQueryResult; - std::unordered_map> params; + std::unordered_map> params; }; diff --git a/tools/nodejs_api/src_cpp/include/node_util.h b/tools/nodejs_api/src_cpp/include/node_util.h index 26bbe6d156..e8980f61d5 100644 --- a/tools/nodejs_api/src_cpp/include/node_util.h +++ b/tools/nodejs_api/src_cpp/include/node_util.h @@ -1,10 +1,6 @@ #pragma once -#include -#include -#include - -#include "main/kuzu.h" +#include "common/types/value/value.h" #include using namespace kuzu::common; @@ -12,8 +8,9 @@ using namespace kuzu::common; class Util { public: static Napi::Value ConvertToNapiObject(const Value& value, Napi::Env env); - static std::unordered_map> TransformParametersForExec( - Napi::Array params, std::unordered_map>& parameterMap); + static std::unordered_map> TransformParametersForExec( + Napi::Array params, + const std::unordered_map>& parameterMap); private: static Napi::Object ConvertNodeIdToNapiObject(const nodeID_t& nodeId, Napi::Env env); diff --git a/tools/nodejs_api/src_cpp/node_connection.cpp b/tools/nodejs_api/src_cpp/node_connection.cpp index 7d7b542461..98c4bbcf46 100644 --- a/tools/nodejs_api/src_cpp/node_connection.cpp +++ b/tools/nodejs_api/src_cpp/node_connection.cpp @@ -75,10 +75,10 @@ Napi::Value NodeConnection::ExecuteAsync(const Napi::CallbackInfo& info) { auto nodeQueryResult = Napi::ObjectWrap::Unwrap(info[1].As()); auto callback = info[3].As(); try { - auto parameterMap = nodePreparedStatement->preparedStatement->getParameterMap(); + const auto& parameterMap = nodePreparedStatement->preparedStatement->getParameterMap(); auto params = Util::TransformParametersForExec(info[2].As(), parameterMap); auto asyncWorker = new ConnectionExecuteAsyncWorker(callback, connection, - nodePreparedStatement->preparedStatement, nodeQueryResult, params); + nodePreparedStatement->preparedStatement, nodeQueryResult, std::move(params)); asyncWorker->Queue(); } catch (const std::exception& exc) { Napi::Error::New(env, std::string(exc.what())).ThrowAsJavaScriptException(); diff --git a/tools/nodejs_api/src_cpp/node_util.cpp b/tools/nodejs_api/src_cpp/node_util.cpp index 8896de63a2..b7e77cd45c 100644 --- a/tools/nodejs_api/src_cpp/node_util.cpp +++ b/tools/nodejs_api/src_cpp/node_util.cpp @@ -162,9 +162,10 @@ Napi::Value Util::ConvertToNapiObject(const Value& value, Napi::Env env) { return Napi::Value(); } -std::unordered_map> Util::TransformParametersForExec( - Napi::Array params, std::unordered_map>& parameterMap) { - std::unordered_map> result; +std::unordered_map> Util::TransformParametersForExec( + Napi::Array params, + const std::unordered_map>& parameterMap) { + std::unordered_map> result; for (size_t i = 0; i < params.Length(); i++) { auto param = params.Get(i).As(); KU_ASSERT(param.Length() == 2); @@ -173,11 +174,10 @@ std::unordered_map> Util::TransformParameter if (!parameterMap.count(key)) { throw Exception("Parameter " + key + " is not defined in the prepared statement"); } - auto paramValue = parameterMap[key]; auto napiValue = param.Get(uint32_t(1)); - auto expectedDataType = paramValue->getDataType(); + auto expectedDataType = parameterMap.at(key)->getDataType(); auto transformedVal = TransformNapiValue(napiValue, expectedDataType, key); - result[key] = std::make_shared(transformedVal); + result[key] = std::make_unique(transformedVal); } return result; } @@ -266,8 +266,7 @@ Value Util::TransformNapiValue( return Value(normalizedInterval); } default: - throw Exception("Unsupported type " + - expectedDataType->toString() + - " for parameter: " + key); + throw Exception( + "Unsupported type " + expectedDataType->toString() + " for parameter: " + key); } } diff --git a/tools/python_api/src_cpp/include/py_connection.h b/tools/python_api/src_cpp/include/py_connection.h index 526725859f..4c8f43d5ac 100644 --- a/tools/python_api/src_cpp/include/py_connection.h +++ b/tools/python_api/src_cpp/include/py_connection.h @@ -32,12 +32,6 @@ class PyConnection { static bool isPandasDataframe(const py::object& object); -private: - std::unordered_map> transformPythonParameters( - py::dict params); - - kuzu::common::Value transformPythonValue(py::handle val); - private: std::unique_ptr storageDriver; std::unique_ptr conn; diff --git a/tools/python_api/src_cpp/py_connection.cpp b/tools/python_api/src_cpp/py_connection.cpp index 228ff0b743..3402b6ee25 100644 --- a/tools/python_api/src_cpp/py_connection.cpp +++ b/tools/python_api/src_cpp/py_connection.cpp @@ -1,10 +1,8 @@ #include "include/py_connection.h" -#include "binder/bound_statement_result.h" #include "common/string_format.h" #include "datetime.h" // from Python #include "main/connection.h" -#include "planner/operator/logical_plan.h" #include "pandas/pandas_scan.h" #include "processor/result/factorized_table.h" @@ -30,7 +28,7 @@ void PyConnection::initialize(py::handle& m) { PyConnection::PyConnection(PyDatabase* pyDatabase, uint64_t numThreads) { storageDriver = std::make_unique(pyDatabase->database.get()); conn = std::make_unique(pyDatabase->database.get()); - //TODO(Xiyang): We should implement a generic replacement framework in binder. + // TODO(Xiyang): We should implement a generic replacement framework in binder. conn->setReplaceFunc(kuzu::replacePD); if (numThreads > 0) { conn->setMaxNumThreadForExec(numThreads); @@ -41,12 +39,15 @@ void PyConnection::setQueryTimeout(uint64_t timeoutInMS) { conn->setQueryTimeOut(timeoutInMS); } -std::unique_ptr PyConnection::execute(PyPreparedStatement* preparedStatement, - py::dict params) { +static std::unordered_map> transformPythonParameters( + py::dict params); + +std::unique_ptr PyConnection::execute( + PyPreparedStatement* preparedStatement, py::dict params) { auto parameters = transformPythonParameters(params); py::gil_scoped_release release; auto queryResult = - conn->executeWithParams(preparedStatement->preparedStatement.get(), parameters); + conn->executeWithParams(preparedStatement->preparedStatement.get(), std::move(parameters)); py::gil_scoped_acquire acquire; if (!queryResult->isSuccess()) { throw std::runtime_error(queryResult->getErrorMessage()); @@ -103,10 +104,10 @@ void PyConnection::getAllEdgesForTorchGeometric(py::array_t& npArray, int64_t start = batch * queryBatchSize; int64_t end = (batch + 1) * queryBatchSize; end = end > numDstNodes ? numDstNodes : end; - std::unordered_map> parameters; - parameters["s"] = std::make_shared(start); - parameters["e"] = std::make_shared(end); - auto result = conn->executeWithParams(preparedStatement.get(), parameters); + std::unordered_map> parameters; + parameters["s"] = std::make_unique(start); + parameters["e"] = std::make_unique(end); + auto result = conn->executeWithParams(preparedStatement.get(), std::move(parameters)); if (!result->isSuccess()) { throw std::runtime_error(result->getErrorMessage()); } @@ -151,22 +152,23 @@ bool PyConnection::isPandasDataframe(const py::object& object) { return py::isinstance(object, pandas.attr("DataFrame")); } -std::unordered_map> PyConnection::transformPythonParameters( - py::dict params) { - std::unordered_map> result; +static Value transformPythonValue(py::handle val); + +std::unordered_map> transformPythonParameters(py::dict params) { + std::unordered_map> result; for (auto& [key, value] : params) { if (!py::isinstance(key)) { throw std::runtime_error("Parameter name must be of type string but get " + py::str(key.get_type()).cast()); } auto name = key.cast(); - auto val = std::make_shared(transformPythonValue(value)); - result.insert({name, val}); + auto val = std::make_unique(transformPythonValue(value)); + result.insert({name, std::move(val)}); } return result; } -Value PyConnection::transformPythonValue(py::handle val) { +Value transformPythonValue(py::handle val) { auto datetime_mod = py::module::import("datetime"); auto datetime_datetime = datetime_mod.attr("datetime"); auto time_delta = datetime_mod.attr("timedelta");