Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion mssql_python/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,11 @@
# Constants
from .constants import ConstantsDDBC

# Export specific constants for setencoding()
SQL_CHAR = ConstantsDDBC.SQL_CHAR.value
SQL_WCHAR = ConstantsDDBC.SQL_WCHAR.value
SQL_WMETADATA = -99

# GLOBALS
# Read-Only
apilevel = "2.0"
Expand All @@ -71,4 +76,3 @@ def pooling(max_size=100, idle_timeout=600, enabled=True):
PoolingManager.disable()
else:
PoolingManager.enable(max_size, idle_timeout)

302 changes: 301 additions & 1 deletion mssql_python/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,44 @@
"""
import weakref
import re
import codecs
from mssql_python.cursor import Cursor
from mssql_python.helpers import add_driver_to_connection_str, sanitize_connection_string, log
from mssql_python.helpers import add_driver_to_connection_str, sanitize_connection_string, sanitize_user_input, log
from mssql_python import ddbc_bindings
from mssql_python.pooling import PoolingManager
from mssql_python.exceptions import InterfaceError, ProgrammingError
from mssql_python.auth import process_connection_string
from mssql_python.constants import ConstantsDDBC

# Add SQL_WMETADATA constant for metadata decoding configuration
SQL_WMETADATA = -99 # Special flag for column name decoding

# UTF-16 encoding variants that should use SQL_WCHAR by default
UTF16_ENCODINGS = frozenset([
'utf-16',
'utf-16le',
'utf-16be'
])

def _validate_encoding(encoding: str) -> bool:
"""
Cached encoding validation using codecs.lookup().

Args:
encoding (str): The encoding name to validate.

Returns:
bool: True if encoding is valid, False otherwise.

Note:
Uses LRU cache to avoid repeated expensive codecs.lookup() calls.
Cache size is limited to 128 entries which should cover most use cases.
"""
try:
codecs.lookup(encoding)
return True
except LookupError:
return False

# Import all DB-API 2.0 exception classes for Connection attributes
from mssql_python.exceptions import (
Expand Down Expand Up @@ -68,6 +101,9 @@ class Connection:
close() -> None:
__enter__() -> Connection:
__exit__() -> None:
setencoding(encoding=None, ctype=None) -> None:
setdecoding(sqltype, encoding=None, ctype=None) -> None:
getdecoding(sqltype) -> dict:
"""

# DB-API 2.0 Exception attributes
Expand Down Expand Up @@ -108,6 +144,29 @@ def __init__(self, connection_str: str = "", autocommit: bool = False, attrs_bef
)
self._attrs_before = attrs_before or {}

# Initialize encoding settings with defaults for Python 3
# Python 3 only has str (which is Unicode), so we use utf-16le by default
self._encoding_settings = {
'encoding': 'utf-16le',
'ctype': ConstantsDDBC.SQL_WCHAR.value
}

# Initialize decoding settings with Python 3 defaults
self._decoding_settings = {
ConstantsDDBC.SQL_CHAR.value: {
'encoding': 'utf-8',
'ctype': ConstantsDDBC.SQL_CHAR.value
},
ConstantsDDBC.SQL_WCHAR.value: {
'encoding': 'utf-16le',
'ctype': ConstantsDDBC.SQL_WCHAR.value
},
SQL_WMETADATA: {
'encoding': 'utf-16le',
'ctype': ConstantsDDBC.SQL_WCHAR.value
}
}

# Check if the connection string contains authentication parameters
# This is important for processing the connection string correctly.
# If authentication is specified, it will be processed to handle
Expand Down Expand Up @@ -204,6 +263,247 @@ def setautocommit(self, value: bool = False) -> None:
"""
self._conn.set_autocommit(value)

def setencoding(self, encoding=None, ctype=None):
"""
Sets the text encoding for SQL statements and text parameters.

Since Python 3 only has str (which is Unicode), this method configures
how text is encoded when sending to the database.

Args:
encoding (str, optional): The encoding to use. This must be a valid Python
encoding that converts text to bytes. If None, defaults to 'utf-16le'.
ctype (int, optional): The C data type to use when passing data:
SQL_CHAR or SQL_WCHAR. If not provided, SQL_WCHAR is used for
UTF-16 variants (see UTF16_ENCODINGS constant). SQL_CHAR is used for all other encodings.

Returns:
None

Raises:
ProgrammingError: If the encoding is not valid or not supported.
InterfaceError: If the connection is closed.

Example:
# For databases that only communicate with UTF-8
cnxn.setencoding(encoding='utf-8')

