Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support UUID for SQLAlchemy #359

Merged
merged 1 commit into from
Apr 27, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
56 changes: 56 additions & 0 deletions tests/integration/test_sqlalchemy_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License
import uuid

import pytest
import sqlalchemy as sqla
from sqlalchemy.sql import and_, not_, or_
Expand Down Expand Up @@ -133,6 +135,60 @@ def test_insert(trino_connection):
metadata.drop_all(engine)


@pytest.mark.skipif(
sqlalchemy_version() < "2.0",
reason="sqlalchemy.Uuid only exists with SQLAlchemy 2.0 and above"
)
@pytest.mark.parametrize('trino_connection', ['memory'], indirect=True)
def test_define_and_create_table_uuid(trino_connection):
engine, conn = trino_connection
if not engine.dialect.has_schema(conn, "test"):
with engine.begin() as connection:
connection.execute(sqla.schema.CreateSchema("test"))
metadata = sqla.MetaData()
try:
sqla.Table('users',
metadata,
sqla.Column('guid', sqla.Uuid),
schema="test")
metadata.create_all(engine)
assert sqla.inspect(engine).has_table('users', schema="test")
users = sqla.Table('users', metadata, schema='test', autoload_with=conn)
assert_column(users, "guid", sqla.sql.sqltypes.Uuid)
finally:
metadata.drop_all(engine)


@pytest.mark.skipif(
sqlalchemy_version() < "2.0",
reason="sqlalchemy.Uuid only exists with SQLAlchemy 2.0 and above"
)
@pytest.mark.parametrize('trino_connection', ['memory'], indirect=True)
def test_insert_uuid(trino_connection):
engine, conn = trino_connection

if not engine.dialect.has_schema(conn, "test"):
with engine.begin() as connection:
connection.execute(sqla.schema.CreateSchema("test"))
metadata = sqla.MetaData()
try:
users = sqla.Table('users',
metadata,
sqla.Column('guid', sqla.Uuid),
schema="test")
metadata.create_all(engine)
ins = users.insert()
guid = uuid.uuid4()
conn.execute(ins, {"guid": guid})
query = sqla.select(users)
result = conn.execute(query)
rows = result.fetchall()
assert len(rows) == 1
assert rows[0] == (guid,)
finally:
metadata.drop_all(engine)


@pytest.mark.skipif(
sqlalchemy_version() < "1.4",
reason="columns argument to select() must be a Python list or other iterable"
Expand Down
4 changes: 4 additions & 0 deletions trino/sqlalchemy/datatype.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
import re
from typing import Any, Dict, Iterator, List, Optional, Tuple, Type, Union

import sqlalchemy
from sqlalchemy import util
from sqlalchemy.sql import sqltypes
from sqlalchemy.sql.type_api import TypeDecorator, TypeEngine
Expand Down Expand Up @@ -129,6 +130,9 @@ def get_col_spec(self, **kw):
# 'tdigest': TDIGEST,
}

if hasattr(sqlalchemy, "Uuid"):
_type_map["uuid"] = sqlalchemy.Uuid


def unquote(string: str, quote: str = '"', escape: str = "\\") -> str:
"""
Expand Down