Skip to content

Commit

Permalink
Add JSON handling to dbapi and dialect
Browse files Browse the repository at this point in the history
  • Loading branch information
laserkaplan authored and hashhar committed Sep 22, 2022
1 parent 89c2769 commit cd614ff
Show file tree
Hide file tree
Showing 3 changed files with 60 additions and 2 deletions.
40 changes: 40 additions & 0 deletions tests/integration/test_sqlalchemy_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@
import sqlalchemy as sqla
from sqlalchemy.sql import and_, or_, not_

from trino.sqlalchemy.datatype import JSON


@pytest.fixture
def trino_connection(run_trino, request):
Expand Down Expand Up @@ -255,3 +257,41 @@ def test_cte(trino_connection):
result = conn.execute(s)
rows = result.fetchall()
assert len(rows) == 15


@pytest.mark.parametrize(
'trino_connection,json_object',
[
('memory', None),
('memory', 1),
('memory', 'test'),
('memory', [1, 'test']),
('memory', {'test': 1}),
],
indirect=['trino_connection']
)
def test_json_column(trino_connection, json_object):
engine, conn = trino_connection

if not engine.dialect.has_schema(engine, "test"):
engine.execute(sqla.schema.CreateSchema("test"))
metadata = sqla.MetaData()

try:
table_with_json = sqla.Table(
'table_with_json',
metadata,
sqla.Column('id', sqla.Integer),
sqla.Column('json_column', JSON),
schema="test"
)
metadata.create_all(engine)
ins = table_with_json.insert()
conn.execute(ins, {"id": 1, "json_column": json_object})
query = sqla.select(table_with_json)
result = conn.execute(query)
rows = result.fetchall()
assert len(rows) == 1
assert rows[0] == (1, json_object)
finally:
metadata.drop_all(engine)
3 changes: 3 additions & 0 deletions trino/sqlalchemy/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -206,6 +206,9 @@ def visit_TIME(self, type_, **kw):
datatype += " WITH TIME ZONE"
return datatype

def visit_JSON(self, type_, **kw):
return 'JSON'


class TrinoIdentifierPreparer(compiler.IdentifierPreparer):
reserved_words = RESERVED_WORDS
Expand Down
19 changes: 17 additions & 2 deletions trino/sqlalchemy/datatype.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,14 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import json
import re
from typing import Iterator, List, Optional, Tuple, Type, Union, Dict, Any

from sqlalchemy import util
from sqlalchemy.sql import sqltypes
from sqlalchemy.sql.type_api import TypeEngine
from sqlalchemy.sql.type_api import TypeDecorator, TypeEngine
from sqlalchemy.types import String

SQLType = Union[TypeEngine, Type[TypeEngine]]

Expand Down Expand Up @@ -71,6 +73,19 @@ def __init__(self, precision=None, timezone=False):
self.precision = precision


class JSON(TypeDecorator):
impl = String

def process_bind_param(self, value, dialect):
return json.dumps(value)

def process_result_value(self, value, dialect):
return json.loads(value)

def get_col_spec(self, **kw):
return 'JSON'


# https://trino.io/docs/current/language/types.html
_type_map = {
# === Boolean ===
Expand All @@ -90,7 +105,7 @@ def __init__(self, precision=None, timezone=False):
"varchar": sqltypes.VARCHAR,
"char": sqltypes.CHAR,
"varbinary": sqltypes.VARBINARY,
"json": sqltypes.JSON,
"json": JSON,
# === Date and time ===
"date": sqltypes.DATE,
"time": TIME,
Expand Down

0 comments on commit cd614ff

Please sign in to comment.