Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
79 changes: 73 additions & 6 deletions mssql_python/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
This module initializes the mssql_python package.
"""
import threading
import locale

# Exceptions
# https://www.python.org/dev/peps/pep-0249/#exceptions

Expand All @@ -13,25 +15,90 @@
paramstyle = "qmark"
threadsafety = 1

_settings_lock = threading.Lock()
# Initialize the locale setting only once at module import time
# This avoids thread-safety issues with locale
_DEFAULT_DECIMAL_SEPARATOR = "."
try:
# Get the locale setting once during module initialization
_locale_separator = locale.localeconv()['decimal_point']
if _locale_separator and len(_locale_separator) == 1:
_DEFAULT_DECIMAL_SEPARATOR = _locale_separator
except (AttributeError, KeyError, TypeError, ValueError):
pass # Keep the default "." if locale access fails

# Create a settings object to hold configuration
class Settings:
def __init__(self):
self.lowercase = False
# Use the pre-determined separator - no locale access here
self.decimal_separator = _DEFAULT_DECIMAL_SEPARATOR

# Create a global settings instance
# Global settings instance
_settings = Settings()
_settings_lock = threading.Lock()

# Define the get_settings function for internal use
def get_settings():
"""Return the global settings object"""
with _settings_lock:
_settings.lowercase = lowercase
return _settings

# Expose lowercase as a regular module variable that users can access and set
lowercase = _settings.lowercase
lowercase = _settings.lowercase # Default is False

# Set the initial decimal separator in C++
from .ddbc_bindings import DDBCSetDecimalSeparator
DDBCSetDecimalSeparator(_settings.decimal_separator)

# New functions for decimal separator control
def setDecimalSeparator(separator):
"""
Sets the decimal separator character used when parsing NUMERIC/DECIMAL values
from the database, e.g. the "." in "1,234.56".

The default is to use the current locale's "decimal_point" value when the module
was first imported, or "." if the locale is not available. This function overrides
the default.

Args:
separator (str): The character to use as decimal separator

Raises:
ValueError: If the separator is not a single character string
"""
# Type validation
if not isinstance(separator, str):
raise ValueError("Decimal separator must be a string")

# Length validation
if len(separator) == 0:
raise ValueError("Decimal separator cannot be empty")

if len(separator) > 1:
raise ValueError("Decimal separator must be a single character")

# Character validation
if separator.isspace():
raise ValueError("Whitespace characters are not allowed as decimal separators")

# Check for specific disallowed characters
if separator in ['\t', '\n', '\r', '\v', '\f']:
raise ValueError(f"Control character '{repr(separator)}' is not allowed as a decimal separator")

# Set in Python side settings
_settings.decimal_separator = separator

# Update the C++ side
from .ddbc_bindings import DDBCSetDecimalSeparator
DDBCSetDecimalSeparator(separator)

def getDecimalSeparator():
"""
Returns the decimal separator character used when parsing NUMERIC/DECIMAL values
from the database.

Returns:
str: The current decimal separator character
"""
return _settings.decimal_separator

# Import necessary modules
from .exceptions import (
Expand Down
1 change: 1 addition & 0 deletions mssql_python/cursor.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
MONEY_MIN = decimal.Decimal('-922337203685477.5808')
MONEY_MAX = decimal.Decimal('922337203685477.5807')


class Cursor:
"""
Represents a database cursor, which is used to manage the context of a fetch operation.
Expand Down
84 changes: 67 additions & 17 deletions mssql_python/pybind/ddbc_bindings.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2037,20 +2037,48 @@ SQLRETURN SQLGetData_wrap(SqlHandlePtr StatementHandle, SQLUSMALLINT colCount, p
ret = SQLGetData_ptr(hStmt, i, SQL_C_CHAR, numericStr, sizeof(numericStr), &indicator);

if (SQL_SUCCEEDED(ret)) {
if (indicator == SQL_NULL_DATA) {
row.append(py::none());
} else {
try {
std::string s(reinterpret_cast<const char*>(numericStr));
auto Decimal = py::module_::import("decimal").attr("Decimal");
row.append(Decimal(s));
} catch (const py::error_already_set& e) {
LOG("Error converting to Decimal: {}", e.what());
row.append(py::none());
try {
// Validate 'indicator' to avoid buffer overflow and fallback to a safe
// null-terminated read when length is unknown or out-of-range.
const char* cnum = reinterpret_cast<const char*>(numericStr);
size_t bufSize = sizeof(numericStr);
size_t safeLen = 0;

if (indicator > 0 && indicator <= static_cast<SQLLEN>(bufSize)) {
// indicator appears valid and within the buffer size
safeLen = static_cast<size_t>(indicator);
} else {
// indicator is unknown, zero, negative, or too large; determine length
// by searching for a terminating null (safe bounded scan)
for (size_t j = 0; j < bufSize; ++j) {
if (cnum[j] == '\0') {
safeLen = j;
break;
}
}
// if no null found, use the full buffer size as a conservative fallback
if (safeLen == 0 && bufSize > 0 && cnum[0] != '\0') {
safeLen = bufSize;
}
}

// Use the validated length to construct the string for Decimal
std::string numStr(cnum, safeLen);

// Create Python Decimal object
py::object decimalObj = py::module_::import("decimal").attr("Decimal")(numStr);

// Add to row
row.append(decimalObj);
} catch (const py::error_already_set& e) {
// If conversion fails, append None
LOG("Error converting to decimal: {}", e.what());
row.append(py::none());
}
} else {
LOG("Error retrieving data for column - {}, data type - {}, SQLGetData rc - {}",
}
else {
LOG("Error retrieving data for column - {}, data type - {}, SQLGetData return "
"code - {}. Returning NULL value instead",
i, dataType, ret);
row.append(py::none());
}
Expand Down Expand Up @@ -2560,11 +2588,24 @@ SQLRETURN FetchBatchData(SQLHSTMT hStmt, ColumnBuffers& buffers, py::list& colum
case SQL_DECIMAL:
case SQL_NUMERIC: {
try {
// Convert numericStr to py::decimal.Decimal and append to row
row.append(py::module_::import("decimal").attr("Decimal")(std::string(
reinterpret_cast<const char*>(
&buffers.charBuffers[col - 1][i * MAX_DIGITS_IN_NUMERIC]),
buffers.indicators[col - 1][i])));
// Convert the string to use the current decimal separator
std::string numStr(reinterpret_cast<const char*>(
&buffers.charBuffers[col - 1][i * MAX_DIGITS_IN_NUMERIC]),
buffers.indicators[col - 1][i]);

// Get the current separator in a thread-safe way
std::string separator = GetDecimalSeparator();

if (separator != ".") {
// Replace the driver's decimal point with our configured separator
size_t pos = numStr.find('.');
if (pos != std::string::npos) {
numStr.replace(pos, 1, separator);
}
}

// Convert to Python decimal
row.append(py::module_::import("decimal").attr("Decimal")(numStr));
} catch (const py::error_already_set& e) {
// Handle the exception, e.g., log the error and append py::none()
LOG("Error converting to decimal: {}", e.what());
Expand Down Expand Up @@ -3016,6 +3057,13 @@ void enable_pooling(int maxSize, int idleTimeout) {
});
}

// Thread-safe decimal separator setting
ThreadSafeDecimalSeparator g_decimalSeparator;

void DDBCSetDecimalSeparator(const std::string& separator) {
SetDecimalSeparator(separator);
}

// Architecture-specific defines
#ifndef ARCHITECTURE
#define ARCHITECTURE "win64" // Default to win64 if not defined during compilation
Expand Down Expand Up @@ -3102,6 +3150,8 @@ PYBIND11_MODULE(ddbc_bindings, m) {
py::arg("tableType") = std::wstring());
m.def("DDBCSQLFetchScroll", &SQLFetchScroll_wrap,
"Scroll to a specific position in the result set and optionally fetch data");
m.def("DDBCSetDecimalSeparator", &DDBCSetDecimalSeparator, "Set the decimal separator character");


// Add a version attribute
m.attr("__version__") = "1.0.0";
Expand Down
44 changes: 44 additions & 0 deletions mssql_python/pybind/ddbc_bindings.h
Original file line number Diff line number Diff line change
Expand Up @@ -413,3 +413,47 @@ inline std::wstring Utf8ToWString(const std::string& str) {
return converter.from_bytes(str);
#endif
}

// Thread-safe decimal separator accessor class
class ThreadSafeDecimalSeparator {
private:
std::string value;
mutable std::mutex mutex;

public:
// Constructor with default value
ThreadSafeDecimalSeparator() : value(".") {}

// Set the decimal separator with thread safety
void set(const std::string& separator) {
std::lock_guard<std::mutex> lock(mutex);
value = separator;
}

// Get the decimal separator with thread safety
std::string get() const {
std::lock_guard<std::mutex> lock(mutex);
return value;
}

// Returns whether the current separator is different from the default "."
bool isCustomSeparator() const {
std::lock_guard<std::mutex> lock(mutex);
return value != ".";
}
};

// Global instance
extern ThreadSafeDecimalSeparator g_decimalSeparator;

// Helper functions to replace direct access
inline void SetDecimalSeparator(const std::string& separator) {
g_decimalSeparator.set(separator);
}

inline std::string GetDecimalSeparator() {
return g_decimalSeparator.get();
}

// Function to set the decimal separator
void DDBCSetDecimalSeparator(const std::string& separator);
20 changes: 19 additions & 1 deletion mssql_python/row.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,25 @@ def __iter__(self):

def __str__(self):
"""Return string representation of the row"""
return str(tuple(self._values))
from decimal import Decimal
from mssql_python import getDecimalSeparator

parts = []
for value in self:
if isinstance(value, Decimal):
# Apply custom decimal separator for display
sep = getDecimalSeparator()
if sep != '.' and value is not None:
s = str(value)
if '.' in s:
s = s.replace('.', sep)
parts.append(s)
else:
parts.append(str(value))
else:
parts.append(repr(value))

return "(" + ", ".join(parts) + ")"

def __repr__(self):
"""Return a detailed string representation for debugging"""
Expand Down
Loading
Loading