From 2eeb2dc8ba79404fe181cd0894bcfe3f78ba642e Mon Sep 17 00:00:00 2001 From: Przemek Denkiewicz Date: Tue, 31 May 2022 13:46:56 +0200 Subject: [PATCH] Support for TIME(p) and TIMESTAMP(p) to SQLAlchemy --- tests/unit/sqlalchemy/conftest.py | 11 ++++++- tests/unit/sqlalchemy/test_datatype_parse.py | 30 ++++++++++-------- trino/sqlalchemy/compiler.py | 20 ++++++++++++ trino/sqlalchemy/datatype.py | 32 ++++++++++++++++---- 4 files changed, 74 insertions(+), 19 deletions(-) diff --git a/tests/unit/sqlalchemy/conftest.py b/tests/unit/sqlalchemy/conftest.py index 7857149a..e80f19b8 100644 --- a/tests/unit/sqlalchemy/conftest.py +++ b/tests/unit/sqlalchemy/conftest.py @@ -12,7 +12,7 @@ import pytest from sqlalchemy.sql.sqltypes import ARRAY -from trino.sqlalchemy.datatype import MAP, ROW, SQLType +from trino.sqlalchemy.datatype import MAP, ROW, SQLType, TIMESTAMP, TIME @pytest.fixture(scope="session") @@ -40,6 +40,15 @@ def _assert_sqltype(this: SQLType, that: SQLType): for (this_attr, that_attr) in zip(this.attr_types, that.attr_types): assert this_attr[0] == that_attr[0] _assert_sqltype(this_attr[1], that_attr[1]) + + elif isinstance(this, TIME): + assert this.precision == that.precision + assert this.timezone == that.timezone + + elif isinstance(this, TIMESTAMP): + assert this.precision == that.precision + assert this.timezone == that.timezone + else: assert str(this) == str(that) diff --git a/tests/unit/sqlalchemy/test_datatype_parse.py b/tests/unit/sqlalchemy/test_datatype_parse.py index ae2d297b..00f6727e 100644 --- a/tests/unit/sqlalchemy/test_datatype_parse.py +++ b/tests/unit/sqlalchemy/test_datatype_parse.py @@ -16,14 +16,17 @@ ARRAY, INTEGER, DECIMAL, - DATE, - TIME, - TIMESTAMP, + DATE ) from sqlalchemy.sql.type_api import TypeEngine from trino.sqlalchemy import datatype -from trino.sqlalchemy.datatype import MAP, ROW +from trino.sqlalchemy.datatype import ( + MAP, + ROW, + TIME, + TIMESTAMP +) @pytest.mark.parametrize( @@ -65,8 +68,7 @@ def test_parse_cases(type_str: str, sql_type: TypeEngine, assert_sqltype): "CHAR(10)": CHAR(10), "VARCHAR(10)": VARCHAR(10), "DECIMAL(20)": DECIMAL(20), - "DECIMAL(20, 3)": DECIMAL(20, 3), - # TODO: support parametric timestamps (https://github.com/trinodb/trino-python-client/issues/107) + "DECIMAL(20, 3)": DECIMAL(20, 3) } @@ -142,8 +144,8 @@ def test_parse_map(type_str: str, sql_type: ARRAY, assert_sqltype): ), "row(min timestamp(6) with time zone, max timestamp(6) with time zone)": ROW( attr_types=[ - ("min", TIMESTAMP(timezone=True)), - ("max", TIMESTAMP(timezone=True)), + ("min", TIMESTAMP(6, timezone=True)), + ("max", TIMESTAMP(6, timezone=True)), ] ), 'row("first name" varchar, "last name" varchar)': ROW( @@ -173,12 +175,16 @@ def test_parse_row(type_str: str, sql_type: ARRAY, assert_sqltype): parse_datetime_testcases = { - # TODO: support parametric timestamps (https://github.com/trinodb/trino-python-client/issues/107) "date": DATE(), "time": TIME(), + "time(3)": TIME(3, timezone=False), + "time(6)": TIME(6), + "time(12) with time zone": TIME(12, timezone=True), "time with time zone": TIME(timezone=True), - "timestamp": TIMESTAMP(), - "timestamp with time zone": TIMESTAMP(timezone=True), + "timestamp(3)": TIMESTAMP(3, timezone=False), + "timestamp(6)": TIMESTAMP(6), + "timestamp(12) with time zone": TIMESTAMP(12, timezone=True), + "timestamp with time zone": TIMESTAMP(timezone=True) } @@ -187,6 +193,6 @@ def test_parse_row(type_str: str, sql_type: ARRAY, assert_sqltype): parse_datetime_testcases.items(), ids=parse_datetime_testcases.keys(), ) -def test_parse_datetime(type_str: str, sql_type: ARRAY, assert_sqltype): +def test_parse_datetime(type_str: str, sql_type: TypeEngine, assert_sqltype): actual_type = datatype.parse_sqltype(type_str) assert_sqltype(actual_type, sql_type) diff --git a/trino/sqlalchemy/compiler.py b/trino/sqlalchemy/compiler.py index 4945aa58..77fce67b 100644 --- a/trino/sqlalchemy/compiler.py +++ b/trino/sqlalchemy/compiler.py @@ -147,6 +147,26 @@ def visit_BLOB(self, type_, **kw): def visit_DATETIME(self, type_, **kw): return self.visit_TIMESTAMP(type_, **kw) + def visit_TIMESTAMP(self, type_, **kw): + return "TIMESTAMP%s%s" % ( + "(%d)" % type_.precision + if getattr(type_, "precision", None) is not None + else "", + " WITH TIME ZONE" + if getattr(type_, "timezone", False) + else "" + ) + + def visit_TIME(self, type_, **kw): + return "TIME%s %s" % ( + "(%d)" % type_.precision + if getattr(type_, "precision", None) is not None + else "", + " WITH TIME ZONE" + if getattr(type_, "timezone", False) + else "" + ) + class TrinoIdentifierPreparer(compiler.IdentifierPreparer): reserved_words = RESERVED_WORDS diff --git a/trino/sqlalchemy/datatype.py b/trino/sqlalchemy/datatype.py index 6b3af78e..450da325 100644 --- a/trino/sqlalchemy/datatype.py +++ b/trino/sqlalchemy/datatype.py @@ -10,7 +10,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import re -from typing import Iterator, List, Optional, Tuple, Type, Union +from typing import Iterator, List, Optional, Tuple, Type, Union, Dict, Any from sqlalchemy import util from sqlalchemy.sql import sqltypes @@ -55,6 +55,22 @@ def python_type(self): return list +class TIME(sqltypes.TIME): + __visit_name__ = "TIME" + + def __init__(self, precision=None, timezone=False): + super(TIME, self).__init__(timezone=timezone) + self.precision = precision + + +class TIMESTAMP(sqltypes.TIMESTAMP): + __visit_name__ = "TIMESTAMP" + + def __init__(self, precision=None, timezone=False): + super(TIMESTAMP, self).__init__(timezone=timezone) + self.precision = precision + + # https://trino.io/docs/current/language/types.html _type_map = { # === Boolean === @@ -77,8 +93,10 @@ def python_type(self): "json": sqltypes.JSON, # === Date and time === "date": sqltypes.DATE, - "time": sqltypes.TIME, - "timestamp": sqltypes.TIMESTAMP, + "time": TIME, + "time with time zone": TIME, + "timestamp": TIMESTAMP, + "timestamp with time zone": TIMESTAMP, # 'interval year to month': # 'interval day to second': # @@ -193,7 +211,9 @@ def parse_sqltype(type_str: str) -> TypeEngine: type_class = _type_map[type_name] type_args = [int(o.strip()) for o in type_opts.split(",")] if type_opts else [] if type_name in ("time", "timestamp"): - type_kwargs = dict(timezone=type_str.endswith("with time zone")) - # TODO: support parametric timestamps (https://github.com/trinodb/trino-python-client/issues/107) - return type_class(**type_kwargs) + if type_str.endswith("with time zone"): + type_kwargs: Dict[str, Any] = dict(timezone=True) + if type_opts is not None: + type_kwargs["precision"] = int(type_opts) + return type_class(**type_kwargs) return type_class(*type_args)