Skip to content

Commit

Permalink
Support for TIME(p) and TIMESTAMP(p) to SQLAlchemy
Browse files Browse the repository at this point in the history
  • Loading branch information
hovaesco committed Jun 27, 2022
1 parent 771eec3 commit b7a59a0
Show file tree
Hide file tree
Showing 4 changed files with 80 additions and 18 deletions.
11 changes: 10 additions & 1 deletion tests/unit/sqlalchemy/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
import pytest
from sqlalchemy.sql.sqltypes import ARRAY

from trino.sqlalchemy.datatype import MAP, ROW, SQLType
from trino.sqlalchemy.datatype import MAP, ROW, SQLType, TIMESTAMP, TIME


@pytest.fixture(scope="session")
Expand Down Expand Up @@ -40,6 +40,15 @@ def _assert_sqltype(this: SQLType, that: SQLType):
for (this_attr, that_attr) in zip(this.attr_types, that.attr_types):
assert this_attr[0] == that_attr[0]
_assert_sqltype(this_attr[1], that_attr[1])

elif isinstance(this, TIME):
assert this.precision == that.precision
assert this.timezone == that.timezone

elif isinstance(this, TIMESTAMP):
assert this.precision == that.precision
assert this.timezone == that.timezone

else:
assert str(this) == str(that)

Expand Down
33 changes: 21 additions & 12 deletions tests/unit/sqlalchemy/test_datatype_parse.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,14 +16,17 @@
ARRAY,
INTEGER,
DECIMAL,
DATE,
TIME,
TIMESTAMP,
DATE
)
from sqlalchemy.sql.type_api import TypeEngine

from trino.sqlalchemy import datatype
from trino.sqlalchemy.datatype import MAP, ROW
from trino.sqlalchemy.datatype import (
MAP,
ROW,
TIME,
TIMESTAMP
)


@pytest.mark.parametrize(
Expand Down Expand Up @@ -65,8 +68,7 @@ def test_parse_cases(type_str: str, sql_type: TypeEngine, assert_sqltype):
"CHAR(10)": CHAR(10),
"VARCHAR(10)": VARCHAR(10),
"DECIMAL(20)": DECIMAL(20),
"DECIMAL(20, 3)": DECIMAL(20, 3),
# TODO: support parametric timestamps (https://github.com/trinodb/trino-python-client/issues/107)
"DECIMAL(20, 3)": DECIMAL(20, 3)
}


Expand Down Expand Up @@ -142,8 +144,8 @@ def test_parse_map(type_str: str, sql_type: ARRAY, assert_sqltype):
),
"row(min timestamp(6) with time zone, max timestamp(6) with time zone)": ROW(
attr_types=[
("min", TIMESTAMP(timezone=True)),
("max", TIMESTAMP(timezone=True)),
("min", TIMESTAMP(6, timezone=True)),
("max", TIMESTAMP(6, timezone=True)),
]
),
'row("first name" varchar, "last name" varchar)': ROW(
Expand Down Expand Up @@ -173,12 +175,19 @@ def test_parse_row(type_str: str, sql_type: ARRAY, assert_sqltype):


parse_datetime_testcases = {
# TODO: support parametric timestamps (https://github.com/trinodb/trino-python-client/issues/107)
"date": DATE(),
"time": TIME(),
"time(0)": TIME(0),
"time(3)": TIME(3, timezone=False),
"time(6)": TIME(6),
"time(13)": TIME(13),
"time(12) with time zone": TIME(12, timezone=True),
"time with time zone": TIME(timezone=True),
"timestamp": TIMESTAMP(),
"timestamp with time zone": TIMESTAMP(timezone=True),
"timestamp(0)": TIMESTAMP(0),
"timestamp(3)": TIMESTAMP(3, timezone=False),
"timestamp(6)": TIMESTAMP(6),
"timestamp(12) with time zone": TIMESTAMP(12, timezone=True),
"timestamp with time zone": TIMESTAMP(timezone=True)
}


Expand All @@ -187,6 +196,6 @@ def test_parse_row(type_str: str, sql_type: ARRAY, assert_sqltype):
parse_datetime_testcases.items(),
ids=parse_datetime_testcases.keys(),
)
def test_parse_datetime(type_str: str, sql_type: ARRAY, assert_sqltype):
def test_parse_datetime(type_str: str, sql_type: TypeEngine, assert_sqltype):
actual_type = datatype.parse_sqltype(type_str)
assert_sqltype(actual_type, sql_type)
23 changes: 23 additions & 0 deletions trino/sqlalchemy/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,6 +147,29 @@ def visit_BLOB(self, type_, **kw):
def visit_DATETIME(self, type_, **kw):
return self.visit_TIMESTAMP(type_, **kw)

