diff --git a/mssql_python/constants.py b/mssql_python/constants.py index a4e0c707..e63fbd1b 100644 --- a/mssql_python/constants.py +++ b/mssql_python/constants.py @@ -117,6 +117,15 @@ class ConstantsDDBC(Enum): SQL_NULLABLE = 1 SQL_MAX_NUMERIC_LEN = 16 SQL_ATTR_QUERY_TIMEOUT = 0 + SQL_SCOPE_CURROW = 0 + SQL_BEST_ROWID = 1 + SQL_ROWVER = 2 + SQL_NO_NULLS = 0 + SQL_NULLABLE_UNKNOWN = 2 + SQL_INDEX_UNIQUE = 0 + SQL_INDEX_ALL = 1 + SQL_QUICK = 0 + SQL_ENSURE = 1 class AuthType(Enum): """Constants for authentication types""" diff --git a/mssql_python/cursor.py b/mssql_python/cursor.py index 6bea2abf..a99dee65 100644 --- a/mssql_python/cursor.py +++ b/mssql_python/cursor.py @@ -16,7 +16,7 @@ from mssql_python.constants import ConstantsDDBC as ddbc_sql_const from mssql_python.helpers import check_error, log from mssql_python import ddbc_bindings -from mssql_python.exceptions import InterfaceError +from mssql_python.exceptions import InterfaceError, ProgrammingError from mssql_python.row import Row from mssql_python import get_settings @@ -46,6 +46,16 @@ class Cursor: setoutputsize(size, column=None) -> None. """ + # TODO(jathakkar): Thread safety considerations + # The cursor class contains methods that are not thread-safe due to: + # 1. Methods that mutate cursor state (_reset_cursor, self.description, etc.) + # 2. Methods that call ODBC functions with shared handles (self.hstmt) + # + # These methods should be properly synchronized or redesigned when implementing + # async functionality to prevent race conditions and data corruption. + # Consider using locks, redesigning for immutability, or ensuring + # cursor objects are never shared across threads. + def __init__(self, connection, timeout: int = 0) -> None: """ Initialize the cursor with a database connection. @@ -648,7 +658,7 @@ def execute( del self._original_fetchone del self._original_fetchmany del self._original_fetchall - + self._check_closed() # Check if the cursor is closed if reset_cursor: self._reset_cursor() @@ -904,6 +914,751 @@ def fetchall_with_mapping(): # Return the cursor itself return self + + def primaryKeys(self, table, catalog=None, schema=None): + """ + Creates a result set of column names that make up the primary key for a table + by executing the SQLPrimaryKeys function. + + Args: + table (str): The name of the table + catalog (str, optional): The catalog name (database). Defaults to None. + schema (str, optional): The schema name. Defaults to None. + + Returns: + list: A list of rows with the following columns: + - table_cat: Catalog name + - table_schem: Schema name + - table_name: Table name + - column_name: Column name that is part of the primary key + - key_seq: Column sequence number in the primary key (starting with 1) + - pk_name: Primary key name + + Raises: + ProgrammingError: If the cursor is closed + """ + self._check_closed() + + # Always reset the cursor first to ensure clean state + self._reset_cursor() + + if not table: + raise ProgrammingError("Table name must be specified", "HY000") + + # Call the SQLPrimaryKeys function + retcode = ddbc_bindings.DDBCSQLPrimaryKeys( + self.hstmt, + catalog, + schema, + table + ) + check_error(ddbc_sql_const.SQL_HANDLE_STMT.value, self.hstmt, retcode) + + # Initialize description from column metadata + column_metadata = [] + try: + ddbc_bindings.DDBCSQLDescribeCol(self.hstmt, column_metadata) + self._initialize_description(column_metadata) + except InterfaceError as e: + log('error', f"Driver interface error during metadata retrieval: {e}") + + except Exception as e: + # Log the exception with appropriate context + log('error', f"Failed to retrieve column metadata: {e}. Using standard ODBC column definitions instead.") + + if not self.description: + # If describe fails, create a manual description for the standard columns + column_types = [str, str, str, str, int, str] + self.description = [ + ("table_cat", column_types[0], None, 128, 128, 0, True), + ("table_schem", column_types[1], None, 128, 128, 0, True), + ("table_name", column_types[2], None, 128, 128, 0, False), + ("column_name", column_types[3], None, 128, 128, 0, False), + ("key_seq", column_types[4], None, 10, 10, 0, False), + ("pk_name", column_types[5], None, 128, 128, 0, True) + ] + + # Define column names in ODBC standard order + self._column_map = {} + for i, (name, *_) in enumerate(self.description): + # Add standard name + self._column_map[name] = i + # Add lowercase alias + self._column_map[name.lower()] = i + + # Remember original fetch methods (store only once) + if not hasattr(self, '_original_fetchone'): + self._original_fetchone = self.fetchone + self._original_fetchmany = self.fetchmany + self._original_fetchall = self.fetchall + + # Create wrapper fetch methods that add column mappings + def fetchone_with_mapping(): + row = self._original_fetchone() + if row is not None: + row._column_map = self._column_map + return row + + def fetchmany_with_mapping(size=None): + rows = self._original_fetchmany(size) + for row in rows: + row._column_map = self._column_map + return rows + + def fetchall_with_mapping(): + rows = self._original_fetchall() + for row in rows: + row._column_map = self._column_map + return rows + + # Replace fetch methods + self.fetchone = fetchone_with_mapping + self.fetchmany = fetchmany_with_mapping + self.fetchall = fetchall_with_mapping + + # Return the cursor itself + return self + + def foreignKeys(self, table=None, catalog=None, schema=None, foreignTable=None, foreignCatalog=None, foreignSchema=None): + """ + Executes the SQLForeignKeys function and creates a result set of column names that are foreign keys. + + This function returns: + 1. Foreign keys in the specified table that reference primary keys in other tables, OR + 2. Foreign keys in other tables that reference the primary key in the specified table + + Args: + table (str, optional): The table containing the foreign key columns + catalog (str, optional): The catalog containing table + schema (str, optional): The schema containing table + foreignTable (str, optional): The table containing the primary key columns + foreignCatalog (str, optional): The catalog containing foreignTable + foreignSchema (str, optional): The schema containing foreignTable + + Returns: + List of Row objects, each containing foreign key information with these columns: + - pktable_cat (str): Primary key table catalog name + - pktable_schem (str): Primary key table schema name + - pktable_name (str): Primary key table name + - pkcolumn_name (str): Primary key column name + - fktable_cat (str): Foreign key table catalog name + - fktable_schem (str): Foreign key table schema name + - fktable_name (str): Foreign key table name + - fkcolumn_name (str): Foreign key column name + - key_seq (int): Sequence number of the column in the foreign key + - update_rule (int): Action for update (CASCADE, SET NULL, etc.) + - delete_rule (int): Action for delete (CASCADE, SET NULL, etc.) + - fk_name (str): Foreign key name + - pk_name (str): Primary key name + - deferrability (int): Indicates if constraint checking can be deferred + """ + self._check_closed() + + # Always reset the cursor first to ensure clean state + self._reset_cursor() + + # Check if we have at least one table specified - mimic pyodbc behavior + if table is None and foreignTable is None: + raise ProgrammingError("Either table or foreignTable must be specified", "HY000") + + # Call the SQLForeignKeys function + retcode = ddbc_bindings.DDBCSQLForeignKeys( + self.hstmt, + foreignCatalog, foreignSchema, foreignTable, + catalog, schema, table + ) + check_error(ddbc_sql_const.SQL_HANDLE_STMT.value, self.hstmt, retcode) + + # Initialize description from column metadata + column_metadata = [] + try: + ddbc_bindings.DDBCSQLDescribeCol(self.hstmt, column_metadata) + self._initialize_description(column_metadata) + + except InterfaceError as e: + log('error', f"Driver interface error during metadata retrieval: {e}") + + except Exception as e: + # Log the exception with appropriate context + log('error', f"Failed to retrieve column metadata: {e}. Using standard ODBC column definitions instead.") + + if not self.description: + # If describe fails, create a manual description for the standard columns + column_types = [str, str, str, str, str, str, str, str, int, int, int, str, str, int] + self.description = [ + ("pktable_cat", column_types[0], None, 128, 128, 0, True), + ("pktable_schem", column_types[1], None, 128, 128, 0, True), + ("pktable_name", column_types[2], None, 128, 128, 0, False), + ("pkcolumn_name", column_types[3], None, 128, 128, 0, False), + ("fktable_cat", column_types[4], None, 128, 128, 0, True), + ("fktable_schem", column_types[5], None, 128, 128, 0, True), + ("fktable_name", column_types[6], None, 128, 128, 0, False), + ("fkcolumn_name", column_types[7], None, 128, 128, 0, False), + ("key_seq", column_types[8], None, 10, 10, 0, False), + ("update_rule", column_types[9], None, 10, 10, 0, False), + ("delete_rule", column_types[10], None, 10, 10, 0, False), + ("fk_name", column_types[11], None, 128, 128, 0, True), + ("pk_name", column_types[12], None, 128, 128, 0, True), + ("deferrability", column_types[13], None, 10, 10, 0, False) + ] + + # Define column names in ODBC standard order + self._column_map = {} + for i, (name, *_) in enumerate(self.description): + # Add standard name + self._column_map[name] = i + # Add lowercase alias + self._column_map[name.lower()] = i + + # Remember original fetch methods (store only once) + if not hasattr(self, '_original_fetchone'): + self._original_fetchone = self.fetchone + self._original_fetchmany = self.fetchmany + self._original_fetchall = self.fetchall + + # Create wrapper fetch methods that add column mappings + def fetchone_with_mapping(): + row = self._original_fetchone() + if row is not None: + row._column_map = self._column_map + return row + + def fetchmany_with_mapping(size=None): + rows = self._original_fetchmany(size) + for row in rows: + row._column_map = self._column_map + return rows + + def fetchall_with_mapping(): + rows = self._original_fetchall() + for row in rows: + row._column_map = self._column_map + return rows + + # Replace fetch methods + self.fetchone = fetchone_with_mapping + self.fetchmany = fetchmany_with_mapping + self.fetchall = fetchall_with_mapping + + # Return the cursor itself + return self + + def rowIdColumns(self, table, catalog=None, schema=None, nullable=True): + """ + Executes SQLSpecialColumns with SQL_BEST_ROWID which creates a result set of + columns that uniquely identify a row. + + Args: + table (str): The table name + catalog (str, optional): The catalog name (database). Defaults to None. + schema (str, optional): The schema name. Defaults to None. + nullable (bool, optional): Whether to include nullable columns. Defaults to True. + + Returns: + list: A list of rows with the following columns: + - scope: One of SQL_SCOPE_CURROW, SQL_SCOPE_TRANSACTION, or SQL_SCOPE_SESSION + - column_name: Column name + - data_type: The ODBC SQL data type constant (e.g. SQL_CHAR) + - type_name: Type name + - column_size: Column size + - buffer_length: Buffer length + - decimal_digits: Decimal digits + - pseudo_column: One of SQL_PC_UNKNOWN, SQL_PC_NOT_PSEUDO, SQL_PC_PSEUDO + """ + self._check_closed() + + # Always reset the cursor first to ensure clean state + self._reset_cursor() + + # Convert None values to empty strings as required by ODBC API + if not table: + raise ProgrammingError("Table name must be specified", "HY000") + + # Set the identifier type to SQL_BEST_ROWID (1) + identifier_type = ddbc_sql_const.SQL_BEST_ROWID.value + + # Set scope to SQL_SCOPE_CURROW (0) - default scope + scope = ddbc_sql_const.SQL_SCOPE_CURROW.value + + # Set nullable flag + nullable_flag = ddbc_sql_const.SQL_NULLABLE.value if nullable else ddbc_sql_const.SQL_NO_NULLS.value + + # Call the SQLSpecialColumns function + retcode = ddbc_bindings.DDBCSQLSpecialColumns( + self.hstmt, + identifier_type, + catalog, + schema, + table, + scope, + nullable_flag + ) + check_error(ddbc_sql_const.SQL_HANDLE_STMT.value, self.hstmt, retcode) + + # Initialize description from column metadata + column_metadata = [] + try: + ddbc_bindings.DDBCSQLDescribeCol(self.hstmt, column_metadata) + self._initialize_description(column_metadata) + + except InterfaceError as e: + log('error', f"Driver interface error during metadata retrieval: {e}") + + except Exception as e: + # Log the exception with appropriate context + log('error', f"Failed to retrieve column metadata: {e}. Using standard ODBC column definitions instead.") + + if not self.description: + # If describe fails, create a manual description for the standard columns + column_types = [int, str, int, str, int, int, int, int] + self.description = [ + ("scope", column_types[0], None, 10, 10, 0, False), + ("column_name", column_types[1], None, 128, 128, 0, False), + ("data_type", column_types[2], None, 10, 10, 0, False), + ("type_name", column_types[3], None, 128, 128, 0, False), + ("column_size", column_types[4], None, 10, 10, 0, False), + ("buffer_length", column_types[5], None, 10, 10, 0, False), + ("decimal_digits", column_types[6], None, 10, 10, 0, True), + ("pseudo_column", column_types[7], None, 10, 10, 0, False) + ] + + # Create a column map with both ODBC standard names and lowercase aliases + self._column_map = {} + for i, (name, *_) in enumerate(self.description): + # Add standard name + self._column_map[name] = i + # Add lowercase alias + self._column_map[name.lower()] = i + + # Remember original fetch methods (store only once) + if not hasattr(self, '_original_fetchone'): + self._original_fetchone = self.fetchone + self._original_fetchmany = self.fetchmany + self._original_fetchall = self.fetchall + + # Create wrapper fetch methods that add column mappings + def fetchone_with_mapping(): + row = self._original_fetchone() + if row is not None: + row._column_map = self._column_map + return row + + def fetchmany_with_mapping(size=None): + rows = self._original_fetchmany(size) + for row in rows: + row._column_map = self._column_map + return rows + + def fetchall_with_mapping(): + rows = self._original_fetchall() + for row in rows: + row._column_map = self._column_map + return rows + + # Replace fetch methods + self.fetchone = fetchone_with_mapping + self.fetchmany = fetchmany_with_mapping + self.fetchall = fetchall_with_mapping + + # Return the cursor itself + return self + + def rowVerColumns(self, table, catalog=None, schema=None, nullable=True): + """ + Executes SQLSpecialColumns with SQL_ROWVER which creates a result set of + columns that are automatically updated when any value in the row is updated. + + Args: + table (str): The table name + catalog (str, optional): The catalog name (database). Defaults to None. + schema (str, optional): The schema name. Defaults to None. + nullable (bool, optional): Whether to include nullable columns. Defaults to True. + + Returns: + list: A list of rows with the following columns: + - scope: One of SQL_SCOPE_CURROW, SQL_SCOPE_TRANSACTION, or SQL_SCOPE_SESSION + - column_name: Column name + - data_type: The ODBC SQL data type constant (e.g. SQL_CHAR) + - type_name: Type name + - column_size: Column size + - buffer_length: Buffer length + - decimal_digits: Decimal digits + - pseudo_column: One of SQL_PC_UNKNOWN, SQL_PC_NOT_PSEUDO, SQL_PC_PSEUDO + """ + self._check_closed() + + # Always reset the cursor first to ensure clean state + self._reset_cursor() + + if not table: + raise ProgrammingError("Table name must be specified", "HY000") + + # Set the identifier type to SQL_ROWVER (2) + identifier_type = ddbc_sql_const.SQL_ROWVER.value + + # Set scope to SQL_SCOPE_CURROW (0) - default scope + scope = ddbc_sql_const.SQL_SCOPE_CURROW.value + + # Set nullable flag + nullable_flag = ddbc_sql_const.SQL_NULLABLE.value if nullable else ddbc_sql_const.SQL_NO_NULLS.value + + # Call the SQLSpecialColumns function + retcode = ddbc_bindings.DDBCSQLSpecialColumns( + self.hstmt, + identifier_type, + catalog, + schema, + table, + scope, + nullable_flag + ) + check_error(ddbc_sql_const.SQL_HANDLE_STMT.value, self.hstmt, retcode) + + # Initialize description from column metadata + column_metadata = [] + try: + ddbc_bindings.DDBCSQLDescribeCol(self.hstmt, column_metadata) + self._initialize_description(column_metadata) + + except InterfaceError as e: + log('error', f"Driver interface error during metadata retrieval: {e}") + + except Exception as e: + # Log the exception with appropriate context + log('error', f"Failed to retrieve column metadata: {e}. Using standard ODBC column definitions instead.") + + if not self.description: + # If describe fails, create a manual description for the standard columns + column_types = [int, str, int, str, int, int, int, int] + self.description = [ + ("scope", column_types[0], None, 10, 10, 0, False), + ("column_name", column_types[1], None, 128, 128, 0, False), + ("data_type", column_types[2], None, 10, 10, 0, False), + ("type_name", column_types[3], None, 128, 128, 0, False), + ("column_size", column_types[4], None, 10, 10, 0, False), + ("buffer_length", column_types[5], None, 10, 10, 0, False), + ("decimal_digits", column_types[6], None, 10, 10, 0, True), + ("pseudo_column", column_types[7], None, 10, 10, 0, False) + ] + + # Create a column map with both ODBC standard names and lowercase aliases + self._column_map = {} + for i, (name, *_) in enumerate(self.description): + # Add standard name + self._column_map[name] = i + # Add lowercase alias + self._column_map[name.lower()] = i + + # Remember original fetch methods (store only once) + if not hasattr(self, '_original_fetchone'): + self._original_fetchone = self.fetchone + self._original_fetchmany = self.fetchmany + self._original_fetchall = self.fetchall + + # Create wrapper fetch methods that add column mappings + def fetchone_with_mapping(): + row = self._original_fetchone() + if row is not None: + row._column_map = self._column_map + return row + + def fetchmany_with_mapping(size=None): + rows = self._original_fetchmany(size) + for row in rows: + row._column_map = self._column_map + return rows + + def fetchall_with_mapping(): + rows = self._original_fetchall() + for row in rows: + row._column_map = self._column_map + return rows + + # Replace fetch methods + self.fetchone = fetchone_with_mapping + self.fetchmany = fetchmany_with_mapping + self.fetchall = fetchall_with_mapping + + # Return the cursor itself + return self + + def statistics(self, table: str, catalog: str = None, schema: str = None, unique: bool = False, quick: bool = True) -> 'Cursor': + """ + Creates a result set of statistics about a single table and the indexes associated + with the table by executing SQLStatistics. + + Args: + table (str): The name of the table. + catalog (str, optional): The catalog name. Defaults to None. + schema (str, optional): The schema name. Defaults to None. + unique (bool, optional): If True, only unique indexes are returned. + If False, all indexes are returned. Defaults to False. + quick (bool, optional): If True, CARDINALITY and PAGES are returned only + if readily available. Defaults to True. + + Returns: + cursor: The cursor itself, containing the result set. Use fetchone(), fetchmany(), + or fetchall() to retrieve the results. + + Example: + # Get statistics for the 'Customers' table + stats_cursor = cursor.statistics(table='Customers') + + # Fetch rows as needed + first_stat = stats_cursor.fetchone() + next_10_stats = stats_cursor.fetchmany(10) + all_remaining = stats_cursor.fetchall() + """ + self._check_closed() + + # Always reset the cursor first to ensure clean state + self._reset_cursor() + + # Table name is required + if not table: + raise ProgrammingError("Table name is required", "HY000") + + # Set unique flag (SQL_INDEX_UNIQUE = 0, SQL_INDEX_ALL = 1) + unique_option = ddbc_sql_const.SQL_INDEX_UNIQUE.value if unique else ddbc_sql_const.SQL_INDEX_ALL.value + + # Set quick flag (SQL_QUICK = 0, SQL_ENSURE = 1) + reserved_option = ddbc_sql_const.SQL_QUICK.value if quick else ddbc_sql_const.SQL_ENSURE.value + + # Call the SQLStatistics function + retcode = ddbc_bindings.DDBCSQLStatistics( + self.hstmt, + catalog, + schema, + table, + unique_option, + reserved_option + ) + check_error(ddbc_sql_const.SQL_HANDLE_STMT.value, self.hstmt, retcode) + + # Initialize description from column metadata + column_metadata = [] + try: + ddbc_bindings.DDBCSQLDescribeCol(self.hstmt, column_metadata) + self._initialize_description(column_metadata) + except InterfaceError as e: + log('error', f"Driver interface error during metadata retrieval: {e}") + + except Exception as e: + # Log the exception with appropriate context + log('error', f"Failed to retrieve column metadata: {e}. Using standard ODBC column definitions instead.") + + if not self.description: + # If describe fails, create a manual description for the standard columns + column_types = [str, str, str, bool, str, str, int, int, str, str, int, int, str] + self.description = [ + ("table_cat", column_types[0], None, 128, 128, 0, True), + ("table_schem", column_types[1], None, 128, 128, 0, True), + ("table_name", column_types[2], None, 128, 128, 0, False), + ("non_unique", column_types[3], None, 1, 1, 0, False), + ("index_qualifier", column_types[4], None, 128, 128, 0, True), + ("index_name", column_types[5], None, 128, 128, 0, True), + ("type", column_types[6], None, 10, 10, 0, False), + ("ordinal_position", column_types[7], None, 10, 10, 0, False), + ("column_name", column_types[8], None, 128, 128, 0, True), + ("asc_or_desc", column_types[9], None, 1, 1, 0, True), + ("cardinality", column_types[10], None, 20, 20, 0, True), + ("pages", column_types[11], None, 20, 20, 0, True), + ("filter_condition", column_types[12], None, 128, 128, 0, True) + ] + + # Create a column map with both ODBC standard names and lowercase aliases + self._column_map = {} + for i, (name, *_) in enumerate(self.description): + # Add standard name + self._column_map[name] = i + # Add lowercase alias + self._column_map[name.lower()] = i + + # Remember original fetch methods (store only once) + if not hasattr(self, '_original_fetchone'): + self._original_fetchone = self.fetchone + self._original_fetchmany = self.fetchmany + self._original_fetchall = self.fetchall + + # Create wrapper fetch methods that add column mappings + def fetchone_with_mapping(): + row = self._original_fetchone() + if row is not None: + row._column_map = self._column_map + return row + + def fetchmany_with_mapping(size=None): + rows = self._original_fetchmany(size) + for row in rows: + row._column_map = self._column_map + return rows + + def fetchall_with_mapping(): + rows = self._original_fetchall() + for row in rows: + row._column_map = self._column_map + return rows + + # Replace fetch methods + self.fetchone = fetchone_with_mapping + self.fetchmany = fetchmany_with_mapping + self.fetchall = fetchall_with_mapping + + return self + + def columns(self, table=None, catalog=None, schema=None, column=None): + """ + Creates a result set of column information in the specified tables + using the SQLColumns function. + + Args: + table (str, optional): The table name pattern. Default is None (all tables). + catalog (str, optional): The catalog name. Default is None (current catalog). + schema (str, optional): The schema name pattern. Default is None (all schemas). + column (str, optional): The column name pattern. Default is None (all columns). + + Returns: + cursor: The cursor itself, containing the result set. Use fetchone(), fetchmany(), + or fetchall() to retrieve the results. + + Each row contains the following columns: + - table_cat (str): Catalog name + - table_schem (str): Schema name + - table_name (str): Table name + - column_name (str): Column name + - data_type (int): The ODBC SQL data type constant (e.g. SQL_CHAR) + - type_name (str): Data source dependent type name + - column_size (int): Column size + - buffer_length (int): Length of the column in bytes + - decimal_digits (int): Number of fractional digits + - num_prec_radix (int): Radix (typically 10 or 2) + - nullable (int): One of SQL_NO_NULLS, SQL_NULLABLE, SQL_NULLABLE_UNKNOWN + - remarks (str): Comments about the column + - column_def (str): Default value for the column + - sql_data_type (int): The SQL data type from java.sql.Types + - sql_datetime_sub (int): Subcode for datetime types + - char_octet_length (int): Maximum length in bytes for char types + - ordinal_position (int): Column position in the table (starting at 1) + - is_nullable (str): "YES", "NO", or "" (unknown) + + Warning: + Calling this method without any filters (all parameters as None) will enumerate + EVERY column in EVERY table in the database. This can be extremely expensive in + large databases, potentially causing high memory usage, slow execution times, + and in extreme cases, timeout errors. Always use filters (catalog, schema, table, + or column) whenever possible to limit the result set. + + Example: + # Get all columns in table 'Customers' + columns = cursor.columns(table='Customers') + + # Get all columns in table 'Customers' in schema 'dbo' + columns = cursor.columns(table='Customers', schema='dbo') + + # Get column named 'CustomerID' in any table + columns = cursor.columns(column='CustomerID') + """ + self._check_closed() + + # Always reset the cursor first to ensure clean state + self._reset_cursor() + + # Call the SQLColumns function + retcode = ddbc_bindings.DDBCSQLColumns( + self.hstmt, + catalog, + schema, + table, + column + ) + check_error(ddbc_sql_const.SQL_HANDLE_STMT.value, self.hstmt, retcode) + + # Initialize description from column metadata + column_metadata = [] + try: + ddbc_bindings.DDBCSQLDescribeCol(self.hstmt, column_metadata) + self._initialize_description(column_metadata) + except InterfaceError as e: + log('error', f"Driver interface error during metadata retrieval: {e}") + + except Exception as e: + # Log the exception with appropriate context + log('error', f"Failed to retrieve column metadata: {e}. Using standard ODBC column definitions instead.") + + if not self.description: + # If describe fails, create a manual description for the standard columns + column_types = [str, str, str, str, int, str, int, int, int, int, int, str, str, int, int, int, int, str] + self.description = [ + ("table_cat", column_types[0], None, 128, 128, 0, True), + ("table_schem", column_types[1], None, 128, 128, 0, True), + ("table_name", column_types[2], None, 128, 128, 0, False), + ("column_name", column_types[3], None, 128, 128, 0, False), + ("data_type", column_types[4], None, 10, 10, 0, False), + ("type_name", column_types[5], None, 128, 128, 0, False), + ("column_size", column_types[6], None, 10, 10, 0, True), + ("buffer_length", column_types[7], None, 10, 10, 0, True), + ("decimal_digits", column_types[8], None, 10, 10, 0, True), + ("num_prec_radix", column_types[9], None, 10, 10, 0, True), + ("nullable", column_types[10], None, 10, 10, 0, False), + ("remarks", column_types[11], None, 254, 254, 0, True), + ("column_def", column_types[12], None, 254, 254, 0, True), + ("sql_data_type", column_types[13], None, 10, 10, 0, False), + ("sql_datetime_sub", column_types[14], None, 10, 10, 0, True), + ("char_octet_length", column_types[15], None, 10, 10, 0, True), + ("ordinal_position", column_types[16], None, 10, 10, 0, False), + ("is_nullable", column_types[17], None, 254, 254, 0, True) + ] + + # Store the column mappings for this specific columns() call + column_names = [desc[0] for desc in self.description] + + # Create a specialized column map for this result set + columns_map = {} + for i, name in enumerate(column_names): + columns_map[name] = i + columns_map[name.lower()] = i + + # Define wrapped fetch methods that preserve existing column mapping + # but add our specialized mapping just for column results + def fetchone_with_columns_mapping(): + row = self._original_fetchone() + if row is not None: + # Create a merged map with columns result taking precedence + merged_map = getattr(row, '_column_map', {}).copy() + merged_map.update(columns_map) + row._column_map = merged_map + return row + + def fetchmany_with_columns_mapping(size=None): + rows = self._original_fetchmany(size) + for row in rows: + # Create a merged map with columns result taking precedence + merged_map = getattr(row, '_column_map', {}).copy() + merged_map.update(columns_map) + row._column_map = merged_map + return rows + + def fetchall_with_columns_mapping(): + rows = self._original_fetchall() + for row in rows: + # Create a merged map with columns result taking precedence + merged_map = getattr(row, '_column_map', {}).copy() + merged_map.update(columns_map) + row._column_map = merged_map + return rows + + # Save original fetch methods + if not hasattr(self, '_original_fetchone'): + self._original_fetchone = self.fetchone + self._original_fetchmany = self.fetchmany + self._original_fetchall = self.fetchall + + # Override fetch methods with our wrapped versions + self.fetchone = fetchone_with_columns_mapping + self.fetchmany = fetchmany_with_columns_mapping + self.fetchall = fetchall_with_columns_mapping + + return self @staticmethod def _select_best_sample_value(column): diff --git a/mssql_python/pybind/ddbc_bindings.cpp b/mssql_python/pybind/ddbc_bindings.cpp index 6bc96f42..f3aed22a 100644 --- a/mssql_python/pybind/ddbc_bindings.cpp +++ b/mssql_python/pybind/ddbc_bindings.cpp @@ -125,6 +125,11 @@ SQLMoreResultsFunc SQLMoreResults_ptr = nullptr; SQLColAttributeFunc SQLColAttribute_ptr = nullptr; SQLGetTypeInfoFunc SQLGetTypeInfo_ptr = nullptr; SQLProceduresFunc SQLProcedures_ptr = nullptr; +SQLForeignKeysFunc SQLForeignKeys_ptr = nullptr; +SQLPrimaryKeysFunc SQLPrimaryKeys_ptr = nullptr; +SQLSpecialColumnsFunc SQLSpecialColumns_ptr = nullptr; +SQLStatisticsFunc SQLStatistics_ptr = nullptr; +SQLColumnsFunc SQLColumns_ptr = nullptr; // Transaction APIs SQLEndTranFunc SQLEndTran_ptr = nullptr; @@ -783,6 +788,11 @@ DriverHandle LoadDriverOrThrowException() { SQLColAttribute_ptr = GetFunctionPointer(handle, "SQLColAttributeW"); SQLGetTypeInfo_ptr = GetFunctionPointer(handle, "SQLGetTypeInfoW"); SQLProcedures_ptr = GetFunctionPointer(handle, "SQLProceduresW"); + SQLForeignKeys_ptr = GetFunctionPointer(handle, "SQLForeignKeysW"); + SQLPrimaryKeys_ptr = GetFunctionPointer(handle, "SQLPrimaryKeysW"); + SQLSpecialColumns_ptr = GetFunctionPointer(handle, "SQLSpecialColumnsW"); + SQLStatistics_ptr = GetFunctionPointer(handle, "SQLStatisticsW"); + SQLColumns_ptr = GetFunctionPointer(handle, "SQLColumnsW"); SQLEndTran_ptr = GetFunctionPointer(handle, "SQLEndTran"); SQLDisconnect_ptr = GetFunctionPointer(handle, "SQLDisconnect"); @@ -801,7 +811,9 @@ DriverHandle LoadDriverOrThrowException() { SQLDescribeCol_ptr && SQLMoreResults_ptr && SQLColAttribute_ptr && SQLEndTran_ptr && SQLDisconnect_ptr && SQLFreeHandle_ptr && SQLFreeStmt_ptr && SQLGetDiagRec_ptr && - SQLGetTypeInfo_ptr && SQLProcedures_ptr; + SQLGetTypeInfo_ptr && SQLProcedures_ptr && SQLForeignKeys_ptr && + SQLPrimaryKeys_ptr && SQLSpecialColumns_ptr && SQLStatistics_ptr && + SQLColumns_ptr; if (!success) { ThrowStdException("Failed to load required function pointers from driver."); @@ -913,6 +925,197 @@ SQLRETURN SQLProcedures_wrap(SqlHandlePtr StatementHandle, #endif } +SQLRETURN SQLForeignKeys_wrap(SqlHandlePtr StatementHandle, + const py::object& pkCatalogObj, + const py::object& pkSchemaObj, + const py::object& pkTableObj, + const py::object& fkCatalogObj, + const py::object& fkSchemaObj, + const py::object& fkTableObj) { + if (!SQLForeignKeys_ptr) { + ThrowStdException("SQLForeignKeys function not loaded"); + } + + std::wstring pkCatalog = py::isinstance(pkCatalogObj) ? L"" : pkCatalogObj.cast(); + std::wstring pkSchema = py::isinstance(pkSchemaObj) ? L"" : pkSchemaObj.cast(); + std::wstring pkTable = py::isinstance(pkTableObj) ? L"" : pkTableObj.cast(); + std::wstring fkCatalog = py::isinstance(fkCatalogObj) ? L"" : fkCatalogObj.cast(); + std::wstring fkSchema = py::isinstance(fkSchemaObj) ? L"" : fkSchemaObj.cast(); + std::wstring fkTable = py::isinstance(fkTableObj) ? L"" : fkTableObj.cast(); + +#if defined(__APPLE__) || defined(__linux__) + // Unix implementation + std::vector pkCatalogBuf = WStringToSQLWCHAR(pkCatalog); + std::vector pkSchemaBuf = WStringToSQLWCHAR(pkSchema); + std::vector pkTableBuf = WStringToSQLWCHAR(pkTable); + std::vector fkCatalogBuf = WStringToSQLWCHAR(fkCatalog); + std::vector fkSchemaBuf = WStringToSQLWCHAR(fkSchema); + std::vector fkTableBuf = WStringToSQLWCHAR(fkTable); + + return SQLForeignKeys_ptr( + StatementHandle->get(), + pkCatalog.empty() ? nullptr : pkCatalogBuf.data(), + pkCatalog.empty() ? 0 : SQL_NTS, + pkSchema.empty() ? nullptr : pkSchemaBuf.data(), + pkSchema.empty() ? 0 : SQL_NTS, + pkTable.empty() ? nullptr : pkTableBuf.data(), + pkTable.empty() ? 0 : SQL_NTS, + fkCatalog.empty() ? nullptr : fkCatalogBuf.data(), + fkCatalog.empty() ? 0 : SQL_NTS, + fkSchema.empty() ? nullptr : fkSchemaBuf.data(), + fkSchema.empty() ? 0 : SQL_NTS, + fkTable.empty() ? nullptr : fkTableBuf.data(), + fkTable.empty() ? 0 : SQL_NTS); +#else + // Windows implementation + return SQLForeignKeys_ptr( + StatementHandle->get(), + pkCatalog.empty() ? nullptr : (SQLWCHAR*)pkCatalog.c_str(), + pkCatalog.empty() ? 0 : SQL_NTS, + pkSchema.empty() ? nullptr : (SQLWCHAR*)pkSchema.c_str(), + pkSchema.empty() ? 0 : SQL_NTS, + pkTable.empty() ? nullptr : (SQLWCHAR*)pkTable.c_str(), + pkTable.empty() ? 0 : SQL_NTS, + fkCatalog.empty() ? nullptr : (SQLWCHAR*)fkCatalog.c_str(), + fkCatalog.empty() ? 0 : SQL_NTS, + fkSchema.empty() ? nullptr : (SQLWCHAR*)fkSchema.c_str(), + fkSchema.empty() ? 0 : SQL_NTS, + fkTable.empty() ? nullptr : (SQLWCHAR*)fkTable.c_str(), + fkTable.empty() ? 0 : SQL_NTS); +#endif +} + +SQLRETURN SQLPrimaryKeys_wrap(SqlHandlePtr StatementHandle, + const py::object& catalogObj, + const py::object& schemaObj, + const std::wstring& table) { + if (!SQLPrimaryKeys_ptr) { + ThrowStdException("SQLPrimaryKeys function not loaded"); + } + + // Convert py::object to std::wstring, treating None as empty string + std::wstring catalog = catalogObj.is_none() ? L"" : catalogObj.cast(); + std::wstring schema = schemaObj.is_none() ? L"" : schemaObj.cast(); + +#if defined(__APPLE__) || defined(__linux__) + // Unix implementation + std::vector catalogBuf = WStringToSQLWCHAR(catalog); + std::vector schemaBuf = WStringToSQLWCHAR(schema); + std::vector tableBuf = WStringToSQLWCHAR(table); + + return SQLPrimaryKeys_ptr( + StatementHandle->get(), + catalog.empty() ? nullptr : catalogBuf.data(), + catalog.empty() ? 0 : SQL_NTS, + schema.empty() ? nullptr : schemaBuf.data(), + schema.empty() ? 0 : SQL_NTS, + table.empty() ? nullptr : tableBuf.data(), + table.empty() ? 0 : SQL_NTS); +#else + // Windows implementation + return SQLPrimaryKeys_ptr( + StatementHandle->get(), + catalog.empty() ? nullptr : (SQLWCHAR*)catalog.c_str(), + catalog.empty() ? 0 : SQL_NTS, + schema.empty() ? nullptr : (SQLWCHAR*)schema.c_str(), + schema.empty() ? 0 : SQL_NTS, + table.empty() ? nullptr : (SQLWCHAR*)table.c_str(), + table.empty() ? 0 : SQL_NTS); +#endif +} + +SQLRETURN SQLStatistics_wrap(SqlHandlePtr StatementHandle, + const py::object& catalogObj, + const py::object& schemaObj, + const std::wstring& table, + SQLUSMALLINT unique, + SQLUSMALLINT reserved) { + if (!SQLStatistics_ptr) { + ThrowStdException("SQLStatistics function not loaded"); + } + + // Convert py::object to std::wstring, treating None as empty string + std::wstring catalog = catalogObj.is_none() ? L"" : catalogObj.cast(); + std::wstring schema = schemaObj.is_none() ? L"" : schemaObj.cast(); + +#if defined(__APPLE__) || defined(__linux__) + // Unix implementation + std::vector catalogBuf = WStringToSQLWCHAR(catalog); + std::vector schemaBuf = WStringToSQLWCHAR(schema); + std::vector tableBuf = WStringToSQLWCHAR(table); + + return SQLStatistics_ptr( + StatementHandle->get(), + catalog.empty() ? nullptr : catalogBuf.data(), + catalog.empty() ? 0 : SQL_NTS, + schema.empty() ? nullptr : schemaBuf.data(), + schema.empty() ? 0 : SQL_NTS, + table.empty() ? nullptr : tableBuf.data(), + table.empty() ? 0 : SQL_NTS, + unique, + reserved); +#else + // Windows implementation + return SQLStatistics_ptr( + StatementHandle->get(), + catalog.empty() ? nullptr : (SQLWCHAR*)catalog.c_str(), + catalog.empty() ? 0 : SQL_NTS, + schema.empty() ? nullptr : (SQLWCHAR*)schema.c_str(), + schema.empty() ? 0 : SQL_NTS, + table.empty() ? nullptr : (SQLWCHAR*)table.c_str(), + table.empty() ? 0 : SQL_NTS, + unique, + reserved); +#endif +} + +SQLRETURN SQLColumns_wrap(SqlHandlePtr StatementHandle, + const py::object& catalogObj, + const py::object& schemaObj, + const py::object& tableObj, + const py::object& columnObj) { + if (!SQLColumns_ptr) { + ThrowStdException("SQLColumns function not loaded"); + } + + // Convert py::object to std::wstring, treating None as empty string + std::wstring catalogStr = catalogObj.is_none() ? L"" : catalogObj.cast(); + std::wstring schemaStr = schemaObj.is_none() ? L"" : schemaObj.cast(); + std::wstring tableStr = tableObj.is_none() ? L"" : tableObj.cast(); + std::wstring columnStr = columnObj.is_none() ? L"" : columnObj.cast(); + +#if defined(__APPLE__) || defined(__linux__) + // Unix implementation + std::vector catalogBuf = WStringToSQLWCHAR(catalogStr); + std::vector schemaBuf = WStringToSQLWCHAR(schemaStr); + std::vector tableBuf = WStringToSQLWCHAR(tableStr); + std::vector columnBuf = WStringToSQLWCHAR(columnStr); + + return SQLColumns_ptr( + StatementHandle->get(), + catalogStr.empty() ? nullptr : catalogBuf.data(), + catalogStr.empty() ? 0 : SQL_NTS, + schemaStr.empty() ? nullptr : schemaBuf.data(), + schemaStr.empty() ? 0 : SQL_NTS, + tableStr.empty() ? nullptr : tableBuf.data(), + tableStr.empty() ? 0 : SQL_NTS, + columnStr.empty() ? nullptr : columnBuf.data(), + columnStr.empty() ? 0 : SQL_NTS); +#else + // Windows implementation + return SQLColumns_ptr( + StatementHandle->get(), + catalogStr.empty() ? nullptr : (SQLWCHAR*)catalogStr.c_str(), + catalogStr.empty() ? 0 : SQL_NTS, + schemaStr.empty() ? nullptr : (SQLWCHAR*)schemaStr.c_str(), + schemaStr.empty() ? 0 : SQL_NTS, + tableStr.empty() ? nullptr : (SQLWCHAR*)tableStr.c_str(), + tableStr.empty() ? 0 : SQL_NTS, + columnStr.empty() ? nullptr : (SQLWCHAR*)columnStr.c_str(), + columnStr.empty() ? 0 : SQL_NTS); +#endif +} + // Helper function to check for driver errors ErrorInfo SQLCheckError_Wrap(SQLSMALLINT handleType, SqlHandlePtr handle, SQLRETURN retcode) { LOG("Checking errors for retcode - {}" , retcode); @@ -1470,6 +1673,54 @@ SQLRETURN SQLDescribeCol_wrap(SqlHandlePtr StatementHandle, py::list& ColumnMeta return SQL_SUCCESS; } +SQLRETURN SQLSpecialColumns_wrap(SqlHandlePtr StatementHandle, + SQLSMALLINT identifierType, + const py::object& catalogObj, + const py::object& schemaObj, + const std::wstring& table, + SQLSMALLINT scope, + SQLSMALLINT nullable) { + if (!SQLSpecialColumns_ptr) { + ThrowStdException("SQLSpecialColumns function not loaded"); + } + + // Convert py::object to std::wstring, treating None as empty string + std::wstring catalog = catalogObj.is_none() ? L"" : catalogObj.cast(); + std::wstring schema = schemaObj.is_none() ? L"" : schemaObj.cast(); + +#if defined(__APPLE__) || defined(__linux__) + // Unix implementation + std::vector catalogBuf = WStringToSQLWCHAR(catalog); + std::vector schemaBuf = WStringToSQLWCHAR(schema); + std::vector tableBuf = WStringToSQLWCHAR(table); + + return SQLSpecialColumns_ptr( + StatementHandle->get(), + identifierType, + catalog.empty() ? nullptr : catalogBuf.data(), + catalog.empty() ? 0 : SQL_NTS, + schema.empty() ? nullptr : schemaBuf.data(), + schema.empty() ? 0 : SQL_NTS, + table.empty() ? nullptr : tableBuf.data(), + table.empty() ? 0 : SQL_NTS, + scope, + nullable); +#else + // Windows implementation + return SQLSpecialColumns_ptr( + StatementHandle->get(), + identifierType, + catalog.empty() ? nullptr : (SQLWCHAR*)catalog.c_str(), + catalog.empty() ? 0 : SQL_NTS, + schema.empty() ? nullptr : (SQLWCHAR*)schema.c_str(), + schema.empty() ? 0 : SQL_NTS, + table.empty() ? nullptr : (SQLWCHAR*)table.c_str(), + table.empty() ? 0 : SQL_NTS, + scope, + nullable); +#endif +} + // Wrap SQLFetch to retrieve rows SQLRETURN SQLFetch_wrap(SqlHandlePtr StatementHandle) { LOG("Fetch next row"); @@ -2633,12 +2884,57 @@ PYBIND11_MODULE(ddbc_bindings, m) { }, "Set statement attributes"); m.def("DDBCSQLGetTypeInfo", &SQLGetTypeInfo_Wrapper, "Returns information about the data types that are supported by the data source", py::arg("StatementHandle"), py::arg("DataType")); - m.def("DDBCSQLProcedures", [](SqlHandlePtr StatementHandle, + m.def("DDBCSQLProcedures", [](SqlHandlePtr StatementHandle, const py::object& catalog, const py::object& schema, const py::object& procedure) { - return SQLProcedures_wrap(StatementHandle, catalog, schema, procedure); -}); + return SQLProcedures_wrap(StatementHandle, catalog, schema, procedure); + }); + + m.def("DDBCSQLForeignKeys", [](SqlHandlePtr StatementHandle, + const py::object& pkCatalog, + const py::object& pkSchema, + const py::object& pkTable, + const py::object& fkCatalog, + const py::object& fkSchema, + const py::object& fkTable) { + return SQLForeignKeys_wrap(StatementHandle, + pkCatalog, pkSchema, pkTable, + fkCatalog, fkSchema, fkTable); + }); + m.def("DDBCSQLPrimaryKeys", [](SqlHandlePtr StatementHandle, + const py::object& catalog, + const py::object& schema, + const std::wstring& table) { + return SQLPrimaryKeys_wrap(StatementHandle, catalog, schema, table); + }); + m.def("DDBCSQLSpecialColumns", [](SqlHandlePtr StatementHandle, + SQLSMALLINT identifierType, + const py::object& catalog, + const py::object& schema, + const std::wstring& table, + SQLSMALLINT scope, + SQLSMALLINT nullable) { + return SQLSpecialColumns_wrap(StatementHandle, + identifierType, catalog, schema, table, + scope, nullable); + }); + m.def("DDBCSQLStatistics", [](SqlHandlePtr StatementHandle, + const py::object& catalog, + const py::object& schema, + const std::wstring& table, + SQLUSMALLINT unique, + SQLUSMALLINT reserved) { + return SQLStatistics_wrap(StatementHandle, catalog, schema, table, unique, reserved); + }); + m.def("DDBCSQLColumns", [](SqlHandlePtr StatementHandle, + const py::object& catalog, + const py::object& schema, + const py::object& table, + const py::object& column) { + return SQLColumns_wrap(StatementHandle, catalog, schema, table, column); + }); + // Add a version attribute m.attr("__version__") = "1.0.0"; diff --git a/mssql_python/pybind/ddbc_bindings.h b/mssql_python/pybind/ddbc_bindings.h index f2cae7b3..d757ad95 100644 --- a/mssql_python/pybind/ddbc_bindings.h +++ b/mssql_python/pybind/ddbc_bindings.h @@ -108,6 +108,20 @@ typedef SQLRETURN (SQL_API* SQLColAttributeFunc)(SQLHSTMT, SQLUSMALLINT, SQLUSMA typedef SQLRETURN (SQL_API* SQLGetTypeInfoFunc)(SQLHSTMT, SQLSMALLINT); typedef SQLRETURN (SQL_API* SQLProceduresFunc)(SQLHSTMT, SQLWCHAR*, SQLSMALLINT, SQLWCHAR*, SQLSMALLINT, SQLWCHAR*, SQLSMALLINT); +typedef SQLRETURN (SQL_API* SQLForeignKeysFunc)(SQLHSTMT, SQLWCHAR*, SQLSMALLINT, SQLWCHAR*, + SQLSMALLINT, SQLWCHAR*, SQLSMALLINT, SQLWCHAR*, + SQLSMALLINT, SQLWCHAR*, SQLSMALLINT, SQLWCHAR*, SQLSMALLINT); +typedef SQLRETURN (SQL_API* SQLPrimaryKeysFunc)(SQLHSTMT, SQLWCHAR*, SQLSMALLINT, SQLWCHAR*, + SQLSMALLINT, SQLWCHAR*, SQLSMALLINT); +typedef SQLRETURN (SQL_API* SQLSpecialColumnsFunc)(SQLHSTMT, SQLUSMALLINT, SQLWCHAR*, SQLSMALLINT, + SQLWCHAR*, SQLSMALLINT, SQLWCHAR*, SQLSMALLINT, + SQLUSMALLINT, SQLUSMALLINT); +typedef SQLRETURN (SQL_API* SQLStatisticsFunc)(SQLHSTMT, SQLWCHAR*, SQLSMALLINT, SQLWCHAR*, + SQLSMALLINT, SQLWCHAR*, SQLSMALLINT, + SQLUSMALLINT, SQLUSMALLINT); +typedef SQLRETURN (SQL_API* SQLColumnsFunc)(SQLHSTMT, SQLWCHAR*, SQLSMALLINT, SQLWCHAR*, + SQLSMALLINT, SQLWCHAR*, SQLSMALLINT, + SQLWCHAR*, SQLSMALLINT); // Transaction APIs typedef SQLRETURN (SQL_API* SQLEndTranFunc)(SQLSMALLINT, SQLHANDLE, SQLSMALLINT); @@ -153,6 +167,11 @@ extern SQLMoreResultsFunc SQLMoreResults_ptr; extern SQLColAttributeFunc SQLColAttribute_ptr; extern SQLGetTypeInfoFunc SQLGetTypeInfo_ptr; extern SQLProceduresFunc SQLProcedures_ptr; +extern SQLForeignKeysFunc SQLForeignKeys_ptr; +extern SQLPrimaryKeysFunc SQLPrimaryKeys_ptr; +extern SQLSpecialColumnsFunc SQLSpecialColumns_ptr; +extern SQLStatisticsFunc SQLStatistics_ptr; +extern SQLColumnsFunc SQLColumns_ptr; // Transaction APIs extern SQLEndTranFunc SQLEndTran_ptr; diff --git a/tests/test_004_cursor.py b/tests/test_004_cursor.py index 653d713b..e07d370a 100644 --- a/tests/test_004_cursor.py +++ b/tests/test_004_cursor.py @@ -10,6 +10,7 @@ import pytest from datetime import datetime, date, time +import time as time_module import decimal from mssql_python import Connection import mssql_python @@ -2161,6 +2162,1583 @@ def test_procedures_cleanup(cursor, db_connection): except Exception as e: pytest.fail(f"Test cleanup failed: {e}") +def test_foreignkeys_setup(cursor, db_connection): + """Create tables with foreign key relationships for testing""" + try: + # Create a test schema for isolation + cursor.execute("IF NOT EXISTS (SELECT * FROM sys.schemas WHERE name = 'pytest_fk_schema') EXEC('CREATE SCHEMA pytest_fk_schema')") + + # Drop tables if they exist (in reverse order to avoid constraint conflicts) + cursor.execute("DROP TABLE IF EXISTS pytest_fk_schema.orders") + cursor.execute("DROP TABLE IF EXISTS pytest_fk_schema.customers") + + # Create parent table + cursor.execute(""" + CREATE TABLE pytest_fk_schema.customers ( + customer_id INT PRIMARY KEY, + customer_name VARCHAR(100) NOT NULL + ) + """) + + # Create child table with foreign key + cursor.execute(""" + CREATE TABLE pytest_fk_schema.orders ( + order_id INT PRIMARY KEY, + order_date DATETIME NOT NULL, + customer_id INT NOT NULL, + total_amount DECIMAL(10, 2) NOT NULL, + CONSTRAINT FK_Orders_Customers FOREIGN KEY (customer_id) + REFERENCES pytest_fk_schema.customers (customer_id) + ) + """) + + # Insert test data + cursor.execute(""" + INSERT INTO pytest_fk_schema.customers (customer_id, customer_name) + VALUES (1, 'Test Customer 1'), (2, 'Test Customer 2') + """) + + cursor.execute(""" + INSERT INTO pytest_fk_schema.orders (order_id, order_date, customer_id, total_amount) + VALUES (101, GETDATE(), 1, 150.00), (102, GETDATE(), 2, 250.50) + """) + + db_connection.commit() + except Exception as e: + pytest.fail(f"Test setup failed: {e}") + +def test_foreignkeys_all(cursor, db_connection): + """Test getting all foreign keys""" + try: + # First set up our test tables + test_foreignkeys_setup(cursor, db_connection) + + # Get all foreign keys + fks = cursor.foreignKeys(table='orders', schema='pytest_fk_schema').fetchall() + + # Verify we got results + assert fks is not None, "foreignKeys() should return results" + assert len(fks) > 0, "foreignKeys() should return at least one foreign key" + + # Verify our test FK is in the results + # Search case-insensitively since the database might return different case + found_test_fk = False + for fk in fks: + if (fk.fktable_name.lower() == 'orders' and + fk.pktable_name.lower() == 'customers'): + found_test_fk = True + break + + assert found_test_fk, "Could not find the test foreign key in results" + + finally: + # Clean up + cursor.execute("DROP TABLE IF EXISTS pytest_fk_schema.orders") + cursor.execute("DROP TABLE IF EXISTS pytest_fk_schema.customers") + db_connection.commit() + +def test_foreignkeys_specific_table(cursor, db_connection): + """Test getting foreign keys for a specific table""" + try: + # First set up our test tables + test_foreignkeys_setup(cursor, db_connection) + + # Get foreign keys for the orders table + fks = cursor.foreignKeys(table='orders', schema='pytest_fk_schema').fetchall() + + # Verify we got results + assert len(fks) == 1, "Should find exactly one foreign key for orders table" + + # Verify the foreign key details + fk = fks[0] + assert fk.fktable_name.lower() == 'orders', "Wrong foreign key table name" + assert fk.pktable_name.lower() == 'customers', "Wrong primary key table name" + assert fk.fkcolumn_name.lower() == 'customer_id', "Wrong foreign key column name" + assert fk.pkcolumn_name.lower() == 'customer_id', "Wrong primary key column name" + + finally: + # Clean up + cursor.execute("DROP TABLE IF EXISTS pytest_fk_schema.orders") + cursor.execute("DROP TABLE IF EXISTS pytest_fk_schema.customers") + db_connection.commit() + +def test_foreignkeys_specific_foreign_table(cursor, db_connection): + """Test getting foreign keys that reference a specific table""" + try: + # First set up our test tables + test_foreignkeys_setup(cursor, db_connection) + + # Get foreign keys that reference the customers table + fks = cursor.foreignKeys(foreignTable='customers', foreignSchema='pytest_fk_schema').fetchall() + + # Verify we got results + assert len(fks) > 0, "Should find at least one foreign key referencing customers table" + + # Verify our test FK is in the results + found_test_fk = False + for fk in fks: + if (fk.fktable_name.lower() == 'orders' and + fk.pktable_name.lower() == 'customers'): + found_test_fk = True + break + + assert found_test_fk, "Could not find the test foreign key in results" + + finally: + # Clean up + cursor.execute("DROP TABLE IF EXISTS pytest_fk_schema.orders") + cursor.execute("DROP TABLE IF EXISTS pytest_fk_schema.customers") + db_connection.commit() + +def test_foreignkeys_both_tables(cursor, db_connection): + """Test getting foreign keys with both table and foreignTable specified""" + try: + # First set up our test tables + test_foreignkeys_setup(cursor, db_connection) + + # Get foreign keys between the two tables + fks = cursor.foreignKeys( + table='orders', schema='pytest_fk_schema', + foreignTable='customers', foreignSchema='pytest_fk_schema' + ).fetchall() + + # Verify we got results + assert len(fks) == 1, "Should find exactly one foreign key between specified tables" + + # Verify the foreign key details + fk = fks[0] + assert fk.fktable_name.lower() == 'orders', "Wrong foreign key table name" + assert fk.pktable_name.lower() == 'customers', "Wrong primary key table name" + assert fk.fkcolumn_name.lower() == 'customer_id', "Wrong foreign key column name" + assert fk.pkcolumn_name.lower() == 'customer_id', "Wrong primary key column name" + + finally: + # Clean up + cursor.execute("DROP TABLE IF EXISTS pytest_fk_schema.orders") + cursor.execute("DROP TABLE IF EXISTS pytest_fk_schema.customers") + db_connection.commit() + +def test_foreignkeys_nonexistent(cursor): + """Test foreignKeys() with non-existent table name""" + # Use a table name that's highly unlikely to exist + fks = cursor.foreignKeys(table='nonexistent_table_xyz123').fetchall() + + # Should return empty list, not error + assert isinstance(fks, list), "Should return a list for non-existent table" + assert len(fks) == 0, "Should return empty list for non-existent table" + +def test_foreignkeys_catalog_schema(cursor, db_connection): + """Test foreignKeys() with catalog and schema filters""" + try: + # First set up our test tables + test_foreignkeys_setup(cursor, db_connection) + + # Get current database name + cursor.execute("SELECT DB_NAME() AS current_db") + row = cursor.fetchone() + current_db = row.current_db + + # Get foreign keys with current catalog and pytest schema + fks = cursor.foreignKeys( + table='orders', + catalog=current_db, + schema='pytest_fk_schema' + ).fetchall() + + # Verify we got results + assert len(fks) > 0, "Should find foreign keys with correct catalog/schema" + + # Verify catalog/schema in results + for fk in fks: + assert fk.fktable_cat == current_db, "Wrong foreign key table catalog" + assert fk.fktable_schem == 'pytest_fk_schema', "Wrong foreign key table schema" + + finally: + # Clean up + cursor.execute("DROP TABLE IF EXISTS pytest_fk_schema.orders") + cursor.execute("DROP TABLE IF EXISTS pytest_fk_schema.customers") + db_connection.commit() + +def test_foreignkeys_result_structure(cursor, db_connection): + """Test the structure of foreignKeys result rows""" + try: + # First set up our test tables + test_foreignkeys_setup(cursor, db_connection) + + # Get foreign keys for the orders table + fks = cursor.foreignKeys(table='orders', schema='pytest_fk_schema').fetchall() + + # Verify we got results + assert len(fks) > 0, "Should find at least one foreign key" + + # Check for all required columns in the result + first_row = fks[0] + required_columns = [ + 'pktable_cat', 'pktable_schem', 'pktable_name', 'pkcolumn_name', + 'fktable_cat', 'fktable_schem', 'fktable_name', 'fkcolumn_name', + 'key_seq', 'update_rule', 'delete_rule', 'fk_name', 'pk_name', + 'deferrability' + ] + + for column in required_columns: + assert hasattr(first_row, column), f"Result missing required column: {column}" + + # Verify specific values + assert first_row.fktable_name.lower() == 'orders', "Wrong foreign key table name" + assert first_row.pktable_name.lower() == 'customers', "Wrong primary key table name" + assert first_row.fkcolumn_name.lower() == 'customer_id', "Wrong foreign key column name" + assert first_row.pkcolumn_name.lower() == 'customer_id', "Wrong primary key column name" + assert first_row.key_seq == 1, "Wrong key sequence number" + assert first_row.fk_name is not None, "Foreign key name should not be None" + assert first_row.pk_name is not None, "Primary key name should not be None" + + finally: + # Clean up + cursor.execute("DROP TABLE IF EXISTS pytest_fk_schema.orders") + cursor.execute("DROP TABLE IF EXISTS pytest_fk_schema.customers") + db_connection.commit() + +def test_foreignkeys_multiple_column_fk(cursor, db_connection): + """Test foreignKeys() with a multi-column foreign key""" + try: + # First create the schema if needed + cursor.execute("IF NOT EXISTS (SELECT * FROM sys.schemas WHERE name = 'pytest_fk_schema') EXEC('CREATE SCHEMA pytest_fk_schema')") + + # Drop tables if they exist (in reverse order to avoid constraint conflicts) + cursor.execute("DROP TABLE IF EXISTS pytest_fk_schema.order_details") + cursor.execute("DROP TABLE IF EXISTS pytest_fk_schema.product_variants") + + # Create parent table with composite primary key + cursor.execute(""" + CREATE TABLE pytest_fk_schema.product_variants ( + product_id INT NOT NULL, + variant_id INT NOT NULL, + variant_name VARCHAR(100) NOT NULL, + PRIMARY KEY (product_id, variant_id) + ) + """) + + # Create child table with composite foreign key + cursor.execute(""" + CREATE TABLE pytest_fk_schema.order_details ( + order_id INT NOT NULL, + product_id INT NOT NULL, + variant_id INT NOT NULL, + quantity INT NOT NULL, + PRIMARY KEY (order_id, product_id, variant_id), + CONSTRAINT FK_OrderDetails_ProductVariants FOREIGN KEY (product_id, variant_id) + REFERENCES pytest_fk_schema.product_variants (product_id, variant_id) + ) + """) + + db_connection.commit() + + # Get foreign keys for the order_details table + fks = cursor.foreignKeys(table='order_details', schema='pytest_fk_schema').fetchall() + + # Verify we got results + assert len(fks) == 2, "Should find two rows for the composite foreign key (one per column)" + + # Group by key_seq to verify both columns + fk_columns = {} + for fk in fks: + fk_columns[fk.key_seq] = { + 'pkcolumn': fk.pkcolumn_name.lower(), + 'fkcolumn': fk.fkcolumn_name.lower() + } + + # Verify both columns are present + assert 1 in fk_columns, "First column of composite key missing" + assert 2 in fk_columns, "Second column of composite key missing" + + # Verify column mappings + assert fk_columns[1]['pkcolumn'] == 'product_id', "Wrong primary key column 1" + assert fk_columns[1]['fkcolumn'] == 'product_id', "Wrong foreign key column 1" + assert fk_columns[2]['pkcolumn'] == 'variant_id', "Wrong primary key column 2" + assert fk_columns[2]['fkcolumn'] == 'variant_id', "Wrong foreign key column 2" + + finally: + # Clean up + cursor.execute("DROP TABLE IF EXISTS pytest_fk_schema.order_details") + cursor.execute("DROP TABLE IF EXISTS pytest_fk_schema.product_variants") + db_connection.commit() + +def test_cleanup_schema(cursor, db_connection): + """Clean up the test schema after all tests""" + try: + # Make sure no tables remain + cursor.execute("DROP TABLE IF EXISTS pytest_fk_schema.orders") + cursor.execute("DROP TABLE IF EXISTS pytest_fk_schema.customers") + cursor.execute("DROP TABLE IF EXISTS pytest_fk_schema.order_details") + cursor.execute("DROP TABLE IF EXISTS pytest_fk_schema.product_variants") + db_connection.commit() + + # Drop the schema + cursor.execute("DROP SCHEMA IF EXISTS pytest_fk_schema") + db_connection.commit() + except Exception as e: + pytest.fail(f"Schema cleanup failed: {e}") + +def test_primarykeys_setup(cursor, db_connection): + """Create tables with primary keys for testing""" + try: + # Create a test schema for isolation + cursor.execute("IF NOT EXISTS (SELECT * FROM sys.schemas WHERE name = 'pytest_pk_schema') EXEC('CREATE SCHEMA pytest_pk_schema')") + + # Drop tables if they exist + cursor.execute("DROP TABLE IF EXISTS pytest_pk_schema.single_pk_test") + cursor.execute("DROP TABLE IF EXISTS pytest_pk_schema.composite_pk_test") + + # Create table with simple primary key + cursor.execute(""" + CREATE TABLE pytest_pk_schema.single_pk_test ( + id INT PRIMARY KEY, + name VARCHAR(100) NOT NULL, + description VARCHAR(200) NULL + ) + """) + + # Create table with composite primary key + cursor.execute(""" + CREATE TABLE pytest_pk_schema.composite_pk_test ( + dept_id INT NOT NULL, + emp_id INT NOT NULL, + hire_date DATE NOT NULL, + CONSTRAINT PK_composite_test PRIMARY KEY (dept_id, emp_id) + ) + """) + + db_connection.commit() + except Exception as e: + pytest.fail(f"Test setup failed: {e}") + +def test_primarykeys_simple(cursor, db_connection): + """Test primaryKeys returns information about a simple primary key""" + try: + # First set up our test tables + test_primarykeys_setup(cursor, db_connection) + + # Get primary key information + pks = cursor.primaryKeys('single_pk_test', schema='pytest_pk_schema').fetchall() + + # Verify we got results + assert len(pks) == 1, "Should find exactly one primary key column" + pk = pks[0] + + # Verify primary key details + assert pk.table_name.lower() == 'single_pk_test', "Wrong table name" + assert pk.column_name.lower() == 'id', "Wrong primary key column name" + assert pk.key_seq == 1, "Wrong key sequence number" + assert pk.pk_name is not None, "Primary key name should not be None" + + finally: + # Clean up happens in test_primarykeys_cleanup + pass + +def test_primarykeys_composite(cursor, db_connection): + """Test primaryKeys with a composite primary key""" + try: + # Get primary key information + pks = cursor.primaryKeys('composite_pk_test', schema='pytest_pk_schema').fetchall() + + # Verify we got results for both columns + assert len(pks) == 2, "Should find two primary key columns" + + # Sort by key_seq to ensure consistent order + pks = sorted(pks, key=lambda row: row.key_seq) + + # Verify first column + assert pks[0].table_name.lower() == 'composite_pk_test', "Wrong table name" + assert pks[0].column_name.lower() == 'dept_id', "Wrong first primary key column name" + assert pks[0].key_seq == 1, "Wrong key sequence number for first column" + + # Verify second column + assert pks[1].table_name.lower() == 'composite_pk_test', "Wrong table name" + assert pks[1].column_name.lower() == 'emp_id', "Wrong second primary key column name" + assert pks[1].key_seq == 2, "Wrong key sequence number for second column" + + # Both should have the same PK name + assert pks[0].pk_name == pks[1].pk_name, "Both columns should have the same primary key name" + + finally: + # Clean up happens in test_primarykeys_cleanup + pass + +def test_primarykeys_column_info(cursor, db_connection): + """Test that primaryKeys returns correct column information""" + try: + # Get primary key information + pks = cursor.primaryKeys('single_pk_test', schema='pytest_pk_schema').fetchall() + + # Verify column information + assert len(pks) == 1, "Should find exactly one primary key column" + pk = pks[0] + + # Verify expected columns are present + assert hasattr(pk, 'table_cat'), "Result should have table_cat column" + assert hasattr(pk, 'table_schem'), "Result should have table_schem column" + assert hasattr(pk, 'table_name'), "Result should have table_name column" + assert hasattr(pk, 'column_name'), "Result should have column_name column" + assert hasattr(pk, 'key_seq'), "Result should have key_seq column" + assert hasattr(pk, 'pk_name'), "Result should have pk_name column" + + # Verify values are correct + assert pk.table_schem.lower() == 'pytest_pk_schema', "Wrong schema name" + assert pk.table_name.lower() == 'single_pk_test', "Wrong table name" + assert pk.column_name.lower() == 'id', "Wrong column name" + assert isinstance(pk.key_seq, int), "key_seq should be an integer" + + finally: + # Clean up happens in test_primarykeys_cleanup + pass + +def test_primarykeys_nonexistent(cursor): + """Test primaryKeys() with non-existent table name""" + # Use a table name that's highly unlikely to exist + pks = cursor.primaryKeys('nonexistent_table_xyz123').fetchall() + + # Should return empty list, not error + assert isinstance(pks, list), "Should return a list for non-existent table" + assert len(pks) == 0, "Should return empty list for non-existent table" + +def test_primarykeys_catalog_filter(cursor, db_connection): + """Test primaryKeys() with catalog filter""" + try: + # Get current database name + cursor.execute("SELECT DB_NAME() AS current_db") + current_db = cursor.fetchone().current_db + + # Get primary keys with current catalog + pks = cursor.primaryKeys('single_pk_test', catalog=current_db, schema='pytest_pk_schema').fetchall() + + # Verify catalog filter worked + assert len(pks) == 1, "Should find exactly one primary key column" + pk = pks[0] + assert pk.table_cat == current_db, f"Expected catalog {current_db}, got {pk.table_cat}" + + # Get primary keys with non-existent catalog + fake_pks = cursor.primaryKeys('single_pk_test', catalog='nonexistent_db_xyz123').fetchall() + assert len(fake_pks) == 0, "Should return empty list for non-existent catalog" + + finally: + # Clean up happens in test_primarykeys_cleanup + pass + +def test_primarykeys_cleanup(cursor, db_connection): + """Clean up test tables after testing""" + try: + # Drop all test tables + cursor.execute("DROP TABLE IF EXISTS pytest_pk_schema.single_pk_test") + cursor.execute("DROP TABLE IF EXISTS pytest_pk_schema.composite_pk_test") + + # Drop the test schema + cursor.execute("DROP SCHEMA IF EXISTS pytest_pk_schema") + db_connection.commit() + except Exception as e: + pytest.fail(f"Test cleanup failed: {e}") + +def test_specialcolumns_setup(cursor, db_connection): + """Create test tables for testing rowIdColumns and rowVerColumns""" + try: + # Create a test schema for isolation + cursor.execute("IF NOT EXISTS (SELECT * FROM sys.schemas WHERE name = 'pytest_special_schema') EXEC('CREATE SCHEMA pytest_special_schema')") + + # Drop tables if they exist + cursor.execute("DROP TABLE IF EXISTS pytest_special_schema.rowid_test") + cursor.execute("DROP TABLE IF EXISTS pytest_special_schema.timestamp_test") + cursor.execute("DROP TABLE IF EXISTS pytest_special_schema.multiple_unique_test") + cursor.execute("DROP TABLE IF EXISTS pytest_special_schema.identity_test") + + # Create table with primary key (for rowIdColumns) + cursor.execute(""" + CREATE TABLE pytest_special_schema.rowid_test ( + id INT PRIMARY KEY, + name NVARCHAR(100) NOT NULL, + unique_col NVARCHAR(100) UNIQUE, + non_unique_col NVARCHAR(100) + ) + """) + + # Create table with rowversion column (for rowVerColumns) + cursor.execute(""" + CREATE TABLE pytest_special_schema.timestamp_test ( + id INT PRIMARY KEY, + name NVARCHAR(100) NOT NULL, + last_updated ROWVERSION + ) + """) + + # Create table with multiple unique identifiers + cursor.execute(""" + CREATE TABLE pytest_special_schema.multiple_unique_test ( + id INT NOT NULL, + code VARCHAR(10) NOT NULL, + email VARCHAR(100) UNIQUE, + order_number VARCHAR(20) UNIQUE, + CONSTRAINT PK_multiple_unique_test PRIMARY KEY (id, code) + ) + """) + + # Create table with identity column + cursor.execute(""" + CREATE TABLE pytest_special_schema.identity_test ( + id INT IDENTITY(1,1) PRIMARY KEY, + name NVARCHAR(100) NOT NULL, + last_modified DATETIME DEFAULT GETDATE() + ) + """) + + db_connection.commit() + except Exception as e: + pytest.fail(f"Test setup failed: {e}") + +def test_rowid_columns_basic(cursor, db_connection): + """Test basic functionality of rowIdColumns""" + try: + # Get row identifier columns for simple table + rowid_cols = cursor.rowIdColumns( + table='rowid_test', + schema='pytest_special_schema' + ).fetchall() + + # LIMITATION: Only returns first column of primary key + assert len(rowid_cols) == 1, "Should find exactly one ROWID column (first column of PK)" + + # Verify column name in the results + col = rowid_cols[0] + assert col.column_name.lower() == 'id', "Primary key column should be included in ROWID results" + + # Verify result structure + assert hasattr(col, 'scope'), "Result should have scope column" + assert hasattr(col, 'column_name'), "Result should have column_name column" + assert hasattr(col, 'data_type'), "Result should have data_type column" + assert hasattr(col, 'type_name'), "Result should have type_name column" + assert hasattr(col, 'column_size'), "Result should have column_size column" + assert hasattr(col, 'buffer_length'), "Result should have buffer_length column" + assert hasattr(col, 'decimal_digits'), "Result should have decimal_digits column" + assert hasattr(col, 'pseudo_column'), "Result should have pseudo_column column" + + # The scope should be one of the valid values or NULL + assert col.scope in [0, 1, 2, None], f"Invalid scope value: {col.scope}" + + # The pseudo_column should be one of the valid values + assert col.pseudo_column in [0, 1, 2, None], f"Invalid pseudo_column value: {col.pseudo_column}" + + except Exception as e: + pytest.fail(f"rowIdColumns basic test failed: {e}") + finally: + # Clean up happens in test_specialcolumns_cleanup + pass + +def test_rowid_columns_identity(cursor, db_connection): + """Test rowIdColumns with identity column""" + try: + # Get row identifier columns for table with identity column + rowid_cols = cursor.rowIdColumns( + table='identity_test', + schema='pytest_special_schema' + ).fetchall() + + # LIMITATION: Only returns the identity column if it's the primary key + assert len(rowid_cols) == 1, "Should find exactly one ROWID column (identity column as PK)" + + # Verify it's the identity column + col = rowid_cols[0] + assert col.column_name.lower() == 'id', "Identity column should be included as it's the PK" + + except Exception as e: + pytest.fail(f"rowIdColumns identity test failed: {e}") + finally: + # Clean up happens in test_specialcolumns_cleanup + pass + +def test_rowid_columns_composite(cursor, db_connection): + """Test rowIdColumns with composite primary key""" + try: + # Get row identifier columns for table with composite primary key + rowid_cols = cursor.rowIdColumns( + table='multiple_unique_test', + schema='pytest_special_schema' + ).fetchall() + + # LIMITATION: Only returns first column of composite primary key + assert len(rowid_cols) >= 1, "Should find at least one ROWID column (first column of PK)" + + # Verify column names in the results - should be the first PK column + col_names = [col.column_name.lower() for col in rowid_cols] + assert 'id' in col_names, "First part of composite PK should be included" + + # LIMITATION: Other parts of the PK or unique constraints may not be included + if len(rowid_cols) > 1: + # If additional columns are returned, they should be valid + for col in rowid_cols: + assert col.column_name.lower() in ['id', 'code'], "Only PK columns should be returned" + + except Exception as e: + pytest.fail(f"rowIdColumns composite test failed: {e}") + finally: + # Clean up happens in test_specialcolumns_cleanup + pass + +def test_rowid_columns_nonexistent(cursor): + """Test rowIdColumns with non-existent table""" + # Use a table name that's highly unlikely to exist + rowid_cols = cursor.rowIdColumns('nonexistent_table_xyz123').fetchall() + + # Should return empty list, not error + assert isinstance(rowid_cols, list), "Should return a list for non-existent table" + assert len(rowid_cols) == 0, "Should return empty list for non-existent table" + +def test_rowid_columns_nullable(cursor, db_connection): + """Test rowIdColumns with nullable parameter""" + try: + # First create a table with nullable unique column and non-nullable PK + cursor.execute(""" + CREATE TABLE pytest_special_schema.nullable_test ( + id INT PRIMARY KEY, -- PK can't be nullable in SQL Server + data NVARCHAR(100) NULL + ) + """) + db_connection.commit() + + # Test with nullable=True (default) + rowid_cols_with_nullable = cursor.rowIdColumns( + table='nullable_test', + schema='pytest_special_schema' + ).fetchall() + + # Verify PK column is included + assert len(rowid_cols_with_nullable) == 1, "Should return exactly one column (PK)" + assert rowid_cols_with_nullable[0].column_name.lower() == 'id', "PK column should be returned" + + # Test with nullable=False + rowid_cols_no_nullable = cursor.rowIdColumns( + table='nullable_test', + schema='pytest_special_schema', + nullable=False + ).fetchall() + + # The behavior of SQLSpecialColumns with SQL_NO_NULLS is to only return + # non-nullable columns that uniquely identify a row, but SQL Server returns + # an empty set in this case - this is expected behavior + assert len(rowid_cols_no_nullable) == 0, "Should return empty list when nullable=False (ODBC API behavior)" + + except Exception as e: + pytest.fail(f"rowIdColumns nullable test failed: {e}") + finally: + cursor.execute("DROP TABLE IF EXISTS pytest_special_schema.nullable_test") + db_connection.commit() + +def test_rowver_columns_basic(cursor, db_connection): + """Test basic functionality of rowVerColumns""" + try: + # Get version columns from timestamp test table + rowver_cols = cursor.rowVerColumns( + table='timestamp_test', + schema='pytest_special_schema' + ).fetchall() + + # Verify we got results + assert len(rowver_cols) == 1, "Should find exactly one ROWVER column" + + # Verify the column is the rowversion column + rowver_col = rowver_cols[0] + assert rowver_col.column_name.lower() == 'last_updated', "ROWVER column should be 'last_updated'" + assert rowver_col.type_name.lower() in ['rowversion', 'timestamp'], "ROWVER column should have rowversion or timestamp type" + + # Verify result structure - allowing for NULL values + assert hasattr(rowver_col, 'scope'), "Result should have scope column" + assert hasattr(rowver_col, 'column_name'), "Result should have column_name column" + assert hasattr(rowver_col, 'data_type'), "Result should have data_type column" + assert hasattr(rowver_col, 'type_name'), "Result should have type_name column" + assert hasattr(rowver_col, 'column_size'), "Result should have column_size column" + assert hasattr(rowver_col, 'buffer_length'), "Result should have buffer_length column" + assert hasattr(rowver_col, 'decimal_digits'), "Result should have decimal_digits column" + assert hasattr(rowver_col, 'pseudo_column'), "Result should have pseudo_column column" + + # The scope should be one of the valid values or NULL + assert rowver_col.scope in [0, 1, 2, None], f"Invalid scope value: {rowver_col.scope}" + + except Exception as e: + pytest.fail(f"rowVerColumns basic test failed: {e}") + finally: + # Clean up happens in test_specialcolumns_cleanup + pass + +def test_rowver_columns_nonexistent(cursor): + """Test rowVerColumns with non-existent table""" + # Use a table name that's highly unlikely to exist + rowver_cols = cursor.rowVerColumns('nonexistent_table_xyz123').fetchall() + + # Should return empty list, not error + assert isinstance(rowver_cols, list), "Should return a list for non-existent table" + assert len(rowver_cols) == 0, "Should return empty list for non-existent table" + +def test_rowver_columns_nullable(cursor, db_connection): + """Test rowVerColumns with nullable parameter (not expected to have effect)""" + try: + # First create a table with rowversion column + cursor.execute(""" + CREATE TABLE pytest_special_schema.nullable_rowver_test ( + id INT PRIMARY KEY, + ts ROWVERSION + ) + """) + db_connection.commit() + + # Test with nullable=True (default) + rowver_cols_with_nullable = cursor.rowVerColumns( + table='nullable_rowver_test', + schema='pytest_special_schema' + ).fetchall() + + # Verify rowversion column is included (rowversion can't be nullable) + assert len(rowver_cols_with_nullable) == 1, "Should find exactly one ROWVER column" + assert rowver_cols_with_nullable[0].column_name.lower() == 'ts', "ROWVERSION column should be included" + + # Test with nullable=False + rowver_cols_no_nullable = cursor.rowVerColumns( + table='nullable_rowver_test', + schema='pytest_special_schema', + nullable=False + ).fetchall() + + # Verify rowversion column is still included + assert len(rowver_cols_no_nullable) == 1, "Should find exactly one ROWVER column" + assert rowver_cols_no_nullable[0].column_name.lower() == 'ts', "ROWVERSION column should be included even with nullable=False" + + except Exception as e: + pytest.fail(f"rowVerColumns nullable test failed: {e}") + finally: + cursor.execute("DROP TABLE IF EXISTS pytest_special_schema.nullable_rowver_test") + db_connection.commit() + +def test_specialcolumns_catalog_filter(cursor, db_connection): + """Test special columns with catalog filter""" + try: + # Get current database name + cursor.execute("SELECT DB_NAME() AS current_db") + current_db = cursor.fetchone().current_db + + # Test rowIdColumns with current catalog + rowid_cols = cursor.rowIdColumns( + table='rowid_test', + catalog=current_db, + schema='pytest_special_schema' + ).fetchall() + + # Verify catalog filter worked + assert len(rowid_cols) > 0, "Should find ROWID columns with correct catalog" + + # Test rowIdColumns with non-existent catalog + fake_rowid_cols = cursor.rowIdColumns( + table='rowid_test', + catalog='nonexistent_db_xyz123', + schema='pytest_special_schema' + ).fetchall() + assert len(fake_rowid_cols) == 0, "Should return empty list for non-existent catalog" + + # Test rowVerColumns with current catalog + rowver_cols = cursor.rowVerColumns( + table='timestamp_test', + catalog=current_db, + schema='pytest_special_schema' + ).fetchall() + + # Verify catalog filter worked + assert len(rowver_cols) > 0, "Should find ROWVER columns with correct catalog" + + # Test rowVerColumns with non-existent catalog + fake_rowver_cols = cursor.rowVerColumns( + table='timestamp_test', + catalog='nonexistent_db_xyz123', + schema='pytest_special_schema' + ).fetchall() + assert len(fake_rowver_cols) == 0, "Should return empty list for non-existent catalog" + + except Exception as e: + pytest.fail(f"Special columns catalog filter test failed: {e}") + finally: + # Clean up happens in test_specialcolumns_cleanup + pass + +def test_specialcolumns_cleanup(cursor, db_connection): + """Clean up test tables after testing""" + try: + # Drop all test tables + cursor.execute("DROP TABLE IF EXISTS pytest_special_schema.rowid_test") + cursor.execute("DROP TABLE IF EXISTS pytest_special_schema.timestamp_test") + cursor.execute("DROP TABLE IF EXISTS pytest_special_schema.multiple_unique_test") + cursor.execute("DROP TABLE IF EXISTS pytest_special_schema.identity_test") + cursor.execute("DROP TABLE IF EXISTS pytest_special_schema.nullable_unique_test") + cursor.execute("DROP TABLE IF EXISTS pytest_special_schema.nullable_timestamp_test") + + # Drop the test schema + cursor.execute("DROP SCHEMA IF EXISTS pytest_special_schema") + db_connection.commit() + except Exception as e: + pytest.fail(f"Test cleanup failed: {e}") + +def test_statistics_setup(cursor, db_connection): + """Create test tables and indexes for statistics testing""" + try: + # Create a test schema for isolation + cursor.execute("IF NOT EXISTS (SELECT * FROM sys.schemas WHERE name = 'pytest_stats_schema') EXEC('CREATE SCHEMA pytest_stats_schema')") + + # Drop tables if they exist + cursor.execute("DROP TABLE IF EXISTS pytest_stats_schema.stats_test") + cursor.execute("DROP TABLE IF EXISTS pytest_stats_schema.empty_stats_test") + + # Create test table with various indexes + cursor.execute(""" + CREATE TABLE pytest_stats_schema.stats_test ( + id INT PRIMARY KEY, + name VARCHAR(100) NOT NULL, + email VARCHAR(100) UNIQUE, + department VARCHAR(50) NOT NULL, + salary DECIMAL(10, 2) NULL, + hire_date DATE NOT NULL + ) + """) + + # Create a non-unique index + cursor.execute(""" + CREATE INDEX IX_stats_test_dept_date ON pytest_stats_schema.stats_test (department, hire_date) + """) + + # Create a unique index on multiple columns + cursor.execute(""" + CREATE UNIQUE INDEX UX_stats_test_name_dept ON pytest_stats_schema.stats_test (name, department) + """) + + # Create an empty table for testing + cursor.execute(""" + CREATE TABLE pytest_stats_schema.empty_stats_test ( + id INT PRIMARY KEY, + data VARCHAR(100) NULL + ) + """) + + db_connection.commit() + except Exception as e: + pytest.fail(f"Test setup failed: {e}") + +def test_statistics_basic(cursor, db_connection): + """Test basic functionality of statistics method""" + try: + # First set up our test tables + test_statistics_setup(cursor, db_connection) + + # Get statistics for the test table (all indexes) + stats = cursor.statistics( + table='stats_test', + schema='pytest_stats_schema' + ).fetchall() + + # Verify we got results - should include PK, unique index on email, and non-unique index + assert stats is not None, "statistics() should return results" + assert len(stats) > 0, "statistics() should return at least one row" + + # Count different types of indexes + table_stats = [s for s in stats if s.type == 0] # TABLE_STAT + indexes = [s for s in stats if s.type != 0] # Actual indexes + + # We should have at least one table statistics row and multiple index rows + assert len(table_stats) <= 1, "Should have at most one TABLE_STAT row" + assert len(indexes) >= 3, "Should have at least 3 index entries (PK, unique email, non-unique dept+date)" + + # Verify column names in results + first_row = stats[0] + assert hasattr(first_row, 'table_name'), "Result should have table_name column" + assert hasattr(first_row, 'non_unique'), "Result should have non_unique column" + assert hasattr(first_row, 'index_name'), "Result should have index_name column" + assert hasattr(first_row, 'type'), "Result should have type column" + assert hasattr(first_row, 'column_name'), "Result should have column_name column" + + # Check that we can find the primary key + pk_found = False + for stat in stats: + if (hasattr(stat, 'index_name') and + stat.index_name and + 'pk' in stat.index_name.lower()): + pk_found = True + break + + assert pk_found, "Primary key should be included in statistics results" + + # Check that we can find the unique index on email + email_index_found = False + for stat in stats: + if (hasattr(stat, 'column_name') and + stat.column_name and + stat.column_name.lower() == 'email' and + hasattr(stat, 'non_unique') and + stat.non_unique == 0): # 0 = unique + email_index_found = True + break + + assert email_index_found, "Unique index on email should be included in statistics results" + + finally: + # Clean up happens in test_statistics_cleanup + pass + +def test_statistics_unique_only(cursor, db_connection): + """Test statistics with unique=True to get only unique indexes""" + try: + # Get statistics for only unique indexes + stats = cursor.statistics( + table='stats_test', + schema='pytest_stats_schema', + unique=True + ).fetchall() + + # Verify we got results + assert stats is not None, "statistics() with unique=True should return results" + assert len(stats) > 0, "statistics() with unique=True should return at least one row" + + # All index entries should be for unique indexes (non_unique = 0) + for stat in stats: + if hasattr(stat, 'type') and stat.type != 0: # Skip TABLE_STAT entries + assert hasattr(stat, 'non_unique'), "Index entry should have non_unique column" + assert stat.non_unique == 0, "With unique=True, all indexes should be unique" + + # Count different types of indexes + indexes = [s for s in stats if hasattr(s, 'type') and s.type != 0] + + # We should have multiple unique indexes (PK, unique email, unique name+dept) + assert len(indexes) >= 3, "Should have at least 3 unique index entries" + + finally: + # Clean up happens in test_statistics_cleanup + pass + +def test_statistics_empty_table(cursor, db_connection): + """Test statistics on a table with no data (just schema)""" + try: + # Get statistics for the empty table + stats = cursor.statistics( + table='empty_stats_test', + schema='pytest_stats_schema' + ).fetchall() + + # Should still return metadata about the primary key + assert stats is not None, "statistics() should return results even for empty table" + assert len(stats) > 0, "statistics() should return at least one row for empty table" + + # Check for primary key + pk_found = False + for stat in stats: + if (hasattr(stat, 'index_name') and + stat.index_name and + 'pk' in stat.index_name.lower()): + pk_found = True + break + + assert pk_found, "Primary key should be included in statistics results for empty table" + + finally: + # Clean up happens in test_statistics_cleanup + pass + +def test_statistics_nonexistent(cursor): + """Test statistics with non-existent table name""" + # Use a table name that's highly unlikely to exist + stats = cursor.statistics('nonexistent_table_xyz123').fetchall() + + # Should return empty list, not error + assert isinstance(stats, list), "Should return a list for non-existent table" + assert len(stats) == 0, "Should return empty list for non-existent table" + +def test_statistics_result_structure(cursor, db_connection): + """Test the complete structure of statistics result rows""" + try: + # Get statistics for the test table + stats = cursor.statistics( + table='stats_test', + schema='pytest_stats_schema' + ).fetchall() + + # Verify we have results + assert len(stats) > 0, "Should have statistics results" + + # Find a row that's an actual index (not TABLE_STAT) + index_row = None + for stat in stats: + if hasattr(stat, 'type') and stat.type != 0: + index_row = stat + break + + assert index_row is not None, "Should have at least one index row" + + # Check for all required columns + required_columns = [ + 'table_cat', 'table_schem', 'table_name', 'non_unique', + 'index_qualifier', 'index_name', 'type', 'ordinal_position', + 'column_name', 'asc_or_desc', 'cardinality', 'pages', + 'filter_condition' + ] + + for column in required_columns: + assert hasattr(index_row, column), f"Result missing required column: {column}" + + # Check types of key columns + assert isinstance(index_row.table_name, str), "table_name should be a string" + assert isinstance(index_row.type, int), "type should be an integer" + + # Don't check the actual values of cardinality and pages as they may be NULL + # or driver-dependent, especially for empty tables + + finally: + # Clean up happens in test_statistics_cleanup + pass + +def test_statistics_catalog_filter(cursor, db_connection): + """Test statistics with catalog filter""" + try: + # Get current database name + cursor.execute("SELECT DB_NAME() AS current_db") + current_db = cursor.fetchone().current_db + + # Get statistics with current catalog + stats = cursor.statistics( + table='stats_test', + catalog=current_db, + schema='pytest_stats_schema' + ).fetchall() + + # Verify catalog filter worked + assert len(stats) > 0, "Should find statistics with correct catalog" + + # Verify catalog in results + for stat in stats: + if hasattr(stat, 'table_cat'): + assert stat.table_cat.lower() == current_db.lower(), "Wrong table catalog" + + # Get statistics with non-existent catalog + fake_stats = cursor.statistics( + table='stats_test', + catalog='nonexistent_db_xyz123', + schema='pytest_stats_schema' + ).fetchall() + assert len(fake_stats) == 0, "Should return empty list for non-existent catalog" + + finally: + # Clean up happens in test_statistics_cleanup + pass + +def test_statistics_with_quick_parameter(cursor, db_connection): + """Test statistics with quick parameter variations""" + try: + # Test with quick=True (default) + quick_stats = cursor.statistics( + table='stats_test', + schema='pytest_stats_schema', + quick=True + ).fetchall() + + # Test with quick=False + thorough_stats = cursor.statistics( + table='stats_test', + schema='pytest_stats_schema', + quick=False + ).fetchall() + + # Both should return results, but we can't guarantee behavior differences + # since it depends on the ODBC driver and database system + assert len(quick_stats) > 0, "quick=True should return results" + assert len(thorough_stats) > 0, "quick=False should return results" + + # Just verify that changing the parameter didn't cause errors + + finally: + # Clean up happens in test_statistics_cleanup + pass + +def test_statistics_cleanup(cursor, db_connection): + """Clean up test tables after testing""" + try: + # Drop all test tables + cursor.execute("DROP TABLE IF EXISTS pytest_stats_schema.stats_test") + cursor.execute("DROP TABLE IF EXISTS pytest_stats_schema.empty_stats_test") + + # Drop the test schema + cursor.execute("DROP SCHEMA IF EXISTS pytest_stats_schema") + db_connection.commit() + except Exception as e: + pytest.fail(f"Test cleanup failed: {e}") + +def test_columns_setup(cursor, db_connection): + """Create test tables for columns method testing""" + try: + # Create a test schema for isolation + cursor.execute("IF NOT EXISTS (SELECT * FROM sys.schemas WHERE name = 'pytest_cols_schema') EXEC('CREATE SCHEMA pytest_cols_schema')") + + # Drop tables if they exist + cursor.execute("DROP TABLE IF EXISTS pytest_cols_schema.columns_test") + cursor.execute("DROP TABLE IF EXISTS pytest_cols_schema.columns_special_test") + + # Create test table with various column types + cursor.execute(""" + CREATE TABLE pytest_cols_schema.columns_test ( + id INT PRIMARY KEY, + name NVARCHAR(100) NOT NULL, + description NVARCHAR(MAX) NULL, + price DECIMAL(10, 2) NULL, + created_date DATETIME DEFAULT GETDATE(), + is_active BIT NOT NULL DEFAULT 1, + binary_data VARBINARY(MAX) NULL, + notes TEXT NULL, + [computed_col] AS (name + ' - ' + CAST(id AS VARCHAR(10))) + ) + """) + + # Create table with special column names and edge cases - fix the problematic column name + cursor.execute(""" + CREATE TABLE pytest_cols_schema.columns_special_test ( + [ID] INT PRIMARY KEY, + [User Name] NVARCHAR(100) NULL, + [Spaces Multiple] VARCHAR(50) NULL, + [123_numeric_start] INT NULL, + [MAX] VARCHAR(20) NULL, -- SQL keyword as column name + [SELECT] INT NULL, -- SQL keyword as column name + [Column.With.Dots] VARCHAR(20) NULL, + [Column/With/Slashes] VARCHAR(20) NULL, + [Column_With_Underscores] VARCHAR(20) NULL -- Changed from problematic nested brackets + ) + """) + + db_connection.commit() + except Exception as e: + pytest.fail(f"Test setup failed: {e}") + +def test_columns_all(cursor, db_connection): + """Test columns returns information about all columns in all tables""" + try: + # First set up our test tables + test_columns_setup(cursor, db_connection) + + # Get all columns (no filters) + cols_cursor = cursor.columns() + cols = cols_cursor.fetchall() + + # Verify we got results + assert cols is not None, "columns() should return results" + assert len(cols) > 0, "columns() should return at least one column" + + # Verify our test tables' columns are in the results + # Use case-insensitive comparison to avoid driver case sensitivity issues + found_test_table = False + for col in cols: + if (hasattr(col, 'table_name') and + col.table_name and + col.table_name.lower() == 'columns_test' and + hasattr(col, 'table_schem') and + col.table_schem and + col.table_schem.lower() == 'pytest_cols_schema'): + found_test_table = True + break + + assert found_test_table, "Test table columns should be included in results" + + # Verify structure of results + first_row = cols[0] + assert hasattr(first_row, 'table_cat'), "Result should have table_cat column" + assert hasattr(first_row, 'table_schem'), "Result should have table_schem column" + assert hasattr(first_row, 'table_name'), "Result should have table_name column" + assert hasattr(first_row, 'column_name'), "Result should have column_name column" + assert hasattr(first_row, 'data_type'), "Result should have data_type column" + assert hasattr(first_row, 'type_name'), "Result should have type_name column" + assert hasattr(first_row, 'column_size'), "Result should have column_size column" + assert hasattr(first_row, 'buffer_length'), "Result should have buffer_length column" + assert hasattr(first_row, 'decimal_digits'), "Result should have decimal_digits column" + assert hasattr(first_row, 'num_prec_radix'), "Result should have num_prec_radix column" + assert hasattr(first_row, 'nullable'), "Result should have nullable column" + assert hasattr(first_row, 'remarks'), "Result should have remarks column" + assert hasattr(first_row, 'column_def'), "Result should have column_def column" + assert hasattr(first_row, 'sql_data_type'), "Result should have sql_data_type column" + assert hasattr(first_row, 'sql_datetime_sub'), "Result should have sql_datetime_sub column" + assert hasattr(first_row, 'char_octet_length'), "Result should have char_octet_length column" + assert hasattr(first_row, 'ordinal_position'), "Result should have ordinal_position column" + assert hasattr(first_row, 'is_nullable'), "Result should have is_nullable column" + + finally: + # Clean up happens in test_columns_cleanup + pass + +def test_columns_specific_table(cursor, db_connection): + """Test columns returns information about a specific table""" + try: + # Get columns for the test table + cols = cursor.columns( + table='columns_test', + schema='pytest_cols_schema' + ).fetchall() + + # Verify we got results + assert len(cols) == 9, "Should find exactly 9 columns in columns_test" + + # Verify all column names are present (case insensitive) + col_names = [col.column_name.lower() for col in cols] + expected_names = ['id', 'name', 'description', 'price', 'created_date', + 'is_active', 'binary_data', 'notes', 'computed_col'] + + for name in expected_names: + assert name in col_names, f"Column {name} should be in results" + + # Verify details of a specific column (id) + id_col = next(col for col in cols if col.column_name.lower() == 'id') + assert id_col.nullable == 0, "id column should be non-nullable" + assert id_col.ordinal_position == 1, "id should be the first column" + assert id_col.is_nullable == "NO", "is_nullable should be NO for id column" + + # Check data types (but don't assume specific ODBC type codes since they vary by driver) + # Instead check that the type_name is correct + id_type = id_col.type_name.lower() + assert 'int' in id_type, f"id column should be INTEGER type, got {id_type}" + + # Check a nullable column + desc_col = next(col for col in cols if col.column_name.lower() == 'description') + assert desc_col.nullable == 1, "description column should be nullable" + assert desc_col.is_nullable == "YES", "is_nullable should be YES for description column" + + finally: + # Clean up happens in test_columns_cleanup + pass + +def test_columns_special_chars(cursor, db_connection): + """Test columns with special characters and edge cases""" + try: + # Get columns for the special table + cols = cursor.columns( + table='columns_special_test', + schema='pytest_cols_schema' + ).fetchall() + + # Verify we got results + assert len(cols) == 9, "Should find exactly 9 columns in columns_special_test" + + # Check that special column names are handled correctly + col_names = [col.column_name for col in cols] + + # Create case-insensitive lookup + col_names_lower = [name.lower() if name else None for name in col_names] + + # Check for columns with special characters - note that column names might be + # returned with or without brackets/quotes depending on the driver + assert any('user name' in name.lower() for name in col_names), "Column with spaces should be in results" + assert any('id' == name.lower() for name in col_names), "ID column should be in results" + assert any('123_numeric_start' in name.lower() for name in col_names), "Column starting with numbers should be in results" + assert any('max' == name.lower() for name in col_names), "MAX column should be in results" + assert any('select' == name.lower() for name in col_names), "SELECT column should be in results" + assert any('column.with.dots' in name.lower() for name in col_names), "Column with dots should be in results" + assert any('column/with/slashes' in name.lower() for name in col_names), "Column with slashes should be in results" + assert any('column_with_underscores' in name.lower() for name in col_names), "Column with underscores should be in results" + + finally: + # Clean up happens in test_columns_cleanup + pass + +def test_columns_specific_column(cursor, db_connection): + """Test columns with specific column filter""" + try: + # Get specific column + cols = cursor.columns( + table='columns_test', + schema='pytest_cols_schema', + column='name' + ).fetchall() + + # Verify we got just one result + assert len(cols) == 1, "Should find exactly 1 column named 'name'" + + # Verify column details + col = cols[0] + assert col.column_name.lower() == 'name', "Column name should be 'name'" + assert col.table_name.lower() == 'columns_test', "Table name should be 'columns_test'" + assert col.table_schem.lower() == 'pytest_cols_schema', "Schema should be 'pytest_cols_schema'" + assert col.nullable == 0, "name column should be non-nullable" + + # Get column using pattern (% wildcard) + pattern_cols = cursor.columns( + table='columns_test', + schema='pytest_cols_schema', + column='%date%' + ).fetchall() + + # Should find created_date column + assert len(pattern_cols) == 1, "Should find 1 column matching '%date%'" + + assert pattern_cols[0].column_name.lower() == 'created_date', "Should find created_date column" + + # Get multiple columns with pattern + multi_cols = cursor.columns( + table='columns_test', + schema='pytest_cols_schema', + column='%d%' # Should match id, description, created_date + ).fetchall() + + # At least 3 columns should match this pattern + assert len(multi_cols) >= 3, "Should find at least 3 columns matching '%d%'" + match_names = [col.column_name.lower() for col in multi_cols] + assert 'id' in match_names, "id should match '%d%'" + assert 'description' in match_names, "description should match '%d%'" + assert 'created_date' in match_names, "created_date should match '%d%'" + + finally: + # Clean up happens in test_columns_cleanup + pass + +def test_columns_with_underscore_pattern(cursor): + """Test columns with underscore wildcard pattern""" + try: + # Get columns with underscore pattern (one character wildcard) + # Looking for 'id' (exactly 2 chars) + cols = cursor.columns( + table='columns_test', + schema='pytest_cols_schema', + column='__' + ).fetchall() + + # Should find 'id' column + id_found = False + for col in cols: + if col.column_name.lower() == 'id' and col.table_name.lower() == 'columns_test': + id_found = True + break + + assert id_found, "Should find 'id' column with pattern '__'" + + # Try a more complex pattern with both % and _ + # For example: '%_d%' matches any column with 'd' as the second or later character + pattern_cols = cursor.columns( + table='columns_test', + schema='pytest_cols_schema', + column='%_d%' + ).fetchall() + + # Should match 'id' (if considering case-insensitive) and 'created_date' + match_names = [col.column_name.lower() for col in pattern_cols + if col.table_name.lower() == 'columns_test'] + + # At least 'created_date' should match this pattern + assert 'created_date' in match_names, "created_date should match '%_d%'" + + finally: + # Clean up happens in test_columns_cleanup + pass + +def test_columns_nonexistent(cursor): + """Test columns with non-existent table or column""" + # Test with non-existent table + table_cols = cursor.columns(table='nonexistent_table_xyz123') + assert len(table_cols) == 0, "Should return empty list for non-existent table" + + # Test with non-existent column in existing table + col_cols = cursor.columns( + table='columns_test', + schema='pytest_cols_schema', + column='nonexistent_column_xyz123' + ).fetchall() + assert len(col_cols) == 0, "Should return empty list for non-existent column" + + # Test with non-existent schema + schema_cols = cursor.columns( + table='columns_test', + schema='nonexistent_schema_xyz123' + ) + assert len(schema_cols) == 0, "Should return empty list for non-existent schema" + +def test_columns_data_types(cursor): + """Test columns returns correct data type information""" + try: + # Get all columns from test table + cols = cursor.columns( + table='columns_test', + schema='pytest_cols_schema' + ).fetchall() + + # Create a dictionary mapping column names to their details + col_dict = {col.column_name.lower(): col for col in cols} + + # Check data types by name (case insensitive checks) + # Note: We're checking type_name as a string to avoid SQL type code inconsistencies + # between drivers + + # INT column + assert 'int' in col_dict['id'].type_name.lower(), "id should be INT type" + + # NVARCHAR column + assert any(name in col_dict['name'].type_name.lower() + for name in ['nvarchar', 'varchar', 'char', 'wchar']), "name should be NVARCHAR type" + + # DECIMAL column + assert any(name in col_dict['price'].type_name.lower() + for name in ['decimal', 'numeric', 'money']), "price should be DECIMAL type" + + # BIT column + assert any(name in col_dict['is_active'].type_name.lower() + for name in ['bit', 'boolean']), "is_active should be BIT type" + + # TEXT column + assert any(name in col_dict['notes'].type_name.lower() + for name in ['text', 'char', 'varchar']), "notes should be TEXT type" + + # Check nullable flag + assert col_dict['id'].nullable == 0, "id should be non-nullable" + assert col_dict['description'].nullable == 1, "description should be nullable" + + # Check column size + assert col_dict['name'].column_size == 100, "name should have size 100" + + # Check decimal digits for numeric type + assert col_dict['price'].decimal_digits == 2, "price should have 2 decimal digits" + + finally: + # Clean up happens in test_columns_cleanup + pass + +def test_columns_nonexistent(cursor): + """Test columns with non-existent table or column""" + # Test with non-existent table + table_cols = cursor.columns(table='nonexistent_table_xyz123').fetchall() + assert len(table_cols) == 0, "Should return empty list for non-existent table" + + # Test with non-existent column in existing table + col_cols = cursor.columns( + table='columns_test', + schema='pytest_cols_schema', + column='nonexistent_column_xyz123' + ).fetchall() + assert len(col_cols) == 0, "Should return empty list for non-existent column" + + # Test with non-existent schema + schema_cols = cursor.columns( + table='columns_test', + schema='nonexistent_schema_xyz123' + ).fetchall() + assert len(schema_cols) == 0, "Should return empty list for non-existent schema" + +def test_columns_catalog_filter(cursor): + """Test columns with catalog filter""" + try: + # Get current database name + cursor.execute("SELECT DB_NAME() AS current_db") + current_db = cursor.fetchone().current_db + + # Get columns with current catalog + cols = cursor.columns( + table='columns_test', + catalog=current_db, + schema='pytest_cols_schema' + ).fetchall() + + # Verify catalog filter worked + assert len(cols) > 0, "Should find columns with correct catalog" + + # Check catalog in results + for col in cols: + # Some drivers might return None for catalog + if col.table_cat is not None: + assert col.table_cat.lower() == current_db.lower(), "Wrong table catalog" + + # Test with non-existent catalog + fake_cols = cursor.columns( + table='columns_test', + catalog='nonexistent_db_xyz123', + schema='pytest_cols_schema' + ).fetchall() + assert len(fake_cols) == 0, "Should return empty list for non-existent catalog" + + finally: + # Clean up happens in test_columns_cleanup + pass + +def test_columns_schema_pattern(cursor): + """Test columns with schema name pattern""" + try: + # Get columns with schema pattern + cols = cursor.columns( + table='columns_test', + schema='pytest_%' + ).fetchall() + + # Should find our test table columns + test_cols = [col for col in cols if col.table_name.lower() == 'columns_test'] + assert len(test_cols) > 0, "Should find columns using schema pattern" + + # Try a more specific pattern + specific_cols = cursor.columns( + table='columns_test', + schema='pytest_cols%' + ).fetchall() + + # Should still find our test table columns + test_cols = [col for col in specific_cols if col.table_name.lower() == 'columns_test'] + assert len(test_cols) > 0, "Should find columns using specific schema pattern" + + finally: + # Clean up happens in test_columns_cleanup + pass + +def test_columns_table_pattern(cursor): + """Test columns with table name pattern""" + try: + # Get columns with table pattern + cols = cursor.columns( + table='columns_%', + schema='pytest_cols_schema' + ).fetchall() + + # Should find columns from both test tables + tables_found = set() + for col in cols: + if col.table_name: + tables_found.add(col.table_name.lower()) + + assert 'columns_test' in tables_found, "Should find columns_test with pattern columns_%" + assert 'columns_special_test' in tables_found, "Should find columns_special_test with pattern columns_%" + + finally: + # Clean up happens in test_columns_cleanup + pass + +def test_columns_ordinal_position(cursor): + """Test ordinal_position is correct in columns results""" + try: + # Get columns for the test table + cols = cursor.columns( + table='columns_test', + schema='pytest_cols_schema' + ).fetchall() + + # Sort by ordinal position + sorted_cols = sorted(cols, key=lambda col: col.ordinal_position) + + # Verify positions are consecutive starting from 1 + for i, col in enumerate(sorted_cols, 1): + assert col.ordinal_position == i, f"Column {col.column_name} should have ordinal_position {i}" + + # First column should be id (primary key) + assert sorted_cols[0].column_name.lower() == 'id', "First column should be id" + + finally: + # Clean up happens in test_columns_cleanup + pass + +def test_columns_cleanup(cursor, db_connection): + """Clean up test tables after testing""" + try: + # Drop all test tables + cursor.execute("DROP TABLE IF EXISTS pytest_cols_schema.columns_test") + cursor.execute("DROP TABLE IF EXISTS pytest_cols_schema.columns_special_test") + + # Drop the test schema + cursor.execute("DROP SCHEMA IF EXISTS pytest_cols_schema") + db_connection.commit() + except Exception as e: + pytest.fail(f"Test cleanup failed: {e}") + def test_close(db_connection): """Test closing the cursor""" try: