diff --git a/mssql_python/cursor.py b/mssql_python/cursor.py index b6b309cf..f0b3a9e4 100644 --- a/mssql_python/cursor.py +++ b/mssql_python/cursor.py @@ -338,6 +338,16 @@ def _map_sql_type(self, param, parameters_list, i, min_val=None, max_val=None): parameters_list[i].scale, False, ) + + if isinstance(param, uuid.UUID): + parameters_list[i] = param.bytes_le + return ( + ddbc_sql_const.SQL_GUID.value, + ddbc_sql_const.SQL_C_GUID.value, + 16, + 0, + False, + ) if isinstance(param, str): if ( @@ -352,6 +362,20 @@ def _map_sql_type(self, param, parameters_list, i, min_val=None, max_val=None): 0, False, ) + + try: + val = uuid.UUID(param) + parameters_list[i] = val.bytes_le + return ( + ddbc_sql_const.SQL_GUID.value, + ddbc_sql_const.SQL_C_GUID.value, + 16, + 0, + False + ) + except ValueError: + pass + # Attempt to parse as date, datetime, datetime2, timestamp, smalldatetime or time if self._parse_date(param): diff --git a/mssql_python/pybind/ddbc_bindings.cpp b/mssql_python/pybind/ddbc_bindings.cpp index fe8197fd..a62866ee 100644 --- a/mssql_python/pybind/ddbc_bindings.cpp +++ b/mssql_python/pybind/ddbc_bindings.cpp @@ -504,7 +504,33 @@ SQLRETURN BindParameters(SQLHANDLE hStmt, const py::list& params, break; } case SQL_C_GUID: { - // TODO + if (!py::isinstance(param)) { + ThrowStdException(MakeParamMismatchErrorStr(paramInfo.paramCType, paramIndex)); + } + py::bytes uuid_bytes = param.cast(); + const unsigned char* uuid_data = reinterpret_cast(PyBytes_AS_STRING(uuid_bytes.ptr())); + if (PyBytes_GET_SIZE(uuid_bytes.ptr()) != 16) { + LOG("Invalid UUID parameter at index {}: expected 16 bytes, got {} bytes, type {}", paramIndex, PyBytes_GET_SIZE(uuid_bytes.ptr()), paramInfo.paramCType); + ThrowStdException("UUID binary data must be exactly 16 bytes long."); + } + SQLGUID* guid_data_ptr = AllocateParamBuffer(paramBuffers); + guid_data_ptr->Data1 = + (static_cast(uuid_data[3]) << 24) | + (static_cast(uuid_data[2]) << 16) | + (static_cast(uuid_data[1]) << 8) | + (static_cast(uuid_data[0])); + guid_data_ptr->Data2 = + (static_cast(uuid_data[5]) << 8) | + (static_cast(uuid_data[4])); + guid_data_ptr->Data3 = + (static_cast(uuid_data[7]) << 8) | + (static_cast(uuid_data[6])); + std::memcpy(guid_data_ptr->Data4, &uuid_data[8], 8); + dataPtr = static_cast(guid_data_ptr); + bufferLength = sizeof(SQLGUID); + strLenOrIndPtr = AllocateParamBuffer(paramBuffers); + *strLenOrIndPtr = sizeof(SQLGUID); + break; } default: { std::ostringstream errorString; @@ -2553,20 +2579,27 @@ SQLRETURN SQLGetData_wrap(SqlHandlePtr StatementHandle, SQLUSMALLINT colCount, p #if (ODBCVER >= 0x0350) case SQL_GUID: { SQLGUID guidValue; - ret = SQLGetData_ptr(hStmt, i, SQL_C_GUID, &guidValue, sizeof(guidValue), NULL); - if (SQL_SUCCEEDED(ret)) { - std::ostringstream oss; - oss << std::hex << std::setfill('0') << std::setw(8) << guidValue.Data1 << '-' - << std::setw(4) << guidValue.Data2 << '-' << std::setw(4) << guidValue.Data3 - << '-' << std::setw(2) << static_cast(guidValue.Data4[0]) - << std::setw(2) << static_cast(guidValue.Data4[1]) << '-' << std::hex - << std::setw(2) << static_cast(guidValue.Data4[2]) << std::setw(2) - << static_cast(guidValue.Data4[3]) << std::setw(2) - << static_cast(guidValue.Data4[4]) << std::setw(2) - << static_cast(guidValue.Data4[5]) << std::setw(2) - << static_cast(guidValue.Data4[6]) << std::setw(2) - << static_cast(guidValue.Data4[7]); - row.append(oss.str()); // Append GUID as a string + SQLLEN indicator; + ret = SQLGetData_ptr(hStmt, i, SQL_C_GUID, &guidValue, sizeof(guidValue), &indicator); + + if (SQL_SUCCEEDED(ret) && indicator != SQL_NULL_DATA) { + std::vector guid_bytes(16); + guid_bytes[0] = ((char*)&guidValue.Data1)[3]; + guid_bytes[1] = ((char*)&guidValue.Data1)[2]; + guid_bytes[2] = ((char*)&guidValue.Data1)[1]; + guid_bytes[3] = ((char*)&guidValue.Data1)[0]; + guid_bytes[4] = ((char*)&guidValue.Data2)[1]; + guid_bytes[5] = ((char*)&guidValue.Data2)[0]; + guid_bytes[6] = ((char*)&guidValue.Data3)[1]; + guid_bytes[7] = ((char*)&guidValue.Data3)[0]; + std::memcpy(&guid_bytes[8], guidValue.Data4, sizeof(guidValue.Data4)); + + py::bytes py_guid_bytes(guid_bytes.data(), guid_bytes.size()); + py::object uuid_module = py::module_::import("uuid"); + py::object uuid_obj = uuid_module.attr("UUID")(py::arg("bytes")=py_guid_bytes); + row.append(uuid_obj); + } else if (indicator == SQL_NULL_DATA) { + row.append(py::none()); } else { LOG("Error retrieving data for column - {}, data type - {}, SQLGetData return " "code - {}. Returning NULL value instead", @@ -2957,9 +2990,23 @@ SQLRETURN FetchBatchData(SQLHSTMT hStmt, ColumnBuffers& buffers, py::list& colum break; } case SQL_GUID: { - row.append( - py::bytes(reinterpret_cast(&buffers.guidBuffers[col - 1][i]), - sizeof(SQLGUID))); + SQLGUID* guidValue = &buffers.guidBuffers[col - 1][i]; + uint8_t reordered[16]; + reordered[0] = ((char*)&guidValue->Data1)[3]; + reordered[1] = ((char*)&guidValue->Data1)[2]; + reordered[2] = ((char*)&guidValue->Data1)[1]; + reordered[3] = ((char*)&guidValue->Data1)[0]; + reordered[4] = ((char*)&guidValue->Data2)[1]; + reordered[5] = ((char*)&guidValue->Data2)[0]; + reordered[6] = ((char*)&guidValue->Data3)[1]; + reordered[7] = ((char*)&guidValue->Data3)[0]; + std::memcpy(reordered + 8, guidValue->Data4, 8); + + py::bytes py_guid_bytes(reinterpret_cast(reordered), 16); + py::dict kwargs; + kwargs["bytes"] = py_guid_bytes; + py::object uuid_obj = py::module_::import("uuid").attr("UUID")(**kwargs); + row.append(uuid_obj); break; } case SQL_BINARY: diff --git a/tests/test_004_cursor.py b/tests/test_004_cursor.py index 9b7276ab..6b28a378 100644 --- a/tests/test_004_cursor.py +++ b/tests/test_004_cursor.py @@ -14,6 +14,8 @@ import decimal from contextlib import closing import mssql_python +import uuid + # Setup test table TEST_TABLE = """ @@ -6942,6 +6944,208 @@ def test_money_smallmoney_invalid_values(cursor, db_connection): drop_table_if_exists(cursor, "dbo.money_test") db_connection.commit() +def test_uuid_insert_and_select_none(cursor, db_connection): + """Test inserting and retrieving None in a nullable UUID column.""" + table_name = "#pytest_uuid_nullable" + try: + cursor.execute(f"DROP TABLE IF EXISTS {table_name}") + cursor.execute(f""" + CREATE TABLE {table_name} ( + id UNIQUEIDENTIFIER, + name NVARCHAR(50) + ) + """) + db_connection.commit() + + # Insert a row with None for the UUID + cursor.execute(f"INSERT INTO {table_name} (id, name) VALUES (?, ?)", [None, "Bob"]) + db_connection.commit() + + # Fetch the row + cursor.execute(f"SELECT id, name FROM {table_name}") + retrieved_uuid, retrieved_name = cursor.fetchone() + + # Assert correct results + assert retrieved_uuid is None, f"Expected None, got {retrieved_uuid}" + assert retrieved_name == "Bob" + finally: + cursor.execute(f"DROP TABLE IF EXISTS {table_name}") + db_connection.commit() + + +def test_insert_multiple_uuids(cursor, db_connection): + """Test inserting multiple UUIDs and verifying retrieval.""" + table_name = "#pytest_uuid_multiple" + try: + cursor.execute(f"DROP TABLE IF EXISTS {table_name}") + cursor.execute(f""" + CREATE TABLE {table_name} ( + id UNIQUEIDENTIFIER PRIMARY KEY, + description NVARCHAR(50) + ) + """) + db_connection.commit() + + # Prepare test data + uuids_to_insert = {f"Item {i}": uuid.uuid4() for i in range(5)} + + # Insert UUIDs and descriptions + for desc, uid in uuids_to_insert.items(): + cursor.execute(f"INSERT INTO {table_name} (id, description) VALUES (?, ?)", [uid, desc]) + db_connection.commit() + + # Fetch all rows + cursor.execute(f"SELECT id, description FROM {table_name}") + rows = cursor.fetchall() + + # Verify each fetched row + assert len(rows) == len(uuids_to_insert), "Fetched row count mismatch" + + for retrieved_uuid, retrieved_desc in rows: + assert isinstance(retrieved_uuid, uuid.UUID), f"Expected uuid.UUID, got {type(retrieved_uuid)}" + expected_uuid = uuids_to_insert[retrieved_desc] + assert retrieved_uuid == expected_uuid, f"UUID mismatch for '{retrieved_desc}': expected {expected_uuid}, got {retrieved_uuid}" + finally: + cursor.execute(f"DROP TABLE IF EXISTS {table_name}") + db_connection.commit() + + +def test_fetchmany_uuids(cursor, db_connection): + """Test fetching multiple UUID rows with fetchmany().""" + table_name = "#pytest_uuid_fetchmany" + try: + cursor.execute(f"DROP TABLE IF EXISTS {table_name}") + cursor.execute(f""" + CREATE TABLE {table_name} ( + id UNIQUEIDENTIFIER PRIMARY KEY, + description NVARCHAR(50) + ) + """) + db_connection.commit() + + uuids_to_insert = {f"Item {i}": uuid.uuid4() for i in range(10)} + + for desc, uid in uuids_to_insert.items(): + cursor.execute(f"INSERT INTO {table_name} (id, description) VALUES (?, ?)", [uid, desc]) + db_connection.commit() + + cursor.execute(f"SELECT id, description FROM {table_name}") + + # Fetch in batches of 3 + batch_size = 3 + fetched_rows = [] + while True: + batch = cursor.fetchmany(batch_size) + if not batch: + break + fetched_rows.extend(batch) + + # Verify all rows + assert len(fetched_rows) == len(uuids_to_insert), "Fetched row count mismatch" + for retrieved_uuid, retrieved_desc in fetched_rows: + assert isinstance(retrieved_uuid, uuid.UUID) + expected_uuid = uuids_to_insert[retrieved_desc] + assert retrieved_uuid == expected_uuid + finally: + cursor.execute(f"DROP TABLE IF EXISTS {table_name}") + db_connection.commit() + + +def test_uuid_insert_with_none(cursor, db_connection): + """Test inserting None into a UUID column results in a NULL value.""" + table_name = "#pytest_uuid_none" + try: + cursor.execute(f"DROP TABLE IF EXISTS {table_name}") + cursor.execute(f""" + CREATE TABLE {table_name} ( + id UNIQUEIDENTIFIER, + name NVARCHAR(50) + ) + """) + db_connection.commit() + + cursor.execute(f"INSERT INTO {table_name} (id, name) VALUES (?, ?)", [None, "Alice"]) + db_connection.commit() + + cursor.execute(f"SELECT id, name FROM {table_name}") + retrieved_uuid, retrieved_name = cursor.fetchone() + + assert retrieved_uuid is None, f"Expected NULL UUID, got {retrieved_uuid}" + assert retrieved_name == "Alice" + finally: + cursor.execute(f"DROP TABLE IF EXISTS {table_name}") + db_connection.commit() + +def test_invalid_uuid_inserts(cursor, db_connection): + """Test inserting invalid UUID values raises appropriate errors.""" + table_name = "#pytest_uuid_invalid" + try: + cursor.execute(f"DROP TABLE IF EXISTS {table_name}") + cursor.execute(f"CREATE TABLE {table_name} (id UNIQUEIDENTIFIER)") + db_connection.commit() + + invalid_values = [ + "12345", # Too short + "not-a-uuid", # Not a UUID string + 123456789, # Integer + 12.34, # Float + object() # Arbitrary object + ] + + for val in invalid_values: + with pytest.raises(Exception): + cursor.execute(f"INSERT INTO {table_name} (id) VALUES (?)", [val]) + db_connection.commit() + finally: + cursor.execute(f"DROP TABLE IF EXISTS {table_name}") + db_connection.commit() + +def test_duplicate_uuid_inserts(cursor, db_connection): + """Test that inserting duplicate UUIDs into a PK column raises an error.""" + table_name = "#pytest_uuid_duplicate" + try: + cursor.execute(f"DROP TABLE IF EXISTS {table_name}") + cursor.execute(f"CREATE TABLE {table_name} (id UNIQUEIDENTIFIER PRIMARY KEY)") + db_connection.commit() + + uid = uuid.uuid4() + cursor.execute(f"INSERT INTO {table_name} (id) VALUES (?)", [uid]) + db_connection.commit() + + with pytest.raises(Exception): + cursor.execute(f"INSERT INTO {table_name} (id) VALUES (?)", [uid]) + db_connection.commit() + finally: + cursor.execute(f"DROP TABLE IF EXISTS {table_name}") + db_connection.commit() + +def test_extreme_uuids(cursor, db_connection): + """Test inserting extreme but valid UUIDs.""" + table_name = "#pytest_uuid_extreme" + try: + cursor.execute(f"DROP TABLE IF EXISTS {table_name}") + cursor.execute(f"CREATE TABLE {table_name} (id UNIQUEIDENTIFIER)") + db_connection.commit() + + extreme_uuids = [ + uuid.UUID(int=0), # All zeros + uuid.UUID(int=(1 << 128) - 1), # All ones + ] + + for uid in extreme_uuids: + cursor.execute(f"INSERT INTO {table_name} (id) VALUES (?)", [uid]) + db_connection.commit() + + cursor.execute(f"SELECT id FROM {table_name}") + rows = cursor.fetchall() + fetched_uuids = [row[0] for row in rows] + + for uid in extreme_uuids: + assert uid in fetched_uuids, f"Extreme UUID {uid} not retrieved correctly" + finally: + cursor.execute(f"DROP TABLE IF EXISTS {table_name}") + db_connection.commit() + def test_decimal_separator_with_multiple_values(cursor, db_connection): """Test decimal separator with multiple different decimal values""" original_separator = mssql_python.getDecimalSeparator() @@ -10193,7 +10397,6 @@ def test_decimal_separator_calculations(cursor, db_connection): # Cleanup cursor.execute("DROP TABLE IF EXISTS #pytest_decimal_calc_test") - db_connection.commit() def test_close(db_connection): """Test closing the cursor"""