Skip to content

Commit

Permalink
prepared_statement: move parameters
Browse files Browse the repository at this point in the history
In the Java API, we had a bug where we take ownership of and free
parameters passed into executeWithParams.

Inspecting the method itself, it was taking a shared_ptr, but then
performing a deep copy, which is nonsense. Instead, we should take a
unique_ptr, since we need to copy the parameters to guarantee that they
are not modified for the duration of the query.

This commit also fixes three other issues. First, the Java tests weren't running
any tests from ConnectionTest.java, which is why we didn't observe this
bug. Additionally, the constructor of KuzuConnection uses an assertion,
but assertions are disabled by default, which causes our tests to fail
(and if the assertion is skipped, we segfault).

Also, since rust-lang/cc-rs#900 has been
fixed, we can remove the version pinning of `cc` on MacOS.
  • Loading branch information
Riolku committed Nov 16, 2023
1 parent 21f37be commit 82b15e6
Show file tree
Hide file tree
Showing 15 changed files with 85 additions and 85 deletions.
1 change: 0 additions & 1 deletion .github/workflows/ci-workflow.yml
Original file line number Diff line number Diff line change
Expand Up @@ -359,7 +359,6 @@ jobs:
run: |
ulimit -n 10240
source /Users/runner/.cargo/env
cargo update -p cc --precise '1.0.83'
make rusttest
- name: Rust example
Expand Down
4 changes: 2 additions & 2 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -128,11 +128,11 @@ nodejstest: nodejs
javatest: java
ifeq ($(OS),Windows_NT)
$(call mkdirp,tools/java_api/build/test) && cd tools/java_api/ && \
javac -d build/test -cp ".;build/kuzu_java.jar;third_party/junit-platform-console-standalone-1.9.3.jar" -sourcepath src/test/java/com/kuzudb/test/*.java && \
javac -d build/test -cp ".;build/kuzu_java.jar;third_party/junit-platform-console-standalone-1.9.3.jar" src/test/java/com/kuzudb/test/*.java && \
java -jar third_party/junit-platform-console-standalone-1.9.3.jar -cp ".;build/kuzu_java.jar;build/test/" --scan-classpath --include-package=com.kuzudb.java_test --details=verbose
else
$(call mkdirp,tools/java_api/build/test) && cd tools/java_api/ && \
javac -d build/test -cp ".:build/kuzu_java.jar:third_party/junit-platform-console-standalone-1.9.3.jar" -sourcepath src/test/java/com/kuzudb/test/*.java && \
javac -d build/test -cp ".:build/kuzu_java.jar:third_party/junit-platform-console-standalone-1.9.3.jar" src/test/java/com/kuzudb/test/*.java && \
java -jar third_party/junit-platform-console-standalone-1.9.3.jar -cp ".:build/kuzu_java.jar:build/test/" --scan-classpath --include-package=com.kuzudb.java_test --details=verbose
endif

Expand Down
18 changes: 13 additions & 5 deletions src/c_api/connection.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -67,19 +67,27 @@ kuzu_prepared_statement* kuzu_connection_prepare(kuzu_connection* connection, co
auto* c_prepared_statement = new kuzu_prepared_statement;
c_prepared_statement->_prepared_statement = prepared_statement;
c_prepared_statement->_bound_values =
new std::unordered_map<std::string, std::shared_ptr<Value>>;
new std::unordered_map<std::string, std::unique_ptr<Value>>;
return c_prepared_statement;
}

kuzu_query_result* kuzu_connection_execute(
kuzu_connection* connection, kuzu_prepared_statement* prepared_statement) {
auto prepared_statement_ptr =
static_cast<PreparedStatement*>(prepared_statement->_prepared_statement);
auto bound_values = static_cast<std::unordered_map<std::string, std::shared_ptr<Value>>*>(
auto bound_values = static_cast<std::unordered_map<std::string, std::unique_ptr<Value>>*>(
prepared_statement->_bound_values);
auto query_result = static_cast<Connection*>(connection->_connection)
->executeWithParams(prepared_statement_ptr, *bound_values)
.release();

// Must copy the parameters for safety.
std::unordered_map<std::string, std::unique_ptr<Value>> copied_bound_values;
for (auto& [name, value] : *bound_values) {
copied_bound_values.emplace(name, value->copy());
}

auto query_result =
static_cast<Connection*>(connection->_connection)
->executeWithParams(prepared_statement_ptr, std::move(copied_bound_values))
.release();
if (query_result == nullptr) {
return nullptr;
}
Expand Down
16 changes: 8 additions & 8 deletions src/include/main/connection.h
Original file line number Diff line number Diff line change
Expand Up @@ -82,8 +82,8 @@ class Connection {
template<typename... Args>
inline std::unique_ptr<QueryResult> execute(
PreparedStatement* preparedStatement, std::pair<std::string, Args>... args) {
std::unordered_map<std::string, std::shared_ptr<common::Value>> inputParameters;
return executeWithParams(preparedStatement, inputParameters, args...);
std::unordered_map<std::string, std::unique_ptr<common::Value>> inputParameters;
return executeWithParams(preparedStatement, std::move(inputParameters), args...);
}
/**
* @brief Executes the given prepared statement with inputParams and returns the result.
Expand All @@ -93,7 +93,7 @@ class Connection {
* @return the result of the query.
*/
KUZU_API std::unique_ptr<QueryResult> executeWithParams(PreparedStatement* preparedStatement,
std::unordered_map<std::string, std::shared_ptr<common::Value>>& inputParams);
std::unordered_map<std::string, std::unique_ptr<common::Value>> inputParams);
/**
* @brief interrupts all queries currently executing within this connection.
*/
Expand Down Expand Up @@ -151,16 +151,16 @@ class Connection {

template<typename T, typename... Args>
std::unique_ptr<QueryResult> executeWithParams(PreparedStatement* preparedStatement,
std::unordered_map<std::string, std::shared_ptr<common::Value>>& params,
std::unordered_map<std::string, std::unique_ptr<common::Value>> params,
std::pair<std::string, T> arg, std::pair<std::string, Args>... args) {
auto name = arg.first;
auto val = std::make_shared<common::Value>((T)arg.second);
params.insert({name, val});
return executeWithParams(preparedStatement, params, args...);
auto val = std::make_unique<common::Value>((T)arg.second);
params.insert({name, std::move(val)});
return executeWithParams(preparedStatement, std::move(params), args...);
}

void bindParametersNoLock(PreparedStatement* preparedStatement,
std::unordered_map<std::string, std::shared_ptr<common::Value>>& inputParams);
std::unordered_map<std::string, std::unique_ptr<common::Value>> inputParams);

