diff --git a/mssql_python/__init__.py b/mssql_python/__init__.py index ce34a7f3..bc7ea4b8 100644 --- a/mssql_python/__init__.py +++ b/mssql_python/__init__.py @@ -180,3 +180,28 @@ def _custom_setattr(name, value): # Replace the module's __setattr__ with our custom version sys.modules[__name__].__setattr__ = _custom_setattr + + +# Export SQL constants at module level +SQL_CHAR = ConstantsDDBC.SQL_CHAR.value +SQL_VARCHAR = ConstantsDDBC.SQL_VARCHAR.value +SQL_LONGVARCHAR = ConstantsDDBC.SQL_LONGVARCHAR.value +SQL_WCHAR = ConstantsDDBC.SQL_WCHAR.value +SQL_WVARCHAR = ConstantsDDBC.SQL_WVARCHAR.value +SQL_WLONGVARCHAR = ConstantsDDBC.SQL_WLONGVARCHAR.value +SQL_DECIMAL = ConstantsDDBC.SQL_DECIMAL.value +SQL_NUMERIC = ConstantsDDBC.SQL_NUMERIC.value +SQL_BIT = ConstantsDDBC.SQL_BIT.value +SQL_TINYINT = ConstantsDDBC.SQL_TINYINT.value +SQL_SMALLINT = ConstantsDDBC.SQL_SMALLINT.value +SQL_INTEGER = ConstantsDDBC.SQL_INTEGER.value +SQL_BIGINT = ConstantsDDBC.SQL_BIGINT.value +SQL_REAL = ConstantsDDBC.SQL_REAL.value +SQL_FLOAT = ConstantsDDBC.SQL_FLOAT.value +SQL_DOUBLE = ConstantsDDBC.SQL_DOUBLE.value +SQL_BINARY = ConstantsDDBC.SQL_BINARY.value +SQL_VARBINARY = ConstantsDDBC.SQL_VARBINARY.value +SQL_LONGVARBINARY = ConstantsDDBC.SQL_LONGVARBINARY.value +SQL_DATE = ConstantsDDBC.SQL_DATE.value +SQL_TIME = ConstantsDDBC.SQL_TIME.value +SQL_TIMESTAMP = ConstantsDDBC.SQL_TIMESTAMP.value diff --git a/mssql_python/constants.py b/mssql_python/constants.py index 3d6b4732..61380e1f 100644 --- a/mssql_python/constants.py +++ b/mssql_python/constants.py @@ -124,9 +124,63 @@ class ConstantsDDBC(Enum): SQL_FETCH_ABSOLUTE = 5 SQL_FETCH_RELATIVE = 6 SQL_FETCH_BOOKMARK = 8 + SQL_SCOPE_CURROW = 0 + SQL_BEST_ROWID = 1 + SQL_ROWVER = 2 + SQL_NO_NULLS = 0 + SQL_NULLABLE_UNKNOWN = 2 + SQL_INDEX_UNIQUE = 0 + SQL_INDEX_ALL = 1 + SQL_QUICK = 0 + SQL_ENSURE = 1 class AuthType(Enum): """Constants for authentication types""" INTERACTIVE = "activedirectoryinteractive" DEVICE_CODE = "activedirectorydevicecode" - DEFAULT = "activedirectorydefault" \ No newline at end of file + DEFAULT = "activedirectorydefault" + +class SQLTypes: + """Constants for valid SQL data types to use with setinputsizes""" + + @classmethod + def get_valid_types(cls) -> set: + """Returns a set of all valid SQL type constants""" + + return { + ConstantsDDBC.SQL_CHAR.value, ConstantsDDBC.SQL_VARCHAR.value, + ConstantsDDBC.SQL_LONGVARCHAR.value, ConstantsDDBC.SQL_WCHAR.value, + ConstantsDDBC.SQL_WVARCHAR.value, ConstantsDDBC.SQL_WLONGVARCHAR.value, + ConstantsDDBC.SQL_DECIMAL.value, ConstantsDDBC.SQL_NUMERIC.value, + ConstantsDDBC.SQL_BIT.value, ConstantsDDBC.SQL_TINYINT.value, + ConstantsDDBC.SQL_SMALLINT.value, ConstantsDDBC.SQL_INTEGER.value, + ConstantsDDBC.SQL_BIGINT.value, ConstantsDDBC.SQL_REAL.value, + ConstantsDDBC.SQL_FLOAT.value, ConstantsDDBC.SQL_DOUBLE.value, + ConstantsDDBC.SQL_BINARY.value, ConstantsDDBC.SQL_VARBINARY.value, + ConstantsDDBC.SQL_LONGVARBINARY.value, ConstantsDDBC.SQL_DATE.value, + ConstantsDDBC.SQL_TIME.value, ConstantsDDBC.SQL_TIMESTAMP.value, + ConstantsDDBC.SQL_GUID.value + } + + # Could also add category methods for convenience + @classmethod + def get_string_types(cls) -> set: + """Returns a set of string SQL type constants""" + + return { + ConstantsDDBC.SQL_CHAR.value, ConstantsDDBC.SQL_VARCHAR.value, + ConstantsDDBC.SQL_LONGVARCHAR.value, ConstantsDDBC.SQL_WCHAR.value, + ConstantsDDBC.SQL_WVARCHAR.value, ConstantsDDBC.SQL_WLONGVARCHAR.value + } + + @classmethod + def get_numeric_types(cls) -> set: + """Returns a set of numeric SQL type constants""" + + return { + ConstantsDDBC.SQL_DECIMAL.value, ConstantsDDBC.SQL_NUMERIC.value, + ConstantsDDBC.SQL_BIT.value, ConstantsDDBC.SQL_TINYINT.value, + ConstantsDDBC.SQL_SMALLINT.value, ConstantsDDBC.SQL_INTEGER.value, + ConstantsDDBC.SQL_BIGINT.value, ConstantsDDBC.SQL_REAL.value, + ConstantsDDBC.SQL_FLOAT.value, ConstantsDDBC.SQL_DOUBLE.value + } \ No newline at end of file diff --git a/mssql_python/cursor.py b/mssql_python/cursor.py index 4a1e6a91..b6b309cf 100644 --- a/mssql_python/cursor.py +++ b/mssql_python/cursor.py @@ -11,8 +11,9 @@ import decimal import uuid import datetime -from typing import List, Union -from mssql_python.constants import ConstantsDDBC as ddbc_sql_const +import warnings +from typing import List, Union, Any +from mssql_python.constants import ConstantsDDBC as ddbc_sql_const, SQLTypes from mssql_python.helpers import check_error, log from mssql_python import ddbc_bindings from mssql_python.exceptions import InterfaceError, NotSupportedError, ProgrammingError @@ -54,6 +55,16 @@ class Cursor: setoutputsize(size, column=None) -> None. """ + # TODO(jathakkar): Thread safety considerations + # The cursor class contains methods that are not thread-safe due to: + # 1. Methods that mutate cursor state (_reset_cursor, self.description, etc.) + # 2. Methods that call ODBC functions with shared handles (self.hstmt) + # + # These methods should be properly synchronized or redesigned when implementing + # async functionality to prevent race conditions and data corruption. + # Consider using locks, redesigning for immutability, or ensuring + # cursor objects are never shared across threads. + def __init__(self, connection, timeout: int = 0) -> None: """ Initialize the cursor with a database connection. @@ -63,6 +74,7 @@ def __init__(self, connection, timeout: int = 0) -> None: """ self._connection = connection # Store as private attribute self._timeout = timeout + self._inputsizes = None # self.connection.autocommit = False self.hstmt = None self._initialize_cursor() @@ -529,6 +541,97 @@ def _check_closed(self): ddbc_error="" ) + def setinputsizes(self, sizes: List[Union[int, tuple]]) -> None: + """ + Sets the type information to be used for parameters in execute and executemany. + + This method can be used to explicitly declare the types and sizes of query parameters. + For example: + + sql = "INSERT INTO product (item, price) VALUES (?, ?)" + params = [('bicycle', 499.99), ('ham', 17.95)] + # specify that parameters are for NVARCHAR(50) and DECIMAL(18,4) columns + cursor.setinputsizes([(SQL_WVARCHAR, 50, 0), (SQL_DECIMAL, 18, 4)]) + cursor.executemany(sql, params) + + Args: + sizes: A sequence of tuples, one for each parameter. Each tuple contains + (sql_type, size, decimal_digits) where size and decimal_digits are optional. + """ + + # Get valid SQL types from centralized constants + valid_sql_types = SQLTypes.get_valid_types() + + self._inputsizes = [] + + if sizes: + for size_info in sizes: + if isinstance(size_info, tuple): + # Handle tuple format (sql_type, size, decimal_digits) + if len(size_info) == 1: + sql_type = size_info[0] + column_size = 0 + decimal_digits = 0 + elif len(size_info) == 2: + sql_type, column_size = size_info + decimal_digits = 0 + elif len(size_info) >= 3: + sql_type, column_size, decimal_digits = size_info + + # Validate SQL type + if not isinstance(sql_type, int) or sql_type not in valid_sql_types: + raise ValueError(f"Invalid SQL type: {sql_type}. Must be a valid SQL type constant.") + + # Validate size and precision + if not isinstance(column_size, int) or column_size < 0: + raise ValueError(f"Invalid column size: {column_size}. Must be a non-negative integer.") + + if not isinstance(decimal_digits, int) or decimal_digits < 0: + raise ValueError(f"Invalid decimal digits: {decimal_digits}. Must be a non-negative integer.") + + self._inputsizes.append((sql_type, column_size, decimal_digits)) + else: + # Handle single value (just sql_type) + sql_type = size_info + + # Validate SQL type + if not isinstance(sql_type, int) or sql_type not in valid_sql_types: + raise ValueError(f"Invalid SQL type: {sql_type}. Must be a valid SQL type constant.") + + self._inputsizes.append((sql_type, 0, 0)) + + def _reset_inputsizes(self): + """Reset input sizes after execution""" + self._inputsizes = None + + def _get_c_type_for_sql_type(self, sql_type: int) -> int: + """Map SQL type to appropriate C type for parameter binding""" + sql_to_c_type = { + ddbc_sql_const.SQL_CHAR.value: ddbc_sql_const.SQL_C_CHAR.value, + ddbc_sql_const.SQL_VARCHAR.value: ddbc_sql_const.SQL_C_CHAR.value, + ddbc_sql_const.SQL_LONGVARCHAR.value: ddbc_sql_const.SQL_C_CHAR.value, + ddbc_sql_const.SQL_WCHAR.value: ddbc_sql_const.SQL_C_WCHAR.value, + ddbc_sql_const.SQL_WVARCHAR.value: ddbc_sql_const.SQL_C_WCHAR.value, + ddbc_sql_const.SQL_WLONGVARCHAR.value: ddbc_sql_const.SQL_C_WCHAR.value, + ddbc_sql_const.SQL_DECIMAL.value: ddbc_sql_const.SQL_C_NUMERIC.value, + ddbc_sql_const.SQL_NUMERIC.value: ddbc_sql_const.SQL_C_NUMERIC.value, + ddbc_sql_const.SQL_BIT.value: ddbc_sql_const.SQL_C_BIT.value, + ddbc_sql_const.SQL_TINYINT.value: ddbc_sql_const.SQL_C_TINYINT.value, + ddbc_sql_const.SQL_SMALLINT.value: ddbc_sql_const.SQL_C_SHORT.value, + ddbc_sql_const.SQL_INTEGER.value: ddbc_sql_const.SQL_C_LONG.value, + ddbc_sql_const.SQL_BIGINT.value: ddbc_sql_const.SQL_C_SBIGINT.value, + ddbc_sql_const.SQL_REAL.value: ddbc_sql_const.SQL_C_FLOAT.value, + ddbc_sql_const.SQL_FLOAT.value: ddbc_sql_const.SQL_C_DOUBLE.value, + ddbc_sql_const.SQL_DOUBLE.value: ddbc_sql_const.SQL_C_DOUBLE.value, + ddbc_sql_const.SQL_BINARY.value: ddbc_sql_const.SQL_C_BINARY.value, + ddbc_sql_const.SQL_VARBINARY.value: ddbc_sql_const.SQL_C_BINARY.value, + ddbc_sql_const.SQL_LONGVARBINARY.value: ddbc_sql_const.SQL_C_BINARY.value, + ddbc_sql_const.SQL_DATE.value: ddbc_sql_const.SQL_C_TYPE_DATE.value, + ddbc_sql_const.SQL_TIME.value: ddbc_sql_const.SQL_C_TYPE_TIME.value, + ddbc_sql_const.SQL_TIMESTAMP.value: ddbc_sql_const.SQL_C_TYPE_TIMESTAMP.value, + } + return sql_to_c_type.get(sql_type, ddbc_sql_const.SQL_C_DEFAULT.value) + def _create_parameter_types_list(self, parameter, param_info, parameters_list, i, min_val=None, max_val=None): """ Maps parameter types for the given parameter. @@ -538,19 +641,51 @@ def _create_parameter_types_list(self, parameter, param_info, parameters_list, i paraminfo. """ paraminfo = param_info() - sql_type, c_type, column_size, decimal_digits, is_dae = self._map_sql_type( - parameter, parameters_list, i, min_val=min_val, max_val=max_val - ) + + # Check if we have explicit type information from setinputsizes + if self._inputsizes and i < len(self._inputsizes): + # Use explicit type information + sql_type, column_size, decimal_digits = self._inputsizes[i] + + # Default is_dae to False for explicit types, but set to True for large strings/binary + is_dae = False + + if parameter is None: + # For NULL parameters, always use SQL_C_DEFAULT regardless of SQL type + c_type = ddbc_sql_const.SQL_C_DEFAULT.value + else: + # For non-NULL parameters, determine the appropriate C type based on SQL type + c_type = self._get_c_type_for_sql_type(sql_type) + + # Check if this should be a DAE (data at execution) parameter + # For string types with large column sizes + if isinstance(parameter, str) and column_size > MAX_INLINE_CHAR: + is_dae = True + # For binary types with large column sizes + elif isinstance(parameter, (bytes, bytearray)) and column_size > 8000: + is_dae = True + + # Sanitize precision/scale for numeric types + if sql_type in (ddbc_sql_const.SQL_DECIMAL.value, ddbc_sql_const.SQL_NUMERIC.value): + column_size = max(1, min(int(column_size) if column_size > 0 else 18, 38)) + decimal_digits = min(max(0, decimal_digits), column_size) + + else: + # Fall back to automatic type inference + sql_type, c_type, column_size, decimal_digits, is_dae = self._map_sql_type( + parameter, parameters_list, i, min_val=min_val, max_val=max_val + ) + paraminfo.paramCType = c_type paraminfo.paramSQLType = sql_type paraminfo.inputOutputType = ddbc_sql_const.SQL_PARAM_INPUT.value paraminfo.columnSize = column_size paraminfo.decimalDigits = decimal_digits paraminfo.isDAE = is_dae - + if is_dae: paraminfo.dataPtr = parameter # Will be converted to py::object* in C++ - + return paraminfo def _initialize_description(self, column_metadata=None): @@ -802,6 +937,16 @@ def execute( parameters = list(parameters) + # Validate that inputsizes matches parameter count if both are present + if parameters and self._inputsizes: + if len(self._inputsizes) != len(parameters): + + warnings.warn( + f"Number of input sizes ({len(self._inputsizes)}) does not match " + f"number of parameters ({len(parameters)}). This may lead to unexpected behavior.", + Warning + ) + if parameters: for i, param in enumerate(parameters): paraminfo = self._create_parameter_types_list( @@ -879,27 +1024,482 @@ def execute( self.rowcount = ddbc_bindings.DDBCSQLRowCount(self.hstmt) self._clear_rownumber() + # After successful execution, initialize description if there are results + column_metadata = [] + try: + ddbc_bindings.DDBCSQLDescribeCol(self.hstmt, column_metadata) + self._initialize_description(column_metadata) + except Exception as e: + # If describe fails, it's likely there are no results (e.g., for INSERT) + self.description = None + + self._reset_inputsizes() # Reset input sizes after execution # Return self for method chaining return self - def _transpose_rowwise_to_columnwise(self, seq_of_parameters: list) -> list: + def _prepare_metadata_result_set(self, column_metadata=None, fallback_description=None, specialized_mapping=None): """ - Convert list of rows (row-wise) into list of columns (column-wise), - for array binding via ODBC. + Prepares a metadata result set by: + 1. Retrieving column metadata if not provided + 2. Initializing the description attribute + 3. Setting up column name mappings + 4. Creating wrapper fetch methods with column mapping support + Args: - seq_of_parameters: Sequence of sequences or mappings of parameters. + column_metadata (list, optional): Pre-fetched column metadata. + If None, it will be retrieved. + fallback_description (list, optional): Fallback description to use if + metadata retrieval fails. + specialized_mapping (dict, optional): Custom column mapping for special cases. + + Returns: + Cursor: Self, for method chaining """ - if not seq_of_parameters: - return [] + # Retrieve column metadata if not provided + if column_metadata is None: + column_metadata = [] + try: + ddbc_bindings.DDBCSQLDescribeCol(self.hstmt, column_metadata) + except InterfaceError as e: + log('error', f"Driver interface error during metadata retrieval: {e}") + except Exception as e: + # Log the exception with appropriate context + log('error', f"Failed to retrieve column metadata: {e}. Using standard ODBC column definitions instead.") + + # Initialize the description attribute with the column metadata + self._initialize_description(column_metadata) + + # Use fallback description if provided and current description is empty + if not self.description and fallback_description: + self.description = fallback_description + + # Define column names in ODBC standard order + self._column_map = {} + for i, (name, *_) in enumerate(self.description): + # Add standard name + self._column_map[name] = i + # Add lowercase alias + self._column_map[name.lower()] = i + + # If specialized mapping is provided, handle it differently + if specialized_mapping: + # Define specialized fetch methods that use the custom mapping + def fetchone_with_specialized_mapping(): + row = self._original_fetchone() + if row is not None: + merged_map = getattr(row, '_column_map', {}).copy() + merged_map.update(specialized_mapping) + row._column_map = merged_map + return row + + def fetchmany_with_specialized_mapping(size=None): + rows = self._original_fetchmany(size) + for row in rows: + merged_map = getattr(row, '_column_map', {}).copy() + merged_map.update(specialized_mapping) + row._column_map = merged_map + return rows + + def fetchall_with_specialized_mapping(): + rows = self._original_fetchall() + for row in rows: + merged_map = getattr(row, '_column_map', {}).copy() + merged_map.update(specialized_mapping) + row._column_map = merged_map + return rows + + # Save original fetch methods + if not hasattr(self, '_original_fetchone'): + self._original_fetchone = self.fetchone + self._original_fetchmany = self.fetchmany + self._original_fetchall = self.fetchall + + # Use specialized mapping methods + self.fetchone = fetchone_with_specialized_mapping + self.fetchmany = fetchmany_with_specialized_mapping + self.fetchall = fetchall_with_specialized_mapping + else: + # Standard column mapping + # Remember original fetch methods (store only once) + if not hasattr(self, '_original_fetchone'): + self._original_fetchone = self.fetchone + self._original_fetchmany = self.fetchmany + self._original_fetchall = self.fetchall + + # Create wrapper fetch methods that add column mappings + def fetchone_with_mapping(): + row = self._original_fetchone() + if row is not None: + row._column_map = self._column_map + return row + + def fetchmany_with_mapping(size=None): + rows = self._original_fetchmany(size) + for row in rows: + row._column_map = self._column_map + return rows + + def fetchall_with_mapping(): + rows = self._original_fetchall() + for row in rows: + row._column_map = self._column_map + return rows + + # Replace fetch methods + self.fetchone = fetchone_with_mapping + self.fetchmany = fetchmany_with_mapping + self.fetchall = fetchall_with_mapping + + # Return the cursor itself for method chaining + return self + + def getTypeInfo(self, sqlType=None): + """ + Executes SQLGetTypeInfo and creates a result set with information about + the specified data type or all data types supported by the ODBC driver if not specified. + """ + self._check_closed() + self._reset_cursor() + + sql_all_types = 0 # SQL_ALL_TYPES = 0 + + try: + # Get information about data types + ret = ddbc_bindings.DDBCSQLGetTypeInfo( + self.hstmt, + sqlType if sqlType is not None else sql_all_types + ) + check_error(ddbc_sql_const.SQL_HANDLE_STMT.value, self.hstmt, ret) + + # Use the helper method to prepare the result set + return self._prepare_metadata_result_set() + except Exception as e: + self._reset_cursor() + raise e + + def procedures(self, procedure=None, catalog=None, schema=None): + """ + Executes SQLProcedures and creates a result set of information about procedures in the data source. + + Args: + procedure (str, optional): Procedure name pattern. Default is None (all procedures). + catalog (str, optional): Catalog name pattern. Default is None (current catalog). + schema (str, optional): Schema name pattern. Default is None (all schemas). + """ + self._check_closed() + self._reset_cursor() + + # Call the SQLProcedures function + retcode = ddbc_bindings.DDBCSQLProcedures(self.hstmt, catalog, schema, procedure) + check_error(ddbc_sql_const.SQL_HANDLE_STMT.value, self.hstmt, retcode) + + # Define fallback description for procedures + fallback_description = [ + ("procedure_cat", str, None, 128, 128, 0, True), + ("procedure_schem", str, None, 128, 128, 0, True), + ("procedure_name", str, None, 128, 128, 0, False), + ("num_input_params", int, None, 10, 10, 0, True), + ("num_output_params", int, None, 10, 10, 0, True), + ("num_result_sets", int, None, 10, 10, 0, True), + ("remarks", str, None, 254, 254, 0, True), + ("procedure_type", int, None, 10, 10, 0, False) + ] + + # Use the helper method to prepare the result set + return self._prepare_metadata_result_set(fallback_description=fallback_description) + + def primaryKeys(self, table, catalog=None, schema=None): + """ + Creates a result set of column names that make up the primary key for a table + by executing the SQLPrimaryKeys function. + + Args: + table (str): The name of the table + catalog (str, optional): The catalog name (database). Defaults to None. + schema (str, optional): The schema name. Defaults to None. + """ + self._check_closed() + self._reset_cursor() + + if not table: + raise ProgrammingError("Table name must be specified", "HY000") + + # Call the SQLPrimaryKeys function + retcode = ddbc_bindings.DDBCSQLPrimaryKeys(self.hstmt, catalog, schema, table) + check_error(ddbc_sql_const.SQL_HANDLE_STMT.value, self.hstmt, retcode) + + # Define fallback description for primary keys + fallback_description = [ + ("table_cat", str, None, 128, 128, 0, True), + ("table_schem", str, None, 128, 128, 0, True), + ("table_name", str, None, 128, 128, 0, False), + ("column_name", str, None, 128, 128, 0, False), + ("key_seq", int, None, 10, 10, 0, False), + ("pk_name", str, None, 128, 128, 0, True) + ] + + # Use the helper method to prepare the result set + return self._prepare_metadata_result_set(fallback_description=fallback_description) + + def foreignKeys(self, table=None, catalog=None, schema=None, foreignTable=None, foreignCatalog=None, foreignSchema=None): + """ + Executes the SQLForeignKeys function and creates a result set of column names that are foreign keys. + + This function returns: + 1. Foreign keys in the specified table that reference primary keys in other tables, OR + 2. Foreign keys in other tables that reference the primary key in the specified table + """ + self._check_closed() + self._reset_cursor() + + # Check if we have at least one table specified + if table is None and foreignTable is None: + raise ProgrammingError("Either table or foreignTable must be specified", "HY000") + + # Call the SQLForeignKeys function + retcode = ddbc_bindings.DDBCSQLForeignKeys( + self.hstmt, + foreignCatalog, foreignSchema, foreignTable, + catalog, schema, table + ) + check_error(ddbc_sql_const.SQL_HANDLE_STMT.value, self.hstmt, retcode) + + # Define fallback description for foreign keys + fallback_description = [ + ("pktable_cat", str, None, 128, 128, 0, True), + ("pktable_schem", str, None, 128, 128, 0, True), + ("pktable_name", str, None, 128, 128, 0, False), + ("pkcolumn_name", str, None, 128, 128, 0, False), + ("fktable_cat", str, None, 128, 128, 0, True), + ("fktable_schem", str, None, 128, 128, 0, True), + ("fktable_name", str, None, 128, 128, 0, False), + ("fkcolumn_name", str, None, 128, 128, 0, False), + ("key_seq", int, None, 10, 10, 0, False), + ("update_rule", int, None, 10, 10, 0, False), + ("delete_rule", int, None, 10, 10, 0, False), + ("fk_name", str, None, 128, 128, 0, True), + ("pk_name", str, None, 128, 128, 0, True), + ("deferrability", int, None, 10, 10, 0, False) + ] + + # Use the helper method to prepare the result set + return self._prepare_metadata_result_set(fallback_description=fallback_description) - num_params = len(seq_of_parameters[0]) - columnwise = [[] for _ in range(num_params)] + def rowIdColumns(self, table, catalog=None, schema=None, nullable=True): + """ + Executes SQLSpecialColumns with SQL_BEST_ROWID which creates a result set of + columns that uniquely identify a row. + """ + self._check_closed() + self._reset_cursor() + + if not table: + raise ProgrammingError("Table name must be specified", "HY000") + + # Set the identifier type and options + identifier_type = ddbc_sql_const.SQL_BEST_ROWID.value + scope = ddbc_sql_const.SQL_SCOPE_CURROW.value + nullable_flag = ddbc_sql_const.SQL_NULLABLE.value if nullable else ddbc_sql_const.SQL_NO_NULLS.value + + # Call the SQLSpecialColumns function + retcode = ddbc_bindings.DDBCSQLSpecialColumns( + self.hstmt, identifier_type, catalog, schema, table, scope, nullable_flag + ) + check_error(ddbc_sql_const.SQL_HANDLE_STMT.value, self.hstmt, retcode) + + # Define fallback description for special columns + fallback_description = [ + ("scope", int, None, 10, 10, 0, False), + ("column_name", str, None, 128, 128, 0, False), + ("data_type", int, None, 10, 10, 0, False), + ("type_name", str, None, 128, 128, 0, False), + ("column_size", int, None, 10, 10, 0, False), + ("buffer_length", int, None, 10, 10, 0, False), + ("decimal_digits", int, None, 10, 10, 0, True), + ("pseudo_column", int, None, 10, 10, 0, False) + ] + + # Use the helper method to prepare the result set + return self._prepare_metadata_result_set(fallback_description=fallback_description) + + def rowVerColumns(self, table, catalog=None, schema=None, nullable=True): + """ + Executes SQLSpecialColumns with SQL_ROWVER which creates a result set of + columns that are automatically updated when any value in the row is updated. + """ + self._check_closed() + self._reset_cursor() + + if not table: + raise ProgrammingError("Table name must be specified", "HY000") + + # Set the identifier type and options + identifier_type = ddbc_sql_const.SQL_ROWVER.value + scope = ddbc_sql_const.SQL_SCOPE_CURROW.value + nullable_flag = ddbc_sql_const.SQL_NULLABLE.value if nullable else ddbc_sql_const.SQL_NO_NULLS.value + + # Call the SQLSpecialColumns function + retcode = ddbc_bindings.DDBCSQLSpecialColumns( + self.hstmt, identifier_type, catalog, schema, table, scope, nullable_flag + ) + check_error(ddbc_sql_const.SQL_HANDLE_STMT.value, self.hstmt, retcode) + + # Same fallback description as rowIdColumns + fallback_description = [ + ("scope", int, None, 10, 10, 0, False), + ("column_name", str, None, 128, 128, 0, False), + ("data_type", int, None, 10, 10, 0, False), + ("type_name", str, None, 128, 128, 0, False), + ("column_size", int, None, 10, 10, 0, False), + ("buffer_length", int, None, 10, 10, 0, False), + ("decimal_digits", int, None, 10, 10, 0, True), + ("pseudo_column", int, None, 10, 10, 0, False) + ] + + # Use the helper method to prepare the result set + return self._prepare_metadata_result_set(fallback_description=fallback_description) + + def statistics(self, table: str, catalog: str = None, schema: str = None, unique: bool = False, quick: bool = True) -> 'Cursor': + """ + Creates a result set of statistics about a single table and the indexes associated + with the table by executing SQLStatistics. + """ + self._check_closed() + self._reset_cursor() + + if not table: + raise ProgrammingError("Table name is required", "HY000") + + # Set unique and quick flags + unique_option = ddbc_sql_const.SQL_INDEX_UNIQUE.value if unique else ddbc_sql_const.SQL_INDEX_ALL.value + reserved_option = ddbc_sql_const.SQL_QUICK.value if quick else ddbc_sql_const.SQL_ENSURE.value + + # Call the SQLStatistics function + retcode = ddbc_bindings.DDBCSQLStatistics( + self.hstmt, catalog, schema, table, unique_option, reserved_option + ) + check_error(ddbc_sql_const.SQL_HANDLE_STMT.value, self.hstmt, retcode) + + # Define fallback description for statistics + fallback_description = [ + ("table_cat", str, None, 128, 128, 0, True), + ("table_schem", str, None, 128, 128, 0, True), + ("table_name", str, None, 128, 128, 0, False), + ("non_unique", bool, None, 1, 1, 0, False), + ("index_qualifier", str, None, 128, 128, 0, True), + ("index_name", str, None, 128, 128, 0, True), + ("type", int, None, 10, 10, 0, False), + ("ordinal_position", int, None, 10, 10, 0, False), + ("column_name", str, None, 128, 128, 0, True), + ("asc_or_desc", str, None, 1, 1, 0, True), + ("cardinality", int, None, 20, 20, 0, True), + ("pages", int, None, 20, 20, 0, True), + ("filter_condition", str, None, 128, 128, 0, True) + ] + + # Use the helper method to prepare the result set + return self._prepare_metadata_result_set(fallback_description=fallback_description) + + def columns(self, table=None, catalog=None, schema=None, column=None): + """ + Creates a result set of column information in the specified tables + using the SQLColumns function. + """ + self._check_closed() + self._reset_cursor() + + # Call the SQLColumns function + retcode = ddbc_bindings.DDBCSQLColumns( + self.hstmt, catalog, schema, table, column + ) + check_error(ddbc_sql_const.SQL_HANDLE_STMT.value, self.hstmt, retcode) + + # Define fallback description for columns + fallback_description = [ + ("table_cat", str, None, 128, 128, 0, True), + ("table_schem", str, None, 128, 128, 0, True), + ("table_name", str, None, 128, 128, 0, False), + ("column_name", str, None, 128, 128, 0, False), + ("data_type", int, None, 10, 10, 0, False), + ("type_name", str, None, 128, 128, 0, False), + ("column_size", int, None, 10, 10, 0, True), + ("buffer_length", int, None, 10, 10, 0, True), + ("decimal_digits", int, None, 10, 10, 0, True), + ("num_prec_radix", int, None, 10, 10, 0, True), + ("nullable", int, None, 10, 10, 0, False), + ("remarks", str, None, 254, 254, 0, True), + ("column_def", str, None, 254, 254, 0, True), + ("sql_data_type", int, None, 10, 10, 0, False), + ("sql_datetime_sub", int, None, 10, 10, 0, True), + ("char_octet_length", int, None, 10, 10, 0, True), + ("ordinal_position", int, None, 10, 10, 0, False), + ("is_nullable", str, None, 254, 254, 0, True) + ] + + # Use the helper method to prepare the result set + return self._prepare_metadata_result_set(fallback_description=fallback_description) + + @staticmethod + def _select_best_sample_value(column): + """ + Selects the most representative non-null value from a column for type inference. + + This is used during executemany() to infer SQL/C types based on actual data, + preferring a non-null value that is not the first row to avoid bias from placeholder defaults. + + Args: + column: List of values in the column. + """ + non_nulls = [v for v in column if v is not None] + if not non_nulls: + return None + if all(isinstance(v, int) for v in non_nulls): + # Pick the value with the widest range (min/max) + return max(non_nulls, key=lambda v: abs(v)) + if all(isinstance(v, float) for v in non_nulls): + return 0.0 + if all(isinstance(v, decimal.Decimal) for v in non_nulls): + return max(non_nulls, key=lambda d: len(d.as_tuple().digits)) + if all(isinstance(v, str) for v in non_nulls): + return max(non_nulls, key=lambda s: len(str(s))) + if all(isinstance(v, datetime.datetime) for v in non_nulls): + return datetime.datetime.now() + if all(isinstance(v, datetime.date) for v in non_nulls): + return datetime.date.today() + return non_nulls[0] # fallback + + def _transpose_rowwise_to_columnwise(self, seq_of_parameters: list) -> tuple[list, int]: + """ + Convert sequence of rows (row-wise) into list of columns (column-wise), + for array binding via ODBC. Works with both iterables and generators. + + Args: + seq_of_parameters: Sequence of sequences or mappings of parameters. + + Returns: + tuple: (columnwise_data, row_count) + """ + columnwise = [] + first_row = True + row_count = 0 + for row in seq_of_parameters: - if len(row) != num_params: - raise ValueError("Inconsistent parameter row size in executemany()") + row_count += 1 + if first_row: + # Initialize columnwise lists based on first row + num_params = len(row) + columnwise = [[] for _ in range(num_params)] + first_row = False + else: + # Validate row size consistency + if len(row) != num_params: + raise ValueError("Inconsistent parameter row size in executemany()") + + # Add each value to its column list for i, val in enumerate(row): columnwise[i].append(val) - return columnwise + + return columnwise, row_count def _compute_column_type(self, column): """ @@ -962,50 +1562,166 @@ def executemany(self, operation: str, seq_of_parameters: list) -> None: except Exception as e: log('warning', f"Failed to set query timeout: {e}") + # Get sample row for parameter type detection and validation + sample_row = seq_of_parameters[0] if hasattr(seq_of_parameters, '__getitem__') else next(iter(seq_of_parameters)) + param_count = len(sample_row) param_info = ddbc_bindings.ParamInfo - param_count = len(seq_of_parameters[0]) parameters_type = [] + # Check if we have explicit input sizes set + if self._inputsizes: + # Validate input sizes match parameter count + if len(self._inputsizes) != param_count: + warnings.warn( + f"Number of input sizes ({len(self._inputsizes)}) does not match " + f"number of parameters ({param_count}). This may lead to unexpected behavior.", + Warning + ) + + # Prepare parameter type information for col_index in range(param_count): - column = [row[col_index] for row in seq_of_parameters] + column = [row[col_index] for row in seq_of_parameters] if hasattr(seq_of_parameters, '__getitem__') else [] sample_value, min_val, max_val = self._compute_column_type(column) - modified_row = list(seq_of_parameters[0]) - modified_row[col_index] = sample_value - # sending original values for all rows here, we may change this if any inconsistent behavior is observed - paraminfo = self._create_parameter_types_list( - sample_value, param_info, modified_row, col_index, min_val=min_val, max_val=max_val - ) - parameters_type.append(paraminfo) - - columnwise_params = self._transpose_rowwise_to_columnwise(seq_of_parameters) + + if self._inputsizes and col_index < len(self._inputsizes): + # Use explicitly set input sizes + sql_type, column_size, decimal_digits = self._inputsizes[col_index] + + # Default is_dae to False + is_dae = False + + # Determine appropriate C type based on SQL type + c_type = self._get_c_type_for_sql_type(sql_type) + + # Check if this should be a DAE (data at execution) parameter based on column size + if sample_value is not None: + if isinstance(sample_value, str) and column_size > MAX_INLINE_CHAR: + is_dae = True + elif isinstance(sample_value, (bytes, bytearray)) and column_size > 8000: + is_dae = True + + # Sanitize precision/scale for numeric types + if sql_type in (ddbc_sql_const.SQL_DECIMAL.value, ddbc_sql_const.SQL_NUMERIC.value): + column_size = max(1, min(int(column_size) if column_size > 0 else 18, 38)) + decimal_digits = min(max(0, decimal_digits), column_size) + + # For binary data columns with mixed content, we need to find max size + if sql_type in (ddbc_sql_const.SQL_BINARY.value, ddbc_sql_const.SQL_VARBINARY.value, + ddbc_sql_const.SQL_LONGVARBINARY.value): + # Find the maximum size needed for any row's binary data + max_binary_size = 0 + for row in seq_of_parameters: + value = row[col_index] + if value is not None and isinstance(value, (bytes, bytearray)): + max_binary_size = max(max_binary_size, len(value)) + + # For SQL Server VARBINARY(MAX), we need to use large object binding + if column_size > 8000 or max_binary_size > 8000: + sql_type = ddbc_sql_const.SQL_LONGVARBINARY.value + is_dae = True + + # Update column_size to actual maximum size if it's larger + # Always ensure at least a minimum size of 1 for empty strings + column_size = max(max_binary_size, 1) + + paraminfo = param_info() + paraminfo.paramCType = c_type + paraminfo.paramSQLType = sql_type + paraminfo.inputOutputType = ddbc_sql_const.SQL_PARAM_INPUT.value + paraminfo.columnSize = column_size + paraminfo.decimalDigits = decimal_digits + paraminfo.isDAE = is_dae + + # Ensure we never have SQL_C_DEFAULT (0) for C-type + if paraminfo.paramCType == 0: + paraminfo.paramCType = ddbc_sql_const.SQL_C_DEFAULT.value + + parameters_type.append(paraminfo) + else: + # Use auto-detection for columns without explicit types + column = [row[col_index] for row in seq_of_parameters] if hasattr(seq_of_parameters, '__getitem__') else [] + if not column: + # For generators, use the sample row for inference + sample_value = sample_row[col_index] + else: + sample_value = self._select_best_sample_value(column) + + dummy_row = list(sample_row) + paraminfo = self._create_parameter_types_list( + sample_value, param_info, dummy_row, col_index, min_val=min_val, max_val=max_val + ) + # Special handling for binary data in auto-detected types + if paraminfo.paramSQLType in (ddbc_sql_const.SQL_BINARY.value, ddbc_sql_const.SQL_VARBINARY.value, + ddbc_sql_const.SQL_LONGVARBINARY.value): + # Find the maximum size needed for any row's binary data + max_binary_size = 0 + for row in seq_of_parameters: + value = row[col_index] + if value is not None and isinstance(value, (bytes, bytearray)): + max_binary_size = max(max_binary_size, len(value)) + + # For SQL Server VARBINARY(MAX), we need to use large object binding + if max_binary_size > 8000: + paraminfo.paramSQLType = ddbc_sql_const.SQL_LONGVARBINARY.value + paraminfo.isDAE = True + + # Update column_size to actual maximum size + # Always ensure at least a minimum size of 1 for empty strings + paraminfo.columnSize = max(max_binary_size, 1) + + parameters_type.append(paraminfo) + + # Process parameters into column-wise format with possible type conversions + # First, convert any Decimal types as needed for NUMERIC/DECIMAL columns + processed_parameters = [] + for row in seq_of_parameters: + processed_row = list(row) + for i, val in enumerate(processed_row): + if (parameters_type[i].paramSQLType in + (ddbc_sql_const.SQL_DECIMAL.value, ddbc_sql_const.SQL_NUMERIC.value) and + not isinstance(val, decimal.Decimal) and val is not None): + try: + processed_row[i] = decimal.Decimal(str(val)) + except: + pass # Keep original value if conversion fails + processed_parameters.append(processed_row) + + # Now transpose the processed parameters + columnwise_params, row_count = self._transpose_rowwise_to_columnwise(processed_parameters) + + # Add debug logging log('debug', "Executing batch query with %d parameter sets:\n%s", - len(seq_of_parameters), "\n".join(f" {i+1}: {tuple(p) if isinstance(p, (list, tuple)) else p}" for i, p in enumerate(seq_of_parameters)) + len(seq_of_parameters), "\n".join(f" {i+1}: {tuple(p) if isinstance(p, (list, tuple)) else p}" for i, p in enumerate(seq_of_parameters[:5])) # Limit to first 5 rows for large batches ) - + # Execute batched statement ret = ddbc_bindings.SQLExecuteMany( self.hstmt, operation, columnwise_params, parameters_type, - len(seq_of_parameters) + row_count ) - check_error(ddbc_sql_const.SQL_HANDLE_STMT.value, self.hstmt, ret) - + # Capture any diagnostic messages after execution if self.hstmt: self.messages.extend(ddbc_bindings.DDBCSQLGetAllDiagRecords(self.hstmt)) - - self.rowcount = ddbc_bindings.DDBCSQLRowCount(self.hstmt) - self.last_executed_stmt = operation - self._initialize_description() - - if self.description: - self.rowcount = -1 - self._reset_rownumber() - else: + + try: + check_error(ddbc_sql_const.SQL_HANDLE_STMT.value, self.hstmt, ret) self.rowcount = ddbc_bindings.DDBCSQLRowCount(self.hstmt) - self._clear_rownumber() + self.last_executed_stmt = operation + self._initialize_description() + + if self.description: + self.rowcount = -1 + self._reset_rownumber() + else: + self.rowcount = ddbc_bindings.DDBCSQLRowCount(self.hstmt) + self._clear_rownumber() + finally: + # Reset input sizes after execution + self._reset_inputsizes() def fetchone(self) -> Union[None, Row]: """ @@ -1174,7 +1890,7 @@ def fetchval(self): Example: >>> count = cursor.execute('SELECT COUNT(*) FROM users').fetchval() - >>> max_id = cursor.execute('SELECT MAX(id) FROM products').fetchval() + >>> max_id = cursor.execute('SELECT MAX(id) FROM users').fetchval() >>> name = cursor.execute('SELECT name FROM users WHERE id = ?', user_id).fetchval() Note: @@ -1425,26 +2141,8 @@ def tables(self, table=None, catalog=None, schema=None, tableType=None): Returns: Cursor: The cursor object itself for method chaining with fetch methods. - - Example: - # Get all tables in the database - tables = cursor.tables().fetchall() - - # Get all tables in schema 'dbo' - tables = cursor.tables(schema='dbo').fetchall() - - # Get table named 'Customers' - tables = cursor.tables(table='Customers').fetchone() - - # Get all views with fetchmany - tables = cursor.tables(tableType='VIEW').fetchmany(10) """ self._check_closed() - - # Clear messages - self.messages = [] - - # Always reset the cursor first to ensure clean state self._reset_cursor() # Format table_type parameter - SQLTables expects comma-separated string @@ -1455,84 +2153,29 @@ def tables(self, table=None, catalog=None, schema=None, tableType=None): else: table_type_str = str(tableType) - # Call SQLTables via the helper method - self._execute_tables( - self.hstmt, - catalog_name=catalog, - schema_name=schema, - table_name=table, - table_type=table_type_str - ) - - # Initialize description from column metadata - column_metadata = [] try: - ddbc_bindings.DDBCSQLDescribeCol(self.hstmt, column_metadata) - self._initialize_description(column_metadata) - except InterfaceError as e: - log('error', f"Driver interface error during metadata retrieval: {e}") - except Exception as e: - # Log the exception with appropriate context - log('error', f"Failed to retrieve column metadata: {e}. Using standard ODBC column definitions instead.") - - if not self.description: - # If describe fails, create a manual description for the standard columns - column_types = [str, str, str, str, str] - self.description = [ - ("table_cat", column_types[0], None, 128, 128, 0, True), - ("table_schem", column_types[1], None, 128, 128, 0, True), - ("table_name", column_types[2], None, 128, 128, 0, False), - ("table_type", column_types[3], None, 128, 128, 0, False), - ("remarks", column_types[4], None, 254, 254, 0, True) + # Call SQLTables via the helper method + self._execute_tables( + self.hstmt, + catalog_name=catalog, + schema_name=schema, + table_name=table, + table_type=table_type_str + ) + + # Define fallback description for tables + fallback_description = [ + ("table_cat", str, None, 128, 128, 0, True), + ("table_schem", str, None, 128, 128, 0, True), + ("table_name", str, None, 128, 128, 0, False), + ("table_type", str, None, 128, 128, 0, False), + ("remarks", str, None, 254, 254, 0, True) ] - - # Store the column mappings for this specific tables() call - column_names = [desc[0] for desc in self.description] - - # Create a specialized column map for this result set - columns_map = {} - for i, name in enumerate(column_names): - columns_map[name] = i - columns_map[name.lower()] = i - - # Define wrapped fetch methods that preserve existing column mapping - # but add our specialized mapping just for column results - def fetchone_with_columns_mapping(): - row = self._original_fetchone() - if row is not None: - # Create a merged map with columns result taking precedence - merged_map = getattr(row, '_column_map', {}).copy() - merged_map.update(columns_map) - row._column_map = merged_map - return row - - def fetchmany_with_columns_mapping(size=None): - rows = self._original_fetchmany(size) - for row in rows: - # Create a merged map with columns result taking precedence - merged_map = getattr(row, '_column_map', {}).copy() - merged_map.update(columns_map) - row._column_map = merged_map - return rows - - def fetchall_with_columns_mapping(): - rows = self._original_fetchall() - for row in rows: - # Create a merged map with columns result taking precedence - merged_map = getattr(row, '_column_map', {}).copy() - merged_map.update(columns_map) - row._column_map = merged_map - return rows - - # Save original fetch methods - if not hasattr(self, '_original_fetchone'): - self._original_fetchone = self.fetchone - self._original_fetchmany = self.fetchmany - self._original_fetchall = self.fetchall - - # Override fetch methods with our wrapped versions - self.fetchone = fetchone_with_columns_mapping - self.fetchmany = fetchmany_with_columns_mapping - self.fetchall = fetchall_with_columns_mapping - - return self \ No newline at end of file + + # Use the helper method to prepare the result set + return self._prepare_metadata_result_set(fallback_description=fallback_description) + + except Exception as e: + # Log the error and re-raise + log('error', f"Error executing tables query: {e}") + raise \ No newline at end of file diff --git a/mssql_python/pybind/ddbc_bindings.cpp b/mssql_python/pybind/ddbc_bindings.cpp index bac9c664..c7f7aefb 100644 --- a/mssql_python/pybind/ddbc_bindings.cpp +++ b/mssql_python/pybind/ddbc_bindings.cpp @@ -124,6 +124,13 @@ SQLBindColFunc SQLBindCol_ptr = nullptr; SQLDescribeColFunc SQLDescribeCol_ptr = nullptr; SQLMoreResultsFunc SQLMoreResults_ptr = nullptr; SQLColAttributeFunc SQLColAttribute_ptr = nullptr; +SQLGetTypeInfoFunc SQLGetTypeInfo_ptr = nullptr; +SQLProceduresFunc SQLProcedures_ptr = nullptr; +SQLForeignKeysFunc SQLForeignKeys_ptr = nullptr; +SQLPrimaryKeysFunc SQLPrimaryKeys_ptr = nullptr; +SQLSpecialColumnsFunc SQLSpecialColumns_ptr = nullptr; +SQLStatisticsFunc SQLStatistics_ptr = nullptr; +SQLColumnsFunc SQLColumns_ptr = nullptr; // Transaction APIs SQLEndTranFunc SQLEndTran_ptr = nullptr; @@ -813,6 +820,13 @@ DriverHandle LoadDriverOrThrowException() { SQLDescribeCol_ptr = GetFunctionPointer(handle, "SQLDescribeColW"); SQLMoreResults_ptr = GetFunctionPointer(handle, "SQLMoreResults"); SQLColAttribute_ptr = GetFunctionPointer(handle, "SQLColAttributeW"); + SQLGetTypeInfo_ptr = GetFunctionPointer(handle, "SQLGetTypeInfoW"); + SQLProcedures_ptr = GetFunctionPointer(handle, "SQLProceduresW"); + SQLForeignKeys_ptr = GetFunctionPointer(handle, "SQLForeignKeysW"); + SQLPrimaryKeys_ptr = GetFunctionPointer(handle, "SQLPrimaryKeysW"); + SQLSpecialColumns_ptr = GetFunctionPointer(handle, "SQLSpecialColumnsW"); + SQLStatistics_ptr = GetFunctionPointer(handle, "SQLStatisticsW"); + SQLColumns_ptr = GetFunctionPointer(handle, "SQLColumnsW"); SQLEndTran_ptr = GetFunctionPointer(handle, "SQLEndTran"); SQLDisconnect_ptr = GetFunctionPointer(handle, "SQLDisconnect"); @@ -838,7 +852,10 @@ DriverHandle LoadDriverOrThrowException() { SQLEndTran_ptr && SQLDisconnect_ptr && SQLFreeHandle_ptr && SQLFreeStmt_ptr && SQLGetDiagRec_ptr && SQLParamData_ptr && SQLPutData_ptr && SQLTables_ptr && - SQLDescribeParam_ptr; + SQLDescribeParam_ptr && + SQLGetTypeInfo_ptr && SQLProcedures_ptr && SQLForeignKeys_ptr && + SQLPrimaryKeys_ptr && SQLSpecialColumns_ptr && SQLStatistics_ptr && + SQLColumns_ptr; if (!success) { ThrowStdException("Failed to load required function pointers from driver."); @@ -903,6 +920,244 @@ void SqlHandle::free() { } } +SQLRETURN SQLGetTypeInfo_Wrapper(SqlHandlePtr StatementHandle, SQLSMALLINT DataType) { + if (!SQLGetTypeInfo_ptr) { + ThrowStdException("SQLGetTypeInfo function not loaded"); + } + + return SQLGetTypeInfo_ptr(StatementHandle->get(), DataType); +} + +SQLRETURN SQLProcedures_wrap(SqlHandlePtr StatementHandle, + const py::object& catalogObj, + const py::object& schemaObj, + const py::object& procedureObj) { + if (!SQLProcedures_ptr) { + ThrowStdException("SQLProcedures function not loaded"); + } + + std::wstring catalog = py::isinstance(catalogObj) ? L"" : catalogObj.cast(); + std::wstring schema = py::isinstance(schemaObj) ? L"" : schemaObj.cast(); + std::wstring procedure = py::isinstance(procedureObj) ? L"" : procedureObj.cast(); + +#if defined(__APPLE__) || defined(__linux__) + // Unix implementation + std::vector catalogBuf = WStringToSQLWCHAR(catalog); + std::vector schemaBuf = WStringToSQLWCHAR(schema); + std::vector procedureBuf = WStringToSQLWCHAR(procedure); + + return SQLProcedures_ptr( + StatementHandle->get(), + catalog.empty() ? nullptr : catalogBuf.data(), + catalog.empty() ? 0 : SQL_NTS, + schema.empty() ? nullptr : schemaBuf.data(), + schema.empty() ? 0 : SQL_NTS, + procedure.empty() ? nullptr : procedureBuf.data(), + procedure.empty() ? 0 : SQL_NTS); +#else + // Windows implementation + return SQLProcedures_ptr( + StatementHandle->get(), + catalog.empty() ? nullptr : (SQLWCHAR*)catalog.c_str(), + catalog.empty() ? 0 : SQL_NTS, + schema.empty() ? nullptr : (SQLWCHAR*)schema.c_str(), + schema.empty() ? 0 : SQL_NTS, + procedure.empty() ? nullptr : (SQLWCHAR*)procedure.c_str(), + procedure.empty() ? 0 : SQL_NTS); +#endif +} + +SQLRETURN SQLForeignKeys_wrap(SqlHandlePtr StatementHandle, + const py::object& pkCatalogObj, + const py::object& pkSchemaObj, + const py::object& pkTableObj, + const py::object& fkCatalogObj, + const py::object& fkSchemaObj, + const py::object& fkTableObj) { + if (!SQLForeignKeys_ptr) { + ThrowStdException("SQLForeignKeys function not loaded"); + } + + std::wstring pkCatalog = py::isinstance(pkCatalogObj) ? L"" : pkCatalogObj.cast(); + std::wstring pkSchema = py::isinstance(pkSchemaObj) ? L"" : pkSchemaObj.cast(); + std::wstring pkTable = py::isinstance(pkTableObj) ? L"" : pkTableObj.cast(); + std::wstring fkCatalog = py::isinstance(fkCatalogObj) ? L"" : fkCatalogObj.cast(); + std::wstring fkSchema = py::isinstance(fkSchemaObj) ? L"" : fkSchemaObj.cast(); + std::wstring fkTable = py::isinstance(fkTableObj) ? L"" : fkTableObj.cast(); + +#if defined(__APPLE__) || defined(__linux__) + // Unix implementation + std::vector pkCatalogBuf = WStringToSQLWCHAR(pkCatalog); + std::vector pkSchemaBuf = WStringToSQLWCHAR(pkSchema); + std::vector pkTableBuf = WStringToSQLWCHAR(pkTable); + std::vector fkCatalogBuf = WStringToSQLWCHAR(fkCatalog); + std::vector fkSchemaBuf = WStringToSQLWCHAR(fkSchema); + std::vector fkTableBuf = WStringToSQLWCHAR(fkTable); + + return SQLForeignKeys_ptr( + StatementHandle->get(), + pkCatalog.empty() ? nullptr : pkCatalogBuf.data(), + pkCatalog.empty() ? 0 : SQL_NTS, + pkSchema.empty() ? nullptr : pkSchemaBuf.data(), + pkSchema.empty() ? 0 : SQL_NTS, + pkTable.empty() ? nullptr : pkTableBuf.data(), + pkTable.empty() ? 0 : SQL_NTS, + fkCatalog.empty() ? nullptr : fkCatalogBuf.data(), + fkCatalog.empty() ? 0 : SQL_NTS, + fkSchema.empty() ? nullptr : fkSchemaBuf.data(), + fkSchema.empty() ? 0 : SQL_NTS, + fkTable.empty() ? nullptr : fkTableBuf.data(), + fkTable.empty() ? 0 : SQL_NTS); +#else + // Windows implementation + return SQLForeignKeys_ptr( + StatementHandle->get(), + pkCatalog.empty() ? nullptr : (SQLWCHAR*)pkCatalog.c_str(), + pkCatalog.empty() ? 0 : SQL_NTS, + pkSchema.empty() ? nullptr : (SQLWCHAR*)pkSchema.c_str(), + pkSchema.empty() ? 0 : SQL_NTS, + pkTable.empty() ? nullptr : (SQLWCHAR*)pkTable.c_str(), + pkTable.empty() ? 0 : SQL_NTS, + fkCatalog.empty() ? nullptr : (SQLWCHAR*)fkCatalog.c_str(), + fkCatalog.empty() ? 0 : SQL_NTS, + fkSchema.empty() ? nullptr : (SQLWCHAR*)fkSchema.c_str(), + fkSchema.empty() ? 0 : SQL_NTS, + fkTable.empty() ? nullptr : (SQLWCHAR*)fkTable.c_str(), + fkTable.empty() ? 0 : SQL_NTS); +#endif +} + +SQLRETURN SQLPrimaryKeys_wrap(SqlHandlePtr StatementHandle, + const py::object& catalogObj, + const py::object& schemaObj, + const std::wstring& table) { + if (!SQLPrimaryKeys_ptr) { + ThrowStdException("SQLPrimaryKeys function not loaded"); + } + + // Convert py::object to std::wstring, treating None as empty string + std::wstring catalog = catalogObj.is_none() ? L"" : catalogObj.cast(); + std::wstring schema = schemaObj.is_none() ? L"" : schemaObj.cast(); + +#if defined(__APPLE__) || defined(__linux__) + // Unix implementation + std::vector catalogBuf = WStringToSQLWCHAR(catalog); + std::vector schemaBuf = WStringToSQLWCHAR(schema); + std::vector tableBuf = WStringToSQLWCHAR(table); + + return SQLPrimaryKeys_ptr( + StatementHandle->get(), + catalog.empty() ? nullptr : catalogBuf.data(), + catalog.empty() ? 0 : SQL_NTS, + schema.empty() ? nullptr : schemaBuf.data(), + schema.empty() ? 0 : SQL_NTS, + table.empty() ? nullptr : tableBuf.data(), + table.empty() ? 0 : SQL_NTS); +#else + // Windows implementation + return SQLPrimaryKeys_ptr( + StatementHandle->get(), + catalog.empty() ? nullptr : (SQLWCHAR*)catalog.c_str(), + catalog.empty() ? 0 : SQL_NTS, + schema.empty() ? nullptr : (SQLWCHAR*)schema.c_str(), + schema.empty() ? 0 : SQL_NTS, + table.empty() ? nullptr : (SQLWCHAR*)table.c_str(), + table.empty() ? 0 : SQL_NTS); +#endif +} + +SQLRETURN SQLStatistics_wrap(SqlHandlePtr StatementHandle, + const py::object& catalogObj, + const py::object& schemaObj, + const std::wstring& table, + SQLUSMALLINT unique, + SQLUSMALLINT reserved) { + if (!SQLStatistics_ptr) { + ThrowStdException("SQLStatistics function not loaded"); + } + + // Convert py::object to std::wstring, treating None as empty string + std::wstring catalog = catalogObj.is_none() ? L"" : catalogObj.cast(); + std::wstring schema = schemaObj.is_none() ? L"" : schemaObj.cast(); + +#if defined(__APPLE__) || defined(__linux__) + // Unix implementation + std::vector catalogBuf = WStringToSQLWCHAR(catalog); + std::vector schemaBuf = WStringToSQLWCHAR(schema); + std::vector tableBuf = WStringToSQLWCHAR(table); + + return SQLStatistics_ptr( + StatementHandle->get(), + catalog.empty() ? nullptr : catalogBuf.data(), + catalog.empty() ? 0 : SQL_NTS, + schema.empty() ? nullptr : schemaBuf.data(), + schema.empty() ? 0 : SQL_NTS, + table.empty() ? nullptr : tableBuf.data(), + table.empty() ? 0 : SQL_NTS, + unique, + reserved); +#else + // Windows implementation + return SQLStatistics_ptr( + StatementHandle->get(), + catalog.empty() ? nullptr : (SQLWCHAR*)catalog.c_str(), + catalog.empty() ? 0 : SQL_NTS, + schema.empty() ? nullptr : (SQLWCHAR*)schema.c_str(), + schema.empty() ? 0 : SQL_NTS, + table.empty() ? nullptr : (SQLWCHAR*)table.c_str(), + table.empty() ? 0 : SQL_NTS, + unique, + reserved); +#endif +} + +SQLRETURN SQLColumns_wrap(SqlHandlePtr StatementHandle, + const py::object& catalogObj, + const py::object& schemaObj, + const py::object& tableObj, + const py::object& columnObj) { + if (!SQLColumns_ptr) { + ThrowStdException("SQLColumns function not loaded"); + } + + // Convert py::object to std::wstring, treating None as empty string + std::wstring catalogStr = catalogObj.is_none() ? L"" : catalogObj.cast(); + std::wstring schemaStr = schemaObj.is_none() ? L"" : schemaObj.cast(); + std::wstring tableStr = tableObj.is_none() ? L"" : tableObj.cast(); + std::wstring columnStr = columnObj.is_none() ? L"" : columnObj.cast(); + +#if defined(__APPLE__) || defined(__linux__) + // Unix implementation + std::vector catalogBuf = WStringToSQLWCHAR(catalogStr); + std::vector schemaBuf = WStringToSQLWCHAR(schemaStr); + std::vector tableBuf = WStringToSQLWCHAR(tableStr); + std::vector columnBuf = WStringToSQLWCHAR(columnStr); + + return SQLColumns_ptr( + StatementHandle->get(), + catalogStr.empty() ? nullptr : catalogBuf.data(), + catalogStr.empty() ? 0 : SQL_NTS, + schemaStr.empty() ? nullptr : schemaBuf.data(), + schemaStr.empty() ? 0 : SQL_NTS, + tableStr.empty() ? nullptr : tableBuf.data(), + tableStr.empty() ? 0 : SQL_NTS, + columnStr.empty() ? nullptr : columnBuf.data(), + columnStr.empty() ? 0 : SQL_NTS); +#else + // Windows implementation + return SQLColumns_ptr( + StatementHandle->get(), + catalogStr.empty() ? nullptr : (SQLWCHAR*)catalogStr.c_str(), + catalogStr.empty() ? 0 : SQL_NTS, + schemaStr.empty() ? nullptr : (SQLWCHAR*)schemaStr.c_str(), + schemaStr.empty() ? 0 : SQL_NTS, + tableStr.empty() ? nullptr : (SQLWCHAR*)tableStr.c_str(), + tableStr.empty() ? 0 : SQL_NTS, + columnStr.empty() ? nullptr : (SQLWCHAR*)columnStr.c_str(), + columnStr.empty() ? 0 : SQL_NTS); +#endif +} + // Helper function to check for driver errors ErrorInfo SQLCheckError_Wrap(SQLSMALLINT handleType, SqlHandlePtr handle, SQLRETURN retcode) { LOG("Checking errors for retcode - {}" , retcode); @@ -1738,6 +1993,54 @@ SQLRETURN SQLDescribeCol_wrap(SqlHandlePtr StatementHandle, py::list& ColumnMeta return SQL_SUCCESS; } +SQLRETURN SQLSpecialColumns_wrap(SqlHandlePtr StatementHandle, + SQLSMALLINT identifierType, + const py::object& catalogObj, + const py::object& schemaObj, + const std::wstring& table, + SQLSMALLINT scope, + SQLSMALLINT nullable) { + if (!SQLSpecialColumns_ptr) { + ThrowStdException("SQLSpecialColumns function not loaded"); + } + + // Convert py::object to std::wstring, treating None as empty string + std::wstring catalog = catalogObj.is_none() ? L"" : catalogObj.cast(); + std::wstring schema = schemaObj.is_none() ? L"" : schemaObj.cast(); + +#if defined(__APPLE__) || defined(__linux__) + // Unix implementation + std::vector catalogBuf = WStringToSQLWCHAR(catalog); + std::vector schemaBuf = WStringToSQLWCHAR(schema); + std::vector tableBuf = WStringToSQLWCHAR(table); + + return SQLSpecialColumns_ptr( + StatementHandle->get(), + identifierType, + catalog.empty() ? nullptr : catalogBuf.data(), + catalog.empty() ? 0 : SQL_NTS, + schema.empty() ? nullptr : schemaBuf.data(), + schema.empty() ? 0 : SQL_NTS, + table.empty() ? nullptr : tableBuf.data(), + table.empty() ? 0 : SQL_NTS, + scope, + nullable); +#else + // Windows implementation + return SQLSpecialColumns_ptr( + StatementHandle->get(), + identifierType, + catalog.empty() ? nullptr : (SQLWCHAR*)catalog.c_str(), + catalog.empty() ? 0 : SQL_NTS, + schema.empty() ? nullptr : (SQLWCHAR*)schema.c_str(), + schema.empty() ? 0 : SQL_NTS, + table.empty() ? nullptr : (SQLWCHAR*)table.c_str(), + table.empty() ? 0 : SQL_NTS, + scope, + nullable); +#endif +} + // Wrap SQLFetch to retrieve rows SQLRETURN SQLFetch_wrap(SqlHandlePtr StatementHandle) { LOG("Fetch next row"); @@ -3154,6 +3457,58 @@ PYBIND11_MODULE(ddbc_bindings, m) { m.def("DDBCSQLSetStmtAttr", [](SqlHandlePtr stmt, SQLINTEGER attr, SQLPOINTER value) { return SQLSetStmtAttr_ptr(stmt->get(), attr, value, 0); }, "Set statement attributes"); + m.def("DDBCSQLGetTypeInfo", &SQLGetTypeInfo_Wrapper, "Returns information about the data types that are supported by the data source", + py::arg("StatementHandle"), py::arg("DataType")); + m.def("DDBCSQLProcedures", [](SqlHandlePtr StatementHandle, + const py::object& catalog, + const py::object& schema, + const py::object& procedure) { + return SQLProcedures_wrap(StatementHandle, catalog, schema, procedure); + }); + + m.def("DDBCSQLForeignKeys", [](SqlHandlePtr StatementHandle, + const py::object& pkCatalog, + const py::object& pkSchema, + const py::object& pkTable, + const py::object& fkCatalog, + const py::object& fkSchema, + const py::object& fkTable) { + return SQLForeignKeys_wrap(StatementHandle, + pkCatalog, pkSchema, pkTable, + fkCatalog, fkSchema, fkTable); + }); + m.def("DDBCSQLPrimaryKeys", [](SqlHandlePtr StatementHandle, + const py::object& catalog, + const py::object& schema, + const std::wstring& table) { + return SQLPrimaryKeys_wrap(StatementHandle, catalog, schema, table); + }); + m.def("DDBCSQLSpecialColumns", [](SqlHandlePtr StatementHandle, + SQLSMALLINT identifierType, + const py::object& catalog, + const py::object& schema, + const std::wstring& table, + SQLSMALLINT scope, + SQLSMALLINT nullable) { + return SQLSpecialColumns_wrap(StatementHandle, + identifierType, catalog, schema, table, + scope, nullable); + }); + m.def("DDBCSQLStatistics", [](SqlHandlePtr StatementHandle, + const py::object& catalog, + const py::object& schema, + const std::wstring& table, + SQLUSMALLINT unique, + SQLUSMALLINT reserved) { + return SQLStatistics_wrap(StatementHandle, catalog, schema, table, unique, reserved); + }); + m.def("DDBCSQLColumns", [](SqlHandlePtr StatementHandle, + const py::object& catalog, + const py::object& schema, + const py::object& table, + const py::object& column) { + return SQLColumns_wrap(StatementHandle, catalog, schema, table, column); + }); // Add a version attribute diff --git a/mssql_python/pybind/ddbc_bindings.h b/mssql_python/pybind/ddbc_bindings.h index 2afdd660..63bee543 100644 --- a/mssql_python/pybind/ddbc_bindings.h +++ b/mssql_python/pybind/ddbc_bindings.h @@ -208,7 +208,24 @@ typedef SQLRETURN (*SQLTablesFunc)( SQLWCHAR* TableType, SQLSMALLINT NameLength4 ); - + typedef SQLRETURN (SQL_API* SQLGetTypeInfoFunc)(SQLHSTMT, SQLSMALLINT); +typedef SQLRETURN (SQL_API* SQLProceduresFunc)(SQLHSTMT, SQLWCHAR*, SQLSMALLINT, SQLWCHAR*, + SQLSMALLINT, SQLWCHAR*, SQLSMALLINT); +typedef SQLRETURN (SQL_API* SQLForeignKeysFunc)(SQLHSTMT, SQLWCHAR*, SQLSMALLINT, SQLWCHAR*, + SQLSMALLINT, SQLWCHAR*, SQLSMALLINT, SQLWCHAR*, + SQLSMALLINT, SQLWCHAR*, SQLSMALLINT, SQLWCHAR*, SQLSMALLINT); +typedef SQLRETURN (SQL_API* SQLPrimaryKeysFunc)(SQLHSTMT, SQLWCHAR*, SQLSMALLINT, SQLWCHAR*, + SQLSMALLINT, SQLWCHAR*, SQLSMALLINT); +typedef SQLRETURN (SQL_API* SQLSpecialColumnsFunc)(SQLHSTMT, SQLUSMALLINT, SQLWCHAR*, SQLSMALLINT, + SQLWCHAR*, SQLSMALLINT, SQLWCHAR*, SQLSMALLINT, + SQLUSMALLINT, SQLUSMALLINT); +typedef SQLRETURN (SQL_API* SQLStatisticsFunc)(SQLHSTMT, SQLWCHAR*, SQLSMALLINT, SQLWCHAR*, + SQLSMALLINT, SQLWCHAR*, SQLSMALLINT, + SQLUSMALLINT, SQLUSMALLINT); +typedef SQLRETURN (SQL_API* SQLColumnsFunc)(SQLHSTMT, SQLWCHAR*, SQLSMALLINT, SQLWCHAR*, + SQLSMALLINT, SQLWCHAR*, SQLSMALLINT, + SQLWCHAR*, SQLSMALLINT); + // Transaction APIs typedef SQLRETURN (SQL_API* SQLEndTranFunc)(SQLSMALLINT, SQLHANDLE, SQLSMALLINT); @@ -257,6 +274,13 @@ extern SQLDescribeColFunc SQLDescribeCol_ptr; extern SQLMoreResultsFunc SQLMoreResults_ptr; extern SQLColAttributeFunc SQLColAttribute_ptr; extern SQLTablesFunc SQLTables_ptr; +extern SQLGetTypeInfoFunc SQLGetTypeInfo_ptr; +extern SQLProceduresFunc SQLProcedures_ptr; +extern SQLForeignKeysFunc SQLForeignKeys_ptr; +extern SQLPrimaryKeysFunc SQLPrimaryKeys_ptr; +extern SQLSpecialColumnsFunc SQLSpecialColumns_ptr; +extern SQLStatisticsFunc SQLStatistics_ptr; +extern SQLColumnsFunc SQLColumns_ptr; // Transaction APIs extern SQLEndTranFunc SQLEndTran_ptr; diff --git a/tests/test_001_globals.py b/tests/test_001_globals.py index a7247823..2f1e3754 100644 --- a/tests/test_001_globals.py +++ b/tests/test_001_globals.py @@ -28,6 +28,37 @@ def test_paramstyle(): # Check if paramstyle has the expected value assert paramstyle == "qmark", "paramstyle should be 'qmark'" +def test_lowercase(): + # Check if lowercase has the expected default value + assert lowercase is False, "lowercase should default to False" + +def test_decimal_separator(): + """Test decimal separator functionality""" + + # Check default value + assert getDecimalSeparator() == '.', "Default decimal separator should be '.'" + + try: + # Test setting a new value + setDecimalSeparator(',') + assert getDecimalSeparator() == ',', "Decimal separator should be ',' after setting" + + # Test invalid input + with pytest.raises(ValueError): + setDecimalSeparator('too long') + + with pytest.raises(ValueError): + setDecimalSeparator('') + + with pytest.raises(ValueError): + setDecimalSeparator(123) # Non-string input + + finally: + # Restore default value + setDecimalSeparator('.') + assert getDecimalSeparator() == '.', "Decimal separator should be restored to '.'" + + def test_lowercase(): # Check if lowercase has the expected default value assert lowercase is False, "lowercase should default to False" @@ -174,9 +205,6 @@ def reader(): # Assert that no errors occurred in the threads assert not errors, f"Thread safety test failed with errors: {errors}" -def test_lowercase(): - # Check if lowercase has the expected default value - assert lowercase is False, "lowercase should default to False" def test_decimal_separator_edge_cases(): """Test decimal separator edge cases and boundary conditions""" diff --git a/tests/test_003_connection.py b/tests/test_003_connection.py index 9c6fd35e..ac77be9d 100644 --- a/tests/test_003_connection.py +++ b/tests/test_003_connection.py @@ -43,46 +43,6 @@ from datetime import datetime, timedelta, timezone from mssql_python.constants import ConstantsDDBC -@pytest.fixture(autouse=True) -def clean_connection_state(db_connection): - """Ensure connection is in a clean state before each test""" - # Create a cursor and clear any active results - try: - cleanup_cursor = db_connection.cursor() - cleanup_cursor.execute("SELECT 1") # Simple query to reset state - cleanup_cursor.fetchall() # Consume all results - cleanup_cursor.close() - except Exception: - pass # Ignore errors during cleanup - - yield # Run the test - - # Clean up after the test - try: - cleanup_cursor = db_connection.cursor() - cleanup_cursor.execute("SELECT 1") # Simple query to reset state - cleanup_cursor.fetchall() # Consume all results - cleanup_cursor.close() - except Exception: - pass # Ignore errors during cleanup - -# Import all exception classes for testing -from mssql_python.exceptions import ( - Warning, - Error, - InterfaceError, - DatabaseError, - DataError, - OperationalError, - IntegrityError, - InternalError, - ProgrammingError, - NotSupportedError, -) -import struct -from datetime import datetime, timedelta, timezone -from mssql_python.constants import ConstantsDDBC - @pytest.fixture(autouse=True) def clean_connection_state(db_connection): """Ensure connection is in a clean state before each test""" @@ -128,12 +88,7 @@ def handle_datetimeoffset(dto_value): ) def custom_string_converter(value): - """ - A simple converter that adds a prefix to string values. - Assumes SQL_WVARCHAR is UTF-16LE encoded by default, - but this may vary depending on the database configuration. - You can specify a different encoding if needed. - """ + """A simple converter that adds a prefix to string values""" if value is None: return None return "CONVERTED: " + value.decode('utf-16-le') # SQL_WVARCHAR is UTF-16LE encoded @@ -4256,6 +4211,533 @@ def test_timeout_affects_all_cursors(db_connection): result2 = cursor2.fetchone() assert result2[0] == 2, "Query with second cursor failed" + # No direct way to check cursor timeout, but both should succeed + # with the current timeout setting + finally: + # Reset timeout + db_connection.timeout = original_timeout +def test_connection_execute(db_connection): + """Test the execute() convenience method for Connection class""" + # Test basic execution + cursor = db_connection.execute("SELECT 1 AS test_value") + result = cursor.fetchone() + assert result is not None, "Execute failed: No result returned" + assert result[0] == 1, "Execute failed: Incorrect result" + + # Test with parameters + cursor = db_connection.execute("SELECT ? AS test_value", 42) + result = cursor.fetchone() + assert result is not None, "Execute with parameters failed: No result returned" + assert result[0] == 42, "Execute with parameters failed: Incorrect result" + + # Test that cursor is tracked by connection + assert cursor in db_connection._cursors, "Cursor from execute() not tracked by connection" + + # Test with data modification and verify it requires commit + if not db_connection.autocommit: + drop_table_if_exists(db_connection.cursor(), "#pytest_test_execute") + cursor1 = db_connection.execute("CREATE TABLE #pytest_test_execute (id INT, value VARCHAR(50))") + cursor2 = db_connection.execute("INSERT INTO #pytest_test_execute VALUES (1, 'test_value')") + cursor3 = db_connection.execute("SELECT * FROM #pytest_test_execute") + result = cursor3.fetchone() + assert result is not None, "Execute with table creation failed" + assert result[0] == 1, "Execute with table creation returned wrong id" + assert result[1] == 'test_value', "Execute with table creation returned wrong value" + + # Clean up + db_connection.execute("DROP TABLE #pytest_test_execute") + db_connection.commit() + +def test_connection_execute_error_handling(db_connection): + """Test that execute() properly handles SQL errors""" + with pytest.raises(Exception): + db_connection.execute("SELECT * FROM nonexistent_table") + +def test_connection_execute_empty_result(db_connection): + """Test execute() with a query that returns no rows""" + cursor = db_connection.execute("SELECT * FROM sys.tables WHERE name = 'nonexistent_table_name'") + result = cursor.fetchone() + assert result is None, "Query should return no results" + + # Test empty result with fetchall + rows = cursor.fetchall() + assert len(rows) == 0, "fetchall should return empty list for empty result set" + +def test_connection_execute_different_parameter_types(db_connection): + """Test execute() with different parameter data types""" + # Test with different data types + params = [ + 1234, # Integer + 3.14159, # Float + "test string", # String + bytearray(b'binary data'), # Binary data + True, # Boolean + None # NULL + ] + + for param in params: + cursor = db_connection.execute("SELECT ? AS value", param) + result = cursor.fetchone() + if param is None: + assert result[0] is None, "NULL parameter not handled correctly" + else: + assert result[0] == param, f"Parameter {param} of type {type(param)} not handled correctly" + +def test_connection_execute_with_transaction(db_connection): + """Test execute() in the context of explicit transactions""" + if db_connection.autocommit: + db_connection.autocommit = False + + cursor1 = db_connection.cursor() + drop_table_if_exists(cursor1, "#pytest_test_execute_transaction") + + try: + # Create table and insert data + db_connection.execute("CREATE TABLE #pytest_test_execute_transaction (id INT, value VARCHAR(50))") + db_connection.execute("INSERT INTO #pytest_test_execute_transaction VALUES (1, 'before rollback')") + + # Check data is there + cursor = db_connection.execute("SELECT * FROM #pytest_test_execute_transaction") + result = cursor.fetchone() + assert result is not None, "Data should be visible within transaction" + assert result[1] == 'before rollback', "Incorrect data in transaction" + + # Rollback and verify data is gone + db_connection.rollback() + + # Need to recreate table since it was rolled back + db_connection.execute("CREATE TABLE #pytest_test_execute_transaction (id INT, value VARCHAR(50))") + db_connection.execute("INSERT INTO #pytest_test_execute_transaction VALUES (2, 'after rollback')") + + cursor = db_connection.execute("SELECT * FROM #pytest_test_execute_transaction") + result = cursor.fetchone() + assert result is not None, "Data should be visible after new insert" + assert result[0] == 2, "Should see the new data after rollback" + assert result[1] == 'after rollback', "Incorrect data after rollback" + + # Commit and verify data persists + db_connection.commit() + finally: + # Clean up + try: + db_connection.execute("DROP TABLE #pytest_test_execute_transaction") + db_connection.commit() + except Exception: + pass + +def test_connection_execute_vs_cursor_execute(db_connection): + """Compare behavior of connection.execute() vs cursor.execute()""" + # Connection.execute creates a new cursor each time + cursor1 = db_connection.execute("SELECT 1 AS first_query") + # Consume the results from cursor1 before creating cursor2 + result1 = cursor1.fetchall() + assert result1[0][0] == 1, "First cursor should have result from first query" + + # Now it's safe to create a second cursor + cursor2 = db_connection.execute("SELECT 2 AS second_query") + result2 = cursor2.fetchall() + assert result2[0][0] == 2, "Second cursor should have result from second query" + + # These should be different cursor objects + assert cursor1 != cursor2, "Connection.execute should create a new cursor each time" + + # Now compare with reusing the same cursor + cursor3 = db_connection.cursor() + cursor3.execute("SELECT 3 AS third_query") + result3 = cursor3.fetchone() + assert result3[0] == 3, "Direct cursor execution failed" + + # Reuse the same cursor + cursor3.execute("SELECT 4 AS fourth_query") + result4 = cursor3.fetchone() + assert result4[0] == 4, "Reused cursor should have new results" + + # The previous results should no longer be accessible + cursor3.execute("SELECT 3 AS third_query_again") + result5 = cursor3.fetchone() + assert result5[0] == 3, "Cursor reexecution should work" + +def test_connection_execute_many_parameters(db_connection): + """Test execute() with many parameters""" + # First make sure no active results are pending + # by using a fresh cursor and fetching all results + cursor = db_connection.cursor() + cursor.execute("SELECT 1") + cursor.fetchall() + + # Create a query with 10 parameters + params = list(range(1, 11)) + query = "SELECT " + ", ".join(["?" for _ in params]) + " AS many_params" + + # Now execute with many parameters + cursor = db_connection.execute(query, *params) + result = cursor.fetchall() # Use fetchall to consume all results + + # Verify all parameters were correctly passed + for i, value in enumerate(params): + assert result[0][i] == value, f"Parameter at position {i} not correctly passed" + +def test_add_output_converter(db_connection): + """Test adding an output converter""" + # Add a converter + sql_wvarchar = ConstantsDDBC.SQL_WVARCHAR.value + db_connection.add_output_converter(sql_wvarchar, custom_string_converter) + + # Verify it was added correctly + assert hasattr(db_connection, '_output_converters') + assert sql_wvarchar in db_connection._output_converters + assert db_connection._output_converters[sql_wvarchar] == custom_string_converter + + # Clean up + db_connection.clear_output_converters() + +def test_get_output_converter(db_connection): + """Test getting an output converter""" + sql_wvarchar = ConstantsDDBC.SQL_WVARCHAR.value + + # Initial state - no converter + assert db_connection.get_output_converter(sql_wvarchar) is None + + # Add a converter + db_connection.add_output_converter(sql_wvarchar, custom_string_converter) + + # Get the converter + converter = db_connection.get_output_converter(sql_wvarchar) + assert converter == custom_string_converter + + # Get a non-existent converter + assert db_connection.get_output_converter(999) is None + + # Clean up + db_connection.clear_output_converters() + +def test_remove_output_converter(db_connection): + """Test removing an output converter""" + sql_wvarchar = ConstantsDDBC.SQL_WVARCHAR.value + + # Add a converter + db_connection.add_output_converter(sql_wvarchar, custom_string_converter) + assert db_connection.get_output_converter(sql_wvarchar) is not None + + # Remove the converter + db_connection.remove_output_converter(sql_wvarchar) + assert db_connection.get_output_converter(sql_wvarchar) is None + + # Remove a non-existent converter (should not raise) + db_connection.remove_output_converter(999) + +def test_clear_output_converters(db_connection): + """Test clearing all output converters""" + sql_wvarchar = ConstantsDDBC.SQL_WVARCHAR.value + sql_timestamp_offset = ConstantsDDBC.SQL_TIMESTAMPOFFSET.value + + # Add multiple converters + db_connection.add_output_converter(sql_wvarchar, custom_string_converter) + db_connection.add_output_converter(sql_timestamp_offset, handle_datetimeoffset) + + # Verify converters were added + assert db_connection.get_output_converter(sql_wvarchar) is not None + assert db_connection.get_output_converter(sql_timestamp_offset) is not None + + # Clear all converters + db_connection.clear_output_converters() + + # Verify all converters were removed + assert db_connection.get_output_converter(sql_wvarchar) is None + assert db_connection.get_output_converter(sql_timestamp_offset) is None + +def test_converter_integration(db_connection): + """ + Test that converters work during fetching. + + This test verifies that output converters work at the Python level + without requiring native driver support. + """ + cursor = db_connection.cursor() + sql_wvarchar = ConstantsDDBC.SQL_WVARCHAR.value + + # Test with string converter + db_connection.add_output_converter(sql_wvarchar, custom_string_converter) + + # Test a simple string query + cursor.execute("SELECT N'test string' AS test_col") + row = cursor.fetchone() + + # Check if the type matches what we expect for SQL_WVARCHAR + # For Cursor.description, the second element is the type code + column_type = cursor.description[0][1] + + # If the cursor description has SQL_WVARCHAR as the type code, + # then our converter should be applied + if column_type == sql_wvarchar: + assert row[0].startswith("CONVERTED:"), "Output converter not applied" + else: + # If the type code is different, adjust the test or the converter + print(f"Column type is {column_type}, not {sql_wvarchar}") + # Add converter for the actual type used + db_connection.clear_output_converters() + db_connection.add_output_converter(column_type, custom_string_converter) + + # Re-execute the query + cursor.execute("SELECT N'test string' AS test_col") + row = cursor.fetchone() + assert row[0].startswith("CONVERTED:"), "Output converter not applied" + + # Clean up + db_connection.clear_output_converters() + +def test_output_converter_with_null_values(db_connection): + """Test that output converters handle NULL values correctly""" + cursor = db_connection.cursor() + sql_wvarchar = ConstantsDDBC.SQL_WVARCHAR.value + + # Add converter for string type + db_connection.add_output_converter(sql_wvarchar, custom_string_converter) + + # Execute a query with NULL values + cursor.execute("SELECT CAST(NULL AS NVARCHAR(50)) AS null_col") + value = cursor.fetchone()[0] + + # NULL values should remain None regardless of converter + assert value is None + + # Clean up + db_connection.clear_output_converters() + +def test_chaining_output_converters(db_connection): + """Test that output converters can be chained (replaced)""" + sql_wvarchar = ConstantsDDBC.SQL_WVARCHAR.value + + # Define a second converter + def another_string_converter(value): + if value is None: + return None + return "ANOTHER: " + value.decode('utf-16-le') + + # Add first converter + db_connection.add_output_converter(sql_wvarchar, custom_string_converter) + + # Verify first converter is registered + assert db_connection.get_output_converter(sql_wvarchar) == custom_string_converter + + # Replace with second converter + db_connection.add_output_converter(sql_wvarchar, another_string_converter) + + # Verify second converter replaced the first + assert db_connection.get_output_converter(sql_wvarchar) == another_string_converter + + # Clean up + db_connection.clear_output_converters() + +def test_temporary_converter_replacement(db_connection): + """Test temporarily replacing a converter and then restoring it""" + sql_wvarchar = ConstantsDDBC.SQL_WVARCHAR.value + + # Add a converter + db_connection.add_output_converter(sql_wvarchar, custom_string_converter) + + # Save original converter + original_converter = db_connection.get_output_converter(sql_wvarchar) + + # Define a temporary converter + def temp_converter(value): + if value is None: + return None + return "TEMP: " + value.decode('utf-16-le') + + # Replace with temporary converter + db_connection.add_output_converter(sql_wvarchar, temp_converter) + + # Verify temporary converter is in use + assert db_connection.get_output_converter(sql_wvarchar) == temp_converter + + # Restore original converter + db_connection.add_output_converter(sql_wvarchar, original_converter) + + # Verify original converter is restored + assert db_connection.get_output_converter(sql_wvarchar) == original_converter + + # Clean up + db_connection.clear_output_converters() + +def test_multiple_output_converters(db_connection): + """Test that multiple output converters can work together""" + cursor = db_connection.cursor() + + # Execute a query to get the actual type codes used + cursor.execute("SELECT CAST(42 AS INT) as int_col, N'test' as str_col") + int_type = cursor.description[0][1] # Type code for integer column + str_type = cursor.description[1][1] # Type code for string column + + # Add converter for string type + db_connection.add_output_converter(str_type, custom_string_converter) + + # Add converter for integer type + def int_converter(value): + if value is None: + return None + # Convert from bytes to int and multiply by 2 + if isinstance(value, bytes): + return int.from_bytes(value, byteorder='little') * 2 + elif isinstance(value, int): + return value * 2 + return value + + db_connection.add_output_converter(int_type, int_converter) + + # Test query with both types + cursor.execute("SELECT CAST(42 AS INT) as int_col, N'test' as str_col") + row = cursor.fetchone() + + # Verify converters worked + assert row[0] == 84, f"Integer converter failed, got {row[0]} instead of 84" + assert isinstance(row[1], str) and "CONVERTED:" in row[1], f"String converter failed, got {row[1]}" + + # Clean up + db_connection.clear_output_converters() + +def test_timeout_default(db_connection): + """Test that the default timeout value is 0 (no timeout)""" + assert hasattr(db_connection, 'timeout'), "Connection should have a timeout attribute" + assert db_connection.timeout == 0, "Default timeout should be 0" + +def test_timeout_setter(db_connection): + """Test setting and getting the timeout value""" + # Set a non-zero timeout + db_connection.timeout = 30 + assert db_connection.timeout == 30, "Timeout should be set to 30" + + # Test that timeout can be reset to zero + db_connection.timeout = 0 + assert db_connection.timeout == 0, "Timeout should be reset to 0" + + # Test setting invalid timeout values + with pytest.raises(ValueError): + db_connection.timeout = -1 + + with pytest.raises(TypeError): + db_connection.timeout = "30" + + # Reset timeout to default for other tests + db_connection.timeout = 0 + +def test_timeout_from_constructor(conn_str): + """Test setting timeout in the connection constructor""" + # Create a connection with timeout set + conn = connect(conn_str, timeout=45) + try: + assert conn.timeout == 45, "Timeout should be set to 45 from constructor" + + # Create a cursor and verify it inherits the timeout + cursor = conn.cursor() + # Execute a quick query to ensure the timeout doesn't interfere + cursor.execute("SELECT 1") + result = cursor.fetchone() + assert result[0] == 1, "Query execution should succeed with timeout set" + finally: + # Clean up + conn.close() + +def test_timeout_long_query(db_connection): + """Test that a query exceeding the timeout raises an exception if supported by driver""" + import time + import pytest + + cursor = db_connection.cursor() + + try: + # First execute a simple query to check if we can run tests + cursor.execute("SELECT 1") + cursor.fetchall() + except Exception as e: + pytest.skip(f"Skipping timeout test due to connection issue: {e}") + + # Set a short timeout + original_timeout = db_connection.timeout + db_connection.timeout = 2 # 2 seconds + + try: + # Try several different approaches to test timeout + start_time = time.perf_counter() + try: + # Method 1: CPU-intensive query with REPLICATE and large result set + cpu_intensive_query = """ + WITH numbers AS ( + SELECT TOP 1000000 ROW_NUMBER() OVER (ORDER BY (SELECT NULL)) AS n + FROM sys.objects a CROSS JOIN sys.objects b + ) + SELECT COUNT(*) FROM numbers WHERE n % 2 = 0 + """ + cursor.execute(cpu_intensive_query) + cursor.fetchall() + + elapsed_time = time.perf_counter() - start_time + + # If we get here without an exception, try a different approach + if elapsed_time < 4.5: + + # Method 2: Try with WAITFOR + start_time = time.perf_counter() + cursor.execute("WAITFOR DELAY '00:00:05'") + cursor.fetchall() + elapsed_time = time.perf_counter() - start_time + + # If we still get here, try one more approach + if elapsed_time < 4.5: + + # Method 3: Try with a join that generates many rows + start_time = time.perf_counter() + cursor.execute(""" + SELECT COUNT(*) FROM sys.objects a, sys.objects b, sys.objects c + WHERE a.object_id = b.object_id * c.object_id + """) + cursor.fetchall() + elapsed_time = time.perf_counter() - start_time + + # If we still get here without an exception + if elapsed_time < 4.5: + pytest.skip("Timeout feature not enforced by database driver") + + except Exception as e: + # Verify this is a timeout exception + elapsed_time = time.perf_counter() - start_time + assert elapsed_time < 4.5, "Exception occurred but after expected timeout" + error_text = str(e).lower() + + # Check for various error messages that might indicate timeout + timeout_indicators = [ + "timeout", "timed out", "hyt00", "hyt01", "cancel", + "operation canceled", "execution terminated", "query limit" + ] + + assert any(indicator in error_text for indicator in timeout_indicators), \ + f"Exception occurred but doesn't appear to be a timeout error: {e}" + finally: + # Reset timeout for other tests + db_connection.timeout = original_timeout + +def test_timeout_affects_all_cursors(db_connection): + """Test that changing timeout on connection affects all new cursors""" + # Create a cursor with default timeout + cursor1 = db_connection.cursor() + + # Change the connection timeout + original_timeout = db_connection.timeout + db_connection.timeout = 10 + + # Create a new cursor + cursor2 = db_connection.cursor() + + try: + # Execute quick queries to ensure both cursors work + cursor1.execute("SELECT 1") + result1 = cursor1.fetchone() + assert result1[0] == 1, "Query with first cursor failed" + + cursor2.execute("SELECT 2") + result2 = cursor2.fetchone() + assert result2[0] == 2, "Query with second cursor failed" + # No direct way to check cursor timeout, but both should succeed # with the current timeout setting finally: diff --git a/tests/test_004_cursor.py b/tests/test_004_cursor.py index 072a5ec6..0d2fc232 100644 --- a/tests/test_004_cursor.py +++ b/tests/test_004_cursor.py @@ -10,6 +10,7 @@ import pytest from datetime import datetime, date, time +import time as time_module import decimal from contextlib import closing import mssql_python @@ -7355,6 +7356,2603 @@ def test_decimal_separator_calculations(cursor, db_connection): cursor.execute("DROP TABLE IF EXISTS #pytest_decimal_calc_test") db_connection.commit() +def test_lowercase_attribute(cursor, db_connection): + """Test that the lowercase attribute properly converts column names to lowercase""" + + # Store original value to restore after test + original_lowercase = mssql_python.lowercase + drop_cursor = None + + try: + # Create a test table with mixed-case column names + cursor.execute(""" + CREATE TABLE #pytest_lowercase_test ( + ID INT PRIMARY KEY, + UserName VARCHAR(50), + EMAIL_ADDRESS VARCHAR(100), + PhoneNumber VARCHAR(20) + ) + """) + db_connection.commit() + + # Insert test data + cursor.execute(""" + INSERT INTO #pytest_lowercase_test (ID, UserName, EMAIL_ADDRESS, PhoneNumber) + VALUES (1, 'JohnDoe', 'john@example.com', '555-1234') + """) + db_connection.commit() + + # First test with lowercase=False (default) + mssql_python.lowercase = False + cursor1 = db_connection.cursor() + cursor1.execute("SELECT * FROM #pytest_lowercase_test") + + # Description column names should preserve original case + column_names1 = [desc[0] for desc in cursor1.description] + assert "ID" in column_names1, "Column 'ID' should be present with original case" + assert "UserName" in column_names1, "Column 'UserName' should be present with original case" + + # Make sure to consume all results and close the cursor + cursor1.fetchall() + cursor1.close() + + # Now test with lowercase=True + mssql_python.lowercase = True + cursor2 = db_connection.cursor() + cursor2.execute("SELECT * FROM #pytest_lowercase_test") + + # Description column names should be lowercase + column_names2 = [desc[0] for desc in cursor2.description] + assert "id" in column_names2, "Column names should be lowercase when lowercase=True" + assert "username" in column_names2, "Column names should be lowercase when lowercase=True" + + # Make sure to consume all results and close the cursor + cursor2.fetchall() + cursor2.close() + + # Create a fresh cursor for cleanup + drop_cursor = db_connection.cursor() + + finally: + # Restore original value + mssql_python.lowercase = original_lowercase + + try: + # Use a separate cursor for cleanup + if drop_cursor: + drop_cursor.execute("DROP TABLE IF EXISTS #pytest_lowercase_test") + db_connection.commit() + drop_cursor.close() + except Exception as e: + print(f"Warning: Failed to drop test table: {e}") + +def test_decimal_separator_function(cursor, db_connection): + """Test decimal separator functionality with database operations""" + # Store original value to restore after test + original_separator = mssql_python.getDecimalSeparator() + + try: + # Create test table + cursor.execute(""" + CREATE TABLE #pytest_decimal_separator_test ( + id INT PRIMARY KEY, + decimal_value DECIMAL(10, 2) + ) + """) + db_connection.commit() + + # Insert test values with default separator (.) + test_value = decimal.Decimal('123.45') + cursor.execute(""" + INSERT INTO #pytest_decimal_separator_test (id, decimal_value) + VALUES (1, ?) + """, [test_value]) + db_connection.commit() + + # First test with default decimal separator (.) + cursor.execute("SELECT id, decimal_value FROM #pytest_decimal_separator_test") + row = cursor.fetchone() + default_str = str(row) + assert '123.45' in default_str, "Default separator not found in string representation" + + # Now change to comma separator and test string representation + mssql_python.setDecimalSeparator(',') + cursor.execute("SELECT id, decimal_value FROM #pytest_decimal_separator_test") + row = cursor.fetchone() + + # This should format the decimal with a comma in the string representation + comma_str = str(row) + assert '123,45' in comma_str, f"Expected comma in string representation but got: {comma_str}" + + finally: + # Restore original decimal separator + mssql_python.setDecimalSeparator(original_separator) + + # Cleanup + cursor.execute("DROP TABLE IF EXISTS #pytest_decimal_separator_test") + db_connection.commit() + +def test_decimal_separator_basic_functionality(): + """Test basic decimal separator functionality without database operations""" + # Store original value to restore after test + original_separator = mssql_python.getDecimalSeparator() + + try: + # Test default value + assert mssql_python.getDecimalSeparator() == '.', "Default decimal separator should be '.'" + + # Test setting to comma + mssql_python.setDecimalSeparator(',') + assert mssql_python.getDecimalSeparator() == ',', "Decimal separator should be ',' after setting" + + # Test setting to other valid separators + mssql_python.setDecimalSeparator(':') + assert mssql_python.getDecimalSeparator() == ':', "Decimal separator should be ':' after setting" + + # Test invalid inputs + with pytest.raises(ValueError): + mssql_python.setDecimalSeparator('') # Empty string + + with pytest.raises(ValueError): + mssql_python.setDecimalSeparator('too_long') # More than one character + + with pytest.raises(ValueError): + mssql_python.setDecimalSeparator(123) # Not a string + + finally: + # Restore original separator + mssql_python.setDecimalSeparator(original_separator) + +def test_decimal_separator_with_multiple_values(cursor, db_connection): + """Test decimal separator with multiple different decimal values""" + original_separator = mssql_python.getDecimalSeparator() + + try: + # Create test table + cursor.execute(""" + CREATE TABLE #pytest_decimal_multi_test ( + id INT PRIMARY KEY, + positive_value DECIMAL(10, 2), + negative_value DECIMAL(10, 2), + zero_value DECIMAL(10, 2), + small_value DECIMAL(10, 4) + ) + """) + db_connection.commit() + + # Insert test data + cursor.execute(""" + INSERT INTO #pytest_decimal_multi_test VALUES (1, 123.45, -67.89, 0.00, 0.0001) + """) + db_connection.commit() + + # Test with default separator first + cursor.execute("SELECT * FROM #pytest_decimal_multi_test") + row = cursor.fetchone() + default_str = str(row) + assert '123.45' in default_str, "Default positive value formatting incorrect" + assert '-67.89' in default_str, "Default negative value formatting incorrect" + + # Change to comma separator + mssql_python.setDecimalSeparator(',') + cursor.execute("SELECT * FROM #pytest_decimal_multi_test") + row = cursor.fetchone() + comma_str = str(row) + + # Verify comma is used in all decimal values + assert '123,45' in comma_str, "Positive value not formatted with comma" + assert '-67,89' in comma_str, "Negative value not formatted with comma" + assert '0,00' in comma_str, "Zero value not formatted with comma" + assert '0,0001' in comma_str, "Small value not formatted with comma" + + finally: + # Restore original separator + mssql_python.setDecimalSeparator(original_separator) + + # Cleanup + cursor.execute("DROP TABLE IF EXISTS #pytest_decimal_multi_test") + db_connection.commit() + +def test_decimal_separator_calculations(cursor, db_connection): + """Test that decimal separator doesn't affect calculations""" + original_separator = mssql_python.getDecimalSeparator() + + try: + # Create test table + cursor.execute(""" + CREATE TABLE #pytest_decimal_calc_test ( + id INT PRIMARY KEY, + value1 DECIMAL(10, 2), + value2 DECIMAL(10, 2) + ) + """) + db_connection.commit() + + # Insert test data + cursor.execute(""" + INSERT INTO #pytest_decimal_calc_test VALUES (1, 10.25, 5.75) + """) + db_connection.commit() + + # Test with default separator + cursor.execute("SELECT value1 + value2 AS sum_result FROM #pytest_decimal_calc_test") + row = cursor.fetchone() + assert row.sum_result == decimal.Decimal('16.00'), "Sum calculation incorrect with default separator" + + # Change to comma separator + mssql_python.setDecimalSeparator(',') + + # Calculations should still work correctly + cursor.execute("SELECT value1 + value2 AS sum_result FROM #pytest_decimal_calc_test") + row = cursor.fetchone() + assert row.sum_result == decimal.Decimal('16.00'), "Sum calculation affected by separator change" + + # But string representation should use comma + assert '16,00' in str(row), "Sum result not formatted with comma in string representation" + + finally: + # Restore original separator + mssql_python.setDecimalSeparator(original_separator) + + # Cleanup + cursor.execute("DROP TABLE IF EXISTS #pytest_decimal_calc_test") + db_connection.commit() + +def test_cursor_setinputsizes_basic(db_connection): + """Test the basic functionality of setinputsizes""" + + cursor = db_connection.cursor() + + # Create a test table + cursor.execute("DROP TABLE IF EXISTS #test_inputsizes") + cursor.execute(""" + CREATE TABLE #test_inputsizes ( + string_col NVARCHAR(100), + int_col INT + ) + """) + + # Set input sizes for parameters + cursor.setinputsizes([ + (mssql_python.SQL_WVARCHAR, 100, 0), + (mssql_python.SQL_INTEGER, 0, 0) + ]) + + # Execute with parameters + cursor.execute( + "INSERT INTO #test_inputsizes VALUES (?, ?)", + "Test String", 42 + ) + + # Verify data was inserted correctly + cursor.execute("SELECT * FROM #test_inputsizes") + row = cursor.fetchone() + + assert row[0] == "Test String" + assert row[1] == 42 + + # Clean up + cursor.execute("DROP TABLE IF EXISTS #test_inputsizes") + +def test_cursor_setinputsizes_with_executemany_float(db_connection): + """Test setinputsizes with executemany using float instead of Decimal""" + + cursor = db_connection.cursor() + + # Create a test table + cursor.execute("DROP TABLE IF EXISTS #test_inputsizes_float") + cursor.execute(""" + CREATE TABLE #test_inputsizes_float ( + id INT, + name NVARCHAR(50), + price REAL /* Use REAL instead of DECIMAL */ + ) + """) + + # Prepare data with float values + data = [ + (1, "Item 1", 10.99), + (2, "Item 2", 20.50), + (3, "Item 3", 30.75) + ] + + # Set input sizes for parameters + cursor.setinputsizes([ + (mssql_python.SQL_INTEGER, 0, 0), + (mssql_python.SQL_WVARCHAR, 50, 0), + (mssql_python.SQL_REAL, 0, 0) + ]) + + # Execute with parameters + cursor.executemany( + "INSERT INTO #test_inputsizes_float VALUES (?, ?, ?)", + data + ) + + # Verify all data was inserted correctly + cursor.execute("SELECT * FROM #test_inputsizes_float ORDER BY id") + rows = cursor.fetchall() + + assert len(rows) == 3 + assert rows[0][0] == 1 + assert rows[0][1] == "Item 1" + assert abs(rows[0][2] - 10.99) < 0.001 + + # Clean up + cursor.execute("DROP TABLE IF EXISTS #test_inputsizes_float") + +def test_cursor_setinputsizes_reset(db_connection): + """Test that setinputsizes is reset after execution""" + + cursor = db_connection.cursor() + + # Create a test table + cursor.execute("DROP TABLE IF EXISTS #test_inputsizes_reset") + cursor.execute(""" + CREATE TABLE #test_inputsizes_reset ( + col1 NVARCHAR(100), + col2 INT + ) + """) + + # Set input sizes for parameters + cursor.setinputsizes([ + (mssql_python.SQL_WVARCHAR, 100, 0), + (mssql_python.SQL_INTEGER, 0, 0) + ]) + + # Execute with parameters + cursor.execute( + "INSERT INTO #test_inputsizes_reset VALUES (?, ?)", + "Test String", 42 + ) + + # Verify inputsizes was reset + assert cursor._inputsizes is None + + # Now execute again without setting input sizes + cursor.execute( + "INSERT INTO #test_inputsizes_reset VALUES (?, ?)", + "Another String", 84 + ) + + # Verify both rows were inserted correctly + cursor.execute("SELECT * FROM #test_inputsizes_reset ORDER BY col2") + rows = cursor.fetchall() + + assert len(rows) == 2 + assert rows[0][0] == "Test String" + assert rows[0][1] == 42 + assert rows[1][0] == "Another String" + assert rows[1][1] == 84 + + # Clean up + cursor.execute("DROP TABLE IF EXISTS #test_inputsizes_reset") + +def test_cursor_setinputsizes_override_inference(db_connection): + """Test that setinputsizes overrides type inference""" + + cursor = db_connection.cursor() + + # Create a test table with specific types + cursor.execute("DROP TABLE IF EXISTS #test_inputsizes_override") + cursor.execute(""" + CREATE TABLE #test_inputsizes_override ( + small_int SMALLINT, + big_text NVARCHAR(MAX) + ) + """) + + # Set input sizes that override the default inference + # For SMALLINT, use a valid precision value (5 is typical for SMALLINT) + cursor.setinputsizes([ + (mssql_python.SQL_SMALLINT, 5, 0), # Use valid precision for SMALLINT + (mssql_python.SQL_WVARCHAR, 8000, 0) # Force short string to NVARCHAR(MAX) + ]) + + # Test with values that would normally be inferred differently + big_number = 30000 # Would normally be INTEGER or BIGINT + short_text = "abc" # Would normally be a regular NVARCHAR + + try: + cursor.execute( + "INSERT INTO #test_inputsizes_override VALUES (?, ?)", + big_number, short_text + ) + + # Verify the row was inserted (may have been truncated by SQL Server) + cursor.execute("SELECT * FROM #test_inputsizes_override") + row = cursor.fetchone() + + # SQL Server would either truncate or round the value + assert row[1] == short_text + + except Exception as e: + # If an exception occurs, it should be related to the data type conversion + # Add "invalid precision" to the expected error messages + error_text = str(e).lower() + assert any(text in error_text for text in ["overflow", "out of range", "convert", "invalid precision", "precision value"]), \ + f"Unexpected error: {e}" + + # Clean up + cursor.execute("DROP TABLE IF EXISTS #test_inputsizes_override") + +def test_setinputsizes_parameter_count_mismatch_fewer(db_connection): + """Test setinputsizes with fewer sizes than parameters""" + import warnings + + cursor = db_connection.cursor() + + # Create a test table + cursor.execute("DROP TABLE IF EXISTS #test_inputsizes_mismatch") + cursor.execute(""" + CREATE TABLE #test_inputsizes_mismatch ( + col1 INT, + col2 NVARCHAR(100), + col3 FLOAT + ) + """) + + # Set fewer input sizes than parameters + cursor.setinputsizes([ + (mssql_python.SQL_INTEGER, 0, 0), + (mssql_python.SQL_WVARCHAR, 100, 0) + # Missing third parameter type + ]) + + # Execute with more parameters than specified input sizes + # This should use automatic type inference for the third parameter + with warnings.catch_warnings(record=True) as w: + cursor.execute( + "INSERT INTO #test_inputsizes_mismatch VALUES (?, ?, ?)", + 1, "Test String", 3.14 + ) + assert len(w) > 0, "Warning should be issued for parameter count mismatch" + assert "number of input sizes" in str(w[0].message).lower() + + # Verify data was inserted correctly + cursor.execute("SELECT * FROM #test_inputsizes_mismatch") + row = cursor.fetchone() + + assert row[0] == 1 + assert row[1] == "Test String" + assert abs(row[2] - 3.14) < 0.0001 + + # Clean up + cursor.execute("DROP TABLE IF EXISTS #test_inputsizes_mismatch") + +def test_setinputsizes_parameter_count_mismatch_more(db_connection): + """Test setinputsizes with more sizes than parameters""" + import warnings + + cursor = db_connection.cursor() + + # Create a test table + cursor.execute("DROP TABLE IF EXISTS #test_inputsizes_mismatch") + cursor.execute(""" + CREATE TABLE #test_inputsizes_mismatch ( + col1 INT, + col2 NVARCHAR(100) + ) + """) + + # Set more input sizes than parameters + cursor.setinputsizes([ + (mssql_python.SQL_INTEGER, 0, 0), + (mssql_python.SQL_WVARCHAR, 100, 0), + (mssql_python.SQL_FLOAT, 0, 0) # Extra parameter type + ]) + + # Execute with fewer parameters than specified input sizes + with warnings.catch_warnings(record=True) as w: + cursor.execute( + "INSERT INTO #test_inputsizes_mismatch VALUES (?, ?)", + 1, "Test String" + ) + assert len(w) > 0, "Warning should be issued for parameter count mismatch" + assert "number of input sizes" in str(w[0].message).lower() + + # Verify data was inserted correctly + cursor.execute("SELECT * FROM #test_inputsizes_mismatch") + row = cursor.fetchone() + + assert row[0] == 1 + assert row[1] == "Test String" + + # Clean up + cursor.execute("DROP TABLE IF EXISTS #test_inputsizes_mismatch") + +def test_setinputsizes_with_null_values(db_connection): + """Test setinputsizes with NULL values for various data types""" + + cursor = db_connection.cursor() + + # Create a test table with multiple data types + cursor.execute("DROP TABLE IF EXISTS #test_inputsizes_null") + cursor.execute(""" + CREATE TABLE #test_inputsizes_null ( + int_col INT, + string_col NVARCHAR(100), + float_col FLOAT, + date_col DATE, + binary_col VARBINARY(100) + ) + """) + + # Set input sizes for all columns + cursor.setinputsizes([ + (mssql_python.SQL_INTEGER, 0, 0), + (mssql_python.SQL_WVARCHAR, 100, 0), + (mssql_python.SQL_FLOAT, 0, 0), + (mssql_python.SQL_DATE, 0, 0), + (mssql_python.SQL_VARBINARY, 100, 0) + ]) + + # Insert row with all NULL values + cursor.execute( + "INSERT INTO #test_inputsizes_null VALUES (?, ?, ?, ?, ?)", + None, None, None, None, None + ) + + # Insert row with mix of NULL and non-NULL values + cursor.execute( + "INSERT INTO #test_inputsizes_null VALUES (?, ?, ?, ?, ?)", + 42, None, 3.14, None, b'binary data' + ) + + # Verify data was inserted correctly + cursor.execute("SELECT * FROM #test_inputsizes_null ORDER BY CASE WHEN int_col IS NULL THEN 0 ELSE 1 END") + rows = cursor.fetchall() + + # First row should be all NULLs + assert len(rows) == 2 + assert rows[0][0] is None + assert rows[0][1] is None + assert rows[0][2] is None + assert rows[0][3] is None + assert rows[0][4] is None + + # Second row should have mix of NULL and non-NULL + assert rows[1][0] == 42 + assert rows[1][1] is None + assert abs(rows[1][2] - 3.14) < 0.0001 + assert rows[1][3] is None + assert rows[1][4] == b'binary data' + + # Clean up + cursor.execute("DROP TABLE IF EXISTS #test_inputsizes_null") + +def test_setinputsizes_sql_injection_protection(db_connection): + """Test that setinputsizes doesn't allow SQL injection""" + cursor = db_connection.cursor() + + # Create a test table + cursor.execute("CREATE TABLE #test_sql_injection (id INT, name VARCHAR(100))") + + # Insert legitimate data + cursor.execute("INSERT INTO #test_sql_injection VALUES (1, 'safe')") + + # Set input sizes with potentially malicious SQL types and sizes + try: + # This should fail with a validation error + cursor.setinputsizes([(999999, 1000000, 1000000)]) # Invalid SQL type + except ValueError: + pass # Expected + + # Test with valid types but attempt SQL injection in parameter + cursor.setinputsizes([(mssql_python.SQL_VARCHAR, 100, 0)]) + injection_attempt = "x'; DROP TABLE #test_sql_injection; --" + + # This should safely parameterize without executing the injection + cursor.execute("SELECT * FROM #test_sql_injection WHERE name = ?", injection_attempt) + + # Verify table still exists and injection didn't work + cursor.execute("SELECT COUNT(*) FROM #test_sql_injection") + count = cursor.fetchone()[0] + assert count == 1, "SQL injection protection failed" + + # Clean up + cursor.execute("DROP TABLE #test_sql_injection") + +def test_gettypeinfo_all_types(cursor): + """Test getTypeInfo with no arguments returns all data types""" + # Get all type information + type_info = cursor.getTypeInfo().fetchall() + + # Verify we got results + assert type_info is not None, "getTypeInfo() should return results" + assert len(type_info) > 0, "getTypeInfo() should return at least one data type" + + # Verify common data types are present + type_names = [str(row.type_name).upper() for row in type_info] + assert any('VARCHAR' in name for name in type_names), "VARCHAR type should be in results" + assert any('INT' in name for name in type_names), "INTEGER type should be in results" + + # Verify first row has expected columns + first_row = type_info[0] + assert hasattr(first_row, 'type_name'), "Result should have type_name column" + assert hasattr(first_row, 'data_type'), "Result should have data_type column" + assert hasattr(first_row, 'column_size'), "Result should have column_size column" + assert hasattr(first_row, 'nullable'), "Result should have nullable column" + +def test_gettypeinfo_specific_type(cursor): + """Test getTypeInfo with specific type argument""" + from mssql_python.constants import ConstantsDDBC + + # Test with VARCHAR type (SQL_VARCHAR) + varchar_info = cursor.getTypeInfo(ConstantsDDBC.SQL_VARCHAR.value).fetchall() + + # Verify we got results specific to VARCHAR + assert varchar_info is not None, "getTypeInfo(SQL_VARCHAR) should return results" + assert len(varchar_info) > 0, "getTypeInfo(SQL_VARCHAR) should return at least one row" + + # All rows should be related to VARCHAR type + for row in varchar_info: + assert 'varchar' in row.type_name or 'char' in row.type_name, \ + f"Expected VARCHAR type, got {row.type_name}" + assert row.data_type == ConstantsDDBC.SQL_VARCHAR.value, \ + f"Expected data_type={ConstantsDDBC.SQL_VARCHAR.value}, got {row.data_type}" + +def test_gettypeinfo_result_structure(cursor): + """Test the structure of getTypeInfo result rows""" + # Get info for a common type like INTEGER + from mssql_python.constants import ConstantsDDBC + + int_info = cursor.getTypeInfo(ConstantsDDBC.SQL_INTEGER.value).fetchall() + + # Make sure we have at least one result + assert len(int_info) > 0, "getTypeInfo for INTEGER should return results" + + # Check for all required columns in the result + first_row = int_info[0] + required_columns = [ + 'type_name', 'data_type', 'column_size', 'literal_prefix', + 'literal_suffix', 'create_params', 'nullable', 'case_sensitive', + 'searchable', 'unsigned_attribute', 'fixed_prec_scale', + 'auto_unique_value', 'local_type_name', 'minimum_scale', + 'maximum_scale', 'sql_data_type', 'sql_datetime_sub', + 'num_prec_radix', 'interval_precision' + ] + + for column in required_columns: + assert hasattr(first_row, column), f"Result missing required column: {column}" + +def test_gettypeinfo_numeric_type(cursor): + """Test getTypeInfo for numeric data types""" + from mssql_python.constants import ConstantsDDBC + + # Get information about DECIMAL type + decimal_info = cursor.getTypeInfo(ConstantsDDBC.SQL_DECIMAL.value).fetchall() + + # Verify decimal-specific attributes + assert len(decimal_info) > 0, "getTypeInfo for DECIMAL should return results" + + decimal_row = decimal_info[0] + # DECIMAL should have precision and scale parameters + assert decimal_row.create_params is not None, "DECIMAL should have create_params" + assert "PRECISION" in decimal_row.create_params.upper() or \ + "SCALE" in decimal_row.create_params.upper(), \ + "DECIMAL create_params should mention precision/scale" + + # Numeric types typically use base 10 for the num_prec_radix + assert decimal_row.num_prec_radix == 10, \ + f"Expected num_prec_radix=10 for DECIMAL, got {decimal_row.num_prec_radix}" + +def test_gettypeinfo_datetime_types(cursor): + """Test getTypeInfo for datetime types""" + from mssql_python.constants import ConstantsDDBC + + # Get information about TIMESTAMP type instead of DATETIME + # SQL_TYPE_TIMESTAMP (93) is more commonly used for datetime in ODBC + datetime_info = cursor.getTypeInfo(ConstantsDDBC.SQL_TYPE_TIMESTAMP.value).fetchall() + + # Verify we got datetime-related results + assert len(datetime_info) > 0, "getTypeInfo for TIMESTAMP should return results" + + # Check for datetime-specific attributes + first_row = datetime_info[0] + assert hasattr(first_row, 'type_name'), "Result should have type_name column" + + # Datetime type names often contain 'date', 'time', or 'datetime' + type_name_lower = first_row.type_name.lower() + assert any(term in type_name_lower for term in ['date', 'time', 'timestamp', 'datetime']), \ + f"Expected datetime-related type name, got {first_row.type_name}" + +def test_gettypeinfo_multiple_calls(cursor): + """Test calling getTypeInfo multiple times in succession""" + from mssql_python.constants import ConstantsDDBC + + # First call - get all types + all_types = cursor.getTypeInfo().fetchall() + assert len(all_types) > 0, "First call to getTypeInfo should return results" + + # Second call - get VARCHAR type + varchar_info = cursor.getTypeInfo(ConstantsDDBC.SQL_VARCHAR.value).fetchall() + assert len(varchar_info) > 0, "Second call to getTypeInfo should return results" + + # Third call - get INTEGER type + int_info = cursor.getTypeInfo(ConstantsDDBC.SQL_INTEGER.value).fetchall() + assert len(int_info) > 0, "Third call to getTypeInfo should return results" + + # Verify the results are different between calls + assert len(all_types) > len(varchar_info), "All types should return more rows than specific type" + +def test_gettypeinfo_binary_types(cursor): + """Test getTypeInfo for binary data types""" + from mssql_python.constants import ConstantsDDBC + + # Get information about BINARY or VARBINARY type + binary_info = cursor.getTypeInfo(ConstantsDDBC.SQL_BINARY.value).fetchall() + + # Verify we got binary-related results + assert len(binary_info) > 0, "getTypeInfo for BINARY should return results" + + # Check for binary-specific attributes + for row in binary_info: + type_name_lower = row.type_name.lower() + # Include 'timestamp' as SQL Server reports it as a binary type + assert any(term in type_name_lower for term in ['binary', 'blob', 'image', 'timestamp']), \ + f"Expected binary-related type name, got {row.type_name}" + + # Binary types typically don't support case sensitivity + assert row.case_sensitive == 0, f"Binary types should not be case sensitive, got {row.case_sensitive}" + +def test_gettypeinfo_cached_results(cursor): + """Test that multiple identical calls to getTypeInfo are efficient""" + from mssql_python.constants import ConstantsDDBC + import time + + # First call - might be slower + start_time = time.time() + first_result = cursor.getTypeInfo(ConstantsDDBC.SQL_VARCHAR.value).fetchall() + first_duration = time.time() - start_time + + # Give the system a moment + time.sleep(0.1) + + # Second call with same type - should be similar or faster + start_time = time.time() + second_result = cursor.getTypeInfo(ConstantsDDBC.SQL_VARCHAR.value).fetchall() + second_duration = time.time() - start_time + + # Results should be consistent + assert len(first_result) == len(second_result), "Multiple calls should return same number of results" + + # Both calls should return the correct type info + for row in second_result: + assert row.data_type == ConstantsDDBC.SQL_VARCHAR.value, \ + f"Expected SQL_VARCHAR type, got {row.data_type}" + +def test_procedures_setup(cursor, db_connection): + """Create a test schema and procedures for testing""" + try: + # Create a test schema for isolation + cursor.execute("IF NOT EXISTS (SELECT * FROM sys.schemas WHERE name = 'pytest_proc_schema') EXEC('CREATE SCHEMA pytest_proc_schema')") + + # Create test stored procedures + cursor.execute(""" + CREATE OR ALTER PROCEDURE pytest_proc_schema.test_proc1 + AS + BEGIN + SELECT 1 AS result + END + """) + + cursor.execute(""" + CREATE OR ALTER PROCEDURE pytest_proc_schema.test_proc2 + @param1 INT, + @param2 VARCHAR(50) OUTPUT + AS + BEGIN + SELECT @param2 = 'Output ' + CAST(@param1 AS VARCHAR(10)) + RETURN @param1 + END + """) + + db_connection.commit() + except Exception as e: + pytest.fail(f"Test setup failed: {e}") + +def test_procedures_all(cursor, db_connection): + """Test getting information about all procedures""" + # First set up our test procedures + test_procedures_setup(cursor, db_connection) + + try: + # Get all procedures + procs = cursor.procedures().fetchall() + + # Verify we got results + assert procs is not None, "procedures() should return results" + assert len(procs) > 0, "procedures() should return at least one procedure" + + # Verify structure of results + first_row = procs[0] + assert hasattr(first_row, 'procedure_cat'), "Result should have procedure_cat column" + assert hasattr(first_row, 'procedure_schem'), "Result should have procedure_schem column" + assert hasattr(first_row, 'procedure_name'), "Result should have procedure_name column" + assert hasattr(first_row, 'num_input_params'), "Result should have num_input_params column" + assert hasattr(first_row, 'num_output_params'), "Result should have num_output_params column" + assert hasattr(first_row, 'num_result_sets'), "Result should have num_result_sets column" + assert hasattr(first_row, 'remarks'), "Result should have remarks column" + assert hasattr(first_row, 'procedure_type'), "Result should have procedure_type column" + + finally: + # Clean up happens in test_procedures_cleanup + pass + +def test_procedures_specific(cursor, db_connection): + """Test getting information about a specific procedure""" + try: + # Get specific procedure + procs = cursor.procedures(procedure='test_proc1', schema='pytest_proc_schema').fetchall() + + # Verify we got the correct procedure + assert len(procs) == 1, "Should find exactly one procedure" + proc = procs[0] + assert proc.procedure_name == 'test_proc1;1', "Wrong procedure name returned" + assert proc.procedure_schem == 'pytest_proc_schema', "Wrong schema returned" + + finally: + # Clean up happens in test_procedures_cleanup + pass + +def test_procedures_with_schema(cursor, db_connection): + """Test getting procedures with schema filter""" + try: + # Get procedures for our test schema + procs = cursor.procedures(schema='pytest_proc_schema').fetchall() + + # Verify schema filter worked + assert len(procs) >= 2, "Should find at least two procedures in schema" + for proc in procs: + assert proc.procedure_schem == 'pytest_proc_schema', f"Expected schema pytest_proc_schema, got {proc.procedure_schem}" + + # Verify our specific procedures are in the results + proc_names = [p.procedure_name for p in procs] + assert 'test_proc1;1' in proc_names, "test_proc1;1 should be in results" + assert 'test_proc2;1' in proc_names, "test_proc2;1 should be in results" + + finally: + # Clean up happens in test_procedures_cleanup + pass + +def test_procedures_nonexistent(cursor): + """Test procedures() with non-existent procedure name""" + # Use a procedure name that's highly unlikely to exist + procs = cursor.procedures(procedure='nonexistent_procedure_xyz123').fetchall() + + # Should return empty list, not error + assert isinstance(procs, list), "Should return a list for non-existent procedure" + assert len(procs) == 0, "Should return empty list for non-existent procedure" + +def test_procedures_catalog_filter(cursor, db_connection): + """Test procedures() with catalog filter""" + # Get current database name + cursor.execute("SELECT DB_NAME() AS current_db") + current_db = cursor.fetchone().current_db + + try: + # Get procedures with current catalog + procs = cursor.procedures(catalog=current_db, schema='pytest_proc_schema').fetchall() + + # Verify catalog filter worked + assert len(procs) >= 2, "Should find procedures in current catalog" + for proc in procs: + assert proc.procedure_cat == current_db, f"Expected catalog {current_db}, got {proc.procedure_cat}" + + # Get procedures with non-existent catalog + fake_procs = cursor.procedures(catalog='nonexistent_db_xyz123').fetchall() + assert len(fake_procs) == 0, "Should return empty list for non-existent catalog" + + finally: + # Clean up happens in test_procedures_cleanup + pass + +def test_procedures_with_parameters(cursor, db_connection): + """Test that procedures() correctly reports parameter information""" + try: + # Create a simpler procedure with basic parameters + cursor.execute(""" + CREATE OR ALTER PROCEDURE pytest_proc_schema.test_params_proc + @in1 INT, + @in2 VARCHAR(50) + AS + BEGIN + SELECT @in1 AS value1, @in2 AS value2 + END + """) + db_connection.commit() + + # Get procedure info + procs = cursor.procedures(procedure='test_params_proc', schema='pytest_proc_schema').fetchall() + + # Verify we found the procedure + assert len(procs) == 1, "Should find exactly one procedure" + proc = procs[0] + + # Just check if columns exist, don't check specific values + assert hasattr(proc, 'num_input_params'), "Result should have num_input_params column" + assert hasattr(proc, 'num_output_params'), "Result should have num_output_params column" + + # Test simple execution without output parameters + cursor.execute("EXEC pytest_proc_schema.test_params_proc 10, 'Test'") + + # Verify the procedure returned expected values + row = cursor.fetchone() + assert row is not None, "Procedure should return results" + assert row[0] == 10, "First parameter value incorrect" + assert row[1] == 'Test', "Second parameter value incorrect" + + finally: + cursor.execute("DROP PROCEDURE IF EXISTS pytest_proc_schema.test_params_proc") + db_connection.commit() + +def test_procedures_result_set_info(cursor, db_connection): + """Test that procedures() reports information about result sets""" + try: + # Create procedures with different result set patterns + cursor.execute(""" + CREATE OR ALTER PROCEDURE pytest_proc_schema.test_no_results + AS + BEGIN + DECLARE @x INT = 1 + END + """) + + cursor.execute(""" + CREATE OR ALTER PROCEDURE pytest_proc_schema.test_one_result + AS + BEGIN + SELECT 1 AS col1, 'test' AS col2 + END + """) + + cursor.execute(""" + CREATE OR ALTER PROCEDURE pytest_proc_schema.test_multiple_results + AS + BEGIN + SELECT 1 AS result1 + SELECT 'test' AS result2 + SELECT GETDATE() AS result3 + END + """) + db_connection.commit() + + # Get procedure info for all test procedures + procs = cursor.procedures(schema='pytest_proc_schema', procedure='test_%').fetchall() + + # Verify we found at least some procedures + assert len(procs) > 0, "Should find at least some test procedures" + + # Get the procedure names we found + result_proc_names = [p.procedure_name for p in procs + if p.procedure_name.startswith('test_') and 'results' in p.procedure_name] + print(f"Found result procedures: {result_proc_names}") + + # The num_result_sets column exists but might not have correct values + for proc in procs: + assert hasattr(proc, 'num_result_sets'), "Result should have num_result_sets column" + + # Test execution of the procedures to verify they work + cursor.execute("EXEC pytest_proc_schema.test_no_results") + assert cursor.fetchall() == [], "test_no_results should return no results" + + cursor.execute("EXEC pytest_proc_schema.test_one_result") + rows = cursor.fetchall() + assert len(rows) == 1, "test_one_result should return one row" + assert len(rows[0]) == 2, "test_one_result row should have two columns" + + cursor.execute("EXEC pytest_proc_schema.test_multiple_results") + rows1 = cursor.fetchall() + assert len(rows1) == 1, "First result set should have one row" + assert cursor.nextset(), "Should have a second result set" + rows2 = cursor.fetchall() + assert len(rows2) == 1, "Second result set should have one row" + assert cursor.nextset(), "Should have a third result set" + rows3 = cursor.fetchall() + assert len(rows3) == 1, "Third result set should have one row" + + finally: + cursor.execute("DROP PROCEDURE IF EXISTS pytest_proc_schema.test_no_results") + cursor.execute("DROP PROCEDURE IF EXISTS pytest_proc_schema.test_one_result") + cursor.execute("DROP PROCEDURE IF EXISTS pytest_proc_schema.test_multiple_results") + db_connection.commit() + +def test_procedures_cleanup(cursor, db_connection): + """Clean up all test procedures and schema after testing""" + try: + # Drop all test procedures + cursor.execute("DROP PROCEDURE IF EXISTS pytest_proc_schema.test_proc1") + cursor.execute("DROP PROCEDURE IF EXISTS pytest_proc_schema.test_proc2") + cursor.execute("DROP PROCEDURE IF EXISTS pytest_proc_schema.test_params_proc") + cursor.execute("DROP PROCEDURE IF EXISTS pytest_proc_schema.test_no_results") + cursor.execute("DROP PROCEDURE IF EXISTS pytest_proc_schema.test_one_result") + cursor.execute("DROP PROCEDURE IF EXISTS pytest_proc_schema.test_multiple_results") + + # Drop the test schema + cursor.execute("DROP SCHEMA IF EXISTS pytest_proc_schema") + db_connection.commit() + except Exception as e: + pytest.fail(f"Test cleanup failed: {e}") + +def test_foreignkeys_setup(cursor, db_connection): + """Create tables with foreign key relationships for testing""" + try: + # Create a test schema for isolation + cursor.execute("IF NOT EXISTS (SELECT * FROM sys.schemas WHERE name = 'pytest_fk_schema') EXEC('CREATE SCHEMA pytest_fk_schema')") + + # Drop tables if they exist (in reverse order to avoid constraint conflicts) + cursor.execute("DROP TABLE IF EXISTS pytest_fk_schema.orders") + cursor.execute("DROP TABLE IF EXISTS pytest_fk_schema.customers") + + # Create parent table + cursor.execute(""" + CREATE TABLE pytest_fk_schema.customers ( + customer_id INT PRIMARY KEY, + customer_name VARCHAR(100) NOT NULL + ) + """) + + # Create child table with foreign key + cursor.execute(""" + CREATE TABLE pytest_fk_schema.orders ( + order_id INT PRIMARY KEY, + order_date DATETIME NOT NULL, + customer_id INT NOT NULL, + total_amount DECIMAL(10, 2) NOT NULL, + CONSTRAINT FK_Orders_Customers FOREIGN KEY (customer_id) + REFERENCES pytest_fk_schema.customers (customer_id) + ) + """) + + # Insert test data + cursor.execute(""" + INSERT INTO pytest_fk_schema.customers (customer_id, customer_name) + VALUES (1, 'Test Customer 1'), (2, 'Test Customer 2') + """) + + cursor.execute(""" + INSERT INTO pytest_fk_schema.orders (order_id, order_date, customer_id, total_amount) + VALUES (101, GETDATE(), 1, 150.00), (102, GETDATE(), 2, 250.50) + """) + + db_connection.commit() + except Exception as e: + pytest.fail(f"Test setup failed: {e}") + +def test_foreignkeys_all(cursor, db_connection): + """Test getting all foreign keys""" + try: + # First set up our test tables + test_foreignkeys_setup(cursor, db_connection) + + # Get all foreign keys + fks = cursor.foreignKeys(table='orders', schema='pytest_fk_schema').fetchall() + + # Verify we got results + assert fks is not None, "foreignKeys() should return results" + assert len(fks) > 0, "foreignKeys() should return at least one foreign key" + + # Verify our test FK is in the results + # Search case-insensitively since the database might return different case + found_test_fk = False + for fk in fks: + if (fk.fktable_name.lower() == 'orders' and + fk.pktable_name.lower() == 'customers'): + found_test_fk = True + break + + assert found_test_fk, "Could not find the test foreign key in results" + + finally: + # Clean up + cursor.execute("DROP TABLE IF EXISTS pytest_fk_schema.orders") + cursor.execute("DROP TABLE IF EXISTS pytest_fk_schema.customers") + db_connection.commit() + +def test_foreignkeys_specific_table(cursor, db_connection): + """Test getting foreign keys for a specific table""" + try: + # First set up our test tables + test_foreignkeys_setup(cursor, db_connection) + + # Get foreign keys for the orders table + fks = cursor.foreignKeys(table='orders', schema='pytest_fk_schema').fetchall() + + # Verify we got results + assert len(fks) == 1, "Should find exactly one foreign key for orders table" + + # Verify the foreign key details + fk = fks[0] + assert fk.fktable_name.lower() == 'orders', "Wrong foreign key table name" + assert fk.pktable_name.lower() == 'customers', "Wrong primary key table name" + assert fk.fkcolumn_name.lower() == 'customer_id', "Wrong foreign key column name" + assert fk.pkcolumn_name.lower() == 'customer_id', "Wrong primary key column name" + + finally: + # Clean up + cursor.execute("DROP TABLE IF EXISTS pytest_fk_schema.orders") + cursor.execute("DROP TABLE IF EXISTS pytest_fk_schema.customers") + db_connection.commit() + +def test_foreignkeys_specific_foreign_table(cursor, db_connection): + """Test getting foreign keys that reference a specific table""" + try: + # First set up our test tables + test_foreignkeys_setup(cursor, db_connection) + + # Get foreign keys that reference the customers table + fks = cursor.foreignKeys(foreignTable='customers', foreignSchema='pytest_fk_schema').fetchall() + + # Verify we got results + assert len(fks) > 0, "Should find at least one foreign key referencing customers table" + + # Verify our test FK is in the results + found_test_fk = False + for fk in fks: + if (fk.fktable_name.lower() == 'orders' and + fk.pktable_name.lower() == 'customers'): + found_test_fk = True + break + + assert found_test_fk, "Could not find the test foreign key in results" + + finally: + # Clean up + cursor.execute("DROP TABLE IF EXISTS pytest_fk_schema.orders") + cursor.execute("DROP TABLE IF EXISTS pytest_fk_schema.customers") + db_connection.commit() + +def test_foreignkeys_both_tables(cursor, db_connection): + """Test getting foreign keys with both table and foreignTable specified""" + try: + # First set up our test tables + test_foreignkeys_setup(cursor, db_connection) + + # Get foreign keys between the two tables + fks = cursor.foreignKeys( + table='orders', schema='pytest_fk_schema', + foreignTable='customers', foreignSchema='pytest_fk_schema' + ).fetchall() + + # Verify we got results + assert len(fks) == 1, "Should find exactly one foreign key between specified tables" + + # Verify the foreign key details + fk = fks[0] + assert fk.fktable_name.lower() == 'orders', "Wrong foreign key table name" + assert fk.pktable_name.lower() == 'customers', "Wrong primary key table name" + assert fk.fkcolumn_name.lower() == 'customer_id', "Wrong foreign key column name" + assert fk.pkcolumn_name.lower() == 'customer_id', "Wrong primary key column name" + + finally: + # Clean up + cursor.execute("DROP TABLE IF EXISTS pytest_fk_schema.orders") + cursor.execute("DROP TABLE IF EXISTS pytest_fk_schema.customers") + db_connection.commit() + +def test_foreignkeys_nonexistent(cursor): + """Test foreignKeys() with non-existent table name""" + # Use a table name that's highly unlikely to exist + fks = cursor.foreignKeys(table='nonexistent_table_xyz123').fetchall() + + # Should return empty list, not error + assert isinstance(fks, list), "Should return a list for non-existent table" + assert len(fks) == 0, "Should return empty list for non-existent table" + +def test_foreignkeys_catalog_schema(cursor, db_connection): + """Test foreignKeys() with catalog and schema filters""" + try: + # First set up our test tables + test_foreignkeys_setup(cursor, db_connection) + + # Get current database name + cursor.execute("SELECT DB_NAME() AS current_db") + row = cursor.fetchone() + current_db = row.current_db + + # Get foreign keys with current catalog and pytest schema + fks = cursor.foreignKeys( + table='orders', + catalog=current_db, + schema='pytest_fk_schema' + ).fetchall() + + # Verify we got results + assert len(fks) > 0, "Should find foreign keys with correct catalog/schema" + + # Verify catalog/schema in results + for fk in fks: + assert fk.fktable_cat == current_db, "Wrong foreign key table catalog" + assert fk.fktable_schem == 'pytest_fk_schema', "Wrong foreign key table schema" + + finally: + # Clean up + cursor.execute("DROP TABLE IF EXISTS pytest_fk_schema.orders") + cursor.execute("DROP TABLE IF EXISTS pytest_fk_schema.customers") + db_connection.commit() + +def test_foreignkeys_result_structure(cursor, db_connection): + """Test the structure of foreignKeys result rows""" + try: + # First set up our test tables + test_foreignkeys_setup(cursor, db_connection) + + # Get foreign keys for the orders table + fks = cursor.foreignKeys(table='orders', schema='pytest_fk_schema').fetchall() + + # Verify we got results + assert len(fks) > 0, "Should find at least one foreign key" + + # Check for all required columns in the result + first_row = fks[0] + required_columns = [ + 'pktable_cat', 'pktable_schem', 'pktable_name', 'pkcolumn_name', + 'fktable_cat', 'fktable_schem', 'fktable_name', 'fkcolumn_name', + 'key_seq', 'update_rule', 'delete_rule', 'fk_name', 'pk_name', + 'deferrability' + ] + + for column in required_columns: + assert hasattr(first_row, column), f"Result missing required column: {column}" + + # Verify specific values + assert first_row.fktable_name.lower() == 'orders', "Wrong foreign key table name" + assert first_row.pktable_name.lower() == 'customers', "Wrong primary key table name" + assert first_row.fkcolumn_name.lower() == 'customer_id', "Wrong foreign key column name" + assert first_row.pkcolumn_name.lower() == 'customer_id', "Wrong primary key column name" + assert first_row.key_seq == 1, "Wrong key sequence number" + assert first_row.fk_name is not None, "Foreign key name should not be None" + assert first_row.pk_name is not None, "Primary key name should not be None" + + finally: + # Clean up + cursor.execute("DROP TABLE IF EXISTS pytest_fk_schema.orders") + cursor.execute("DROP TABLE IF EXISTS pytest_fk_schema.customers") + db_connection.commit() + +def test_foreignkeys_multiple_column_fk(cursor, db_connection): + """Test foreignKeys() with a multi-column foreign key""" + try: + # First create the schema if needed + cursor.execute("IF NOT EXISTS (SELECT * FROM sys.schemas WHERE name = 'pytest_fk_schema') EXEC('CREATE SCHEMA pytest_fk_schema')") + + # Drop tables if they exist (in reverse order to avoid constraint conflicts) + cursor.execute("DROP TABLE IF EXISTS pytest_fk_schema.order_details") + cursor.execute("DROP TABLE IF EXISTS pytest_fk_schema.product_variants") + + # Create parent table with composite primary key + cursor.execute(""" + CREATE TABLE pytest_fk_schema.product_variants ( + product_id INT NOT NULL, + variant_id INT NOT NULL, + variant_name VARCHAR(100) NOT NULL, + PRIMARY KEY (product_id, variant_id) + ) + """) + + # Create child table with composite foreign key + cursor.execute(""" + CREATE TABLE pytest_fk_schema.order_details ( + order_id INT NOT NULL, + product_id INT NOT NULL, + variant_id INT NOT NULL, + quantity INT NOT NULL, + PRIMARY KEY (order_id, product_id, variant_id), + CONSTRAINT FK_OrderDetails_ProductVariants FOREIGN KEY (product_id, variant_id) + REFERENCES pytest_fk_schema.product_variants (product_id, variant_id) + ) + """) + + db_connection.commit() + + # Get foreign keys for the order_details table + fks = cursor.foreignKeys(table='order_details', schema='pytest_fk_schema').fetchall() + + # Verify we got results + assert len(fks) == 2, "Should find two rows for the composite foreign key (one per column)" + + # Group by key_seq to verify both columns + fk_columns = {} + for fk in fks: + fk_columns[fk.key_seq] = { + 'pkcolumn': fk.pkcolumn_name.lower(), + 'fkcolumn': fk.fkcolumn_name.lower() + } + + # Verify both columns are present + assert 1 in fk_columns, "First column of composite key missing" + assert 2 in fk_columns, "Second column of composite key missing" + + # Verify column mappings + assert fk_columns[1]['pkcolumn'] == 'product_id', "Wrong primary key column 1" + assert fk_columns[1]['fkcolumn'] == 'product_id', "Wrong foreign key column 1" + assert fk_columns[2]['pkcolumn'] == 'variant_id', "Wrong primary key column 2" + assert fk_columns[2]['fkcolumn'] == 'variant_id', "Wrong foreign key column 2" + + finally: + # Clean up + cursor.execute("DROP TABLE IF EXISTS pytest_fk_schema.order_details") + cursor.execute("DROP TABLE IF EXISTS pytest_fk_schema.product_variants") + db_connection.commit() + +def test_cleanup_schema(cursor, db_connection): + """Clean up the test schema after all tests""" + try: + # Make sure no tables remain + cursor.execute("DROP TABLE IF EXISTS pytest_fk_schema.orders") + cursor.execute("DROP TABLE IF EXISTS pytest_fk_schema.customers") + cursor.execute("DROP TABLE IF EXISTS pytest_fk_schema.order_details") + cursor.execute("DROP TABLE IF EXISTS pytest_fk_schema.product_variants") + db_connection.commit() + + # Drop the schema + cursor.execute("DROP SCHEMA IF EXISTS pytest_fk_schema") + db_connection.commit() + except Exception as e: + pytest.fail(f"Schema cleanup failed: {e}") + +def test_primarykeys_setup(cursor, db_connection): + """Create tables with primary keys for testing""" + try: + # Create a test schema for isolation + cursor.execute("IF NOT EXISTS (SELECT * FROM sys.schemas WHERE name = 'pytest_pk_schema') EXEC('CREATE SCHEMA pytest_pk_schema')") + + # Drop tables if they exist + cursor.execute("DROP TABLE IF EXISTS pytest_pk_schema.single_pk_test") + cursor.execute("DROP TABLE IF EXISTS pytest_pk_schema.composite_pk_test") + + # Create table with simple primary key + cursor.execute(""" + CREATE TABLE pytest_pk_schema.single_pk_test ( + id INT PRIMARY KEY, + name VARCHAR(100) NOT NULL, + description VARCHAR(200) NULL + ) + """) + + # Create table with composite primary key + cursor.execute(""" + CREATE TABLE pytest_pk_schema.composite_pk_test ( + dept_id INT NOT NULL, + emp_id INT NOT NULL, + hire_date DATE NOT NULL, + CONSTRAINT PK_composite_test PRIMARY KEY (dept_id, emp_id) + ) + """) + + db_connection.commit() + except Exception as e: + pytest.fail(f"Test setup failed: {e}") + +def test_primarykeys_simple(cursor, db_connection): + """Test primaryKeys returns information about a simple primary key""" + try: + # First set up our test tables + test_primarykeys_setup(cursor, db_connection) + + # Get primary key information + pks = cursor.primaryKeys('single_pk_test', schema='pytest_pk_schema').fetchall() + + # Verify we got results + assert len(pks) == 1, "Should find exactly one primary key column" + pk = pks[0] + + # Verify primary key details + assert pk.table_name.lower() == 'single_pk_test', "Wrong table name" + assert pk.column_name.lower() == 'id', "Wrong primary key column name" + assert pk.key_seq == 1, "Wrong key sequence number" + assert pk.pk_name is not None, "Primary key name should not be None" + + finally: + # Clean up happens in test_primarykeys_cleanup + pass + +def test_primarykeys_composite(cursor, db_connection): + """Test primaryKeys with a composite primary key""" + try: + # Get primary key information + pks = cursor.primaryKeys('composite_pk_test', schema='pytest_pk_schema').fetchall() + + # Verify we got results for both columns + assert len(pks) == 2, "Should find two primary key columns" + + # Sort by key_seq to ensure consistent order + pks = sorted(pks, key=lambda row: row.key_seq) + + # Verify first column + assert pks[0].table_name.lower() == 'composite_pk_test', "Wrong table name" + assert pks[0].column_name.lower() == 'dept_id', "Wrong first primary key column name" + assert pks[0].key_seq == 1, "Wrong key sequence number for first column" + + # Verify second column + assert pks[1].table_name.lower() == 'composite_pk_test', "Wrong table name" + assert pks[1].column_name.lower() == 'emp_id', "Wrong second primary key column name" + assert pks[1].key_seq == 2, "Wrong key sequence number for second column" + + # Both should have the same PK name + assert pks[0].pk_name == pks[1].pk_name, "Both columns should have the same primary key name" + + finally: + # Clean up happens in test_primarykeys_cleanup + pass + +def test_primarykeys_column_info(cursor, db_connection): + """Test that primaryKeys returns correct column information""" + try: + # Get primary key information + pks = cursor.primaryKeys('single_pk_test', schema='pytest_pk_schema').fetchall() + + # Verify column information + assert len(pks) == 1, "Should find exactly one primary key column" + pk = pks[0] + + # Verify expected columns are present + assert hasattr(pk, 'table_cat'), "Result should have table_cat column" + assert hasattr(pk, 'table_schem'), "Result should have table_schem column" + assert hasattr(pk, 'table_name'), "Result should have table_name column" + assert hasattr(pk, 'column_name'), "Result should have column_name column" + assert hasattr(pk, 'key_seq'), "Result should have key_seq column" + assert hasattr(pk, 'pk_name'), "Result should have pk_name column" + + # Verify values are correct + assert pk.table_schem.lower() == 'pytest_pk_schema', "Wrong schema name" + assert pk.table_name.lower() == 'single_pk_test', "Wrong table name" + assert pk.column_name.lower() == 'id', "Wrong column name" + assert isinstance(pk.key_seq, int), "key_seq should be an integer" + + finally: + # Clean up happens in test_primarykeys_cleanup + pass + +def test_primarykeys_nonexistent(cursor): + """Test primaryKeys() with non-existent table name""" + # Use a table name that's highly unlikely to exist + pks = cursor.primaryKeys('nonexistent_table_xyz123').fetchall() + + # Should return empty list, not error + assert isinstance(pks, list), "Should return a list for non-existent table" + assert len(pks) == 0, "Should return empty list for non-existent table" + +def test_primarykeys_catalog_filter(cursor, db_connection): + """Test primaryKeys() with catalog filter""" + try: + # Get current database name + cursor.execute("SELECT DB_NAME() AS current_db") + current_db = cursor.fetchone().current_db + + # Get primary keys with current catalog + pks = cursor.primaryKeys('single_pk_test', catalog=current_db, schema='pytest_pk_schema').fetchall() + + # Verify catalog filter worked + assert len(pks) == 1, "Should find exactly one primary key column" + pk = pks[0] + assert pk.table_cat == current_db, f"Expected catalog {current_db}, got {pk.table_cat}" + + # Get primary keys with non-existent catalog + fake_pks = cursor.primaryKeys('single_pk_test', catalog='nonexistent_db_xyz123').fetchall() + assert len(fake_pks) == 0, "Should return empty list for non-existent catalog" + + finally: + # Clean up happens in test_primarykeys_cleanup + pass + +def test_primarykeys_cleanup(cursor, db_connection): + """Clean up test tables after testing""" + try: + # Drop all test tables + cursor.execute("DROP TABLE IF EXISTS pytest_pk_schema.single_pk_test") + cursor.execute("DROP TABLE IF EXISTS pytest_pk_schema.composite_pk_test") + + # Drop the test schema + cursor.execute("DROP SCHEMA IF EXISTS pytest_pk_schema") + db_connection.commit() + except Exception as e: + pytest.fail(f"Test cleanup failed: {e}") + +def test_specialcolumns_setup(cursor, db_connection): + """Create test tables for testing rowIdColumns and rowVerColumns""" + try: + # Create a test schema for isolation + cursor.execute("IF NOT EXISTS (SELECT * FROM sys.schemas WHERE name = 'pytest_special_schema') EXEC('CREATE SCHEMA pytest_special_schema')") + + # Drop tables if they exist + cursor.execute("DROP TABLE IF EXISTS pytest_special_schema.rowid_test") + cursor.execute("DROP TABLE IF EXISTS pytest_special_schema.timestamp_test") + cursor.execute("DROP TABLE IF EXISTS pytest_special_schema.multiple_unique_test") + cursor.execute("DROP TABLE IF EXISTS pytest_special_schema.identity_test") + + # Create table with primary key (for rowIdColumns) + cursor.execute(""" + CREATE TABLE pytest_special_schema.rowid_test ( + id INT PRIMARY KEY, + name NVARCHAR(100) NOT NULL, + unique_col NVARCHAR(100) UNIQUE, + non_unique_col NVARCHAR(100) + ) + """) + + # Create table with rowversion column (for rowVerColumns) + cursor.execute(""" + CREATE TABLE pytest_special_schema.timestamp_test ( + id INT PRIMARY KEY, + name NVARCHAR(100) NOT NULL, + last_updated ROWVERSION + ) + """) + + # Create table with multiple unique identifiers + cursor.execute(""" + CREATE TABLE pytest_special_schema.multiple_unique_test ( + id INT NOT NULL, + code VARCHAR(10) NOT NULL, + email VARCHAR(100) UNIQUE, + order_number VARCHAR(20) UNIQUE, + CONSTRAINT PK_multiple_unique_test PRIMARY KEY (id, code) + ) + """) + + # Create table with identity column + cursor.execute(""" + CREATE TABLE pytest_special_schema.identity_test ( + id INT IDENTITY(1,1) PRIMARY KEY, + name NVARCHAR(100) NOT NULL, + last_modified DATETIME DEFAULT GETDATE() + ) + """) + + db_connection.commit() + except Exception as e: + pytest.fail(f"Test setup failed: {e}") + +def test_rowid_columns_basic(cursor, db_connection): + """Test basic functionality of rowIdColumns""" + try: + # Get row identifier columns for simple table + rowid_cols = cursor.rowIdColumns( + table='rowid_test', + schema='pytest_special_schema' + ).fetchall() + + # LIMITATION: Only returns first column of primary key + assert len(rowid_cols) == 1, "Should find exactly one ROWID column (first column of PK)" + + # Verify column name in the results + col = rowid_cols[0] + assert col.column_name.lower() == 'id', "Primary key column should be included in ROWID results" + + # Verify result structure + assert hasattr(col, 'scope'), "Result should have scope column" + assert hasattr(col, 'column_name'), "Result should have column_name column" + assert hasattr(col, 'data_type'), "Result should have data_type column" + assert hasattr(col, 'type_name'), "Result should have type_name column" + assert hasattr(col, 'column_size'), "Result should have column_size column" + assert hasattr(col, 'buffer_length'), "Result should have buffer_length column" + assert hasattr(col, 'decimal_digits'), "Result should have decimal_digits column" + assert hasattr(col, 'pseudo_column'), "Result should have pseudo_column column" + + # The scope should be one of the valid values or NULL + assert col.scope in [0, 1, 2, None], f"Invalid scope value: {col.scope}" + + # The pseudo_column should be one of the valid values + assert col.pseudo_column in [0, 1, 2, None], f"Invalid pseudo_column value: {col.pseudo_column}" + + except Exception as e: + pytest.fail(f"rowIdColumns basic test failed: {e}") + finally: + # Clean up happens in test_specialcolumns_cleanup + pass + +def test_rowid_columns_identity(cursor, db_connection): + """Test rowIdColumns with identity column""" + try: + # Get row identifier columns for table with identity column + rowid_cols = cursor.rowIdColumns( + table='identity_test', + schema='pytest_special_schema' + ).fetchall() + + # LIMITATION: Only returns the identity column if it's the primary key + assert len(rowid_cols) == 1, "Should find exactly one ROWID column (identity column as PK)" + + # Verify it's the identity column + col = rowid_cols[0] + assert col.column_name.lower() == 'id', "Identity column should be included as it's the PK" + + except Exception as e: + pytest.fail(f"rowIdColumns identity test failed: {e}") + finally: + # Clean up happens in test_specialcolumns_cleanup + pass + +def test_rowid_columns_composite(cursor, db_connection): + """Test rowIdColumns with composite primary key""" + try: + # Get row identifier columns for table with composite primary key + rowid_cols = cursor.rowIdColumns( + table='multiple_unique_test', + schema='pytest_special_schema' + ).fetchall() + + # LIMITATION: Only returns first column of composite primary key + assert len(rowid_cols) >= 1, "Should find at least one ROWID column (first column of PK)" + + # Verify column names in the results - should be the first PK column + col_names = [col.column_name.lower() for col in rowid_cols] + assert 'id' in col_names, "First part of composite PK should be included" + + # LIMITATION: Other parts of the PK or unique constraints may not be included + if len(rowid_cols) > 1: + # If additional columns are returned, they should be valid + for col in rowid_cols: + assert col.column_name.lower() in ['id', 'code'], "Only PK columns should be returned" + + except Exception as e: + pytest.fail(f"rowIdColumns composite test failed: {e}") + finally: + # Clean up happens in test_specialcolumns_cleanup + pass + +def test_rowid_columns_nonexistent(cursor): + """Test rowIdColumns with non-existent table""" + # Use a table name that's highly unlikely to exist + rowid_cols = cursor.rowIdColumns('nonexistent_table_xyz123').fetchall() + + # Should return empty list, not error + assert isinstance(rowid_cols, list), "Should return a list for non-existent table" + assert len(rowid_cols) == 0, "Should return empty list for non-existent table" + +def test_rowid_columns_nullable(cursor, db_connection): + """Test rowIdColumns with nullable parameter""" + try: + # First create a table with nullable unique column and non-nullable PK + cursor.execute(""" + CREATE TABLE pytest_special_schema.nullable_test ( + id INT PRIMARY KEY, -- PK can't be nullable in SQL Server + data NVARCHAR(100) NULL + ) + """) + db_connection.commit() + + # Test with nullable=True (default) + rowid_cols_with_nullable = cursor.rowIdColumns( + table='nullable_test', + schema='pytest_special_schema' + ).fetchall() + + # Verify PK column is included + assert len(rowid_cols_with_nullable) == 1, "Should return exactly one column (PK)" + assert rowid_cols_with_nullable[0].column_name.lower() == 'id', "PK column should be returned" + + # Test with nullable=False + rowid_cols_no_nullable = cursor.rowIdColumns( + table='nullable_test', + schema='pytest_special_schema', + nullable=False + ).fetchall() + + # The behavior of SQLSpecialColumns with SQL_NO_NULLS is to only return + # non-nullable columns that uniquely identify a row, but SQL Server returns + # an empty set in this case - this is expected behavior + assert len(rowid_cols_no_nullable) == 0, "Should return empty list when nullable=False (ODBC API behavior)" + + except Exception as e: + pytest.fail(f"rowIdColumns nullable test failed: {e}") + finally: + cursor.execute("DROP TABLE IF EXISTS pytest_special_schema.nullable_test") + db_connection.commit() + +def test_rowver_columns_basic(cursor, db_connection): + """Test basic functionality of rowVerColumns""" + try: + # Get version columns from timestamp test table + rowver_cols = cursor.rowVerColumns( + table='timestamp_test', + schema='pytest_special_schema' + ).fetchall() + + # Verify we got results + assert len(rowver_cols) == 1, "Should find exactly one ROWVER column" + + # Verify the column is the rowversion column + rowver_col = rowver_cols[0] + assert rowver_col.column_name.lower() == 'last_updated', "ROWVER column should be 'last_updated'" + assert rowver_col.type_name.lower() in ['rowversion', 'timestamp'], "ROWVER column should have rowversion or timestamp type" + + # Verify result structure - allowing for NULL values + assert hasattr(rowver_col, 'scope'), "Result should have scope column" + assert hasattr(rowver_col, 'column_name'), "Result should have column_name column" + assert hasattr(rowver_col, 'data_type'), "Result should have data_type column" + assert hasattr(rowver_col, 'type_name'), "Result should have type_name column" + assert hasattr(rowver_col, 'column_size'), "Result should have column_size column" + assert hasattr(rowver_col, 'buffer_length'), "Result should have buffer_length column" + assert hasattr(rowver_col, 'decimal_digits'), "Result should have decimal_digits column" + assert hasattr(rowver_col, 'pseudo_column'), "Result should have pseudo_column column" + + # The scope should be one of the valid values or NULL + assert rowver_col.scope in [0, 1, 2, None], f"Invalid scope value: {rowver_col.scope}" + + except Exception as e: + pytest.fail(f"rowVerColumns basic test failed: {e}") + finally: + # Clean up happens in test_specialcolumns_cleanup + pass + +def test_rowver_columns_nonexistent(cursor): + """Test rowVerColumns with non-existent table""" + # Use a table name that's highly unlikely to exist + rowver_cols = cursor.rowVerColumns('nonexistent_table_xyz123').fetchall() + + # Should return empty list, not error + assert isinstance(rowver_cols, list), "Should return a list for non-existent table" + assert len(rowver_cols) == 0, "Should return empty list for non-existent table" + +def test_rowver_columns_nullable(cursor, db_connection): + """Test rowVerColumns with nullable parameter (not expected to have effect)""" + try: + # First create a table with rowversion column + cursor.execute(""" + CREATE TABLE pytest_special_schema.nullable_rowver_test ( + id INT PRIMARY KEY, + ts ROWVERSION + ) + """) + db_connection.commit() + + # Test with nullable=True (default) + rowver_cols_with_nullable = cursor.rowVerColumns( + table='nullable_rowver_test', + schema='pytest_special_schema' + ).fetchall() + + # Verify rowversion column is included (rowversion can't be nullable) + assert len(rowver_cols_with_nullable) == 1, "Should find exactly one ROWVER column" + assert rowver_cols_with_nullable[0].column_name.lower() == 'ts', "ROWVERSION column should be included" + + # Test with nullable=False + rowver_cols_no_nullable = cursor.rowVerColumns( + table='nullable_rowver_test', + schema='pytest_special_schema', + nullable=False + ).fetchall() + + # Verify rowversion column is still included + assert len(rowver_cols_no_nullable) == 1, "Should find exactly one ROWVER column" + assert rowver_cols_no_nullable[0].column_name.lower() == 'ts', "ROWVERSION column should be included even with nullable=False" + + except Exception as e: + pytest.fail(f"rowVerColumns nullable test failed: {e}") + finally: + cursor.execute("DROP TABLE IF EXISTS pytest_special_schema.nullable_rowver_test") + db_connection.commit() + +def test_specialcolumns_catalog_filter(cursor, db_connection): + """Test special columns with catalog filter""" + try: + # Get current database name + cursor.execute("SELECT DB_NAME() AS current_db") + current_db = cursor.fetchone().current_db + + # Test rowIdColumns with current catalog + rowid_cols = cursor.rowIdColumns( + table='rowid_test', + catalog=current_db, + schema='pytest_special_schema' + ).fetchall() + + # Verify catalog filter worked + assert len(rowid_cols) > 0, "Should find ROWID columns with correct catalog" + + # Test rowIdColumns with non-existent catalog + fake_rowid_cols = cursor.rowIdColumns( + table='rowid_test', + catalog='nonexistent_db_xyz123', + schema='pytest_special_schema' + ).fetchall() + assert len(fake_rowid_cols) == 0, "Should return empty list for non-existent catalog" + + # Test rowVerColumns with current catalog + rowver_cols = cursor.rowVerColumns( + table='timestamp_test', + catalog=current_db, + schema='pytest_special_schema' + ).fetchall() + + # Verify catalog filter worked + assert len(rowver_cols) > 0, "Should find ROWVER columns with correct catalog" + + # Test rowVerColumns with non-existent catalog + fake_rowver_cols = cursor.rowVerColumns( + table='timestamp_test', + catalog='nonexistent_db_xyz123', + schema='pytest_special_schema' + ).fetchall() + assert len(fake_rowver_cols) == 0, "Should return empty list for non-existent catalog" + + except Exception as e: + pytest.fail(f"Special columns catalog filter test failed: {e}") + finally: + # Clean up happens in test_specialcolumns_cleanup + pass + +def test_specialcolumns_cleanup(cursor, db_connection): + """Clean up test tables after testing""" + try: + # Drop all test tables + cursor.execute("DROP TABLE IF EXISTS pytest_special_schema.rowid_test") + cursor.execute("DROP TABLE IF EXISTS pytest_special_schema.timestamp_test") + cursor.execute("DROP TABLE IF EXISTS pytest_special_schema.multiple_unique_test") + cursor.execute("DROP TABLE IF EXISTS pytest_special_schema.identity_test") + cursor.execute("DROP TABLE IF EXISTS pytest_special_schema.nullable_unique_test") + cursor.execute("DROP TABLE IF EXISTS pytest_special_schema.nullable_timestamp_test") + + # Drop the test schema + cursor.execute("DROP SCHEMA IF EXISTS pytest_special_schema") + db_connection.commit() + except Exception as e: + pytest.fail(f"Test cleanup failed: {e}") + +def test_statistics_setup(cursor, db_connection): + """Create test tables and indexes for statistics testing""" + try: + # Create a test schema for isolation + cursor.execute("IF NOT EXISTS (SELECT * FROM sys.schemas WHERE name = 'pytest_stats_schema') EXEC('CREATE SCHEMA pytest_stats_schema')") + + # Drop tables if they exist + cursor.execute("DROP TABLE IF EXISTS pytest_stats_schema.stats_test") + cursor.execute("DROP TABLE IF EXISTS pytest_stats_schema.empty_stats_test") + + # Create test table with various indexes + cursor.execute(""" + CREATE TABLE pytest_stats_schema.stats_test ( + id INT PRIMARY KEY, + name VARCHAR(100) NOT NULL, + email VARCHAR(100) UNIQUE, + department VARCHAR(50) NOT NULL, + salary DECIMAL(10, 2) NULL, + hire_date DATE NOT NULL + ) + """) + + # Create a non-unique index + cursor.execute(""" + CREATE INDEX IX_stats_test_dept_date ON pytest_stats_schema.stats_test (department, hire_date) + """) + + # Create a unique index on multiple columns + cursor.execute(""" + CREATE UNIQUE INDEX UX_stats_test_name_dept ON pytest_stats_schema.stats_test (name, department) + """) + + # Create an empty table for testing + cursor.execute(""" + CREATE TABLE pytest_stats_schema.empty_stats_test ( + id INT PRIMARY KEY, + data VARCHAR(100) NULL + ) + """) + + db_connection.commit() + except Exception as e: + pytest.fail(f"Test setup failed: {e}") + +def test_statistics_basic(cursor, db_connection): + """Test basic functionality of statistics method""" + try: + # First set up our test tables + test_statistics_setup(cursor, db_connection) + + # Get statistics for the test table (all indexes) + stats = cursor.statistics( + table='stats_test', + schema='pytest_stats_schema' + ).fetchall() + + # Verify we got results - should include PK, unique index on email, and non-unique index + assert stats is not None, "statistics() should return results" + assert len(stats) > 0, "statistics() should return at least one row" + + # Count different types of indexes + table_stats = [s for s in stats if s.type == 0] # TABLE_STAT + indexes = [s for s in stats if s.type != 0] # Actual indexes + + # We should have at least one table statistics row and multiple index rows + assert len(table_stats) <= 1, "Should have at most one TABLE_STAT row" + assert len(indexes) >= 3, "Should have at least 3 index entries (PK, unique email, non-unique dept+date)" + + # Verify column names in results + first_row = stats[0] + assert hasattr(first_row, 'table_name'), "Result should have table_name column" + assert hasattr(first_row, 'non_unique'), "Result should have non_unique column" + assert hasattr(first_row, 'index_name'), "Result should have index_name column" + assert hasattr(first_row, 'type'), "Result should have type column" + assert hasattr(first_row, 'column_name'), "Result should have column_name column" + + # Check that we can find the primary key + pk_found = False + for stat in stats: + if (hasattr(stat, 'index_name') and + stat.index_name and + 'pk' in stat.index_name.lower()): + pk_found = True + break + + assert pk_found, "Primary key should be included in statistics results" + + # Check that we can find the unique index on email + email_index_found = False + for stat in stats: + if (hasattr(stat, 'column_name') and + stat.column_name and + stat.column_name.lower() == 'email' and + hasattr(stat, 'non_unique') and + stat.non_unique == 0): # 0 = unique + email_index_found = True + break + + assert email_index_found, "Unique index on email should be included in statistics results" + + finally: + # Clean up happens in test_statistics_cleanup + pass + +def test_statistics_unique_only(cursor, db_connection): + """Test statistics with unique=True to get only unique indexes""" + try: + # Get statistics for only unique indexes + stats = cursor.statistics( + table='stats_test', + schema='pytest_stats_schema', + unique=True + ).fetchall() + + # Verify we got results + assert stats is not None, "statistics() with unique=True should return results" + assert len(stats) > 0, "statistics() with unique=True should return at least one row" + + # All index entries should be for unique indexes (non_unique = 0) + for stat in stats: + if hasattr(stat, 'type') and stat.type != 0: # Skip TABLE_STAT entries + assert hasattr(stat, 'non_unique'), "Index entry should have non_unique column" + assert stat.non_unique == 0, "With unique=True, all indexes should be unique" + + # Count different types of indexes + indexes = [s for s in stats if hasattr(s, 'type') and s.type != 0] + + # We should have multiple unique indexes (PK, unique email, unique name+dept) + assert len(indexes) >= 3, "Should have at least 3 unique index entries" + + finally: + # Clean up happens in test_statistics_cleanup + pass + +def test_statistics_empty_table(cursor, db_connection): + """Test statistics on a table with no data (just schema)""" + try: + # Get statistics for the empty table + stats = cursor.statistics( + table='empty_stats_test', + schema='pytest_stats_schema' + ).fetchall() + + # Should still return metadata about the primary key + assert stats is not None, "statistics() should return results even for empty table" + assert len(stats) > 0, "statistics() should return at least one row for empty table" + + # Check for primary key + pk_found = False + for stat in stats: + if (hasattr(stat, 'index_name') and + stat.index_name and + 'pk' in stat.index_name.lower()): + pk_found = True + break + + assert pk_found, "Primary key should be included in statistics results for empty table" + + finally: + # Clean up happens in test_statistics_cleanup + pass + +def test_statistics_nonexistent(cursor): + """Test statistics with non-existent table name""" + # Use a table name that's highly unlikely to exist + stats = cursor.statistics('nonexistent_table_xyz123').fetchall() + + # Should return empty list, not error + assert isinstance(stats, list), "Should return a list for non-existent table" + assert len(stats) == 0, "Should return empty list for non-existent table" + +def test_statistics_result_structure(cursor, db_connection): + """Test the complete structure of statistics result rows""" + try: + # Get statistics for the test table + stats = cursor.statistics( + table='stats_test', + schema='pytest_stats_schema' + ).fetchall() + + # Verify we have results + assert len(stats) > 0, "Should have statistics results" + + # Find a row that's an actual index (not TABLE_STAT) + index_row = None + for stat in stats: + if hasattr(stat, 'type') and stat.type != 0: + index_row = stat + break + + assert index_row is not None, "Should have at least one index row" + + # Check for all required columns + required_columns = [ + 'table_cat', 'table_schem', 'table_name', 'non_unique', + 'index_qualifier', 'index_name', 'type', 'ordinal_position', + 'column_name', 'asc_or_desc', 'cardinality', 'pages', + 'filter_condition' + ] + + for column in required_columns: + assert hasattr(index_row, column), f"Result missing required column: {column}" + + # Check types of key columns + assert isinstance(index_row.table_name, str), "table_name should be a string" + assert isinstance(index_row.type, int), "type should be an integer" + + # Don't check the actual values of cardinality and pages as they may be NULL + # or driver-dependent, especially for empty tables + + finally: + # Clean up happens in test_statistics_cleanup + pass + +def test_statistics_catalog_filter(cursor, db_connection): + """Test statistics with catalog filter""" + try: + # Get current database name + cursor.execute("SELECT DB_NAME() AS current_db") + current_db = cursor.fetchone().current_db + + # Get statistics with current catalog + stats = cursor.statistics( + table='stats_test', + catalog=current_db, + schema='pytest_stats_schema' + ).fetchall() + + # Verify catalog filter worked + assert len(stats) > 0, "Should find statistics with correct catalog" + + # Verify catalog in results + for stat in stats: + if hasattr(stat, 'table_cat'): + assert stat.table_cat.lower() == current_db.lower(), "Wrong table catalog" + + # Get statistics with non-existent catalog + fake_stats = cursor.statistics( + table='stats_test', + catalog='nonexistent_db_xyz123', + schema='pytest_stats_schema' + ).fetchall() + assert len(fake_stats) == 0, "Should return empty list for non-existent catalog" + + finally: + # Clean up happens in test_statistics_cleanup + pass + +def test_statistics_with_quick_parameter(cursor, db_connection): + """Test statistics with quick parameter variations""" + try: + # Test with quick=True (default) + quick_stats = cursor.statistics( + table='stats_test', + schema='pytest_stats_schema', + quick=True + ).fetchall() + + # Test with quick=False + thorough_stats = cursor.statistics( + table='stats_test', + schema='pytest_stats_schema', + quick=False + ).fetchall() + + # Both should return results, but we can't guarantee behavior differences + # since it depends on the ODBC driver and database system + assert len(quick_stats) > 0, "quick=True should return results" + assert len(thorough_stats) > 0, "quick=False should return results" + + # Just verify that changing the parameter didn't cause errors + + finally: + # Clean up happens in test_statistics_cleanup + pass + +def test_statistics_cleanup(cursor, db_connection): + """Clean up test tables after testing""" + try: + # Drop all test tables + cursor.execute("DROP TABLE IF EXISTS pytest_stats_schema.stats_test") + cursor.execute("DROP TABLE IF EXISTS pytest_stats_schema.empty_stats_test") + + # Drop the test schema + cursor.execute("DROP SCHEMA IF EXISTS pytest_stats_schema") + db_connection.commit() + except Exception as e: + pytest.fail(f"Test cleanup failed: {e}") + +def test_columns_setup(cursor, db_connection): + """Create test tables for columns method testing""" + try: + # Create a test schema for isolation + cursor.execute("IF NOT EXISTS (SELECT * FROM sys.schemas WHERE name = 'pytest_cols_schema') EXEC('CREATE SCHEMA pytest_cols_schema')") + + # Drop tables if they exist + cursor.execute("DROP TABLE IF EXISTS pytest_cols_schema.columns_test") + cursor.execute("DROP TABLE IF EXISTS pytest_cols_schema.columns_special_test") + + # Create test table with various column types + cursor.execute(""" + CREATE TABLE pytest_cols_schema.columns_test ( + id INT PRIMARY KEY, + name NVARCHAR(100) NOT NULL, + description NVARCHAR(MAX) NULL, + price DECIMAL(10, 2) NULL, + created_date DATETIME DEFAULT GETDATE(), + is_active BIT NOT NULL DEFAULT 1, + binary_data VARBINARY(MAX) NULL, + notes TEXT NULL, + [computed_col] AS (name + ' - ' + CAST(id AS VARCHAR(10))) + ) + """) + + # Create table with special column names and edge cases - fix the problematic column name + cursor.execute(""" + CREATE TABLE pytest_cols_schema.columns_special_test ( + [ID] INT PRIMARY KEY, + [User Name] NVARCHAR(100) NULL, + [Spaces Multiple] VARCHAR(50) NULL, + [123_numeric_start] INT NULL, + [MAX] VARCHAR(20) NULL, -- SQL keyword as column name + [SELECT] INT NULL, -- SQL keyword as column name + [Column.With.Dots] VARCHAR(20) NULL, + [Column/With/Slashes] VARCHAR(20) NULL, + [Column_With_Underscores] VARCHAR(20) NULL -- Changed from problematic nested brackets + ) + """) + + db_connection.commit() + except Exception as e: + pytest.fail(f"Test setup failed: {e}") + +def test_columns_all(cursor, db_connection): + """Test columns returns information about all columns in all tables""" + try: + # First set up our test tables + test_columns_setup(cursor, db_connection) + + # Get all columns (no filters) + cols_cursor = cursor.columns() + cols = cols_cursor.fetchall() + + # Verify we got results + assert cols is not None, "columns() should return results" + assert len(cols) > 0, "columns() should return at least one column" + + # Verify our test tables' columns are in the results + # Use case-insensitive comparison to avoid driver case sensitivity issues + found_test_table = False + for col in cols: + if (hasattr(col, 'table_name') and + col.table_name and + col.table_name.lower() == 'columns_test' and + hasattr(col, 'table_schem') and + col.table_schem and + col.table_schem.lower() == 'pytest_cols_schema'): + found_test_table = True + break + + assert found_test_table, "Test table columns should be included in results" + + # Verify structure of results + first_row = cols[0] + assert hasattr(first_row, 'table_cat'), "Result should have table_cat column" + assert hasattr(first_row, 'table_schem'), "Result should have table_schem column" + assert hasattr(first_row, 'table_name'), "Result should have table_name column" + assert hasattr(first_row, 'column_name'), "Result should have column_name column" + assert hasattr(first_row, 'data_type'), "Result should have data_type column" + assert hasattr(first_row, 'type_name'), "Result should have type_name column" + assert hasattr(first_row, 'column_size'), "Result should have column_size column" + assert hasattr(first_row, 'buffer_length'), "Result should have buffer_length column" + assert hasattr(first_row, 'decimal_digits'), "Result should have decimal_digits column" + assert hasattr(first_row, 'num_prec_radix'), "Result should have num_prec_radix column" + assert hasattr(first_row, 'nullable'), "Result should have nullable column" + assert hasattr(first_row, 'remarks'), "Result should have remarks column" + assert hasattr(first_row, 'column_def'), "Result should have column_def column" + assert hasattr(first_row, 'sql_data_type'), "Result should have sql_data_type column" + assert hasattr(first_row, 'sql_datetime_sub'), "Result should have sql_datetime_sub column" + assert hasattr(first_row, 'char_octet_length'), "Result should have char_octet_length column" + assert hasattr(first_row, 'ordinal_position'), "Result should have ordinal_position column" + assert hasattr(first_row, 'is_nullable'), "Result should have is_nullable column" + + finally: + # Clean up happens in test_columns_cleanup + pass + +def test_columns_specific_table(cursor, db_connection): + """Test columns returns information about a specific table""" + try: + # Get columns for the test table + cols = cursor.columns( + table='columns_test', + schema='pytest_cols_schema' + ).fetchall() + + # Verify we got results + assert len(cols) == 9, "Should find exactly 9 columns in columns_test" + + # Verify all column names are present (case insensitive) + col_names = [col.column_name.lower() for col in cols] + expected_names = ['id', 'name', 'description', 'price', 'created_date', + 'is_active', 'binary_data', 'notes', 'computed_col'] + + for name in expected_names: + assert name in col_names, f"Column {name} should be in results" + + # Verify details of a specific column (id) + id_col = next(col for col in cols if col.column_name.lower() == 'id') + assert id_col.nullable == 0, "id column should be non-nullable" + assert id_col.ordinal_position == 1, "id should be the first column" + assert id_col.is_nullable == "NO", "is_nullable should be NO for id column" + + # Check data types (but don't assume specific ODBC type codes since they vary by driver) + # Instead check that the type_name is correct + id_type = id_col.type_name.lower() + assert 'int' in id_type, f"id column should be INTEGER type, got {id_type}" + + # Check a nullable column + desc_col = next(col for col in cols if col.column_name.lower() == 'description') + assert desc_col.nullable == 1, "description column should be nullable" + assert desc_col.is_nullable == "YES", "is_nullable should be YES for description column" + + finally: + # Clean up happens in test_columns_cleanup + pass + +def test_columns_special_chars(cursor, db_connection): + """Test columns with special characters and edge cases""" + try: + # Get columns for the special table + cols = cursor.columns( + table='columns_special_test', + schema='pytest_cols_schema' + ).fetchall() + + # Verify we got results + assert len(cols) == 9, "Should find exactly 9 columns in columns_special_test" + + # Check that special column names are handled correctly + col_names = [col.column_name for col in cols] + + # Create case-insensitive lookup + col_names_lower = [name.lower() if name else None for name in col_names] + + # Check for columns with special characters - note that column names might be + # returned with or without brackets/quotes depending on the driver + assert any('user name' in name.lower() for name in col_names), "Column with spaces should be in results" + assert any('id' == name.lower() for name in col_names), "ID column should be in results" + assert any('123_numeric_start' in name.lower() for name in col_names), "Column starting with numbers should be in results" + assert any('max' == name.lower() for name in col_names), "MAX column should be in results" + assert any('select' == name.lower() for name in col_names), "SELECT column should be in results" + assert any('column.with.dots' in name.lower() for name in col_names), "Column with dots should be in results" + assert any('column/with/slashes' in name.lower() for name in col_names), "Column with slashes should be in results" + assert any('column_with_underscores' in name.lower() for name in col_names), "Column with underscores should be in results" + + finally: + # Clean up happens in test_columns_cleanup + pass + +def test_columns_specific_column(cursor, db_connection): + """Test columns with specific column filter""" + try: + # Get specific column + cols = cursor.columns( + table='columns_test', + schema='pytest_cols_schema', + column='name' + ).fetchall() + + # Verify we got just one result + assert len(cols) == 1, "Should find exactly 1 column named 'name'" + + # Verify column details + col = cols[0] + assert col.column_name.lower() == 'name', "Column name should be 'name'" + assert col.table_name.lower() == 'columns_test', "Table name should be 'columns_test'" + assert col.table_schem.lower() == 'pytest_cols_schema', "Schema should be 'pytest_cols_schema'" + assert col.nullable == 0, "name column should be non-nullable" + + # Get column using pattern (% wildcard) + pattern_cols = cursor.columns( + table='columns_test', + schema='pytest_cols_schema', + column='%date%' + ).fetchall() + + # Should find created_date column + assert len(pattern_cols) == 1, "Should find 1 column matching '%date%'" + + assert pattern_cols[0].column_name.lower() == 'created_date', "Should find created_date column" + + # Get multiple columns with pattern + multi_cols = cursor.columns( + table='columns_test', + schema='pytest_cols_schema', + column='%d%' # Should match id, description, created_date + ).fetchall() + + # At least 3 columns should match this pattern + assert len(multi_cols) >= 3, "Should find at least 3 columns matching '%d%'" + match_names = [col.column_name.lower() for col in multi_cols] + assert 'id' in match_names, "id should match '%d%'" + assert 'description' in match_names, "description should match '%d%'" + assert 'created_date' in match_names, "created_date should match '%d%'" + + finally: + # Clean up happens in test_columns_cleanup + pass + +def test_columns_with_underscore_pattern(cursor): + """Test columns with underscore wildcard pattern""" + try: + # Get columns with underscore pattern (one character wildcard) + # Looking for 'id' (exactly 2 chars) + cols = cursor.columns( + table='columns_test', + schema='pytest_cols_schema', + column='__' + ).fetchall() + + # Should find 'id' column + id_found = False + for col in cols: + if col.column_name.lower() == 'id' and col.table_name.lower() == 'columns_test': + id_found = True + break + + assert id_found, "Should find 'id' column with pattern '__'" + + # Try a more complex pattern with both % and _ + # For example: '%_d%' matches any column with 'd' as the second or later character + pattern_cols = cursor.columns( + table='columns_test', + schema='pytest_cols_schema', + column='%_d%' + ).fetchall() + + # Should match 'id' (if considering case-insensitive) and 'created_date' + match_names = [col.column_name.lower() for col in pattern_cols + if col.table_name.lower() == 'columns_test'] + + # At least 'created_date' should match this pattern + assert 'created_date' in match_names, "created_date should match '%_d%'" + + finally: + # Clean up happens in test_columns_cleanup + pass + +def test_columns_nonexistent(cursor): + """Test columns with non-existent table or column""" + # Test with non-existent table + table_cols = cursor.columns(table='nonexistent_table_xyz123') + assert len(table_cols) == 0, "Should return empty list for non-existent table" + + # Test with non-existent column in existing table + col_cols = cursor.columns( + table='columns_test', + schema='pytest_cols_schema', + column='nonexistent_column_xyz123' + ).fetchall() + assert len(col_cols) == 0, "Should return empty list for non-existent column" + + # Test with non-existent schema + schema_cols = cursor.columns( + table='columns_test', + schema='nonexistent_schema_xyz123' + ).fetchall() + assert len(schema_cols) == 0, "Should return empty list for non-existent schema" + +def test_columns_data_types(cursor): + """Test columns returns correct data type information""" + try: + # Get all columns from test table + cols = cursor.columns( + table='columns_test', + schema='pytest_cols_schema' + ).fetchall() + + # Create a dictionary mapping column names to their details + col_dict = {col.column_name.lower(): col for col in cols} + + # Check data types by name (case insensitive checks) + # Note: We're checking type_name as a string to avoid SQL type code inconsistencies + # between drivers + + # INT column + assert 'int' in col_dict['id'].type_name.lower(), "id should be INT type" + + # NVARCHAR column + assert any(name in col_dict['name'].type_name.lower() + for name in ['nvarchar', 'varchar', 'char', 'wchar']), "name should be NVARCHAR type" + + # DECIMAL column + assert any(name in col_dict['price'].type_name.lower() + for name in ['decimal', 'numeric', 'money']), "price should be DECIMAL type" + + # BIT column + assert any(name in col_dict['is_active'].type_name.lower() + for name in ['bit', 'boolean']), "is_active should be BIT type" + + # TEXT column + assert any(name in col_dict['notes'].type_name.lower() + for name in ['text', 'char', 'varchar']), "notes should be TEXT type" + + # Check nullable flag + assert col_dict['id'].nullable == 0, "id should be non-nullable" + assert col_dict['description'].nullable == 1, "description should be nullable" + + # Check column size + assert col_dict['name'].column_size == 100, "name should have size 100" + + # Check decimal digits for numeric type + assert col_dict['price'].decimal_digits == 2, "price should have 2 decimal digits" + + finally: + # Clean up happens in test_columns_cleanup + pass + +def test_columns_nonexistent(cursor): + """Test columns with non-existent table or column""" + # Test with non-existent table + table_cols = cursor.columns(table='nonexistent_table_xyz123').fetchall() + assert len(table_cols) == 0, "Should return empty list for non-existent table" + + # Test with non-existent column in existing table + col_cols = cursor.columns( + table='columns_test', + schema='pytest_cols_schema', + column='nonexistent_column_xyz123' + ).fetchall() + assert len(col_cols) == 0, "Should return empty list for non-existent column" + + # Test with non-existent schema + schema_cols = cursor.columns( + table='columns_test', + schema='nonexistent_schema_xyz123' + ).fetchall() + assert len(schema_cols) == 0, "Should return empty list for non-existent schema" + +def test_columns_catalog_filter(cursor): + """Test columns with catalog filter""" + try: + # Get current database name + cursor.execute("SELECT DB_NAME() AS current_db") + current_db = cursor.fetchone().current_db + + # Get columns with current catalog + cols = cursor.columns( + table='columns_test', + catalog=current_db, + schema='pytest_cols_schema' + ).fetchall() + + # Verify catalog filter worked + assert len(cols) > 0, "Should find columns with correct catalog" + + # Check catalog in results + for col in cols: + # Some drivers might return None for catalog + if col.table_cat is not None: + assert col.table_cat.lower() == current_db.lower(), "Wrong table catalog" + + # Test with non-existent catalog + fake_cols = cursor.columns( + table='columns_test', + catalog='nonexistent_db_xyz123', + schema='pytest_cols_schema' + ).fetchall() + assert len(fake_cols) == 0, "Should return empty list for non-existent catalog" + + finally: + # Clean up happens in test_columns_cleanup + pass + +def test_columns_schema_pattern(cursor): + """Test columns with schema name pattern""" + try: + # Get columns with schema pattern + cols = cursor.columns( + table='columns_test', + schema='pytest_%' + ).fetchall() + + # Should find our test table columns + test_cols = [col for col in cols if col.table_name.lower() == 'columns_test'] + assert len(test_cols) > 0, "Should find columns using schema pattern" + + # Try a more specific pattern + specific_cols = cursor.columns( + table='columns_test', + schema='pytest_cols%' + ).fetchall() + + # Should still find our test table columns + test_cols = [col for col in specific_cols if col.table_name.lower() == 'columns_test'] + assert len(test_cols) > 0, "Should find columns using specific schema pattern" + + finally: + # Clean up happens in test_columns_cleanup + pass + +def test_columns_table_pattern(cursor): + """Test columns with table name pattern""" + try: + # Get columns with table pattern + cols = cursor.columns( + table='columns_%', + schema='pytest_cols_schema' + ).fetchall() + + # Should find columns from both test tables + tables_found = set() + for col in cols: + if col.table_name: + tables_found.add(col.table_name.lower()) + + assert 'columns_test' in tables_found, "Should find columns_test with pattern columns_%" + assert 'columns_special_test' in tables_found, "Should find columns_special_test with pattern columns_%" + + finally: + # Clean up happens in test_columns_cleanup + pass + +def test_columns_ordinal_position(cursor): + """Test ordinal_position is correct in columns results""" + try: + # Get columns for the test table + cols = cursor.columns( + table='columns_test', + schema='pytest_cols_schema' + ).fetchall() + + # Sort by ordinal position + sorted_cols = sorted(cols, key=lambda col: col.ordinal_position) + + # Verify positions are consecutive starting from 1 + for i, col in enumerate(sorted_cols, 1): + assert col.ordinal_position == i, f"Column {col.column_name} should have ordinal_position {i}" + + # First column should be id (primary key) + assert sorted_cols[0].column_name.lower() == 'id', "First column should be id" + + finally: + # Clean up happens in test_columns_cleanup + pass + +def test_columns_cleanup(cursor, db_connection): + """Clean up test tables after testing""" + try: + # Drop all test tables + cursor.execute("DROP TABLE IF EXISTS pytest_cols_schema.columns_test") + cursor.execute("DROP TABLE IF EXISTS pytest_cols_schema.columns_special_test") + + # Drop the test schema + cursor.execute("DROP SCHEMA IF EXISTS pytest_cols_schema") + db_connection.commit() + except Exception as e: + pytest.fail(f"Test cleanup failed: {e}") + def test_close(db_connection): """Test closing the cursor""" try: