Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
155 changes: 154 additions & 1 deletion mssql_python/cursor.py
Original file line number Diff line number Diff line change
Expand Up @@ -1387,7 +1387,160 @@ def fetchall_with_mapping():
self.fetchmany = fetchmany_with_mapping
self.fetchall = fetchall_with_mapping

# Return the cursor itself
return result_rows

def columns(self, table=None, catalog=None, schema=None, column=None):
"""
Creates a result set of column information in the specified tables
using the SQLColumns function.

Args:
table (str, optional): The table name pattern. Default is None (all tables).
catalog (str, optional): The catalog name. Default is None (current catalog).
schema (str, optional): The schema name pattern. Default is None (all schemas).
column (str, optional): The column name pattern. Default is None (all columns).

Returns:
cursor: The cursor itself, containing the result set. Use fetchone(), fetchmany(),
or fetchall() to retrieve the results.

Each row contains the following columns:
- table_cat (str): Catalog name
- table_schem (str): Schema name
- table_name (str): Table name
- column_name (str): Column name
- data_type (int): The ODBC SQL data type constant (e.g. SQL_CHAR)
- type_name (str): Data source dependent type name
- column_size (int): Column size
- buffer_length (int): Length of the column in bytes
- decimal_digits (int): Number of fractional digits
- num_prec_radix (int): Radix (typically 10 or 2)
- nullable (int): One of SQL_NO_NULLS, SQL_NULLABLE, SQL_NULLABLE_UNKNOWN
- remarks (str): Comments about the column
- column_def (str): Default value for the column
- sql_data_type (int): The SQL data type from java.sql.Types
- sql_datetime_sub (int): Subcode for datetime types
- char_octet_length (int): Maximum length in bytes for char types
- ordinal_position (int): Column position in the table (starting at 1)
- is_nullable (str): "YES", "NO", or "" (unknown)

Warning:
Calling this method without any filters (all parameters as None) will enumerate
EVERY column in EVERY table in the database. This can be extremely expensive in
large databases, potentially causing high memory usage, slow execution times,
and in extreme cases, timeout errors. Always use filters (catalog, schema, table,
or column) whenever possible to limit the result set.

Example:
# Get all columns in table 'Customers'
columns = cursor.columns(table='Customers')

# Get all columns in table 'Customers' in schema 'dbo'
columns = cursor.columns(table='Customers', schema='dbo')

# Get column named 'CustomerID' in any table
columns = cursor.columns(column='CustomerID')
"""
self._check_closed()

# Always reset the cursor first to ensure clean state
self._reset_cursor()

# Call the SQLColumns function
retcode = ddbc_bindings.DDBCSQLColumns(
self.hstmt,
catalog,
schema,
table,
column
)
check_error(ddbc_sql_const.SQL_HANDLE_STMT.value, self.hstmt, retcode)

# Initialize description from column metadata
column_metadata = []
try:
ddbc_bindings.DDBCSQLDescribeCol(self.hstmt, column_metadata)
self._initialize_description(column_metadata)
except InterfaceError as e:
log('error', f"Driver interface error during metadata retrieval: {e}")

except Exception as e:
# Log the exception with appropriate context
log('error', f"Failed to retrieve column metadata: {e}. Using standard ODBC column definitions instead.")

if not self.description:
# If describe fails, create a manual description for the standard columns
column_types = [str, str, str, str, int, str, int, int, int, int, int, str, str, int, int, int, int, str]
self.description = [
("table_cat", column_types[0], None, 128, 128, 0, True),
("table_schem", column_types[1], None, 128, 128, 0, True),
("table_name", column_types[2], None, 128, 128, 0, False),
("column_name", column_types[3], None, 128, 128, 0, False),
("data_type", column_types[4], None, 10, 10, 0, False),
("type_name", column_types[5], None, 128, 128, 0, False),
("column_size", column_types[6], None, 10, 10, 0, True),
("buffer_length", column_types[7], None, 10, 10, 0, True),
("decimal_digits", column_types[8], None, 10, 10, 0, True),
("num_prec_radix", column_types[9], None, 10, 10, 0, True),
("nullable", column_types[10], None, 10, 10, 0, False),
("remarks", column_types[11], None, 254, 254, 0, True),
("column_def", column_types[12], None, 254, 254, 0, True),
("sql_data_type", column_types[13], None, 10, 10, 0, False),
("sql_datetime_sub", column_types[14], None, 10, 10, 0, True),
("char_octet_length", column_types[15], None, 10, 10, 0, True),
("ordinal_position", column_types[16], None, 10, 10, 0, False),
("is_nullable", column_types[17], None, 254, 254, 0, True)
]