std::unique_ptr<QueryResult> executeAndAutoCommitIfNecessaryNoLock(
PreparedStatement* preparedStatement, uint32_t planIdx = 0u);
Expand Down
8 changes: 4 additions & 4 deletions src/main/connection.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -157,13 +157,13 @@ uint64_t Connection::getQueryTimeOut() {
}

std::unique_ptr<QueryResult> Connection::executeWithParams(PreparedStatement* preparedStatement,
std::unordered_map<std::string, std::shared_ptr<Value>>& inputParams) {
std::unordered_map<std::string, std::unique_ptr<Value>> inputParams) {
lock_t lck{mtx};
if (!preparedStatement->isSuccess()) {
return queryResultWithError(preparedStatement->errMsg);
}
try {
bindParametersNoLock(preparedStatement, inputParams);
bindParametersNoLock(preparedStatement, std::move(inputParams));
} catch (Exception& exception) {
std::string errMsg = exception.what();
return queryResultWithError(errMsg);
Expand All @@ -172,7 +172,7 @@ std::unique_ptr<QueryResult> Connection::executeWithParams(PreparedStatement* pr
}

void Connection::bindParametersNoLock(PreparedStatement* preparedStatement,
std::unordered_map<std::string, std::shared_ptr<Value>>& inputParams) {
std::unordered_map<std::string, std::unique_ptr<Value>> inputParams) {
auto& parameterMap = preparedStatement->parameterMap;
for (auto& [name, value] : inputParams) {
if (!parameterMap.contains(name)) {
Expand All @@ -184,7 +184,7 @@ void Connection::bindParametersNoLock(PreparedStatement* preparedStatement,
value->getDataType()->toString() + " but expects " +
expectParam->getDataType()->toString() + ".");
}
parameterMap.at(name)->copyValueFrom(*value);
parameterMap.at(name) = std::move(value);
}
}

Expand Down
26 changes: 10 additions & 16 deletions tools/java_api/src/jni/kuzu_java.cpp
Original file line number Diff line number Diff line change
@@ -1,7 +1,5 @@
#include <iostream>
#include <unordered_map>

#include "binder/bound_statement_result.h"
// This header is generated at build time. See CMakeLists.txt.
#include "com_kuzudb_KuzuNative.h"
#include "common/exception/conversion.h"
Expand All @@ -11,10 +9,8 @@
#include "common/types/value/node.h"
#include "common/types/value/rel.h"
#include "common/types/value/value.h"
#include "json.hpp"
#include "main/kuzu.h"
#include "main/query_summary.h"
#include "planner/operator/logical_plan.h"
#include <jni.h>

using namespace kuzu::main;
Expand Down Expand Up @@ -116,8 +112,8 @@ std::string dataTypeToString(const LogicalType& dataType) {
return LogicalTypeUtils::toString(dataType.getLogicalTypeID());
}

void javaMapToCPPMap(
JNIEnv* env, jobject javaMap, std::unordered_map<std::string, std::shared_ptr<Value>>& cppMap) {
std::unordered_map<std::string, std::unique_ptr<Value>> javaMapToCPPMap(
JNIEnv* env, jobject javaMap) {

jclass mapClass = env->FindClass("java/util/Map");
jmethodID entrySet = env->GetMethodID(mapClass, "entrySet", "()Ljava/util/Set;");
Expand All @@ -132,20 +128,22 @@ void javaMapToCPPMap(
jmethodID entryGetKey = env->GetMethodID(entryClass, "getKey", "()Ljava/lang/Object;");
jmethodID entryGetValue = env->GetMethodID(entryClass, "getValue", "()Ljava/lang/Object;");

std::unordered_map<std::string, std::unique_ptr<Value>> result;
while (env->CallBooleanMethod(iter, hasNext)) {
jobject entry = env->CallObjectMethod(iter, next);
jstring key = (jstring)env->CallObjectMethod(entry, entryGetKey);
jobject value = env->CallObjectMethod(entry, entryGetValue);
const char* keyStr = env->GetStringUTFChars(key, JNI_FALSE);
Value* v = getValue(env, value);
std::shared_ptr<Value> value_ptr(v);
cppMap.insert({keyStr, value_ptr});
const Value* v = getValue(env, value);
// Java code can keep a reference to the value, so we cannot move.
result.insert({keyStr, v->copy()});

env->DeleteLocalRef(entry);
env->ReleaseStringUTFChars(key, keyStr);
env->DeleteLocalRef(key);
env->DeleteLocalRef(value);
}
return result;
}

/**
Expand Down Expand Up @@ -301,14 +299,10 @@ JNIEXPORT jobject JNICALL Java_com_kuzudb_KuzuNative_kuzu_1connection_1execute(
Connection* conn = getConnection(env, thisConn);
PreparedStatement* ps = getPreparedStatement(env, preStm);

std::unordered_map<std::string, std::shared_ptr<Value>> param;
javaMapToCPPMap(env, param_map, param);
std::unordered_map<std::string, std::unique_ptr<Value>> params =
javaMapToCPPMap(env, param_map);

for (auto const& pair : param) {
std::cout << "{" << pair.first << ": " << pair.second.get()->toString() << "}\n";
}

auto query_result = conn->executeWithParams(ps, param).release();
auto query_result = conn->executeWithParams(ps, std::move(params)).release();
if (query_result == nullptr) {
return nullptr;
}
Expand Down
3 changes: 2 additions & 1 deletion tools/java_api/src/main/java/com/kuzudb/KuzuConnection.java
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,8 @@ public class KuzuConnection {
* @param db: KuzuDatabase instance.
*/
public KuzuConnection(KuzuDatabase db) {
assert db != null : "Cannot create connection, database is null.";
if (db == null)
throw new AssertionError("Cannot create connection, database is null.");
conn_ref = KuzuNative.kuzu_connection_init(db);
}

Expand Down
1 change: 1 addition & 0 deletions tools/java_api/src/main/java/com/kuzudb/KuzuValue.java
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,7 @@ public void destroy() throws KuzuObjectRefDestroyedException {
KuzuNative.kuzu_value_destroy(this);
destroyed = true;
}
System.gc();
}

/**
Expand Down
11 changes: 9 additions & 2 deletions tools/java_api/src/test/java/com/kuzudb/test/ConnectionTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -256,8 +256,10 @@ void ConnPrepareInterval() throws KuzuObjectRefDestroyedException {
void ConnPrepareMultiParam() throws KuzuObjectRefDestroyedException {
String query = "MATCH (a:person) WHERE a.lastJobDuration > $1 AND a.fName = $2 RETURN COUNT(*)";
Map<String, KuzuValue> m = new HashMap<String, KuzuValue>();
m.put("1", new KuzuValue(Duration.ofDays(730)));
m.put("2", new KuzuValue("Alice"));
KuzuValue v1 = new KuzuValue(Duration.ofDays(730));
KuzuValue v2 = new KuzuValue("Alice");
m.put("1", v1);
m.put("2", v2);
KuzuPreparedStatement statement = conn.prepare(query);
assertNotNull(statement);
KuzuQueryResult result = conn.execute(statement, m);
Expand All @@ -270,6 +272,11 @@ void ConnPrepareMultiParam() throws KuzuObjectRefDestroyedException {
assertEquals(((long) tuple.getValue(0).getValue()), 1);
statement.destroy();
result.destroy();

// Not strictly necessary, but this makes sure if we freed v1 or v2 in
// the execute() call, we segfault here.
v1.destroy();
v2.destroy();
}

@Test
Expand Down
10 changes: 4 additions & 6 deletions tools/nodejs_api/src_cpp/include/node_connection.h
Original file line number Diff line number Diff line change
@@ -1,7 +1,5 @@
#pragma once

#include <iostream>

#include "main/kuzu.h"
#include "node_database.h"
#include "node_prepared_statement.h"
Expand Down Expand Up @@ -65,15 +63,15 @@ class ConnectionExecuteAsyncWorker : public Napi::AsyncWorker {
public:
ConnectionExecuteAsyncWorker(Napi::Function& callback, std::shared_ptr<Connection>& connection,
std::shared_ptr<PreparedStatement> preparedStatement, NodeQueryResult* nodeQueryResult,
std::unordered_map<std::string, std::shared_ptr<Value>>& params)
std::unordered_map<std::string, std::unique_ptr<Value>> params)
: Napi::AsyncWorker(callback), preparedStatement(preparedStatement),
nodeQueryResult(nodeQueryResult), connection(connection), params(params) {}
nodeQueryResult(nodeQueryResult), connection(connection), params(std::move(params)) {}
~ConnectionExecuteAsyncWorker() = default;

inline void Execute() override {
try {
std::shared_ptr<QueryResult> result =
std::move(connection->executeWithParams(preparedStatement.get(), params));
connection->executeWithParams(preparedStatement.get(), std::move(params));
nodeQueryResult->SetQueryResult(result);
if (!result->isSuccess()) {
SetError(result->getErrorMessage());
Expand All @@ -90,5 +88,5 @@ class ConnectionExecuteAsyncWorker : public Napi::AsyncWorker {
std::shared_ptr<Connection> connection;
std::shared_ptr<PreparedStatement> preparedStatement;
NodeQueryResult* nodeQueryResult;
std::unordered_map<std::string, std::shared_ptr<Value>> params;
std::unordered_map<std::string, std::unique_ptr<Value>> params;
};
11 changes: 4 additions & 7 deletions tools/nodejs_api/src_cpp/include/node_util.h
Original file line number Diff line number Diff line change
@@ -1,19 +1,16 @@
#pragma once

#include <chrono>
#include <ctime>
#include <iostream>

#include "main/kuzu.h"
#include "common/types/value/value.h"
#include <napi.h>

using namespace kuzu::common;

class Util {
public:
static Napi::Value ConvertToNapiObject(const Value& value, Napi::Env env);
static std::unordered_map<std::string, std::shared_ptr<Value>> TransformParametersForExec(
Napi::Array params, std::unordered_map<std::string, std::shared_ptr<Value>>& parameterMap);
static std::unordered_map<std::string, std::unique_ptr<Value>> TransformParametersForExec(
Napi::Array params,
const std::unordered_map<std::string, std::shared_ptr<Value>>& parameterMap);

private:
static Napi::Object ConvertNodeIdToNapiObject(const nodeID_t& nodeId, Napi::Env env);
Expand Down
4 changes: 2 additions & 2 deletions tools/nodejs_api/src_cpp/node_connection.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -75,10 +75,10 @@ Napi::Value NodeConnection::ExecuteAsync(const Napi::CallbackInfo& info) {
auto nodeQueryResult = Napi::ObjectWrap<NodeQueryResult>::Unwrap(info[1].As<Napi::Object>());
auto callback = info[3].As<Napi::Function>();
try {
auto parameterMap = nodePreparedStatement->preparedStatement->getParameterMap();
const auto& parameterMap = nodePreparedStatement->preparedStatement->getParameterMap();
auto params = Util::TransformParametersForExec(info[2].As<Napi::Array>(), parameterMap);
auto asyncWorker = new ConnectionExecuteAsyncWorker(callback, connection,
nodePreparedStatement->preparedStatement, nodeQueryResult, params);
nodePreparedStatement->preparedStatement, nodeQueryResult, std::move(params));
asyncWorker->Queue();
} catch (const std::exception& exc) {
Napi::Error::New(env, std::string(exc.what())).ThrowAsJavaScriptException();
Expand Down
17 changes: 8 additions & 9 deletions tools/nodejs_api/src_cpp/node_util.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -162,9 +162,10 @@ Napi::Value Util::ConvertToNapiObject(const Value& value, Napi::Env env) {
return Napi::Value();
}

std::unordered_map<std::string, std::shared_ptr<Value>> Util::TransformParametersForExec(
Napi::Array params, std::unordered_map<std::string, std::shared_ptr<Value>>& parameterMap) {
std::unordered_map<std::string, std::shared_ptr<Value>> result;
std::unordered_map<std::string, std::unique_ptr<Value>> Util::TransformParametersForExec(
Napi::Array params,
const std::unordered_map<std::string, std::shared_ptr<Value>>& parameterMap) {
std::unordered_map<std::string, std::unique_ptr<Value>> result;
for (size_t i = 0; i < params.Length(); i++) {
auto param = params.Get(i).As<Napi::Array>();
KU_ASSERT(param.Length() == 2);
Expand All @@ -173,11 +174,10 @@ std::unordered_map<std::string, std::shared_ptr<Value>> Util::TransformParameter
if (!parameterMap.count(key)) {
throw Exception("Parameter " + key + " is not defined in the prepared statement");
}
auto paramValue = parameterMap[key];
auto napiValue = param.Get(uint32_t(1));
auto expectedDataType = paramValue->getDataType();
auto expectedDataType = parameterMap.at(key)->getDataType();
auto transformedVal = TransformNapiValue(napiValue, expectedDataType, key);
result[key] = std::make_shared<Value>(transformedVal);
result[key] = std::make_unique<Value>(transformedVal);
}
return result;
}
Expand Down Expand Up @@ -266,8 +266,7 @@ Value Util::TransformNapiValue(
return Value(normalizedInterval);
}
default:
throw Exception("Unsupported type " +
expectedDataType->toString() +
" for parameter: " + key);
throw Exception(
"Unsupported type " + expectedDataType->toString() + " for parameter: " + key);
}
}
6 changes: 0 additions & 6 deletions tools/python_api/src_cpp/include/py_connection.h
Original file line number Diff line number Diff line change
Expand Up @@ -32,12 +32,6 @@ class PyConnection {

static bool isPandasDataframe(const py::object& object);

private:
std::unordered_map<std::string, std::shared_ptr<kuzu::common::Value>> transformPythonParameters(
py::dict params);

kuzu::common::Value transformPythonValue(py::handle val);

private:
std::unique_ptr<StorageDriver> storageDriver;
std::unique_ptr<Connection> conn;
Expand Down
Loading

0 comments on commit 82b15e6

Please sign in to comment.