Skip to content

Commit

Permalink
Fix time zone localization issue by using zoneinfo rather than `pyt…
Browse files Browse the repository at this point in the history
…z` when specifying time zones
  • Loading branch information
john-bodley committed May 3, 2023
1 parent cd30eed commit cbd4c6b
Show file tree
Hide file tree
Showing 5 changed files with 43 additions and 40 deletions.
1 change: 1 addition & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,7 @@
python_requires='>=3.7',
install_requires=[
"backports.zoneinfo;python_version<'3.9'",
"python-dateutil",
"pytz",
"requests",
"tzlocal",
Expand Down
39 changes: 19 additions & 20 deletions tests/integration/test_dbapi_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,12 @@
from decimal import Decimal
from typing import Tuple

try:
from zoneinfo import ZoneInfo
except ModuleNotFoundError:
from backports.zoneinfo import ZoneInfo

import pytest
import pytz
import requests
from tzlocal import get_localzone_name # type: ignore

Expand Down Expand Up @@ -234,7 +238,7 @@ def test_legacy_primitive_types_with_connection_and_cursor(
assert rows[0][0] == Decimal('0.142857')
assert rows[0][1] == date(2018, 1, 1)
assert rows[0][2] == datetime(2019, 1, 1, tzinfo=timezone(timedelta(hours=1)))
assert rows[0][3] == datetime(2019, 1, 1, tzinfo=pytz.timezone('UTC'))
assert rows[0][3] == datetime(2019, 1, 1, tzinfo=ZoneInfo('UTC'))
assert rows[0][4] == datetime(2019, 1, 1)
assert rows[0][5] == time(0, 0, 0, 0)
else:
Expand Down Expand Up @@ -338,7 +342,7 @@ def test_datetime_query_param(trino_connection):
def test_datetime_with_utc_time_zone_query_param(trino_connection):
cur = trino_connection.cursor()

params = datetime(2020, 1, 1, 16, 43, 22, 320000, tzinfo=pytz.timezone('UTC'))
params = datetime(2020, 1, 1, 16, 43, 22, 320000, tzinfo=ZoneInfo('UTC'))

cur.execute("SELECT ?", params=(params,))
rows = cur.fetchall()
Expand All @@ -364,7 +368,7 @@ def test_datetime_with_numeric_offset_time_zone_query_param(trino_connection):
def test_datetime_with_named_time_zone_query_param(trino_connection):
cur = trino_connection.cursor()

params = datetime(2020, 1, 1, 16, 43, 22, 320000, tzinfo=pytz.timezone('America/Los_Angeles'))
params = datetime(2020, 1, 1, 16, 43, 22, 320000, tzinfo=ZoneInfo('America/Los_Angeles'))

cur.execute("SELECT ?", params=(params,))
rows = cur.fetchall()
Expand Down Expand Up @@ -407,32 +411,24 @@ def test_datetimes_with_time_zone_in_dst_gap_query_param(trino_connection):
cur = trino_connection.cursor()

# This is a datetime that lies within a DST transition and not actually exists.
params = datetime(2021, 3, 28, 2, 30, 0, tzinfo=pytz.timezone('Europe/Brussels'))
params = datetime(2021, 3, 28, 2, 30, 0, tzinfo=ZoneInfo('Europe/Brussels'))
with pytest.raises(trino.exceptions.TrinoUserError):
cur.execute("SELECT ?", params=(params,))
cur.fetchall()


def test_doubled_datetimes(trino_connection):
# Trino doesn't distinguish between doubled datetimes that lie within a DST transition. See also
@pytest.mark.parametrize('fold', [0, 1])
def test_doubled_datetimes(trino_connection, fold):
# Trino doesn't distinguish between doubled datetimes that lie within a DST transition.
# See also https://github.com/trinodb/trino/issues/5781
cur = trino_connection.cursor()

params = pytz.timezone('US/Eastern').localize(datetime(2002, 10, 27, 1, 30, 0), is_dst=True)
params = datetime(2002, 10, 27, 1, 30, 0, tzinfo=ZoneInfo('US/Eastern'), fold=fold)

cur.execute("SELECT ?", params=(params,))
rows = cur.fetchall()

assert rows[0][0] == datetime(2002, 10, 27, 1, 30, 0, tzinfo=pytz.timezone('US/Eastern'))

cur = trino_connection.cursor()

params = pytz.timezone('US/Eastern').localize(datetime(2002, 10, 27, 1, 30, 0), is_dst=False)

cur.execute("SELECT ?", params=(params,))
rows = cur.fetchall()

assert rows[0][0] == datetime(2002, 10, 27, 1, 30, 0, tzinfo=pytz.timezone('US/Eastern'))
assert rows[0][0] == datetime(2002, 10, 27, 1, 30, 0, tzinfo=ZoneInfo('US/Eastern'))


def test_date_query_param(trino_connection):
Expand Down Expand Up @@ -529,7 +525,7 @@ def test_time_query_param(trino_connection):
def test_time_with_named_time_zone_query_param(trino_connection):
cur = trino_connection.cursor()

params = time(16, 43, 22, 320000, tzinfo=pytz.timezone('Asia/Shanghai'))
params = time(16, 43, 22, 320000, tzinfo=ZoneInfo('Asia/Shanghai'))

cur.execute("SELECT ?", params=(params,))
rows = cur.fetchall()
Expand Down Expand Up @@ -693,7 +689,10 @@ def test_array_timestamp_query_param(trino_connection):
def test_array_timestamp_with_timezone_query_param(trino_connection):
cur = trino_connection.cursor()

params = [datetime(2020, 1, 1, 0, 0, 0, tzinfo=pytz.utc), datetime(2020, 1, 2, 0, 0, 0, tzinfo=pytz.utc)]
params = [
datetime(2020, 1, 1, 0, 0, 0, tzinfo=ZoneInfo('UTC')),
datetime(2020, 1, 2, 0, 0, 0, tzinfo=ZoneInfo('UTC')),
]

cur.execute("SELECT ?", params=(params,))
rows = cur.fetchall()
Expand Down
8 changes: 6 additions & 2 deletions tests/integration/test_types_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,12 @@
from datetime import date, datetime, time, timedelta, timezone, tzinfo
from decimal import Decimal

try:
from zoneinfo import ZoneInfo
except ModuleNotFoundError:
from backports.zoneinfo import ZoneInfo

import pytest
import pytz

import trino
from tests.integration.conftest import trino_version
Expand Down Expand Up @@ -729,7 +733,7 @@ def create_timezone(timezone_str: str) -> tzinfo:
else:
return timezone(-timedelta(hours=hours, minutes=minutes))
else:
return pytz.timezone(timezone_str)
return ZoneInfo(timezone_str)


def test_interval(trino_connection):
Expand Down
22 changes: 9 additions & 13 deletions trino/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,21 +51,18 @@
from time import sleep
from typing import Any, Dict, Generic, List, Optional, Tuple, TypeVar, Union

import pytz
try:
from zoneinfo import ZoneInfo
except ModuleNotFoundError:
from backports.zoneinfo import ZoneInfo

import requests
from pytz.tzinfo import BaseTzInfo
from dateutil import tz
from tzlocal import get_localzone_name # type: ignore

import trino.logging
from trino import constants, exceptions

try:
from zoneinfo import ZoneInfo # type: ignore

except ModuleNotFoundError:
from backports.zoneinfo import ZoneInfo # type: ignore


__all__ = ["ClientSession", "TrinoQuery", "TrinoRequest", "PROXIES"]

logger = trino.logging.get_logger(__name__)
Expand Down Expand Up @@ -946,7 +943,7 @@ def _create_tzinfo(timezone_str: str) -> tzinfo:
return timezone(-timedelta(hours=int(hours), minutes=int(minutes)))
return timezone(timedelta(hours=int(hours), minutes=int(minutes)))
else:
return pytz.timezone(timezone_str)
return ZoneInfo(timezone_str)


def _fraction_to_decimal(fractional_str: str) -> Decimal:
Expand Down Expand Up @@ -996,8 +993,7 @@ def add_time_delta(self, time_delta: timedelta) -> PythonTemporalType:
def normalize(self, value: PythonTemporalType) -> PythonTemporalType:
"""
If `add_time_delta` results in value crossing DST boundaries, this method should
return a normalized version of the value to account for it, for example,
using `pytz.timezone.normalize`.
return a normalized version of the value to account for it.
"""
return value

Expand Down Expand Up @@ -1041,7 +1037,7 @@ def new_instance(self, value: datetime, fraction: Decimal) -> TimestampWithTimeZ
return TimestampWithTimeZone(value, fraction)

def normalize(self, value: datetime) -> datetime:
if isinstance(self._whole_python_temporal_value.tzinfo, BaseTzInfo):
if tz.datetime_ambiguous(value):
return self._whole_python_temporal_value.tzinfo.normalize(value)
return value

Expand Down
13 changes: 8 additions & 5 deletions trino/dbapi.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,10 @@
from typing import Any, Dict, List, NamedTuple, Optional # NOQA for mypy types
from urllib.parse import urlparse

import pytz
try:
from zoneinfo import ZoneInfo
except ModuleNotFoundError:
from backports.zoneinfo import ZoneInfo

import trino.client
import trino.exceptions
Expand Down Expand Up @@ -425,8 +428,8 @@ def _format_prepared_param(self, param):
if isinstance(param, datetime.datetime) and param.tzinfo is not None:
datetime_str = param.strftime("%Y-%m-%d %H:%M:%S.%f")
# named timezones
if hasattr(param.tzinfo, 'zone'):
return "TIMESTAMP '%s %s'" % (datetime_str, param.tzinfo.zone)
if isinstance(param.tzinfo, ZoneInfo):
return "TIMESTAMP '%s %s'" % (datetime_str, param.tzinfo.key)
# offset-based timezones
return "TIMESTAMP '%s %s'" % (datetime_str, param.tzinfo.tzname(param))

Expand All @@ -438,8 +441,8 @@ def _format_prepared_param(self, param):
if isinstance(param, datetime.time) and param.tzinfo is not None:
time_str = param.strftime("%H:%M:%S.%f")
# named timezones
if hasattr(param.tzinfo, 'zone'):
utc_offset = datetime.datetime.now(pytz.timezone(param.tzinfo.zone)).strftime('%z')
if isinstance(param.tzinfo, ZoneInfo):
utc_offset = datetime.datetime.now(tz=param.tzinfo).strftime('%z')
return "TIME '%s %s:%s'" % (time_str, utc_offset[:3], utc_offset[3:])
# offset-based timezones
return "TIME '%s %s'" % (time_str, param.strftime('%Z')[3:])
Expand Down

0 comments on commit cbd4c6b

Please sign in to comment.