# For explicitly using SQL_CHAR
cnxn.setencoding(encoding='utf-8', ctype=mssql_python.SQL_CHAR)
"""
if self._closed:
raise InterfaceError(
driver_error="Connection is closed",
ddbc_error="Connection is closed",
)

# Set default encoding if not provided
if encoding is None:
encoding = 'utf-16le'

# Validate encoding using cached validation for better performance
if not _validate_encoding(encoding):
# Log the sanitized encoding for security
log('warning', "Invalid encoding attempted: %s", sanitize_user_input(str(encoding)))
raise ProgrammingError(
driver_error=f"Unsupported encoding: {encoding}",
ddbc_error=f"The encoding '{encoding}' is not supported by Python",
)

# Normalize encoding to casefold for more robust Unicode handling
encoding = encoding.casefold()

# Set default ctype based on encoding if not provided
if ctype is None:
if encoding in UTF16_ENCODINGS:
ctype = ConstantsDDBC.SQL_WCHAR.value
else:
ctype = ConstantsDDBC.SQL_CHAR.value

# Validate ctype
valid_ctypes = [ConstantsDDBC.SQL_CHAR.value, ConstantsDDBC.SQL_WCHAR.value]
if ctype not in valid_ctypes:
# Log the sanitized ctype for security
log('warning', "Invalid ctype attempted: %s", sanitize_user_input(str(ctype)))
raise ProgrammingError(
driver_error=f"Invalid ctype: {ctype}",
ddbc_error=f"ctype must be SQL_CHAR ({ConstantsDDBC.SQL_CHAR.value}) or SQL_WCHAR ({ConstantsDDBC.SQL_WCHAR.value})",
)

# Store the encoding settings
self._encoding_settings = {
'encoding': encoding,
'ctype': ctype
}

# Log with sanitized values for security
log('info', "Text encoding set to %s with ctype %s",
sanitize_user_input(encoding), sanitize_user_input(str(ctype)))

def getencoding(self):
"""
Gets the current text encoding settings.

Returns:
dict: A dictionary containing 'encoding' and 'ctype' keys.

Raises:
InterfaceError: If the connection is closed.

Example:
settings = cnxn.getencoding()
print(f"Current encoding: {settings['encoding']}")
print(f"Current ctype: {settings['ctype']}")
"""
if self._closed:
raise InterfaceError(
driver_error="Connection is closed",
ddbc_error="Connection is closed",
)

return self._encoding_settings.copy()

def setdecoding(self, sqltype, encoding=None, ctype=None):
"""
Sets the text decoding used when reading SQL_CHAR and SQL_WCHAR from the database.

This method configures how text data is decoded when reading from the database.
In Python 3, all text is Unicode (str), so this primarily affects the encoding
used to decode bytes from the database.

Args:
sqltype (int): The SQL type being configured: SQL_CHAR, SQL_WCHAR, or SQL_WMETADATA.
SQL_WMETADATA is a special flag for configuring column name decoding.
encoding (str, optional): The Python encoding to use when decoding the data.
If None, uses default encoding based on sqltype.
ctype (int, optional): The C data type to request from SQLGetData:
SQL_CHAR or SQL_WCHAR. If None, uses default based on encoding.

Returns:
None

Raises:
ProgrammingError: If the sqltype, encoding, or ctype is invalid.
InterfaceError: If the connection is closed.

Example:
# Configure SQL_CHAR to use UTF-8 decoding
cnxn.setdecoding(mssql_python.SQL_CHAR, encoding='utf-8')

# Configure column metadata decoding
cnxn.setdecoding(mssql_python.SQL_WMETADATA, encoding='utf-16le')

