Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactoring of Statement and Column classes #349

Merged
merged 7 commits into from
Mar 29, 2022
12 changes: 7 additions & 5 deletions include/SQLiteCpp/Column.h
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,11 @@
#include <SQLiteCpp/Exception.h>

#include <string>
#include <memory>
#include <climits> // For INT_MAX

// Forward declarations to avoid inclusion of <sqlite3.h> in a header
struct sqlite3_stmt;

namespace SQLite
{
Expand All @@ -26,7 +29,6 @@ extern const int TEXT; ///< SQLITE_TEXT
extern const int BLOB; ///< SQLITE_BLOB
extern const int Null; ///< SQLITE_NULL


/**
* @brief Encapsulation of a Column in a row of the result pointed by the prepared Statement.
*
Expand All @@ -52,7 +54,7 @@ class Column
* @param[in] aStmtPtr Shared pointer to the prepared SQLite Statement Object.
* @param[in] aIndex Index of the column in the row of result, starting at 0
*/
Column(Statement::Ptr& aStmtPtr, int aIndex) noexcept;
explicit Column(const Statement::TStatementPtr& aStmtPtr, int aIndex);

// default destructor: the finalization will be done by the destructor of the last shared pointer
// default copy constructor and assignment operator are perfectly suited :
Expand Down Expand Up @@ -250,8 +252,8 @@ class Column
}

private:
Statement::Ptr mStmtPtr; ///< Shared Pointer to the prepared SQLite Statement Object
int mIndex; ///< Index of the column in the row of result, starting at 0
Statement::TStatementPtr mStmtPtr; ///< Shared Pointer to the prepared SQLite Statement Object
int mIndex; ///< Index of the column in the row of result, starting at 0
};

