diff --git a/mssql_python/cursor.py b/mssql_python/cursor.py index e7eb906b..07c31a55 100644 --- a/mssql_python/cursor.py +++ b/mssql_python/cursor.py @@ -1387,7 +1387,160 @@ def fetchall_with_mapping(): self.fetchmany = fetchmany_with_mapping self.fetchall = fetchall_with_mapping - # Return the cursor itself + return result_rows + + def columns(self, table=None, catalog=None, schema=None, column=None): + """ + Creates a result set of column information in the specified tables + using the SQLColumns function. + + Args: + table (str, optional): The table name pattern. Default is None (all tables). + catalog (str, optional): The catalog name. Default is None (current catalog). + schema (str, optional): The schema name pattern. Default is None (all schemas). + column (str, optional): The column name pattern. Default is None (all columns). + + Returns: + cursor: The cursor itself, containing the result set. Use fetchone(), fetchmany(), + or fetchall() to retrieve the results. + + Each row contains the following columns: + - table_cat (str): Catalog name + - table_schem (str): Schema name + - table_name (str): Table name + - column_name (str): Column name + - data_type (int): The ODBC SQL data type constant (e.g. SQL_CHAR) + - type_name (str): Data source dependent type name + - column_size (int): Column size + - buffer_length (int): Length of the column in bytes + - decimal_digits (int): Number of fractional digits + - num_prec_radix (int): Radix (typically 10 or 2) + - nullable (int): One of SQL_NO_NULLS, SQL_NULLABLE, SQL_NULLABLE_UNKNOWN + - remarks (str): Comments about the column + - column_def (str): Default value for the column + - sql_data_type (int): The SQL data type from java.sql.Types + - sql_datetime_sub (int): Subcode for datetime types + - char_octet_length (int): Maximum length in bytes for char types + - ordinal_position (int): Column position in the table (starting at 1) + - is_nullable (str): "YES", "NO", or "" (unknown) + + Warning: + Calling this method without any filters (all parameters as None) will enumerate + EVERY column in EVERY table in the database. This can be extremely expensive in + large databases, potentially causing high memory usage, slow execution times, + and in extreme cases, timeout errors. Always use filters (catalog, schema, table, + or column) whenever possible to limit the result set. + + Example: + # Get all columns in table 'Customers' + columns = cursor.columns(table='Customers') + + # Get all columns in table 'Customers' in schema 'dbo' + columns = cursor.columns(table='Customers', schema='dbo') + + # Get column named 'CustomerID' in any table + columns = cursor.columns(column='CustomerID') + """ + self._check_closed() + + # Always reset the cursor first to ensure clean state + self._reset_cursor() + + # Call the SQLColumns function + retcode = ddbc_bindings.DDBCSQLColumns( + self.hstmt, + catalog, + schema, + table, + column + ) + check_error(ddbc_sql_const.SQL_HANDLE_STMT.value, self.hstmt, retcode) + + # Initialize description from column metadata + column_metadata = [] + try: + ddbc_bindings.DDBCSQLDescribeCol(self.hstmt, column_metadata) + self._initialize_description(column_metadata) + except InterfaceError as e: + log('error', f"Driver interface error during metadata retrieval: {e}") + + except Exception as e: + # Log the exception with appropriate context + log('error', f"Failed to retrieve column metadata: {e}. Using standard ODBC column definitions instead.") + + if not self.description: + # If describe fails, create a manual description for the standard columns + column_types = [str, str, str, str, int, str, int, int, int, int, int, str, str, int, int, int, int, str] + self.description = [ + ("table_cat", column_types[0], None, 128, 128, 0, True), + ("table_schem", column_types[1], None, 128, 128, 0, True), + ("table_name", column_types[2], None, 128, 128, 0, False), + ("column_name", column_types[3], None, 128, 128, 0, False), + ("data_type", column_types[4], None, 10, 10, 0, False), + ("type_name", column_types[5], None, 128, 128, 0, False), + ("column_size", column_types[6], None, 10, 10, 0, True), + ("buffer_length", column_types[7], None, 10, 10, 0, True), + ("decimal_digits", column_types[8], None, 10, 10, 0, True), + ("num_prec_radix", column_types[9], None, 10, 10, 0, True), + ("nullable", column_types[10], None, 10, 10, 0, False), + ("remarks", column_types[11], None, 254, 254, 0, True), + ("column_def", column_types[12], None, 254, 254, 0, True), + ("sql_data_type", column_types[13], None, 10, 10, 0, False), + ("sql_datetime_sub", column_types[14], None, 10, 10, 0, True), + ("char_octet_length", column_types[15], None, 10, 10, 0, True), + ("ordinal_position", column_types[16], None, 10, 10, 0, False), + ("is_nullable", column_types[17], None, 254, 254, 0, True) + ] + + # Store the column mappings for this specific columns() call + column_names = [desc[0] for desc in self.description] + + # Create a specialized column map for this result set + columns_map = {} + for i, name in enumerate(column_names): + columns_map[name] = i + columns_map[name.lower()] = i + + # Define wrapped fetch methods that preserve existing column mapping + # but add our specialized mapping just for column results + def fetchone_with_columns_mapping(): + row = self._original_fetchone() + if row is not None: + # Create a merged map with columns result taking precedence + merged_map = getattr(row, '_column_map', {}).copy() + merged_map.update(columns_map) + row._column_map = merged_map + return row + + def fetchmany_with_columns_mapping(size=None): + rows = self._original_fetchmany(size) + for row in rows: + # Create a merged map with columns result taking precedence + merged_map = getattr(row, '_column_map', {}).copy() + merged_map.update(columns_map) + row._column_map = merged_map + return rows + + def fetchall_with_columns_mapping(): + rows = self._original_fetchall() + for row in rows: + # Create a merged map with columns result taking precedence + merged_map = getattr(row, '_column_map', {}).copy() + merged_map.update(columns_map) + row._column_map = merged_map + return rows + + # Save original fetch methods + if not hasattr(self, '_original_fetchone'): + self._original_fetchone = self.fetchone + self._original_fetchmany = self.fetchmany + self._original_fetchall = self.fetchall + + # Override fetch methods with our wrapped versions + self.fetchone = fetchone_with_columns_mapping + self.fetchmany = fetchmany_with_columns_mapping + self.fetchall = fetchall_with_columns_mapping + return self @staticmethod diff --git a/mssql_python/pybind/ddbc_bindings.cpp b/mssql_python/pybind/ddbc_bindings.cpp index dcc916f3..4c5ed8a0 100644 --- a/mssql_python/pybind/ddbc_bindings.cpp +++ b/mssql_python/pybind/ddbc_bindings.cpp @@ -129,6 +129,7 @@ SQLForeignKeysFunc SQLForeignKeys_ptr = nullptr; SQLPrimaryKeysFunc SQLPrimaryKeys_ptr = nullptr; SQLSpecialColumnsFunc SQLSpecialColumns_ptr = nullptr; SQLStatisticsFunc SQLStatistics_ptr = nullptr; +SQLColumnsFunc SQLColumns_ptr = nullptr; // Transaction APIs SQLEndTranFunc SQLEndTran_ptr = nullptr; @@ -791,6 +792,7 @@ DriverHandle LoadDriverOrThrowException() { SQLPrimaryKeys_ptr = GetFunctionPointer(handle, "SQLPrimaryKeysW"); SQLSpecialColumns_ptr = GetFunctionPointer(handle, "SQLSpecialColumnsW"); SQLStatistics_ptr = GetFunctionPointer(handle, "SQLStatisticsW"); + SQLColumns_ptr = GetFunctionPointer(handle, "SQLColumnsW"); SQLEndTran_ptr = GetFunctionPointer(handle, "SQLEndTran"); SQLDisconnect_ptr = GetFunctionPointer(handle, "SQLDisconnect"); @@ -810,7 +812,8 @@ DriverHandle LoadDriverOrThrowException() { SQLEndTran_ptr && SQLDisconnect_ptr && SQLFreeHandle_ptr && SQLFreeStmt_ptr && SQLGetDiagRec_ptr && SQLGetTypeInfo_ptr && SQLProcedures_ptr && SQLForeignKeys_ptr && - SQLPrimaryKeys_ptr && SQLSpecialColumns_ptr && SQLStatistics_ptr; + SQLPrimaryKeys_ptr && SQLSpecialColumns_ptr && SQLStatistics_ptr && + SQLColumns_ptr; if (!success) { ThrowStdException("Failed to load required function pointers from driver."); @@ -1051,6 +1054,53 @@ SQLRETURN SQLStatistics_wrap(SqlHandlePtr StatementHandle, #endif } +SQLRETURN SQLColumns_wrap(SqlHandlePtr StatementHandle, + const py::object& catalogObj, + const py::object& schemaObj, + const py::object& tableObj, + const py::object& columnObj) { + if (!SQLColumns_ptr) { + ThrowStdException("SQLColumns function not loaded"); + } + + // Convert py::object to std::wstring, treating None as empty string + std::wstring catalogStr = catalogObj.is_none() ? L"" : catalogObj.cast(); + std::wstring schemaStr = schemaObj.is_none() ? L"" : schemaObj.cast(); + std::wstring tableStr = tableObj.is_none() ? L"" : tableObj.cast(); + std::wstring columnStr = columnObj.is_none() ? L"" : columnObj.cast(); + +#if defined(__APPLE__) || defined(__linux__) + // Unix implementation + std::vector catalogBuf = WStringToSQLWCHAR(catalogStr); + std::vector schemaBuf = WStringToSQLWCHAR(schemaStr); + std::vector tableBuf = WStringToSQLWCHAR(tableStr); + std::vector columnBuf = WStringToSQLWCHAR(columnStr); + + return SQLColumns_ptr( + StatementHandle->get(), + catalogStr.empty() ? nullptr : catalogBuf.data(), + catalogStr.empty() ? 0 : SQL_NTS, + schemaStr.empty() ? nullptr : schemaBuf.data(), + schemaStr.empty() ? 0 : SQL_NTS, + tableStr.empty() ? nullptr : tableBuf.data(), + tableStr.empty() ? 0 : SQL_NTS, + columnStr.empty() ? nullptr : columnBuf.data(), + columnStr.empty() ? 0 : SQL_NTS); +#else + // Windows implementation + return SQLColumns_ptr( + StatementHandle->get(), + catalogStr.empty() ? nullptr : (SQLWCHAR*)catalogStr.c_str(), + catalogStr.empty() ? 0 : SQL_NTS, + schemaStr.empty() ? nullptr : (SQLWCHAR*)schemaStr.c_str(), + schemaStr.empty() ? 0 : SQL_NTS, + tableStr.empty() ? nullptr : (SQLWCHAR*)tableStr.c_str(), + tableStr.empty() ? 0 : SQL_NTS, + columnStr.empty() ? nullptr : (SQLWCHAR*)columnStr.c_str(), + columnStr.empty() ? 0 : SQL_NTS); +#endif +} + // Helper function to check for driver errors ErrorInfo SQLCheckError_Wrap(SQLSMALLINT handleType, SqlHandlePtr handle, SQLRETURN retcode) { LOG("Checking errors for retcode - {}" , retcode); @@ -2857,6 +2907,14 @@ PYBIND11_MODULE(ddbc_bindings, m) { SQLUSMALLINT reserved) { return SQLStatistics_wrap(StatementHandle, catalog, schema, table, unique, reserved); }); + m.def("DDBCSQLColumns", [](SqlHandlePtr StatementHandle, + const py::object& catalog, + const py::object& schema, + const py::object& table, + const py::object& column) { + return SQLColumns_wrap(StatementHandle, catalog, schema, table, column); + }); + // Add a version attribute m.attr("__version__") = "1.0.0"; diff --git a/mssql_python/pybind/ddbc_bindings.h b/mssql_python/pybind/ddbc_bindings.h index edaeb6b0..d757ad95 100644 --- a/mssql_python/pybind/ddbc_bindings.h +++ b/mssql_python/pybind/ddbc_bindings.h @@ -119,6 +119,9 @@ typedef SQLRETURN (SQL_API* SQLSpecialColumnsFunc)(SQLHSTMT, SQLUSMALLINT, SQLWC typedef SQLRETURN (SQL_API* SQLStatisticsFunc)(SQLHSTMT, SQLWCHAR*, SQLSMALLINT, SQLWCHAR*, SQLSMALLINT, SQLWCHAR*, SQLSMALLINT, SQLUSMALLINT, SQLUSMALLINT); +typedef SQLRETURN (SQL_API* SQLColumnsFunc)(SQLHSTMT, SQLWCHAR*, SQLSMALLINT, SQLWCHAR*, + SQLSMALLINT, SQLWCHAR*, SQLSMALLINT, + SQLWCHAR*, SQLSMALLINT); // Transaction APIs typedef SQLRETURN (SQL_API* SQLEndTranFunc)(SQLSMALLINT, SQLHANDLE, SQLSMALLINT); @@ -168,6 +171,7 @@ extern SQLForeignKeysFunc SQLForeignKeys_ptr; extern SQLPrimaryKeysFunc SQLPrimaryKeys_ptr; extern SQLSpecialColumnsFunc SQLSpecialColumns_ptr; extern SQLStatisticsFunc SQLStatistics_ptr; +extern SQLColumnsFunc SQLColumns_ptr; // Transaction APIs extern SQLEndTranFunc SQLEndTran_ptr; diff --git a/tests/test_004_cursor.py b/tests/test_004_cursor.py index affacbba..d36bebfd 100644 --- a/tests/test_004_cursor.py +++ b/tests/test_004_cursor.py @@ -10,6 +10,7 @@ import pytest from datetime import datetime, date, time +import time as time_module import decimal from mssql_python import Connection import mssql_python @@ -3267,6 +3268,477 @@ def test_statistics_cleanup(cursor, db_connection): except Exception as e: pytest.fail(f"Test cleanup failed: {e}") +def test_columns_setup(cursor, db_connection): + """Create test tables for columns method testing""" + try: + # Create a test schema for isolation + cursor.execute("IF NOT EXISTS (SELECT * FROM sys.schemas WHERE name = 'pytest_cols_schema') EXEC('CREATE SCHEMA pytest_cols_schema')") + + # Drop tables if they exist + cursor.execute("DROP TABLE IF EXISTS pytest_cols_schema.columns_test") + cursor.execute("DROP TABLE IF EXISTS pytest_cols_schema.columns_special_test") + + # Create test table with various column types + cursor.execute(""" + CREATE TABLE pytest_cols_schema.columns_test ( + id INT PRIMARY KEY, + name NVARCHAR(100) NOT NULL, + description NVARCHAR(MAX) NULL, + price DECIMAL(10, 2) NULL, + created_date DATETIME DEFAULT GETDATE(), + is_active BIT NOT NULL DEFAULT 1, + binary_data VARBINARY(MAX) NULL, + notes TEXT NULL, + [computed_col] AS (name + ' - ' + CAST(id AS VARCHAR(10))) + ) + """) + + # Create table with special column names and edge cases - fix the problematic column name + cursor.execute(""" + CREATE TABLE pytest_cols_schema.columns_special_test ( + [ID] INT PRIMARY KEY, + [User Name] NVARCHAR(100) NULL, + [Spaces Multiple] VARCHAR(50) NULL, + [123_numeric_start] INT NULL, + [MAX] VARCHAR(20) NULL, -- SQL keyword as column name + [SELECT] INT NULL, -- SQL keyword as column name + [Column.With.Dots] VARCHAR(20) NULL, + [Column/With/Slashes] VARCHAR(20) NULL, + [Column_With_Underscores] VARCHAR(20) NULL -- Changed from problematic nested brackets + ) + """) + + db_connection.commit() + except Exception as e: + pytest.fail(f"Test setup failed: {e}") + +def test_columns_all(cursor, db_connection): + """Test columns returns information about all columns in all tables""" + try: + # First set up our test tables + test_columns_setup(cursor, db_connection) + + # Get all columns (no filters) + cols_cursor = cursor.columns() + cols = cols_cursor.fetchall() + + # Verify we got results + assert cols is not None, "columns() should return results" + assert len(cols) > 0, "columns() should return at least one column" + + # Verify our test tables' columns are in the results + # Use case-insensitive comparison to avoid driver case sensitivity issues + found_test_table = False + for col in cols: + if (hasattr(col, 'table_name') and + col.table_name and + col.table_name.lower() == 'columns_test' and + hasattr(col, 'table_schem') and + col.table_schem and + col.table_schem.lower() == 'pytest_cols_schema'): + found_test_table = True + break + + assert found_test_table, "Test table columns should be included in results" + + # Verify structure of results + first_row = cols[0] + assert hasattr(first_row, 'table_cat'), "Result should have table_cat column" + assert hasattr(first_row, 'table_schem'), "Result should have table_schem column" + assert hasattr(first_row, 'table_name'), "Result should have table_name column" + assert hasattr(first_row, 'column_name'), "Result should have column_name column" + assert hasattr(first_row, 'data_type'), "Result should have data_type column" + assert hasattr(first_row, 'type_name'), "Result should have type_name column" + assert hasattr(first_row, 'column_size'), "Result should have column_size column" + assert hasattr(first_row, 'buffer_length'), "Result should have buffer_length column" + assert hasattr(first_row, 'decimal_digits'), "Result should have decimal_digits column" + assert hasattr(first_row, 'num_prec_radix'), "Result should have num_prec_radix column" + assert hasattr(first_row, 'nullable'), "Result should have nullable column" + assert hasattr(first_row, 'remarks'), "Result should have remarks column" + assert hasattr(first_row, 'column_def'), "Result should have column_def column" + assert hasattr(first_row, 'sql_data_type'), "Result should have sql_data_type column" + assert hasattr(first_row, 'sql_datetime_sub'), "Result should have sql_datetime_sub column" + assert hasattr(first_row, 'char_octet_length'), "Result should have char_octet_length column" + assert hasattr(first_row, 'ordinal_position'), "Result should have ordinal_position column" + assert hasattr(first_row, 'is_nullable'), "Result should have is_nullable column" + + finally: + # Clean up happens in test_columns_cleanup + pass + +def test_columns_specific_table(cursor, db_connection): + """Test columns returns information about a specific table""" + try: + # Get columns for the test table + cols = cursor.columns( + table='columns_test', + schema='pytest_cols_schema' + ).fetchall() + + # Verify we got results + assert len(cols) == 9, "Should find exactly 9 columns in columns_test" + + # Verify all column names are present (case insensitive) + col_names = [col.column_name.lower() for col in cols] + expected_names = ['id', 'name', 'description', 'price', 'created_date', + 'is_active', 'binary_data', 'notes', 'computed_col'] + + for name in expected_names: + assert name in col_names, f"Column {name} should be in results" + + # Verify details of a specific column (id) + id_col = next(col for col in cols if col.column_name.lower() == 'id') + assert id_col.nullable == 0, "id column should be non-nullable" + assert id_col.ordinal_position == 1, "id should be the first column" + assert id_col.is_nullable == "NO", "is_nullable should be NO for id column" + + # Check data types (but don't assume specific ODBC type codes since they vary by driver) + # Instead check that the type_name is correct + id_type = id_col.type_name.lower() + assert 'int' in id_type, f"id column should be INTEGER type, got {id_type}" + + # Check a nullable column + desc_col = next(col for col in cols if col.column_name.lower() == 'description') + assert desc_col.nullable == 1, "description column should be nullable" + assert desc_col.is_nullable == "YES", "is_nullable should be YES for description column" + + finally: + # Clean up happens in test_columns_cleanup + pass + +def test_columns_special_chars(cursor, db_connection): + """Test columns with special characters and edge cases""" + try: + # Get columns for the special table + cols = cursor.columns( + table='columns_special_test', + schema='pytest_cols_schema' + ).fetchall() + + # Verify we got results + assert len(cols) == 9, "Should find exactly 9 columns in columns_special_test" + + # Check that special column names are handled correctly + col_names = [col.column_name for col in cols] + + # Create case-insensitive lookup + col_names_lower = [name.lower() if name else None for name in col_names] + + # Check for columns with special characters - note that column names might be + # returned with or without brackets/quotes depending on the driver + assert any('user name' in name.lower() for name in col_names), "Column with spaces should be in results" + assert any('id' == name.lower() for name in col_names), "ID column should be in results" + assert any('123_numeric_start' in name.lower() for name in col_names), "Column starting with numbers should be in results" + assert any('max' == name.lower() for name in col_names), "MAX column should be in results" + assert any('select' == name.lower() for name in col_names), "SELECT column should be in results" + assert any('column.with.dots' in name.lower() for name in col_names), "Column with dots should be in results" + assert any('column/with/slashes' in name.lower() for name in col_names), "Column with slashes should be in results" + assert any('column_with_underscores' in name.lower() for name in col_names), "Column with underscores should be in results" + + finally: + # Clean up happens in test_columns_cleanup + pass + +def test_columns_specific_column(cursor, db_connection): + """Test columns with specific column filter""" + try: + # Get specific column + cols = cursor.columns( + table='columns_test', + schema='pytest_cols_schema', + column='name' + ).fetchall() + + # Verify we got just one result + assert len(cols) == 1, "Should find exactly 1 column named 'name'" + + # Verify column details + col = cols[0] + assert col.column_name.lower() == 'name', "Column name should be 'name'" + assert col.table_name.lower() == 'columns_test', "Table name should be 'columns_test'" + assert col.table_schem.lower() == 'pytest_cols_schema', "Schema should be 'pytest_cols_schema'" + assert col.nullable == 0, "name column should be non-nullable" + + # Get column using pattern (% wildcard) + pattern_cols = cursor.columns( + table='columns_test', + schema='pytest_cols_schema', + column='%date%' + ).fetchall() + + # Should find created_date column + assert len(pattern_cols) == 1, "Should find 1 column matching '%date%'" + + assert pattern_cols[0].column_name.lower() == 'created_date', "Should find created_date column" + + # Get multiple columns with pattern + multi_cols = cursor.columns( + table='columns_test', + schema='pytest_cols_schema', + column='%d%' # Should match id, description, created_date + ).fetchall() + + # At least 3 columns should match this pattern + assert len(multi_cols) >= 3, "Should find at least 3 columns matching '%d%'" + match_names = [col.column_name.lower() for col in multi_cols] + assert 'id' in match_names, "id should match '%d%'" + assert 'description' in match_names, "description should match '%d%'" + assert 'created_date' in match_names, "created_date should match '%d%'" + + finally: + # Clean up happens in test_columns_cleanup + pass + +def test_columns_with_underscore_pattern(cursor): + """Test columns with underscore wildcard pattern""" + try: + # Get columns with underscore pattern (one character wildcard) + # Looking for 'id' (exactly 2 chars) + cols = cursor.columns( + table='columns_test', + schema='pytest_cols_schema', + column='__' + ).fetchall() + + # Should find 'id' column + id_found = False + for col in cols: + if col.column_name.lower() == 'id' and col.table_name.lower() == 'columns_test': + id_found = True + break + + assert id_found, "Should find 'id' column with pattern '__'" + + # Try a more complex pattern with both % and _ + # For example: '%_d%' matches any column with 'd' as the second or later character + pattern_cols = cursor.columns( + table='columns_test', + schema='pytest_cols_schema', + column='%_d%' + ).fetchall() + + # Should match 'id' (if considering case-insensitive) and 'created_date' + match_names = [col.column_name.lower() for col in pattern_cols + if col.table_name.lower() == 'columns_test'] + + # At least 'created_date' should match this pattern + assert 'created_date' in match_names, "created_date should match '%_d%'" + + finally: + # Clean up happens in test_columns_cleanup + pass + +def test_columns_nonexistent(cursor): + """Test columns with non-existent table or column""" + # Test with non-existent table + table_cols = cursor.columns(table='nonexistent_table_xyz123') + assert len(table_cols) == 0, "Should return empty list for non-existent table" + + # Test with non-existent column in existing table + col_cols = cursor.columns( + table='columns_test', + schema='pytest_cols_schema', + column='nonexistent_column_xyz123' + ).fetchall() + assert len(col_cols) == 0, "Should return empty list for non-existent column" + + # Test with non-existent schema + schema_cols = cursor.columns( + table='columns_test', + schema='nonexistent_schema_xyz123' + ) + assert len(schema_cols) == 0, "Should return empty list for non-existent schema" + +def test_columns_data_types(cursor): + """Test columns returns correct data type information""" + try: + # Get all columns from test table + cols = cursor.columns( + table='columns_test', + schema='pytest_cols_schema' + ).fetchall() + + # Create a dictionary mapping column names to their details + col_dict = {col.column_name.lower(): col for col in cols} + + # Check data types by name (case insensitive checks) + # Note: We're checking type_name as a string to avoid SQL type code inconsistencies + # between drivers + + # INT column + assert 'int' in col_dict['id'].type_name.lower(), "id should be INT type" + + # NVARCHAR column + assert any(name in col_dict['name'].type_name.lower() + for name in ['nvarchar', 'varchar', 'char', 'wchar']), "name should be NVARCHAR type" + + # DECIMAL column + assert any(name in col_dict['price'].type_name.lower() + for name in ['decimal', 'numeric', 'money']), "price should be DECIMAL type" + + # BIT column + assert any(name in col_dict['is_active'].type_name.lower() + for name in ['bit', 'boolean']), "is_active should be BIT type" + + # TEXT column + assert any(name in col_dict['notes'].type_name.lower() + for name in ['text', 'char', 'varchar']), "notes should be TEXT type" + + # Check nullable flag + assert col_dict['id'].nullable == 0, "id should be non-nullable" + assert col_dict['description'].nullable == 1, "description should be nullable" + + # Check column size + assert col_dict['name'].column_size == 100, "name should have size 100" + + # Check decimal digits for numeric type + assert col_dict['price'].decimal_digits == 2, "price should have 2 decimal digits" + + finally: + # Clean up happens in test_columns_cleanup + pass + +def test_columns_nonexistent(cursor): + """Test columns with non-existent table or column""" + # Test with non-existent table + table_cols = cursor.columns(table='nonexistent_table_xyz123').fetchall() + assert len(table_cols) == 0, "Should return empty list for non-existent table" + + # Test with non-existent column in existing table + col_cols = cursor.columns( + table='columns_test', + schema='pytest_cols_schema', + column='nonexistent_column_xyz123' + ).fetchall() + assert len(col_cols) == 0, "Should return empty list for non-existent column" + + # Test with non-existent schema + schema_cols = cursor.columns( + table='columns_test', + schema='nonexistent_schema_xyz123' + ).fetchall() + assert len(schema_cols) == 0, "Should return empty list for non-existent schema" + +def test_columns_catalog_filter(cursor): + """Test columns with catalog filter""" + try: + # Get current database name + cursor.execute("SELECT DB_NAME() AS current_db") + current_db = cursor.fetchone().current_db + + # Get columns with current catalog + cols = cursor.columns( + table='columns_test', + catalog=current_db, + schema='pytest_cols_schema' + ).fetchall() + + # Verify catalog filter worked + assert len(cols) > 0, "Should find columns with correct catalog" + + # Check catalog in results + for col in cols: + # Some drivers might return None for catalog + if col.table_cat is not None: + assert col.table_cat.lower() == current_db.lower(), "Wrong table catalog" + + # Test with non-existent catalog + fake_cols = cursor.columns( + table='columns_test', + catalog='nonexistent_db_xyz123', + schema='pytest_cols_schema' + ).fetchall() + assert len(fake_cols) == 0, "Should return empty list for non-existent catalog" + + finally: + # Clean up happens in test_columns_cleanup + pass + +def test_columns_schema_pattern(cursor): + """Test columns with schema name pattern""" + try: + # Get columns with schema pattern + cols = cursor.columns( + table='columns_test', + schema='pytest_%' + ).fetchall() + + # Should find our test table columns + test_cols = [col for col in cols if col.table_name.lower() == 'columns_test'] + assert len(test_cols) > 0, "Should find columns using schema pattern" + + # Try a more specific pattern + specific_cols = cursor.columns( + table='columns_test', + schema='pytest_cols%' + ).fetchall() + + # Should still find our test table columns + test_cols = [col for col in specific_cols if col.table_name.lower() == 'columns_test'] + assert len(test_cols) > 0, "Should find columns using specific schema pattern" + + finally: + # Clean up happens in test_columns_cleanup + pass + +def test_columns_table_pattern(cursor): + """Test columns with table name pattern""" + try: + # Get columns with table pattern + cols = cursor.columns( + table='columns_%', + schema='pytest_cols_schema' + ).fetchall() + + # Should find columns from both test tables + tables_found = set() + for col in cols: + if col.table_name: + tables_found.add(col.table_name.lower()) + + assert 'columns_test' in tables_found, "Should find columns_test with pattern columns_%" + assert 'columns_special_test' in tables_found, "Should find columns_special_test with pattern columns_%" + + finally: + # Clean up happens in test_columns_cleanup + pass + +def test_columns_ordinal_position(cursor): + """Test ordinal_position is correct in columns results""" + try: + # Get columns for the test table + cols = cursor.columns( + table='columns_test', + schema='pytest_cols_schema' + ).fetchall() + + # Sort by ordinal position + sorted_cols = sorted(cols, key=lambda col: col.ordinal_position) + + # Verify positions are consecutive starting from 1 + for i, col in enumerate(sorted_cols, 1): + assert col.ordinal_position == i, f"Column {col.column_name} should have ordinal_position {i}" + + # First column should be id (primary key) + assert sorted_cols[0].column_name.lower() == 'id', "First column should be id" + + finally: + # Clean up happens in test_columns_cleanup + pass + +def test_columns_cleanup(cursor, db_connection): + """Clean up test tables after testing""" + try: + # Drop all test tables + cursor.execute("DROP TABLE IF EXISTS pytest_cols_schema.columns_test") + cursor.execute("DROP TABLE IF EXISTS pytest_cols_schema.columns_special_test") + + # Drop the test schema + cursor.execute("DROP SCHEMA IF EXISTS pytest_cols_schema") + db_connection.commit() + except Exception as e: + pytest.fail(f"Test cleanup failed: {e}") + def test_close(db_connection): """Test closing the cursor""" try: