From cd614ff3c37a7e3576bf6362a0c70d7d34a3f9d8 Mon Sep 17 00:00:00 2001 From: Laser Kaplan Date: Tue, 9 Aug 2022 09:07:56 -0700 Subject: [PATCH] Add JSON handling to dbapi and dialect --- .../test_sqlalchemy_integration.py | 40 +++++++++++++++++++ trino/sqlalchemy/compiler.py | 3 ++ trino/sqlalchemy/datatype.py | 19 ++++++++- 3 files changed, 60 insertions(+), 2 deletions(-) diff --git a/tests/integration/test_sqlalchemy_integration.py b/tests/integration/test_sqlalchemy_integration.py index 1dc8f05a..3a1d231c 100644 --- a/tests/integration/test_sqlalchemy_integration.py +++ b/tests/integration/test_sqlalchemy_integration.py @@ -13,6 +13,8 @@ import sqlalchemy as sqla from sqlalchemy.sql import and_, or_, not_ +from trino.sqlalchemy.datatype import JSON + @pytest.fixture def trino_connection(run_trino, request): @@ -255,3 +257,41 @@ def test_cte(trino_connection): result = conn.execute(s) rows = result.fetchall() assert len(rows) == 15 + + +@pytest.mark.parametrize( + 'trino_connection,json_object', + [ + ('memory', None), + ('memory', 1), + ('memory', 'test'), + ('memory', [1, 'test']), + ('memory', {'test': 1}), + ], + indirect=['trino_connection'] +) +def test_json_column(trino_connection, json_object): + engine, conn = trino_connection + + if not engine.dialect.has_schema(engine, "test"): + engine.execute(sqla.schema.CreateSchema("test")) + metadata = sqla.MetaData() + + try: + table_with_json = sqla.Table( + 'table_with_json', + metadata, + sqla.Column('id', sqla.Integer), + sqla.Column('json_column', JSON), + schema="test" + ) + metadata.create_all(engine) + ins = table_with_json.insert() + conn.execute(ins, {"id": 1, "json_column": json_object}) + query = sqla.select(table_with_json) + result = conn.execute(query) + rows = result.fetchall() + assert len(rows) == 1 + assert rows[0] == (1, json_object) + finally: + metadata.drop_all(engine) diff --git a/trino/sqlalchemy/compiler.py b/trino/sqlalchemy/compiler.py index a085fbf3..99b272fd 100644 --- a/trino/sqlalchemy/compiler.py +++ b/trino/sqlalchemy/compiler.py @@ -206,6 +206,9 @@ def visit_TIME(self, type_, **kw): datatype += " WITH TIME ZONE" return datatype + def visit_JSON(self, type_, **kw): + return 'JSON' + class TrinoIdentifierPreparer(compiler.IdentifierPreparer): reserved_words = RESERVED_WORDS diff --git a/trino/sqlalchemy/datatype.py b/trino/sqlalchemy/datatype.py index 8284ba9c..44961762 100644 --- a/trino/sqlalchemy/datatype.py +++ b/trino/sqlalchemy/datatype.py @@ -9,12 +9,14 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import json import re from typing import Iterator, List, Optional, Tuple, Type, Union, Dict, Any from sqlalchemy import util from sqlalchemy.sql import sqltypes -from sqlalchemy.sql.type_api import TypeEngine +from sqlalchemy.sql.type_api import TypeDecorator, TypeEngine +from sqlalchemy.types import String SQLType = Union[TypeEngine, Type[TypeEngine]] @@ -71,6 +73,19 @@ def __init__(self, precision=None, timezone=False): self.precision = precision +class JSON(TypeDecorator): + impl = String + + def process_bind_param(self, value, dialect): + return json.dumps(value) + + def process_result_value(self, value, dialect): + return json.loads(value) + + def get_col_spec(self, **kw): + return 'JSON' + + # https://trino.io/docs/current/language/types.html _type_map = { # === Boolean === @@ -90,7 +105,7 @@ def __init__(self, precision=None, timezone=False): "varchar": sqltypes.VARCHAR, "char": sqltypes.CHAR, "varbinary": sqltypes.VARBINARY, - "json": sqltypes.JSON, + "json": JSON, # === Date and time === "date": sqltypes.DATE, "time": TIME,