From d564cfd8b33c2fa6180a1390cd53f8885fab9c33 Mon Sep 17 00:00:00 2001 From: Cole Miller Date: Wed, 15 Feb 2023 16:21:34 -0500 Subject: [PATCH] gateway: Support QUERY and QUERY_SQL that modify the database Signed-off-by: Cole Miller --- src/gateway.c | 77 ++++++++++++++++++++++++++++-- test/integration/test_cluster.c | 83 +++++++++++++++++++++++++++++++++ 2 files changed, 157 insertions(+), 3 deletions(-) diff --git a/src/gateway.c b/src/gateway.c index d3a332fd8..33e7bee8b 100644 --- a/src/gateway.c +++ b/src/gateway.c @@ -186,6 +186,18 @@ static void failure(struct handle *req, int code, const char *message) req->cb(req, 0, DQLITE_RESPONSE_FAILURE, 0); } +static void emptyRows(struct handle *req) +{ + void *cursor = buffer__advance(req->buffer, 8 + 8); + uint64_t val; + assert(cursor != NULL); + val = 0; + uint64__encode(&val, &cursor); + val = DQLITE_RESPONSE_ROWS_DONE; + uint64__encode(&val, &cursor); + req->cb(req, 0, DQLITE_RESPONSE_ROWS, 0); +} + static int handle_leader_legacy(struct gateway *g, struct handle *req) { tracef("handle leader legacy"); @@ -536,6 +548,24 @@ static void query_barrier_cb(struct barrier *barrier, int status) query_batch(g); } +static void leaderModifyingQueryCb(struct exec *exec, int status) +{ + struct gateway *g = exec->data; + struct handle *req = g->req; + assert(req != NULL); + g->req = NULL; + struct stmt *stmt = stmt__registry_get(&g->stmts, req->stmt_id); + assert(stmt != NULL); + + if (status == SQLITE_DONE) { + emptyRows(req); + } else { + assert(g->leader != NULL); + failure(req, status, error_message(g->leader->conn, status)); + sqlite3_reset(stmt->stmt); + } +} + static int handle_query(struct gateway *g, struct handle *req) { tracef("handle query schema:%" PRIu8, req->schema); @@ -543,6 +573,8 @@ static int handle_query(struct gateway *g, struct handle *req) struct stmt *stmt; struct request_query request = {0}; int tuple_format; + bool is_readonly; + uint64_t req_id; int rv; switch (req->schema) { @@ -576,9 +608,15 @@ static int handle_query(struct gateway *g, struct handle *req) } req->stmt_id = stmt->id; g->req = req; - rv = leader__barrier(g->leader, &g->barrier, query_barrier_cb); + + is_readonly = (bool)sqlite3_stmt_readonly(stmt->stmt); + if (is_readonly) { + rv = leader__barrier(g->leader, &g->barrier, query_barrier_cb); + } else { + req_id = idNext(&g->random_state); + rv = leader__exec(g->leader, &g->exec, stmt->stmt, req_id, leaderModifyingQueryCb); + } if (rv != 0) { - tracef("handle query leader barrier failed %d", rv); g->req = NULL; return rv; } @@ -748,6 +786,25 @@ static int handle_exec_sql(struct gateway *g, struct handle *req) return 0; } +static void leaderModifyingQuerySqlCb(struct exec *exec, int status) +{ + struct gateway *g = exec->data; + struct handle *req = g->req; + assert(req != NULL); + g->req = NULL; + sqlite3_stmt *stmt = exec->stmt; + assert(stmt != NULL); + + sqlite3_finalize(stmt); + + if (status == SQLITE_DONE) { + emptyRows(req); + } else { + assert(g->leader != NULL); + failure(req, status, error_message(g->leader->conn, status)); + } +} + static void querySqlBarrierCb(struct barrier *barrier, int status) { tracef("query sql barrier cb status:%d", status); @@ -761,6 +818,8 @@ static void querySqlBarrierCb(struct barrier *barrier, int status) const char *tail; sqlite3_stmt *tail_stmt; int tuple_format; + bool is_readonly; + uint64_t req_id; int rv; if (status != 0) { @@ -810,7 +869,19 @@ static void querySqlBarrierCb(struct barrier *barrier, int status) req->stmt = stmt; g->req = req; - query_batch(g); + + is_readonly = (bool)sqlite3_stmt_readonly(stmt); + if (is_readonly) { + query_batch(g); + } else { + req_id = idNext(&g->random_state); + rv = leader__exec(g->leader, &g->exec, stmt, req_id, leaderModifyingQuerySqlCb); + if (rv != 0) { + sqlite3_finalize(stmt); + g->req = NULL; + failure(req, rv, "leader exec"); + } + } } static int handle_query_sql(struct gateway *g, struct handle *req) diff --git a/test/integration/test_cluster.c b/test/integration/test_cluster.c index c1d0a69cc..db6a81092 100644 --- a/test/integration/test_cluster.c +++ b/test/integration/test_cluster.c @@ -200,3 +200,86 @@ TEST(cluster, hugeRow, setUp, tearDown, 0, NULL) return MUNIT_OK; } + +TEST(cluster, modifyingQuery, setUp, tearDown, 0, cluster_params) +{ + struct fixture *f = data; + uint32_t stmt_id; + uint64_t last_insert_id; + uint64_t rows_affected; + struct rows rows; + long n_records = strtol(munit_parameters_get(params, "num_records"), NULL, 0); + char sql[128]; + unsigned id = 2; + const char *address = "@2"; + + HANDSHAKE; + OPEN; + PREPARE("CREATE TABLE test (n INT)", &stmt_id); + EXEC(stmt_id, &last_insert_id, &rows_affected); + + for (int i = 0; i < n_records; ++i) { + sprintf(sql, "INSERT INTO test(n) VALUES(%d)", i + 1); + PREPARE(sql, &stmt_id); + QUERY(stmt_id, &rows); + munit_assert_uint64(rows.column_count, ==, 0); + munit_assert_ptr(rows.next, ==, NULL); + clientCloseRows(&rows); + } + + ADD(id, address); + ASSIGN(id, DQLITE_VOTER); + + REMOVE(1); + sleep(1); + + SELECT(2); + HANDSHAKE; + OPEN; + PREPARE("SELECT COUNT(*) from test", &stmt_id); + QUERY(stmt_id, &rows); + munit_assert_long(rows.next->values->integer, ==, n_records); + clientCloseRows(&rows); + return MUNIT_OK; +} + +TEST(cluster, modifyingQuerySql, setUp, tearDown, 0, cluster_params) +{ + struct fixture *f = data; + uint32_t stmt_id; + uint64_t last_insert_id; + uint64_t rows_affected; + struct rows rows; + long n_records = strtol(munit_parameters_get(params, "num_records"), NULL, 0); + char sql[128]; + unsigned id = 2; + const char *address = "@2"; + + HANDSHAKE; + OPEN; + PREPARE("CREATE TABLE test (n INT)", &stmt_id); + EXEC(stmt_id, &last_insert_id, &rows_affected); + + for (int i = 0; i < n_records; ++i) { + sprintf(sql, "INSERT INTO test(n) VALUES(%d)", i + 1); + QUERY_SQL(sql, &rows); + munit_assert_uint64(rows.column_count, ==, 0); + munit_assert_ptr(rows.next, ==, NULL); + clientCloseRows(&rows); + } + + ADD(id, address); + ASSIGN(id, DQLITE_VOTER); + + REMOVE(1); + sleep(1); + + SELECT(2); + HANDSHAKE; + OPEN; + PREPARE("SELECT COUNT(*) from test", &stmt_id); + QUERY(stmt_id, &rows); + munit_assert_long(rows.next->values->integer, ==, n_records); + clientCloseRows(&rows); + return MUNIT_OK; +}