# Use explicit ctype
cnxn.setdecoding(mssql_python.SQL_WCHAR, encoding='utf-16le', ctype=mssql_python.SQL_WCHAR)
"""
if self._closed:
raise InterfaceError(
driver_error="Connection is closed",
ddbc_error="Connection is closed",
)

# Validate sqltype
valid_sqltypes = [
ConstantsDDBC.SQL_CHAR.value,
ConstantsDDBC.SQL_WCHAR.value,
SQL_WMETADATA
]
if sqltype not in valid_sqltypes:
log('warning', "Invalid sqltype attempted: %s", sanitize_user_input(str(sqltype)))
raise ProgrammingError(
driver_error=f"Invalid sqltype: {sqltype}",
ddbc_error=f"sqltype must be SQL_CHAR ({ConstantsDDBC.SQL_CHAR.value}), SQL_WCHAR ({ConstantsDDBC.SQL_WCHAR.value}), or SQL_WMETADATA ({SQL_WMETADATA})",
)

# Set default encoding based on sqltype if not provided
if encoding is None:
if sqltype == ConstantsDDBC.SQL_CHAR.value:
encoding = 'utf-8' # Default for SQL_CHAR in Python 3
else: # SQL_WCHAR or SQL_WMETADATA
encoding = 'utf-16le' # Default for SQL_WCHAR in Python 3

# Validate encoding using cached validation for better performance
if not _validate_encoding(encoding):
log('warning', "Invalid encoding attempted: %s", sanitize_user_input(str(encoding)))
raise ProgrammingError(
driver_error=f"Unsupported encoding: {encoding}",
ddbc_error=f"The encoding '{encoding}' is not supported by Python",
)

# Normalize encoding to lowercase for consistency
encoding = encoding.lower()

# Set default ctype based on encoding if not provided
if ctype is None:
if encoding in UTF16_ENCODINGS:
ctype = ConstantsDDBC.SQL_WCHAR.value
else:
ctype = ConstantsDDBC.SQL_CHAR.value

# Validate ctype
valid_ctypes = [ConstantsDDBC.SQL_CHAR.value, ConstantsDDBC.SQL_WCHAR.value]
if ctype not in valid_ctypes:
log('warning', "Invalid ctype attempted: %s", sanitize_user_input(str(ctype)))
raise ProgrammingError(
driver_error=f"Invalid ctype: {ctype}",
ddbc_error=f"ctype must be SQL_CHAR ({ConstantsDDBC.SQL_CHAR.value}) or SQL_WCHAR ({ConstantsDDBC.SQL_WCHAR.value})",
)

# Store the decoding settings for the specified sqltype
self._decoding_settings[sqltype] = {
'encoding': encoding,
'ctype': ctype
}

# Log with sanitized values for security
sqltype_name = {
ConstantsDDBC.SQL_CHAR.value: "SQL_CHAR",
ConstantsDDBC.SQL_WCHAR.value: "SQL_WCHAR",
SQL_WMETADATA: "SQL_WMETADATA"
}.get(sqltype, str(sqltype))

log('info', "Text decoding set for %s to %s with ctype %s",
sqltype_name, sanitize_user_input(encoding), sanitize_user_input(str(ctype)))

def getdecoding(self, sqltype):
"""
Gets the current text decoding settings for the specified SQL type.

Args:
sqltype (int): The SQL type to get settings for: SQL_CHAR, SQL_WCHAR, or SQL_WMETADATA.

Returns:
dict: A dictionary containing 'encoding' and 'ctype' keys for the specified sqltype.

Raises:
ProgrammingError: If the sqltype is invalid.
InterfaceError: If the connection is closed.

Example:
settings = cnxn.getdecoding(mssql_python.SQL_CHAR)
print(f"SQL_CHAR encoding: {settings['encoding']}")
print(f"SQL_CHAR ctype: {settings['ctype']}")
"""
if self._closed:
raise InterfaceError(
driver_error="Connection is closed",
ddbc_error="Connection is closed",
)

# Validate sqltype
valid_sqltypes = [
ConstantsDDBC.SQL_CHAR.value,
ConstantsDDBC.SQL_WCHAR.value,
SQL_WMETADATA
]
if sqltype not in valid_sqltypes:
raise ProgrammingError(
driver_error=f"Invalid sqltype: {sqltype}",
ddbc_error=f"sqltype must be SQL_CHAR ({ConstantsDDBC.SQL_CHAR.value}), SQL_WCHAR ({ConstantsDDBC.SQL_WCHAR.value}), or SQL_WMETADATA ({SQL_WMETADATA})",
)

return self._decoding_settings[sqltype].copy()

def cursor(self) -> Cursor:
"""
Return a new Cursor object using the connection.
Expand Down
28 changes: 28 additions & 0 deletions mssql_python/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,34 @@ def sanitize_connection_string(conn_str: str) -> str:
return re.sub(r"(Pwd\s*=\s*)[^;]*", r"\1***", conn_str, flags=re.IGNORECASE)


def sanitize_user_input(user_input: str, max_length: int = 50) -> str:
"""
Sanitize user input for safe logging by removing control characters,
limiting length, and ensuring safe characters only.

Args:
user_input (str): The user input to sanitize.
max_length (int): Maximum length of the sanitized output.

Returns:
str: The sanitized string safe for logging.
"""
if not isinstance(user_input, str):
return "<non-string>"

# Remove control characters and non-printable characters
import re
# Allow alphanumeric, dash, underscore, and dot (common in encoding names)
sanitized = re.sub(r'[^\w\-\.]', '', user_input)

# Limit length to prevent log flooding
if len(sanitized) > max_length:
sanitized = sanitized[:max_length] + "..."

# Return placeholder if nothing remains after sanitization
return sanitized if sanitized else "<invalid>"


def log(level: str, message: str, *args) -> None:
"""
Universal logging helper that gets a fresh logger instance.
Expand Down
2 changes: 1 addition & 1 deletion mssql_python/type.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ def Binary(value) -> bytes:
"""
Converts a string or bytes to bytes for use with binary database columns.

This function follows the DB-API 2.0 specification and pyodbc compatibility.
This function follows the DB-API 2.0 specification.
It accepts only str and bytes/bytearray types to ensure type safety.

Args:
Expand Down
Loading
Loading