# Store the column mappings for this specific columns() call
column_names = [desc[0] for desc in self.description]

# Create a specialized column map for this result set
columns_map = {}
for i, name in enumerate(column_names):
columns_map[name] = i
columns_map[name.lower()] = i

# Define wrapped fetch methods that preserve existing column mapping
# but add our specialized mapping just for column results
def fetchone_with_columns_mapping():
row = self._original_fetchone()
if row is not None:
# Create a merged map with columns result taking precedence
merged_map = getattr(row, '_column_map', {}).copy()
merged_map.update(columns_map)
row._column_map = merged_map
return row

def fetchmany_with_columns_mapping(size=None):
rows = self._original_fetchmany(size)
for row in rows:
# Create a merged map with columns result taking precedence
merged_map = getattr(row, '_column_map', {}).copy()
merged_map.update(columns_map)
row._column_map = merged_map
return rows

def fetchall_with_columns_mapping():
rows = self._original_fetchall()
for row in rows:
# Create a merged map with columns result taking precedence
merged_map = getattr(row, '_column_map', {}).copy()
merged_map.update(columns_map)
row._column_map = merged_map
return rows

# Save original fetch methods
if not hasattr(self, '_original_fetchone'):
self._original_fetchone = self.fetchone
self._original_fetchmany = self.fetchmany
self._original_fetchall = self.fetchall

# Override fetch methods with our wrapped versions
self.fetchone = fetchone_with_columns_mapping
self.fetchmany = fetchmany_with_columns_mapping
self.fetchall = fetchall_with_columns_mapping

return self

@staticmethod
Expand Down
60 changes: 59 additions & 1 deletion mssql_python/pybind/ddbc_bindings.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,7 @@ SQLForeignKeysFunc SQLForeignKeys_ptr = nullptr;
SQLPrimaryKeysFunc SQLPrimaryKeys_ptr = nullptr;
SQLSpecialColumnsFunc SQLSpecialColumns_ptr = nullptr;
SQLStatisticsFunc SQLStatistics_ptr = nullptr;
SQLColumnsFunc SQLColumns_ptr = nullptr;

