Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enhance Cursor.description with type information #315

Merged
merged 1 commit into from
Jan 11, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
62 changes: 48 additions & 14 deletions tests/integration/test_dbapi_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,7 @@ def test_none_query_param(trino_connection):
rows = cur.fetchall()

assert rows[0][0] is None
assert_cursor_description(cur, trino_type="unknown")


def test_string_query_param(trino_connection):
Expand All @@ -128,6 +129,7 @@ def test_string_query_param(trino_connection):
rows = cur.fetchall()

assert rows[0][0] == "six'"
assert_cursor_description(cur, trino_type="varchar(4)", size=4)


def test_execute_many(trino_connection):
Expand Down Expand Up @@ -241,10 +243,11 @@ def test_legacy_primitive_types_with_connection_and_cursor(
def test_decimal_query_param(trino_connection):
cur = trino_connection.cursor()

cur.execute("SELECT ?", params=(Decimal('0.142857'),))
cur.execute("SELECT ?", params=(Decimal('1112.142857'),))
rows = cur.fetchall()

assert rows[0][0] == Decimal('0.142857')
assert rows[0][0] == Decimal('1112.142857')
assert_cursor_description(cur, trino_type="decimal(10, 6)", precision=10, scale=6)


def test_null_decimal(trino_connection):
Expand All @@ -254,6 +257,7 @@ def test_null_decimal(trino_connection):
rows = cur.fetchall()

assert rows[0][0] is None
assert_cursor_description(cur, trino_type="decimal(38, 0)", precision=38, scale=0)


def test_biggest_decimal(trino_connection):
Expand All @@ -264,6 +268,7 @@ def test_biggest_decimal(trino_connection):
rows = cur.fetchall()

assert rows[0][0] == params
assert_cursor_description(cur, trino_type="decimal(38, 0)", precision=38, scale=0)


def test_smallest_decimal(trino_connection):
Expand All @@ -274,6 +279,7 @@ def test_smallest_decimal(trino_connection):
rows = cur.fetchall()

assert rows[0][0] == params
assert_cursor_description(cur, trino_type="decimal(38, 0)", precision=38, scale=0)


def test_highest_precision_decimal(trino_connection):
Expand All @@ -284,6 +290,7 @@ def test_highest_precision_decimal(trino_connection):
rows = cur.fetchall()

assert rows[0][0] == params
assert_cursor_description(cur, trino_type="decimal(38, 38)", precision=38, scale=38)


def test_datetime_query_param(trino_connection):
Expand All @@ -295,7 +302,7 @@ def test_datetime_query_param(trino_connection):
rows = cur.fetchall()

assert rows[0][0] == params
assert cur.description[0][1] == "timestamp(6)"
assert_cursor_description(cur, trino_type="timestamp(6)", precision=6)


def test_datetime_with_utc_time_zone_query_param(trino_connection):
Expand All @@ -307,7 +314,7 @@ def test_datetime_with_utc_time_zone_query_param(trino_connection):
rows = cur.fetchall()

assert rows[0][0] == params
assert cur.description[0][1] == "timestamp(6) with time zone"
assert_cursor_description(cur, trino_type="timestamp(6) with time zone", precision=6)


def test_datetime_with_numeric_offset_time_zone_query_param(trino_connection):
Expand All @@ -321,7 +328,7 @@ def test_datetime_with_numeric_offset_time_zone_query_param(trino_connection):
rows = cur.fetchall()

assert rows[0][0] == params
assert cur.description[0][1] == "timestamp(6) with time zone"
assert_cursor_description(cur, trino_type="timestamp(6) with time zone", precision=6)


def test_datetime_with_named_time_zone_query_param(trino_connection):
Expand All @@ -333,7 +340,7 @@ def test_datetime_with_named_time_zone_query_param(trino_connection):
rows = cur.fetchall()

assert rows[0][0] == params
assert cur.description[0][1] == "timestamp(6) with time zone"
assert_cursor_description(cur, trino_type="timestamp(6) with time zone", precision=6)


def test_datetime_with_trailing_zeros(trino_connection):
Expand All @@ -343,6 +350,7 @@ def test_datetime_with_trailing_zeros(trino_connection):
rows = cur.fetchall()

assert rows[0][0] == datetime.strptime("2001-08-22 03:04:05.321000", "%Y-%m-%d %H:%M:%S.%f")
assert_cursor_description(cur, trino_type="timestamp(6)", precision=6)


def test_null_datetime_with_time_zone(trino_connection):
Expand All @@ -352,6 +360,7 @@ def test_null_datetime_with_time_zone(trino_connection):
rows = cur.fetchall()

assert rows[0][0] is None
assert_cursor_description(cur, trino_type="timestamp(3) with time zone", precision=3)


def test_datetime_with_time_zone_numeric_offset(trino_connection):
Expand All @@ -361,6 +370,7 @@ def test_datetime_with_time_zone_numeric_offset(trino_connection):
rows = cur.fetchall()

assert rows[0][0] == datetime.strptime("2001-08-22 03:04:05.321 -08:00", "%Y-%m-%d %H:%M:%S.%f %z")
assert_cursor_description(cur, trino_type="timestamp(3) with time zone", precision=3)


def test_datetimes_with_time_zone_in_dst_gap_query_param(trino_connection):
Expand Down Expand Up @@ -404,6 +414,7 @@ def test_date_query_param(trino_connection):
rows = cur.fetchall()

assert rows[0][0] == params
assert_cursor_description(cur, trino_type="date")


def test_null_date(trino_connection):
Expand All @@ -413,6 +424,7 @@ def test_null_date(trino_connection):
rows = cur.fetchall()

assert rows[0][0] is None
assert_cursor_description(cur, trino_type="date")


def test_unsupported_python_dates(trino_connection):
Expand Down Expand Up @@ -462,6 +474,16 @@ def test_supported_special_dates_query_param(trino_connection):
assert rows[0][0] == params


def test_char(trino_connection):
cur = trino_connection.cursor()

cur.execute("SELECT CHAR 'trino'")
rows = cur.fetchall()

assert rows[0][0] == 'trino'
assert_cursor_description(cur, trino_type="char(5)", size=5)


def test_time_query_param(trino_connection):
cur = trino_connection.cursor()

Expand All @@ -471,7 +493,7 @@ def test_time_query_param(trino_connection):
rows = cur.fetchall()

assert rows[0][0] == params
assert cur.description[0][1] == "time(6)"
assert_cursor_description(cur, trino_type="time(6)", precision=6)


def test_time_with_named_time_zone_query_param(trino_connection):
Expand Down Expand Up @@ -501,7 +523,7 @@ def test_time(trino_connection):
rows = cur.fetchall()

assert rows[0][0] == time(1, 2, 3, 456000)
assert cur.description[0][1] == "time(3)"
assert_cursor_description(cur, trino_type="time(3)", precision=3)


def test_null_time(trino_connection):
Expand All @@ -511,6 +533,7 @@ def test_null_time(trino_connection):
rows = cur.fetchall()

assert rows[0][0] is None
assert_cursor_description(cur, trino_type="time(3)", precision=3)


def test_time_with_time_zone_negative_offset(trino_connection):
Expand All @@ -522,7 +545,7 @@ def test_time_with_time_zone_negative_offset(trino_connection):
tz = timezone(-timedelta(hours=8, minutes=0))

assert rows[0][0] == time(1, 2, 3, 456000, tzinfo=tz)
assert cur.description[0][1] == "time(3) with time zone"
assert_cursor_description(cur, trino_type="time(3) with time zone", precision=3)


def test_time_with_time_zone_positive_offset(trino_connection):
Expand All @@ -534,7 +557,7 @@ def test_time_with_time_zone_positive_offset(trino_connection):
tz = timezone(timedelta(hours=8, minutes=0))

assert rows[0][0] == time(1, 2, 3, 456000, tzinfo=tz)
assert cur.description[0][1] == "time(3) with time zone"
assert_cursor_description(cur, trino_type="time(3) with time zone", precision=3)


def test_null_date_with_time_zone(trino_connection):
Expand All @@ -544,6 +567,7 @@ def test_null_date_with_time_zone(trino_connection):
rows = cur.fetchall()

assert rows[0][0] is None
assert_cursor_description(cur, trino_type="time(3) with time zone", precision=3)


@pytest.mark.parametrize(
Expand Down Expand Up @@ -717,7 +741,7 @@ def test_float_query_param(trino_connection):
cur.execute("SELECT ?", params=(1.1,))
rows = cur.fetchall()

assert cur.description[0][1] == "double"
assert_cursor_description(cur, trino_type="double")
assert rows[0][0] == 1.1


Expand All @@ -726,7 +750,7 @@ def test_float_nan_query_param(trino_connection):
cur.execute("SELECT ?", params=(float("nan"),))
rows = cur.fetchall()

assert cur.description[0][1] == "double"
assert_cursor_description(cur, trino_type="double")
assert isinstance(rows[0][0], float)
assert math.isnan(rows[0][0])

Expand All @@ -736,6 +760,7 @@ def test_float_inf_query_param(trino_connection):
cur.execute("SELECT ?", params=(float("inf"),))
rows = cur.fetchall()

assert_cursor_description(cur, trino_type="double")
assert rows[0][0] == float("inf")

cur.execute("SELECT ?", params=(float("-inf"),))
Expand All @@ -750,13 +775,13 @@ def test_int_query_param(trino_connection):
rows = cur.fetchall()

assert rows[0][0] == 3
assert cur.description[0][1] == "integer"
assert_cursor_description(cur, trino_type="integer")

cur.execute("SELECT ?", params=(9223372036854775807,))
rows = cur.fetchall()

assert rows[0][0] == 9223372036854775807
assert cur.description[0][1] == "bigint"
assert_cursor_description(cur, trino_type="bigint")


@pytest.mark.parametrize('params', [
Expand Down Expand Up @@ -1234,3 +1259,12 @@ def test_describe_table_query(run_trino):
aliased=False,
)
]


def assert_cursor_description(cur, trino_type, size=None, precision=None, scale=None):
assert cur.description[0][1] == trino_type
assert cur.description[0][2] is None
mdesmet marked this conversation as resolved.
Show resolved Hide resolved
assert cur.description[0][3] is size
assert cur.description[0][4] is precision
assert cur.description[0][5] is scale
assert cur.description[0][6] is None
4 changes: 4 additions & 0 deletions trino/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,3 +53,7 @@
HEADER_SET_CATALOG = "X-Trino-Set-Catalog"

HEADER_CLIENT_CAPABILITIES = "X-Trino-Client-Capabilities"

LENGTH_TYPES = ["char", "varchar"]
mdesmet marked this conversation as resolved.
Show resolved Hide resolved
PRECISION_TYPES = ["time", "time with time zone", "timestamp", "timestamp with time zone", "decimal"]
SCALE_TYPES = ["decimal"]
31 changes: 28 additions & 3 deletions trino/dbapi.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
import trino.exceptions
import trino.logging
from trino import constants
from trino.constants import LENGTH_TYPES, PRECISION_TYPES, SCALE_TYPES
from trino.exceptions import (
DatabaseError,
DataError,
Expand Down Expand Up @@ -237,6 +238,31 @@ def from_row(cls, row: List[Any]):
return cls(*row)


class ColumnDescription(NamedTuple):
name: str
type_code: int
display_size: int
internal_size: int
precision: int
scale: int
null_ok: bool

@classmethod
def from_column(cls, column: Dict[str, Any]):
type_signature = column["typeSignature"]
raw_type = type_signature["rawType"]
arguments = type_signature["arguments"]
return cls(
column["name"], # name
column["type"], # type_code
None, # display_size
arguments[0]["value"] if raw_type in LENGTH_TYPES else None, # internal_size
arguments[0]["value"] if raw_type in PRECISION_TYPES else None, # precision
arguments[1]["value"] if raw_type in SCALE_TYPES else None, # scale
None # null_ok
)


class Cursor(object):
"""Database cursor.

Expand Down Expand Up @@ -278,14 +304,13 @@ def update_type(self):
return None

@property
def description(self):
def description(self) -> List[ColumnDescription]:
if self._query.columns is None:
return None

# [ (name, type_code, display_size, internal_size, precision, scale, null_ok) ]
return [
(col["name"], col["type"], None, None, None, None, None)
for col in self._query.columns
ColumnDescription.from_column(col) for col in self._query.columns
]

@property
Expand Down