Skip to content

Commit

Permalink
Refactor database: Extracted into tfs::db namespace
Browse files Browse the repository at this point in the history
  • Loading branch information
ramon-bernardo committed Nov 25, 2024
1 parent cf38586 commit b3ac5b5
Show file tree
Hide file tree
Showing 19 changed files with 376 additions and 460 deletions.
15 changes: 5 additions & 10 deletions src/ban.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,7 @@ namespace IOBan {

const std::optional<BanInfo> getAccountBanInfo(uint32_t accountId)
{
Database& db = Database::getInstance();

DBResult_ptr result = db.storeQuery(fmt::format(
auto result = tfs::db::store_query(fmt::format(
"SELECT `reason`, `expires_at`, `banned_at`, `banned_by`, (SELECT `name` FROM `players` WHERE `id` = `banned_by`) AS `name` FROM `account_bans` WHERE `account_id` = {:d}",
accountId));
if (!result) {
Expand All @@ -27,8 +25,8 @@ const std::optional<BanInfo> getAccountBanInfo(uint32_t accountId)
// Move the ban to history if it has expired
g_databaseTasks.addTask(fmt::format(
"INSERT INTO `account_ban_history` (`account_id`, `reason`, `banned_at`, `expired_at`, `banned_by`) VALUES ({:d}, {:s}, {:d}, {:d}, {:d})",
accountId, db.escapeString(result->getString("reason")), result->getNumber<time_t>("banned_at"), expiresAt,
result->getNumber<uint32_t>("banned_by")));
accountId, tfs::db::escape_string(result->getString("reason")), result->getNumber<time_t>("banned_at"),
expiresAt, result->getNumber<uint32_t>("banned_by")));
g_databaseTasks.addTask(fmt::format("DELETE FROM `account_bans` WHERE `account_id` = {:d}", accountId));
return std::nullopt;
}
Expand All @@ -51,9 +49,7 @@ const std::optional<BanInfo> getIpBanInfo(const Connection::Address& clientIP)
return std::nullopt;
}

Database& db = Database::getInstance();

DBResult_ptr result = db.storeQuery(fmt::format(
auto result = tfs::db::store_query(fmt::format(
"SELECT `reason`, `expires_at`, (SELECT `name` FROM `players` WHERE `id` = `banned_by`) AS `name` FROM `ip_bans` WHERE `ip` = INET6_ATON('{:s}')",
clientIP.to_string()));
if (!result) {
Expand Down Expand Up @@ -81,8 +77,7 @@ const std::optional<BanInfo> getIpBanInfo(const Connection::Address& clientIP)

bool isPlayerNamelocked(uint32_t playerId)
{
return Database::getInstance()
.storeQuery(fmt::format("SELECT 1 FROM `player_namelocks` WHERE `player_id` = {:d}", playerId))
return tfs::db::store_query(fmt::format("SELECT 1 FROM `player_namelocks` WHERE `player_id` = {:d}", playerId))
.get();
}

Expand Down
173 changes: 96 additions & 77 deletions src/database.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,21 +9,37 @@

#include <mysql/errmsg.h>

static tfs::detail::Mysql_ptr connectToDatabase(const bool retryIfError)
namespace {

std::recursive_mutex database_lock;
tfs::db::detail::Mysql_ptr handle = nullptr;
uint64_t packet_size = 1048576;

// Do not retry queries if we are in the middle of a transaction
bool retry_queries = true;

bool is_lost_connection_error(const unsigned error)
{
bool isFirstAttemptToConnect = true;
return error == CR_SERVER_LOST || error == CR_SERVER_GONE_ERROR || error == CR_CONN_HOST_ERROR ||
error == 1053 /*ER_SERVER_SHUTDOWN*/ || error == CR_CONNECTION_ERROR;
}

tfs::db::detail::Mysql_ptr connect_to_database(const bool retryIfError)
{
auto is_first_attempt_to_connect = true;

retry:
if (!isFirstAttemptToConnect) {
if (!is_first_attempt_to_connect) {
std::this_thread::sleep_for(std::chrono::seconds(1));
}
isFirstAttemptToConnect = false;
is_first_attempt_to_connect = false;

tfs::detail::Mysql_ptr handle{mysql_init(nullptr)};
tfs::db::detail::Mysql_ptr handle{mysql_init(nullptr)};
if (!handle) {
std::cout << std::endl << "Failed to initialize MySQL connection handle." << std::endl;
goto error;
}

// connects to database
if (!mysql_real_connect(handle.get(), getString(ConfigManager::MYSQL_HOST).c_str(),
getString(ConfigManager::MYSQL_USER).c_str(), getString(ConfigManager::MYSQL_PASS).c_str(),
Expand All @@ -41,122 +57,91 @@ static tfs::detail::Mysql_ptr connectToDatabase(const bool retryIfError)
return nullptr;
}

static bool isLostConnectionError(const unsigned error)
{
return error == CR_SERVER_LOST || error == CR_SERVER_GONE_ERROR || error == CR_CONN_HOST_ERROR ||
error == 1053 /*ER_SERVER_SHUTDOWN*/ || error == CR_CONNECTION_ERROR;
}

static bool executeQuery(tfs::detail::Mysql_ptr& handle, std::string_view query, const bool retryIfLostConnection)
bool execute_query_handle(tfs::db::detail::Mysql_ptr& handle, std::string_view query, const bool retryIfLostConnection)
{
while (mysql_real_query(handle.get(), query.data(), query.length()) != 0) {
std::cout << "[Error - mysql_real_query] Query: " << query.substr(0, 256) << std::endl
<< "Message: " << mysql_error(handle.get()) << std::endl;
const unsigned error = mysql_errno(handle.get());
if (!isLostConnectionError(error) || !retryIfLostConnection) {
if (!is_lost_connection_error(error) || !retryIfLostConnection) {
return false;
}
handle = connectToDatabase(true);
handle = connect_to_database(true);
}
return true;
}

bool Database::connect()
} // namespace

bool tfs::db::connect()
{
auto newHandle = connectToDatabase(false);
if (!newHandle) {
auto new_handle = connect_to_database(false);
if (!new_handle) {
return false;
}

handle = std::move(newHandle);
DBResult_ptr result = storeQuery("SHOW VARIABLES LIKE 'max_allowed_packet'");
if (result) {
maxPacketSize = result->getNumber<uint64_t>("Value");
handle = std::move(new_handle);
if (auto result = store_query("SHOW VARIABLES LIKE 'max_allowed_packet'")) {
packet_size = result->getNumber<uint64_t>("Value");
}
return true;
}

bool Database::beginTransaction()
bool tfs::db::execute_query(std::string_view query)
{
databaseLock.lock();
const bool result = executeQuery("START TRANSACTION");
retryQueries = !result;
if (!result) {
databaseLock.unlock();
}
return result;
}
std::lock_guard<std::recursive_mutex> lockGuard(database_lock);
auto success = ::execute_query_handle(handle, query, retry_queries);

bool Database::rollback()
{
const bool result = executeQuery("ROLLBACK");
retryQueries = true;
databaseLock.unlock();
return result;
}

bool Database::commit()
{
const bool result = executeQuery("COMMIT");
retryQueries = true;
databaseLock.unlock();
return result;
}

bool Database::executeQuery(const std::string& query)
{
std::lock_guard<std::recursive_mutex> lockGuard(databaseLock);
auto success = ::executeQuery(handle, query, retryQueries);

// executeQuery can be called with command that produces result (e.g. SELECT)
// execute_query_handle can be called with command that produces result (e.g. SELECT)
// we have to store that result, even though we do not need it, otherwise handle will get blocked
auto mysql_res = mysql_store_result(handle.get());
mysql_free_result(mysql_res);
auto store_result = mysql_store_result(handle.get());
mysql_free_result(store_result);

return success;
}

DBResult_ptr Database::storeQuery(std::string_view query)
DBResult_ptr tfs::db::store_query(std::string_view query)
{
std::lock_guard<std::recursive_mutex> lockGuard(databaseLock);
std::lock_guard<std::recursive_mutex> lockGuard(database_lock);

retry:
if (!::executeQuery(handle, query, retryQueries) && !retryQueries) {
if (!::execute_query_handle(handle, query, retry_queries) && !retry_queries) {
return nullptr;
}

// we should call that every time as someone would call executeQuery('SELECT...')
// we should call that every time as someone would call execute_query_handle('SELECT...')
// as it is described in MySQL manual: "it doesn't hurt" :P
tfs::detail::MysqlResult_ptr res{mysql_store_result(handle.get())};
tfs::db::detail::MysqlResult_ptr res{mysql_store_result(handle.get())};
if (!res) {
std::cout << "[Error - mysql_store_result] Query: " << query << std::endl
<< "Message: " << mysql_error(handle.get()) << std::endl;
const unsigned error = mysql_errno(handle.get());
if (!isLostConnectionError(error) || !retryQueries) {
if (!is_lost_connection_error(error) || !retry_queries) {
return nullptr;
}
goto retry;
}

// retrieving results of query
DBResult_ptr result = std::make_shared<DBResult>(std::move(res));
auto result = std::make_shared<DBResult>(std::move(res));
if (!result->hasNext()) {
return nullptr;
}
return result;
}

std::string Database::escapeBlob(const char* s, uint32_t length) const
{
// the worst case is 2n + 1
size_t maxLength = (length * 2) + 1;
std::string tfs::db::escape_string(std::string_view s) { return escape_blob(s.data(), s.length()); }

std::string tfs::db::escape_blob(const char* s, uint32_t length)
{ // the worst case is 2n + 1
size_t max_length = (length * 2) + 1;

std::string escaped;
escaped.reserve(maxLength + 2);
escaped.reserve(max_length + 2);
escaped.push_back('\'');

if (length != 0) {
char* output = new char[maxLength];
char* output = new char[max_length];
mysql_real_escape_string(handle.get(), output, s, length);
escaped.append(output);
delete[] output;
Expand All @@ -166,7 +151,40 @@ std::string Database::escapeBlob(const char* s, uint32_t length) const
return escaped;
}

DBResult::DBResult(tfs::detail::MysqlResult_ptr&& res) : handle{std::move(res)}
uint64_t tfs::db::last_insert_id() { return static_cast<uint64_t>(mysql_insert_id(handle.get())); }

const char* tfs::db::client_version() { return mysql_get_client_info(); }

uint64_t tfs::db::max_packet_size() { return packet_size; }

bool tfs::db::transaction::begin()
{
database_lock.lock();
const auto result = tfs::db::execute_query("START TRANSACTION");
retry_queries = !result;
if (!result) {
database_lock.unlock();
}
return result;
}

bool tfs::db::transaction::rollback()
{
const auto result = tfs::db::execute_query("ROLLBACK");
retry_queries = true;
database_lock.unlock();
return result;
}

bool tfs::db::transaction::commit()
{
const auto result = tfs::db::execute_query("COMMIT");
retry_queries = true;
database_lock.unlock();
return result;
}

DBResult::DBResult(tfs::db::detail::MysqlResult_ptr&& res) : handle{std::move(res)}
{
size_t i = 0;

Expand Down Expand Up @@ -209,19 +227,20 @@ DBInsert::DBInsert(std::string query) : query(std::move(query)) { this->length =
bool DBInsert::addRow(const std::string& row)
{
// adds new row to buffer
const size_t rowLength = row.length();
length += rowLength;
if (length > Database::getInstance().getMaxPacketSize() && !execute()) {
const size_t row_length = row.length();
length += row_length;

if (length > tfs::db::max_packet_size() && !execute()) {
return false;
}

if (values.empty()) {
values.reserve(rowLength + 2);
values.reserve(row_length + 2);
values.push_back('(');
values.append(row);
values.push_back(')');
} else {
values.reserve(values.length() + rowLength + 3);
values.reserve(values.length() + row_length + 3);
values.push_back(',');
values.push_back('(');
values.append(row);
Expand All @@ -232,9 +251,9 @@ bool DBInsert::addRow(const std::string& row)

bool DBInsert::addRow(std::ostringstream& row)
{
bool ret = addRow(row.str());
auto result = addRow(row.str());
row.str(std::string());
return ret;
return result;
}

bool DBInsert::execute()
Expand All @@ -244,8 +263,8 @@ bool DBInsert::execute()
}

// executes buffer
bool res = Database::getInstance().executeQuery(query + values);
auto result = tfs::db::execute_query(query + values);
values.clear();
length = query.length();
return res;
return result;
}
Loading

0 comments on commit b3ac5b5

Please sign in to comment.