def visit_TIMESTAMP(self, type_, **kw):
datatype = "TIMESTAMP"
precision = getattr(type_, "precision", None)
if precision not in range(0, 13) and precision is not None:
raise ValueError(f"invalid precision={precision}, it must be from range 0-12")
if precision is not None:
datatype += f"({precision})"
if getattr(type_, "timezone", False):
datatype += " WITH TIME ZONE"

return datatype

def visit_TIME(self, type_, **kw):
datatype = "TIME"
precision = getattr(type_, "precision", None)
if precision not in range(0, 13) and precision is not None:
raise ValueError(f"invalid precision={precision}, it must be from range 0-12")
if precision is not None:
datatype += f"({precision})"
if getattr(type_, "timezone", False):
datatype += " WITH TIME ZONE"
return datatype


class TrinoIdentifierPreparer(compiler.IdentifierPreparer):
reserved_words = RESERVED_WORDS
31 changes: 26 additions & 5 deletions trino/sqlalchemy/datatype.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import re
from typing import Iterator, List, Optional, Tuple, Type, Union
from typing import Iterator, List, Optional, Tuple, Type, Union, Dict, Any

from sqlalchemy import util
from sqlalchemy.sql import sqltypes
Expand Down Expand Up @@ -55,6 +55,22 @@ def python_type(self):
return list


class TIME(sqltypes.TIME):
__visit_name__ = "TIME"

def __init__(self, precision=None, timezone=False):
super(TIME, self).__init__(timezone=timezone)
self.precision = precision


class TIMESTAMP(sqltypes.TIMESTAMP):
__visit_name__ = "TIMESTAMP"

def __init__(self, precision=None, timezone=False):
super(TIMESTAMP, self).__init__(timezone=timezone)
self.precision = precision


# https://trino.io/docs/current/language/types.html
_type_map = {
# === Boolean ===
Expand All @@ -77,8 +93,10 @@ def python_type(self):
"json": sqltypes.JSON,
# === Date and time ===
"date": sqltypes.DATE,
"time": sqltypes.TIME,
"timestamp": sqltypes.TIMESTAMP,
"time": TIME,
"time with time zone": TIME,
"timestamp": TIMESTAMP,
"timestamp with time zone": TIMESTAMP,
# 'interval year to month':
# 'interval day to second':
#
Expand Down Expand Up @@ -193,7 +211,10 @@ def parse_sqltype(type_str: str) -> TypeEngine:
type_class = _type_map[type_name]
type_args = [int(o.strip()) for o in type_opts.split(",")] if type_opts else []
if type_name in ("time", "timestamp"):
type_kwargs = dict(timezone=type_str.endswith("with time zone"))
# TODO: support parametric timestamps (https://github.com/trinodb/trino-python-client/issues/107)
type_kwargs: Dict[str, Any] = dict()
if type_str.endswith("with time zone"):
type_kwargs["timezone"] = True
if type_opts is not None:
type_kwargs["precision"] = int(type_opts)
return type_class(**type_kwargs)
return type_class(*type_args)

0 comments on commit b7a59a0

Please sign in to comment.