Skip to content

Commit

Permalink
Refactor value mappers to separate classes
Browse files Browse the repository at this point in the history
  • Loading branch information
mdesmet committed Nov 4, 2022
1 parent b008848 commit 69af6a3
Show file tree
Hide file tree
Showing 2 changed files with 172 additions and 101 deletions.
2 changes: 2 additions & 0 deletions tests/integration/test_types_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,6 +156,8 @@ def test_array(trino_connection):
SqlTest(trino_connection) \
.add_field(sql="CAST(null AS ARRAY(VARCHAR))", python=None) \
.add_field(sql="ARRAY['a', 'b', null]", python=['a', 'b', None]) \
.add_field(sql="ARRAY[1.2, 2.4, null]", python=[Decimal("1.2"), Decimal("2.4"), None]) \
.add_field(sql="ARRAY[CAST(4.9E-324 AS DOUBLE), null]", python=[5e-324, None]) \
.execute()


Expand Down
271 changes: 170 additions & 101 deletions trino/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,18 +32,18 @@
>> query = TrinoQuery(request, sql)
>> rows = list(query.execute())
"""

import abc
import copy
import functools
import os
import random
import re
import threading
import time
import urllib.parse
from datetime import datetime, timedelta, timezone
from datetime import date, datetime, time, timedelta, timezone
from decimal import Decimal
from typing import Any, Dict, List, Optional, Tuple, Union
from time import sleep
from typing import Any, Dict, Generic, List, Optional, Tuple, TypeVar, Union

import pytz
import requests
Expand Down Expand Up @@ -295,7 +295,7 @@ def __init__(

def retry(self, func, args, kwargs, err, attempt):
delay = self._get_delay(attempt)
time.sleep(delay)
sleep(delay)


class TrinoRequest(object):
Expand Down Expand Up @@ -850,123 +850,195 @@ def map(self, rows):
return rows


class RowMapperFactory:
"""
Given the 'columns' result from Trino, generate a list of
lambda functions (one for each column) which will process a data value
and returns a RowMapper instance which will process rows of data
"""
no_op_row_mapper = NoOpRowMapper()
T = TypeVar("T")

def create(self, columns, experimental_python_types):
assert columns is not None

if experimental_python_types:
return RowMapper([self._col_func(column['typeSignature']) for column in columns])
return RowMapperFactory.no_op_row_mapper
class ValueMapper(abc.ABC, Generic[T]):
@abc.abstractmethod
def map(self, value: Any) -> Optional[T]:
pass

def _col_func(self, column):
col_type = column['rawType']

if col_type == 'array':
return self._array_map_func(column)
elif col_type == 'row':
return self._row_map_func(column)
elif col_type == 'map':
return self._map_map_func(column)
elif col_type.startswith('decimal'):
return lambda val: Decimal(val)
elif col_type.startswith('double') or col_type.startswith('real'):
return self._double_map_func()
elif col_type.startswith('timestamp'):
return self._timestamp_map_func(column, col_type)
elif col_type.startswith('time'):
return self._time_map_func(column, col_type)
elif col_type == 'date':
return lambda val: datetime.strptime(val, '%Y-%m-%d').date()
else:
return lambda val: val
class NoOpValueMapper(ValueMapper[Any]):
def map(self, value) -> Optional[Any]:
return value

def _array_map_func(self, column):
element_mapping_func = self._col_func(column['arguments'][0]['value'])
return lambda values: [element_mapping_func(value) for value in values]

def _row_map_func(self, column):
element_mapping_func = [self._col_func(arg['value']['typeSignature']) for arg in column['arguments']]
return lambda values: tuple(element_mapping_func[idx](value) for idx, value in enumerate(values))
class DecimalValueMapper(ValueMapper[Decimal]):
def map(self, value) -> Optional[Decimal]:
if value is None:
return None
return Decimal(value)

def _map_map_func(self, column):
key_mapping_func = self._col_func(column['arguments'][0]['value'])
value_mapping_func = self._col_func(column['arguments'][1]['value'])
return lambda values: {key_mapping_func(key): value_mapping_func(value) for key, value in values.items()}

def _double_map_func(self):
return lambda val: INF if val == 'Infinity' \
else NEGATIVE_INF if val == '-Infinity' \
else NAN if val == 'NaN' \
else float(val)
class DoubleValueMapper(ValueMapper[float]):
def map(self, value) -> Optional[float]:
if value is None:
return None
if value == 'Infinity':
return INF
if value == '-Infinity':
return NEGATIVE_INF
if value == 'NaN':
return NAN
return float(value)

def _timestamp_map_func(self, column, col_type):
datetime_default_size = 20 # size of 'YYYY-MM-DD HH:MM:SS.' (the datetime string up to the milliseconds)
pattern = "%Y-%m-%d %H:%M:%S"
ms_size, ms_to_trim = self._get_number_of_digits(column)
if ms_size > 0:
pattern += ".%f"

dt_size = datetime_default_size + ms_size - ms_to_trim
dt_tz_offset = datetime_default_size + ms_size
if 'with time zone' in col_type:

if ms_to_trim > 0:
return lambda val: \
[datetime.strptime(val[:dt_size] + val[dt_tz_offset:], pattern + ' %z')
if tz.startswith('+') or tz.startswith('-')
else datetime.strptime(dt[:dt_size] + dt[dt_tz_offset:], pattern)
.replace(tzinfo=pytz.timezone(tz))
for dt, tz in [val.rsplit(' ', 1)]][0]
else:
return lambda val: [datetime.strptime(val, pattern + ' %z')
if tz.startswith('+') or tz.startswith('-')
else datetime.strptime(dt, pattern).replace(tzinfo=pytz.timezone(tz))
for dt, tz in [val.rsplit(' ', 1)]][0]
class TemporalValueMapper():
def _get_number_of_digits(self, column):
args = column['arguments']
if len(args) == 0:
return 3, 0
ms_size = column['arguments'][0]['value']
if ms_size == 0:
return -1, 0
ms_to_trim = ms_size - min(ms_size, 6)
return ms_size, ms_to_trim

if ms_to_trim > 0:
return lambda val: datetime.strptime(val[:dt_size] + val[dt_tz_offset:], pattern)
else:
return lambda val: datetime.strptime(val, pattern)

def _time_map_func(self, column, col_type):
class TimeValueMapper(ValueMapper[time], TemporalValueMapper):
def __init__(self, column):
pattern = "%H:%M:%S"
ms_size, ms_to_trim = self._get_number_of_digits(column)
if ms_size > 0:
pattern += ".%f"
self.pattern = pattern
self.time_size = 9 + ms_size - ms_to_trim

def map(self, value) -> Optional[time]:
if value is None:
return None
return datetime.strptime(value[:self.time_size], self.pattern).time()

time_size = 9 + ms_size - ms_to_trim

if 'with time zone' in col_type:
return lambda val: self._get_time_with_timezome(val, time_size, pattern)
else:
return lambda val: datetime.strptime(val[:time_size], pattern).time()
class TimeWithTimeZoneValueMapper(TimeValueMapper):
PATTERN = r'^(.*)([\+\-])(\d{2}):(\d{2})$'

def _get_time_with_timezome(self, value, time_size, pattern):
matches = re.match(r'^(.*)([\+\-])(\d{2}):(\d{2})$', value)
def map(self, value) -> Optional[time]:
if value is None:
return None
matches = re.match(TimeWithTimeZoneValueMapper.PATTERN, value)
assert matches is not None
assert len(matches.groups()) == 4
if matches.group(2) == '-':
tz = -timedelta(hours=int(matches.group(3)), minutes=int(matches.group(4)))
else:
tz = timedelta(hours=int(matches.group(3)), minutes=int(matches.group(4)))
return datetime.strptime(matches.group(1)[:time_size], pattern).time().replace(tzinfo=timezone(tz))
return datetime.strptime(matches.group(1)[:self.time_size], self.pattern).time().replace(tzinfo=timezone(tz))

def _get_number_of_digits(self, column):
args = column['arguments']
if len(args) == 0:
return 3, 0
ms_size = column['arguments'][0]['value']
if ms_size == 0:
return -1, 0
ms_to_trim = ms_size - min(ms_size, 6)
return ms_size, ms_to_trim

class DateValueMapper(ValueMapper[date]):
def map(self, value) -> Optional[date]:
if value is None:
return None
return datetime.strptime(value, '%Y-%m-%d').date()


class TimestampValueMapper(ValueMapper[datetime], TemporalValueMapper):
def __init__(self, column):
datetime_default_size = 20 # size of 'YYYY-MM-DD HH:MM:SS.' (the datetime string up to the milliseconds)
pattern = "%Y-%m-%d %H:%M:%S"
ms_size, ms_to_trim = self._get_number_of_digits(column)
if ms_size > 0:
pattern += ".%f"
self.pattern = pattern
self.dt_size = datetime_default_size + ms_size - ms_to_trim
self.dt_tz_offset = datetime_default_size + ms_size

def map(self, value) -> Optional[datetime]:
if value is None:
return None
return datetime.strptime(value[:self.dt_size] + value[self.dt_tz_offset:], self.pattern)


class TimestampWithTimeZoneValueMapper(TimestampValueMapper):
def map(self, value) -> Optional[datetime]:
if value is None:
return None
dt, tz = value.rsplit(' ', 1)
if tz.startswith('+') or tz.startswith('-'):
return datetime.strptime(value[:self.dt_size] + value[self.dt_tz_offset:], self.pattern + ' %z')
date_str = dt[:self.dt_size] + dt[self.dt_tz_offset:]
return datetime.strptime(date_str, self.pattern).replace(tzinfo=pytz.timezone(tz))


class ArrayValueMapper(ValueMapper[List[Optional[Any]]]):
def __init__(self, mapper: ValueMapper[Any]):
self.mapper = mapper

def map(self, values: List[Any]) -> Optional[List[Any]]:
if values is None:
return None
return [self.mapper.map(value) for value in values]


class RowValueMapper(ValueMapper[Tuple[Optional[Any], ...]]):
def __init__(self, mappers: List[ValueMapper[Any]]):
self.mappers = mappers

def map(self, values: List[Any]) -> Optional[Tuple[Optional[Any], ...]]:
if values is None:
return None
return tuple(self.mappers[idx].map(value) for idx, value in enumerate(values))


class MapValueMapper(ValueMapper[Dict[Any, Optional[Any]]]):
def __init__(self, key_mapper: ValueMapper[Any], value_mapper: ValueMapper[Any]):
self.key_mapper = key_mapper
self.value_mapper = value_mapper

def map(self, values: Any) -> Optional[Dict[Any, Optional[Any]]]:
if values is None:
return None
return {
self.key_mapper.map(key): self.value_mapper.map(value) for key, value in values.items()
}


class RowMapperFactory:
"""
Given the 'columns' result from Trino, generate a list of
lambda functions (one for each column) which will process a data value
and returns a RowMapper instance which will process rows of data
"""
NO_OP_ROW_MAPPER = NoOpRowMapper()

def create(self, columns, experimental_python_types):
assert columns is not None

if experimental_python_types:
return RowMapper([self._create_value_mapper(column['typeSignature']) for column in columns])
return RowMapperFactory.NO_OP_ROW_MAPPER

def _create_value_mapper(self, column) -> ValueMapper:
col_type = column['rawType']

if col_type == 'array':
value_mapper = self._create_value_mapper(column['arguments'][0]['value'])
return ArrayValueMapper(value_mapper)
elif col_type == 'row':
mappers = [self._create_value_mapper(arg['value']['typeSignature']) for arg in column['arguments']]
return RowValueMapper(mappers)
elif col_type == 'map':
key_mapper = self._create_value_mapper(column['arguments'][0]['value'])
value_mapper = self._create_value_mapper(column['arguments'][1]['value'])
return MapValueMapper(key_mapper, value_mapper)
elif col_type.startswith('decimal'):
return DecimalValueMapper()
elif col_type.startswith('double') or col_type.startswith('real'):
return DoubleValueMapper()
elif col_type.startswith('timestamp') and 'with time zone' in col_type:
return TimestampWithTimeZoneValueMapper(column)
elif col_type.startswith('timestamp'):
return TimestampValueMapper(column)
elif col_type.startswith('time') and 'with time zone' in col_type:
return TimeWithTimeZoneValueMapper(column)
elif col_type.startswith('time'):
return TimeValueMapper(column)
elif col_type == 'date':
return DateValueMapper()
else:
return NoOpValueMapper()


class RowMapper:
Expand All @@ -984,12 +1056,9 @@ def map(self, rows):
def _map_row(self, row):
return [self._map_value(value, self.columns[idx]) for idx, value in enumerate(row)]

def _map_value(self, value, col_mapping_func):
if value is None:
return None

def _map_value(self, value, value_mapper: ValueMapper[T]) -> Optional[T]:
try:
return col_mapping_func(value)
return value_mapper.map(value)
except ValueError as e:
error_str = f"Could not convert '{value}' into the associated python type"
raise trino.exceptions.TrinoDataError(error_str) from e

0 comments on commit 69af6a3

Please sign in to comment.