/**
Expand Down Expand Up @@ -281,7 +283,7 @@ T Statement::getColumns()
template<typename T, const int... Is>
T Statement::getColumns(const std::integer_sequence<int, Is...>)
{
return T{Column(mStmtPtr, Is)...};
return T{Column(mpPreparedStatement, Is)...};
}

#endif
Expand Down
175 changes: 44 additions & 131 deletions include/SQLiteCpp/Statement.h
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@

#include <string>
#include <map>
#include <climits> // For INT_MAX
#include <memory>

// Forward declarations to avoid inclusion of <sqlite3.h> in a header
struct sqlite3;
Expand Down Expand Up @@ -51,8 +51,6 @@ extern const int OK; ///< SQLITE_OK
*/
class Statement
{
friend class Column; // For access to Statement::Ptr inner class

public:
/**
* @brief Compile and register the SQL query for the provided SQLite Database Connection
Expand All @@ -62,7 +60,7 @@ class Statement
*
* Exception is thrown in case of error, then the Statement object is NOT constructed.
*/
Statement(Database& aDatabase, const char* apQuery);
Statement(const Database& aDatabase, const char* apQuery);

/**
* @brief Compile and register the SQL query for the provided SQLite Database Connection
Expand All @@ -72,7 +70,7 @@ class Statement
*
* Exception is thrown in case of error, then the Statement object is NOT constructed.
*/
Statement(Database &aDatabase, const std::string& aQuery) :
Statement(const Database& aDatabase, const std::string& aQuery) :
Statement(aDatabase, aQuery.c_str())
{}

Expand All @@ -82,6 +80,7 @@ class Statement
* @param[in] aStatement Statement to move
*/
Statement(Statement&& aStatement) noexcept;
Statement& operator=(Statement&& aStatement) noexcept = default;

// Statement is non-copyable
Statement(const Statement&) = delete;
Expand Down Expand Up @@ -123,39 +122,20 @@ class Statement
// => if you know what you are doing, use bindNoCopy() instead of bind()

SQLITECPP_PURE_FUNC
int getIndex(const char * const apName);
int getIndex(const char * const apName) const;

/**
* @brief Bind an int value to a parameter "?", "?NNN", ":VVV", "@VVV" or "$VVV" in the SQL prepared statement (aIndex >= 1)
*/
void bind(const int aIndex, const int aValue);
void bind(const int aIndex, const int32_t aValue);
/**
* @brief Bind a 32bits unsigned int value to a parameter "?", "?NNN", ":VVV", "@VVV" or "$VVV" in the SQL prepared statement (aIndex >= 1)
*/
void bind(const int aIndex, const unsigned aValue);

#if (LONG_MAX == INT_MAX) // 4 bytes "long" type means the data model is ILP32 or LLP64 (Win64 Visual C++ and MinGW)
/**
* @brief Bind a 32bits long value to a parameter "?", "?NNN", ":VVV", "@VVV" or "$VVV" in the SQL prepared statement (aIndex >= 1)
*/
void bind(const int aIndex, const long aValue)
{
bind(aIndex, static_cast<int>(aValue));
}
#else // 8 bytes "long" type means the data model is LP64 (Most Unix-like, Windows when using Cygwin; z/OS)
/**
* @brief Bind a 64bits long value to a parameter "?", "?NNN", ":VVV", "@VVV" or "$VVV" in the SQL prepared statement (aIndex >= 1)
*/
void bind(const int aIndex, const long aValue)
{
bind(aIndex, static_cast<long long>(aValue));
}
#endif

void bind(const int aIndex, const uint32_t aValue);
/**
* @brief Bind a 64bits int value to a parameter "?", "?NNN", ":VVV", "@VVV" or "$VVV" in the SQL prepared statement (aIndex >= 1)
*/
void bind(const int aIndex, const long long aValue);
void bind(const int aIndex, const int64_t aValue);
/**
* @brief Bind a double (64bits float) value to a parameter "?", "?NNN", ":VVV", "@VVV" or "$VVV" in the SQL prepared statement (aIndex >= 1)
*/
Expand Down Expand Up @@ -210,39 +190,21 @@ class Statement
/**
* @brief Bind an int value to a named parameter "?NNN", ":VVV", "@VVV" or "$VVV" in the SQL prepared statement (aIndex >= 1)
*/
void bind(const char* apName, const int aValue)
void bind(const char* apName, const int32_t aValue)
{
bind(getIndex(apName), aValue);
}
/**
* @brief Bind a 32bits unsigned int value to a named parameter "?NNN", ":VVV", "@VVV" or "$VVV" in the SQL prepared statement (aIndex >= 1)
*/
void bind(const char* apName, const unsigned aValue)
void bind(const char* apName, const uint32_t aValue)
{
bind(getIndex(apName), aValue);
}

#if (LONG_MAX == INT_MAX) // 4 bytes "long" type means the data model is ILP32 or LLP64 (Win64 Visual C++ and MinGW)
/**
* @brief Bind a 32bits long value to a parameter "?", "?NNN", ":VVV", "@VVV" or "$VVV" in the SQL prepared statement (aIndex >= 1)
*/
void bind(const char* apName, const long aValue)
{
bind(apName, static_cast<int>(aValue));
}
#else // 8 bytes "long" type means the data model is LP64 (Most Unix-like, Windows when using Cygwin; z/OS)
/**
* @brief Bind a 64bits long value to a parameter "?", "?NNN", ":VVV", "@VVV" or "$VVV" in the SQL prepared statement (aIndex >= 1)
*/
void bind(const char* apName, const long aValue)
{
bind(apName, static_cast<long long>(aValue));
}
#endif
/**
* @brief Bind a 64bits int value to a named parameter "?NNN", ":VVV", "@VVV" or "$VVV" in the SQL prepared statement (aIndex >= 1)
*/
void bind(const char* apName, const long long aValue)
void bind(const char* apName, const int64_t aValue)
{
bind(getIndex(apName), aValue);
}
Expand Down Expand Up @@ -325,46 +287,28 @@ class Statement
/**
* @brief Bind an int value to a named parameter "?NNN", ":VVV", "@VVV" or "$VVV" in the SQL prepared statement (aIndex >= 1)
*/
void bind(const std::string& aName, const int aValue)
void bind(const std::string& aName, const int32_t aValue)
{
bind(aName.c_str(), aValue);
}
/**
* @brief Bind a 32bits unsigned int value to a named parameter "?NNN", ":VVV", "@VVV" or "$VVV" in the SQL prepared statement (aIndex >= 1)
*/
void bind(const std::string& aName, const unsigned aValue)
void bind(const std::string& aName, const uint32_t aValue)
{
bind(aName.c_str(), aValue);
}

#if (LONG_MAX == INT_MAX) // 4 bytes "long" type means the data model is ILP32 or LLP64 (Win64 Visual C++ and MinGW)
/**
* @brief Bind a 32bits long value to a parameter "?", "?NNN", ":VVV", "@VVV" or "$VVV" in the SQL prepared statement (aIndex >= 1)
*/
void bind(const std::string& aName, const long aValue)
{
bind(aName.c_str(), static_cast<int>(aValue));
}
#else // 8 bytes "long" type means the data model is LP64 (Most Unix-like, Windows when using Cygwin; z/OS)
/**
* @brief Bind a 64bits long value to a parameter "?", "?NNN", ":VVV", "@VVV" or "$VVV" in the SQL prepared statement (aIndex >= 1)
*/
void bind(const std::string& aName, const long aValue)
{
bind(aName.c_str(), static_cast<long long>(aValue));
}
#endif
/**
* @brief Bind a 64bits int value to a named parameter "?NNN", ":VVV", "@VVV" or "$VVV" in the SQL prepared statement (aIndex >= 1)
*/
void bind(const std::string& aName, const long long aValue)
void bind(const std::string& aName, const int64_t aValue)
{
bind(aName.c_str(), aValue);
}
/**
* @brief Bind a double (64bits float) value to a named parameter "?NNN", ":VVV", "@VVV" or "$VVV" in the SQL prepared statement (aIndex >= 1)
*/
void bind(const std::string& aName, const double aValue)
void bind(const std::string& aName, const double aValue)
{
bind(aName.c_str(), aValue);
}
Expand Down Expand Up @@ -519,7 +463,7 @@ class Statement
* Thus, you should instead extract immediately its data (getInt(), getText()...)
* and use or copy this data for any later usage.
*/
Column getColumn(const int aIndex);
Column getColumn(const int aIndex) const;

/**
* @brief Return a copy of the column data specified by its column name (less efficient than using an index)
Expand Down Expand Up @@ -550,7 +494,7 @@ class Statement
*
* Throw an exception if the specified name is not one of the aliased name of the columns in the result.
*/
Column getColumn(const char* apName);
Column getColumn(const char* apName) const;

#if __cplusplus >= 201402L || (defined(_MSC_VER) && _MSC_VER >= 1900) // c++14: Visual Studio 2015
/**
Expand Down Expand Up @@ -673,7 +617,7 @@ class Statement
}

// Return a UTF-8 string containing the SQL text of prepared statement with bound parameters expanded.
std::string getExpandedSQL();
std::string getExpandedSQL() const;

/// Return the number of columns in the result set returned by the prepared statement
int getColumnCount() const
Expand Down Expand Up @@ -701,52 +645,8 @@ class Statement
/// Return UTF-8 encoded English language explanation of the most recent failed API call (if any).
const char* getErrorMsg() const noexcept;

private:
/**
* @brief Shared pointer to the sqlite3_stmt SQLite Statement Object.
*
* Manage the finalization of the sqlite3_stmt with a reference counter.
*
* This is a internal class, not part of the API (hence full documentation is in the cpp).
*/
// TODO Convert this whole custom pointer to a C++11 std::shared_ptr with a custom deleter
class Ptr
{
public:
// Prepare the statement and initialize its reference counter
Ptr(sqlite3* apSQLite, std::string& aQuery);
// Copy constructor increments the ref counter
Ptr(const Ptr& aPtr);

// Move constructor
Ptr(Ptr&& aPtr);

// Decrement the ref counter and finalize the sqlite3_stmt when it reaches 0
~Ptr();

/// Inline cast operator returning the pointer to SQLite Database Connection Handle
operator sqlite3*() const
{
return mpSQLite;
}

/// Inline cast operator returning the pointer to SQLite Statement Object
operator sqlite3_stmt*() const
{
return mpStmt;
}

private:
/// @{ Unused/forbidden copy/assignment operator
Ptr& operator=(const Ptr& aPtr);
/// @}

private:
sqlite3* mpSQLite; //!< Pointer to SQLite Database Connection Handle
sqlite3_stmt* mpStmt; //!< Pointer to SQLite Statement Object
unsigned int* mpRefCount; //!< Pointer to the heap allocated reference counter of the sqlite3_stmt
//!< (to share it with Column objects)
};
/// Shared pointer to SQLite Prepared Statement Object
using TStatementPtr = std::shared_ptr<sqlite3_stmt>;

private:
/**
Expand All @@ -758,7 +658,7 @@ class Statement
{
if (SQLite::OK != aRet)
{
throw SQLite::Exception(mStmtPtr, aRet);
throw SQLite::Exception(mpSQLite, aRet);
}
}

Expand All @@ -784,17 +684,30 @@ class Statement
}
}

private:
/// Map of columns index by name (mutable so getColumnIndex can be const)
typedef std::map<std::string, int> TColumnNames;
/**
* @brief Prepare statement object.
*
* @return Shared pointer to prepared statement object
*/
TStatementPtr prepareStatement();

private:
std::string mQuery; //!< UTF-8 SQL Query
Ptr mStmtPtr; //!< Shared Pointer to the prepared SQLite Statement Object
int mColumnCount; //!< Number of columns in the result of the prepared statement
mutable TColumnNames mColumnNames; //!< Map of columns index by name (mutable so getColumnIndex can be const)
bool mbHasRow; //!< true when a row has been fetched with executeStep()
bool mbDone; //!< true when the last executeStep() had no more row to fetch
/**
* @brief Return a prepared statement object.
*
* Throw an exception if the statement object was not prepared.
* @return raw pointer to Prepared Statement Object
*/
sqlite3_stmt* getPreparedStatement() const;

std::string mQuery; //!< UTF-8 SQL Query
sqlite3* mpSQLite; //!< Pointer to SQLite Database Connection Handle
TStatementPtr mpPreparedStatement; //!< Shared Pointer to the prepared SQLite Statement Object
int mColumnCount{0}; //!< Number of columns in the result of the prepared statement
bool mbHasRow{false}; //!< true when a row has been fetched with executeStep()
bool mbDone{false}; //!< true when the last executeStep() had no more row to fetch

/// Map of columns index by name (mutable so getColumnIndex can be const)
mutable std::map<std::string, int> mColumnNames{};
};


Expand Down
Loading