Skip to content

Commit

Permalink
Switched from using std::optional to std::shared_ptr for easier s…
Browse files Browse the repository at this point in the history
…tate handling
  • Loading branch information
dankmolot committed Jan 21, 2025
1 parent 03de526 commit 29f9fd0
Show file tree
Hide file tree
Showing 4 changed files with 64 additions and 66 deletions.
7 changes: 4 additions & 3 deletions source/async_postgres.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
#include <chrono>
#include <exception>
#include <memory>
#include <optional>
#include <queue>
#include <string_view>
#include <variant>
Expand Down Expand Up @@ -99,6 +98,8 @@ namespace async_postgres {
CreatePreparedCommand, PreparedCommand,
DescribePreparedCommand, DescribePortalCommand>;

Query(CommandVariant&& command) : command(std::move(command)) {}

CommandVariant command;
GLua::AutoReference callback;
bool sent = false;
Expand All @@ -112,8 +113,8 @@ namespace async_postgres {

struct Connection {
pg::conn conn;
std::optional<Query> query;
std::optional<ResetEvent> reset_event;
std::shared_ptr<Query> query;
std::shared_ptr<ResetEvent> reset_event;
GLua::AutoReference on_notify;

Connection(GLua::ILuaInterface* lua, pg::conn&& conn);
Expand Down
16 changes: 7 additions & 9 deletions source/connection.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,7 @@ void async_postgres::reset(GLua::ILuaInterface* lua, Connection* state,
throw std::runtime_error(PQerrorMessage(state->conn.get()));
}

state->reset_event = ResetEvent();
state->reset_event = std::make_shared<ResetEvent>();
}

if (callback) {
Expand All @@ -119,26 +119,24 @@ void async_postgres::process_reset(GLua::ILuaInterface* lua,
return;
}

auto& event = state->reset_event.value();
auto event = state->reset_event;
if (!socket_is_ready(state->conn.get(), state->reset_event->status)) {
return;
}

event.status = PQresetPoll(state->conn.get());
if (event.status == PGRES_POLLING_OK) {
auto callbacks = std::move(event.callbacks);
event->status = PQresetPoll(state->conn.get());
if (event->status == PGRES_POLLING_OK) {
state->reset_event.reset();

for (auto& callback : callbacks) {
for (auto& callback : event->callbacks) {
callback.Push();
lua->PushBool(true);
pcall(lua, 1, 0);
}
} else if (event.status == PGRES_POLLING_FAILED) {
auto callbacks = std::move(event.callbacks);
} else if (event->status == PGRES_POLLING_FAILED) {
state->reset_event.reset();

for (auto& callback : callbacks) {
for (auto& callback : event->callbacks) {
callback.Push();
lua->PushBool(false);
lua->PushString(PQerrorMessage(state->conn.get()));
Expand Down
80 changes: 40 additions & 40 deletions source/main.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -50,14 +50,15 @@ namespace async_postgres::lua {
throw std::runtime_error("query already in progress");
}

async_postgres::SimpleCommand command = {lua->GetString(2)};
async_postgres::Query query = {std::move(command)};
state->query = std::make_shared<async_postgres::Query>(
async_postgres::SimpleCommand{
lua->GetString(2),
});

if (lua->IsType(3, GLua::Type::Function)) {
query.callback = GLua::AutoReference(lua, 3);
state->query->callback = GLua::AutoReference(lua, 3);
}

state->query = std::move(query);
return 0;
}

Expand All @@ -71,17 +72,16 @@ namespace async_postgres::lua {
throw std::runtime_error("query already in progress");
}

async_postgres::ParameterizedCommand command = {
lua->GetString(2),
async_postgres::array_to_params(lua, 3),
};
async_postgres::Query query = {std::move(command)};
state->query = std::make_shared<async_postgres::Query>(
async_postgres::ParameterizedCommand{
lua->GetString(2),
async_postgres::array_to_params(lua, 3),
});

if (lua->IsType(4, GLua::Type::Function)) {
query.callback = GLua::AutoReference(lua, 4);
state->query->callback = GLua::AutoReference(lua, 4);
}

state->query = std::move(query);
return 0;
}

Expand All @@ -95,15 +95,16 @@ namespace async_postgres::lua {
throw std::runtime_error("query already in progress");
}

async_postgres::CreatePreparedCommand command = {lua->GetString(2),
lua->GetString(3)};
async_postgres::Query query = {std::move(command)};
state->query = std::make_shared<async_postgres::Query>(
async_postgres::CreatePreparedCommand{
lua->GetString(2),
lua->GetString(3),
});

if (lua->IsType(4, GLua::Type::Function)) {
query.callback = GLua::AutoReference(lua, 4);
state->query->callback = GLua::AutoReference(lua, 4);
}

state->query = std::move(query);
return 0;
}

Expand All @@ -117,17 +118,16 @@ namespace async_postgres::lua {
throw std::runtime_error("query already in progress");
}

async_postgres::PreparedCommand command = {
lua->GetString(2),
async_postgres::array_to_params(lua, 3),
};
async_postgres::Query query = {std::move(command)};
state->query = std::make_shared<async_postgres::Query>(
async_postgres::PreparedCommand{
lua->GetString(2),
async_postgres::array_to_params(lua, 3),
});

if (lua->IsType(4, GLua::Type::Function)) {
query.callback = GLua::AutoReference(lua, 4);
state->query->callback = GLua::AutoReference(lua, 4);
}

state->query = std::move(query);
return 0;
}

Expand All @@ -140,14 +140,15 @@ namespace async_postgres::lua {
throw std::runtime_error("query already in progress");
}

async_postgres::DescribePreparedCommand command = {lua->GetString(2)};
async_postgres::Query query = {std::move(command)};
state->query = std::make_shared<async_postgres::Query>(
async_postgres::DescribePreparedCommand{
lua->GetString(2),
});

if (lua->IsType(3, GLua::Type::Function)) {
query.callback = GLua::AutoReference(lua, 3);
state->query->callback = GLua::AutoReference(lua, 3);
}

state->query = std::move(query);
return 0;
}

Expand All @@ -160,14 +161,15 @@ namespace async_postgres::lua {
throw std::runtime_error("query already in progress");
}

async_postgres::DescribePortalCommand command = {lua->GetString(2)};
async_postgres::Query query = {std::move(command)};
state->query = std::make_shared<async_postgres::Query>(
async_postgres::DescribePortalCommand{
lua->GetString(2),
});

if (lua->IsType(3, GLua::Type::Function)) {
query.callback = GLua::AutoReference(lua, 3);
state->query->callback = GLua::AutoReference(lua, 3);
}

state->query = std::move(query);
return 0;
}

Expand Down Expand Up @@ -201,9 +203,8 @@ namespace async_postgres::lua {

auto state = lua_connection_state();
if (state->reset_event) {
auto& event = state->reset_event.value();
while (state->reset_event.has_value() &&
&event == &state->reset_event.value()) {
auto event = state->reset_event;
while (event == state->reset_event) {
bool write =
state->reset_event->status == PGRES_POLLING_WRITING;
bool read = state->reset_event->status == PGRES_POLLING_READING;
Expand All @@ -220,16 +221,15 @@ namespace async_postgres::lua {
}

if (state->query) {
auto& query = state->query.value();
auto query = state->query;

// if query wasn't sent, send in through process_query
if (!query.sent) {
if (!query->sent) {
async_postgres::process_query(lua, state);
}

// while query is the same and it's not done
while (state->query.has_value() &&
&query == &state->query.value()) {
while (query == state->query) {
async_postgres::process_result(lua, state,
pg::getResult(state->conn));
}
Expand All @@ -253,14 +253,14 @@ namespace async_postgres::lua {
lua_protected_fn(querying) {
lua->CheckType(1, async_postgres::connection_meta);
auto state = lua_connection_state();
lua->PushBool(state->query.has_value());
lua->PushBool(!!state->query);
return 1;
}

lua_protected_fn(resetting) {
lua->CheckType(1, async_postgres::connection_meta);
auto state = lua_connection_state();
lua->PushBool(state->reset_event.has_value());
lua->PushBool(!!state->reset_event);
return 1;
}
} // namespace async_postgres::lua
Expand Down
27 changes: 13 additions & 14 deletions source/query.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,11 @@
using namespace async_postgres;

#define get_if_command(type) \
const auto* command = std::get_if<type>(&query.command)
const auto* command = std::get_if<type>(&query->command)

// returns true if query was sent
// returns false on error
inline bool send_query(PGconn* conn, Query& query) {
inline bool send_query(PGconn* conn, Query* query) {
if (get_if_command(SimpleCommand)) {
return PQsendQuery(conn, command->command.c_str()) == 1;
} else if (get_if_command(ParameterizedCommand)) {
Expand Down Expand Up @@ -39,11 +39,10 @@ void query_failed(GLua::ILuaInterface* lua, Connection* state) {
return;
}

auto query = std::move(*state->query);
auto query = state->query;
state->query.reset();

if (query.callback) {
query.callback.Push();
if (query->callback.Push()) {
lua->PushBool(false);
lua->PushString(PQerrorMessage(state->conn.get()));
pcall(lua, 2, 0);
Expand Down Expand Up @@ -112,15 +111,15 @@ void query_result(GLua::ILuaInterface* lua, pg::result&& result,

// returns true if poll was successful
// returns false if there was an error
inline bool poll_query(PGconn* conn, Query& query) {
inline bool poll_query(PGconn* conn, Query* query) {
auto socket = check_socket_status(conn);
if (socket.read_ready || socket.write_ready) {
if (socket.read_ready && PQconsumeInput(conn) == 0) {
return false;
}

if (!query.flushed) {
query.flushed = PQflush(conn) == 0;
if (!query->flushed) {
query->flushed = PQflush(conn) == 0;
}
}
return true;
Expand All @@ -142,10 +141,10 @@ void async_postgres::process_result(GLua::ILuaInterface* lua, Connection* state,
auto next_result = pg::getResult(state->conn);
if (!next_result) {
// query is done, we need to remove query from the state
Query query = std::move(*state->query);
auto query = state->query;
state->query.reset();

query_result(lua, std::move(result), query.callback);
query_result(lua, std::move(result), query->callback);

// callback might added another query, process it rightaway
process_query(lua, state);
Expand All @@ -169,15 +168,15 @@ void async_postgres::process_query(GLua::ILuaInterface* lua,
return;
}

auto& query = state->query.value();
if (!query.sent) {
auto* query = state->query.get();
if (!query->sent) {
if (!send_query(state->conn.get(), query)) {
query_failed(lua, state);
return process_query(lua, state);
}

query.sent = true;
query.flushed = PQflush(state->conn.get()) == 0;
query->sent = true;
query->flushed = PQflush(state->conn.get()) == 0;
}

// if (!poll_query(state->conn.get(), query)) {
Expand Down

0 comments on commit 29f9fd0

Please sign in to comment.