Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support for TIME(p) and TIMESTAMP(p) to SQLAlchemy #181

Merged
merged 1 commit into from
Jun 28, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 10 additions & 1 deletion tests/unit/sqlalchemy/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
import pytest
from sqlalchemy.sql.sqltypes import ARRAY

from trino.sqlalchemy.datatype import MAP, ROW, SQLType
from trino.sqlalchemy.datatype import MAP, ROW, SQLType, TIMESTAMP, TIME


@pytest.fixture(scope="session")
Expand Down Expand Up @@ -40,6 +40,15 @@ def _assert_sqltype(this: SQLType, that: SQLType):
for (this_attr, that_attr) in zip(this.attr_types, that.attr_types):
assert this_attr[0] == that_attr[0]
_assert_sqltype(this_attr[1], that_attr[1])

elif isinstance(this, TIME):
assert this.precision == that.precision
assert this.timezone == that.timezone

elif isinstance(this, TIMESTAMP):
assert this.precision == that.precision
assert this.timezone == that.timezone

else:
assert str(this) == str(that)

Expand Down
33 changes: 21 additions & 12 deletions tests/unit/sqlalchemy/test_datatype_parse.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,14 +16,17 @@
ARRAY,
INTEGER,
DECIMAL,
DATE,
TIME,
TIMESTAMP,
DATE
)
from sqlalchemy.sql.type_api import TypeEngine

from trino.sqlalchemy import datatype
from trino.sqlalchemy.datatype import MAP, ROW
from trino.sqlalchemy.datatype import (
MAP,
ROW,
TIME,
TIMESTAMP
)


@pytest.mark.parametrize(
Expand Down Expand Up @@ -65,8 +68,7 @@ def test_parse_cases(type_str: str, sql_type: TypeEngine, assert_sqltype):
"CHAR(10)": CHAR(10),
"VARCHAR(10)": VARCHAR(10),
"DECIMAL(20)": DECIMAL(20),
"DECIMAL(20, 3)": DECIMAL(20, 3),
# TODO: support parametric timestamps (https://github.com/trinodb/trino-python-client/issues/107)
"DECIMAL(20, 3)": DECIMAL(20, 3)
}


Expand Down Expand Up @@ -142,8 +144,8 @@ def test_parse_map(type_str: str, sql_type: ARRAY, assert_sqltype):
),
"row(min timestamp(6) with time zone, max timestamp(6) with time zone)": ROW(
attr_types=[
("min", TIMESTAMP(timezone=True)),
("max", TIMESTAMP(timezone=True)),
("min", TIMESTAMP(6, timezone=True)),
("max", TIMESTAMP(6, timezone=True)),
]
),
'row("first name" varchar, "last name" varchar)': ROW(
Expand Down Expand Up @@ -173,12 +175,19 @@ def test_parse_row(type_str: str, sql_type: ARRAY, assert_sqltype):


parse_datetime_testcases = {
# TODO: support parametric timestamps (https://github.com/trinodb/trino-python-client/issues/107)
"date": DATE(),
"time": TIME(),
"time(0)": TIME(0),
"time(3)": TIME(3, timezone=False),
"time(6)": TIME(6),
"time(13)": TIME(13),
"time(12) with time zone": TIME(12, timezone=True),
"time with time zone": TIME(timezone=True),
"timestamp": TIMESTAMP(),
"timestamp with time zone": TIMESTAMP(timezone=True),
"timestamp(0)": TIMESTAMP(0),
"timestamp(3)": TIMESTAMP(3, timezone=False),
"timestamp(6)": TIMESTAMP(6),
"timestamp(12) with time zone": TIMESTAMP(12, timezone=True),
"timestamp with time zone": TIMESTAMP(timezone=True)
}


Expand All @@ -187,6 +196,6 @@ def test_parse_row(type_str: str, sql_type: ARRAY, assert_sqltype):
parse_datetime_testcases.items(),
ids=parse_datetime_testcases.keys(),
)
def test_parse_datetime(type_str: str, sql_type: ARRAY, assert_sqltype):
def test_parse_datetime(type_str: str, sql_type: TypeEngine, assert_sqltype):
actual_type = datatype.parse_sqltype(type_str)
assert_sqltype(actual_type, sql_type)
23 changes: 23 additions & 0 deletions trino/sqlalchemy/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,6 +147,29 @@ def visit_BLOB(self, type_, **kw):
def visit_DATETIME(self, type_, **kw):
return self.visit_TIMESTAMP(type_, **kw)

def visit_TIMESTAMP(self, type_, **kw):
datatype = "TIMESTAMP"
precision = getattr(type_, "precision", None)
if precision not in range(0, 13) and precision is not None:
raise ValueError(f"invalid precision={precision}, it must be from range 0-12")
if precision is not None:
hovaesco marked this conversation as resolved.
Show resolved Hide resolved
datatype += f"({precision})"
if getattr(type_, "timezone", False):
datatype += " WITH TIME ZONE"

return datatype

def visit_TIME(self, type_, **kw):
datatype = "TIME"
precision = getattr(type_, "precision", None)
if precision not in range(0, 13) and precision is not None:
raise ValueError(f"invalid precision={precision}, it must be from range 0-12")
if precision is not None:
datatype += f"({precision})"
if getattr(type_, "timezone", False):
datatype += " WITH TIME ZONE"
return datatype


class TrinoIdentifierPreparer(compiler.IdentifierPreparer):
reserved_words = RESERVED_WORDS
31 changes: 26 additions & 5 deletions trino/sqlalchemy/datatype.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import re
from typing import Iterator, List, Optional, Tuple, Type, Union
from typing import Iterator, List, Optional, Tuple, Type, Union, Dict, Any

from sqlalchemy import util
from sqlalchemy.sql import sqltypes
Expand Down Expand Up @@ -55,6 +55,22 @@ def python_type(self):
return list


class TIME(sqltypes.TIME):
__visit_name__ = "TIME"

def __init__(self, precision=None, timezone=False):
super(TIME, self).__init__(timezone=timezone)
self.precision = precision


class TIMESTAMP(sqltypes.TIMESTAMP):
__visit_name__ = "TIMESTAMP"

def __init__(self, precision=None, timezone=False):
super(TIMESTAMP, self).__init__(timezone=timezone)
self.precision = precision


# https://trino.io/docs/current/language/types.html
_type_map = {
# === Boolean ===
Expand All @@ -77,8 +93,10 @@ def python_type(self):
"json": sqltypes.JSON,
# === Date and time ===
"date": sqltypes.DATE,
"time": sqltypes.TIME,
"timestamp": sqltypes.TIMESTAMP,
"time": TIME,
"time with time zone": TIME,
"timestamp": TIMESTAMP,
"timestamp with time zone": TIMESTAMP,
# 'interval year to month':
# 'interval day to second':
#
Expand Down Expand Up @@ -193,7 +211,10 @@ def parse_sqltype(type_str: str) -> TypeEngine:
type_class = _type_map[type_name]
type_args = [int(o.strip()) for o in type_opts.split(",")] if type_opts else []
if type_name in ("time", "timestamp"):
type_kwargs = dict(timezone=type_str.endswith("with time zone"))
# TODO: support parametric timestamps (https://github.com/trinodb/trino-python-client/issues/107)
type_kwargs: Dict[str, Any] = dict()
if type_str.endswith("with time zone"):
type_kwargs["timezone"] = True
if type_opts is not None:
type_kwargs["precision"] = int(type_opts)
return type_class(**type_kwargs)
return type_class(*type_args)