// Transaction APIs
SQLEndTranFunc SQLEndTran_ptr = nullptr;
Expand Down Expand Up @@ -791,6 +792,7 @@ DriverHandle LoadDriverOrThrowException() {
SQLPrimaryKeys_ptr = GetFunctionPointer<SQLPrimaryKeysFunc>(handle, "SQLPrimaryKeysW");
SQLSpecialColumns_ptr = GetFunctionPointer<SQLSpecialColumnsFunc>(handle, "SQLSpecialColumnsW");
SQLStatistics_ptr = GetFunctionPointer<SQLStatisticsFunc>(handle, "SQLStatisticsW");
SQLColumns_ptr = GetFunctionPointer<SQLColumnsFunc>(handle, "SQLColumnsW");

SQLEndTran_ptr = GetFunctionPointer<SQLEndTranFunc>(handle, "SQLEndTran");
SQLDisconnect_ptr = GetFunctionPointer<SQLDisconnectFunc>(handle, "SQLDisconnect");
Expand All @@ -810,7 +812,8 @@ DriverHandle LoadDriverOrThrowException() {
SQLEndTran_ptr && SQLDisconnect_ptr && SQLFreeHandle_ptr &&
SQLFreeStmt_ptr && SQLGetDiagRec_ptr &&
SQLGetTypeInfo_ptr && SQLProcedures_ptr && SQLForeignKeys_ptr &&
SQLPrimaryKeys_ptr && SQLSpecialColumns_ptr && SQLStatistics_ptr;
SQLPrimaryKeys_ptr && SQLSpecialColumns_ptr && SQLStatistics_ptr &&
SQLColumns_ptr;

if (!success) {
ThrowStdException("Failed to load required function pointers from driver.");
Expand Down Expand Up @@ -1051,6 +1054,53 @@ SQLRETURN SQLStatistics_wrap(SqlHandlePtr StatementHandle,
#endif
}

SQLRETURN SQLColumns_wrap(SqlHandlePtr StatementHandle,
const py::object& catalogObj,
const py::object& schemaObj,
const py::object& tableObj,
const py::object& columnObj) {
if (!SQLColumns_ptr) {
ThrowStdException("SQLColumns function not loaded");
}

// Convert py::object to std::wstring, treating None as empty string
std::wstring catalogStr = catalogObj.is_none() ? L"" : catalogObj.cast<std::wstring>();
std::wstring schemaStr = schemaObj.is_none() ? L"" : schemaObj.cast<std::wstring>();
std::wstring tableStr = tableObj.is_none() ? L"" : tableObj.cast<std::wstring>();
std::wstring columnStr = columnObj.is_none() ? L"" : columnObj.cast<std::wstring>();

#if defined(__APPLE__) || defined(__linux__)
// Unix implementation
std::vector<SQLWCHAR> catalogBuf = WStringToSQLWCHAR(catalogStr);
std::vector<SQLWCHAR> schemaBuf = WStringToSQLWCHAR(schemaStr);
std::vector<SQLWCHAR> tableBuf = WStringToSQLWCHAR(tableStr);
std::vector<SQLWCHAR> columnBuf = WStringToSQLWCHAR(columnStr);

return SQLColumns_ptr(
StatementHandle->get(),
catalogStr.empty() ? nullptr : catalogBuf.data(),
catalogStr.empty() ? 0 : SQL_NTS,
schemaStr.empty() ? nullptr : schemaBuf.data(),
schemaStr.empty() ? 0 : SQL_NTS,
tableStr.empty() ? nullptr : tableBuf.data(),
tableStr.empty() ? 0 : SQL_NTS,
columnStr.empty() ? nullptr : columnBuf.data(),
columnStr.empty() ? 0 : SQL_NTS);
#else
// Windows implementation
return SQLColumns_ptr(
StatementHandle->get(),
catalogStr.empty() ? nullptr : (SQLWCHAR*)catalogStr.c_str(),
catalogStr.empty() ? 0 : SQL_NTS,
schemaStr.empty() ? nullptr : (SQLWCHAR*)schemaStr.c_str(),
schemaStr.empty() ? 0 : SQL_NTS,
tableStr.empty() ? nullptr : (SQLWCHAR*)tableStr.c_str(),
tableStr.empty() ? 0 : SQL_NTS,
columnStr.empty() ? nullptr : (SQLWCHAR*)columnStr.c_str(),
columnStr.empty() ? 0 : SQL_NTS);
#endif
}

// Helper function to check for driver errors
ErrorInfo SQLCheckError_Wrap(SQLSMALLINT handleType, SqlHandlePtr handle, SQLRETURN retcode) {
LOG("Checking errors for retcode - {}" , retcode);
Expand Down Expand Up @@ -2857,6 +2907,14 @@ PYBIND11_MODULE(ddbc_bindings, m) {
SQLUSMALLINT reserved) {
return SQLStatistics_wrap(StatementHandle, catalog, schema, table, unique, reserved);
});
m.def("DDBCSQLColumns", [](SqlHandlePtr StatementHandle,
const py::object& catalog,
const py::object& schema,
const py::object& table,
const py::object& column) {
return SQLColumns_wrap(StatementHandle, catalog, schema, table, column);
});


// Add a version attribute
m.attr("__version__") = "1.0.0";
Expand Down
4 changes: 4 additions & 0 deletions mssql_python/pybind/ddbc_bindings.h
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,9 @@ typedef SQLRETURN (SQL_API* SQLSpecialColumnsFunc)(SQLHSTMT, SQLUSMALLINT, SQLWC
typedef SQLRETURN (SQL_API* SQLStatisticsFunc)(SQLHSTMT, SQLWCHAR*, SQLSMALLINT, SQLWCHAR*,
SQLSMALLINT, SQLWCHAR*, SQLSMALLINT,
SQLUSMALLINT, SQLUSMALLINT);
typedef SQLRETURN (SQL_API* SQLColumnsFunc)(SQLHSTMT, SQLWCHAR*, SQLSMALLINT, SQLWCHAR*,
SQLSMALLINT, SQLWCHAR*, SQLSMALLINT,
SQLWCHAR*, SQLSMALLINT);

// Transaction APIs
typedef SQLRETURN (SQL_API* SQLEndTranFunc)(SQLSMALLINT, SQLHANDLE, SQLSMALLINT);
Expand Down Expand Up @@ -168,6 +171,7 @@ extern SQLForeignKeysFunc SQLForeignKeys_ptr;
extern SQLPrimaryKeysFunc SQLPrimaryKeys_ptr;
extern SQLSpecialColumnsFunc SQLSpecialColumns_ptr;
extern SQLStatisticsFunc SQLStatistics_ptr;
extern SQLColumnsFunc SQLColumns_ptr;

// Transaction APIs
extern SQLEndTranFunc SQLEndTran_ptr;
Expand Down
Loading