Skip to content

Commit ae549ba

Browse files
committed
fix(execute): fixed a bug which resulted in Pandas Timestamp datatype to be sent to Redshift as the DATE instead of TIMESTAMP/TIMESTAMPTZ when statements are executed using bind parameters. issue #206
1 parent 550006e commit ae549ba

File tree

3 files changed

+141
-0
lines changed

3 files changed

+141
-0
lines changed

redshift_connector/core.py

+28
Original file line numberDiff line numberDiff line change
@@ -1609,11 +1609,19 @@ def make_params(self: "Connection", values) -> typing.Tuple[typing.Tuple[int, in
16091609
for value in values:
16101610
typ: typing.Type = type(value)
16111611
try:
1612+
# 1) check if we have a direct match for this datatype in PY_TYPES mapping
16121613
params.append(self.py_types[typ])
16131614
except KeyError:
16141615
try:
1616+
# 2) if no match was found in 1) check if we have a match in inspect_funcs.
1617+
# note that inspect_funcs inspect the data value to determine the type.
1618+
# e.g. if the datatype is a Datetime and has a timezone, we want to map it
1619+
# to TIMETSTAMPTZ rather than TIMESTAMP.
16151620
params.append(self.inspect_funcs[typ](value))
16161621
except KeyError as e:
1622+
# 3) if no match was found in 1) nor 2), we again iterate through PY_TYPES but
1623+
# check if our data is an instance of any datatypes found in PY_TYPES
1624+
# rather than looking for an exact match as was performed in 1)
16171625
param: typing.Optional[typing.Tuple[int, int, typing.Callable]] = None
16181626
for k, v in self.py_types.items():
16191627
try:
@@ -1624,6 +1632,9 @@ def make_params(self: "Connection", values) -> typing.Tuple[typing.Tuple[int, in
16241632
pass
16251633

16261634
if param is None:
1635+
# 4) if no match was found in 1) nor 2) nor 3), we again iterate through
1636+
# inspect_funcs but check if our data is an instance of any datatype
1637+
# found in inspect_funcs
16271638
for k, v in self.inspect_funcs.items(): # type: ignore
16281639
try:
16291640
if isinstance(value, k):
@@ -1634,6 +1645,23 @@ def make_params(self: "Connection", values) -> typing.Tuple[typing.Tuple[int, in
16341645
pass
16351646
except KeyError:
16361647
pass
1648+
elif param[0] == RedshiftOID.DATE:
1649+
# 5) if we classified this data as DATE in 3), we perform a secondary check to
1650+
# ensure this data was not misclassified in 3). Misclassification occurs in the case
1651+
# where data having a type that is a subclass of datetime.date is also a subclass of
1652+
# datetime.datetime. For this example, the misclassification leads to the loss of
1653+
# time precision in transformed data sent to Redshift on the wire. The simplest
1654+
# example of this edge case can be seen in
1655+
# https://github.com/aws/amazon-redshift-python-driver/issues/206
1656+
# where a pandas Timestamp is misclassified as Redshift DATE in 3).
1657+
if isinstance(value, Datetime) and Datetime in self.inspect_funcs:
1658+
try:
1659+
v_func = typing.cast(typing.Callable, self.inspect_funcs[Datetime])
1660+
param = v_func(value)
1661+
except TypeError:
1662+
pass
1663+
except KeyError:
1664+
pass
16371665

16381666
if param is None:
16391667
raise NotSupportedError("type " + str(e) + " not mapped to pg type")

test/unit/test_connection.py

+40
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
import typing
33
from collections import deque
44
from decimal import Decimal
5+
from test.utils import pandas_only
56
from unittest import mock
67
from unittest.mock import patch
78

@@ -520,3 +521,42 @@ def test_broken_pipe_timeout_on_connect(db_kwargs) -> None:
520521
db_kwargs.pop("region")
521522
db_kwargs.pop("cluster_identifier")
522523
Connection(**db_kwargs)
524+
525+
526+
def make_mock_connection(db_kwargs):
527+
db_kwargs["ssl"] = False
528+
db_kwargs["timeout"] = 60
529+
530+
with mock.patch("socket.getaddrinfo") as mock_getaddrinfo:
531+
addr_tuple = [(0, 1, 2, "", ("3.226.18.73", 5439)), (2, 1, 6, "", ("3.226.18.73", 5439))]
532+
mock_getaddrinfo.return_value = addr_tuple
533+
with mock.patch("socket.socket.connect") as mock_usock:
534+
mock_usock.side_effect = lambda *args, **kwargs: None
535+
536+
with mock.patch("socket.socket.makefile") as mock_sock:
537+
mock_file = mock_sock.return_value
538+
mock_file.read.return_value = b"Zasej"
539+
db_kwargs.pop("region")
540+
db_kwargs.pop("cluster_identifier")
541+
return Connection(**db_kwargs)
542+
543+
544+
@pandas_only
545+
def test_make_params_maps_pandas_timestamp_to_timestamp(db_kwargs):
546+
import datetime
547+
548+
import pandas as pd
549+
550+
from redshift_connector.utils.oids import RedshiftOID
551+
from redshift_connector.utils.type_utils import py_types, timestamptz_send_integer
552+
553+
columns = ["dw_inserted_at"]
554+
values = pd.DataFrame({col: [datetime.datetime.now(datetime.timezone.utc)] * 1 for col in columns}).values.tolist()[
555+
0
556+
]
557+
558+
mock_connection: Connection = make_mock_connection(db_kwargs)
559+
res = mock_connection.make_params(values)
560+
assert res[0][0] == RedshiftOID.TIMESTAMPTZ
561+
assert res[0][1] == 1
562+
assert res[0][2] == timestamptz_send_integer

test/unit/test_cursor.py

+73
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import datetime
12
import typing
23
from io import StringIO
34
from math import ceil
@@ -502,3 +503,75 @@ def test_write_dataframe_handles_npdtyes(mocker):
502503
assert len(spy.mock_calls[1].args[1]) == 1
503504
# bind parameter list should not contain numpy objects
504505
assert not isinstance(spy.mock_calls[1].args[1][0], np.generic)
506+
507+
508+
@pandas_only
509+
def test_write_dataframe_handles_pandas_types(mocker):
510+
import pandas as pd
511+
512+
mocker.patch("redshift_connector.Cursor.execute", return_value=None)
513+
mocker.patch("redshift_connector.Cursor.fetchone", return_value=[1])
514+
mock_cursor: Cursor = Cursor.__new__(Cursor)
515+
mock_connection: Connection = Connection.__new__(Connection)
516+
mock_cursor._c = mock_connection
517+
518+
mock_cursor.paramstyle = "mocked_val"
519+
520+
for datatype, data, _type in (
521+
("int64", pd.Series([42]), int),
522+
("float64", pd.Series([3.14]), float),
523+
("object", pd.Series(["Hello, Pandas!"]), str),
524+
("bool", pd.Series([True]), bool),
525+
("datetime64", pd.Series([pd.Timestamp("2022-01-01")]), int),
526+
("timedelta64", pd.Series([pd.Timedelta(days=5)]), int),
527+
):
528+
spy = mocker.spy(mock_cursor, "execute")
529+
dataframe = pd.DataFrame(data)
530+
mock_cursor.write_dataframe(df=dataframe, table=datatype)
531+
532+
assert spy.called
533+
assert spy.call_count == 2 # once for __is_valid_table, once for write_dataframe
534+
assert not isinstance(spy.mock_calls[1].args[1], pd.core.base.PandasObject)
535+
assert isinstance(spy.mock_calls[1].args[1], list)
536+
assert len(spy.mock_calls[1].args[1]) == 1
537+
# bind parameter list should not contain numpy objects
538+
assert isinstance(spy.mock_calls[1].args[1][0], _type)
539+
540+
541+
@pandas_only
542+
@pytest.mark.parametrize(
543+
"datatype,data,_type",
544+
(
545+
("int", 42, int),
546+
("float", 3.14, float),
547+
("str", "H", str),
548+
("bool", True, bool),
549+
("list", [1, 2, 3], list),
550+
("tuple", (4, 5, 6), tuple),
551+
("set", {1, 2, 3}, set),
552+
("datetime", datetime.datetime.now(datetime.timezone.utc), datetime.datetime),
553+
),
554+
)
555+
def test_write_dataframe_handles_python_types(mocker, datatype, data, _type):
556+
import datetime
557+
558+
import pandas as pd
559+
560+
mocker.patch("redshift_connector.Cursor.execute", return_value=None)
561+
mocker.patch("redshift_connector.Cursor.fetchone", return_value=[1])
562+
mock_cursor: Cursor = Cursor.__new__(Cursor)
563+
mock_connection: Connection = Connection.__new__(Connection)
564+
mock_cursor._c = mock_connection
565+
566+
mock_cursor.paramstyle = "mocked_val"
567+
568+
spy = mocker.spy(mock_cursor, "execute")
569+
dataframe = pd.DataFrame({col: [data] * 1 for col in (datatype,)})
570+
mock_cursor.write_dataframe(df=dataframe, table=datatype)
571+
572+
assert spy.called
573+
assert spy.call_count == 2 # once for __is_valid_table, once for write_dataframe
574+
assert not isinstance(spy.mock_calls[1].args[1], pd.core.base.PandasObject)
575+
assert isinstance(spy.mock_calls[1].args[1], list)
576+
assert len(spy.mock_calls[1].args[1]) == 1
577+
assert isinstance((spy.mock_calls[1].args[1][0]), _type)

0 commit comments

Comments
 (0)